""" src/model.py Builds the EfficientNet-B0 model with a custom 24-label output head. Design decisions worth knowing for interviews: - We replace only the final classifier layer, keeping the feature extractor intact. - EfficientNet-B0's penultimate representation has 1280 channels (after global average pooling). A single Linear(1280, NUM_LABELS) projects to 24 logits. - No sigmoid here — BCEWithLogitsLoss fuses sigmoid + loss in one numerically stable operation, so we keep raw logits until inference. - Dropout(0.3) before the head is EfficientNet's own convention; we preserve it. """ import torch import torch.nn as nn from torchvision.models import EfficientNet_B0_Weights, efficientnet_b0 from src.config import NUM_LABELS def build_model(num_labels: int = NUM_LABELS, pretrained: bool = True) -> nn.Module: """ Return EfficientNet-B0 with ImageNet weights and a fresh NUM_LABELS head. The returned model has two named parameter groups that train.py uses to apply different learning rates: - "backbone": everything except the final classifier - "head": the new Linear layer """ weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None model = efficientnet_b0(weights=weights) # EfficientNet-B0 classifier: Sequential(Dropout(0.2), Linear(1280, 1000)) # We keep Dropout, replace only the Linear. in_features = model.classifier[1].in_features # 1280 model.classifier[1] = nn.Linear(in_features, num_labels) return model def freeze_backbone(model: nn.Module) -> None: """Freeze all layers except the final classifier (head-only training phase).""" for name, param in model.named_parameters(): if not name.startswith("classifier"): param.requires_grad = False def unfreeze_all(model: nn.Module) -> None: """Unfreeze all parameters (full fine-tuning phase).""" for param in model.parameters(): param.requires_grad = True def count_params(model: nn.Module) -> dict: total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return {"total": total, "trainable": trainable}