| from torch import nn | |
| from functools import partial | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention | |
| from timm import create_model | |
| from configuration_efficientnet import EfficientNetConfig | |
| class EfficientNetModel(PreTrainedModel): | |
| config_class = EfficientNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = create_model(config.model_name, pretrained = config.pretrained) | |
| def forward(self, pixel_values): | |
| last_hidden_state = self.model.forward_features(pixel_values) | |
| return BaseModelOutputWithPoolingAndNoAttention( | |
| last_hidden_state = last_hidden_state | |
| ) | |
| class EfficientNetModelForImageClassification(PreTrainedModel): | |
| config_class = EfficientNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = create_model(config.model_name, pretrained = config.pretrained) | |
| def forward(self, pixel_values, labels=None): | |
| logits = self.model(pixel_values) | |
| loss = None | |
| if labels is not None: | |
| loss = nn.CrossEntropyLoss(logits, labels) | |
| return ImageClassifierOutputWithNoAttention( | |
| loss = loss, | |
| logits = logits | |
| ) | |
| __all__ = [ | |
| "EfficientNetModel", | |
| "EfficientNetModelForImageClassification" | |
| ] |