eshwar-gz2-api / src /ablation.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/ablation.py
---------------
Lambda ablation study for the hierarchical KL + MSE loss.
Sweeps lambda_kl over [0.0, 0.25, 0.50, 0.75, 1.0] on a 10k subset
to justify the choice of lambda_kl = 0.5 used in the proposed model.
This ablation is reported in the paper as justification for the
balanced KL + MSE formulation. It is run BEFORE full training.
Output
------
outputs/figures/ablation/table_lambda_ablation.csv
outputs/figures/ablation/fig_lambda_ablation.pdf
outputs/figures/ablation/fig_lambda_ablation.png
Usage
-----
cd ~/galaxy
nohup python -m src.ablation --config configs/ablation.yaml \
> outputs/logs/ablation.log 2>&1 &
echo "PID: $!"
"""
import argparse
import copy
import logging
import random
import sys
import gc
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler
from omegaconf import OmegaConf, DictConfig
from tqdm import tqdm
from src.dataset import build_dataloaders
from src.model import build_model
from src.loss import HierarchicalLoss
from src.metrics import compute_metrics, predictions_to_numpy
logging.basicConfig(
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
)
log = logging.getLogger("ablation")
LAMBDA_VALUES = [0.0, 0.25, 0.50, 0.75, 1.0]
ABLATION_EPOCHS = 15 # sufficient to converge on 10k subset
ABLATION_SAMPLES = 10000
def _set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def run_single(cfg: DictConfig, lambda_kl: float) -> dict:
"""
Train one model with the given lambda_kl on a 10k subset and
return test metrics. All other settings are identical across runs.
"""
_set_seed(cfg.seed)
cfg = copy.deepcopy(cfg)
cfg.loss.lambda_kl = lambda_kl
cfg.loss.lambda_mse = 1.0 - lambda_kl
cfg.data.n_samples = ABLATION_SAMPLES
cfg.training.epochs = ABLATION_EPOCHS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, val_loader, test_loader = build_dataloaders(cfg)
model = build_model(cfg).to(device)
loss_fn = HierarchicalLoss(cfg)
optimizer = torch.optim.AdamW(
[
{"params": model.backbone.parameters(),
"lr": cfg.training.learning_rate * 0.1},
{"params": model.head.parameters(),
"lr": cfg.training.learning_rate},
],
weight_decay=cfg.training.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=ABLATION_EPOCHS, eta_min=1e-6
)
scaler = GradScaler("cuda")
best_val = float("inf")
best_state = None
for epoch in range(1, ABLATION_EPOCHS + 1):
# ── train ──────────────────────────────────────────────
model.train()
for images, targets, weights, _ in tqdm(
train_loader, desc=f"λ={lambda_kl:.2f} E{epoch}", leave=False
):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with autocast("cuda", enabled=True):
logits = model(images)
loss, _ = loss_fn(logits, targets, weights)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
# ── validate ───────────────────────────────────────────
model.eval()
val_loss = 0.0
nb = 0
with torch.no_grad():
for images, targets, weights, _ in val_loader:
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
with autocast("cuda", enabled=True):
logits = model(images)
loss, _ = loss_fn(logits, targets, weights)
val_loss += loss.item()
nb += 1
val_loss /= nb
log.info(" λ_kl=%.2f epoch=%d val_loss=%.5f", lambda_kl, epoch, val_loss)
if val_loss < best_val:
best_val = val_loss
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
# ── test evaluation ────────────────────────────────────────
model.load_state_dict(best_state)
model.eval()
all_preds, all_targets, all_weights = [], [], []
with torch.no_grad():
for images, targets, weights, _ in test_loader:
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
with autocast("cuda", enabled=True):
logits = model(images)
p, t, w = predictions_to_numpy(logits, targets, weights)
all_preds.append(p)
all_targets.append(t)
all_weights.append(w)
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
all_weights = np.concatenate(all_weights)
metrics = compute_metrics(all_preds, all_targets, all_weights)
return {
"lambda_kl" : lambda_kl,
"lambda_mse" : round(1.0 - lambda_kl, 2),
"best_val_loss": round(best_val, 5),
"mae_weighted" : round(metrics["mae/weighted_avg"], 5),
"rmse_weighted": round(metrics["rmse/weighted_avg"], 5),
"ece_mean" : round(metrics["ece/mean"], 5),
}
def _plot_ablation(df: pd.DataFrame, save_dir: Path):
best_row = df.loc[df["mae_weighted"].idxmin()]
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
metrics_cfg = [
("mae_weighted", "Weighted MAE", "#2980b9"),
("rmse_weighted", "Weighted RMSE", "#c0392b"),
("ece_mean", "Mean ECE", "#27ae60"),
]
for ax, (col, ylabel, color) in zip(axes, metrics_cfg):
ax.plot(df["lambda_kl"], df[col], "-o", color=color,
linewidth=2, markersize=8)
ax.axvline(best_row["lambda_kl"], color="#7f8c8d",
linestyle="--", alpha=0.8,
label=f"Best λ = {best_row['lambda_kl']:.2f}")
ax.set_xlabel("$\\lambda_{\\mathrm{KL}}$ "
"(0 = pure MSE, 1 = pure KL)", fontsize=11)
ax.set_ylabel(ylabel, fontsize=11)
ax.set_title(f"Lambda ablation — {ylabel}", fontsize=10)
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
ax.set_xticks(df["lambda_kl"].tolist())
plt.suptitle(
"Ablation study: effect of $\\lambda_{\\mathrm{KL}}$ in the hierarchical loss\n"
f"10,000-sample subset, seed=42. Best: $\\lambda_{{\\mathrm{{KL}}}}$"
f" = {best_row['lambda_kl']:.2f} (MAE = {best_row['mae_weighted']:.5f})",
fontsize=11, y=1.02,
)
plt.tight_layout()
fig.savefig(save_dir / "fig_lambda_ablation.pdf", dpi=300, bbox_inches="tight")
fig.savefig(save_dir / "fig_lambda_ablation.png", dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_lambda_ablation")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
args = parser.parse_args()
base_cfg = OmegaConf.load("configs/base.yaml")
exp_cfg = OmegaConf.load(args.config)
cfg = OmegaConf.merge(base_cfg, exp_cfg)
save_dir = Path(cfg.outputs.figures_dir) / "ablation"
save_dir.mkdir(parents=True, exist_ok=True)
results = []
for lam in LAMBDA_VALUES:
log.info("=" * 55)
log.info("Ablation: lambda_kl=%.2f lambda_mse=%.2f",
lam, 1.0 - lam)
log.info("=" * 55)
result = run_single(cfg, lam)
results.append(result)
log.info("Result: %s", result)
# Free up RAM and GPU memory
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
df = pd.DataFrame(results)
df.to_csv(save_dir / "table_lambda_ablation.csv", index=False)
log.info("Saved: table_lambda_ablation.csv")
print()
print(df.to_string(index=False))
print()
best = df.loc[df["mae_weighted"].idxmin()]
log.info("Best: lambda_kl=%.2f MAE=%.5f RMSE=%.5f",
best["lambda_kl"], best["mae_weighted"], best["rmse_weighted"])
_plot_ablation(df, save_dir)
if __name__ == "__main__":
main()