zoolang / model.py
vsimmer's picture
Upload 4 files
756deb2 verified
raw
history blame contribute delete
377 Bytes
import torch
from torch import nn
def get_model():
# Replicate your architecture exactly
model = torch.hub.load('pytorch/vision', 'mobilenet_v2', pretrained=True)
model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=5)
return model