Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from torchvision import models | |
| class ResNet18Classifier(nn.Module): | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| dropout: float = 0.4, | |
| fc_dim: int = 256, | |
| fine_tune_mode: str = "layer4", | |
| ): | |
| super().__init__() | |
| weights = models.ResNet18_Weights.DEFAULT | |
| self.backbone = models.resnet18(weights=weights) | |
| in_features = self.backbone.fc.in_features | |
| # Freeze everything first | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| # Fine-tuning strategy | |
| if fine_tune_mode == "frozen": | |
| pass | |
| elif fine_tune_mode == "layer4": | |
| for param in self.backbone.layer4.parameters(): | |
| param.requires_grad = True | |
| elif fine_tune_mode == "full": | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = True | |
| else: | |
| raise ValueError(f"Unsupported fine_tune_mode: {fine_tune_mode}") | |
| self.backbone.fc = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(in_features, fc_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(fc_dim, num_classes), | |
| ) | |
| # Always train classifier head | |
| for param in self.backbone.fc.parameters(): | |
| param.requires_grad = True | |
| def forward(self, x): | |
| return self.backbone(x) |