tcg-space / Code /Model /src /dl /efficientnet.py
github-actions[bot]
deploy: backend bundle from 9c864b98f64c05462a27b71841ae97fb4451e449
c61ba70
"""
EfficientNet-B7 Backbone for Print Quality Analysis.
Dedicated backbone for the print quality classification head.
Higher resolution and more parameters than ResNet50, specialized
for detecting print patterns, color consistency, and artifacts.
"""
import torch
import torch.nn as nn
import torchvision.models as models
class EfficientNetB7Backbone(nn.Module):
"""
EfficientNet-B7 feature extractor for print quality.
- ImageNet-pretrained weights
- Fine-tune last 2 blocks only (freeze earlier blocks)
- Output: 2560-dim feature vector
"""
def __init__(self, pretrained: bool = True, freeze_early: bool = True):
"""
Initialize EfficientNet-B7 backbone.
Args:
pretrained: Use ImageNet-pretrained weights
freeze_early: Freeze early blocks (default True)
"""
super().__init__()
weights = models.EfficientNet_B7_Weights.DEFAULT if pretrained else None
efficientnet = models.efficientnet_b7(weights=weights)
# Extract features and pooling layers
self.features = efficientnet.features
self.avgpool = efficientnet.avgpool
self.output_dim = 2560
if freeze_early:
self._freeze_early_blocks()
def _freeze_early_blocks(self):
"""
Freeze early feature blocks.
EfficientNet-B7 has 8 blocks (indices 0-7).
Freeze blocks 0-5, fine-tune blocks 6-7.
"""
for i, block in enumerate(self.features):
if i < 6:
for param in block.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Extract print quality features.
Args:
x: Input tensor (B, 3, 224, 224)
Returns:
Feature vector (B, 2560)
"""
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def get_layer_groups(self):
"""
Get parameter groups split by block depth for discriminative fine-tuning.
Returns:
List of 2 param lists: [block6_params, block7_params]
(Earlier blocks 0-5 are frozen and excluded.)
"""
groups = [[], []]
for i, block in enumerate(self.features):
if i == 6:
groups[0].extend([p for p in block.parameters() if p.requires_grad])
elif i >= 7:
groups[1].extend([p for p in block.parameters() if p.requires_grad])
return groups
def get_trainable_params(self):
"""Get count of trainable vs frozen parameters."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
return {"trainable": trainable, "frozen": total - trainable, "total": total}