File size: 2,886 Bytes
0ba6002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61ba70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ba6002
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
EfficientNet-B7 Backbone for Print Quality Analysis.

Dedicated backbone for the print quality classification head.
Higher resolution and more parameters than ResNet50, specialized
for detecting print patterns, color consistency, and artifacts.
"""

import torch
import torch.nn as nn
import torchvision.models as models


class EfficientNetB7Backbone(nn.Module):
    """
    EfficientNet-B7 feature extractor for print quality.

    - ImageNet-pretrained weights
    - Fine-tune last 2 blocks only (freeze earlier blocks)
    - Output: 2560-dim feature vector
    """

    def __init__(self, pretrained: bool = True, freeze_early: bool = True):
        """
        Initialize EfficientNet-B7 backbone.

        Args:
            pretrained: Use ImageNet-pretrained weights
            freeze_early: Freeze early blocks (default True)
        """
        super().__init__()

        weights = models.EfficientNet_B7_Weights.DEFAULT if pretrained else None
        efficientnet = models.efficientnet_b7(weights=weights)

        # Extract features and pooling layers
        self.features = efficientnet.features
        self.avgpool = efficientnet.avgpool

        self.output_dim = 2560

        if freeze_early:
            self._freeze_early_blocks()

    def _freeze_early_blocks(self):
        """
        Freeze early feature blocks.

        EfficientNet-B7 has 8 blocks (indices 0-7).
        Freeze blocks 0-5, fine-tune blocks 6-7.
        """
        for i, block in enumerate(self.features):
            if i < 6:
                for param in block.parameters():
                    param.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Extract print quality features.

        Args:
            x: Input tensor (B, 3, 224, 224)

        Returns:
            Feature vector (B, 2560)
        """
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x

    def get_layer_groups(self):
        """
        Get parameter groups split by block depth for discriminative fine-tuning.

        Returns:
            List of 2 param lists: [block6_params, block7_params]
            (Earlier blocks 0-5 are frozen and excluded.)
        """
        groups = [[], []]
        for i, block in enumerate(self.features):
            if i == 6:
                groups[0].extend([p for p in block.parameters() if p.requires_grad])
            elif i >= 7:
                groups[1].extend([p for p in block.parameters() if p.requires_grad])
        return groups

    def get_trainable_params(self):
        """Get count of trainable vs frozen parameters."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        return {"trainable": trainable, "frozen": total - trainable, "total": total}