File size: 2,981 Bytes
0ba6002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61ba70
 
 
 
 
 
 
 
 
 
 
 
 
0ba6002
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
ResNet50 Backbone for Card Authentication.

Provides the shared feature extractor used by 5 of 6 classification heads.
ImageNet-pretrained with optional partial freezing strategy.
"""

import torch
import torch.nn as nn
import torchvision.models as models


class ResNet50Backbone(nn.Module):
    """
    ResNet50 feature extractor.

    - ImageNet-pretrained weights
    - Freeze early layers (conv1 through layer2) by default
    - Fine-tune layer3 and layer4
    - Output: 2048-dim feature vector
    """

    def __init__(self, pretrained: bool = True, freeze_early: bool = True):
        """
        Initialize ResNet50 backbone.

        Args:
            pretrained: Use ImageNet-pretrained weights
            freeze_early: Freeze conv1 through layer2 (default True)
        """
        super().__init__()

        weights = models.ResNet50_Weights.DEFAULT if pretrained else None
        resnet = models.resnet50(weights=weights)

        # Remove the classification head (avgpool + fc)
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        self.avgpool = resnet.avgpool

        self.output_dim = 2048

        if freeze_early:
            self._freeze_early_layers()

    def _freeze_early_layers(self):
        """Freeze conv1 through layer2 to preserve low-level features."""
        for module in [self.conv1, self.bn1, self.layer1, self.layer2]:
            for param in module.parameters():
                param.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Extract features from input image.

        Args:
            x: Input tensor (B, 3, 224, 224)

        Returns:
            Feature vector (B, 2048)
        """
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return x

    def get_layer_groups(self):
        """
        Get parameter groups split by layer depth for discriminative fine-tuning.

        Returns:
            List of 2 param lists: [layer3_params, layer4_params]
            (Earlier layers are frozen and excluded.)
        """
        return [
            [p for p in self.layer3.parameters() if p.requires_grad],
            [p for p in self.layer4.parameters() if p.requires_grad],
        ]

    def get_trainable_params(self):
        """Get count of trainable vs frozen parameters."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        return {"trainable": trainable, "frozen": total - trainable, "total": total}