Spaces:
Starting
Starting
| """ | |
| ResNet50 Backbone for Card Authentication. | |
| Provides the shared feature extractor used by 5 of 6 classification heads. | |
| ImageNet-pretrained with optional partial freezing strategy. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class ResNet50Backbone(nn.Module): | |
| """ | |
| ResNet50 feature extractor. | |
| - ImageNet-pretrained weights | |
| - Freeze early layers (conv1 through layer2) by default | |
| - Fine-tune layer3 and layer4 | |
| - Output: 2048-dim feature vector | |
| """ | |
| def __init__(self, pretrained: bool = True, freeze_early: bool = True): | |
| """ | |
| Initialize ResNet50 backbone. | |
| Args: | |
| pretrained: Use ImageNet-pretrained weights | |
| freeze_early: Freeze conv1 through layer2 (default True) | |
| """ | |
| super().__init__() | |
| weights = models.ResNet50_Weights.DEFAULT if pretrained else None | |
| resnet = models.resnet50(weights=weights) | |
| # Remove the classification head (avgpool + fc) | |
| self.conv1 = resnet.conv1 | |
| self.bn1 = resnet.bn1 | |
| self.relu = resnet.relu | |
| self.maxpool = resnet.maxpool | |
| self.layer1 = resnet.layer1 | |
| self.layer2 = resnet.layer2 | |
| self.layer3 = resnet.layer3 | |
| self.layer4 = resnet.layer4 | |
| self.avgpool = resnet.avgpool | |
| self.output_dim = 2048 | |
| if freeze_early: | |
| self._freeze_early_layers() | |
| def _freeze_early_layers(self): | |
| """Freeze conv1 through layer2 to preserve low-level features.""" | |
| for module in [self.conv1, self.bn1, self.layer1, self.layer2]: | |
| for param in module.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Extract features from input image. | |
| Args: | |
| x: Input tensor (B, 3, 224, 224) | |
| Returns: | |
| Feature vector (B, 2048) | |
| """ | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| return x | |
| def get_layer_groups(self): | |
| """ | |
| Get parameter groups split by layer depth for discriminative fine-tuning. | |
| Returns: | |
| List of 2 param lists: [layer3_params, layer4_params] | |
| (Earlier layers are frozen and excluded.) | |
| """ | |
| return [ | |
| [p for p in self.layer3.parameters() if p.requires_grad], | |
| [p for p in self.layer4.parameters() if p.requires_grad], | |
| ] | |
| def get_trainable_params(self): | |
| """Get count of trainable vs frozen parameters.""" | |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in self.parameters()) | |
| return {"trainable": trainable, "frozen": total - trainable, "total": total} | |