| from torch import nn | |
| from functools import partial | |
| from transformers import PreTrainedModel | |
| from timm import create_model | |
| from configuration_efficientnet import EfficientNetConfig | |
| class EfficientNetModel(PreTrainedModel): | |
| config_class = EfficientNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = create_model(config.model_name, pretrained = config.pretrained) | |
| def forward(self, pixel_values): | |
| return self.model.forward_features(pixel_values) | |
| class EfficientNetModelForImageClassification(PreTrainedModel): | |
| config_class = EfficientNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = create_model(config.model_name, pretrained = config.pretrained) | |
| def forward(self, pixel_values, labels=None): | |
| logits = self.model(pixel_values) | |
| if labels is not None: | |
| loss = nn.CrossEntropyLoss(logits, labels) | |
| return {"loss": loss, "logits": logits} | |
| return logits | |
| __all__ = [ | |
| "EfficientNetModel", | |
| "EfficientNetModelForImageClassification" | |
| ] |