File size: 377 Bytes
756deb2 | 1 2 3 4 5 6 7 8 9 | import torch
from torch import nn
def get_model():
# Replicate your architecture exactly
model = torch.hub.load('pytorch/vision', 'mobilenet_v2', pretrained=True)
model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=5)
return model |