""" src/ablation.py --------------- Lambda ablation study for the hierarchical KL + MSE loss. Sweeps lambda_kl over [0.0, 0.25, 0.50, 0.75, 1.0] on a 10k subset to justify the choice of lambda_kl = 0.5 used in the proposed model. This ablation is reported in the paper as justification for the balanced KL + MSE formulation. It is run BEFORE full training. Output ------ outputs/figures/ablation/table_lambda_ablation.csv outputs/figures/ablation/fig_lambda_ablation.pdf outputs/figures/ablation/fig_lambda_ablation.png Usage ----- cd ~/galaxy nohup python -m src.ablation --config configs/ablation.yaml \ > outputs/logs/ablation.log 2>&1 & echo "PID: $!" """ import argparse import copy import logging import random import sys import gc from pathlib import Path import numpy as np import pandas as pd import torch import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from torch.amp import autocast, GradScaler from omegaconf import OmegaConf, DictConfig from tqdm import tqdm from src.dataset import build_dataloaders from src.model import build_model from src.loss import HierarchicalLoss from src.metrics import compute_metrics, predictions_to_numpy logging.basicConfig( format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, ) log = logging.getLogger("ablation") LAMBDA_VALUES = [0.0, 0.25, 0.50, 0.75, 1.0] ABLATION_EPOCHS = 15 # sufficient to converge on 10k subset ABLATION_SAMPLES = 10000 def _set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def run_single(cfg: DictConfig, lambda_kl: float) -> dict: """ Train one model with the given lambda_kl on a 10k subset and return test metrics. All other settings are identical across runs. """ _set_seed(cfg.seed) cfg = copy.deepcopy(cfg) cfg.loss.lambda_kl = lambda_kl cfg.loss.lambda_mse = 1.0 - lambda_kl cfg.data.n_samples = ABLATION_SAMPLES cfg.training.epochs = ABLATION_EPOCHS device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader, val_loader, test_loader = build_dataloaders(cfg) model = build_model(cfg).to(device) loss_fn = HierarchicalLoss(cfg) optimizer = torch.optim.AdamW( [ {"params": model.backbone.parameters(), "lr": cfg.training.learning_rate * 0.1}, {"params": model.head.parameters(), "lr": cfg.training.learning_rate}, ], weight_decay=cfg.training.weight_decay, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=ABLATION_EPOCHS, eta_min=1e-6 ) scaler = GradScaler("cuda") best_val = float("inf") best_state = None for epoch in range(1, ABLATION_EPOCHS + 1): # ── train ────────────────────────────────────────────── model.train() for images, targets, weights, _ in tqdm( train_loader, desc=f"λ={lambda_kl:.2f} E{epoch}", leave=False ): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) with autocast("cuda", enabled=True): logits = model(images) loss, _ = loss_fn(logits, targets, weights) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() # ── validate ─────────────────────────────────────────── model.eval() val_loss = 0.0 nb = 0 with torch.no_grad(): for images, targets, weights, _ in val_loader: images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) with autocast("cuda", enabled=True): logits = model(images) loss, _ = loss_fn(logits, targets, weights) val_loss += loss.item() nb += 1 val_loss /= nb log.info(" λ_kl=%.2f epoch=%d val_loss=%.5f", lambda_kl, epoch, val_loss) if val_loss < best_val: best_val = val_loss best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} # ── test evaluation ──────────────────────────────────────── model.load_state_dict(best_state) model.eval() all_preds, all_targets, all_weights = [], [], [] with torch.no_grad(): for images, targets, weights, _ in test_loader: images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) with autocast("cuda", enabled=True): logits = model(images) p, t, w = predictions_to_numpy(logits, targets, weights) all_preds.append(p) all_targets.append(t) all_weights.append(w) all_preds = np.concatenate(all_preds) all_targets = np.concatenate(all_targets) all_weights = np.concatenate(all_weights) metrics = compute_metrics(all_preds, all_targets, all_weights) return { "lambda_kl" : lambda_kl, "lambda_mse" : round(1.0 - lambda_kl, 2), "best_val_loss": round(best_val, 5), "mae_weighted" : round(metrics["mae/weighted_avg"], 5), "rmse_weighted": round(metrics["rmse/weighted_avg"], 5), "ece_mean" : round(metrics["ece/mean"], 5), } def _plot_ablation(df: pd.DataFrame, save_dir: Path): best_row = df.loc[df["mae_weighted"].idxmin()] fig, axes = plt.subplots(1, 3, figsize=(15, 4)) metrics_cfg = [ ("mae_weighted", "Weighted MAE", "#2980b9"), ("rmse_weighted", "Weighted RMSE", "#c0392b"), ("ece_mean", "Mean ECE", "#27ae60"), ] for ax, (col, ylabel, color) in zip(axes, metrics_cfg): ax.plot(df["lambda_kl"], df[col], "-o", color=color, linewidth=2, markersize=8) ax.axvline(best_row["lambda_kl"], color="#7f8c8d", linestyle="--", alpha=0.8, label=f"Best λ = {best_row['lambda_kl']:.2f}") ax.set_xlabel("$\\lambda_{\\mathrm{KL}}$ " "(0 = pure MSE, 1 = pure KL)", fontsize=11) ax.set_ylabel(ylabel, fontsize=11) ax.set_title(f"Lambda ablation — {ylabel}", fontsize=10) ax.legend(fontsize=9) ax.grid(True, alpha=0.3) ax.set_xticks(df["lambda_kl"].tolist()) plt.suptitle( "Ablation study: effect of $\\lambda_{\\mathrm{KL}}$ in the hierarchical loss\n" f"10,000-sample subset, seed=42. Best: $\\lambda_{{\\mathrm{{KL}}}}$" f" = {best_row['lambda_kl']:.2f} (MAE = {best_row['mae_weighted']:.5f})", fontsize=11, y=1.02, ) plt.tight_layout() fig.savefig(save_dir / "fig_lambda_ablation.pdf", dpi=300, bbox_inches="tight") fig.savefig(save_dir / "fig_lambda_ablation.png", dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_lambda_ablation") def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", required=True) args = parser.parse_args() base_cfg = OmegaConf.load("configs/base.yaml") exp_cfg = OmegaConf.load(args.config) cfg = OmegaConf.merge(base_cfg, exp_cfg) save_dir = Path(cfg.outputs.figures_dir) / "ablation" save_dir.mkdir(parents=True, exist_ok=True) results = [] for lam in LAMBDA_VALUES: log.info("=" * 55) log.info("Ablation: lambda_kl=%.2f lambda_mse=%.2f", lam, 1.0 - lam) log.info("=" * 55) result = run_single(cfg, lam) results.append(result) log.info("Result: %s", result) # Free up RAM and GPU memory gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() df = pd.DataFrame(results) df.to_csv(save_dir / "table_lambda_ablation.csv", index=False) log.info("Saved: table_lambda_ablation.csv") print() print(df.to_string(index=False)) print() best = df.loc[df["mae_weighted"].idxmin()] log.info("Best: lambda_kl=%.2f MAE=%.5f RMSE=%.5f", best["lambda_kl"], best["mae_weighted"], best["rmse_weighted"]) _plot_ablation(df, save_dir) if __name__ == "__main__": main()