import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class CustomModel(nn.Module): def __init__(self, num_classes=4): super(CustomModel, self).__init__() self.efficientnet = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1) # Get the number of features from the last layer of EfficientNetV2 num_features = self.efficientnet.classifier[1].in_features # Remove the classifier self.efficientnet = nn.Sequential(*list(self.efficientnet.children())[:-1]) self.gap = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Linear(num_features, 512) self.dropout1 = nn.Dropout(0.5) self.fc2 = nn.Linear(512, 256) self.dropout2 = nn.Dropout(0.3) self.fc3 = nn.Linear(256, num_classes) def forward(self, x): x = self.efficientnet(x) x = self.gap(x) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.dropout1(x) x = F.relu(self.fc2(x)) x = self.dropout2(x) x = self.fc3(x) return x def load_model(model_path): model = CustomModel() model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() return model