""" EfficientNet-B7 Backbone for Print Quality Analysis. Dedicated backbone for the print quality classification head. Higher resolution and more parameters than ResNet50, specialized for detecting print patterns, color consistency, and artifacts. """ import torch import torch.nn as nn import torchvision.models as models class EfficientNetB7Backbone(nn.Module): """ EfficientNet-B7 feature extractor for print quality. - ImageNet-pretrained weights - Fine-tune last 2 blocks only (freeze earlier blocks) - Output: 2560-dim feature vector """ def __init__(self, pretrained: bool = True, freeze_early: bool = True): """ Initialize EfficientNet-B7 backbone. Args: pretrained: Use ImageNet-pretrained weights freeze_early: Freeze early blocks (default True) """ super().__init__() weights = models.EfficientNet_B7_Weights.DEFAULT if pretrained else None efficientnet = models.efficientnet_b7(weights=weights) # Extract features and pooling layers self.features = efficientnet.features self.avgpool = efficientnet.avgpool self.output_dim = 2560 if freeze_early: self._freeze_early_blocks() def _freeze_early_blocks(self): """ Freeze early feature blocks. EfficientNet-B7 has 8 blocks (indices 0-7). Freeze blocks 0-5, fine-tune blocks 6-7. """ for i, block in enumerate(self.features): if i < 6: for param in block.parameters(): param.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: """ Extract print quality features. Args: x: Input tensor (B, 3, 224, 224) Returns: Feature vector (B, 2560) """ x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) return x def get_layer_groups(self): """ Get parameter groups split by block depth for discriminative fine-tuning. Returns: List of 2 param lists: [block6_params, block7_params] (Earlier blocks 0-5 are frozen and excluded.) """ groups = [[], []] for i, block in enumerate(self.features): if i == 6: groups[0].extend([p for p in block.parameters() if p.requires_grad]) elif i >= 7: groups[1].extend([p for p in block.parameters() if p.requires_grad]) return groups def get_trainable_params(self): """Get count of trainable vs frozen parameters.""" trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) total = sum(p.numel() for p in self.parameters()) return {"trainable": trainable, "frozen": total - trainable, "total": total}