from torch import nn from functools import partial from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention from timm import create_model from configuration_efficientnet import EfficientNetConfig class EfficientNetModel(PreTrainedModel): config_class = EfficientNetConfig def __init__(self, config): super().__init__(config) self.model = create_model(config.model_name, pretrained = config.pretrained) def forward(self, pixel_values): last_hidden_state = self.model.forward_features(pixel_values) return BaseModelOutputWithPoolingAndNoAttention( last_hidden_state = last_hidden_state ) class EfficientNetModelForImageClassification(PreTrainedModel): config_class = EfficientNetConfig def __init__(self, config): super().__init__(config) self.model = create_model(config.model_name, pretrained = config.pretrained) def forward(self, pixel_values, labels=None): logits = self.model(pixel_values) loss = None if labels is not None: loss = nn.CrossEntropyLoss(logits, labels) return ImageClassifierOutputWithNoAttention( loss = loss, logits = logits ) __all__ = [ "EfficientNetModel", "EfficientNetModelForImageClassification" ]