efficientnet_b0 / modeling_efficientnet.py
Thastp's picture
Upload model
d1a7921 verified
raw
history blame
1.47 kB
from torch import nn
from functools import partial
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
from timm import create_model
from configuration_efficientnet import EfficientNetConfig
class EfficientNetModel(PreTrainedModel):
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.model = create_model(config.model_name, pretrained = config.pretrained)
def forward(self, pixel_values):
last_hidden_state = self.model.forward_features(pixel_values)
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state = last_hidden_state
)
class EfficientNetModelForImageClassification(PreTrainedModel):
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.model = create_model(config.model_name, pretrained = config.pretrained)
def forward(self, pixel_values, labels=None):
logits = self.model(pixel_values)
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss(logits, labels)
return ImageClassifierOutputWithNoAttention(
loss = loss,
logits = logits
)
__all__ = [
"EfficientNetModel",
"EfficientNetModelForImageClassification"
]