File size: 307 Bytes
b7265bc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from __future__ import annotations
import torch.nn as nn
from torchvision.models import resnet50

def build_model(num_classes: int, pretrained: bool = False) -> nn.Module:
    m = resnet50(pretrained=pretrained)
    in_features = m.fc.in_features
    m.fc = nn.Linear(in_features, num_classes)
    return m