#!/usr/bin/env python3 """ posterior_inference.py — VLB-based cosmological inference (Mudur et al. 2023 §4 style). Pure inference-time; frozen DDPM weights. Script lives next to diffusion_conditional.py. """ from __future__ import annotations import argparse import ast import json import os import sys import time from pathlib import Path from typing import Dict, List, Optional, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import matplotlib.patheffects as mpathe import numpy as np import torch # ── Project imports ──────────────────────────────────────────────────────────── _ROOT = Path(__file__).resolve().parent if (_ROOT / "diffusion_conditional.py").is_file(): sys.path.insert(0, str(_ROOT)) from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel from unet_conditional import ConditionalUNet plt.rcParams.update({ "figure.facecolor": "white", "axes.facecolor": "white", "axes.edgecolor": "#222", "axes.linewidth": 0.7, "axes.spines.top": False, "axes.spines.right": False, "font.family": "DejaVu Sans", "font.size": 9.5, "savefig.facecolor": "white", }) REAL_COLOR = "#CC3333" GEN_COLOR = "#2266BB" SIGMA_LEVELS = [2.30, 6.17, 11.83] SIGMA_COLORS = ["#1a5c9e", "#5590d0", "#99c0ea"] SIGMA_LABELS = {2.30: r"$1\sigma$", 6.17: r"$2\sigma$", 11.83: r"$3\sigma$"} def load_config(path: str) -> Dict: p = Path(path) if p.suffix == ".json": with open(p) as f: return json.load(f) cfg = {} with open(p) as f: for line in f: if ":" not in line: continue k, v = line.strip().split(":", 1) try: cfg[k.strip()] = ast.literal_eval(v.strip()) except Exception: cfg[k.strip()] = v.strip() return cfg def load_model(ckpt: str, cfg: Dict, device: torch.device) -> ConditionalDiffusionModel: unet = ConditionalUNet( in_channels=1, out_channels=1, label_dim=int(cfg.get("label_dim", 2)), base_channels=int(cfg.get("base_channels", 64)), channel_multipliers=list(cfg.get("channel_multipliers", [1, 2, 4, 8])), attention_levels=list(cfg.get("attention_levels", [2, 3])), dropout=float(cfg.get("dropout", 0.1)), ) diff = GaussianDiffusion( timesteps=int(cfg.get("timesteps", 1500)), beta_start=float(cfg.get("beta_start", 1e-4)), beta_end=float(cfg.get("beta_end", 0.02)), schedule_type=str(cfg.get("schedule_type", "linear")), ) model = ConditionalDiffusionModel(unet, diff).to(device) ck = torch.load(ckpt, map_location=device, weights_only=False) if isinstance(ck, dict) and "ema_shadow" in ck: cur = model.state_dict() for k, v in ck["ema_shadow"].items(): if k in cur: cur[k] = v model.load_state_dict(cur) print(" Loaded EMA weights") elif isinstance(ck, dict) and "model_state_dict" in ck: model.load_state_dict(ck["model_state_dict"]) else: model.load_state_dict(ck) model.eval() for p in model.parameters(): p.requires_grad_(False) return model def load_test_data( data_dir: str, n_fields: int, seed: int = 42 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: dp = Path(data_dir) lsuf = "_2" if (dp / "train_labels_LH_2.npy").exists() else "" isuf = "" if (dp / "train_LH.npy").exists() else "_6" imgs = np.load(dp / f"test_LH{isuf}.npy").astype(np.float32) labels = np.load(dp / f"test_labels_LH{lsuf}.npy").astype(np.float32) tr_lab = np.load(dp / f"train_labels_LH{lsuf}.npy").astype(np.float32) rng = np.random.default_rng(seed) idx = rng.choice(len(imgs), n_fields, replace=False) label_mu = tr_lab.mean(0) label_std = np.where(tr_lab.std(0) == 0, 1.0, tr_lab.std(0)) return imgs[idx], labels[idx], label_mu, label_std def normal_kl(mean1, log_var1, mean2, log_var2): return 0.5 * ( -1.0 + log_var2 - log_var1 + torch.exp(log_var1 - log_var2) + ((mean1 - mean2) ** 2) * torch.exp(-log_var2) ) def _approx_standard_normal_cdf(x): return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * x ** 3))) def discretised_gaussian_log_likelihood(x_0, mean, log_var): centered_x = x_0 - mean inv_stdv = torch.exp(-0.5 * log_var) plus_in = inv_stdv * (centered_x + 1.0 / 255.0) min_in = inv_stdv * (centered_x - 1.0 / 255.0) cdf_plus = _approx_standard_normal_cdf(plus_in) cdf_min = _approx_standard_normal_cdf(min_in) log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) cdf_delta = (cdf_plus - cdf_min).clamp(min=1e-12) log_probs = torch.where( x_0 < -0.999, log_cdf_plus, torch.where(x_0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta)), ) return log_probs def predict_x_start_from_eps(diff: GaussianDiffusion, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor: # Matches GaussianDiffusion._predict_xstart_from_noise (diffusion_conditional.py) return ( diff._extract(diff.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t - diff._extract(diff.sqrt_recip_minus_one, t, x_t.shape) * eps ) def q_posterior_mean_var(diff: GaussianDiffusion, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor): mean = ( diff._extract(diff.posterior_mean_coef1, t, x_t.shape) * x_start + diff._extract(diff.posterior_mean_coef2, t, x_t.shape) * x_t ) var = diff._extract(diff.posterior_variance, t, x_t.shape) log_var_c = diff._extract(diff.posterior_log_variance_clipped, t, x_t.shape) return mean, var, log_var_c @torch.no_grad() def compute_L_t( model: ConditionalDiffusionModel, x_0: torch.Tensor, labels_n: torch.Tensor, t: int, fixed_eps: torch.Tensor, ) -> torch.Tensor: diff = model.diffusion device = x_0.device B = x_0.shape[0] t_vec = torch.full((B,), t, device=device, dtype=torch.long) if t == 0: t1 = torch.full((B,), 1, device=device, dtype=torch.long) ab1 = diff._extract(diff.alphas_cumprod, t1, x_0.shape) x_1 = torch.sqrt(ab1) * x_0 + torch.sqrt(1.0 - ab1) * fixed_eps eps_pred = model(x_1, t1, labels_n) x_start_pred = predict_x_start_from_eps(diff, x_1, t1, eps_pred).clamp(-1, 1) mean, _, log_var = q_posterior_mean_var(diff, x_start_pred, x_1, t1) log_p = discretised_gaussian_log_likelihood(x_0, mean, log_var) return -log_p.sum(dim=(1, 2, 3)) ab_t = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape) x_t = torch.sqrt(ab_t) * x_0 + torch.sqrt(1.0 - ab_t) * fixed_eps true_mean, _, true_log_var = q_posterior_mean_var(diff, x_0, x_t, t_vec) eps_pred = model(x_t, t_vec, labels_n) x_start_pred = predict_x_start_from_eps(diff, x_t, t_vec, eps_pred).clamp(-1, 1) model_mean, _, model_log_var = q_posterior_mean_var(diff, x_start_pred, x_t, t_vec) kl = normal_kl(true_mean, true_log_var, model_mean, model_log_var) return kl.sum(dim=(1, 2, 3)) @torch.no_grad() def compute_L_T_analytic(diff: GaussianDiffusion, x_0: torch.Tensor) -> torch.Tensor: T = diff.timesteps t_vec = torch.full((x_0.shape[0],), T - 1, device=x_0.device, dtype=torch.long) abar_T = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape) mean1 = torch.sqrt(abar_T) * x_0 log_var1 = torch.log((1.0 - abar_T).clamp(min=1e-30)) kl = normal_kl(mean1, log_var1, torch.zeros_like(mean1), torch.zeros_like(log_var1)) return kl.sum(dim=(1, 2, 3)) def build_eval_grid( Om_true: float, s8_true: float, grid_size: int, span: float = 0.1, Om_range: Tuple[float, float] = (0.10, 0.50), s8_range: Tuple[float, float] = (0.60, 1.00), ) -> Tuple[np.ndarray, np.ndarray]: Om_lo = max(Om_true - span, Om_range[0]) Om_hi = min(Om_true + span, Om_range[1]) s8_lo = max(s8_true - span, s8_range[0]) s8_hi = min(s8_true + span, s8_range[1]) Om_1d = np.linspace(Om_lo, Om_hi, grid_size) s8_1d = np.linspace(s8_lo, s8_hi, grid_size) return Om_1d, s8_1d @torch.no_grad() def evaluate_vlb_surface( model: ConditionalDiffusionModel, x_0: torch.Tensor, Om_grid: np.ndarray, s8_grid: np.ndarray, label_mu: np.ndarray, label_std: np.ndarray, t_values: List[int], n_seeds: int = 4, batch_size: int = 32, label_dim: int = 2, fixed_seed: int = 0, device: Optional[torch.device] = None, ) -> Dict[int, np.ndarray]: device = device or x_0.device nO, nS = len(Om_grid), len(s8_grid) n_pts = nO * nS Omg, s8g = np.meshgrid(Om_grid, s8_grid, indexing="ij") raw_labels = np.column_stack([Omg.ravel(), s8g.ravel()]) if label_dim > 2: pad = np.zeros((n_pts, label_dim - 2), dtype=np.float32) for i in range(label_dim - 2): pad[:, i] = label_mu[2 + i] raw_labels = np.concatenate([raw_labels, pad], axis=1) norm_labels = (raw_labels - label_mu) / label_std norm_labels_t = torch.from_numpy(norm_labels.astype(np.float32)).to(device) L_surfaces = {t: np.zeros(n_pts, dtype=np.float64) for t in t_values} H, W = x_0.shape[-2], x_0.shape[-1] rng_torch = torch.Generator(device=device).manual_seed(fixed_seed) seeds_eps = [ torch.randn(1, 1, H, W, generator=rng_torch, device=device) for _ in range(n_seeds) ] for _, fixed_eps in enumerate(seeds_eps): for t in t_values: for start in range(0, n_pts, batch_size): end = min(start + batch_size, n_pts) bsz = end - start x_b = x_0.expand(bsz, -1, -1, -1) lbl_b = norm_labels_t[start:end] eps_b = fixed_eps.expand(bsz, -1, -1, -1) L_t = compute_L_t(model, x_b, lbl_b, t=t, fixed_eps=eps_b) L_surfaces[t][start:end] += L_t.cpu().numpy() / n_seeds return {t: L_surfaces[t].reshape(nO, nS) for t in t_values} def marginal_from_neg2dL( neg2dL: np.ndarray, Om_grid: np.ndarray, s8_grid: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, Tuple[float, float]]: L = -0.5 * neg2dL L = L - L.max() P = np.exp(L) Om_marginal = P.sum(axis=1) Om_marginal /= Om_marginal.sum() s8_marginal = P.sum(axis=0) s8_marginal /= s8_marginal.sum() Om_pred = float(Om_grid[np.argmax(Om_marginal)]) s8_pred = float(s8_grid[np.argmax(s8_marginal)]) return Om_marginal, s8_marginal, (Om_pred, s8_pred) def credible_interval_68(values: np.ndarray, probs: np.ndarray) -> Tuple[float, float, float]: cdf = np.cumsum(probs) cdf /= cdf[-1] median = float(np.interp(0.50, cdf, values)) lo = float(np.interp(0.16, cdf, values)) hi = float(np.interp(0.84, cdf, values)) return median, lo, hi def fig_contours_per_t( surfaces: Dict[int, np.ndarray], Om_grid: np.ndarray, s8_grid: np.ndarray, Om_true: float, s8_true: float, out_path: Path, dpi: int = 200, ) -> None: fig, ax = plt.subplots(figsize=(7, 6.5), dpi=dpi) cmap = plt.cm.viridis n_t = len(surfaces) colors = cmap(np.linspace(0.05, 0.95, n_t)) for (t, L_surf), col in zip(sorted(surfaces.items()), colors): neg2dL = 2.0 * (L_surf - L_surf.min()) ax.contour( Om_grid, s8_grid, neg2dL.T, levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"], ) ax.plot([], [], color=col, lw=1.8, label=f"t={t}") ax.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10) ax.set_xlabel(r"$\Omega_m$", fontsize=12) ax.set_ylabel(r"$\sigma_8$", fontsize=12) ax.set_title( r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contour per timestep" "\n(Mudur-style) smaller $t$ → tighter constraint", fontweight="bold", fontsize=10, ) ax.legend(fontsize=8, loc="best", ncol=2, framealpha=0.92) ax.grid(alpha=0.18) ax.set_xlim(Om_grid[0], Om_grid[-1]) ax.set_ylim(s8_grid[0], s8_grid[-1]) fig.savefig(out_path, bbox_inches="tight", dpi=dpi) plt.close(fig) print(f" Saved -> {out_path}") def _L0_posterior_smoothed( L0_surface: np.ndarray, smooth_sigma: float = 0.6, ): from scipy.ndimage import gaussian_filter as gf neg2dL = 2.0 * (L0_surface - L0_surface.min()) surface_sm = gf(neg2dL, sigma=smooth_sigma) return surface_sm, neg2dL def draw_L0_posterior_main_panel( ax, surface_sm: np.ndarray, Om_grid: np.ndarray, s8_grid: np.ndarray, Om_true: float, s8_true: float, Om_pred: float, s8_pred: float, *, clabel_fontsize: float = 9.5, marker_ms: float = 16, ) -> None: ax.contourf(Om_grid, s8_grid, surface_sm.T, levels=60, cmap="Blues_r", vmin=0, vmax=SIGMA_LEVELS[-1] * 3, extend="max", alpha=0.55) for lv, co in zip(reversed(SIGMA_LEVELS), reversed(SIGMA_COLORS)): ax.contourf(Om_grid, s8_grid, surface_sm.T, levels=[0, lv], colors=[co], alpha=0.78) cs = ax.contour( Om_grid, s8_grid, surface_sm.T, levels=SIGMA_LEVELS, colors=["white", "white", "white"], linewidths=[2.2, 1.6, 1.2], linestyles=["-", "--", "-."], ) ax.clabel(cs, fmt=SIGMA_LABELS, inline=True, fontsize=clabel_fontsize, colors="white") ax.axvline(Om_true, color="red", lw=0.7, ls=":", alpha=0.6) ax.axhline(s8_true, color="red", lw=0.7, ls=":", alpha=0.6) ax.plot(Om_true, s8_true, "r+", ms=marker_ms, mew=2.5, zorder=6, label="True") ax.plot(Om_pred, s8_pred, "w^", ms=max(6, marker_ms * 0.55), mew=1.2, zorder=6, label="MAP") ax.set_xlim(Om_grid[0], Om_grid[-1]) ax.set_ylim(s8_grid[0], s8_grid[-1]) ax.grid(alpha=0.18) def fig_posterior_L0_mosaic_3x3( out_dir: Path, n_fields: int, out_path: Path, mosaic_side_px: int = 10_000, panel_inches: float = 4.0, ) -> None: from matplotlib.patches import Patch n_plot = min(n_fields, 9) fig_side = panel_inches * 3 dpi = mosaic_side_px / fig_side fig, axes = plt.subplots( 3, 3, figsize=(fig_side, fig_side), dpi=dpi, squeeze=False, ) for idx in range(9): r, c = divmod(idx, 3) ax = axes[r][c] if idx >= n_plot: ax.set_visible(False) continue nz = np.load(out_dir / f"field{idx:02d}_surfaces.npz") L0 = np.asarray(nz["L_t0"]) Om_grid = np.asarray(nz["Om_grid"]) s8_grid = np.asarray(nz["s8_grid"]) Om_true = float(nz["Om_true"]) s8_true = float(nz["s8_true"]) surface_sm, _ = _L0_posterior_smoothed(L0, smooth_sigma=0.6) _, _, (Om_pred, s8_pred) = marginal_from_neg2dL(surface_sm, Om_grid, s8_grid) draw_L0_posterior_main_panel( ax, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred, clabel_fontsize=7.0, marker_ms=11, ) if idx == 0: legend_patches = [ Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"), Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"), Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"), ] hs, ls_ = ax.get_legend_handles_labels() ax.legend( handles=legend_patches + hs, labels=[p.get_label() for p in legend_patches] + ls_, fontsize=6, loc="upper right", framealpha=0.9, ) else: leg = ax.get_legend() if leg is not None: leg.remove() ax.set_title( rf"field {idx}: $\Omega_m^{{\rm true}}={Om_true:.3f}$, $\sigma_8^{{\rm true}}={s8_true:.3f}$", fontsize=8, ) ax.set_xlabel(r"$\Omega_m$", fontsize=8) ax.set_ylabel(r"$\sigma_8$", fontsize=8) fig.suptitle( r"VLB $L_0$ posterior (2D) — 9 test fields", fontsize=11, fontweight="bold", y=0.995, ) fig.savefig(out_path, bbox_inches="tight", dpi=dpi) plt.close(fig) print(f" Saved -> {out_path} (≈ {mosaic_side_px}×{mosaic_side_px} px)") def fig_main_posterior( L0_surface: np.ndarray, Om_grid: np.ndarray, s8_grid: np.ndarray, Om_true: float, s8_true: float, out_path: Path, dpi: int = 200, ): from matplotlib.patches import Patch surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6) Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL( surface_sm, Om_grid, s8_grid ) Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg) s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg) fig = plt.figure(figsize=(8.5, 8.5), dpi=dpi) gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4], hspace=0.05, wspace=0.05, left=0.10, right=0.95, top=0.95, bottom=0.08) ax_main = fig.add_subplot(gs[1, 0]) ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main) ax_rt = fig.add_subplot(gs[1, 1], sharey=ax_main) draw_L0_posterior_main_panel( ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred, ) ax_main.set_xlabel(r"$\Omega_m$", fontsize=11) ax_main.set_ylabel(r"$\sigma_8$", fontsize=11) ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6) ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4) ax_top.axvline(Om_true, color="red", lw=1.0, ls=":") ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--", path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")]) ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI") ax_top.set_yticks([]) ax_top.tick_params(labelbottom=False) ax_top.set_title( rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$" rf" (true: {Om_true:.3f})", fontsize=9, ) ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6) ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4) ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":") ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--", path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")]) ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18) ax_rt.set_xticks([]) ax_rt.tick_params(labelleft=False) ax_rt.set_ylabel( rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$" rf" (true: {s8_true:.3f})", fontsize=9, rotation=270, labelpad=15, ) ax_rt.yaxis.set_label_position("right") legend_patches = [ Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"), Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"), Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"), ] hs, ls_ = ax_main.get_legend_handles_labels() ax_main.legend( handles=legend_patches + hs, labels=[p.get_label() for p in legend_patches] + ls_, fontsize=8, loc="upper right", framealpha=0.92, ) fig.suptitle( r"VLB Posterior using $L_0$ — joint and marginal distributions", fontsize=10, fontweight="bold", y=0.99, ) fig.savefig(out_path, bbox_inches="tight", dpi=dpi) plt.close(fig) print(f" Saved -> {out_path}") return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) def fig_pred_vs_true(pred_results: List[Dict], out_path: Path, dpi: int = 200) -> None: Om_true = np.array([r["Om_true"] for r in pred_results]) s8_true = np.array([r["s8_true"] for r in pred_results]) Om_pred = np.array([r["Om_pred"] for r in pred_results]) s8_pred = np.array([r["s8_pred"] for r in pred_results]) Om_err_lo = np.array([r["Om_pred"] - r["Om_lo"] for r in pred_results]) Om_err_hi = np.array([r["Om_hi"] - r["Om_pred"] for r in pred_results]) s8_err_lo = np.array([r["s8_pred"] - r["s8_lo"] for r in pred_results]) s8_err_hi = np.array([r["s8_hi"] - r["s8_pred"] for r in pred_results]) rmse_Om = np.sqrt(((Om_pred - Om_true) ** 2).mean()) rmse_s8 = np.sqrt(((s8_pred - s8_true) ** 2).mean()) fig, axes = plt.subplots(1, 2, figsize=(11, 5), dpi=dpi) for ax, (true, pred, err_lo, err_hi, name, prange, rmse) in zip(axes, [ (Om_true, Om_pred, Om_err_lo, Om_err_hi, r"$\Omega_m$", (0.10, 0.50), rmse_Om), (s8_true, s8_pred, s8_err_lo, s8_err_hi, r"$\sigma_8$", (0.60, 1.00), rmse_s8), ]): ax.errorbar( true, pred, yerr=[np.maximum(err_lo, 0), np.maximum(err_hi, 0)], fmt="o", color=GEN_COLOR, ecolor=SIGMA_COLORS[1], elinewidth=1.2, capsize=3, ms=6, label="DDPM-VLB inference (68% CI)", ) ax.plot(prange, prange, "k--", lw=1.0, alpha=0.5, label="Identity") ax.set_xlabel(f"True {name}", fontsize=11) ax.set_ylabel(f"Predicted {name}", fontsize=11) ax.set_xlim(*prange) ax.set_ylim(*prange) ax.grid(alpha=0.2) ax.legend(fontsize=9, loc="lower right") ax.text( 0.04, 0.92, f"RMSE = {rmse:.4f}", transform=ax.transAxes, fontsize=10, bbox=dict(facecolor="white", edgecolor="#ccc", alpha=0.92, pad=4), ) ax.set_title(f"{name}: predicted vs true", fontweight="bold", fontsize=10) fig.suptitle( "VLB Parameter Inference: predicted vs true\n" r"Error bars = 68% CI from $L_0$ marginal posterior", fontsize=10, fontweight="bold", y=1.01, ) plt.tight_layout() fig.savefig(out_path, bbox_inches="tight", dpi=dpi) plt.close(fig) print(f" Saved -> {out_path}") print(f" RMSE: Omega_m={rmse_Om:.4f} sigma_8={rmse_s8:.4f}") def fig_posterior_and_contours_combined( surfaces: Dict[int, np.ndarray], L0_surface: np.ndarray, Om_grid: np.ndarray, s8_grid: np.ndarray, Om_true: float, s8_true: float, out_path: Path, dpi: int = 200, ) -> Tuple[float, float, Tuple[float, float], Tuple[float, float]]: """ Create a combined figure with contours_per_t on left and posterior on right. """ from matplotlib.patches import Patch surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6) Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL( surface_sm, Om_grid, s8_grid ) Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg) s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg) fig = plt.figure(figsize=(16, 7), dpi=dpi) gs = gridspec.GridSpec(2, 4, width_ratios=[4, 0.3, 4, 1], height_ratios=[1, 4], hspace=0.08, wspace=0.10, left=0.08, right=0.96, top=0.94, bottom=0.08) # Left panel: Contours per timestep ax_contours = fig.add_subplot(gs[:, 0]) cmap = plt.cm.viridis n_t = len(surfaces) colors = cmap(np.linspace(0.05, 0.95, n_t)) for (t, L_surf), col in zip(sorted(surfaces.items()), colors): neg2dL = 2.0 * (L_surf - L_surf.min()) ax_contours.contour( Om_grid, s8_grid, neg2dL.T, levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"], ) ax_contours.plot([], [], color=col, lw=1.8, label=f"t={t}") ax_contours.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10) ax_contours.set_xlabel(r"$\Omega_m$", fontsize=12) ax_contours.set_ylabel(r"$\sigma_8$", fontsize=12) ax_contours.set_title( r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contours per timestep", fontweight="bold", fontsize=11, ) ax_contours.legend(fontsize=8, loc="best", ncol=1, framealpha=0.92) ax_contours.grid(alpha=0.18) ax_contours.set_xlim(Om_grid[0], Om_grid[-1]) ax_contours.set_ylim(s8_grid[0], s8_grid[-1]) # Right panel: Posterior L_0 (similar to fig_main_posterior layout) ax_main = fig.add_subplot(gs[1, 2]) ax_top = fig.add_subplot(gs[0, 2], sharex=ax_main) ax_rt = fig.add_subplot(gs[1, 3], sharey=ax_main) draw_L0_posterior_main_panel( ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred, ) ax_main.set_xlabel(r"$\Omega_m$", fontsize=11) ax_main.set_ylabel(r"$\sigma_8$", fontsize=11) ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6) ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4) ax_top.axvline(Om_true, color="red", lw=1.0, ls=":") ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--", path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")]) ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI") ax_top.set_yticks([]) ax_top.tick_params(labelbottom=False) ax_top.set_title( rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$" rf" (true: {Om_true:.3f})", fontsize=9, ) ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6) ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4) ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":") ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--", path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")]) ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18) ax_rt.set_xticks([]) ax_rt.tick_params(labelleft=False) ax_rt.set_ylabel( rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$" rf" (true: {s8_true:.3f})", fontsize=9, rotation=270, labelpad=15, ) ax_rt.yaxis.set_label_position("right") legend_patches = [ Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"), Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"), Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"), ] hs, ls_ = ax_main.get_legend_handles_labels() ax_main.legend( handles=legend_patches + hs, labels=[p.get_label() for p in legend_patches] + ls_, fontsize=8, loc="upper right", framealpha=0.92, ) fig.suptitle( r"VLB Inference: $L_t$ contours (left) and $L_0$ posterior (right)", fontsize=12, fontweight="bold", y=0.99, ) fig.savefig(out_path, bbox_inches="tight", dpi=dpi) plt.close(fig) print(f" Saved -> {out_path}") return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="VLB-based parameter inference for trained conditional DDPM.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("--checkpoint", required=True) p.add_argument("--training_args", default=None) p.add_argument("--data_dir", default="./data/params_2") p.add_argument("--output_dir", default="vlb_inference_outputs") p.add_argument("--n_fields", type=int, default=9) p.add_argument( "--grid_size", type=int, default=10_000, help="Ωm×σ8 evaluation grid resolution (each side). Values > 300 require " "--allow_huge_grid (very long runs for large grid_size).", ) p.add_argument( "--allow_huge_grid", action="store_true", help="Required when --grid_size > 300 (avoids accidental multi-week GPU jobs).", ) p.add_argument( "--mosaic_side_px", type=int, default=10_000, help="Pixel width/height of posterior_L0_mosaic_3x3.png (square).", ) p.add_argument( "--mosaic_panel_inches", type=float, default=4.0, help="Matplotlib size (inches) of each 3×3 panel; dpi = mosaic_side_px / (3× this).", ) p.add_argument("--span", type=float, default=0.10) p.add_argument("--t_subset", type=int, nargs="+", default=[0, 1, 2, 5, 8, 10, 15, 20]) p.add_argument("--n_seeds", type=int, default=4) p.add_argument("--batch_size", type=int, default=32) p.add_argument("--device", default="auto") p.add_argument("--seed", type=int, default=42) p.add_argument("--dpi", type=int, default=200) return p.parse_args() def autodetect_args() -> Optional[str]: for pat in ["outputs_conditional_*/args.json", "outputs_conditional_*/args.txt"]: cands = sorted(Path(".").glob(pat), key=os.path.getctime, reverse=True) if cands: return str(cands[0]) return None def main() -> None: args = parse_args() if args.grid_size > 300 and not args.allow_huge_grid: print( "\nRefusing --grid_size={} (> 300) without --allow_huge_grid.\n" "A 10_000×10_000 grid is roughly (200)^2 ≈ 40_000× more forward passes per\n" "field than 50×50. For a quick run use e.g. --grid_size 50; for a high-res\n" "summary figure use the default --mosaic_side_px without increasing grid_size.\n" "To proceed anyway: add --allow_huge_grid\n".format(args.grid_size) ) raise SystemExit(2) torch.manual_seed(args.seed) np.random.seed(args.seed) device = ( torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.device == "auto" else torch.device(args.device) ) print(f"\nDevice: {device}") out = Path(args.output_dir) out.mkdir(parents=True, exist_ok=True) if args.training_args is None: args.training_args = autodetect_args() if args.training_args is None: raise FileNotFoundError("Cannot find args.json — pass --training_args") print(f" Auto-detected args: {args.training_args}") cfg = load_config(args.training_args) print("\nLoading model ...") model = load_model(args.checkpoint, cfg, device) n_p = sum(p.numel() for p in model.parameters()) print(f" Parameters: {n_p:,} T={model.diffusion.timesteps}") print(f"\nLoading {args.n_fields} test fields ...") test_imgs, test_labels, label_mu, label_std = load_test_data( args.data_dir, args.n_fields, seed=args.seed, ) print(f" Image shape: {test_imgs.shape[1:]}") print(f" Label dim: {test_labels.shape[1]}") print(f" Label μ/σ: {label_mu} / {label_std}") print( f"\nEvaluating L_t on {args.grid_size}x{args.grid_size} grid for " f"{len(args.t_subset)} timesteps × {args.n_seeds} seeds ..." ) print( f" -> {args.grid_size ** 2 * len(args.t_subset) * args.n_seeds:,} " f"forward-pass groups per field (× seeds averaged)" ) pred_results = [] label_dim = int(cfg.get("label_dim", 2)) for fi in range(args.n_fields): Om_true = float(test_labels[fi, 0]) s8_true = float(test_labels[fi, 1]) print(f"\n [{fi+1}/{args.n_fields}] field with " f"Om={Om_true:.3f}, s8={s8_true:.3f}") x_0 = torch.from_numpy(test_imgs[fi:fi + 1] * 2.0 - 1.0).unsqueeze(1).to(device) Om_grid, s8_grid = build_eval_grid(Om_true, s8_true, args.grid_size, args.span) t_start = time.time() surfaces = evaluate_vlb_surface( model=model, x_0=x_0, Om_grid=Om_grid, s8_grid=s8_grid, label_mu=label_mu, label_std=label_std, t_values=args.t_subset, n_seeds=args.n_seeds, batch_size=args.batch_size, label_dim=label_dim, fixed_seed=args.seed + fi, device=device, ) elapsed = time.time() - t_start print(f" Evaluation time: {elapsed:.1f}s") np.savez( out / f"field{fi:02d}_surfaces.npz", **{f"L_t{t}": s for t, s in surfaces.items()}, Om_grid=Om_grid, s8_grid=s8_grid, Om_true=Om_true, s8_true=s8_true, ) if 0 in surfaces: # Combined figure: contours_per_t + posterior on same plot Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) = fig_posterior_and_contours_combined( surfaces, surfaces[0], Om_grid, s8_grid, Om_true, s8_true, out / f"field{fi:02d}_combined.png", dpi=args.dpi, ) pred_results.append(dict( Om_true=Om_true, s8_true=s8_true, Om_pred=Om_pred, s8_pred=s8_pred, Om_lo=Om_lo, Om_hi=Om_hi, s8_lo=s8_lo, s8_hi=s8_hi, )) # Also save individual figures for detailed inspection fig_contours_per_t( surfaces, Om_grid, s8_grid, Om_true, s8_true, out / f"field{fi:02d}_contours_per_t.png", dpi=args.dpi, ) if 0 in surfaces: fig_main_posterior( surfaces[0], Om_grid, s8_grid, Om_true, s8_true, out / f"field{fi:02d}_posterior_L0.png", dpi=args.dpi, ) if len(pred_results) >= 2: fig_pred_vs_true(pred_results, out / "summary_pred_vs_true.png", dpi=args.dpi) np.savez( out / "summary.npz", **{k: np.array([r[k] for r in pred_results]) for k in pred_results[0].keys()}, ) if args.n_fields >= 9 and all((out / f"field{i:02d}_surfaces.npz").is_file() for i in range(9)): fig_posterior_L0_mosaic_3x3( out, args.n_fields, out / "posterior_L0_mosaic_3x3.png", mosaic_side_px=args.mosaic_side_px, panel_inches=args.mosaic_panel_inches, ) print(f"\nAll outputs -> {out.resolve()}/") for f in sorted(out.glob("*.png")): print(f" {f.name}") if __name__ == "__main__": main()