import torch.nn as nn from torchvision import models class ResNet18Classifier(nn.Module): def __init__( self, num_classes: int, dropout: float = 0.4, fc_dim: int = 256, fine_tune_mode: str = "layer4", ): super().__init__() weights = models.ResNet18_Weights.DEFAULT self.backbone = models.resnet18(weights=weights) in_features = self.backbone.fc.in_features # Freeze everything first for param in self.backbone.parameters(): param.requires_grad = False # Fine-tuning strategy if fine_tune_mode == "frozen": pass elif fine_tune_mode == "layer4": for param in self.backbone.layer4.parameters(): param.requires_grad = True elif fine_tune_mode == "full": for param in self.backbone.parameters(): param.requires_grad = True else: raise ValueError(f"Unsupported fine_tune_mode: {fine_tune_mode}") self.backbone.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_features, fc_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fc_dim, num_classes), ) # Always train classifier head for param in self.backbone.fc.parameters(): param.requires_grad = True def forward(self, x): return self.backbone(x)