kishore-9's picture
Add road scene classifier app
9466fff
"""
src/model.py
Builds the EfficientNet-B0 model with a custom 24-label output head.
Design decisions worth knowing for interviews:
- We replace only the final classifier layer, keeping the feature extractor intact.
- EfficientNet-B0's penultimate representation has 1280 channels (after global
average pooling). A single Linear(1280, NUM_LABELS) projects to 24 logits.
- No sigmoid here — BCEWithLogitsLoss fuses sigmoid + loss in one numerically
stable operation, so we keep raw logits until inference.
- Dropout(0.3) before the head is EfficientNet's own convention; we preserve it.
"""
import torch
import torch.nn as nn
from torchvision.models import EfficientNet_B0_Weights, efficientnet_b0
from src.config import NUM_LABELS
def build_model(num_labels: int = NUM_LABELS, pretrained: bool = True) -> nn.Module:
"""
Return EfficientNet-B0 with ImageNet weights and a fresh NUM_LABELS head.
The returned model has two named parameter groups that train.py uses to
apply different learning rates:
- "backbone": everything except the final classifier
- "head": the new Linear layer
"""
weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
model = efficientnet_b0(weights=weights)
# EfficientNet-B0 classifier: Sequential(Dropout(0.2), Linear(1280, 1000))
# We keep Dropout, replace only the Linear.
in_features = model.classifier[1].in_features # 1280
model.classifier[1] = nn.Linear(in_features, num_labels)
return model
def freeze_backbone(model: nn.Module) -> None:
"""Freeze all layers except the final classifier (head-only training phase)."""
for name, param in model.named_parameters():
if not name.startswith("classifier"):
param.requires_grad = False
def unfreeze_all(model: nn.Module) -> None:
"""Unfreeze all parameters (full fine-tuning phase)."""
for param in model.parameters():
param.requires_grad = True
def count_params(model: nn.Module) -> dict:
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {"total": total, "trainable": trainable}