Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| def build_model(num_classes: int, device: torch.device): | |
| model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) | |
| # Freeze early layers | |
| for name, param in model.named_parameters(): | |
| if not ( | |
| name.startswith("layer3") or | |
| name.startswith("layer4") or | |
| name.startswith("fc") | |
| ): | |
| param.requires_grad = False | |
| # Replace classifier | |
| in_features = model.fc.in_features | |
| model.fc = nn.Linear(in_features, num_classes) | |
| return model.to(device) | |