{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "from torchvision.models import densenet121, DenseNet121_Weights\n", "from torchvision.models import resnet50, ResNet50_Weights\n", "from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights\n", "from torchvision.models import alexnet, AlexNet_Weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", "print(device)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def changedClassifierLayer(model, modelName, N_CLASSES=10):\n", " print(modelName)\n", " for param in model.parameters():\n", " param.requires_grad = True\n", "\n", " if modelName == \"DenseNet121\":\n", " num_input = model.classifier.in_features\n", "\n", " elif modelName == \"ResNet50\":\n", " num_input = model.fc.in_features\n", "\n", " elif modelName == \"EfficientNet-V2-M\" or modelName == \"AlexNet\":\n", " num_input = model.classifier[1].in_features\n", "\n", " classifier = nn.Sequential(\n", " nn.Linear(num_input, 256),\n", " nn.ReLU(),\n", " nn.Dropout(0.2),\n", " nn.Linear(256, 128),\n", " nn.ReLU(),\n", " nn.Dropout(0.2),\n", " nn.Linear(128, N_CLASSES),\n", " nn.LogSoftmax(dim=1)\n", " )\n", "\n", " if modelName == \"ResNet50\":\n", " model.fc = classifier\n", " else:\n", " model.classifier = classifier" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "efficientnet_weights_path = 'models-2/EfficientNet-V2-M.pth'\n", "densenet_weights_path = 'models-2/DenseNet121.pth'\n", "resnet_weights_path = 'models-2/ResNet50.pth'\n", "alexnet_weights_path = 'models-2/AlexNet.pth'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "efficientnetV2M_model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)\n", "densenet_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)\n", "resnet_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)\n", "alexnet_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "EfficientNet-V2-M\n", "DenseNet121\n", "ResNet50\n", "AlexNet\n" ] } ], "source": [ "changedClassifierLayer(efficientnetV2M_model, \"EfficientNet-V2-M\")\n", "changedClassifierLayer(densenet_model, \"DenseNet121\")\n", "changedClassifierLayer(resnet_model, \"ResNet50\")\n", "changedClassifierLayer(alexnet_model, \"AlexNet\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "efficientnetV2M_model.load_state_dict(torch.load(efficientnet_weights_path))\n", "densenet_model.load_state_dict(torch.load(densenet_weights_path))\n", "resnet_model.load_state_dict(torch.load(resnet_weights_path))\n", "alexnet_model.load_state_dict(torch.load(alexnet_weights_path))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class EnsembleModel(nn.Module):\n", " def __init__(self, model_list, weights=None):\n", " super(EnsembleModel, self).__init__()\n", " self.models = nn.ModuleList(model_list)\n", " self.weights = weights\n", "\n", " def forward(self, x):\n", " outputs = [model(x.to(next(model.parameters()).device)) for model in self.models]\n", "\n", " if self.weights is None:\n", " # ensemble_output = torch.mean(torch.stack(outputs), dim=0)\n", "\n", " ensemble_output, _ = torch.max(torch.stack(outputs), dim=0)\n", " else:\n", " weighted_outputs = torch.stack([w * output for w, output in zip(self.weights, outputs)])\n", " ensemble_output = torch.sum(weighted_outputs, dim=0)\n", "\n", " return ensemble_output" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "models_list = [\n", " efficientnetV2M_model,\n", " densenet_model,\n", " resnet_model,\n", " alexnet_model\n", "]\n", "\n", "ensemble_model = EnsembleModel(models_list)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 2 }