Spaces:
Running on Zero
Running on Zero
| import torch.nn as nn | |
| from torchvision import models | |
| class ResNet18Classifier(nn.Module): | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| dropout: float = 0.5, | |
| fc_dim: int = 256, | |
| freeze_backbone: bool = True, | |
| ): | |
| super().__init__() | |
| weights = models.ResNet18_Weights.DEFAULT | |
| self.backbone = models.resnet18(weights=weights) | |
| in_features = self.backbone.fc.in_features | |
| if freeze_backbone: | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| self.backbone.fc = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(in_features, fc_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(fc_dim, num_classes), | |
| ) | |
| for param in self.backbone.fc.parameters(): | |
| param.requires_grad = True | |
| def forward(self, x): | |
| return self.backbone(x) |