from __future__ import annotations import torch.nn as nn from torchvision.models import resnet50 def build_model(num_classes: int, pretrained: bool = False) -> nn.Module: m = resnet50(pretrained=pretrained) in_features = m.fc.in_features m.fc = nn.Linear(in_features, num_classes) return m