""" src/train.py Two-phase training loop: Phase 1 (freeze) — backbone frozen, only head trains for `freeze_epochs` Phase 2 (unfreeze) — entire network trains for `unfreeze_epochs` at lower lr All runs are logged to MLflow. The best checkpoint (by val_macro_f1) is saved. Usage: python -m src.train --config configs/baseline.yaml python -m src.train --config configs/baseline.yaml --epochs 1 # quick check """ import argparse import logging from pathlib import Path import mlflow import torch import torch.nn as nn from sklearn.metrics import f1_score from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader from tqdm import tqdm from src.config import LABELS, MLRUNS_DIR, NUM_LABELS, SEED from src.dataset import BDDMultiLabelDataset, load_pos_weight from src.model import build_model, count_params, freeze_backbone, unfreeze_all from src.utils import get_device, load_yaml, set_seed logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # One epoch helpers # --------------------------------------------------------------------------- def train_one_epoch(model, loader, criterion, optimizer, device) -> float: model.train() total_loss = 0.0 for imgs, labels in tqdm(loader, desc=" train", leave=False): imgs, labels = imgs.to(device), labels.to(device) optimizer.zero_grad() logits = model(imgs) loss = criterion(logits, labels) loss.backward() optimizer.step() total_loss += loss.item() * len(imgs) return total_loss / len(loader.dataset) @torch.no_grad() def evaluate(model, loader, criterion, device, threshold: float = 0.5) -> dict: model.eval() total_loss = 0.0 all_preds, all_targets = [], [] for imgs, labels in tqdm(loader, desc=" eval ", leave=False): imgs, labels = imgs.to(device), labels.to(device) logits = model(imgs) loss = criterion(logits, labels) total_loss += loss.item() * len(imgs) probs = torch.sigmoid(logits).cpu().numpy() preds = (probs >= threshold).astype(int) all_preds.append(preds) all_targets.append(labels.cpu().numpy()) import numpy as np all_preds = np.vstack(all_preds) all_targets = np.vstack(all_targets) micro_f1 = f1_score(all_targets, all_preds, average="micro", zero_division=0) macro_f1 = f1_score(all_targets, all_preds, average="macro", zero_division=0) return { "loss": total_loss / len(loader.dataset), "micro_f1": micro_f1, "macro_f1": macro_f1, } # --------------------------------------------------------------------------- # Main training loop # --------------------------------------------------------------------------- def train(cfg: dict, override_epochs: int | None = None): set_seed(SEED) device = get_device() log.info("Device: %s", device) # --- data --- train_ds = BDDMultiLabelDataset("train") val_ds = BDDMultiLabelDataset("val") num_workers = cfg.get("num_workers", 0) train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"] * 2, shuffle=False, num_workers=num_workers, pin_memory=True) # --- model --- model = build_model().to(device) params = count_params(model) log.info("Params — total: %s trainable: %s", f"{params['total']:,}", f"{params['trainable']:,}") # --- loss --- if cfg.get("use_pos_weight", False): pos_weight = load_pos_weight(device) criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) log.info("Using pos_weight for class imbalance") else: criterion = nn.BCEWithLogitsLoss() threshold = cfg.get("threshold", 0.5) save_dir = Path(cfg.get("save_dir", "experiments/checkpoints")) save_dir.mkdir(parents=True, exist_ok=True) best_ckpt = save_dir / f"{cfg['run_name']}_best.pt" # --- MLflow --- mlflow.set_tracking_uri(str(MLRUNS_DIR)) mlflow.set_experiment(cfg["experiment_name"]) with mlflow.start_run(run_name=cfg["run_name"]): mlflow.log_params({ "backbone": cfg["backbone"], "freeze_epochs": cfg["freeze_epochs"], "unfreeze_epochs": cfg["unfreeze_epochs"], "batch_size": cfg["batch_size"], "lr_head": cfg["lr_head"], "lr_finetune": cfg["lr_finetune"], "weight_decay": cfg["weight_decay"], "use_pos_weight": cfg.get("use_pos_weight", False), "threshold": threshold, "num_labels": NUM_LABELS, }) mlflow.log_artifact(str(Path("src/config.py"))) best_val_macro = 0.0 global_step = 0 # --------------------------------------------------------------- # Phase 1: freeze backbone, train head only # --------------------------------------------------------------- freeze_epochs = cfg["freeze_epochs"] if override_epochs is None else min(1, cfg["freeze_epochs"]) freeze_backbone(model) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["lr_head"], weight_decay=cfg["weight_decay"], ) log.info("Phase 1: frozen backbone (%d epochs)", freeze_epochs) for epoch in range(1, freeze_epochs + 1): train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) val_metrics = evaluate(model, val_loader, criterion, device, threshold) log.info( "Epoch %d/%d (freeze) | train_loss=%.4f | val_loss=%.4f | val_micro_f1=%.4f | val_macro_f1=%.4f", epoch, freeze_epochs, train_loss, val_metrics["loss"], val_metrics["micro_f1"], val_metrics["macro_f1"], ) mlflow.log_metrics( {"train_loss": train_loss, "val_loss": val_metrics["loss"], "val_micro_f1": val_metrics["micro_f1"], "val_macro_f1": val_metrics["macro_f1"]}, step=global_step, ) global_step += 1 if val_metrics["macro_f1"] > best_val_macro: best_val_macro = val_metrics["macro_f1"] torch.save(model.state_dict(), best_ckpt) # --------------------------------------------------------------- # Phase 2: unfreeze everything, fine-tune at lower lr # --------------------------------------------------------------- unfreeze_epochs = cfg["unfreeze_epochs"] if override_epochs is None else 1 unfreeze_all(model) optimizer = torch.optim.AdamW( model.parameters(), lr=cfg["lr_finetune"], weight_decay=cfg["weight_decay"], ) scheduler = CosineAnnealingLR(optimizer, T_max=unfreeze_epochs, eta_min=1e-6) log.info("Phase 2: full fine-tune (%d epochs)", unfreeze_epochs) for epoch in range(1, unfreeze_epochs + 1): train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) val_metrics = evaluate(model, val_loader, criterion, device, threshold) scheduler.step() log.info( "Epoch %d/%d (unfreeze) | train_loss=%.4f | val_loss=%.4f | val_micro_f1=%.4f | val_macro_f1=%.4f", epoch, unfreeze_epochs, train_loss, val_metrics["loss"], val_metrics["micro_f1"], val_metrics["macro_f1"], ) mlflow.log_metrics( {"train_loss": train_loss, "val_loss": val_metrics["loss"], "val_micro_f1": val_metrics["micro_f1"], "val_macro_f1": val_metrics["macro_f1"]}, step=global_step, ) global_step += 1 if val_metrics["macro_f1"] > best_val_macro: best_val_macro = val_metrics["macro_f1"] torch.save(model.state_dict(), best_ckpt) # --------------------------------------------------------------- # Log final best # --------------------------------------------------------------- mlflow.log_metric("best_val_macro_f1", best_val_macro) mlflow.log_artifact(str(best_ckpt)) log.info("Best val_macro_f1: %.4f | checkpoint: %s", best_val_macro, best_ckpt) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train multi-label road scene model") parser.add_argument("--config", required=True, help="Path to YAML config file") parser.add_argument( "--epochs", type=int, default=None, help="Override both freeze_epochs and unfreeze_epochs to 1 (quick sanity check)" ) args = parser.parse_args() cfg = load_yaml(args.config) train(cfg, override_epochs=args.epochs)