enyasantos's picture
upload scripts
f960225 verified
import torch
import torch.nn as nn
from torchvision.models import densenet121, DenseNet121_Weights
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
from torchvision.models import alexnet, AlexNet_Weights
from torchvision.models import vgg16, VGG16_Weights
from torchvision.models import vgg19, VGG19_Weights
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def changedClassifierLayer(model, modelName, N_CLASSES=10):
for param in model.parameters():
param.requires_grad = False
if modelName == "DenseNet121":
num_input = model.classifier.in_features
elif modelName == "ResNet50":
num_input = model.fc.in_features
elif modelName == "EfficientNet-V2-M" or modelName == "AlexNet":
num_input = model.classifier[1].in_features
elif modelName == "VGG19" or modelName == "VGG16":
num_input = model.classifier[0].in_features
classifier = nn.Sequential(
nn.Linear(num_input, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, N_CLASSES),
nn.LogSoftmax(dim=1)
)
if modelName == "ResNet50":
model.fc = classifier
else:
model.classifier = classifier
efficientnet_weights_path = 'models/EfficientNet-V2-M.pth'
densenet_weights_path = 'models/DenseNet121.pth'
resnet_weights_path = 'models/ResNet50.pth'
alexnet_weights_path = 'models/AlexNet.pth'
vgg16_weights_path = 'models/VGG16.pth'
vgg19_weights_path = 'models/VGG19.pth'
efficientnetV2M_model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
densenet_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
resnet_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
alexnet_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
vgg16_model = alexnet(weights=VGG16_Weights.IMAGENET1K_V1)
vgg19_model = alexnet(weights=VGG19_Weights.IMAGENET1K_V1)
changedClassifierLayer(efficientnetV2M_model, "EfficientNet-V2-M")
changedClassifierLayer(densenet_model, "DenseNet121")
changedClassifierLayer(resnet_model, "ResNet50")
changedClassifierLayer(alexnet_model, "AlexNet")
changedClassifierLayer(vgg16_model, "VGG16")
changedClassifierLayer(vgg19_model, "VGG19")
efficientnetV2M_model.load_state_dict(torch.load(efficientnet_weights_path))
densenet_model.load_state_dict(torch.load(densenet_weights_path))
resnet_model.load_state_dict(torch.load(resnet_weights_path))
alexnet_model.load_state_dict(torch.load(alexnet_weights_path))
vgg16_model.load_state_dict(torch.load(vgg16_weights_path))
vgg19_model.load_state_dict(torch.load(vgg19_weights_path))
class EnsembleModel(nn.Module):
def __init__(self, model_list, weights=None):
super(EnsembleModel, self).__init__()
self.models = nn.ModuleList(model_list)
self.weights = weights
def forward(self, x):
outputs = [model(x.to(next(model.parameters()).device)) for model in self.models]
if self.weights is None:
ensemble_output = torch.mean(torch.stack(outputs), dim=0)
#ensemble_output, _ = torch.max(torch.stack(outputs), dim=0)
else:
weighted_outputs = torch.stack([w * output for w, output in zip(self.weights, outputs)])
ensemble_output = torch.sum(weighted_outputs, dim=0)
return ensemble_output
models_list = [
efficientnetV2M_model,
densenet_model,
resnet_model,
alexnet_model,
vgg16_model,
vgg19_model
]
ensemble_model = EnsembleModel(models_list)