import torchvision.models as models from transformers import PreTrainedModel import torch.nn as nn from transformers import PretrainedConfig class MusheffConfig(PretrainedConfig): model_type = "efficientnet_b3" def __init__(self, num_classes=12, dropout_rate=0.3, **kwargs): self.num_classes = num_classes self.dropout_rate = dropout_rate super().__init__(**kwargs) class Musheff(PreTrainedModel): config_class = MusheffConfig # Link to config def __init__(self, config): super().__init__(config) # Extract parameters num_classes = config.num_classes dropout_rate = config.dropout_rate # # Load default weights from base model # weights = models.EfficientNet_B3_Weights.DEFAULT # Load base model self.model = models.efficientnet_b3(weights=None) # Modify classifier head in_features = self.model.classifier[1].in_features self.model.classifier = nn.Sequential( nn.Dropout(p=dropout_rate, inplace=True), nn.Linear(in_features, num_classes), ) def forward(self, x): return self.model(x)