efficientnet_b0 / modeling_efficientnet.py
Thastp's picture
Upload model
6febcc2 verified
raw
history blame
1.15 kB
from torch import nn
from functools import partial
from transformers import PreTrainedModel
from timm import create_model
from configuration_efficientnet import EfficientNetConfig
class EfficientNetModel(PreTrainedModel):
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.model = create_model(config.model_name, pretrained = config.pretrained)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class EfficientNetModelForImageClassification(PreTrainedModel):
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.model = create_model(config.model_name, pretrained = config.pretrained)
def forward(self, pixel_values, labels=None):
logits = self.model(pixel_values)
if labels is not None:
loss = nn.CrossEntropyLoss(logits, labels)
return {"loss": loss, "logits": logits}
return logits
__all__ = [
"EfficientNetModel",
"EfficientNetModelForImageClassification"
]