File size: 377 Bytes
756deb2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
import torch
from torch import nn

def get_model():
    # Replicate your architecture exactly
    model = torch.hub.load('pytorch/vision', 'mobilenet_v2', pretrained=True)
    model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
    model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=5)
    return model