Spaces:
Running
Running
| """ | |
| src/ablation.py | |
| --------------- | |
| Lambda ablation study for the hierarchical KL + MSE loss. | |
| Sweeps lambda_kl over [0.0, 0.25, 0.50, 0.75, 1.0] on a 10k subset | |
| to justify the choice of lambda_kl = 0.5 used in the proposed model. | |
| This ablation is reported in the paper as justification for the | |
| balanced KL + MSE formulation. It is run BEFORE full training. | |
| Output | |
| ------ | |
| outputs/figures/ablation/table_lambda_ablation.csv | |
| outputs/figures/ablation/fig_lambda_ablation.pdf | |
| outputs/figures/ablation/fig_lambda_ablation.png | |
| Usage | |
| ----- | |
| cd ~/galaxy | |
| nohup python -m src.ablation --config configs/ablation.yaml \ | |
| > outputs/logs/ablation.log 2>&1 & | |
| echo "PID: $!" | |
| """ | |
| import argparse | |
| import copy | |
| import logging | |
| import random | |
| import sys | |
| import gc | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from torch.amp import autocast, GradScaler | |
| from omegaconf import OmegaConf, DictConfig | |
| from tqdm import tqdm | |
| from src.dataset import build_dataloaders | |
| from src.model import build_model | |
| from src.loss import HierarchicalLoss | |
| from src.metrics import compute_metrics, predictions_to_numpy | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, | |
| ) | |
| log = logging.getLogger("ablation") | |
| LAMBDA_VALUES = [0.0, 0.25, 0.50, 0.75, 1.0] | |
| ABLATION_EPOCHS = 15 # sufficient to converge on 10k subset | |
| ABLATION_SAMPLES = 10000 | |
| def _set_seed(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def run_single(cfg: DictConfig, lambda_kl: float) -> dict: | |
| """ | |
| Train one model with the given lambda_kl on a 10k subset and | |
| return test metrics. All other settings are identical across runs. | |
| """ | |
| _set_seed(cfg.seed) | |
| cfg = copy.deepcopy(cfg) | |
| cfg.loss.lambda_kl = lambda_kl | |
| cfg.loss.lambda_mse = 1.0 - lambda_kl | |
| cfg.data.n_samples = ABLATION_SAMPLES | |
| cfg.training.epochs = ABLATION_EPOCHS | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| train_loader, val_loader, test_loader = build_dataloaders(cfg) | |
| model = build_model(cfg).to(device) | |
| loss_fn = HierarchicalLoss(cfg) | |
| optimizer = torch.optim.AdamW( | |
| [ | |
| {"params": model.backbone.parameters(), | |
| "lr": cfg.training.learning_rate * 0.1}, | |
| {"params": model.head.parameters(), | |
| "lr": cfg.training.learning_rate}, | |
| ], | |
| weight_decay=cfg.training.weight_decay, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=ABLATION_EPOCHS, eta_min=1e-6 | |
| ) | |
| scaler = GradScaler("cuda") | |
| best_val = float("inf") | |
| best_state = None | |
| for epoch in range(1, ABLATION_EPOCHS + 1): | |
| # ── train ────────────────────────────────────────────── | |
| model.train() | |
| for images, targets, weights, _ in tqdm( | |
| train_loader, desc=f"λ={lambda_kl:.2f} E{epoch}", leave=False | |
| ): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| optimizer.zero_grad(set_to_none=True) | |
| with autocast("cuda", enabled=True): | |
| logits = model(images) | |
| loss, _ = loss_fn(logits, targets, weights) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| # ── validate ─────────────────────────────────────────── | |
| model.eval() | |
| val_loss = 0.0 | |
| nb = 0 | |
| with torch.no_grad(): | |
| for images, targets, weights, _ in val_loader: | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=True): | |
| logits = model(images) | |
| loss, _ = loss_fn(logits, targets, weights) | |
| val_loss += loss.item() | |
| nb += 1 | |
| val_loss /= nb | |
| log.info(" λ_kl=%.2f epoch=%d val_loss=%.5f", lambda_kl, epoch, val_loss) | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} | |
| # ── test evaluation ──────────────────────────────────────── | |
| model.load_state_dict(best_state) | |
| model.eval() | |
| all_preds, all_targets, all_weights = [], [], [] | |
| with torch.no_grad(): | |
| for images, targets, weights, _ in test_loader: | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=True): | |
| logits = model(images) | |
| p, t, w = predictions_to_numpy(logits, targets, weights) | |
| all_preds.append(p) | |
| all_targets.append(t) | |
| all_weights.append(w) | |
| all_preds = np.concatenate(all_preds) | |
| all_targets = np.concatenate(all_targets) | |
| all_weights = np.concatenate(all_weights) | |
| metrics = compute_metrics(all_preds, all_targets, all_weights) | |
| return { | |
| "lambda_kl" : lambda_kl, | |
| "lambda_mse" : round(1.0 - lambda_kl, 2), | |
| "best_val_loss": round(best_val, 5), | |
| "mae_weighted" : round(metrics["mae/weighted_avg"], 5), | |
| "rmse_weighted": round(metrics["rmse/weighted_avg"], 5), | |
| "ece_mean" : round(metrics["ece/mean"], 5), | |
| } | |
| def _plot_ablation(df: pd.DataFrame, save_dir: Path): | |
| best_row = df.loc[df["mae_weighted"].idxmin()] | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 4)) | |
| metrics_cfg = [ | |
| ("mae_weighted", "Weighted MAE", "#2980b9"), | |
| ("rmse_weighted", "Weighted RMSE", "#c0392b"), | |
| ("ece_mean", "Mean ECE", "#27ae60"), | |
| ] | |
| for ax, (col, ylabel, color) in zip(axes, metrics_cfg): | |
| ax.plot(df["lambda_kl"], df[col], "-o", color=color, | |
| linewidth=2, markersize=8) | |
| ax.axvline(best_row["lambda_kl"], color="#7f8c8d", | |
| linestyle="--", alpha=0.8, | |
| label=f"Best λ = {best_row['lambda_kl']:.2f}") | |
| ax.set_xlabel("$\\lambda_{\\mathrm{KL}}$ " | |
| "(0 = pure MSE, 1 = pure KL)", fontsize=11) | |
| ax.set_ylabel(ylabel, fontsize=11) | |
| ax.set_title(f"Lambda ablation — {ylabel}", fontsize=10) | |
| ax.legend(fontsize=9) | |
| ax.grid(True, alpha=0.3) | |
| ax.set_xticks(df["lambda_kl"].tolist()) | |
| plt.suptitle( | |
| "Ablation study: effect of $\\lambda_{\\mathrm{KL}}$ in the hierarchical loss\n" | |
| f"10,000-sample subset, seed=42. Best: $\\lambda_{{\\mathrm{{KL}}}}$" | |
| f" = {best_row['lambda_kl']:.2f} (MAE = {best_row['mae_weighted']:.5f})", | |
| fontsize=11, y=1.02, | |
| ) | |
| plt.tight_layout() | |
| fig.savefig(save_dir / "fig_lambda_ablation.pdf", dpi=300, bbox_inches="tight") | |
| fig.savefig(save_dir / "fig_lambda_ablation.png", dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_lambda_ablation") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True) | |
| args = parser.parse_args() | |
| base_cfg = OmegaConf.load("configs/base.yaml") | |
| exp_cfg = OmegaConf.load(args.config) | |
| cfg = OmegaConf.merge(base_cfg, exp_cfg) | |
| save_dir = Path(cfg.outputs.figures_dir) / "ablation" | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| results = [] | |
| for lam in LAMBDA_VALUES: | |
| log.info("=" * 55) | |
| log.info("Ablation: lambda_kl=%.2f lambda_mse=%.2f", | |
| lam, 1.0 - lam) | |
| log.info("=" * 55) | |
| result = run_single(cfg, lam) | |
| results.append(result) | |
| log.info("Result: %s", result) | |
| # Free up RAM and GPU memory | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| df = pd.DataFrame(results) | |
| df.to_csv(save_dir / "table_lambda_ablation.csv", index=False) | |
| log.info("Saved: table_lambda_ablation.csv") | |
| print() | |
| print(df.to_string(index=False)) | |
| print() | |
| best = df.loc[df["mae_weighted"].idxmin()] | |
| log.info("Best: lambda_kl=%.2f MAE=%.5f RMSE=%.5f", | |
| best["lambda_kl"], best["mae_weighted"], best["rmse_weighted"]) | |
| _plot_ablation(df, save_dir) | |
| if __name__ == "__main__": | |
| main() | |