| """ |
| src/train.py |
| |
| Two-phase training loop: |
| Phase 1 (freeze) — backbone frozen, only head trains for `freeze_epochs` |
| Phase 2 (unfreeze) — entire network trains for `unfreeze_epochs` at lower lr |
| |
| All runs are logged to MLflow. The best checkpoint (by val_macro_f1) is saved. |
| |
| Usage: |
| python -m src.train --config configs/baseline.yaml |
| python -m src.train --config configs/baseline.yaml --epochs 1 # quick check |
| """ |
|
|
| import argparse |
| import logging |
| from pathlib import Path |
|
|
| import mlflow |
| import torch |
| import torch.nn as nn |
| from sklearn.metrics import f1_score |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from src.config import LABELS, MLRUNS_DIR, NUM_LABELS, SEED |
| from src.dataset import BDDMultiLabelDataset, load_pos_weight |
| from src.model import build_model, count_params, freeze_backbone, unfreeze_all |
| from src.utils import get_device, load_yaml, set_seed |
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| def train_one_epoch(model, loader, criterion, optimizer, device) -> float: |
| model.train() |
| total_loss = 0.0 |
| for imgs, labels in tqdm(loader, desc=" train", leave=False): |
| imgs, labels = imgs.to(device), labels.to(device) |
| optimizer.zero_grad() |
| logits = model(imgs) |
| loss = criterion(logits, labels) |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() * len(imgs) |
| return total_loss / len(loader.dataset) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, criterion, device, threshold: float = 0.5) -> dict: |
| model.eval() |
| total_loss = 0.0 |
| all_preds, all_targets = [], [] |
|
|
| for imgs, labels in tqdm(loader, desc=" eval ", leave=False): |
| imgs, labels = imgs.to(device), labels.to(device) |
| logits = model(imgs) |
| loss = criterion(logits, labels) |
| total_loss += loss.item() * len(imgs) |
|
|
| probs = torch.sigmoid(logits).cpu().numpy() |
| preds = (probs >= threshold).astype(int) |
| all_preds.append(preds) |
| all_targets.append(labels.cpu().numpy()) |
|
|
| import numpy as np |
| all_preds = np.vstack(all_preds) |
| all_targets = np.vstack(all_targets) |
|
|
| micro_f1 = f1_score(all_targets, all_preds, average="micro", zero_division=0) |
| macro_f1 = f1_score(all_targets, all_preds, average="macro", zero_division=0) |
|
|
| return { |
| "loss": total_loss / len(loader.dataset), |
| "micro_f1": micro_f1, |
| "macro_f1": macro_f1, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def train(cfg: dict, override_epochs: int | None = None): |
| set_seed(SEED) |
| device = get_device() |
| log.info("Device: %s", device) |
|
|
| |
| train_ds = BDDMultiLabelDataset("train") |
| val_ds = BDDMultiLabelDataset("val") |
| num_workers = cfg.get("num_workers", 0) |
| train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True, |
| num_workers=num_workers, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"] * 2, shuffle=False, |
| num_workers=num_workers, pin_memory=True) |
|
|
| |
| model = build_model().to(device) |
| params = count_params(model) |
| log.info("Params — total: %s trainable: %s", f"{params['total']:,}", f"{params['trainable']:,}") |
|
|
| |
| if cfg.get("use_pos_weight", False): |
| pos_weight = load_pos_weight(device) |
| criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
| log.info("Using pos_weight for class imbalance") |
| else: |
| criterion = nn.BCEWithLogitsLoss() |
|
|
| threshold = cfg.get("threshold", 0.5) |
| save_dir = Path(cfg.get("save_dir", "experiments/checkpoints")) |
| save_dir.mkdir(parents=True, exist_ok=True) |
| best_ckpt = save_dir / f"{cfg['run_name']}_best.pt" |
|
|
| |
| mlflow.set_tracking_uri(str(MLRUNS_DIR)) |
| mlflow.set_experiment(cfg["experiment_name"]) |
|
|
| with mlflow.start_run(run_name=cfg["run_name"]): |
| mlflow.log_params({ |
| "backbone": cfg["backbone"], |
| "freeze_epochs": cfg["freeze_epochs"], |
| "unfreeze_epochs": cfg["unfreeze_epochs"], |
| "batch_size": cfg["batch_size"], |
| "lr_head": cfg["lr_head"], |
| "lr_finetune": cfg["lr_finetune"], |
| "weight_decay": cfg["weight_decay"], |
| "use_pos_weight": cfg.get("use_pos_weight", False), |
| "threshold": threshold, |
| "num_labels": NUM_LABELS, |
| }) |
| mlflow.log_artifact(str(Path("src/config.py"))) |
|
|
| best_val_macro = 0.0 |
| global_step = 0 |
|
|
| |
| |
| |
| freeze_epochs = cfg["freeze_epochs"] if override_epochs is None else min(1, cfg["freeze_epochs"]) |
| freeze_backbone(model) |
| optimizer = torch.optim.Adam( |
| filter(lambda p: p.requires_grad, model.parameters()), |
| lr=cfg["lr_head"], |
| weight_decay=cfg["weight_decay"], |
| ) |
|
|
| log.info("Phase 1: frozen backbone (%d epochs)", freeze_epochs) |
| for epoch in range(1, freeze_epochs + 1): |
| train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) |
| val_metrics = evaluate(model, val_loader, criterion, device, threshold) |
|
|
| log.info( |
| "Epoch %d/%d (freeze) | train_loss=%.4f | val_loss=%.4f | val_micro_f1=%.4f | val_macro_f1=%.4f", |
| epoch, freeze_epochs, train_loss, val_metrics["loss"], |
| val_metrics["micro_f1"], val_metrics["macro_f1"], |
| ) |
| mlflow.log_metrics( |
| {"train_loss": train_loss, "val_loss": val_metrics["loss"], |
| "val_micro_f1": val_metrics["micro_f1"], "val_macro_f1": val_metrics["macro_f1"]}, |
| step=global_step, |
| ) |
| global_step += 1 |
|
|
| if val_metrics["macro_f1"] > best_val_macro: |
| best_val_macro = val_metrics["macro_f1"] |
| torch.save(model.state_dict(), best_ckpt) |
|
|
| |
| |
| |
| unfreeze_epochs = cfg["unfreeze_epochs"] if override_epochs is None else 1 |
| unfreeze_all(model) |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=cfg["lr_finetune"], |
| weight_decay=cfg["weight_decay"], |
| ) |
| scheduler = CosineAnnealingLR(optimizer, T_max=unfreeze_epochs, eta_min=1e-6) |
|
|
| log.info("Phase 2: full fine-tune (%d epochs)", unfreeze_epochs) |
| for epoch in range(1, unfreeze_epochs + 1): |
| train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) |
| val_metrics = evaluate(model, val_loader, criterion, device, threshold) |
| scheduler.step() |
|
|
| log.info( |
| "Epoch %d/%d (unfreeze) | train_loss=%.4f | val_loss=%.4f | val_micro_f1=%.4f | val_macro_f1=%.4f", |
| epoch, unfreeze_epochs, train_loss, val_metrics["loss"], |
| val_metrics["micro_f1"], val_metrics["macro_f1"], |
| ) |
| mlflow.log_metrics( |
| {"train_loss": train_loss, "val_loss": val_metrics["loss"], |
| "val_micro_f1": val_metrics["micro_f1"], "val_macro_f1": val_metrics["macro_f1"]}, |
| step=global_step, |
| ) |
| global_step += 1 |
|
|
| if val_metrics["macro_f1"] > best_val_macro: |
| best_val_macro = val_metrics["macro_f1"] |
| torch.save(model.state_dict(), best_ckpt) |
|
|
| |
| |
| |
| mlflow.log_metric("best_val_macro_f1", best_val_macro) |
| mlflow.log_artifact(str(best_ckpt)) |
| log.info("Best val_macro_f1: %.4f | checkpoint: %s", best_val_macro, best_ckpt) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Train multi-label road scene model") |
| parser.add_argument("--config", required=True, help="Path to YAML config file") |
| parser.add_argument( |
| "--epochs", type=int, default=None, |
| help="Override both freeze_epochs and unfreeze_epochs to 1 (quick sanity check)" |
| ) |
| args = parser.parse_args() |
|
|
| cfg = load_yaml(args.config) |
| train(cfg, override_epochs=args.epochs) |
|
|