File size: 4,118 Bytes
9c69ae8 c57c180 9c69ae8 c57c180 9c69ae8 c57c180 9c69ae8 c57c180 9c69ae8 6febcc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from torch import nn, Tensor, tensor
from typing import Union, List, Optional
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention
)
from timm import create_model
from .configuration_efficientnet import EfficientNetConfig
class EfficientNetModel(PreTrainedModel):
"""
EfficientNet model wrapper using Hugging Face's PreTrainedModel.
This class initializes an EfficientNet model from `timm` library
and defines a forward method that extracts feature representations.
Attributes
----------
config:
Configuration object containing model parameters.
model:
Instantiated EfficientNet model.
"""
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = create_model(
config.model_name,
pretrained = config.pretrained,
num_classes = config.num_classes,
global_pool = config.global_pool,
)
def forward(self, pixel_values: Tensor) -> BaseModelOutputWithPoolingAndNoAttention:
"""
Parameters
----------
pixel_values : torch.Tensor
Input tensor representing image pixel values.
Returns
-------
BaseModelOutputWithPoolingAndNoAttention
Object containing the `last_hidden_state` and `pooled_output`.
"""
last_hidden_state = self.model.forward_features(pixel_values)
pooler_output = self.model.forward_head(last_hidden_state, pre_logits=True)
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state = last_hidden_state,
pooler_output=pooler_output
)
class EfficientNetModelForImageClassification(PreTrainedModel):
"""
EfficientNet model wrapper using Hugging Face's PreTrainedModel.
This class initializes an EfficientNet model from `timm` library
and defines a forward method that return logits.
It supports training when labels are provided
Attributes
----------
config :
Configuration object containing model parameters.
model :
Instantiated EfficientNet model.
"""
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = create_model(
config.model_name,
pretrained = config.pretrained,
num_classes = config.num_classes,
global_pool = config.global_pool,
)
def compute_loss(self, logits, labels=None):
loss = None
if labels is None:
pass
else:
labels = tensor(labels)
ce_loss = nn.CrossEntropyLoss()
loss = ce_loss(logits, labels)
return loss
def forward(
self,
pixel_values: Tensor,
labels: Optional[Union[List[int], Tensor]] = None,
) -> ImageClassifierOutputWithNoAttention:
"""
Parameters
----------
pixel_values : torch.Tensor
Input tensor representing image pixel values.
labels : Optional[Union[List[int], torch.Tensor]]
Ground truth labels for training and computing loss.
List of integers/tensor representing class IDs.
Returns
-------
ImageClassifierOutputWithNoAttention
Object containing `logits` and `loss`.
"""
logits = self.model(pixel_values)
loss = self.compute_loss(logits, labels)
return ImageClassifierOutputWithNoAttention(
loss = loss,
logits = logits,
)
__all__ = [
"EfficientNetModel",
"EfficientNetModelForImageClassification"
] |