enyasantos commited on
Commit
b7af805
·
verified ·
1 Parent(s): 55a6ba6

upload script

Browse files
Files changed (2) hide show
  1. model.ipynb +199 -0
  2. model.py +87 -0
model.ipynb ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "\n",
12
+ "from torchvision.models import densenet121, DenseNet121_Weights\n",
13
+ "from torchvision.models import resnet50, ResNet50_Weights\n",
14
+ "from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights\n",
15
+ "from torchvision.models import alexnet, AlexNet_Weights"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
25
+ "print(device)"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 2,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "def changedClassifierLayer(model, modelName, N_CLASSES=10):\n",
35
+ " print(modelName)\n",
36
+ " for param in model.parameters():\n",
37
+ " param.requires_grad = True\n",
38
+ "\n",
39
+ " if modelName == \"DenseNet121\":\n",
40
+ " num_input = model.classifier.in_features\n",
41
+ "\n",
42
+ " elif modelName == \"ResNet50\":\n",
43
+ " num_input = model.fc.in_features\n",
44
+ "\n",
45
+ " elif modelName == \"EfficientNet-V2-M\" or modelName == \"AlexNet\":\n",
46
+ " num_input = model.classifier[1].in_features\n",
47
+ "\n",
48
+ " classifier = nn.Sequential(\n",
49
+ " nn.Linear(num_input, 256),\n",
50
+ " nn.ReLU(),\n",
51
+ " nn.Dropout(0.2),\n",
52
+ " nn.Linear(256, 128),\n",
53
+ " nn.ReLU(),\n",
54
+ " nn.Dropout(0.2),\n",
55
+ " nn.Linear(128, N_CLASSES),\n",
56
+ " nn.LogSoftmax(dim=1)\n",
57
+ " )\n",
58
+ "\n",
59
+ " if modelName == \"ResNet50\":\n",
60
+ " model.fc = classifier\n",
61
+ " else:\n",
62
+ " model.classifier = classifier"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 3,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "efficientnet_weights_path = 'models-2/EfficientNet-V2-M.pth'\n",
72
+ "densenet_weights_path = 'models-2/DenseNet121.pth'\n",
73
+ "resnet_weights_path = 'models-2/ResNet50.pth'\n",
74
+ "alexnet_weights_path = 'models-2/AlexNet.pth'"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 4,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "efficientnetV2M_model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)\n",
84
+ "densenet_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)\n",
85
+ "resnet_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)\n",
86
+ "alexnet_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 5,
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "name": "stdout",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "EfficientNet-V2-M\n",
99
+ "DenseNet121\n",
100
+ "ResNet50\n",
101
+ "AlexNet\n"
102
+ ]
103
+ }
104
+ ],
105
+ "source": [
106
+ "changedClassifierLayer(efficientnetV2M_model, \"EfficientNet-V2-M\")\n",
107
+ "changedClassifierLayer(densenet_model, \"DenseNet121\")\n",
108
+ "changedClassifierLayer(resnet_model, \"ResNet50\")\n",
109
+ "changedClassifierLayer(alexnet_model, \"AlexNet\")"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 6,
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "data": {
119
+ "text/plain": [
120
+ "<All keys matched successfully>"
121
+ ]
122
+ },
123
+ "execution_count": 6,
124
+ "metadata": {},
125
+ "output_type": "execute_result"
126
+ }
127
+ ],
128
+ "source": [
129
+ "efficientnetV2M_model.load_state_dict(torch.load(efficientnet_weights_path))\n",
130
+ "densenet_model.load_state_dict(torch.load(densenet_weights_path))\n",
131
+ "resnet_model.load_state_dict(torch.load(resnet_weights_path))\n",
132
+ "alexnet_model.load_state_dict(torch.load(alexnet_weights_path))"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 7,
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "class EnsembleModel(nn.Module):\n",
142
+ " def __init__(self, model_list, weights=None):\n",
143
+ " super(EnsembleModel, self).__init__()\n",
144
+ " self.models = nn.ModuleList(model_list)\n",
145
+ " self.weights = weights\n",
146
+ "\n",
147
+ " def forward(self, x):\n",
148
+ " outputs = [model(x.to(next(model.parameters()).device)) for model in self.models]\n",
149
+ "\n",
150
+ " if self.weights is None:\n",
151
+ " # ensemble_output = torch.mean(torch.stack(outputs), dim=0)\n",
152
+ "\n",
153
+ " ensemble_output, _ = torch.max(torch.stack(outputs), dim=0)\n",
154
+ " else:\n",
155
+ " weighted_outputs = torch.stack([w * output for w, output in zip(self.weights, outputs)])\n",
156
+ " ensemble_output = torch.sum(weighted_outputs, dim=0)\n",
157
+ "\n",
158
+ " return ensemble_output"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 8,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "models_list = [\n",
168
+ " efficientnetV2M_model,\n",
169
+ " densenet_model,\n",
170
+ " resnet_model,\n",
171
+ " alexnet_model\n",
172
+ "]\n",
173
+ "\n",
174
+ "ensemble_model = EnsembleModel(models_list)"
175
+ ]
176
+ }
177
+ ],
178
+ "metadata": {
179
+ "kernelspec": {
180
+ "display_name": "Python 3",
181
+ "language": "python",
182
+ "name": "python3"
183
+ },
184
+ "language_info": {
185
+ "codemirror_mode": {
186
+ "name": "ipython",
187
+ "version": 3
188
+ },
189
+ "file_extension": ".py",
190
+ "mimetype": "text/x-python",
191
+ "name": "python",
192
+ "nbconvert_exporter": "python",
193
+ "pygments_lexer": "ipython3",
194
+ "version": "3.11.0"
195
+ }
196
+ },
197
+ "nbformat": 4,
198
+ "nbformat_minor": 2
199
+ }
model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torchvision.models import densenet121, DenseNet121_Weights
5
+ from torchvision.models import resnet50, ResNet50_Weights
6
+ from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
7
+ from torchvision.models import alexnet, AlexNet_Weights
8
+
9
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
+
11
+ def changedClassifierLayer(model, modelName, N_CLASSES=10):
12
+ for param in model.parameters():
13
+ param.requires_grad = False
14
+
15
+ if modelName == "DenseNet121":
16
+ num_input = model.classifier.in_features
17
+
18
+ elif modelName == "ResNet50":
19
+ num_input = model.fc.in_features
20
+
21
+ elif modelName == "EfficientNet-V2-M" or modelName == "AlexNet":
22
+ num_input = model.classifier[1].in_features
23
+
24
+ classifier = nn.Sequential(
25
+ nn.Linear(num_input, 256),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.2),
28
+ nn.Linear(256, 128),
29
+ nn.ReLU(),
30
+ nn.Dropout(0.2),
31
+ nn.Linear(128, N_CLASSES),
32
+ nn.LogSoftmax(dim=1)
33
+ )
34
+
35
+ if modelName == "ResNet50":
36
+ model.fc = classifier
37
+ else:
38
+ model.classifier = classifier
39
+
40
+ efficientnet_weights_path = 'models/EfficientNet-V2-M.pth'
41
+ densenet_weights_path = 'models/DenseNet121.pth'
42
+ resnet_weights_path = 'models/ResNet50.pth'
43
+ alexnet_weights_path = 'models/AlexNet.pth'
44
+
45
+ efficientnetV2M_model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
46
+ densenet_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
47
+ resnet_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
48
+ alexnet_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
49
+
50
+ changedClassifierLayer(efficientnetV2M_model, "EfficientNet-V2-M")
51
+ changedClassifierLayer(densenet_model, "DenseNet121")
52
+ changedClassifierLayer(resnet_model, "ResNet50")
53
+ changedClassifierLayer(alexnet_model, "AlexNet")
54
+
55
+ efficientnetV2M_model.load_state_dict(torch.load(efficientnet_weights_path))
56
+ densenet_model.load_state_dict(torch.load(densenet_weights_path))
57
+ resnet_model.load_state_dict(torch.load(resnet_weights_path))
58
+ alexnet_model.load_state_dict(torch.load(alexnet_weights_path))
59
+
60
+ class EnsembleModel(nn.Module):
61
+ def __init__(self, model_list, weights=None):
62
+ super(EnsembleModel, self).__init__()
63
+ self.models = nn.ModuleList(model_list)
64
+ self.weights = weights
65
+
66
+ def forward(self, x):
67
+ outputs = [model(x.to(next(model.parameters()).device)) for model in self.models]
68
+
69
+ if self.weights is None:
70
+ # ensemble_output = torch.mean(torch.stack(outputs), dim=0)
71
+
72
+ ensemble_output, _ = torch.max(torch.stack(outputs), dim=0)
73
+ else:
74
+ weighted_outputs = torch.stack([w * output for w, output in zip(self.weights, outputs)])
75
+ ensemble_output = torch.sum(weighted_outputs, dim=0)
76
+
77
+ return ensemble_output
78
+
79
+ models_list = [
80
+ efficientnetV2M_model,
81
+ densenet_model,
82
+ resnet_model,
83
+ alexnet_model
84
+ ]
85
+
86
+ ensemble_model = EnsembleModel(models_list)
87
+