eshwar-gz2-api / src /uncertainty_analysis.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/uncertainty_analysis.py
----------------------------
MC Dropout epistemic uncertainty analysis for the proposed model.
MC Dropout (Gal & Ghahramani 2016) is used as a post-hoc uncertainty
estimator. At inference time, dropout is kept active and N=30 stochastic
forward passes are run per batch. The standard deviation across passes
is used as the epistemic uncertainty estimate per galaxy per question.
Key findings reported
---------------------
1. Uncertainty distributions: right-skewed, well-separated means across
questions reflecting the conditional nature of the decision tree.
2. Uncertainty vs. error correlation: Spearman ρ reported per question.
Strong positive correlation for root and shallow-branch questions
(t01, t02, t04, t07) indicates the model is well-calibrated in
uncertainty. Weak or near-zero correlation for deep conditional
branches (t03, t05, t08, t09, t10, t11) is expected β€” these branches
have small effective sample sizes and aleatoric uncertainty dominates.
3. Morphology selection benchmark: F1 score at threshold Ο„ for downstream
binary morphology classification tasks.
Output files
------------
outputs/figures/uncertainty/
fig_uncertainty_distributions.pdf
fig_uncertainty_vs_error.pdf
fig_morphology_f1_comparison.pdf
table_uncertainty_summary.csv
table_morphology_selection_benchmark.csv
mc_cache/ β€” cached numpy arrays (crash-safe)
Usage
-----
cd ~/galaxy
nohup python -m src.uncertainty_analysis \
--config configs/full_train.yaml --n_passes 30 \
> outputs/logs/uncertainty.log 2>&1 &
echo "PID: $!"
"""
import argparse
import logging
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy import stats as scipy_stats
from torch.amp import autocast
from omegaconf import OmegaConf
from tqdm import tqdm
from src.dataset import build_dataloaders, QUESTION_GROUPS
from src.model import build_model, build_dirichlet_model
from src.baselines import ResNet18Baseline
from src.metrics import predictions_to_numpy, dirichlet_predictions_to_numpy
logging.basicConfig(
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
)
log = logging.getLogger("uncertainty")
plt.rcParams.update({
"figure.dpi": 150, "savefig.dpi": 300,
"font.family": "serif", "font.size": 11,
"axes.titlesize": 10, "axes.labelsize": 10,
"xtick.labelsize": 8, "ytick.labelsize": 8,
"legend.fontsize": 8,
"figure.facecolor": "white", "axes.facecolor": "white",
"axes.grid": True, "grid.alpha": 0.3,
"pdf.fonttype": 42, "ps.fonttype": 42,
})
QUESTION_LABELS = {
"t01": "Smooth or features", "t02": "Edge-on disk",
"t03": "Bar", "t04": "Spiral arms",
"t05": "Bulge prominence", "t06": "Odd feature",
"t07": "Roundedness", "t08": "Odd feature type",
"t09": "Bulge shape", "t10": "Arms winding",
"t11": "Arms number",
}
MODEL_COLORS = {
"ViT-Base + KL+MSE (proposed)" : "#27ae60",
"ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad",
"ResNet-18 + MSE (sigmoid)" : "#c0392b",
"ResNet-18 + KL+MSE" : "#e67e22",
}
SELECTION_THRESHOLDS = [0.5, 0.7, 0.8, 0.9]
SELECTION_ANSWERS = {
"t01": (0, "smooth"),
"t02": (0, "edge-on"),
"t03": (0, "bar"),
"t04": (0, "spiral"),
"t06": (0, "odd feature"),
}
# ─────────────────────────────────────────────────────────────
# MC Dropout inference β€” Welford online algorithm, crash-safe
# ─────────────────────────────────────────────────────────────
def run_mc_inference(model, loader, device, cfg,
n_passes=30, cache_dir=None):
"""
Fast batched MC Dropout inference.
Uses Welford's online algorithm to compute mean and std
per batch without storing all n_passes Γ— N predictions.
Memory usage: O(N Γ— 37) regardless of n_passes.
Parameters
----------
model : GalaxyViT with enable_mc_dropout() available
loader : test DataLoader
device : inference device
cfg : OmegaConf config
n_passes : number of stochastic forward passes (default 30)
cache_dir : if given, saves .npy files and skips if they exist
Returns
-------
mean_all, std_all : [N, 37] float32
targets_all : [N, 37] float32
weights_all : [N, 11] float32
"""
if cache_dir is not None:
cache_dir = Path(cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
fp_mean = cache_dir / "mc_mean.npy"
fp_std = cache_dir / "mc_std.npy"
fp_targets = cache_dir / "mc_targets.npy"
fp_weights = cache_dir / "mc_weights.npy"
if all(p.exists() for p in [fp_mean, fp_std, fp_targets, fp_weights]):
log.info("MC cache found β€” loading from disk (skipping inference).")
return (np.load(fp_mean), np.load(fp_std),
np.load(fp_targets), np.load(fp_weights))
model.eval()
model.enable_mc_dropout()
all_means, all_stds, all_targets, all_weights = [], [], [], []
log.info("MC Dropout: %d passes Γ— %d-image batches = %d total forward passes",
n_passes, loader.batch_size, n_passes * len(loader))
for images, targets, weights, _ in tqdm(loader, desc="MC Dropout"):
images_dev = images.to(device, non_blocking=True)
# Welford online mean and M2
mean_acc = None
M2_acc = None
count = 0
for _ in range(n_passes):
with torch.no_grad():
with autocast("cuda", enabled=cfg.training.mixed_precision):
logits = model(images_dev)
pred = torch.zeros_like(logits)
for q, (s, e) in QUESTION_GROUPS.items():
pred[:, s:e] = F.softmax(logits[:, s:e], dim=-1)
pred_np = pred.cpu().float().numpy() # [B, 37]
count += 1
if mean_acc is None:
mean_acc = pred_np.copy()
M2_acc = np.zeros_like(pred_np)
else:
delta = pred_np - mean_acc
mean_acc += delta / count
M2_acc += delta * (pred_np - mean_acc)
std_acc = np.sqrt(M2_acc / (count - 1) if count > 1
else np.zeros_like(M2_acc))
all_means.append(mean_acc)
all_stds.append(std_acc)
all_targets.append(targets.numpy())
all_weights.append(weights.numpy())
model.disable_mc_dropout()
mean_all = np.concatenate(all_means)
std_all = np.concatenate(all_stds)
targets_all = np.concatenate(all_targets)
weights_all = np.concatenate(all_weights)
if cache_dir is not None:
np.save(fp_mean, mean_all)
np.save(fp_std, std_all)
np.save(fp_targets, targets_all)
np.save(fp_weights, weights_all)
log.info("MC results cached: %s", cache_dir)
return mean_all, std_all, targets_all, weights_all
# ─────────────────────────────────────────────────────────────
# Figure 1: Uncertainty distributions
# ─────────────────────────────────────────────────────────────
def fig_uncertainty_distributions(mean_preds, std_preds,
targets, weights, save_dir):
path_pdf = save_dir / "fig_uncertainty_distributions.pdf"
path_png = save_dir / "fig_uncertainty_distributions.png"
if path_pdf.exists() and path_png.exists():
log.info("Skip (exists): fig_uncertainty_distributions"); return
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
axes = axes.flatten()
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
ax = axes[q_idx]
mask = weights[:, q_idx] >= 0.05
std_q = std_preds[mask, start:end].mean(axis=1)
ax.hist(std_q, bins=50, color="#6366f1", alpha=0.85,
edgecolor="none", density=True)
ax.axvline(std_q.mean(), color="#c0392b", linewidth=1.8,
linestyle="--", label=f"Mean = {std_q.mean():.4f}")
ax.set_xlabel("MC Dropout std (epistemic uncertainty)")
ax.set_ylabel("Density")
ax.set_title(
f"{q_name}: {QUESTION_LABELS[q_name]}\n"
f"$n$ = {mask.sum():,} (w β‰₯ 0.05)",
fontsize=9,
)
ax.legend(fontsize=7)
axes[-1].axis("off")
plt.suptitle(
"Epistemic uncertainty distributions β€” MC Dropout (30 passes)\n"
"Proposed model (ViT-Base/16 + hierarchical KL+MSE), test set",
fontsize=12,
)
plt.tight_layout()
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
fig.savefig(path_png, dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_uncertainty_distributions")
# ─────────────────────────────────────────────────────────────
# Figure 2: Uncertainty vs. error (Spearman ρ)
# ─────────────────────────────────────────────────────────────
def fig_uncertainty_vs_error(mean_preds, std_preds,
targets, weights, save_dir):
path_pdf = save_dir / "fig_uncertainty_vs_error.pdf"
path_png = save_dir / "fig_uncertainty_vs_error.png"
if path_pdf.exists() and path_png.exists():
log.info("Skip (exists): fig_uncertainty_vs_error"); return
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
axes = axes.flatten()
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
ax = axes[q_idx]
mask = weights[:, q_idx] >= 0.05
unc = std_preds[mask, start:end].mean(axis=1)
err = np.abs(mean_preds[mask, start:end] -
targets[mask, start:end]).mean(axis=1)
# Adaptive bin means for trend line
n_bins = 15
unc_bins = np.unique(np.percentile(unc, np.linspace(0, 100, n_bins + 1)))
bin_ids = np.clip(np.digitize(unc, unc_bins) - 1, 0, len(unc_bins) - 2)
bn_unc = [unc[bin_ids == b].mean() for b in range(len(unc_bins) - 1)
if (bin_ids == b).any()]
bn_err = [err[bin_ids == b].mean() for b in range(len(unc_bins) - 1)
if (bin_ids == b).any()]
ax.scatter(unc, err, alpha=0.04, s=1, color="#94a3b8", rasterized=True)
ax.plot(bn_unc, bn_err, "r-o", markersize=4, linewidth=2,
label="Bin mean")
# Spearman rank correlation (more robust than Pearson for this data)
rho, pval = scipy_stats.spearmanr(unc, err)
p_str = f"p < 0.001" if pval < 0.001 else f"p = {pval:.3f}"
ax.text(0.05, 0.90,
f"Spearman ρ = {rho:.3f}\n{p_str}",
transform=ax.transAxes, fontsize=7.5,
bbox=dict(boxstyle="round,pad=0.25", facecolor="white",
edgecolor="grey", alpha=0.85))
ax.set_xlabel("Uncertainty (MC std)")
ax.set_ylabel("Absolute error")
ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9)
ax.legend(fontsize=7)
axes[-1].axis("off")
plt.suptitle(
"Epistemic uncertainty vs. absolute prediction error β€” per morphological question\n"
"Strong Spearman ρ for root/shallow questions; weak ρ for deep conditional branches "
"(expected: aleatoric uncertainty dominates when branch is rarely reached)",
fontsize=10,
)
plt.tight_layout()
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
fig.savefig(path_png, dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_uncertainty_vs_error")
# ─────────────────────────────────────────────────────────────
# Table: uncertainty summary
# ─────────────────────────────────────────────────────────────
def table_uncertainty_summary(mean_preds, std_preds,
targets, weights, save_dir):
path = save_dir / "table_uncertainty_summary.csv"
if path.exists():
log.info("Skip (exists): table_uncertainty_summary"); return
rows = []
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
mask = weights[:, q_idx] >= 0.05
unc = std_preds[mask, start:end].mean(axis=1)
err = np.abs(mean_preds[mask, start:end] -
targets[mask, start:end]).mean(axis=1)
if mask.sum() > 10:
rho, pval = scipy_stats.spearmanr(unc, err)
else:
rho, pval = float("nan"), float("nan")
rows.append({
"question" : q_name,
"description" : QUESTION_LABELS[q_name],
"n_reached" : int(mask.sum()),
"mean_uncertainty": round(float(unc.mean()), 5),
"std_uncertainty" : round(float(unc.std()), 5),
"mean_mae" : round(float(err.mean()), 5),
"spearman_rho" : round(float(rho), 4),
"spearman_pval" : round(float(pval), 4),
})
df = pd.DataFrame(rows)
df.to_csv(path, index=False)
log.info("Saved: table_uncertainty_summary.csv")
print("\n" + df.to_string(index=False) + "\n")
return df
# ─────────────────────────────────────────────────────────────
# Figure 3 + Table: Morphology selection benchmark
# ─────────────────────────────────────────────────────────────
def morphology_selection_benchmark(model_results, save_dir):
csv_path = save_dir / "table_morphology_selection_benchmark.csv"
if csv_path.exists():
log.info("Loading existing morphology benchmark...")
df = pd.read_csv(csv_path)
_fig_morphology_f1(df, save_dir)
return df
rows = []
for model_name, (preds, targets, weights) in model_results.items():
for q_name, (ans_idx, ans_label) in SELECTION_ANSWERS.items():
start, end = QUESTION_GROUPS[q_name]
q_idx = list(QUESTION_GROUPS.keys()).index(q_name)
mask = weights[:, q_idx] >= 0.05
pred_a = preds[mask, start + ans_idx]
true_a = targets[mask, start + ans_idx]
for thresh in SELECTION_THRESHOLDS:
sel = pred_a >= thresh
true_pos = true_a >= thresh
n_sel = sel.sum()
n_tp_all = true_pos.sum()
n_tp = (sel & true_pos).sum()
prec = n_tp / n_sel if n_sel > 0 else 0.0
rec = n_tp / n_tp_all if n_tp_all > 0 else 0.0
f1 = (2 * prec * rec / (prec + rec)
if (prec + rec) > 0 else 0.0)
rows.append({
"model" : model_name,
"question" : q_name,
"answer" : ans_label,
"threshold" : thresh,
"n_selected": int(n_sel),
"n_true_pos": int(n_tp_all),
"precision" : round(float(prec), 4),
"recall" : round(float(rec), 4),
"f1" : round(float(f1), 4),
})
df = pd.DataFrame(rows)
df.to_csv(csv_path, index=False)
log.info("Saved: table_morphology_selection_benchmark.csv")
_fig_morphology_f1(df, save_dir)
return df
def _fig_morphology_f1(df, save_dir):
path_pdf = save_dir / "fig_morphology_f1_comparison.pdf"
path_png = save_dir / "fig_morphology_f1_comparison.png"
if path_pdf.exists() and path_png.exists():
log.info("Skip (exists): fig_morphology_f1_comparison"); return
thresh = 0.8
sub = df[df["threshold"] == thresh]
q_list = list(SELECTION_ANSWERS.keys())
models = list(df["model"].unique())
x = np.arange(len(q_list))
width = 0.80 / len(models)
palette = list(MODEL_COLORS.values())
fig, ax = plt.subplots(figsize=(12, 5))
for i, model in enumerate(models):
f1s = []
for q in q_list:
row = sub[(sub["model"] == model) & (sub["question"] == q)]
f1s.append(float(row["f1"].values[0]) if len(row) > 0 else 0.0)
ax.bar(
x + i * width, f1s, width,
label=model,
color=MODEL_COLORS.get(model, palette[i % len(palette)]),
alpha=0.85, edgecolor="white", linewidth=0.5,
)
ax.set_xticks(x + width * (len(models) - 1) / 2)
ax.set_xticklabels(
[f"{q}\n({SELECTION_ANSWERS[q][1]})" for q in q_list], fontsize=9
)
ax.set_ylabel("F$_1$ score", fontsize=11)
ax.set_title(
f"Downstream morphology selection β€” F$_1$ at threshold $\\tau$ = {thresh}\n"
"Higher F$_1$ indicates cleaner, more complete morphological sample selection.",
fontsize=11,
)
ax.legend(fontsize=8)
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3, axis="y")
ax.set_axisbelow(True)
plt.tight_layout()
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
fig.savefig(path_png, dpi=300, bbox_inches="tight")
plt.close(fig)
log.info("Saved: fig_morphology_f1_comparison")
# ─────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
parser.add_argument("--n_passes", type=int, default=30)
args = parser.parse_args()
base_cfg = OmegaConf.load("configs/base.yaml")
exp_cfg = OmegaConf.load(args.config)
cfg = OmegaConf.merge(base_cfg, exp_cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = Path(cfg.outputs.figures_dir) / "uncertainty"
save_dir.mkdir(parents=True, exist_ok=True)
cache_dir = save_dir / "mc_cache"
ckpt_dir = Path(cfg.outputs.checkpoint_dir)
_, _, test_loader = build_dataloaders(cfg)
# ── 1. MC Dropout on proposed model ───────────────────────
log.info("Loading proposed model...")
proposed = build_model(cfg).to(device)
proposed.load_state_dict(
torch.load(ckpt_dir / "best_full_train.pt",
map_location="cpu", weights_only=True)["model_state"]
)
mean_preds, std_preds, targets, weights = run_mc_inference(
proposed, test_loader, device, cfg,
n_passes=args.n_passes, cache_dir=cache_dir,
)
log.info("MC Dropout complete: %d galaxies, %d passes.",
len(mean_preds), args.n_passes)
# ── 2. Uncertainty figures and table ──────────────────────
fig_uncertainty_distributions(mean_preds, std_preds, targets, weights, save_dir)
fig_uncertainty_vs_error(mean_preds, std_preds, targets, weights, save_dir)
table_uncertainty_summary(mean_preds, std_preds, targets, weights, save_dir)
# ── 3. Downstream benchmark across all models ─────────────
log.info("Building model_results for downstream benchmark...")
model_results = {
"ViT-Base + KL+MSE (proposed)": (mean_preds, targets, weights),
}
def _load_resnet(ckpt_name, use_sigmoid):
m = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
m.load_state_dict(
torch.load(ckpt_dir / ckpt_name, map_location="cpu",
weights_only=True)["model_state"]
)
m.eval()
preds_l, tgts_l, wgts_l = [], [], []
with torch.no_grad():
for images, tgts, wgts, _ in tqdm(test_loader, desc=f"ResNet {ckpt_name}"):
images = images.to(device, non_blocking=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
logits = m(images)
if use_sigmoid:
p = torch.sigmoid(logits).cpu().numpy()
else:
p = logits.detach().cpu().clone()
for q, (s, e) in QUESTION_GROUPS.items():
p[:, s:e] = F.softmax(p[:, s:e], dim=-1)
p = p.numpy()
preds_l.append(p)
tgts_l.append(tgts.numpy())
wgts_l.append(wgts.numpy())
return (np.concatenate(preds_l),
np.concatenate(tgts_l),
np.concatenate(wgts_l))
rn_mse_ckpt = "baseline_resnet18_mse.pt"
rn_klm_ckpt = "baseline_resnet18_klmse.pt"
if (ckpt_dir / rn_mse_ckpt).exists():
model_results["ResNet-18 + MSE (sigmoid)"] = _load_resnet(
rn_mse_ckpt, use_sigmoid=True
)
if (ckpt_dir / rn_klm_ckpt).exists():
model_results["ResNet-18 + KL+MSE"] = _load_resnet(
rn_klm_ckpt, use_sigmoid=False
)
dp = ckpt_dir / "baseline_vit_dirichlet.pt"
if dp.exists():
vit_dir = build_dirichlet_model(cfg).to(device)
vit_dir.load_state_dict(
torch.load(dp, map_location="cpu", weights_only=True)["model_state"]
)
vit_dir.eval()
d_p, d_t, d_w = [], [], []
with torch.no_grad():
for images, tgts, wgts, _ in tqdm(test_loader, desc="Dirichlet"):
images = images.to(device, non_blocking=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
alpha = vit_dir(images)
p, t, w = dirichlet_predictions_to_numpy(alpha, tgts, wgts)
d_p.append(p); d_t.append(t); d_w.append(w)
model_results["ViT-Base + Dirichlet (Zoobot-style)"] = (
np.concatenate(d_p),
np.concatenate(d_t),
np.concatenate(d_w),
)
df_sel = morphology_selection_benchmark(model_results, save_dir)
log.info("=" * 60)
log.info("DOWNSTREAM F1 @ Ο„ = 0.8")
log.info("=" * 60)
summary = df_sel[df_sel["threshold"] == 0.8][
["model", "question", "answer", "precision", "recall", "f1"]
]
log.info("\n%s\n", summary.to_string(index=False))
log.info("All outputs saved to: %s", save_dir)
if __name__ == "__main__":
main()