""" ResNet50 Backbone for Card Authentication. Provides the shared feature extractor used by 5 of 6 classification heads. ImageNet-pretrained with optional partial freezing strategy. """ import torch import torch.nn as nn import torchvision.models as models class ResNet50Backbone(nn.Module): """ ResNet50 feature extractor. - ImageNet-pretrained weights - Freeze early layers (conv1 through layer2) by default - Fine-tune layer3 and layer4 - Output: 2048-dim feature vector """ def __init__(self, pretrained: bool = True, freeze_early: bool = True): """ Initialize ResNet50 backbone. Args: pretrained: Use ImageNet-pretrained weights freeze_early: Freeze conv1 through layer2 (default True) """ super().__init__() weights = models.ResNet50_Weights.DEFAULT if pretrained else None resnet = models.resnet50(weights=weights) # Remove the classification head (avgpool + fc) self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 self.avgpool = resnet.avgpool self.output_dim = 2048 if freeze_early: self._freeze_early_layers() def _freeze_early_layers(self): """Freeze conv1 through layer2 to preserve low-level features.""" for module in [self.conv1, self.bn1, self.layer1, self.layer2]: for param in module.parameters(): param.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: """ Extract features from input image. Args: x: Input tensor (B, 3, 224, 224) Returns: Feature vector (B, 2048) """ x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) return x def get_layer_groups(self): """ Get parameter groups split by layer depth for discriminative fine-tuning. Returns: List of 2 param lists: [layer3_params, layer4_params] (Earlier layers are frozen and excluded.) """ return [ [p for p in self.layer3.parameters() if p.requires_grad], [p for p in self.layer4.parameters() if p.requires_grad], ] def get_trainable_params(self): """Get count of trainable vs frozen parameters.""" trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) total = sum(p.numel() for p in self.parameters()) return {"trainable": trainable, "frozen": total - trainable, "total": total}