import torch.nn as nn from torchvision import models class ResNet18Classifier(nn.Module): def __init__( self, num_classes: int, dropout: float = 0.5, fc_dim: int = 256, freeze_backbone: bool = True, ): super().__init__() weights = models.ResNet18_Weights.DEFAULT self.backbone = models.resnet18(weights=weights) in_features = self.backbone.fc.in_features if freeze_backbone: for param in self.backbone.parameters(): param.requires_grad = False self.backbone.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_features, fc_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fc_dim, num_classes), ) for param in self.backbone.fc.parameters(): param.requires_grad = True def forward(self, x): return self.backbone(x)