import torch import torch.nn as nn from torchvision import models def build_model(num_classes: int, device: torch.device): model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) # Freeze early layers for name, param in model.named_parameters(): if not ( name.startswith("layer3") or name.startswith("layer4") or name.startswith("fc") ): param.requires_grad = False # Replace classifier in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) return model.to(device)