enyasantos commited on
Commit
f960225
·
verified ·
1 Parent(s): a9c5e7f

upload scripts

Browse files
Files changed (2) hide show
  1. model.ipynb +214 -0
  2. model.py +102 -0
model.ipynb ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "\n",
12
+ "from torchvision.models import densenet121, DenseNet121_Weights\n",
13
+ "from torchvision.models import resnet50, ResNet50_Weights\n",
14
+ "from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights\n",
15
+ "from torchvision.models import alexnet, AlexNet_Weights\n",
16
+ "from torchvision.models import vgg16, VGG16_Weights\n",
17
+ "from torchvision.models import vgg19, VGG19_Weights"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
27
+ "print(device)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "def changedClassifierLayer(model, modelName, N_CLASSES=10):\n",
37
+ " for param in model.parameters():\n",
38
+ " param.requires_grad = False\n",
39
+ "\n",
40
+ " if modelName == \"DenseNet121\":\n",
41
+ " num_input = model.classifier.in_features\n",
42
+ "\n",
43
+ " elif modelName == \"ResNet50\":\n",
44
+ " num_input = model.fc.in_features\n",
45
+ "\n",
46
+ " elif modelName == \"EfficientNet-V2-M\" or modelName == \"AlexNet\":\n",
47
+ " num_input = model.classifier[1].in_features\n",
48
+ "\n",
49
+ " elif modelName == \"VGG19\" or modelName == \"VGG16\":\n",
50
+ " num_input = model.classifier[0].in_features\n",
51
+ "\n",
52
+ " classifier = nn.Sequential(\n",
53
+ " nn.Linear(num_input, 256),\n",
54
+ " nn.ReLU(),\n",
55
+ " nn.Dropout(0.2),\n",
56
+ " nn.Linear(256, 128),\n",
57
+ " nn.ReLU(),\n",
58
+ " nn.Dropout(0.2),\n",
59
+ " nn.Linear(128, N_CLASSES),\n",
60
+ " nn.LogSoftmax(dim=1)\n",
61
+ " )\n",
62
+ "\n",
63
+ " if modelName == \"ResNet50\":\n",
64
+ " model.fc = classifier\n",
65
+ " else:\n",
66
+ " model.classifier = classifier"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 3,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "efficientnet_weights_path = 'models/EfficientNet-V2-M.pth'\n",
76
+ "densenet_weights_path = 'models/DenseNet121.pth'\n",
77
+ "resnet_weights_path = 'models/ResNet50.pth'\n",
78
+ "alexnet_weights_path = 'models/AlexNet.pth'\n",
79
+ "vgg16_weights_path = 'models/VGG16.pth'\n",
80
+ "vgg19_weights_path = 'models/VGG19.pth'"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 4,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "efficientnetV2M_model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)\n",
90
+ "densenet_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)\n",
91
+ "resnet_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)\n",
92
+ "alexnet_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)\n",
93
+ "vgg16_model = alexnet(weights=VGG16_Weights.IMAGENET1K_V1)\n",
94
+ "vgg19_model = alexnet(weights=VGG19_Weights.IMAGENET1K_V1)\n"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 5,
100
+ "metadata": {},
101
+ "outputs": [
102
+ {
103
+ "name": "stdout",
104
+ "output_type": "stream",
105
+ "text": [
106
+ "EfficientNet-V2-M\n",
107
+ "DenseNet121\n",
108
+ "ResNet50\n",
109
+ "AlexNet\n"
110
+ ]
111
+ }
112
+ ],
113
+ "source": [
114
+ "changedClassifierLayer(efficientnetV2M_model, \"EfficientNet-V2-M\")\n",
115
+ "changedClassifierLayer(densenet_model, \"DenseNet121\")\n",
116
+ "changedClassifierLayer(resnet_model, \"ResNet50\")\n",
117
+ "changedClassifierLayer(alexnet_model, \"AlexNet\")\n",
118
+ "changedClassifierLayer(vgg16_model, \"VGG16\")\n",
119
+ "changedClassifierLayer(vgg19_model, \"VGG19\")"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 6,
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "data": {
129
+ "text/plain": [
130
+ "<All keys matched successfully>"
131
+ ]
132
+ },
133
+ "execution_count": 6,
134
+ "metadata": {},
135
+ "output_type": "execute_result"
136
+ }
137
+ ],
138
+ "source": [
139
+ "efficientnetV2M_model.load_state_dict(torch.load(efficientnet_weights_path))\n",
140
+ "densenet_model.load_state_dict(torch.load(densenet_weights_path))\n",
141
+ "resnet_model.load_state_dict(torch.load(resnet_weights_path))\n",
142
+ "alexnet_model.load_state_dict(torch.load(alexnet_weights_path))\n",
143
+ "vgg16_model.load_state_dict(torch.load(vgg16_weights_path))\n",
144
+ "vgg19_model.load_state_dict(torch.load(vgg19_weights_path))"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 7,
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "class EnsembleModel(nn.Module):\n",
154
+ " def __init__(self, model_list, weights=None):\n",
155
+ " super(EnsembleModel, self).__init__()\n",
156
+ " self.models = nn.ModuleList(model_list)\n",
157
+ " self.weights = weights\n",
158
+ "\n",
159
+ " def forward(self, x):\n",
160
+ " outputs = [model(x.to(next(model.parameters()).device)) for model in self.models]\n",
161
+ "\n",
162
+ " if self.weights is None:\n",
163
+ " ensemble_output = torch.mean(torch.stack(outputs), dim=0)\n",
164
+ "\n",
165
+ " \n",
166
+ " #ensemble_output, _ = torch.max(torch.stack(outputs), dim=0)\n",
167
+ " else:\n",
168
+ " weighted_outputs = torch.stack([w * output for w, output in zip(self.weights, outputs)])\n",
169
+ " ensemble_output = torch.sum(weighted_outputs, dim=0)\n",
170
+ "\n",
171
+ " return ensemble_output"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 8,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "models_list = [\n",
181
+ " efficientnetV2M_model,\n",
182
+ " densenet_model,\n",
183
+ " resnet_model,\n",
184
+ " alexnet_model,\n",
185
+ " vgg16_model,\n",
186
+ " vgg19_model\n",
187
+ "]\n",
188
+ "\n",
189
+ "ensemble_model = EnsembleModel(models_list)"
190
+ ]
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "Python 3",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.11.0"
210
+ }
211
+ },
212
+ "nbformat": 4,
213
+ "nbformat_minor": 2
214
+ }
model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torchvision.models import densenet121, DenseNet121_Weights
5
+ from torchvision.models import resnet50, ResNet50_Weights
6
+ from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
7
+ from torchvision.models import alexnet, AlexNet_Weights
8
+ from torchvision.models import vgg16, VGG16_Weights
9
+ from torchvision.models import vgg19, VGG19_Weights
10
+
11
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
12
+
13
+ def changedClassifierLayer(model, modelName, N_CLASSES=10):
14
+ for param in model.parameters():
15
+ param.requires_grad = False
16
+
17
+ if modelName == "DenseNet121":
18
+ num_input = model.classifier.in_features
19
+
20
+ elif modelName == "ResNet50":
21
+ num_input = model.fc.in_features
22
+
23
+ elif modelName == "EfficientNet-V2-M" or modelName == "AlexNet":
24
+ num_input = model.classifier[1].in_features
25
+
26
+ elif modelName == "VGG19" or modelName == "VGG16":
27
+ num_input = model.classifier[0].in_features
28
+
29
+ classifier = nn.Sequential(
30
+ nn.Linear(num_input, 256),
31
+ nn.ReLU(),
32
+ nn.Dropout(0.2),
33
+ nn.Linear(256, 128),
34
+ nn.ReLU(),
35
+ nn.Dropout(0.2),
36
+ nn.Linear(128, N_CLASSES),
37
+ nn.LogSoftmax(dim=1)
38
+ )
39
+
40
+ if modelName == "ResNet50":
41
+ model.fc = classifier
42
+ else:
43
+ model.classifier = classifier
44
+
45
+ efficientnet_weights_path = 'models/EfficientNet-V2-M.pth'
46
+ densenet_weights_path = 'models/DenseNet121.pth'
47
+ resnet_weights_path = 'models/ResNet50.pth'
48
+ alexnet_weights_path = 'models/AlexNet.pth'
49
+ vgg16_weights_path = 'models/VGG16.pth'
50
+ vgg19_weights_path = 'models/VGG19.pth'
51
+
52
+ efficientnetV2M_model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
53
+ densenet_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
54
+ resnet_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
55
+ alexnet_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
56
+ vgg16_model = alexnet(weights=VGG16_Weights.IMAGENET1K_V1)
57
+ vgg19_model = alexnet(weights=VGG19_Weights.IMAGENET1K_V1)
58
+
59
+ changedClassifierLayer(efficientnetV2M_model, "EfficientNet-V2-M")
60
+ changedClassifierLayer(densenet_model, "DenseNet121")
61
+ changedClassifierLayer(resnet_model, "ResNet50")
62
+ changedClassifierLayer(alexnet_model, "AlexNet")
63
+ changedClassifierLayer(vgg16_model, "VGG16")
64
+ changedClassifierLayer(vgg19_model, "VGG19")
65
+
66
+ efficientnetV2M_model.load_state_dict(torch.load(efficientnet_weights_path))
67
+ densenet_model.load_state_dict(torch.load(densenet_weights_path))
68
+ resnet_model.load_state_dict(torch.load(resnet_weights_path))
69
+ alexnet_model.load_state_dict(torch.load(alexnet_weights_path))
70
+ vgg16_model.load_state_dict(torch.load(vgg16_weights_path))
71
+ vgg19_model.load_state_dict(torch.load(vgg19_weights_path))
72
+
73
+ class EnsembleModel(nn.Module):
74
+ def __init__(self, model_list, weights=None):
75
+ super(EnsembleModel, self).__init__()
76
+ self.models = nn.ModuleList(model_list)
77
+ self.weights = weights
78
+
79
+ def forward(self, x):
80
+ outputs = [model(x.to(next(model.parameters()).device)) for model in self.models]
81
+
82
+ if self.weights is None:
83
+ ensemble_output = torch.mean(torch.stack(outputs), dim=0)
84
+
85
+ #ensemble_output, _ = torch.max(torch.stack(outputs), dim=0)
86
+ else:
87
+ weighted_outputs = torch.stack([w * output for w, output in zip(self.weights, outputs)])
88
+ ensemble_output = torch.sum(weighted_outputs, dim=0)
89
+
90
+ return ensemble_output
91
+
92
+ models_list = [
93
+ efficientnetV2M_model,
94
+ densenet_model,
95
+ resnet_model,
96
+ alexnet_model,
97
+ vgg16_model,
98
+ vgg19_model
99
+ ]
100
+
101
+ ensemble_model = EnsembleModel(models_list)
102
+