kishore-9's picture
Add road scene classifier app
9466fff
"""
src/train.py
Two-phase training loop:
Phase 1 (freeze) — backbone frozen, only head trains for `freeze_epochs`
Phase 2 (unfreeze) — entire network trains for `unfreeze_epochs` at lower lr
All runs are logged to MLflow. The best checkpoint (by val_macro_f1) is saved.
Usage:
python -m src.train --config configs/baseline.yaml
python -m src.train --config configs/baseline.yaml --epochs 1 # quick check
"""
import argparse
import logging
from pathlib import Path
import mlflow
import torch
import torch.nn as nn
from sklearn.metrics import f1_score
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.config import LABELS, MLRUNS_DIR, NUM_LABELS, SEED
from src.dataset import BDDMultiLabelDataset, load_pos_weight
from src.model import build_model, count_params, freeze_backbone, unfreeze_all
from src.utils import get_device, load_yaml, set_seed
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# One epoch helpers
# ---------------------------------------------------------------------------
def train_one_epoch(model, loader, criterion, optimizer, device) -> float:
model.train()
total_loss = 0.0
for imgs, labels in tqdm(loader, desc=" train", leave=False):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * len(imgs)
return total_loss / len(loader.dataset)
@torch.no_grad()
def evaluate(model, loader, criterion, device, threshold: float = 0.5) -> dict:
model.eval()
total_loss = 0.0
all_preds, all_targets = [], []
for imgs, labels in tqdm(loader, desc=" eval ", leave=False):
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
loss = criterion(logits, labels)
total_loss += loss.item() * len(imgs)
probs = torch.sigmoid(logits).cpu().numpy()
preds = (probs >= threshold).astype(int)
all_preds.append(preds)
all_targets.append(labels.cpu().numpy())
import numpy as np
all_preds = np.vstack(all_preds)
all_targets = np.vstack(all_targets)
micro_f1 = f1_score(all_targets, all_preds, average="micro", zero_division=0)
macro_f1 = f1_score(all_targets, all_preds, average="macro", zero_division=0)
return {
"loss": total_loss / len(loader.dataset),
"micro_f1": micro_f1,
"macro_f1": macro_f1,
}
# ---------------------------------------------------------------------------
# Main training loop
# ---------------------------------------------------------------------------
def train(cfg: dict, override_epochs: int | None = None):
set_seed(SEED)
device = get_device()
log.info("Device: %s", device)
# --- data ---
train_ds = BDDMultiLabelDataset("train")
val_ds = BDDMultiLabelDataset("val")
num_workers = cfg.get("num_workers", 0)
train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"] * 2, shuffle=False,
num_workers=num_workers, pin_memory=True)
# --- model ---
model = build_model().to(device)
params = count_params(model)
log.info("Params — total: %s trainable: %s", f"{params['total']:,}", f"{params['trainable']:,}")
# --- loss ---
if cfg.get("use_pos_weight", False):
pos_weight = load_pos_weight(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
log.info("Using pos_weight for class imbalance")
else:
criterion = nn.BCEWithLogitsLoss()
threshold = cfg.get("threshold", 0.5)
save_dir = Path(cfg.get("save_dir", "experiments/checkpoints"))
save_dir.mkdir(parents=True, exist_ok=True)
best_ckpt = save_dir / f"{cfg['run_name']}_best.pt"
# --- MLflow ---
mlflow.set_tracking_uri(str(MLRUNS_DIR))
mlflow.set_experiment(cfg["experiment_name"])
with mlflow.start_run(run_name=cfg["run_name"]):
mlflow.log_params({
"backbone": cfg["backbone"],
"freeze_epochs": cfg["freeze_epochs"],
"unfreeze_epochs": cfg["unfreeze_epochs"],
"batch_size": cfg["batch_size"],
"lr_head": cfg["lr_head"],
"lr_finetune": cfg["lr_finetune"],
"weight_decay": cfg["weight_decay"],
"use_pos_weight": cfg.get("use_pos_weight", False),
"threshold": threshold,
"num_labels": NUM_LABELS,
})
mlflow.log_artifact(str(Path("src/config.py")))
best_val_macro = 0.0
global_step = 0
# ---------------------------------------------------------------
# Phase 1: freeze backbone, train head only
# ---------------------------------------------------------------
freeze_epochs = cfg["freeze_epochs"] if override_epochs is None else min(1, cfg["freeze_epochs"])
freeze_backbone(model)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg["lr_head"],
weight_decay=cfg["weight_decay"],
)
log.info("Phase 1: frozen backbone (%d epochs)", freeze_epochs)
for epoch in range(1, freeze_epochs + 1):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_metrics = evaluate(model, val_loader, criterion, device, threshold)
log.info(
"Epoch %d/%d (freeze) | train_loss=%.4f | val_loss=%.4f | val_micro_f1=%.4f | val_macro_f1=%.4f",
epoch, freeze_epochs, train_loss, val_metrics["loss"],
val_metrics["micro_f1"], val_metrics["macro_f1"],
)
mlflow.log_metrics(
{"train_loss": train_loss, "val_loss": val_metrics["loss"],
"val_micro_f1": val_metrics["micro_f1"], "val_macro_f1": val_metrics["macro_f1"]},
step=global_step,
)
global_step += 1
if val_metrics["macro_f1"] > best_val_macro:
best_val_macro = val_metrics["macro_f1"]
torch.save(model.state_dict(), best_ckpt)
# ---------------------------------------------------------------
# Phase 2: unfreeze everything, fine-tune at lower lr
# ---------------------------------------------------------------
unfreeze_epochs = cfg["unfreeze_epochs"] if override_epochs is None else 1
unfreeze_all(model)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg["lr_finetune"],
weight_decay=cfg["weight_decay"],
)
scheduler = CosineAnnealingLR(optimizer, T_max=unfreeze_epochs, eta_min=1e-6)
log.info("Phase 2: full fine-tune (%d epochs)", unfreeze_epochs)
for epoch in range(1, unfreeze_epochs + 1):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_metrics = evaluate(model, val_loader, criterion, device, threshold)
scheduler.step()
log.info(
"Epoch %d/%d (unfreeze) | train_loss=%.4f | val_loss=%.4f | val_micro_f1=%.4f | val_macro_f1=%.4f",
epoch, unfreeze_epochs, train_loss, val_metrics["loss"],
val_metrics["micro_f1"], val_metrics["macro_f1"],
)
mlflow.log_metrics(
{"train_loss": train_loss, "val_loss": val_metrics["loss"],
"val_micro_f1": val_metrics["micro_f1"], "val_macro_f1": val_metrics["macro_f1"]},
step=global_step,
)
global_step += 1
if val_metrics["macro_f1"] > best_val_macro:
best_val_macro = val_metrics["macro_f1"]
torch.save(model.state_dict(), best_ckpt)
# ---------------------------------------------------------------
# Log final best
# ---------------------------------------------------------------
mlflow.log_metric("best_val_macro_f1", best_val_macro)
mlflow.log_artifact(str(best_ckpt))
log.info("Best val_macro_f1: %.4f | checkpoint: %s", best_val_macro, best_ckpt)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train multi-label road scene model")
parser.add_argument("--config", required=True, help="Path to YAML config file")
parser.add_argument(
"--epochs", type=int, default=None,
help="Override both freeze_epochs and unfreeze_epochs to 1 (quick sanity check)"
)
args = parser.parse_args()
cfg = load_yaml(args.config)
train(cfg, override_epochs=args.epochs)