eshwar-gz2-api / src /evaluate_full.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/evaluate_full.py
--------------------
Full evaluation of all trained models on the held-out test set.
Generates all paper figures and tables:
Tables
------
table_metrics_proposed.csv β€” MAE / RMSE / bias / ECE for proposed model
table_reached_branch_mae.csv β€” reached-branch MAE across all 5 models
table_simplex_violation.csv β€” simplex validity for sigmoid baseline
Figures (PDF + PNG, IEEE naming convention)
-------------------------------------------
fig_scatter_predicted_vs_true.pdf β€” predicted vs true vote fractions (proposed)
fig_calibration_reliability.pdf β€” reliability diagrams, all models
fig_ece_comparison.pdf β€” ECE bar chart, all models
fig_attention_rollout_gallery.pdf β€” full 12-layer attention rollout gallery
fig_attention_entropy_depth.pdf β€” CLS attention entropy vs. layer depth
Usage
-----
cd ~/galaxy
nohup python -m src.evaluate_full --config configs/full_train.yaml \
> outputs/logs/evaluate.log 2>&1 &
echo "PID: $!"
"""
import argparse
import logging
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from torch.amp import autocast
from omegaconf import OmegaConf
from tqdm import tqdm
from src.dataset import build_dataloaders, QUESTION_GROUPS
from src.model import build_model, build_dirichlet_model
from src.metrics import (compute_metrics, predictions_to_numpy,
compute_reached_branch_mae_table,
dirichlet_predictions_to_numpy,
simplex_violation_rate, _compute_ece)
from src.attention_viz import plot_attention_grid, plot_attention_entropy
from src.baselines import ResNet18Baseline
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s %(message)s",
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
)
log = logging.getLogger("evaluate_full")
# ── Global matplotlib style ────────────────────────────────────────────────────
plt.rcParams.update({
"figure.dpi" : 150,
"savefig.dpi" : 300,
"font.family" : "serif",
"font.size" : 11,
"axes.titlesize" : 11,
"axes.labelsize" : 11,
"xtick.labelsize" : 9,
"ytick.labelsize" : 9,
"legend.fontsize" : 9,
"figure.facecolor" : "white",
"axes.facecolor" : "white",
"axes.grid" : True,
"grid.alpha" : 0.3,
"pdf.fonttype" : 42, # editable text in PDF
"ps.fonttype" : 42,
})
QUESTION_LABELS = {
"t01": "Smooth or features",
"t02": "Edge-on disk",
"t03": "Bar",
"t04": "Spiral arms",
"t05": "Bulge prominence",
"t06": "Odd feature",
"t07": "Roundedness",
"t08": "Odd feature type",
"t09": "Bulge shape",
"t10": "Arms winding",
"t11": "Arms number",
}
# Consistent colours and line styles for all models across all figures
MODEL_COLORS = {
"ResNet-18 + MSE (sigmoid)" : "#c0392b",
"ResNet-18 + KL+MSE" : "#e67e22",
"ViT-Base + MSE only" : "#2980b9",
"ViT-Base + KL+MSE (proposed)" : "#27ae60",
"ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad",
}
MODEL_STYLES = {
"ResNet-18 + MSE (sigmoid)" : "-",
"ResNet-18 + KL+MSE" : "-.",
"ViT-Base + MSE only" : "--",
"ViT-Base + KL+MSE (proposed)" : "-",
"ViT-Base + Dirichlet (Zoobot-style)": ":",
}
# ─────────────────────────────────────────────────────────────
# Inference helpers
# ─────────────────────────────────────────────────────────────
def _infer_vit(model, loader, device, cfg,
collect_attn=True, n_attn=16):
model.eval()
all_preds, all_targets, all_weights = [], [], []
attn_images, all_layer_attns, attn_ids = [], [], []
attn_done = False
with torch.no_grad():
for images, targets, weights, image_ids in tqdm(loader, desc="ViT inference"):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
logits = model(images)
p, t, w = predictions_to_numpy(logits, targets, weights)
all_preds.append(p)
all_targets.append(t)
all_weights.append(w)
if collect_attn and not attn_done:
layers = model.get_all_attention_weights()
if layers is not None:
n = min(n_attn, images.shape[0])
attn_images.append(images[:n].cpu())
all_layer_attns.append([l[:n].cpu() for l in layers])
attn_ids.extend([int(i) for i in image_ids[:n]])
if len(attn_ids) >= n_attn:
attn_done = True
preds = np.concatenate(all_preds)
targets = np.concatenate(all_targets)
weights = np.concatenate(all_weights)
attn_imgs_t = torch.cat(attn_images, dim=0)[:n_attn] if attn_images else None
merged_layers = None
if all_layer_attns:
merged_layers = [
torch.cat([b[li] for b in all_layer_attns], dim=0)[:n_attn]
for li in range(len(all_layer_attns[0]))
]
return preds, targets, weights, attn_imgs_t, merged_layers, attn_ids
def _infer_resnet(model, loader, device, cfg, use_sigmoid: bool):
model.eval()
all_preds, all_targets, all_weights = [], [], []
with torch.no_grad():
for images, targets, weights, _ in tqdm(loader, desc="ResNet inference"):
images = images.to(device, non_blocking=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
logits = model(images)
if use_sigmoid:
pred = torch.sigmoid(logits).cpu().numpy()
else:
pred = logits.detach().cpu().clone()
for q, (s, e) in QUESTION_GROUPS.items():
pred[:, s:e] = F.softmax(pred[:, s:e], dim=-1)
pred = pred.numpy()
all_preds.append(pred)
all_targets.append(targets.numpy())
all_weights.append(weights.numpy())
return (np.concatenate(all_preds),
np.concatenate(all_targets),
np.concatenate(all_weights))
def _infer_dirichlet(model, loader, device, cfg):
model.eval()
all_preds, all_targets, all_weights = [], [], []
with torch.no_grad():
for images, targets, weights, _ in tqdm(loader, desc="Dirichlet inference"):
images = images.to(device, non_blocking=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
alpha = model(images)
p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights)
all_preds.append(p)
all_targets.append(t)
all_weights.append(w)
return (np.concatenate(all_preds),
np.concatenate(all_targets),
np.concatenate(all_weights))
# ─────────────────────────────────────────────────────────────
# Figure 1: Predicted vs true scatter (proposed model)
# ─────────────────────────────────────────────────────────────
def fig_scatter_predicted_vs_true(preds, targets, weights, save_dir):
path_pdf = save_dir / "fig_scatter_predicted_vs_true.pdf"
path_png = save_dir / "fig_scatter_predicted_vs_true.png"
if path_pdf.exists() and path_png.exists():
log.info("Skip (exists): fig_scatter_predicted_vs_true"); return
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
axes = axes.flatten()
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
ax = axes[q_idx]
mask = weights[:, q_idx] >= 0.05
pq = preds[mask, start:end].flatten()
tq = targets[mask, start:end].flatten()
ax.scatter(tq, pq, alpha=0.06, s=1, color="#2563eb", rasterized=True)
ax.plot([0, 1], [0, 1], "r--", linewidth=1, alpha=0.8)
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
ax.set_xlabel("True vote fraction")
ax.set_ylabel("Predicted vote fraction")
ax.set_title(
f"{q_name}: {QUESTION_LABELS[q_name]}\n"
f"$n$ = {mask.sum():,} (w β‰₯ 0.05)",
fontsize=9,
)
ax.set_aspect("equal")
mae = np.abs(pq - tq).mean()
ax.text(0.05, 0.92, f"MAE = {mae:.3f}",
transform=ax.transAxes, fontsize=8,
bbox=dict(boxstyle="round,pad=0.2", facecolor="white",
edgecolor="grey", alpha=0.85))
axes[-1].axis("off")
plt.suptitle(
"Predicted vs. true vote fractions β€” reached branches (w β‰₯ 0.05)\n"
"ViT-Base/16 + hierarchical KL+MSE (proposed model, test set)",
fontsize=12,
)
plt.tight_layout()
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
fig.savefig(path_png, dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_scatter_predicted_vs_true")
# ─────────────────────────────────────────────────────────────
# Figure 2: Calibration reliability diagrams
# ─────────────────────────────────────────────────────────────
def fig_calibration_reliability(model_results, save_dir, n_bins=15):
path_pdf = save_dir / "fig_calibration_reliability.pdf"
path_png = save_dir / "fig_calibration_reliability.png"
if path_pdf.exists() and path_png.exists():
log.info("Skip (exists): fig_calibration_reliability"); return
# Show 8 representative questions (skip t02 β€” bimodal, shown separately)
q_show = ["t01", "t03", "t04", "t06", "t07", "t09", "t10", "t11"]
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()
for ax_idx, q_name in enumerate(q_show):
ax = axes[ax_idx]
start, end = QUESTION_GROUPS[q_name]
q_idx = list(QUESTION_GROUPS.keys()).index(q_name)
for model_name, (preds, targets, weights) in model_results.items():
mask = weights[:, q_idx] >= 0.05
if mask.sum() < 50:
continue
pf = preds[mask, start:end].flatten()
tf = targets[mask, start:end].flatten()
# Adaptive bins (equal-frequency) β€” consistent with ECE computation
percentiles = np.linspace(0, 100, n_bins + 1)
bin_edges = np.unique(np.percentile(pf, percentiles))
if len(bin_edges) < 2:
continue
bin_ids = np.clip(
np.digitize(pf, bin_edges[1:-1]), 0, len(bin_edges) - 2
)
mp = np.array([
pf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan
for b in range(len(bin_edges) - 1)
])
mt = np.array([
tf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan
for b in range(len(bin_edges) - 1)
])
valid = ~np.isnan(mp) & ~np.isnan(mt)
ax.plot(
mp[valid], mt[valid],
MODEL_STYLES.get(model_name, "-"),
color=MODEL_COLORS.get(model_name, "#888888"),
linewidth=1.8, marker="o", markersize=3.5,
label=model_name, alpha=0.9,
)
ax.plot([0, 1], [0, 1], "k--", linewidth=1, alpha=0.5, label="Perfect")
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
ax.set_xlabel("Mean predicted", fontsize=8)
ax.set_ylabel("Mean true", fontsize=8)
ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9)
ax.set_aspect("equal")
if ax_idx == 0:
ax.legend(fontsize=6.5, loc="upper left")
plt.suptitle(
"Calibration reliability diagrams β€” all models (test set)\n"
"Reached branches only (w β‰₯ 0.05). Adaptive equal-frequency bins. "
"Closer to diagonal = better calibrated.",
fontsize=11,
)
plt.tight_layout()
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
fig.savefig(path_png, dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_calibration_reliability")
# ─────────────────────────────────────────────────────────────
# Figure 3: ECE bar chart
# ─────────────────────────────────────────────────────────────
def fig_ece_comparison(model_results, save_dir):
path_pdf = save_dir / "fig_ece_comparison.pdf"
path_png = save_dir / "fig_ece_comparison.png"
if path_pdf.exists() and path_png.exists():
log.info("Skip (exists): fig_ece_comparison"); return
q_names = list(QUESTION_GROUPS.keys())
ece_rows = []
for model_name, (preds, targets, weights) in model_results.items():
row = {"model": model_name}
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
mask = weights[:, q_idx] >= 0.05
if mask.sum() < 50:
row[q_name] = float("nan")
else:
row[q_name] = _compute_ece(
preds[mask, start:end].flatten(),
targets[mask, start:end].flatten(),
n_bins=15,
)
row["mean_ece"] = float(
np.nanmean([row[q] for q in q_names])
)
ece_rows.append(row)
df_ece = pd.DataFrame(ece_rows)
df_ece.to_csv(save_dir / "table_ece_comparison.csv", index=False)
x = np.arange(len(q_names))
width = 0.80 / len(model_results)
palette = list(MODEL_COLORS.values())
fig, ax = plt.subplots(figsize=(14, 5))
for i, (model_name, _) in enumerate(model_results.items()):
vals = [
float(df_ece[df_ece["model"] == model_name][q].values[0])
for q in q_names
]
ax.bar(
x + i * width, vals, width,
label=model_name,
color=MODEL_COLORS.get(model_name, palette[i % len(palette)]),
alpha=0.85, edgecolor="white", linewidth=0.5,
)
ax.set_xticks(x + width * (len(model_results) - 1) / 2)
ax.set_xticklabels(
[f"{q}\n({QUESTION_LABELS[q][:12]})" for q in q_names],
rotation=30, ha="right", fontsize=8,
)
ax.set_ylabel("Expected Calibration Error (ECE)", fontsize=11)
ax.set_title(
"Expected Calibration Error β€” all models (test set)\n"
"Reached branches (w β‰₯ 0.05). Adaptive equal-frequency binning. "
"Lower is better.",
fontsize=11,
)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3, axis="y")
ax.set_axisbelow(True)
plt.tight_layout()
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
fig.savefig(path_png, dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_ece_comparison")
# ─────────────────────────────────────────────────────────────
# Figure 4: Attention rollout gallery
# ─────────────────────────────────────────────────────────────
def fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir):
if attn_imgs is None or all_layers is None:
log.warning("No attention data β€” skipping gallery."); return
path_pdf = save_dir / "fig_attention_rollout_gallery.pdf"
path_png = save_dir / "fig_attention_rollout_gallery.png"
if not path_pdf.exists():
fig = plot_attention_grid(
attn_imgs, all_layers, attn_ids,
save_path=str(path_png),
n_cols=4, rollout_mode="full",
)
fig.savefig(path_pdf, dpi=300, bbox_inches="tight", facecolor="black")
plt.close(fig)
log.info("Saved: fig_attention_rollout_gallery")
# High-resolution PNG for journal submission
path_hq = save_dir / "fig_attention_rollout_gallery_HQ.png"
if not path_hq.exists():
fig2 = plot_attention_grid(
attn_imgs, all_layers, attn_ids,
n_cols=4, rollout_mode="full",
)
fig2.savefig(path_hq, dpi=600, bbox_inches="tight", facecolor="black")
plt.close(fig2)
log.info("Saved: fig_attention_rollout_gallery_HQ (600 dpi)")
# ─────────────────────────────────────────────────────────────
# Figure 5: Attention entropy vs. depth
# ─────────────────────────────────────────────────────────────
def fig_attention_entropy_depth(all_layers, save_dir):
if all_layers is None:
log.warning("No attention layers β€” skipping entropy plot."); return
path_pdf = save_dir / "fig_attention_entropy_depth.pdf"
path_png = save_dir / "fig_attention_entropy_depth.png"
if path_pdf.exists() and path_png.exists():
log.info("Skip (exists): fig_attention_entropy_depth"); return
fig = plot_attention_entropy(all_layers, save_path=str(path_png))
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_attention_entropy_depth")
# ─────────────────────────────────────────────────────────────
# Table: metrics for proposed model
# ─────────────────────────────────────────────────────────────
def table_metrics_proposed(preds, targets, weights, save_dir):
metrics = compute_metrics(preds, targets, weights)
rows = []
for q_name in QUESTION_GROUPS:
rows.append({
"question" : q_name,
"description": QUESTION_LABELS[q_name],
"MAE" : round(metrics[f"mae/{q_name}"], 5),
"RMSE" : round(metrics[f"rmse/{q_name}"], 5),
"bias" : round(metrics[f"bias/{q_name}"], 5),
"ECE" : round(metrics[f"ece/{q_name}"], 5),
})
rows.append({
"question": "weighted_avg", "description": "Weighted average",
"MAE" : round(metrics["mae/weighted_avg"], 5),
"RMSE": round(metrics["rmse/weighted_avg"], 5),
"bias": "",
"ECE" : round(metrics["ece/mean"], 5),
})
df = pd.DataFrame(rows)
df.to_csv(save_dir / "table_metrics_proposed.csv", index=False)
log.info("\n%s\n", df.to_string(index=False))
return metrics
# ─────────────────────────────────────────────────────────────
# Table: simplex violation for sigmoid baseline
# ─────────────────────────────────────────────────────────────
def table_simplex_violation(model_results, save_dir):
"""
For each model, report the fraction of test samples where per-question
predictions do not sum to 1 Β± 0.02. Expected: ~0 for softmax models,
nonzero for sigmoid baseline. This table explains why the sigmoid
baseline achieves lower raw per-answer MAE despite being scientifically
invalid: unconstrained sigmoid outputs fit each marginal independently.
"""
rows = []
for model_name, (preds, _, _) in model_results.items():
svr = simplex_violation_rate(preds, tolerance=0.02)
row = {"model": model_name}
row.update({q: round(svr[q], 4) for q in QUESTION_GROUPS})
row["mean"] = round(svr["mean"], 4)
rows.append(row)
df = pd.DataFrame(rows)
df.to_csv(save_dir / "table_simplex_violation.csv", index=False)
log.info("Saved: table_simplex_violation.csv")
log.info("\n%s\n", df[["model", "mean"]].to_string(index=False))
return df
# ─────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
args = parser.parse_args()
base_cfg = OmegaConf.load("configs/base.yaml")
exp_cfg = OmegaConf.load(args.config)
cfg = OmegaConf.merge(base_cfg, exp_cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = Path(cfg.outputs.figures_dir) / "evaluation"
save_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = Path(cfg.outputs.checkpoint_dir)
_, _, test_loader = build_dataloaders(cfg)
# ── Load all models ────────────────────────────────────────
log.info("Loading models from: %s", ckpt_dir)
def _load(path, model):
ckpt = torch.load(path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state"])
return model
vit_proposed = _load(
ckpt_dir / "best_full_train.pt", build_model(cfg)
).to(device)
vit_mse = _load(
ckpt_dir / "baseline_vit_mse.pt", build_model(cfg)
).to(device)
rn_mse = _load(
ckpt_dir / "baseline_resnet18_mse.pt",
ResNet18Baseline(dropout=cfg.model.dropout)
).to(device)
rn_kl = _load(
ckpt_dir / "baseline_resnet18_klmse.pt",
ResNet18Baseline(dropout=cfg.model.dropout)
).to(device)
vit_dirichlet = None
dp = ckpt_dir / "baseline_vit_dirichlet.pt"
if dp.exists():
vit_dirichlet = _load(dp, build_dirichlet_model(cfg)).to(device)
log.info("Loaded: ViT-Base + Dirichlet")
# ── Run inference ──────────────────────────────────────────
log.info("Running inference on test set...")
(p_proposed, t_proposed, w_proposed,
attn_imgs, all_layers, attn_ids) = _infer_vit(
vit_proposed, test_loader, device, cfg,
collect_attn=True, n_attn=16,
)
p_vit_mse, t_vit_mse, w_vit_mse = _infer_vit(
vit_mse, test_loader, device, cfg, collect_attn=False
)[:3]
p_rn_mse, t_rn_mse, w_rn_mse = _infer_resnet(
rn_mse, test_loader, device, cfg, use_sigmoid=True
)
p_rn_kl, t_rn_kl, w_rn_kl = _infer_resnet(
rn_kl, test_loader, device, cfg, use_sigmoid=False
)
# Build model_results dict (order determines legend order in figures)
model_results = {
"ResNet-18 + MSE (sigmoid)" : (p_rn_mse, t_rn_mse, w_rn_mse),
"ResNet-18 + KL+MSE" : (p_rn_kl, t_rn_kl, w_rn_kl),
"ViT-Base + MSE only" : (p_vit_mse, t_vit_mse, w_vit_mse),
"ViT-Base + KL+MSE (proposed)" : (p_proposed, t_proposed, w_proposed),
}
if vit_dirichlet is not None:
p_dir, t_dir, w_dir = _infer_dirichlet(
vit_dirichlet, test_loader, device, cfg
)
model_results["ViT-Base + Dirichlet (Zoobot-style)"] = (p_dir, t_dir, w_dir)
# ── Tables ─────────────────────────────────────────────────
log.info("Computing metrics...")
table_metrics_proposed(p_proposed, t_proposed, w_proposed, save_dir)
log.info("Computing reached-branch MAE table...")
df_r = compute_reached_branch_mae_table(model_results)
df_r.to_csv(save_dir / "table_reached_branch_mae.csv", index=False)
log.info("Saved: table_reached_branch_mae.csv")
log.info("Computing simplex violation table...")
table_simplex_violation(model_results, save_dir)
# ── Figures ────────────────────────────────────────────────
log.info("Generating figures...")
fig_scatter_predicted_vs_true(p_proposed, t_proposed, w_proposed, save_dir)
fig_calibration_reliability(model_results, save_dir)
fig_ece_comparison(model_results, save_dir)
fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir)
fig_attention_entropy_depth(all_layers, save_dir)
log.info("=" * 60)
log.info("ALL OUTPUTS SAVED TO: %s", save_dir)
log.info("=" * 60)
metrics = compute_metrics(p_proposed, t_proposed, w_proposed)
log.info("Proposed model β€” test set results:")
log.info(" Weighted MAE = %.5f", metrics["mae/weighted_avg"])
log.info(" Weighted RMSE = %.5f", metrics["rmse/weighted_avg"])
log.info(" Mean ECE = %.5f", metrics["ece/mean"])
if __name__ == "__main__":
main()