| import torchvision.models as models | |
| from transformers import PreTrainedModel | |
| import torch.nn as nn | |
| from transformers import PretrainedConfig | |
| class MusheffConfig(PretrainedConfig): | |
| model_type = "efficientnet_b3" | |
| def __init__(self, num_classes=12, dropout_rate=0.3, **kwargs): | |
| self.num_classes = num_classes | |
| self.dropout_rate = dropout_rate | |
| super().__init__(**kwargs) | |
| class Musheff(PreTrainedModel): | |
| config_class = MusheffConfig # Link to config | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Extract parameters | |
| num_classes = config.num_classes | |
| dropout_rate = config.dropout_rate | |
| # # Load default weights from base model | |
| # weights = models.EfficientNet_B3_Weights.DEFAULT | |
| # Load base model | |
| self.model = models.efficientnet_b3(weights=None) | |
| # Modify classifier head | |
| in_features = self.model.classifier[1].in_features | |
| self.model.classifier = nn.Sequential( | |
| nn.Dropout(p=dropout_rate, inplace=True), | |
| nn.Linear(in_features, num_classes), | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |