musheff / modeling.py
blasisd's picture
Added both the model and its config in the same modeling.py file, deleted extra configuration file.
8e66c4a
import torchvision.models as models
from transformers import PreTrainedModel
import torch.nn as nn
from transformers import PretrainedConfig
class MusheffConfig(PretrainedConfig):
model_type = "efficientnet_b3"
def __init__(self, num_classes=12, dropout_rate=0.3, **kwargs):
self.num_classes = num_classes
self.dropout_rate = dropout_rate
super().__init__(**kwargs)
class Musheff(PreTrainedModel):
config_class = MusheffConfig # Link to config
def __init__(self, config):
super().__init__(config)
# Extract parameters
num_classes = config.num_classes
dropout_rate = config.dropout_rate
# # Load default weights from base model
# weights = models.EfficientNet_B3_Weights.DEFAULT
# Load base model
self.model = models.efficientnet_b3(weights=None)
# Modify classifier head
in_features = self.model.classifier[1].in_features
self.model.classifier = nn.Sequential(
nn.Dropout(p=dropout_rate, inplace=True),
nn.Linear(in_features, num_classes),
)
def forward(self, x):
return self.model(x)