eshwar-gz2-api / src /baselines.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/baselines.py
----------------
Consolidated baseline training for the GZ2 hierarchical probabilistic
regression paper. ALL baselines are trained from this single script.
Replaces the three separate scripts:
src/baselines.py (was: ResNet-18 MSE + ViT MSE)
src/run_resnet_kl.py (was: ResNet-18 KL+MSE β€” now merged here)
src/train_dirichlet.py (was: ViT Dirichlet β€” now merged here)
DELETE those three original files after switching to this one.
Baselines trained
-----------------
B1. ResNet-18 + independent MSE (sigmoid)
β€” CNN, no hierarchy, no KL. Demonstrates the cost of
ignoring the decision-tree structure.
B2. ResNet-18 + hierarchical KL+MSE
β€” Same loss as proposed, CNN backbone.
Isolates ViT vs. CNN contribution.
B3. ViT-Base + hierarchical MSE only (no KL)
β€” Same backbone as proposed, KL term removed.
Isolates contribution of the KL term.
B4. ViT-Base + Dirichlet NLL (Zoobot-style)
β€” Direct comparison with the established Zoobot approach
(Walmsley et al. 2022, MNRAS 509, 3966).
Proposed model (not trained here β€” trained via src/train.py):
ViT-Base + hierarchical KL+MSE β†’ outputs/checkpoints/best_full_train.pt
Consistency guarantee
---------------------
All baselines use identical:
- Random seed, data split, batch size, epochs, early stopping
- AdamW optimiser, CosineAnnealingLR, gradient clipping
- Image transforms and evaluation metric (compute_metrics on same test split)
The ONLY differences between models are the backbone and/or loss function.
Usage
-----
cd ~/galaxy
nohup python -m src.baselines --config configs/full_train.yaml \
> outputs/logs/baselines.log 2>&1 &
echo "PID: $!"
"""
import argparse
import logging
import random
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import timm
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler
from omegaconf import OmegaConf
from tqdm import tqdm
import wandb
from src.dataset import build_dataloaders, QUESTION_GROUPS
from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss
from src.metrics import (compute_metrics, predictions_to_numpy,
dirichlet_predictions_to_numpy, simplex_violation_rate)
from src.model import build_model, build_dirichlet_model
logging.basicConfig(
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
)
log = logging.getLogger("baselines")
QUESTION_LABELS = {
"t01": "Smooth or features", "t02": "Edge-on disk",
"t03": "Bar", "t04": "Spiral arms",
"t05": "Bulge prominence", "t06": "Odd feature",
"t07": "Roundedness", "t08": "Odd feature type",
"t09": "Bulge shape", "t10": "Arms winding",
"t11": "Arms number",
}
# ─────────────────────────────────────────────────────────────
# Reproducibility
# ─────────────────────────────────────────────────────────────
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ─────────────────────────────────────────────────────────────
# Early stopping (mirrors train.py exactly)
# ─────────────────────────────────────────────────────────────
class EarlyStopping:
def __init__(self, patience, min_delta, checkpoint_path):
self.patience = patience
self.min_delta = min_delta
self.checkpoint_path = checkpoint_path
self.best_loss = float("inf")
self.counter = 0
self.best_epoch = 0
def step(self, val_loss, model, epoch) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.best_epoch = epoch
torch.save(
{"epoch": epoch, "model_state": model.state_dict(),
"val_loss": val_loss},
self.checkpoint_path,
)
log.info(" [ckpt] saved val_loss=%.6f epoch=%d", val_loss, epoch)
else:
self.counter += 1
log.info(" [early_stop] %d/%d best=%.6f",
self.counter, self.patience, self.best_loss)
return self.counter >= self.patience
def restore_best(self, model) -> float:
ckpt = torch.load(self.checkpoint_path, map_location="cpu",
weights_only=True)
model.load_state_dict(ckpt["model_state"])
log.info("Restored best weights epoch=%d val_loss=%.6f",
ckpt["epoch"], ckpt["val_loss"])
return ckpt["val_loss"]
# ─────────────────────────────────────────────────────────────
# Baseline Model 1: ResNet-18 + independent MSE
# ─────────────────────────────────────────────────────────────
class ResNet18Baseline(nn.Module):
"""
ResNet-18 pretrained on ImageNet with a dropout + linear head.
Used for both the sigmoid-MSE baseline and the KL+MSE baseline.
"""
def __init__(self, dropout: float = 0.3):
super().__init__()
self.backbone = timm.create_model(
"resnet18", pretrained=True, num_classes=0
)
self.head = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(self.backbone.num_features, 37),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.head(self.backbone(x))
class IndependentMSELoss(nn.Module):
"""
Plain MSE over all 37 targets independently.
No hierarchical weighting, no KL divergence.
Sigmoid applied to predictions before MSE to constrain range [0,1].
Note: predictions do NOT sum to 1 per question group by construction.
This is documented and the simplex_violation_rate metric quantifies
this invalidity to allow fair comparison with the proposed method.
"""
def forward(self, predictions, targets, weights):
pred_prob = torch.sigmoid(predictions)
loss = F.mse_loss(pred_prob, targets)
return loss, {"loss/total": loss.detach().item()}
# ─────────────────────────────────────────────────────────────
# Shared training loop
# ─────────────────────────────────────────────────────────────
def _train_epoch(model, loader, loss_fn, optimizer, scaler,
device, cfg, epoch, label):
model.train()
total = 0.0
nb = 0
for images, targets, weights, _ in tqdm(
loader, desc=f"{label} E{epoch}", leave=False
):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
logits = model(images)
loss, _ = loss_fn(logits, targets, weights)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
scaler.step(optimizer)
scaler.update()
total += loss.item()
nb += 1
return total / nb
def _train_epoch_dirichlet(model, loader, loss_fn, optimizer, scaler,
device, cfg, epoch, label):
"""Training epoch for Dirichlet model (outputs alpha, not logits)."""
model.train()
total = 0.0
nb = 0
for images, targets, weights, _ in tqdm(
loader, desc=f"{label} E{epoch}", leave=False
):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
alpha = model(images)
loss, _ = loss_fn(alpha, targets, weights)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
scaler.step(optimizer)
scaler.update()
total += loss.item()
nb += 1
return total / nb
def _val_epoch(model, loader, loss_fn, device, cfg, epoch, label,
use_sigmoid=False):
model.eval()
total = 0.0
nb = 0
all_preds, all_targets, all_weights = [], [], []
with torch.no_grad():
for images, targets, weights, _ in tqdm(
loader, desc=f"{label} Val E{epoch}", leave=False
):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
logits = model(images)
loss, _ = loss_fn(logits, targets, weights)
total += loss.item()
nb += 1
if use_sigmoid:
pred_prob = torch.sigmoid(logits).detach().cpu().numpy()
else:
pred_cpu = logits.detach().cpu().clone()
for q, (s, e) in QUESTION_GROUPS.items():
pred_cpu[:, s:e] = torch.softmax(pred_cpu[:, s:e], dim=-1)
pred_prob = pred_cpu.numpy()
all_preds.append(pred_prob)
all_targets.append(targets.detach().cpu().numpy())
all_weights.append(weights.detach().cpu().numpy())
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
all_weights = np.concatenate(all_weights)
metrics = compute_metrics(all_preds, all_targets, all_weights)
return total / nb, metrics
def _val_epoch_dirichlet(model, loader, loss_fn, device, cfg, epoch, label):
model.eval()
total = 0.0
nb = 0
all_preds, all_targets, all_weights = [], [], []
with torch.no_grad():
for images, targets, weights, _ in tqdm(
loader, desc=f"{label} Val E{epoch}", leave=False
):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
alpha = model(images)
loss, _ = loss_fn(alpha, targets, weights)
total += loss.item()
nb += 1
p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights)
all_preds.append(p)
all_targets.append(t)
all_weights.append(w)
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
all_weights = np.concatenate(all_weights)
metrics = compute_metrics(all_preds, all_targets, all_weights)
return total / nb, metrics
# ─────────────────────────────────────────────────────────────
# Generic train_and_evaluate (non-Dirichlet)
# ─────────────────────────────────────────────────────────────
def train_and_evaluate(
model, loss_fn, cfg, device,
label, checkpoint_path,
use_layerwise_lr=True,
use_sigmoid=False,
):
"""
Full training loop consistent with train.py.
Returns (test_metrics, best_val_loss, best_epoch, history).
If checkpoint exists, loads it and skips training.
"""
# Check if checkpoint exists - if so, skip training
if Path(checkpoint_path).exists():
log.info("%s: checkpoint found - loading and skipping training", label)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state"])
best_epoch = ckpt.get("epoch", 0)
best_val = ckpt.get("val_loss", float("inf"))
log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val)
# Evaluate on test set
_, _, test_loader = build_dataloaders(cfg)
_, test_metrics = _val_epoch(
model, test_loader, loss_fn, device, cfg,
epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid
)
return test_metrics, best_val, best_epoch, []
train_loader, val_loader, test_loader = build_dataloaders(cfg)
if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"):
optimizer = torch.optim.AdamW(
[
{"params": model.backbone.parameters(),
"lr": cfg.training.learning_rate * 0.1},
{"params": model.head.parameters(),
"lr": cfg.training.learning_rate},
],
weight_decay=cfg.training.weight_decay,
)
log.info("%s: layer-wise lr β€” backbone=%.1e head=%.1e",
label, cfg.training.learning_rate * 0.1, cfg.training.learning_rate)
else:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg.training.learning_rate,
weight_decay=cfg.training.weight_decay,
)
log.info("%s: single lr=%.1e", label, cfg.training.learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
)
scaler = GradScaler("cuda")
early_stop = EarlyStopping(
patience=cfg.early_stopping.patience,
min_delta=cfg.early_stopping.min_delta,
checkpoint_path=checkpoint_path,
)
wandb.init(
project=cfg.wandb.project,
name=label,
config={
"model": label, "backbone": "resnet18" if "ResNet" in label else "vit_base_patch16_224",
"batch_size": cfg.training.batch_size, "lr": cfg.training.learning_rate,
"epochs": cfg.training.epochs, "seed": cfg.seed,
"lambda_kl": cfg.loss.lambda_kl, "lambda_mse": cfg.loss.lambda_mse,
},
reinit=True,
)
history = []
for epoch in range(1, cfg.training.epochs + 1):
train_loss = _train_epoch(
model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label
)
val_loss, val_metrics = _val_epoch(
model, val_loader, loss_fn, device, cfg, epoch, label,
use_sigmoid=use_sigmoid
)
scheduler.step()
lr = scheduler.get_last_lr()[0]
val_mae = val_metrics.get("mae/weighted_avg", 0)
log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
label, epoch, train_loss, val_loss, val_mae, lr)
history.append({
"epoch": epoch, "train_loss": train_loss,
"val_loss": val_loss, "val_mae": val_mae,
})
wandb.log({
"train_loss": train_loss, "val_loss": val_loss,
"val_mae": val_mae, "lr": lr,
}, step=epoch)
if early_stop.step(val_loss, model, epoch):
log.info("%s: early stopping at epoch %d best=%d",
label, epoch, early_stop.best_epoch)
break
best_val = early_stop.restore_best(model)
wandb.finish()
log.info("%s: evaluating on test set...", label)
_, test_metrics = _val_epoch(
model, test_loader, loss_fn, device, cfg,
epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid
)
return test_metrics, best_val, early_stop.best_epoch, history
# ─────────────────────────────────────────────────────────────
# Dirichlet train_and_evaluate
# ─────────────────────────────────────────────────────────────
def train_and_evaluate_dirichlet(model, loss_fn, cfg, device,
label, checkpoint_path):
"""Training loop for Dirichlet model. Skips training if checkpoint exists."""
# Check if checkpoint exists - if so, skip training
if Path(checkpoint_path).exists():
log.info("%s: checkpoint found - loading and skipping training", label)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state"])
best_epoch = ckpt.get("epoch", 0)
best_val = ckpt.get("val_loss", float("inf"))
log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val)
# Evaluate on test set
_, _, test_loader = build_dataloaders(cfg)
_, test_metrics = _val_epoch_dirichlet(
model, test_loader, loss_fn, device, cfg,
epoch=0, label=f"{label}-test"
)
return test_metrics, best_val, best_epoch, []
train_loader, val_loader, test_loader = build_dataloaders(cfg)
optimizer = torch.optim.AdamW(
[
{"params": model.backbone.parameters(),
"lr": cfg.training.learning_rate * 0.1},
{"params": model.head.parameters(),
"lr": cfg.training.learning_rate},
],
weight_decay=cfg.training.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
)
scaler = GradScaler("cuda")
early_stop = EarlyStopping(
patience=cfg.early_stopping.patience,
min_delta=cfg.early_stopping.min_delta,
checkpoint_path=checkpoint_path,
)
wandb.init(
project=cfg.wandb.project, name=label,
config={"model": label, "loss": "DirichletNLL",
"seed": cfg.seed, "epochs": cfg.training.epochs},
reinit=True,
)
history = []
for epoch in range(1, cfg.training.epochs + 1):
train_loss = _train_epoch_dirichlet(
model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label
)
val_loss, val_metrics = _val_epoch_dirichlet(
model, val_loader, loss_fn, device, cfg, epoch, label
)
scheduler.step()
lr = scheduler.get_last_lr()[0]
val_mae = val_metrics.get("mae/weighted_avg", 0)
log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
label, epoch, train_loss, val_loss, val_mae, lr)
history.append({
"epoch": epoch, "train_loss": train_loss,
"val_loss": val_loss, "val_mae": val_mae,
})
wandb.log({
"train_loss": train_loss, "val_loss": val_loss,
"val_mae": val_mae, "lr": lr,
}, step=epoch)
if early_stop.step(val_loss, model, epoch):
log.info("%s: early stopping at epoch %d", label, epoch)
break
best_val = early_stop.restore_best(model)
wandb.finish()
log.info("%s: evaluating on test set...", label)
_, test_metrics = _val_epoch_dirichlet(
model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{label}-test"
)
return test_metrics, best_val, early_stop.best_epoch, history
# ─────────────────────────────────────────────────────────────
# Figures
# ─────────────────────────────────────────────────────────────
def _save_comparison_figures(all_results, all_histories, save_dir):
"""
Saves:
1. Per-question MAE + RMSE bar chart
2. Validation MAE learning curves
3. Simplex violation table for sigmoid baseline
All figure names follow IEEE journal conventions.
"""
q_names = list(QUESTION_GROUPS.keys())
n_models = len(all_results)
x = np.arange(len(q_names))
width = 0.80 / n_models
palette = ["#c0392b", "#e67e22", "#2980b9", "#27ae60", "#8e44ad"]
# ── Figure 1: Per-question MAE and RMSE ───────────────────
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
for metric, ax, ylabel in [
("mae", axes[0], "Mean Absolute Error (MAE)"),
("rmse", axes[1], "Root Mean Squared Error (RMSE)"),
]:
for i, (row_d, color) in enumerate(zip(all_results, palette)):
vals = [row_d.get(f"{metric}_{q}", np.nan) for q in q_names]
ax.bar(x + i * width, vals, width,
label=row_d["model"], color=color,
alpha=0.85, edgecolor="white", linewidth=0.5)
ax.set_xticks(x + width * (n_models - 1) / 2)
ax.set_xticklabels(
[f"{q}\n({QUESTION_LABELS[q][:10]})" for q in q_names],
rotation=45, ha="right", fontsize=7,
)
ax.set_ylabel(ylabel, fontsize=11)
ax.set_title(f"Per-question {metric.upper()} β€” baseline comparison", fontsize=11)
ax.legend(fontsize=7, loc="upper right")
ax.grid(True, alpha=0.3, axis="y")
ax.set_axisbelow(True)
plt.suptitle(
"Baseline comparison β€” GZ2 hierarchical probabilistic regression\n"
"Full 239,267-sample dataset, identical seed/split/protocol",
fontsize=12, y=1.02,
)
plt.tight_layout()
fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.pdf",
dpi=300, bbox_inches="tight")
fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.png",
dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_baseline_comparison_mae_rmse")
# ── Figure 2: Validation MAE learning curves ───────────────
fig2, ax2 = plt.subplots(figsize=(10, 5))
styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1))]
markers = ["o", "s", "^", "D", "v"]
for (name, hist), ls, color, mk in zip(
all_histories.items(), styles, palette, markers
):
epochs_h = [h["epoch"] for h in hist]
val_maes = [h["val_mae"] for h in hist]
ax2.plot(epochs_h, val_maes, linestyle=ls, color=color, linewidth=1.8,
label=name, marker=mk, markersize=3, markevery=5)
ax2.set_xlabel("Epoch", fontsize=11)
ax2.set_ylabel("Validation MAE (weighted average)", fontsize=11)
ax2.set_title("Validation MAE during training β€” all baseline models", fontsize=11)
ax2.legend(fontsize=9)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
fig2.savefig(save_dir / "fig_baseline_val_mae_curves.pdf",
dpi=300, bbox_inches="tight")
fig2.savefig(save_dir / "fig_baseline_val_mae_curves.png",
dpi=300, bbox_inches="tight")
plt.close(fig2)
log.info("Saved: fig_baseline_val_mae_curves")
# ─────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
args = parser.parse_args()
base_cfg = OmegaConf.load("configs/base.yaml")
exp_cfg = OmegaConf.load(args.config)
cfg = OmegaConf.merge(base_cfg, exp_cfg)
set_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info("Device: %s Dataset: %s",
device, "full 239k" if cfg.data.n_samples is None
else f"{cfg.data.n_samples:,}")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
save_dir = Path(cfg.outputs.figures_dir) / "comparison"
ckpt_dir = Path(cfg.outputs.checkpoint_dir)
save_dir.mkdir(parents=True, exist_ok=True)
all_results = []
all_histories = {}
# ─── B1: ResNet-18 + independent MSE (sigmoid) ────────────
log.info("=" * 60)
log.info("B1: ResNet-18 + independent MSE (sigmoid, no hierarchy)")
log.info("=" * 60)
set_seed(cfg.seed)
rn_mse_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
rn_mse_loss = IndependentMSELoss()
log.info("ResNet-18 params: %s", f"{sum(p.numel() for p in rn_mse_model.parameters()):,}")
rn_mse_metrics, rn_mse_val, rn_mse_epoch, rn_mse_hist = train_and_evaluate(
rn_mse_model, rn_mse_loss, cfg, device,
label = "B1-ResNet18-MSE",
checkpoint_path = str(ckpt_dir / "baseline_resnet18_mse.pt"),
use_layerwise_lr = False,
use_sigmoid = True,
)
# Simplex violation for this baseline
_, _, test_loader_tmp = build_dataloaders(cfg)
rn_mse_model.eval()
tmp_preds = []
with torch.no_grad():
for images, _, _, _ in test_loader_tmp:
images = images.to(device, non_blocking=True)
logits = rn_mse_model(images)
tmp_preds.append(torch.sigmoid(logits).cpu().numpy())
tmp_preds = np.concatenate(tmp_preds)
svr = simplex_violation_rate(tmp_preds, tolerance=0.02)
log.info("B1 simplex violation rate (mean): %.4f", svr["mean"])
row = {
"model": "ResNet-18 + MSE (sigmoid, no hierarchy)",
"backbone": "ResNet-18", "loss": "Independent MSE",
"hierarchy": "None",
"best_epoch": rn_mse_epoch, "best_val_loss": round(rn_mse_val, 5),
"mae_weighted" : round(rn_mse_metrics["mae/weighted_avg"], 5),
"rmse_weighted": round(rn_mse_metrics["rmse/weighted_avg"], 5),
"simplex_violation_mean": round(svr["mean"], 4),
}
for q in QUESTION_GROUPS:
row[f"mae_{q}"] = round(rn_mse_metrics[f"mae/{q}"], 5)
row[f"rmse_{q}"] = round(rn_mse_metrics[f"rmse/{q}"], 5)
all_results.append(row)
all_histories["ResNet-18 + MSE (sigmoid)"] = rn_mse_hist
log.info("B1 done: MAE=%.5f RMSE=%.5f SimplexViol=%.4f",
rn_mse_metrics["mae/weighted_avg"],
rn_mse_metrics["rmse/weighted_avg"],
svr["mean"])
# ─── B2: ResNet-18 + hierarchical KL+MSE ──────────────────
log.info("=" * 60)
log.info("B2: ResNet-18 + hierarchical KL+MSE (same loss as proposed)")
log.info("=" * 60)
set_seed(cfg.seed)
rn_kl_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
rn_kl_loss = HierarchicalLoss(cfg)
rn_kl_metrics, rn_kl_val, rn_kl_epoch, rn_kl_hist = train_and_evaluate(
rn_kl_model, rn_kl_loss, cfg, device,
label = "B2-ResNet18-KL+MSE",
checkpoint_path = str(ckpt_dir / "baseline_resnet18_klmse.pt"),
use_layerwise_lr = False,
use_sigmoid = False,
)
row = {
"model": "ResNet-18 + hierarchical KL+MSE",
"backbone": "ResNet-18", "loss": "Hierarchical KL+MSE (Ξ»=0.5)",
"hierarchy": "Full (weights + KL)",
"best_epoch": rn_kl_epoch, "best_val_loss": round(rn_kl_val, 5),
"mae_weighted" : round(rn_kl_metrics["mae/weighted_avg"], 5),
"rmse_weighted": round(rn_kl_metrics["rmse/weighted_avg"], 5),
"simplex_violation_mean": 0.0, # softmax guarantees validity
}
for q in QUESTION_GROUPS:
row[f"mae_{q}"] = round(rn_kl_metrics[f"mae/{q}"], 5)
row[f"rmse_{q}"] = round(rn_kl_metrics[f"rmse/{q}"], 5)
all_results.append(row)
all_histories["ResNet-18 + KL+MSE"] = rn_kl_hist
log.info("B2 done: MAE=%.5f RMSE=%.5f",
rn_kl_metrics["mae/weighted_avg"],
rn_kl_metrics["rmse/weighted_avg"])
# ─── B3: ViT-Base + hierarchical MSE only ─────────────────
log.info("=" * 60)
log.info("B3: ViT-Base + hierarchical MSE only (no KL term)")
log.info("=" * 60)
set_seed(cfg.seed)
from omegaconf import OmegaConf as OC
vit_mse_cfg = OC.merge(cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}}))
vit_mse_model = build_model(vit_mse_cfg).to(device)
vit_mse_loss = MSEOnlyLoss(vit_mse_cfg)
vit_mse_metrics, vit_mse_val, vit_mse_epoch, vit_mse_hist = train_and_evaluate(
vit_mse_model, vit_mse_loss, vit_mse_cfg, device,
label = "B3-ViT-MSE",
checkpoint_path = str(ckpt_dir / "baseline_vit_mse.pt"),
use_layerwise_lr = True,
use_sigmoid = False,
)
row = {
"model": "ViT-Base + hierarchical MSE (no KL)",
"backbone": "ViT-Base/16", "loss": "Hierarchical MSE (Ξ»_KL=0)",
"hierarchy": "Weights only",
"best_epoch": vit_mse_epoch, "best_val_loss": round(vit_mse_val, 5),
"mae_weighted" : round(vit_mse_metrics["mae/weighted_avg"], 5),
"rmse_weighted": round(vit_mse_metrics["rmse/weighted_avg"], 5),
"simplex_violation_mean": 0.0,
}
for q in QUESTION_GROUPS:
row[f"mae_{q}"] = round(vit_mse_metrics[f"mae/{q}"], 5)
row[f"rmse_{q}"] = round(vit_mse_metrics[f"rmse/{q}"], 5)
all_results.append(row)
all_histories["ViT-Base + MSE only"] = vit_mse_hist
log.info("B3 done: MAE=%.5f RMSE=%.5f",
vit_mse_metrics["mae/weighted_avg"],
vit_mse_metrics["rmse/weighted_avg"])
# ─── B4: ViT-Base + Dirichlet NLL (Zoobot-style) ──────────
log.info("=" * 60)
log.info("B4: ViT-Base + Dirichlet NLL (Walmsley et al. 2022)")
log.info("=" * 60)
set_seed(cfg.seed)
vit_dir_model = build_dirichlet_model(cfg).to(device)
vit_dir_loss = DirichletLoss(cfg)
vit_dir_metrics, vit_dir_val, vit_dir_epoch, vit_dir_hist = train_and_evaluate_dirichlet(
vit_dir_model, vit_dir_loss, cfg, device,
label = "B4-ViT-Dirichlet",
checkpoint_path = str(ckpt_dir / "baseline_vit_dirichlet.pt"),
)
row = {
"model": "ViT-Base + Dirichlet NLL (Zoobot-style)",
"backbone": "ViT-Base/16", "loss": "Dirichlet NLL",
"hierarchy": "Full (weights + Dirichlet)",
"best_epoch": vit_dir_epoch, "best_val_loss": round(vit_dir_val, 5),
"mae_weighted" : round(vit_dir_metrics["mae/weighted_avg"], 5),
"rmse_weighted": round(vit_dir_metrics["rmse/weighted_avg"], 5),
"simplex_violation_mean": 0.0,
}
for q in QUESTION_GROUPS:
row[f"mae_{q}"] = round(vit_dir_metrics[f"mae/{q}"], 5)
row[f"rmse_{q}"] = round(vit_dir_metrics[f"rmse/{q}"], 5)
all_results.append(row)
all_histories["ViT-Base + Dirichlet"] = vit_dir_hist
log.info("B4 done: MAE=%.5f RMSE=%.5f",
vit_dir_metrics["mae/weighted_avg"],
vit_dir_metrics["rmse/weighted_avg"])
# ─── Proposed: load existing checkpoint for final table ────
proposed_ckpt = ckpt_dir / "best_full_train.pt"
if proposed_ckpt.exists():
log.info("=" * 60)
log.info("PROPOSED: Loading ViT-Base + hierarchical KL+MSE")
log.info("=" * 60)
proposed_model = build_model(cfg).to(device)
proposed_model.load_state_dict(
torch.load(proposed_ckpt, map_location="cpu", weights_only=True)["model_state"]
)
_, _, test_loader_p = build_dataloaders(cfg)
_, proposed_metrics = _val_epoch(
proposed_model, test_loader_p, HierarchicalLoss(cfg), device, cfg,
epoch=0, label="Proposed-test", use_sigmoid=False
)
ckpt_info = torch.load(proposed_ckpt, map_location="cpu", weights_only=True)
row = {
"model": "ViT-Base + hierarchical KL+MSE (proposed)",
"backbone": "ViT-Base/16", "loss": "Hierarchical KL+MSE (Ξ»=0.5)",
"hierarchy": "Full (weights + KL)",
"best_epoch": ckpt_info["epoch"],
"best_val_loss": round(ckpt_info["val_loss"], 5),
"mae_weighted" : round(proposed_metrics["mae/weighted_avg"], 5),
"rmse_weighted": round(proposed_metrics["rmse/weighted_avg"], 5),
"simplex_violation_mean": 0.0,
}
for q in QUESTION_GROUPS:
row[f"mae_{q}"] = round(proposed_metrics[f"mae/{q}"], 5)
row[f"rmse_{q}"] = round(proposed_metrics[f"rmse/{q}"], 5)
all_results.append(row)
log.info("Proposed: MAE=%.5f RMSE=%.5f",
proposed_metrics["mae/weighted_avg"],
proposed_metrics["rmse/weighted_avg"])
# ─── Save results ──────────────────────────────────────────
df = pd.DataFrame(all_results)
df.to_csv(save_dir / "table_baseline_comparison.csv", index=False)
summary_cols = ["model", "loss", "hierarchy", "best_epoch",
"best_val_loss", "mae_weighted", "rmse_weighted",
"simplex_violation_mean"]
summary = df[[c for c in summary_cols if c in df.columns]].copy()
summary.to_csv(save_dir / "table_baseline_summary.csv", index=False)
print()
print("=" * 80)
print("BASELINE COMPARISON β€” FINAL RESULTS")
print("=" * 80)
print(summary.to_string(index=False))
print()
# ─── Figures ───────────────────────────────────────────────
_save_comparison_figures(all_results, all_histories, save_dir)
log.info("All baseline outputs saved to: %s", save_dir)
if __name__ == "__main__":
main()