import torch.nn as nn import torchvision.models as models class Musheff(nn.Module): def __init__(self, config): super().__init__() # Extract parameters from config num_classes = config["num_classes"] dropout_rate = config["dropout_rate"] # Load default weights from base model weights = models.EfficientNet_B3_Weights.DEFAULT # Load base model self.model = models.efficientnet_b3(weights=weights) # Modify classifier head in_features = self.model.classifier[1].in_features self.model.classifier = nn.Sequential( nn.Dropout(p=dropout_rate, inplace=True), nn.Linear(in_features, num_classes), ) def forward(self, x): return self.model(x)