import torch from torch import nn def get_model(): # Replicate your architecture exactly model = torch.hub.load('pytorch/vision', 'mobilenet_v2', pretrained=True) model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False) model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=5) return model