Spaces:
Runtime error
Runtime error
| """ | |
| EfficientNet-B7 Backbone for Print Quality Analysis. | |
| Dedicated backbone for the print quality classification head. | |
| Higher resolution and more parameters than ResNet50, specialized | |
| for detecting print patterns, color consistency, and artifacts. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class EfficientNetB7Backbone(nn.Module): | |
| """ | |
| EfficientNet-B7 feature extractor for print quality. | |
| - ImageNet-pretrained weights | |
| - Fine-tune last 2 blocks only (freeze earlier blocks) | |
| - Output: 2560-dim feature vector | |
| """ | |
| def __init__(self, pretrained: bool = True, freeze_early: bool = True): | |
| """ | |
| Initialize EfficientNet-B7 backbone. | |
| Args: | |
| pretrained: Use ImageNet-pretrained weights | |
| freeze_early: Freeze early blocks (default True) | |
| """ | |
| super().__init__() | |
| weights = models.EfficientNet_B7_Weights.DEFAULT if pretrained else None | |
| efficientnet = models.efficientnet_b7(weights=weights) | |
| # Extract features and pooling layers | |
| self.features = efficientnet.features | |
| self.avgpool = efficientnet.avgpool | |
| self.output_dim = 2560 | |
| if freeze_early: | |
| self._freeze_early_blocks() | |
| def _freeze_early_blocks(self): | |
| """ | |
| Freeze early feature blocks. | |
| EfficientNet-B7 has 8 blocks (indices 0-7). | |
| Freeze blocks 0-5, fine-tune blocks 6-7. | |
| """ | |
| for i, block in enumerate(self.features): | |
| if i < 6: | |
| for param in block.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Extract print quality features. | |
| Args: | |
| x: Input tensor (B, 3, 224, 224) | |
| Returns: | |
| Feature vector (B, 2560) | |
| """ | |
| x = self.features(x) | |
| x = self.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| return x | |
| def get_layer_groups(self): | |
| """ | |
| Get parameter groups split by block depth for discriminative fine-tuning. | |
| Returns: | |
| List of 2 param lists: [block6_params, block7_params] | |
| (Earlier blocks 0-5 are frozen and excluded.) | |
| """ | |
| groups = [[], []] | |
| for i, block in enumerate(self.features): | |
| if i == 6: | |
| groups[0].extend([p for p in block.parameters() if p.requires_grad]) | |
| elif i >= 7: | |
| groups[1].extend([p for p in block.parameters() if p.requires_grad]) | |
| return groups | |
| def get_trainable_params(self): | |
| """Get count of trainable vs frozen parameters.""" | |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in self.parameters()) | |
| return {"trainable": trainable, "frozen": total - trainable, "total": total} | |