""" src/baselines.py ---------------- Consolidated baseline training for the GZ2 hierarchical probabilistic regression paper. ALL baselines are trained from this single script. Replaces the three separate scripts: src/baselines.py (was: ResNet-18 MSE + ViT MSE) src/run_resnet_kl.py (was: ResNet-18 KL+MSE — now merged here) src/train_dirichlet.py (was: ViT Dirichlet — now merged here) DELETE those three original files after switching to this one. Baselines trained ----------------- B1. ResNet-18 + independent MSE (sigmoid) — CNN, no hierarchy, no KL. Demonstrates the cost of ignoring the decision-tree structure. B2. ResNet-18 + hierarchical KL+MSE — Same loss as proposed, CNN backbone. Isolates ViT vs. CNN contribution. B3. ViT-Base + hierarchical MSE only (no KL) — Same backbone as proposed, KL term removed. Isolates contribution of the KL term. B4. ViT-Base + Dirichlet NLL (Zoobot-style) — Direct comparison with the established Zoobot approach (Walmsley et al. 2022, MNRAS 509, 3966). Proposed model (not trained here — trained via src/train.py): ViT-Base + hierarchical KL+MSE → outputs/checkpoints/best_full_train.pt Consistency guarantee --------------------- All baselines use identical: - Random seed, data split, batch size, epochs, early stopping - AdamW optimiser, CosineAnnealingLR, gradient clipping - Image transforms and evaluation metric (compute_metrics on same test split) The ONLY differences between models are the backbone and/or loss function. Usage ----- cd ~/galaxy nohup python -m src.baselines --config configs/full_train.yaml \ > outputs/logs/baselines.log 2>&1 & echo "PID: $!" """ import argparse import logging import random import sys from pathlib import Path import numpy as np import pandas as pd import torch import timm import torch.nn as nn import torch.nn.functional as F import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from torch.amp import autocast, GradScaler from omegaconf import OmegaConf from tqdm import tqdm import wandb from src.dataset import build_dataloaders, QUESTION_GROUPS from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss from src.metrics import (compute_metrics, predictions_to_numpy, dirichlet_predictions_to_numpy, simplex_violation_rate) from src.model import build_model, build_dirichlet_model logging.basicConfig( format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, ) log = logging.getLogger("baselines") QUESTION_LABELS = { "t01": "Smooth or features", "t02": "Edge-on disk", "t03": "Bar", "t04": "Spiral arms", "t05": "Bulge prominence", "t06": "Odd feature", "t07": "Roundedness", "t08": "Odd feature type", "t09": "Bulge shape", "t10": "Arms winding", "t11": "Arms number", } # ───────────────────────────────────────────────────────────── # Reproducibility # ───────────────────────────────────────────────────────────── def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # ───────────────────────────────────────────────────────────── # Early stopping (mirrors train.py exactly) # ───────────────────────────────────────────────────────────── class EarlyStopping: def __init__(self, patience, min_delta, checkpoint_path): self.patience = patience self.min_delta = min_delta self.checkpoint_path = checkpoint_path self.best_loss = float("inf") self.counter = 0 self.best_epoch = 0 def step(self, val_loss, model, epoch) -> bool: if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.counter = 0 self.best_epoch = epoch torch.save( {"epoch": epoch, "model_state": model.state_dict(), "val_loss": val_loss}, self.checkpoint_path, ) log.info(" [ckpt] saved val_loss=%.6f epoch=%d", val_loss, epoch) else: self.counter += 1 log.info(" [early_stop] %d/%d best=%.6f", self.counter, self.patience, self.best_loss) return self.counter >= self.patience def restore_best(self, model) -> float: ckpt = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state"]) log.info("Restored best weights epoch=%d val_loss=%.6f", ckpt["epoch"], ckpt["val_loss"]) return ckpt["val_loss"] # ───────────────────────────────────────────────────────────── # Baseline Model 1: ResNet-18 + independent MSE # ───────────────────────────────────────────────────────────── class ResNet18Baseline(nn.Module): """ ResNet-18 pretrained on ImageNet with a dropout + linear head. Used for both the sigmoid-MSE baseline and the KL+MSE baseline. """ def __init__(self, dropout: float = 0.3): super().__init__() self.backbone = timm.create_model( "resnet18", pretrained=True, num_classes=0 ) self.head = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(self.backbone.num_features, 37), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.head(self.backbone(x)) class IndependentMSELoss(nn.Module): """ Plain MSE over all 37 targets independently. No hierarchical weighting, no KL divergence. Sigmoid applied to predictions before MSE to constrain range [0,1]. Note: predictions do NOT sum to 1 per question group by construction. This is documented and the simplex_violation_rate metric quantifies this invalidity to allow fair comparison with the proposed method. """ def forward(self, predictions, targets, weights): pred_prob = torch.sigmoid(predictions) loss = F.mse_loss(pred_prob, targets) return loss, {"loss/total": loss.detach().item()} # ───────────────────────────────────────────────────────────── # Shared training loop # ───────────────────────────────────────────────────────────── def _train_epoch(model, loader, loss_fn, optimizer, scaler, device, cfg, epoch, label): model.train() total = 0.0 nb = 0 for images, targets, weights, _ in tqdm( loader, desc=f"{label} E{epoch}", leave=False ): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) with autocast("cuda", enabled=cfg.training.mixed_precision): logits = model(images) loss, _ = loss_fn(logits, targets, weights) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip) scaler.step(optimizer) scaler.update() total += loss.item() nb += 1 return total / nb def _train_epoch_dirichlet(model, loader, loss_fn, optimizer, scaler, device, cfg, epoch, label): """Training epoch for Dirichlet model (outputs alpha, not logits).""" model.train() total = 0.0 nb = 0 for images, targets, weights, _ in tqdm( loader, desc=f"{label} E{epoch}", leave=False ): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) with autocast("cuda", enabled=cfg.training.mixed_precision): alpha = model(images) loss, _ = loss_fn(alpha, targets, weights) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip) scaler.step(optimizer) scaler.update() total += loss.item() nb += 1 return total / nb def _val_epoch(model, loader, loss_fn, device, cfg, epoch, label, use_sigmoid=False): model.eval() total = 0.0 nb = 0 all_preds, all_targets, all_weights = [], [], [] with torch.no_grad(): for images, targets, weights, _ in tqdm( loader, desc=f"{label} Val E{epoch}", leave=False ): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) with autocast("cuda", enabled=cfg.training.mixed_precision): logits = model(images) loss, _ = loss_fn(logits, targets, weights) total += loss.item() nb += 1 if use_sigmoid: pred_prob = torch.sigmoid(logits).detach().cpu().numpy() else: pred_cpu = logits.detach().cpu().clone() for q, (s, e) in QUESTION_GROUPS.items(): pred_cpu[:, s:e] = torch.softmax(pred_cpu[:, s:e], dim=-1) pred_prob = pred_cpu.numpy() all_preds.append(pred_prob) all_targets.append(targets.detach().cpu().numpy()) all_weights.append(weights.detach().cpu().numpy()) all_preds = np.concatenate(all_preds) all_targets = np.concatenate(all_targets) all_weights = np.concatenate(all_weights) metrics = compute_metrics(all_preds, all_targets, all_weights) return total / nb, metrics def _val_epoch_dirichlet(model, loader, loss_fn, device, cfg, epoch, label): model.eval() total = 0.0 nb = 0 all_preds, all_targets, all_weights = [], [], [] with torch.no_grad(): for images, targets, weights, _ in tqdm( loader, desc=f"{label} Val E{epoch}", leave=False ): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) with autocast("cuda", enabled=cfg.training.mixed_precision): alpha = model(images) loss, _ = loss_fn(alpha, targets, weights) total += loss.item() nb += 1 p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights) all_preds.append(p) all_targets.append(t) all_weights.append(w) all_preds = np.concatenate(all_preds) all_targets = np.concatenate(all_targets) all_weights = np.concatenate(all_weights) metrics = compute_metrics(all_preds, all_targets, all_weights) return total / nb, metrics # ───────────────────────────────────────────────────────────── # Generic train_and_evaluate (non-Dirichlet) # ───────────────────────────────────────────────────────────── def train_and_evaluate( model, loss_fn, cfg, device, label, checkpoint_path, use_layerwise_lr=True, use_sigmoid=False, ): """ Full training loop consistent with train.py. Returns (test_metrics, best_val_loss, best_epoch, history). If checkpoint exists, loads it and skips training. """ # Check if checkpoint exists - if so, skip training if Path(checkpoint_path).exists(): log.info("%s: checkpoint found - loading and skipping training", label) ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state"]) best_epoch = ckpt.get("epoch", 0) best_val = ckpt.get("val_loss", float("inf")) log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val) # Evaluate on test set _, _, test_loader = build_dataloaders(cfg) _, test_metrics = _val_epoch( model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid ) return test_metrics, best_val, best_epoch, [] train_loader, val_loader, test_loader = build_dataloaders(cfg) if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"): optimizer = torch.optim.AdamW( [ {"params": model.backbone.parameters(), "lr": cfg.training.learning_rate * 0.1}, {"params": model.head.parameters(), "lr": cfg.training.learning_rate}, ], weight_decay=cfg.training.weight_decay, ) log.info("%s: layer-wise lr — backbone=%.1e head=%.1e", label, cfg.training.learning_rate * 0.1, cfg.training.learning_rate) else: optimizer = torch.optim.AdamW( model.parameters(), lr=cfg.training.learning_rate, weight_decay=cfg.training.weight_decay, ) log.info("%s: single lr=%.1e", label, cfg.training.learning_rate) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min ) scaler = GradScaler("cuda") early_stop = EarlyStopping( patience=cfg.early_stopping.patience, min_delta=cfg.early_stopping.min_delta, checkpoint_path=checkpoint_path, ) wandb.init( project=cfg.wandb.project, name=label, config={ "model": label, "backbone": "resnet18" if "ResNet" in label else "vit_base_patch16_224", "batch_size": cfg.training.batch_size, "lr": cfg.training.learning_rate, "epochs": cfg.training.epochs, "seed": cfg.seed, "lambda_kl": cfg.loss.lambda_kl, "lambda_mse": cfg.loss.lambda_mse, }, reinit=True, ) history = [] for epoch in range(1, cfg.training.epochs + 1): train_loss = _train_epoch( model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label ) val_loss, val_metrics = _val_epoch( model, val_loader, loss_fn, device, cfg, epoch, label, use_sigmoid=use_sigmoid ) scheduler.step() lr = scheduler.get_last_lr()[0] val_mae = val_metrics.get("mae/weighted_avg", 0) log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e", label, epoch, train_loss, val_loss, val_mae, lr) history.append({ "epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "val_mae": val_mae, }) wandb.log({ "train_loss": train_loss, "val_loss": val_loss, "val_mae": val_mae, "lr": lr, }, step=epoch) if early_stop.step(val_loss, model, epoch): log.info("%s: early stopping at epoch %d best=%d", label, epoch, early_stop.best_epoch) break best_val = early_stop.restore_best(model) wandb.finish() log.info("%s: evaluating on test set...", label) _, test_metrics = _val_epoch( model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid ) return test_metrics, best_val, early_stop.best_epoch, history # ───────────────────────────────────────────────────────────── # Dirichlet train_and_evaluate # ───────────────────────────────────────────────────────────── def train_and_evaluate_dirichlet(model, loss_fn, cfg, device, label, checkpoint_path): """Training loop for Dirichlet model. Skips training if checkpoint exists.""" # Check if checkpoint exists - if so, skip training if Path(checkpoint_path).exists(): log.info("%s: checkpoint found - loading and skipping training", label) ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state"]) best_epoch = ckpt.get("epoch", 0) best_val = ckpt.get("val_loss", float("inf")) log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val) # Evaluate on test set _, _, test_loader = build_dataloaders(cfg) _, test_metrics = _val_epoch_dirichlet( model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{label}-test" ) return test_metrics, best_val, best_epoch, [] train_loader, val_loader, test_loader = build_dataloaders(cfg) optimizer = torch.optim.AdamW( [ {"params": model.backbone.parameters(), "lr": cfg.training.learning_rate * 0.1}, {"params": model.head.parameters(), "lr": cfg.training.learning_rate}, ], weight_decay=cfg.training.weight_decay, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min ) scaler = GradScaler("cuda") early_stop = EarlyStopping( patience=cfg.early_stopping.patience, min_delta=cfg.early_stopping.min_delta, checkpoint_path=checkpoint_path, ) wandb.init( project=cfg.wandb.project, name=label, config={"model": label, "loss": "DirichletNLL", "seed": cfg.seed, "epochs": cfg.training.epochs}, reinit=True, ) history = [] for epoch in range(1, cfg.training.epochs + 1): train_loss = _train_epoch_dirichlet( model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label ) val_loss, val_metrics = _val_epoch_dirichlet( model, val_loader, loss_fn, device, cfg, epoch, label ) scheduler.step() lr = scheduler.get_last_lr()[0] val_mae = val_metrics.get("mae/weighted_avg", 0) log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e", label, epoch, train_loss, val_loss, val_mae, lr) history.append({ "epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "val_mae": val_mae, }) wandb.log({ "train_loss": train_loss, "val_loss": val_loss, "val_mae": val_mae, "lr": lr, }, step=epoch) if early_stop.step(val_loss, model, epoch): log.info("%s: early stopping at epoch %d", label, epoch) break best_val = early_stop.restore_best(model) wandb.finish() log.info("%s: evaluating on test set...", label) _, test_metrics = _val_epoch_dirichlet( model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{label}-test" ) return test_metrics, best_val, early_stop.best_epoch, history # ───────────────────────────────────────────────────────────── # Figures # ───────────────────────────────────────────────────────────── def _save_comparison_figures(all_results, all_histories, save_dir): """ Saves: 1. Per-question MAE + RMSE bar chart 2. Validation MAE learning curves 3. Simplex violation table for sigmoid baseline All figure names follow IEEE journal conventions. """ q_names = list(QUESTION_GROUPS.keys()) n_models = len(all_results) x = np.arange(len(q_names)) width = 0.80 / n_models palette = ["#c0392b", "#e67e22", "#2980b9", "#27ae60", "#8e44ad"] # ── Figure 1: Per-question MAE and RMSE ─────────────────── fig, axes = plt.subplots(1, 2, figsize=(16, 6)) for metric, ax, ylabel in [ ("mae", axes[0], "Mean Absolute Error (MAE)"), ("rmse", axes[1], "Root Mean Squared Error (RMSE)"), ]: for i, (row_d, color) in enumerate(zip(all_results, palette)): vals = [row_d.get(f"{metric}_{q}", np.nan) for q in q_names] ax.bar(x + i * width, vals, width, label=row_d["model"], color=color, alpha=0.85, edgecolor="white", linewidth=0.5) ax.set_xticks(x + width * (n_models - 1) / 2) ax.set_xticklabels( [f"{q}\n({QUESTION_LABELS[q][:10]})" for q in q_names], rotation=45, ha="right", fontsize=7, ) ax.set_ylabel(ylabel, fontsize=11) ax.set_title(f"Per-question {metric.upper()} — baseline comparison", fontsize=11) ax.legend(fontsize=7, loc="upper right") ax.grid(True, alpha=0.3, axis="y") ax.set_axisbelow(True) plt.suptitle( "Baseline comparison — GZ2 hierarchical probabilistic regression\n" "Full 239,267-sample dataset, identical seed/split/protocol", fontsize=12, y=1.02, ) plt.tight_layout() fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.pdf", dpi=300, bbox_inches="tight") fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.png", dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_baseline_comparison_mae_rmse") # ── Figure 2: Validation MAE learning curves ─────────────── fig2, ax2 = plt.subplots(figsize=(10, 5)) styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1))] markers = ["o", "s", "^", "D", "v"] for (name, hist), ls, color, mk in zip( all_histories.items(), styles, palette, markers ): epochs_h = [h["epoch"] for h in hist] val_maes = [h["val_mae"] for h in hist] ax2.plot(epochs_h, val_maes, linestyle=ls, color=color, linewidth=1.8, label=name, marker=mk, markersize=3, markevery=5) ax2.set_xlabel("Epoch", fontsize=11) ax2.set_ylabel("Validation MAE (weighted average)", fontsize=11) ax2.set_title("Validation MAE during training — all baseline models", fontsize=11) ax2.legend(fontsize=9) ax2.grid(True, alpha=0.3) plt.tight_layout() fig2.savefig(save_dir / "fig_baseline_val_mae_curves.pdf", dpi=300, bbox_inches="tight") fig2.savefig(save_dir / "fig_baseline_val_mae_curves.png", dpi=300, bbox_inches="tight") plt.close(fig2) log.info("Saved: fig_baseline_val_mae_curves") # ───────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", required=True) args = parser.parse_args() base_cfg = OmegaConf.load("configs/base.yaml") exp_cfg = OmegaConf.load(args.config) cfg = OmegaConf.merge(base_cfg, exp_cfg) set_seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.info("Device: %s Dataset: %s", device, "full 239k" if cfg.data.n_samples is None else f"{cfg.data.n_samples:,}") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True save_dir = Path(cfg.outputs.figures_dir) / "comparison" ckpt_dir = Path(cfg.outputs.checkpoint_dir) save_dir.mkdir(parents=True, exist_ok=True) all_results = [] all_histories = {} # ─── B1: ResNet-18 + independent MSE (sigmoid) ──────────── log.info("=" * 60) log.info("B1: ResNet-18 + independent MSE (sigmoid, no hierarchy)") log.info("=" * 60) set_seed(cfg.seed) rn_mse_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device) rn_mse_loss = IndependentMSELoss() log.info("ResNet-18 params: %s", f"{sum(p.numel() for p in rn_mse_model.parameters()):,}") rn_mse_metrics, rn_mse_val, rn_mse_epoch, rn_mse_hist = train_and_evaluate( rn_mse_model, rn_mse_loss, cfg, device, label = "B1-ResNet18-MSE", checkpoint_path = str(ckpt_dir / "baseline_resnet18_mse.pt"), use_layerwise_lr = False, use_sigmoid = True, ) # Simplex violation for this baseline _, _, test_loader_tmp = build_dataloaders(cfg) rn_mse_model.eval() tmp_preds = [] with torch.no_grad(): for images, _, _, _ in test_loader_tmp: images = images.to(device, non_blocking=True) logits = rn_mse_model(images) tmp_preds.append(torch.sigmoid(logits).cpu().numpy()) tmp_preds = np.concatenate(tmp_preds) svr = simplex_violation_rate(tmp_preds, tolerance=0.02) log.info("B1 simplex violation rate (mean): %.4f", svr["mean"]) row = { "model": "ResNet-18 + MSE (sigmoid, no hierarchy)", "backbone": "ResNet-18", "loss": "Independent MSE", "hierarchy": "None", "best_epoch": rn_mse_epoch, "best_val_loss": round(rn_mse_val, 5), "mae_weighted" : round(rn_mse_metrics["mae/weighted_avg"], 5), "rmse_weighted": round(rn_mse_metrics["rmse/weighted_avg"], 5), "simplex_violation_mean": round(svr["mean"], 4), } for q in QUESTION_GROUPS: row[f"mae_{q}"] = round(rn_mse_metrics[f"mae/{q}"], 5) row[f"rmse_{q}"] = round(rn_mse_metrics[f"rmse/{q}"], 5) all_results.append(row) all_histories["ResNet-18 + MSE (sigmoid)"] = rn_mse_hist log.info("B1 done: MAE=%.5f RMSE=%.5f SimplexViol=%.4f", rn_mse_metrics["mae/weighted_avg"], rn_mse_metrics["rmse/weighted_avg"], svr["mean"]) # ─── B2: ResNet-18 + hierarchical KL+MSE ────────────────── log.info("=" * 60) log.info("B2: ResNet-18 + hierarchical KL+MSE (same loss as proposed)") log.info("=" * 60) set_seed(cfg.seed) rn_kl_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device) rn_kl_loss = HierarchicalLoss(cfg) rn_kl_metrics, rn_kl_val, rn_kl_epoch, rn_kl_hist = train_and_evaluate( rn_kl_model, rn_kl_loss, cfg, device, label = "B2-ResNet18-KL+MSE", checkpoint_path = str(ckpt_dir / "baseline_resnet18_klmse.pt"), use_layerwise_lr = False, use_sigmoid = False, ) row = { "model": "ResNet-18 + hierarchical KL+MSE", "backbone": "ResNet-18", "loss": "Hierarchical KL+MSE (λ=0.5)", "hierarchy": "Full (weights + KL)", "best_epoch": rn_kl_epoch, "best_val_loss": round(rn_kl_val, 5), "mae_weighted" : round(rn_kl_metrics["mae/weighted_avg"], 5), "rmse_weighted": round(rn_kl_metrics["rmse/weighted_avg"], 5), "simplex_violation_mean": 0.0, # softmax guarantees validity } for q in QUESTION_GROUPS: row[f"mae_{q}"] = round(rn_kl_metrics[f"mae/{q}"], 5) row[f"rmse_{q}"] = round(rn_kl_metrics[f"rmse/{q}"], 5) all_results.append(row) all_histories["ResNet-18 + KL+MSE"] = rn_kl_hist log.info("B2 done: MAE=%.5f RMSE=%.5f", rn_kl_metrics["mae/weighted_avg"], rn_kl_metrics["rmse/weighted_avg"]) # ─── B3: ViT-Base + hierarchical MSE only ───────────────── log.info("=" * 60) log.info("B3: ViT-Base + hierarchical MSE only (no KL term)") log.info("=" * 60) set_seed(cfg.seed) from omegaconf import OmegaConf as OC vit_mse_cfg = OC.merge(cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}})) vit_mse_model = build_model(vit_mse_cfg).to(device) vit_mse_loss = MSEOnlyLoss(vit_mse_cfg) vit_mse_metrics, vit_mse_val, vit_mse_epoch, vit_mse_hist = train_and_evaluate( vit_mse_model, vit_mse_loss, vit_mse_cfg, device, label = "B3-ViT-MSE", checkpoint_path = str(ckpt_dir / "baseline_vit_mse.pt"), use_layerwise_lr = True, use_sigmoid = False, ) row = { "model": "ViT-Base + hierarchical MSE (no KL)", "backbone": "ViT-Base/16", "loss": "Hierarchical MSE (λ_KL=0)", "hierarchy": "Weights only", "best_epoch": vit_mse_epoch, "best_val_loss": round(vit_mse_val, 5), "mae_weighted" : round(vit_mse_metrics["mae/weighted_avg"], 5), "rmse_weighted": round(vit_mse_metrics["rmse/weighted_avg"], 5), "simplex_violation_mean": 0.0, } for q in QUESTION_GROUPS: row[f"mae_{q}"] = round(vit_mse_metrics[f"mae/{q}"], 5) row[f"rmse_{q}"] = round(vit_mse_metrics[f"rmse/{q}"], 5) all_results.append(row) all_histories["ViT-Base + MSE only"] = vit_mse_hist log.info("B3 done: MAE=%.5f RMSE=%.5f", vit_mse_metrics["mae/weighted_avg"], vit_mse_metrics["rmse/weighted_avg"]) # ─── B4: ViT-Base + Dirichlet NLL (Zoobot-style) ────────── log.info("=" * 60) log.info("B4: ViT-Base + Dirichlet NLL (Walmsley et al. 2022)") log.info("=" * 60) set_seed(cfg.seed) vit_dir_model = build_dirichlet_model(cfg).to(device) vit_dir_loss = DirichletLoss(cfg) vit_dir_metrics, vit_dir_val, vit_dir_epoch, vit_dir_hist = train_and_evaluate_dirichlet( vit_dir_model, vit_dir_loss, cfg, device, label = "B4-ViT-Dirichlet", checkpoint_path = str(ckpt_dir / "baseline_vit_dirichlet.pt"), ) row = { "model": "ViT-Base + Dirichlet NLL (Zoobot-style)", "backbone": "ViT-Base/16", "loss": "Dirichlet NLL", "hierarchy": "Full (weights + Dirichlet)", "best_epoch": vit_dir_epoch, "best_val_loss": round(vit_dir_val, 5), "mae_weighted" : round(vit_dir_metrics["mae/weighted_avg"], 5), "rmse_weighted": round(vit_dir_metrics["rmse/weighted_avg"], 5), "simplex_violation_mean": 0.0, } for q in QUESTION_GROUPS: row[f"mae_{q}"] = round(vit_dir_metrics[f"mae/{q}"], 5) row[f"rmse_{q}"] = round(vit_dir_metrics[f"rmse/{q}"], 5) all_results.append(row) all_histories["ViT-Base + Dirichlet"] = vit_dir_hist log.info("B4 done: MAE=%.5f RMSE=%.5f", vit_dir_metrics["mae/weighted_avg"], vit_dir_metrics["rmse/weighted_avg"]) # ─── Proposed: load existing checkpoint for final table ──── proposed_ckpt = ckpt_dir / "best_full_train.pt" if proposed_ckpt.exists(): log.info("=" * 60) log.info("PROPOSED: Loading ViT-Base + hierarchical KL+MSE") log.info("=" * 60) proposed_model = build_model(cfg).to(device) proposed_model.load_state_dict( torch.load(proposed_ckpt, map_location="cpu", weights_only=True)["model_state"] ) _, _, test_loader_p = build_dataloaders(cfg) _, proposed_metrics = _val_epoch( proposed_model, test_loader_p, HierarchicalLoss(cfg), device, cfg, epoch=0, label="Proposed-test", use_sigmoid=False ) ckpt_info = torch.load(proposed_ckpt, map_location="cpu", weights_only=True) row = { "model": "ViT-Base + hierarchical KL+MSE (proposed)", "backbone": "ViT-Base/16", "loss": "Hierarchical KL+MSE (λ=0.5)", "hierarchy": "Full (weights + KL)", "best_epoch": ckpt_info["epoch"], "best_val_loss": round(ckpt_info["val_loss"], 5), "mae_weighted" : round(proposed_metrics["mae/weighted_avg"], 5), "rmse_weighted": round(proposed_metrics["rmse/weighted_avg"], 5), "simplex_violation_mean": 0.0, } for q in QUESTION_GROUPS: row[f"mae_{q}"] = round(proposed_metrics[f"mae/{q}"], 5) row[f"rmse_{q}"] = round(proposed_metrics[f"rmse/{q}"], 5) all_results.append(row) log.info("Proposed: MAE=%.5f RMSE=%.5f", proposed_metrics["mae/weighted_avg"], proposed_metrics["rmse/weighted_avg"]) # ─── Save results ────────────────────────────────────────── df = pd.DataFrame(all_results) df.to_csv(save_dir / "table_baseline_comparison.csv", index=False) summary_cols = ["model", "loss", "hierarchy", "best_epoch", "best_val_loss", "mae_weighted", "rmse_weighted", "simplex_violation_mean"] summary = df[[c for c in summary_cols if c in df.columns]].copy() summary.to_csv(save_dir / "table_baseline_summary.csv", index=False) print() print("=" * 80) print("BASELINE COMPARISON — FINAL RESULTS") print("=" * 80) print(summary.to_string(index=False)) print() # ─── Figures ─────────────────────────────────────────────── _save_comparison_figures(all_results, all_histories, save_dir) log.info("All baseline outputs saved to: %s", save_dir) if __name__ == "__main__": main()