Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from torchvision import models | |
| def build_model(num_classes, device): | |
| model = models.resnet18( | |
| weights=models.ResNet18_Weights.IMAGENET1K_V1 | |
| ) | |
| # Freezing everything | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Unfreezing deeper layers | |
| for param in model.layer3.parameters(): | |
| param.requires_grad = True | |
| for param in model.layer4.parameters(): | |
| param.requires_grad = True | |
| # Replacing classifier for our number of classes | |
| in_features = model.fc.in_features | |
| model.fc = nn.Linear(in_features, num_classes) | |
| model = model.to(device) | |
| return model | |