File size: 3,615 Bytes
f960225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torch.nn as nn

from torchvision.models import densenet121, DenseNet121_Weights
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
from torchvision.models import alexnet, AlexNet_Weights
from torchvision.models import vgg16, VGG16_Weights
from torchvision.models import vgg19, VGG19_Weights

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def changedClassifierLayer(model, modelName, N_CLASSES=10):
    for param in model.parameters():
      param.requires_grad = False

    if modelName == "DenseNet121":
      num_input = model.classifier.in_features

    elif modelName == "ResNet50":
      num_input = model.fc.in_features

    elif modelName == "EfficientNet-V2-M" or modelName == "AlexNet":
      num_input = model.classifier[1].in_features

    elif modelName == "VGG19" or modelName == "VGG16":
      num_input = model.classifier[0].in_features

    classifier = nn.Sequential(
      nn.Linear(num_input, 256),
      nn.ReLU(),
      nn.Dropout(0.2),
      nn.Linear(256, 128),
      nn.ReLU(),
      nn.Dropout(0.2),
      nn.Linear(128, N_CLASSES),
      nn.LogSoftmax(dim=1)
    )

    if modelName == "ResNet50":
      model.fc = classifier
    else:
      model.classifier = classifier

efficientnet_weights_path = 'models/EfficientNet-V2-M.pth'
densenet_weights_path = 'models/DenseNet121.pth'
resnet_weights_path = 'models/ResNet50.pth'
alexnet_weights_path = 'models/AlexNet.pth'
vgg16_weights_path = 'models/VGG16.pth'
vgg19_weights_path = 'models/VGG19.pth'

efficientnetV2M_model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
densenet_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
resnet_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
alexnet_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
vgg16_model = alexnet(weights=VGG16_Weights.IMAGENET1K_V1)
vgg19_model = alexnet(weights=VGG19_Weights.IMAGENET1K_V1)

changedClassifierLayer(efficientnetV2M_model, "EfficientNet-V2-M")
changedClassifierLayer(densenet_model, "DenseNet121")
changedClassifierLayer(resnet_model, "ResNet50")
changedClassifierLayer(alexnet_model, "AlexNet")
changedClassifierLayer(vgg16_model, "VGG16")
changedClassifierLayer(vgg19_model, "VGG19")

efficientnetV2M_model.load_state_dict(torch.load(efficientnet_weights_path))
densenet_model.load_state_dict(torch.load(densenet_weights_path))
resnet_model.load_state_dict(torch.load(resnet_weights_path))
alexnet_model.load_state_dict(torch.load(alexnet_weights_path))
vgg16_model.load_state_dict(torch.load(vgg16_weights_path))
vgg19_model.load_state_dict(torch.load(vgg19_weights_path))

class EnsembleModel(nn.Module):
    def __init__(self, model_list, weights=None):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList(model_list)
        self.weights = weights

    def forward(self, x):
        outputs = [model(x.to(next(model.parameters()).device)) for model in self.models]

        if self.weights is None:
            ensemble_output = torch.mean(torch.stack(outputs), dim=0)

            #ensemble_output, _ = torch.max(torch.stack(outputs), dim=0)
        else:
            weighted_outputs = torch.stack([w * output for w, output in zip(self.weights, outputs)])
            ensemble_output = torch.sum(weighted_outputs, dim=0)

        return ensemble_output
    
models_list = [
    efficientnetV2M_model,
    densenet_model,
    resnet_model,
    alexnet_model,
    vgg16_model,
    vgg19_model
]

ensemble_model = EnsembleModel(models_list)