| import torch | |
| from torch import nn | |
| def get_model(): | |
| # Replicate your architecture exactly | |
| model = torch.hub.load('pytorch/vision', 'mobilenet_v2', pretrained=True) | |
| model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False) | |
| model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=5) | |
| return model |