File size: 1,296 Bytes
54afcfb
e6b3256
 
3169467
54afcfb
 
 
 
 
3169467
 
54afcfb
3169467
 
54afcfb
3169467
54afcfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3169467
54afcfb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class CustomModel(nn.Module):
    def __init__(self, num_classes=4):
        super(CustomModel, self).__init__()
        self.efficientnet = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)

        # Get the number of features from the last layer of EfficientNetV2
        num_features = self.efficientnet.classifier[1].in_features

        # Remove the classifier
        self.efficientnet = nn.Sequential(*list(self.efficientnet.children())[:-1])

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(num_features, 512)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.efficientnet(x)
        x = self.gap(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

def load_model(model_path):
    model = CustomModel()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model