File size: 2,195 Bytes
9466fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
src/model.py

Builds the EfficientNet-B0 model with a custom 24-label output head.

Design decisions worth knowing for interviews:
- We replace only the final classifier layer, keeping the feature extractor intact.
- EfficientNet-B0's penultimate representation has 1280 channels (after global
  average pooling). A single Linear(1280, NUM_LABELS) projects to 24 logits.
- No sigmoid here — BCEWithLogitsLoss fuses sigmoid + loss in one numerically
  stable operation, so we keep raw logits until inference.
- Dropout(0.3) before the head is EfficientNet's own convention; we preserve it.
"""

import torch
import torch.nn as nn
from torchvision.models import EfficientNet_B0_Weights, efficientnet_b0

from src.config import NUM_LABELS


def build_model(num_labels: int = NUM_LABELS, pretrained: bool = True) -> nn.Module:
    """
    Return EfficientNet-B0 with ImageNet weights and a fresh NUM_LABELS head.

    The returned model has two named parameter groups that train.py uses to
    apply different learning rates:
      - "backbone": everything except the final classifier
      - "head": the new Linear layer
    """
    weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
    model = efficientnet_b0(weights=weights)

    # EfficientNet-B0 classifier: Sequential(Dropout(0.2), Linear(1280, 1000))
    # We keep Dropout, replace only the Linear.
    in_features = model.classifier[1].in_features  # 1280
    model.classifier[1] = nn.Linear(in_features, num_labels)

    return model


def freeze_backbone(model: nn.Module) -> None:
    """Freeze all layers except the final classifier (head-only training phase)."""
    for name, param in model.named_parameters():
        if not name.startswith("classifier"):
            param.requires_grad = False


def unfreeze_all(model: nn.Module) -> None:
    """Unfreeze all parameters (full fine-tuning phase)."""
    for param in model.parameters():
        param.requires_grad = True


def count_params(model: nn.Module) -> dict:
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {"total": total, "trainable": trainable}