Spaces:
Runtime error
Runtime error
File size: 2,981 Bytes
0ba6002 c61ba70 0ba6002 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | """
ResNet50 Backbone for Card Authentication.
Provides the shared feature extractor used by 5 of 6 classification heads.
ImageNet-pretrained with optional partial freezing strategy.
"""
import torch
import torch.nn as nn
import torchvision.models as models
class ResNet50Backbone(nn.Module):
"""
ResNet50 feature extractor.
- ImageNet-pretrained weights
- Freeze early layers (conv1 through layer2) by default
- Fine-tune layer3 and layer4
- Output: 2048-dim feature vector
"""
def __init__(self, pretrained: bool = True, freeze_early: bool = True):
"""
Initialize ResNet50 backbone.
Args:
pretrained: Use ImageNet-pretrained weights
freeze_early: Freeze conv1 through layer2 (default True)
"""
super().__init__()
weights = models.ResNet50_Weights.DEFAULT if pretrained else None
resnet = models.resnet50(weights=weights)
# Remove the classification head (avgpool + fc)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.avgpool = resnet.avgpool
self.output_dim = 2048
if freeze_early:
self._freeze_early_layers()
def _freeze_early_layers(self):
"""Freeze conv1 through layer2 to preserve low-level features."""
for module in [self.conv1, self.bn1, self.layer1, self.layer2]:
for param in module.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Extract features from input image.
Args:
x: Input tensor (B, 3, 224, 224)
Returns:
Feature vector (B, 2048)
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def get_layer_groups(self):
"""
Get parameter groups split by layer depth for discriminative fine-tuning.
Returns:
List of 2 param lists: [layer3_params, layer4_params]
(Earlier layers are frozen and excluded.)
"""
return [
[p for p in self.layer3.parameters() if p.requires_grad],
[p for p in self.layer4.parameters() if p.requires_grad],
]
def get_trainable_params(self):
"""Get count of trainable vs frozen parameters."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
return {"trainable": trainable, "frozen": total - trainable, "total": total}
|