import torch.nn as nn from torchvision import models def build_model(num_classes, device): model = models.resnet18( weights=models.ResNet18_Weights.IMAGENET1K_V1 ) # Freezing everything for param in model.parameters(): param.requires_grad = False # Unfreezing deeper layers for param in model.layer3.parameters(): param.requires_grad = True for param in model.layer4.parameters(): param.requires_grad = True # Replacing classifier for our number of classes in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) model = model.to(device) return model