""" src/uncertainty_analysis.py ---------------------------- MC Dropout epistemic uncertainty analysis for the proposed model. MC Dropout (Gal & Ghahramani 2016) is used as a post-hoc uncertainty estimator. At inference time, dropout is kept active and N=30 stochastic forward passes are run per batch. The standard deviation across passes is used as the epistemic uncertainty estimate per galaxy per question. Key findings reported --------------------- 1. Uncertainty distributions: right-skewed, well-separated means across questions reflecting the conditional nature of the decision tree. 2. Uncertainty vs. error correlation: Spearman ρ reported per question. Strong positive correlation for root and shallow-branch questions (t01, t02, t04, t07) indicates the model is well-calibrated in uncertainty. Weak or near-zero correlation for deep conditional branches (t03, t05, t08, t09, t10, t11) is expected — these branches have small effective sample sizes and aleatoric uncertainty dominates. 3. Morphology selection benchmark: F1 score at threshold τ for downstream binary morphology classification tasks. Output files ------------ outputs/figures/uncertainty/ fig_uncertainty_distributions.pdf fig_uncertainty_vs_error.pdf fig_morphology_f1_comparison.pdf table_uncertainty_summary.csv table_morphology_selection_benchmark.csv mc_cache/ — cached numpy arrays (crash-safe) Usage ----- cd ~/galaxy nohup python -m src.uncertainty_analysis \ --config configs/full_train.yaml --n_passes 30 \ > outputs/logs/uncertainty.log 2>&1 & echo "PID: $!" """ import argparse import logging import sys from pathlib import Path import numpy as np import pandas as pd import torch import torch.nn.functional as F import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from scipy import stats as scipy_stats from torch.amp import autocast from omegaconf import OmegaConf from tqdm import tqdm from src.dataset import build_dataloaders, QUESTION_GROUPS from src.model import build_model, build_dirichlet_model from src.baselines import ResNet18Baseline from src.metrics import predictions_to_numpy, dirichlet_predictions_to_numpy logging.basicConfig( format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, ) log = logging.getLogger("uncertainty") plt.rcParams.update({ "figure.dpi": 150, "savefig.dpi": 300, "font.family": "serif", "font.size": 11, "axes.titlesize": 10, "axes.labelsize": 10, "xtick.labelsize": 8, "ytick.labelsize": 8, "legend.fontsize": 8, "figure.facecolor": "white", "axes.facecolor": "white", "axes.grid": True, "grid.alpha": 0.3, "pdf.fonttype": 42, "ps.fonttype": 42, }) QUESTION_LABELS = { "t01": "Smooth or features", "t02": "Edge-on disk", "t03": "Bar", "t04": "Spiral arms", "t05": "Bulge prominence", "t06": "Odd feature", "t07": "Roundedness", "t08": "Odd feature type", "t09": "Bulge shape", "t10": "Arms winding", "t11": "Arms number", } MODEL_COLORS = { "ViT-Base + KL+MSE (proposed)" : "#27ae60", "ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad", "ResNet-18 + MSE (sigmoid)" : "#c0392b", "ResNet-18 + KL+MSE" : "#e67e22", } SELECTION_THRESHOLDS = [0.5, 0.7, 0.8, 0.9] SELECTION_ANSWERS = { "t01": (0, "smooth"), "t02": (0, "edge-on"), "t03": (0, "bar"), "t04": (0, "spiral"), "t06": (0, "odd feature"), } # ───────────────────────────────────────────────────────────── # MC Dropout inference — Welford online algorithm, crash-safe # ───────────────────────────────────────────────────────────── def run_mc_inference(model, loader, device, cfg, n_passes=30, cache_dir=None): """ Fast batched MC Dropout inference. Uses Welford's online algorithm to compute mean and std per batch without storing all n_passes × N predictions. Memory usage: O(N × 37) regardless of n_passes. Parameters ---------- model : GalaxyViT with enable_mc_dropout() available loader : test DataLoader device : inference device cfg : OmegaConf config n_passes : number of stochastic forward passes (default 30) cache_dir : if given, saves .npy files and skips if they exist Returns ------- mean_all, std_all : [N, 37] float32 targets_all : [N, 37] float32 weights_all : [N, 11] float32 """ if cache_dir is not None: cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) fp_mean = cache_dir / "mc_mean.npy" fp_std = cache_dir / "mc_std.npy" fp_targets = cache_dir / "mc_targets.npy" fp_weights = cache_dir / "mc_weights.npy" if all(p.exists() for p in [fp_mean, fp_std, fp_targets, fp_weights]): log.info("MC cache found — loading from disk (skipping inference).") return (np.load(fp_mean), np.load(fp_std), np.load(fp_targets), np.load(fp_weights)) model.eval() model.enable_mc_dropout() all_means, all_stds, all_targets, all_weights = [], [], [], [] log.info("MC Dropout: %d passes × %d-image batches = %d total forward passes", n_passes, loader.batch_size, n_passes * len(loader)) for images, targets, weights, _ in tqdm(loader, desc="MC Dropout"): images_dev = images.to(device, non_blocking=True) # Welford online mean and M2 mean_acc = None M2_acc = None count = 0 for _ in range(n_passes): with torch.no_grad(): with autocast("cuda", enabled=cfg.training.mixed_precision): logits = model(images_dev) pred = torch.zeros_like(logits) for q, (s, e) in QUESTION_GROUPS.items(): pred[:, s:e] = F.softmax(logits[:, s:e], dim=-1) pred_np = pred.cpu().float().numpy() # [B, 37] count += 1 if mean_acc is None: mean_acc = pred_np.copy() M2_acc = np.zeros_like(pred_np) else: delta = pred_np - mean_acc mean_acc += delta / count M2_acc += delta * (pred_np - mean_acc) std_acc = np.sqrt(M2_acc / (count - 1) if count > 1 else np.zeros_like(M2_acc)) all_means.append(mean_acc) all_stds.append(std_acc) all_targets.append(targets.numpy()) all_weights.append(weights.numpy()) model.disable_mc_dropout() mean_all = np.concatenate(all_means) std_all = np.concatenate(all_stds) targets_all = np.concatenate(all_targets) weights_all = np.concatenate(all_weights) if cache_dir is not None: np.save(fp_mean, mean_all) np.save(fp_std, std_all) np.save(fp_targets, targets_all) np.save(fp_weights, weights_all) log.info("MC results cached: %s", cache_dir) return mean_all, std_all, targets_all, weights_all # ───────────────────────────────────────────────────────────── # Figure 1: Uncertainty distributions # ───────────────────────────────────────────────────────────── def fig_uncertainty_distributions(mean_preds, std_preds, targets, weights, save_dir): path_pdf = save_dir / "fig_uncertainty_distributions.pdf" path_png = save_dir / "fig_uncertainty_distributions.png" if path_pdf.exists() and path_png.exists(): log.info("Skip (exists): fig_uncertainty_distributions"); return fig, axes = plt.subplots(3, 4, figsize=(16, 12)) axes = axes.flatten() for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): ax = axes[q_idx] mask = weights[:, q_idx] >= 0.05 std_q = std_preds[mask, start:end].mean(axis=1) ax.hist(std_q, bins=50, color="#6366f1", alpha=0.85, edgecolor="none", density=True) ax.axvline(std_q.mean(), color="#c0392b", linewidth=1.8, linestyle="--", label=f"Mean = {std_q.mean():.4f}") ax.set_xlabel("MC Dropout std (epistemic uncertainty)") ax.set_ylabel("Density") ax.set_title( f"{q_name}: {QUESTION_LABELS[q_name]}\n" f"$n$ = {mask.sum():,} (w ≥ 0.05)", fontsize=9, ) ax.legend(fontsize=7) axes[-1].axis("off") plt.suptitle( "Epistemic uncertainty distributions — MC Dropout (30 passes)\n" "Proposed model (ViT-Base/16 + hierarchical KL+MSE), test set", fontsize=12, ) plt.tight_layout() fig.savefig(path_pdf, dpi=300, bbox_inches="tight") fig.savefig(path_png, dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_uncertainty_distributions") # ───────────────────────────────────────────────────────────── # Figure 2: Uncertainty vs. error (Spearman ρ) # ───────────────────────────────────────────────────────────── def fig_uncertainty_vs_error(mean_preds, std_preds, targets, weights, save_dir): path_pdf = save_dir / "fig_uncertainty_vs_error.pdf" path_png = save_dir / "fig_uncertainty_vs_error.png" if path_pdf.exists() and path_png.exists(): log.info("Skip (exists): fig_uncertainty_vs_error"); return fig, axes = plt.subplots(3, 4, figsize=(16, 12)) axes = axes.flatten() for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): ax = axes[q_idx] mask = weights[:, q_idx] >= 0.05 unc = std_preds[mask, start:end].mean(axis=1) err = np.abs(mean_preds[mask, start:end] - targets[mask, start:end]).mean(axis=1) # Adaptive bin means for trend line n_bins = 15 unc_bins = np.unique(np.percentile(unc, np.linspace(0, 100, n_bins + 1))) bin_ids = np.clip(np.digitize(unc, unc_bins) - 1, 0, len(unc_bins) - 2) bn_unc = [unc[bin_ids == b].mean() for b in range(len(unc_bins) - 1) if (bin_ids == b).any()] bn_err = [err[bin_ids == b].mean() for b in range(len(unc_bins) - 1) if (bin_ids == b).any()] ax.scatter(unc, err, alpha=0.04, s=1, color="#94a3b8", rasterized=True) ax.plot(bn_unc, bn_err, "r-o", markersize=4, linewidth=2, label="Bin mean") # Spearman rank correlation (more robust than Pearson for this data) rho, pval = scipy_stats.spearmanr(unc, err) p_str = f"p < 0.001" if pval < 0.001 else f"p = {pval:.3f}" ax.text(0.05, 0.90, f"Spearman ρ = {rho:.3f}\n{p_str}", transform=ax.transAxes, fontsize=7.5, bbox=dict(boxstyle="round,pad=0.25", facecolor="white", edgecolor="grey", alpha=0.85)) ax.set_xlabel("Uncertainty (MC std)") ax.set_ylabel("Absolute error") ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9) ax.legend(fontsize=7) axes[-1].axis("off") plt.suptitle( "Epistemic uncertainty vs. absolute prediction error — per morphological question\n" "Strong Spearman ρ for root/shallow questions; weak ρ for deep conditional branches " "(expected: aleatoric uncertainty dominates when branch is rarely reached)", fontsize=10, ) plt.tight_layout() fig.savefig(path_pdf, dpi=300, bbox_inches="tight") fig.savefig(path_png, dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_uncertainty_vs_error") # ───────────────────────────────────────────────────────────── # Table: uncertainty summary # ───────────────────────────────────────────────────────────── def table_uncertainty_summary(mean_preds, std_preds, targets, weights, save_dir): path = save_dir / "table_uncertainty_summary.csv" if path.exists(): log.info("Skip (exists): table_uncertainty_summary"); return rows = [] for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): mask = weights[:, q_idx] >= 0.05 unc = std_preds[mask, start:end].mean(axis=1) err = np.abs(mean_preds[mask, start:end] - targets[mask, start:end]).mean(axis=1) if mask.sum() > 10: rho, pval = scipy_stats.spearmanr(unc, err) else: rho, pval = float("nan"), float("nan") rows.append({ "question" : q_name, "description" : QUESTION_LABELS[q_name], "n_reached" : int(mask.sum()), "mean_uncertainty": round(float(unc.mean()), 5), "std_uncertainty" : round(float(unc.std()), 5), "mean_mae" : round(float(err.mean()), 5), "spearman_rho" : round(float(rho), 4), "spearman_pval" : round(float(pval), 4), }) df = pd.DataFrame(rows) df.to_csv(path, index=False) log.info("Saved: table_uncertainty_summary.csv") print("\n" + df.to_string(index=False) + "\n") return df # ───────────────────────────────────────────────────────────── # Figure 3 + Table: Morphology selection benchmark # ───────────────────────────────────────────────────────────── def morphology_selection_benchmark(model_results, save_dir): csv_path = save_dir / "table_morphology_selection_benchmark.csv" if csv_path.exists(): log.info("Loading existing morphology benchmark...") df = pd.read_csv(csv_path) _fig_morphology_f1(df, save_dir) return df rows = [] for model_name, (preds, targets, weights) in model_results.items(): for q_name, (ans_idx, ans_label) in SELECTION_ANSWERS.items(): start, end = QUESTION_GROUPS[q_name] q_idx = list(QUESTION_GROUPS.keys()).index(q_name) mask = weights[:, q_idx] >= 0.05 pred_a = preds[mask, start + ans_idx] true_a = targets[mask, start + ans_idx] for thresh in SELECTION_THRESHOLDS: sel = pred_a >= thresh true_pos = true_a >= thresh n_sel = sel.sum() n_tp_all = true_pos.sum() n_tp = (sel & true_pos).sum() prec = n_tp / n_sel if n_sel > 0 else 0.0 rec = n_tp / n_tp_all if n_tp_all > 0 else 0.0 f1 = (2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0) rows.append({ "model" : model_name, "question" : q_name, "answer" : ans_label, "threshold" : thresh, "n_selected": int(n_sel), "n_true_pos": int(n_tp_all), "precision" : round(float(prec), 4), "recall" : round(float(rec), 4), "f1" : round(float(f1), 4), }) df = pd.DataFrame(rows) df.to_csv(csv_path, index=False) log.info("Saved: table_morphology_selection_benchmark.csv") _fig_morphology_f1(df, save_dir) return df def _fig_morphology_f1(df, save_dir): path_pdf = save_dir / "fig_morphology_f1_comparison.pdf" path_png = save_dir / "fig_morphology_f1_comparison.png" if path_pdf.exists() and path_png.exists(): log.info("Skip (exists): fig_morphology_f1_comparison"); return thresh = 0.8 sub = df[df["threshold"] == thresh] q_list = list(SELECTION_ANSWERS.keys()) models = list(df["model"].unique()) x = np.arange(len(q_list)) width = 0.80 / len(models) palette = list(MODEL_COLORS.values()) fig, ax = plt.subplots(figsize=(12, 5)) for i, model in enumerate(models): f1s = [] for q in q_list: row = sub[(sub["model"] == model) & (sub["question"] == q)] f1s.append(float(row["f1"].values[0]) if len(row) > 0 else 0.0) ax.bar( x + i * width, f1s, width, label=model, color=MODEL_COLORS.get(model, palette[i % len(palette)]), alpha=0.85, edgecolor="white", linewidth=0.5, ) ax.set_xticks(x + width * (len(models) - 1) / 2) ax.set_xticklabels( [f"{q}\n({SELECTION_ANSWERS[q][1]})" for q in q_list], fontsize=9 ) ax.set_ylabel("F$_1$ score", fontsize=11) ax.set_title( f"Downstream morphology selection — F$_1$ at threshold $\\tau$ = {thresh}\n" "Higher F$_1$ indicates cleaner, more complete morphological sample selection.", fontsize=11, ) ax.legend(fontsize=8) ax.set_ylim(0, 1) ax.grid(True, alpha=0.3, axis="y") ax.set_axisbelow(True) plt.tight_layout() fig.savefig(path_pdf, dpi=300, bbox_inches="tight") fig.savefig(path_png, dpi=300, bbox_inches="tight") plt.close(fig) log.info("Saved: fig_morphology_f1_comparison") # ───────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", required=True) parser.add_argument("--n_passes", type=int, default=30) args = parser.parse_args() base_cfg = OmegaConf.load("configs/base.yaml") exp_cfg = OmegaConf.load(args.config) cfg = OmegaConf.merge(base_cfg, exp_cfg) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") save_dir = Path(cfg.outputs.figures_dir) / "uncertainty" save_dir.mkdir(parents=True, exist_ok=True) cache_dir = save_dir / "mc_cache" ckpt_dir = Path(cfg.outputs.checkpoint_dir) _, _, test_loader = build_dataloaders(cfg) # ── 1. MC Dropout on proposed model ─────────────────────── log.info("Loading proposed model...") proposed = build_model(cfg).to(device) proposed.load_state_dict( torch.load(ckpt_dir / "best_full_train.pt", map_location="cpu", weights_only=True)["model_state"] ) mean_preds, std_preds, targets, weights = run_mc_inference( proposed, test_loader, device, cfg, n_passes=args.n_passes, cache_dir=cache_dir, ) log.info("MC Dropout complete: %d galaxies, %d passes.", len(mean_preds), args.n_passes) # ── 2. Uncertainty figures and table ────────────────────── fig_uncertainty_distributions(mean_preds, std_preds, targets, weights, save_dir) fig_uncertainty_vs_error(mean_preds, std_preds, targets, weights, save_dir) table_uncertainty_summary(mean_preds, std_preds, targets, weights, save_dir) # ── 3. Downstream benchmark across all models ───────────── log.info("Building model_results for downstream benchmark...") model_results = { "ViT-Base + KL+MSE (proposed)": (mean_preds, targets, weights), } def _load_resnet(ckpt_name, use_sigmoid): m = ResNet18Baseline(dropout=cfg.model.dropout).to(device) m.load_state_dict( torch.load(ckpt_dir / ckpt_name, map_location="cpu", weights_only=True)["model_state"] ) m.eval() preds_l, tgts_l, wgts_l = [], [], [] with torch.no_grad(): for images, tgts, wgts, _ in tqdm(test_loader, desc=f"ResNet {ckpt_name}"): images = images.to(device, non_blocking=True) with autocast("cuda", enabled=cfg.training.mixed_precision): logits = m(images) if use_sigmoid: p = torch.sigmoid(logits).cpu().numpy() else: p = logits.detach().cpu().clone() for q, (s, e) in QUESTION_GROUPS.items(): p[:, s:e] = F.softmax(p[:, s:e], dim=-1) p = p.numpy() preds_l.append(p) tgts_l.append(tgts.numpy()) wgts_l.append(wgts.numpy()) return (np.concatenate(preds_l), np.concatenate(tgts_l), np.concatenate(wgts_l)) rn_mse_ckpt = "baseline_resnet18_mse.pt" rn_klm_ckpt = "baseline_resnet18_klmse.pt" if (ckpt_dir / rn_mse_ckpt).exists(): model_results["ResNet-18 + MSE (sigmoid)"] = _load_resnet( rn_mse_ckpt, use_sigmoid=True ) if (ckpt_dir / rn_klm_ckpt).exists(): model_results["ResNet-18 + KL+MSE"] = _load_resnet( rn_klm_ckpt, use_sigmoid=False ) dp = ckpt_dir / "baseline_vit_dirichlet.pt" if dp.exists(): vit_dir = build_dirichlet_model(cfg).to(device) vit_dir.load_state_dict( torch.load(dp, map_location="cpu", weights_only=True)["model_state"] ) vit_dir.eval() d_p, d_t, d_w = [], [], [] with torch.no_grad(): for images, tgts, wgts, _ in tqdm(test_loader, desc="Dirichlet"): images = images.to(device, non_blocking=True) with autocast("cuda", enabled=cfg.training.mixed_precision): alpha = vit_dir(images) p, t, w = dirichlet_predictions_to_numpy(alpha, tgts, wgts) d_p.append(p); d_t.append(t); d_w.append(w) model_results["ViT-Base + Dirichlet (Zoobot-style)"] = ( np.concatenate(d_p), np.concatenate(d_t), np.concatenate(d_w), ) df_sel = morphology_selection_benchmark(model_results, save_dir) log.info("=" * 60) log.info("DOWNSTREAM F1 @ τ = 0.8") log.info("=" * 60) summary = df_sel[df_sel["threshold"] == 0.8][ ["model", "question", "answer", "precision", "recall", "f1"] ] log.info("\n%s\n", summary.to_string(index=False)) log.info("All outputs saved to: %s", save_dir) if __name__ == "__main__": main()