tcg-space / Code /Model /src /dl /backbone.py
github-actions[bot]
deploy: backend bundle from 9c864b98f64c05462a27b71841ae97fb4451e449
c61ba70
"""
ResNet50 Backbone for Card Authentication.
Provides the shared feature extractor used by 5 of 6 classification heads.
ImageNet-pretrained with optional partial freezing strategy.
"""
import torch
import torch.nn as nn
import torchvision.models as models
class ResNet50Backbone(nn.Module):
"""
ResNet50 feature extractor.
- ImageNet-pretrained weights
- Freeze early layers (conv1 through layer2) by default
- Fine-tune layer3 and layer4
- Output: 2048-dim feature vector
"""
def __init__(self, pretrained: bool = True, freeze_early: bool = True):
"""
Initialize ResNet50 backbone.
Args:
pretrained: Use ImageNet-pretrained weights
freeze_early: Freeze conv1 through layer2 (default True)
"""
super().__init__()
weights = models.ResNet50_Weights.DEFAULT if pretrained else None
resnet = models.resnet50(weights=weights)
# Remove the classification head (avgpool + fc)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.avgpool = resnet.avgpool
self.output_dim = 2048
if freeze_early:
self._freeze_early_layers()
def _freeze_early_layers(self):
"""Freeze conv1 through layer2 to preserve low-level features."""
for module in [self.conv1, self.bn1, self.layer1, self.layer2]:
for param in module.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Extract features from input image.
Args:
x: Input tensor (B, 3, 224, 224)
Returns:
Feature vector (B, 2048)
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def get_layer_groups(self):
"""
Get parameter groups split by layer depth for discriminative fine-tuning.
Returns:
List of 2 param lists: [layer3_params, layer4_params]
(Earlier layers are frozen and excluded.)
"""
return [
[p for p in self.layer3.parameters() if p.requires_grad],
[p for p in self.layer4.parameters() if p.requires_grad],
]
def get_trainable_params(self):
"""Get count of trainable vs frozen parameters."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
return {"trainable": trainable, "frozen": total - trainable, "total": total}