Spaces:
Running
Running
| """ | |
| src/evaluate_full.py | |
| -------------------- | |
| Full evaluation of all trained models on the held-out test set. | |
| Generates all paper figures and tables: | |
| Tables | |
| ------ | |
| table_metrics_proposed.csv β MAE / RMSE / bias / ECE for proposed model | |
| table_reached_branch_mae.csv β reached-branch MAE across all 5 models | |
| table_simplex_violation.csv β simplex validity for sigmoid baseline | |
| Figures (PDF + PNG, IEEE naming convention) | |
| ------------------------------------------- | |
| fig_scatter_predicted_vs_true.pdf β predicted vs true vote fractions (proposed) | |
| fig_calibration_reliability.pdf β reliability diagrams, all models | |
| fig_ece_comparison.pdf β ECE bar chart, all models | |
| fig_attention_rollout_gallery.pdf β full 12-layer attention rollout gallery | |
| fig_attention_entropy_depth.pdf β CLS attention entropy vs. layer depth | |
| Usage | |
| ----- | |
| cd ~/galaxy | |
| nohup python -m src.evaluate_full --config configs/full_train.yaml \ | |
| > outputs/logs/evaluate.log 2>&1 & | |
| echo "PID: $!" | |
| """ | |
| import argparse | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from torch.amp import autocast | |
| from omegaconf import OmegaConf | |
| from tqdm import tqdm | |
| from src.dataset import build_dataloaders, QUESTION_GROUPS | |
| from src.model import build_model, build_dirichlet_model | |
| from src.metrics import (compute_metrics, predictions_to_numpy, | |
| compute_reached_branch_mae_table, | |
| dirichlet_predictions_to_numpy, | |
| simplex_violation_rate, _compute_ece) | |
| from src.attention_viz import plot_attention_grid, plot_attention_entropy | |
| from src.baselines import ResNet18Baseline | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s %(name)s %(message)s", | |
| datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, | |
| ) | |
| log = logging.getLogger("evaluate_full") | |
| # ββ Global matplotlib style ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| plt.rcParams.update({ | |
| "figure.dpi" : 150, | |
| "savefig.dpi" : 300, | |
| "font.family" : "serif", | |
| "font.size" : 11, | |
| "axes.titlesize" : 11, | |
| "axes.labelsize" : 11, | |
| "xtick.labelsize" : 9, | |
| "ytick.labelsize" : 9, | |
| "legend.fontsize" : 9, | |
| "figure.facecolor" : "white", | |
| "axes.facecolor" : "white", | |
| "axes.grid" : True, | |
| "grid.alpha" : 0.3, | |
| "pdf.fonttype" : 42, # editable text in PDF | |
| "ps.fonttype" : 42, | |
| }) | |
| QUESTION_LABELS = { | |
| "t01": "Smooth or features", | |
| "t02": "Edge-on disk", | |
| "t03": "Bar", | |
| "t04": "Spiral arms", | |
| "t05": "Bulge prominence", | |
| "t06": "Odd feature", | |
| "t07": "Roundedness", | |
| "t08": "Odd feature type", | |
| "t09": "Bulge shape", | |
| "t10": "Arms winding", | |
| "t11": "Arms number", | |
| } | |
| # Consistent colours and line styles for all models across all figures | |
| MODEL_COLORS = { | |
| "ResNet-18 + MSE (sigmoid)" : "#c0392b", | |
| "ResNet-18 + KL+MSE" : "#e67e22", | |
| "ViT-Base + MSE only" : "#2980b9", | |
| "ViT-Base + KL+MSE (proposed)" : "#27ae60", | |
| "ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad", | |
| } | |
| MODEL_STYLES = { | |
| "ResNet-18 + MSE (sigmoid)" : "-", | |
| "ResNet-18 + KL+MSE" : "-.", | |
| "ViT-Base + MSE only" : "--", | |
| "ViT-Base + KL+MSE (proposed)" : "-", | |
| "ViT-Base + Dirichlet (Zoobot-style)": ":", | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Inference helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _infer_vit(model, loader, device, cfg, | |
| collect_attn=True, n_attn=16): | |
| model.eval() | |
| all_preds, all_targets, all_weights = [], [], [] | |
| attn_images, all_layer_attns, attn_ids = [], [], [] | |
| attn_done = False | |
| with torch.no_grad(): | |
| for images, targets, weights, image_ids in tqdm(loader, desc="ViT inference"): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| logits = model(images) | |
| p, t, w = predictions_to_numpy(logits, targets, weights) | |
| all_preds.append(p) | |
| all_targets.append(t) | |
| all_weights.append(w) | |
| if collect_attn and not attn_done: | |
| layers = model.get_all_attention_weights() | |
| if layers is not None: | |
| n = min(n_attn, images.shape[0]) | |
| attn_images.append(images[:n].cpu()) | |
| all_layer_attns.append([l[:n].cpu() for l in layers]) | |
| attn_ids.extend([int(i) for i in image_ids[:n]]) | |
| if len(attn_ids) >= n_attn: | |
| attn_done = True | |
| preds = np.concatenate(all_preds) | |
| targets = np.concatenate(all_targets) | |
| weights = np.concatenate(all_weights) | |
| attn_imgs_t = torch.cat(attn_images, dim=0)[:n_attn] if attn_images else None | |
| merged_layers = None | |
| if all_layer_attns: | |
| merged_layers = [ | |
| torch.cat([b[li] for b in all_layer_attns], dim=0)[:n_attn] | |
| for li in range(len(all_layer_attns[0])) | |
| ] | |
| return preds, targets, weights, attn_imgs_t, merged_layers, attn_ids | |
| def _infer_resnet(model, loader, device, cfg, use_sigmoid: bool): | |
| model.eval() | |
| all_preds, all_targets, all_weights = [], [], [] | |
| with torch.no_grad(): | |
| for images, targets, weights, _ in tqdm(loader, desc="ResNet inference"): | |
| images = images.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| logits = model(images) | |
| if use_sigmoid: | |
| pred = torch.sigmoid(logits).cpu().numpy() | |
| else: | |
| pred = logits.detach().cpu().clone() | |
| for q, (s, e) in QUESTION_GROUPS.items(): | |
| pred[:, s:e] = F.softmax(pred[:, s:e], dim=-1) | |
| pred = pred.numpy() | |
| all_preds.append(pred) | |
| all_targets.append(targets.numpy()) | |
| all_weights.append(weights.numpy()) | |
| return (np.concatenate(all_preds), | |
| np.concatenate(all_targets), | |
| np.concatenate(all_weights)) | |
| def _infer_dirichlet(model, loader, device, cfg): | |
| model.eval() | |
| all_preds, all_targets, all_weights = [], [], [] | |
| with torch.no_grad(): | |
| for images, targets, weights, _ in tqdm(loader, desc="Dirichlet inference"): | |
| images = images.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| alpha = model(images) | |
| p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights) | |
| all_preds.append(p) | |
| all_targets.append(t) | |
| all_weights.append(w) | |
| return (np.concatenate(all_preds), | |
| np.concatenate(all_targets), | |
| np.concatenate(all_weights)) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 1: Predicted vs true scatter (proposed model) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fig_scatter_predicted_vs_true(preds, targets, weights, save_dir): | |
| path_pdf = save_dir / "fig_scatter_predicted_vs_true.pdf" | |
| path_png = save_dir / "fig_scatter_predicted_vs_true.png" | |
| if path_pdf.exists() and path_png.exists(): | |
| log.info("Skip (exists): fig_scatter_predicted_vs_true"); return | |
| fig, axes = plt.subplots(3, 4, figsize=(16, 12)) | |
| axes = axes.flatten() | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| ax = axes[q_idx] | |
| mask = weights[:, q_idx] >= 0.05 | |
| pq = preds[mask, start:end].flatten() | |
| tq = targets[mask, start:end].flatten() | |
| ax.scatter(tq, pq, alpha=0.06, s=1, color="#2563eb", rasterized=True) | |
| ax.plot([0, 1], [0, 1], "r--", linewidth=1, alpha=0.8) | |
| ax.set_xlim(0, 1); ax.set_ylim(0, 1) | |
| ax.set_xlabel("True vote fraction") | |
| ax.set_ylabel("Predicted vote fraction") | |
| ax.set_title( | |
| f"{q_name}: {QUESTION_LABELS[q_name]}\n" | |
| f"$n$ = {mask.sum():,} (w β₯ 0.05)", | |
| fontsize=9, | |
| ) | |
| ax.set_aspect("equal") | |
| mae = np.abs(pq - tq).mean() | |
| ax.text(0.05, 0.92, f"MAE = {mae:.3f}", | |
| transform=ax.transAxes, fontsize=8, | |
| bbox=dict(boxstyle="round,pad=0.2", facecolor="white", | |
| edgecolor="grey", alpha=0.85)) | |
| axes[-1].axis("off") | |
| plt.suptitle( | |
| "Predicted vs. true vote fractions β reached branches (w β₯ 0.05)\n" | |
| "ViT-Base/16 + hierarchical KL+MSE (proposed model, test set)", | |
| fontsize=12, | |
| ) | |
| plt.tight_layout() | |
| fig.savefig(path_pdf, dpi=300, bbox_inches="tight") | |
| fig.savefig(path_png, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_scatter_predicted_vs_true") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 2: Calibration reliability diagrams | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fig_calibration_reliability(model_results, save_dir, n_bins=15): | |
| path_pdf = save_dir / "fig_calibration_reliability.pdf" | |
| path_png = save_dir / "fig_calibration_reliability.png" | |
| if path_pdf.exists() and path_png.exists(): | |
| log.info("Skip (exists): fig_calibration_reliability"); return | |
| # Show 8 representative questions (skip t02 β bimodal, shown separately) | |
| q_show = ["t01", "t03", "t04", "t06", "t07", "t09", "t10", "t11"] | |
| fig, axes = plt.subplots(2, 4, figsize=(16, 8)) | |
| axes = axes.flatten() | |
| for ax_idx, q_name in enumerate(q_show): | |
| ax = axes[ax_idx] | |
| start, end = QUESTION_GROUPS[q_name] | |
| q_idx = list(QUESTION_GROUPS.keys()).index(q_name) | |
| for model_name, (preds, targets, weights) in model_results.items(): | |
| mask = weights[:, q_idx] >= 0.05 | |
| if mask.sum() < 50: | |
| continue | |
| pf = preds[mask, start:end].flatten() | |
| tf = targets[mask, start:end].flatten() | |
| # Adaptive bins (equal-frequency) β consistent with ECE computation | |
| percentiles = np.linspace(0, 100, n_bins + 1) | |
| bin_edges = np.unique(np.percentile(pf, percentiles)) | |
| if len(bin_edges) < 2: | |
| continue | |
| bin_ids = np.clip( | |
| np.digitize(pf, bin_edges[1:-1]), 0, len(bin_edges) - 2 | |
| ) | |
| mp = np.array([ | |
| pf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan | |
| for b in range(len(bin_edges) - 1) | |
| ]) | |
| mt = np.array([ | |
| tf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan | |
| for b in range(len(bin_edges) - 1) | |
| ]) | |
| valid = ~np.isnan(mp) & ~np.isnan(mt) | |
| ax.plot( | |
| mp[valid], mt[valid], | |
| MODEL_STYLES.get(model_name, "-"), | |
| color=MODEL_COLORS.get(model_name, "#888888"), | |
| linewidth=1.8, marker="o", markersize=3.5, | |
| label=model_name, alpha=0.9, | |
| ) | |
| ax.plot([0, 1], [0, 1], "k--", linewidth=1, alpha=0.5, label="Perfect") | |
| ax.set_xlim(0, 1); ax.set_ylim(0, 1) | |
| ax.set_xlabel("Mean predicted", fontsize=8) | |
| ax.set_ylabel("Mean true", fontsize=8) | |
| ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9) | |
| ax.set_aspect("equal") | |
| if ax_idx == 0: | |
| ax.legend(fontsize=6.5, loc="upper left") | |
| plt.suptitle( | |
| "Calibration reliability diagrams β all models (test set)\n" | |
| "Reached branches only (w β₯ 0.05). Adaptive equal-frequency bins. " | |
| "Closer to diagonal = better calibrated.", | |
| fontsize=11, | |
| ) | |
| plt.tight_layout() | |
| fig.savefig(path_pdf, dpi=300, bbox_inches="tight") | |
| fig.savefig(path_png, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_calibration_reliability") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 3: ECE bar chart | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fig_ece_comparison(model_results, save_dir): | |
| path_pdf = save_dir / "fig_ece_comparison.pdf" | |
| path_png = save_dir / "fig_ece_comparison.png" | |
| if path_pdf.exists() and path_png.exists(): | |
| log.info("Skip (exists): fig_ece_comparison"); return | |
| q_names = list(QUESTION_GROUPS.keys()) | |
| ece_rows = [] | |
| for model_name, (preds, targets, weights) in model_results.items(): | |
| row = {"model": model_name} | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| mask = weights[:, q_idx] >= 0.05 | |
| if mask.sum() < 50: | |
| row[q_name] = float("nan") | |
| else: | |
| row[q_name] = _compute_ece( | |
| preds[mask, start:end].flatten(), | |
| targets[mask, start:end].flatten(), | |
| n_bins=15, | |
| ) | |
| row["mean_ece"] = float( | |
| np.nanmean([row[q] for q in q_names]) | |
| ) | |
| ece_rows.append(row) | |
| df_ece = pd.DataFrame(ece_rows) | |
| df_ece.to_csv(save_dir / "table_ece_comparison.csv", index=False) | |
| x = np.arange(len(q_names)) | |
| width = 0.80 / len(model_results) | |
| palette = list(MODEL_COLORS.values()) | |
| fig, ax = plt.subplots(figsize=(14, 5)) | |
| for i, (model_name, _) in enumerate(model_results.items()): | |
| vals = [ | |
| float(df_ece[df_ece["model"] == model_name][q].values[0]) | |
| for q in q_names | |
| ] | |
| ax.bar( | |
| x + i * width, vals, width, | |
| label=model_name, | |
| color=MODEL_COLORS.get(model_name, palette[i % len(palette)]), | |
| alpha=0.85, edgecolor="white", linewidth=0.5, | |
| ) | |
| ax.set_xticks(x + width * (len(model_results) - 1) / 2) | |
| ax.set_xticklabels( | |
| [f"{q}\n({QUESTION_LABELS[q][:12]})" for q in q_names], | |
| rotation=30, ha="right", fontsize=8, | |
| ) | |
| ax.set_ylabel("Expected Calibration Error (ECE)", fontsize=11) | |
| ax.set_title( | |
| "Expected Calibration Error β all models (test set)\n" | |
| "Reached branches (w β₯ 0.05). Adaptive equal-frequency binning. " | |
| "Lower is better.", | |
| fontsize=11, | |
| ) | |
| ax.legend(fontsize=8) | |
| ax.grid(True, alpha=0.3, axis="y") | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| fig.savefig(path_pdf, dpi=300, bbox_inches="tight") | |
| fig.savefig(path_png, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_ece_comparison") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 4: Attention rollout gallery | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir): | |
| if attn_imgs is None or all_layers is None: | |
| log.warning("No attention data β skipping gallery."); return | |
| path_pdf = save_dir / "fig_attention_rollout_gallery.pdf" | |
| path_png = save_dir / "fig_attention_rollout_gallery.png" | |
| if not path_pdf.exists(): | |
| fig = plot_attention_grid( | |
| attn_imgs, all_layers, attn_ids, | |
| save_path=str(path_png), | |
| n_cols=4, rollout_mode="full", | |
| ) | |
| fig.savefig(path_pdf, dpi=300, bbox_inches="tight", facecolor="black") | |
| plt.close(fig) | |
| log.info("Saved: fig_attention_rollout_gallery") | |
| # High-resolution PNG for journal submission | |
| path_hq = save_dir / "fig_attention_rollout_gallery_HQ.png" | |
| if not path_hq.exists(): | |
| fig2 = plot_attention_grid( | |
| attn_imgs, all_layers, attn_ids, | |
| n_cols=4, rollout_mode="full", | |
| ) | |
| fig2.savefig(path_hq, dpi=600, bbox_inches="tight", facecolor="black") | |
| plt.close(fig2) | |
| log.info("Saved: fig_attention_rollout_gallery_HQ (600 dpi)") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 5: Attention entropy vs. depth | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fig_attention_entropy_depth(all_layers, save_dir): | |
| if all_layers is None: | |
| log.warning("No attention layers β skipping entropy plot."); return | |
| path_pdf = save_dir / "fig_attention_entropy_depth.pdf" | |
| path_png = save_dir / "fig_attention_entropy_depth.png" | |
| if path_pdf.exists() and path_png.exists(): | |
| log.info("Skip (exists): fig_attention_entropy_depth"); return | |
| fig = plot_attention_entropy(all_layers, save_path=str(path_png)) | |
| fig.savefig(path_pdf, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_attention_entropy_depth") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Table: metrics for proposed model | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def table_metrics_proposed(preds, targets, weights, save_dir): | |
| metrics = compute_metrics(preds, targets, weights) | |
| rows = [] | |
| for q_name in QUESTION_GROUPS: | |
| rows.append({ | |
| "question" : q_name, | |
| "description": QUESTION_LABELS[q_name], | |
| "MAE" : round(metrics[f"mae/{q_name}"], 5), | |
| "RMSE" : round(metrics[f"rmse/{q_name}"], 5), | |
| "bias" : round(metrics[f"bias/{q_name}"], 5), | |
| "ECE" : round(metrics[f"ece/{q_name}"], 5), | |
| }) | |
| rows.append({ | |
| "question": "weighted_avg", "description": "Weighted average", | |
| "MAE" : round(metrics["mae/weighted_avg"], 5), | |
| "RMSE": round(metrics["rmse/weighted_avg"], 5), | |
| "bias": "", | |
| "ECE" : round(metrics["ece/mean"], 5), | |
| }) | |
| df = pd.DataFrame(rows) | |
| df.to_csv(save_dir / "table_metrics_proposed.csv", index=False) | |
| log.info("\n%s\n", df.to_string(index=False)) | |
| return metrics | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Table: simplex violation for sigmoid baseline | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def table_simplex_violation(model_results, save_dir): | |
| """ | |
| For each model, report the fraction of test samples where per-question | |
| predictions do not sum to 1 Β± 0.02. Expected: ~0 for softmax models, | |
| nonzero for sigmoid baseline. This table explains why the sigmoid | |
| baseline achieves lower raw per-answer MAE despite being scientifically | |
| invalid: unconstrained sigmoid outputs fit each marginal independently. | |
| """ | |
| rows = [] | |
| for model_name, (preds, _, _) in model_results.items(): | |
| svr = simplex_violation_rate(preds, tolerance=0.02) | |
| row = {"model": model_name} | |
| row.update({q: round(svr[q], 4) for q in QUESTION_GROUPS}) | |
| row["mean"] = round(svr["mean"], 4) | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| df.to_csv(save_dir / "table_simplex_violation.csv", index=False) | |
| log.info("Saved: table_simplex_violation.csv") | |
| log.info("\n%s\n", df[["model", "mean"]].to_string(index=False)) | |
| return df | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True) | |
| args = parser.parse_args() | |
| base_cfg = OmegaConf.load("configs/base.yaml") | |
| exp_cfg = OmegaConf.load(args.config) | |
| cfg = OmegaConf.merge(base_cfg, exp_cfg) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| save_dir = Path(cfg.outputs.figures_dir) / "evaluation" | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| ckpt_dir = Path(cfg.outputs.checkpoint_dir) | |
| _, _, test_loader = build_dataloaders(cfg) | |
| # ββ Load all models ββββββββββββββββββββββββββββββββββββββββ | |
| log.info("Loading models from: %s", ckpt_dir) | |
| def _load(path, model): | |
| ckpt = torch.load(path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(ckpt["model_state"]) | |
| return model | |
| vit_proposed = _load( | |
| ckpt_dir / "best_full_train.pt", build_model(cfg) | |
| ).to(device) | |
| vit_mse = _load( | |
| ckpt_dir / "baseline_vit_mse.pt", build_model(cfg) | |
| ).to(device) | |
| rn_mse = _load( | |
| ckpt_dir / "baseline_resnet18_mse.pt", | |
| ResNet18Baseline(dropout=cfg.model.dropout) | |
| ).to(device) | |
| rn_kl = _load( | |
| ckpt_dir / "baseline_resnet18_klmse.pt", | |
| ResNet18Baseline(dropout=cfg.model.dropout) | |
| ).to(device) | |
| vit_dirichlet = None | |
| dp = ckpt_dir / "baseline_vit_dirichlet.pt" | |
| if dp.exists(): | |
| vit_dirichlet = _load(dp, build_dirichlet_model(cfg)).to(device) | |
| log.info("Loaded: ViT-Base + Dirichlet") | |
| # ββ Run inference ββββββββββββββββββββββββββββββββββββββββββ | |
| log.info("Running inference on test set...") | |
| (p_proposed, t_proposed, w_proposed, | |
| attn_imgs, all_layers, attn_ids) = _infer_vit( | |
| vit_proposed, test_loader, device, cfg, | |
| collect_attn=True, n_attn=16, | |
| ) | |
| p_vit_mse, t_vit_mse, w_vit_mse = _infer_vit( | |
| vit_mse, test_loader, device, cfg, collect_attn=False | |
| )[:3] | |
| p_rn_mse, t_rn_mse, w_rn_mse = _infer_resnet( | |
| rn_mse, test_loader, device, cfg, use_sigmoid=True | |
| ) | |
| p_rn_kl, t_rn_kl, w_rn_kl = _infer_resnet( | |
| rn_kl, test_loader, device, cfg, use_sigmoid=False | |
| ) | |
| # Build model_results dict (order determines legend order in figures) | |
| model_results = { | |
| "ResNet-18 + MSE (sigmoid)" : (p_rn_mse, t_rn_mse, w_rn_mse), | |
| "ResNet-18 + KL+MSE" : (p_rn_kl, t_rn_kl, w_rn_kl), | |
| "ViT-Base + MSE only" : (p_vit_mse, t_vit_mse, w_vit_mse), | |
| "ViT-Base + KL+MSE (proposed)" : (p_proposed, t_proposed, w_proposed), | |
| } | |
| if vit_dirichlet is not None: | |
| p_dir, t_dir, w_dir = _infer_dirichlet( | |
| vit_dirichlet, test_loader, device, cfg | |
| ) | |
| model_results["ViT-Base + Dirichlet (Zoobot-style)"] = (p_dir, t_dir, w_dir) | |
| # ββ Tables βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| log.info("Computing metrics...") | |
| table_metrics_proposed(p_proposed, t_proposed, w_proposed, save_dir) | |
| log.info("Computing reached-branch MAE table...") | |
| df_r = compute_reached_branch_mae_table(model_results) | |
| df_r.to_csv(save_dir / "table_reached_branch_mae.csv", index=False) | |
| log.info("Saved: table_reached_branch_mae.csv") | |
| log.info("Computing simplex violation table...") | |
| table_simplex_violation(model_results, save_dir) | |
| # ββ Figures ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| log.info("Generating figures...") | |
| fig_scatter_predicted_vs_true(p_proposed, t_proposed, w_proposed, save_dir) | |
| fig_calibration_reliability(model_results, save_dir) | |
| fig_ece_comparison(model_results, save_dir) | |
| fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir) | |
| fig_attention_entropy_depth(all_layers, save_dir) | |
| log.info("=" * 60) | |
| log.info("ALL OUTPUTS SAVED TO: %s", save_dir) | |
| log.info("=" * 60) | |
| metrics = compute_metrics(p_proposed, t_proposed, w_proposed) | |
| log.info("Proposed model β test set results:") | |
| log.info(" Weighted MAE = %.5f", metrics["mae/weighted_avg"]) | |
| log.info(" Weighted RMSE = %.5f", metrics["rmse/weighted_avg"]) | |
| log.info(" Mean ECE = %.5f", metrics["ece/mean"]) | |
| if __name__ == "__main__": | |
| main() | |