""" src/evaluate_full.py -------------------- Full evaluation of all trained models on the held-out test set. Generates all paper figures and tables: Tables ------ table_metrics_proposed.csv — MAE / RMSE / bias / ECE for proposed model table_reached_branch_mae.csv — reached-branch MAE across all 5 models table_simplex_violation.csv — simplex validity for sigmoid baseline Figures (PDF + PNG, IEEE naming convention) ------------------------------------------- fig_scatter_predicted_vs_true.pdf — predicted vs true vote fractions (proposed) fig_calibration_reliability.pdf — reliability diagrams, all models fig_ece_comparison.pdf — ECE bar chart, all models fig_attention_rollout_gallery.pdf — full 12-layer attention rollout gallery fig_attention_entropy_depth.pdf — CLS attention entropy vs. layer depth Usage ----- cd ~/galaxy nohup python -m src.evaluate_full --config configs/full_train.yaml \ > outputs/logs/evaluate.log 2>&1 & echo "PID: $!" """ import argparse import logging import sys from pathlib import Path import numpy as np import pandas as pd import torch import torch.nn.functional as F import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from torch.amp import autocast from omegaconf import OmegaConf from tqdm import tqdm from src.dataset import build_dataloaders, QUESTION_GROUPS from src.model import build_model, build_dirichlet_model from src.metrics import (compute_metrics, predictions_to_numpy, compute_reached_branch_mae_table, dirichlet_predictions_to_numpy, simplex_violation_rate, _compute_ece) from src.attention_viz import plot_attention_grid, plot_attention_entropy from src.baselines import ResNet18Baseline logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, ) log = logging.getLogger("evaluate_full") # ── Global matplotlib style ──────────────────────────────────────────────────── plt.rcParams.update({ "figure.dpi" : 150, "savefig.dpi" : 300, "font.family" : "serif", "font.size" : 11, "axes.titlesize" : 11, "axes.labelsize" : 11, "xtick.labelsize" : 9, "ytick.labelsize" : 9, "legend.fontsize" : 9, "figure.facecolor" : "white", "axes.facecolor" : "white", "axes.grid" : True, "grid.alpha" : 0.3, "pdf.fonttype" : 42, # editable text in PDF "ps.fonttype" : 42, }) QUESTION_LABELS = { "t01": "Smooth or features", "t02": "Edge-on disk", "t03": "Bar", "t04": "Spiral arms", "t05": "Bulge prominence", "t06": "Odd feature", "t07": "Roundedness", "t08": "Odd feature type", "t09": "Bulge shape", "t10": "Arms winding", "t11": "Arms number", } # Consistent colours and line styles for all models across all figures MODEL_COLORS = { "ResNet-18 + MSE (sigmoid)" : "#c0392b", "ResNet-18 + KL+MSE" : "#e67e22", "ViT-Base + MSE only" : "#2980b9", "ViT-Base + KL+MSE (proposed)" : "#27ae60", "ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad", } MODEL_STYLES = { "ResNet-18 + MSE (sigmoid)" : "-", "ResNet-18 + KL+MSE" : "-.", "ViT-Base + MSE only" : "--", "ViT-Base + KL+MSE (proposed)" : "-", "ViT-Base + Dirichlet (Zoobot-style)": ":", } # ───────────────────────────────────────────────────────────── # Inference helpers # ───────────────────────────────────────────────────────────── def _infer_vit(model, loader, device, cfg, collect_attn=True, n_attn=16): model.eval() all_preds, all_targets, all_weights = [], [], [] attn_images, all_layer_attns, attn_ids = [], [], [] attn_done = False with torch.no_grad(): for images, targets, weights, image_ids in tqdm(loader, desc="ViT inference"): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) with autocast("cuda", enabled=cfg.training.mixed_precision): logits = model(images) p, t, w = predictions_to_numpy(logits, targets, weights) all_preds.append(p) all_targets.append(t) all_weights.append(w) if collect_attn and not attn_done: layers = model.get_all_attention_weights() if layers is not None: n = min(n_attn, images.shape[0]) attn_images.append(images[:n].cpu()) all_layer_attns.append([l[:n].cpu() for l in layers]) attn_ids.extend([int(i) for i in image_ids[:n]]) if len(attn_ids) >= n_attn: attn_done = True preds = np.concatenate(all_preds) targets = np.concatenate(all_targets) weights = np.concatenate(all_weights) attn_imgs_t = torch.cat(attn_images, dim=0)[:n_attn] if attn_images else None merged_layers = None if all_layer_attns: merged_layers = [ torch.cat([b[li] for b in all_layer_attns], dim=0)[:n_attn] for li in range(len(all_layer_attns[0])) ] return preds, targets, weights, attn_imgs_t, merged_layers, attn_ids def _infer_resnet(model, loader, device, cfg, use_sigmoid: bool): model.eval() all_preds, all_targets, all_weights = [], [], [] with torch.no_grad(): for images, targets, weights, _ in tqdm(loader, desc="ResNet inference"): images = images.to(device, non_blocking=True) with autocast("cuda", enabled=cfg.training.mixed_precision): logits = model(images) if use_sigmoid: pred = torch.sigmoid(logits).cpu().numpy() else: pred = logits.detach().cpu().clone() for q, (s, e) in QUESTION_GROUPS.items(): pred[:, s:e] = F.softmax(pred[:, s:e], dim=-1) pred = pred.numpy() all_preds.append(pred) all_targets.append(targets.numpy()) all_weights.append(weights.numpy()) return (np.concatenate(all_preds), np.concatenate(all_targets), np.concatenate(all_weights)) def _infer_dirichlet(model, loader, device, cfg): model.eval() all_preds, all_targets, all_weights = [], [], [] with torch.no_grad(): for images, targets, weights, _ in tqdm(loader, desc="Dirichlet inference"): images = images.to(device, non_blocking=True) with autocast("cuda", enabled=cfg.training.mixed_precision): alpha = model(images) p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights) all_preds.append(p) all_targets.append(t) all_weights.append(w) return (np.concatenate(all_preds), np.concatenate(all_targets), np.concatenate(all_weights)) # ───────────────────────────────────────────────────────────── # Figure 1: Predicted vs true scatter (proposed model) # ───────────────────────────────────────────────────────────── def fig_scatter_predicted_vs_true(preds, targets, weights, save_dir): path_pdf = save_dir / "fig_scatter_predicted_vs_true.pdf" path_png = save_dir / "fig_scatter_predicted_vs_true.png" if path_pdf.exists() and path_png.exists(): log.info("Skip (exists): fig_scatter_predicted_vs_true"); return fig, axes = plt.subplots(3, 4, figsize=(16, 12)) axes = axes.flatten() for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): ax = axes[q_idx] mask = weights[:, q_idx] >= 0.05 pq = preds[mask, start:end].flatten() tq = targets[mask, start:end].flatten() ax.scatter(tq, pq, alpha=0.06, s=1, color="#2563eb", rasterized=True) ax.plot([0, 1], [0, 1], "r--", linewidth=1, alpha=0.8) ax.set_xlim(0, 1); ax.set_ylim(0, 1) ax.set_xlabel("True vote fraction") ax.set_ylabel("Predicted vote fraction") ax.set_title( f"{q_name}: {QUESTION_LABELS[q_name]}\n" f"$n$ = {mask.sum():,} (w ≥ 0.05)", fontsize=9, ) ax.set_aspect("equal") mae = np.abs(pq - tq).mean() ax.text(0.05, 0.92, f"MAE = {mae:.3f}", transform=ax.transAxes, fontsize=8, bbox=dict(boxstyle="round,pad=0.2", facecolor="white", edgecolor="grey", alpha=0.85)) axes[-1].axis("off") plt.suptitle( "Predicted vs. true vote fractions — reached branches (w ≥ 0.05)\n" "ViT-Base/16 + hierarchical KL+MSE (proposed model, test set)", fontsize=12, ) plt.tight_layout() fig.savefig(path_pdf, dpi=300, bbox_inches="tight") fig.savefig(path_png, dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_scatter_predicted_vs_true") # ───────────────────────────────────────────────────────────── # Figure 2: Calibration reliability diagrams # ───────────────────────────────────────────────────────────── def fig_calibration_reliability(model_results, save_dir, n_bins=15): path_pdf = save_dir / "fig_calibration_reliability.pdf" path_png = save_dir / "fig_calibration_reliability.png" if path_pdf.exists() and path_png.exists(): log.info("Skip (exists): fig_calibration_reliability"); return # Show 8 representative questions (skip t02 — bimodal, shown separately) q_show = ["t01", "t03", "t04", "t06", "t07", "t09", "t10", "t11"] fig, axes = plt.subplots(2, 4, figsize=(16, 8)) axes = axes.flatten() for ax_idx, q_name in enumerate(q_show): ax = axes[ax_idx] start, end = QUESTION_GROUPS[q_name] q_idx = list(QUESTION_GROUPS.keys()).index(q_name) for model_name, (preds, targets, weights) in model_results.items(): mask = weights[:, q_idx] >= 0.05 if mask.sum() < 50: continue pf = preds[mask, start:end].flatten() tf = targets[mask, start:end].flatten() # Adaptive bins (equal-frequency) — consistent with ECE computation percentiles = np.linspace(0, 100, n_bins + 1) bin_edges = np.unique(np.percentile(pf, percentiles)) if len(bin_edges) < 2: continue bin_ids = np.clip( np.digitize(pf, bin_edges[1:-1]), 0, len(bin_edges) - 2 ) mp = np.array([ pf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan for b in range(len(bin_edges) - 1) ]) mt = np.array([ tf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan for b in range(len(bin_edges) - 1) ]) valid = ~np.isnan(mp) & ~np.isnan(mt) ax.plot( mp[valid], mt[valid], MODEL_STYLES.get(model_name, "-"), color=MODEL_COLORS.get(model_name, "#888888"), linewidth=1.8, marker="o", markersize=3.5, label=model_name, alpha=0.9, ) ax.plot([0, 1], [0, 1], "k--", linewidth=1, alpha=0.5, label="Perfect") ax.set_xlim(0, 1); ax.set_ylim(0, 1) ax.set_xlabel("Mean predicted", fontsize=8) ax.set_ylabel("Mean true", fontsize=8) ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9) ax.set_aspect("equal") if ax_idx == 0: ax.legend(fontsize=6.5, loc="upper left") plt.suptitle( "Calibration reliability diagrams — all models (test set)\n" "Reached branches only (w ≥ 0.05). Adaptive equal-frequency bins. " "Closer to diagonal = better calibrated.", fontsize=11, ) plt.tight_layout() fig.savefig(path_pdf, dpi=300, bbox_inches="tight") fig.savefig(path_png, dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_calibration_reliability") # ───────────────────────────────────────────────────────────── # Figure 3: ECE bar chart # ───────────────────────────────────────────────────────────── def fig_ece_comparison(model_results, save_dir): path_pdf = save_dir / "fig_ece_comparison.pdf" path_png = save_dir / "fig_ece_comparison.png" if path_pdf.exists() and path_png.exists(): log.info("Skip (exists): fig_ece_comparison"); return q_names = list(QUESTION_GROUPS.keys()) ece_rows = [] for model_name, (preds, targets, weights) in model_results.items(): row = {"model": model_name} for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): mask = weights[:, q_idx] >= 0.05 if mask.sum() < 50: row[q_name] = float("nan") else: row[q_name] = _compute_ece( preds[mask, start:end].flatten(), targets[mask, start:end].flatten(), n_bins=15, ) row["mean_ece"] = float( np.nanmean([row[q] for q in q_names]) ) ece_rows.append(row) df_ece = pd.DataFrame(ece_rows) df_ece.to_csv(save_dir / "table_ece_comparison.csv", index=False) x = np.arange(len(q_names)) width = 0.80 / len(model_results) palette = list(MODEL_COLORS.values()) fig, ax = plt.subplots(figsize=(14, 5)) for i, (model_name, _) in enumerate(model_results.items()): vals = [ float(df_ece[df_ece["model"] == model_name][q].values[0]) for q in q_names ] ax.bar( x + i * width, vals, width, label=model_name, color=MODEL_COLORS.get(model_name, palette[i % len(palette)]), alpha=0.85, edgecolor="white", linewidth=0.5, ) ax.set_xticks(x + width * (len(model_results) - 1) / 2) ax.set_xticklabels( [f"{q}\n({QUESTION_LABELS[q][:12]})" for q in q_names], rotation=30, ha="right", fontsize=8, ) ax.set_ylabel("Expected Calibration Error (ECE)", fontsize=11) ax.set_title( "Expected Calibration Error — all models (test set)\n" "Reached branches (w ≥ 0.05). Adaptive equal-frequency binning. " "Lower is better.", fontsize=11, ) ax.legend(fontsize=8) ax.grid(True, alpha=0.3, axis="y") ax.set_axisbelow(True) plt.tight_layout() fig.savefig(path_pdf, dpi=300, bbox_inches="tight") fig.savefig(path_png, dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_ece_comparison") # ───────────────────────────────────────────────────────────── # Figure 4: Attention rollout gallery # ───────────────────────────────────────────────────────────── def fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir): if attn_imgs is None or all_layers is None: log.warning("No attention data — skipping gallery."); return path_pdf = save_dir / "fig_attention_rollout_gallery.pdf" path_png = save_dir / "fig_attention_rollout_gallery.png" if not path_pdf.exists(): fig = plot_attention_grid( attn_imgs, all_layers, attn_ids, save_path=str(path_png), n_cols=4, rollout_mode="full", ) fig.savefig(path_pdf, dpi=300, bbox_inches="tight", facecolor="black") plt.close(fig) log.info("Saved: fig_attention_rollout_gallery") # High-resolution PNG for journal submission path_hq = save_dir / "fig_attention_rollout_gallery_HQ.png" if not path_hq.exists(): fig2 = plot_attention_grid( attn_imgs, all_layers, attn_ids, n_cols=4, rollout_mode="full", ) fig2.savefig(path_hq, dpi=600, bbox_inches="tight", facecolor="black") plt.close(fig2) log.info("Saved: fig_attention_rollout_gallery_HQ (600 dpi)") # ───────────────────────────────────────────────────────────── # Figure 5: Attention entropy vs. depth # ───────────────────────────────────────────────────────────── def fig_attention_entropy_depth(all_layers, save_dir): if all_layers is None: log.warning("No attention layers — skipping entropy plot."); return path_pdf = save_dir / "fig_attention_entropy_depth.pdf" path_png = save_dir / "fig_attention_entropy_depth.png" if path_pdf.exists() and path_png.exists(): log.info("Skip (exists): fig_attention_entropy_depth"); return fig = plot_attention_entropy(all_layers, save_path=str(path_png)) fig.savefig(path_pdf, dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_attention_entropy_depth") # ───────────────────────────────────────────────────────────── # Table: metrics for proposed model # ───────────────────────────────────────────────────────────── def table_metrics_proposed(preds, targets, weights, save_dir): metrics = compute_metrics(preds, targets, weights) rows = [] for q_name in QUESTION_GROUPS: rows.append({ "question" : q_name, "description": QUESTION_LABELS[q_name], "MAE" : round(metrics[f"mae/{q_name}"], 5), "RMSE" : round(metrics[f"rmse/{q_name}"], 5), "bias" : round(metrics[f"bias/{q_name}"], 5), "ECE" : round(metrics[f"ece/{q_name}"], 5), }) rows.append({ "question": "weighted_avg", "description": "Weighted average", "MAE" : round(metrics["mae/weighted_avg"], 5), "RMSE": round(metrics["rmse/weighted_avg"], 5), "bias": "", "ECE" : round(metrics["ece/mean"], 5), }) df = pd.DataFrame(rows) df.to_csv(save_dir / "table_metrics_proposed.csv", index=False) log.info("\n%s\n", df.to_string(index=False)) return metrics # ───────────────────────────────────────────────────────────── # Table: simplex violation for sigmoid baseline # ───────────────────────────────────────────────────────────── def table_simplex_violation(model_results, save_dir): """ For each model, report the fraction of test samples where per-question predictions do not sum to 1 ± 0.02. Expected: ~0 for softmax models, nonzero for sigmoid baseline. This table explains why the sigmoid baseline achieves lower raw per-answer MAE despite being scientifically invalid: unconstrained sigmoid outputs fit each marginal independently. """ rows = [] for model_name, (preds, _, _) in model_results.items(): svr = simplex_violation_rate(preds, tolerance=0.02) row = {"model": model_name} row.update({q: round(svr[q], 4) for q in QUESTION_GROUPS}) row["mean"] = round(svr["mean"], 4) rows.append(row) df = pd.DataFrame(rows) df.to_csv(save_dir / "table_simplex_violation.csv", index=False) log.info("Saved: table_simplex_violation.csv") log.info("\n%s\n", df[["model", "mean"]].to_string(index=False)) return df # ───────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", required=True) args = parser.parse_args() base_cfg = OmegaConf.load("configs/base.yaml") exp_cfg = OmegaConf.load(args.config) cfg = OmegaConf.merge(base_cfg, exp_cfg) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") save_dir = Path(cfg.outputs.figures_dir) / "evaluation" save_dir.mkdir(parents=True, exist_ok=True) ckpt_dir = Path(cfg.outputs.checkpoint_dir) _, _, test_loader = build_dataloaders(cfg) # ── Load all models ──────────────────────────────────────── log.info("Loading models from: %s", ckpt_dir) def _load(path, model): ckpt = torch.load(path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state"]) return model vit_proposed = _load( ckpt_dir / "best_full_train.pt", build_model(cfg) ).to(device) vit_mse = _load( ckpt_dir / "baseline_vit_mse.pt", build_model(cfg) ).to(device) rn_mse = _load( ckpt_dir / "baseline_resnet18_mse.pt", ResNet18Baseline(dropout=cfg.model.dropout) ).to(device) rn_kl = _load( ckpt_dir / "baseline_resnet18_klmse.pt", ResNet18Baseline(dropout=cfg.model.dropout) ).to(device) vit_dirichlet = None dp = ckpt_dir / "baseline_vit_dirichlet.pt" if dp.exists(): vit_dirichlet = _load(dp, build_dirichlet_model(cfg)).to(device) log.info("Loaded: ViT-Base + Dirichlet") # ── Run inference ────────────────────────────────────────── log.info("Running inference on test set...") (p_proposed, t_proposed, w_proposed, attn_imgs, all_layers, attn_ids) = _infer_vit( vit_proposed, test_loader, device, cfg, collect_attn=True, n_attn=16, ) p_vit_mse, t_vit_mse, w_vit_mse = _infer_vit( vit_mse, test_loader, device, cfg, collect_attn=False )[:3] p_rn_mse, t_rn_mse, w_rn_mse = _infer_resnet( rn_mse, test_loader, device, cfg, use_sigmoid=True ) p_rn_kl, t_rn_kl, w_rn_kl = _infer_resnet( rn_kl, test_loader, device, cfg, use_sigmoid=False ) # Build model_results dict (order determines legend order in figures) model_results = { "ResNet-18 + MSE (sigmoid)" : (p_rn_mse, t_rn_mse, w_rn_mse), "ResNet-18 + KL+MSE" : (p_rn_kl, t_rn_kl, w_rn_kl), "ViT-Base + MSE only" : (p_vit_mse, t_vit_mse, w_vit_mse), "ViT-Base + KL+MSE (proposed)" : (p_proposed, t_proposed, w_proposed), } if vit_dirichlet is not None: p_dir, t_dir, w_dir = _infer_dirichlet( vit_dirichlet, test_loader, device, cfg ) model_results["ViT-Base + Dirichlet (Zoobot-style)"] = (p_dir, t_dir, w_dir) # ── Tables ───────────────────────────────────────────────── log.info("Computing metrics...") table_metrics_proposed(p_proposed, t_proposed, w_proposed, save_dir) log.info("Computing reached-branch MAE table...") df_r = compute_reached_branch_mae_table(model_results) df_r.to_csv(save_dir / "table_reached_branch_mae.csv", index=False) log.info("Saved: table_reached_branch_mae.csv") log.info("Computing simplex violation table...") table_simplex_violation(model_results, save_dir) # ── Figures ──────────────────────────────────────────────── log.info("Generating figures...") fig_scatter_predicted_vs_true(p_proposed, t_proposed, w_proposed, save_dir) fig_calibration_reliability(model_results, save_dir) fig_ece_comparison(model_results, save_dir) fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir) fig_attention_entropy_depth(all_layers, save_dir) log.info("=" * 60) log.info("ALL OUTPUTS SAVED TO: %s", save_dir) log.info("=" * 60) metrics = compute_metrics(p_proposed, t_proposed, w_proposed) log.info("Proposed model — test set results:") log.info(" Weighted MAE = %.5f", metrics["mae/weighted_avg"]) log.info(" Weighted RMSE = %.5f", metrics["rmse/weighted_avg"]) log.info(" Mean ECE = %.5f", metrics["ece/mean"]) if __name__ == "__main__": main()