DDPM-6param / src /posterior_inference.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified
#!/usr/bin/env python3
"""
posterior_inference.py — VLB-based cosmological inference (Mudur et al. 2023 §4 style).
Pure inference-time; frozen DDPM weights. Script lives next to diffusion_conditional.py.
"""
from __future__ import annotations
import argparse
import ast
import json
import os
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.patheffects as mpathe
import numpy as np
import torch
# ── Project imports ────────────────────────────────────────────────────────────
_ROOT = Path(__file__).resolve().parent
if (_ROOT / "diffusion_conditional.py").is_file():
sys.path.insert(0, str(_ROOT))
from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel
from unet_conditional import ConditionalUNet
plt.rcParams.update({
"figure.facecolor": "white", "axes.facecolor": "white",
"axes.edgecolor": "#222", "axes.linewidth": 0.7,
"axes.spines.top": False, "axes.spines.right": False,
"font.family": "DejaVu Sans", "font.size": 9.5,
"savefig.facecolor": "white",
})
REAL_COLOR = "#CC3333"
GEN_COLOR = "#2266BB"
SIGMA_LEVELS = [2.30, 6.17, 11.83]
SIGMA_COLORS = ["#1a5c9e", "#5590d0", "#99c0ea"]
SIGMA_LABELS = {2.30: r"$1\sigma$", 6.17: r"$2\sigma$", 11.83: r"$3\sigma$"}
def load_config(path: str) -> Dict:
p = Path(path)
if p.suffix == ".json":
with open(p) as f:
return json.load(f)
cfg = {}
with open(p) as f:
for line in f:
if ":" not in line:
continue
k, v = line.strip().split(":", 1)
try:
cfg[k.strip()] = ast.literal_eval(v.strip())
except Exception:
cfg[k.strip()] = v.strip()
return cfg
def load_model(ckpt: str, cfg: Dict, device: torch.device) -> ConditionalDiffusionModel:
unet = ConditionalUNet(
in_channels=1, out_channels=1,
label_dim=int(cfg.get("label_dim", 2)),
base_channels=int(cfg.get("base_channels", 64)),
channel_multipliers=list(cfg.get("channel_multipliers", [1, 2, 4, 8])),
attention_levels=list(cfg.get("attention_levels", [2, 3])),
dropout=float(cfg.get("dropout", 0.1)),
)
diff = GaussianDiffusion(
timesteps=int(cfg.get("timesteps", 1500)),
beta_start=float(cfg.get("beta_start", 1e-4)),
beta_end=float(cfg.get("beta_end", 0.02)),
schedule_type=str(cfg.get("schedule_type", "linear")),
)
model = ConditionalDiffusionModel(unet, diff).to(device)
ck = torch.load(ckpt, map_location=device, weights_only=False)
if isinstance(ck, dict) and "ema_shadow" in ck:
cur = model.state_dict()
for k, v in ck["ema_shadow"].items():
if k in cur:
cur[k] = v
model.load_state_dict(cur)
print(" Loaded EMA weights")
elif isinstance(ck, dict) and "model_state_dict" in ck:
model.load_state_dict(ck["model_state_dict"])
else:
model.load_state_dict(ck)
model.eval()
for p in model.parameters():
p.requires_grad_(False)
return model
def load_test_data(
data_dir: str, n_fields: int, seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
dp = Path(data_dir)
lsuf = "_2" if (dp / "train_labels_LH_2.npy").exists() else ""
isuf = "" if (dp / "train_LH.npy").exists() else "_6"
imgs = np.load(dp / f"test_LH{isuf}.npy").astype(np.float32)
labels = np.load(dp / f"test_labels_LH{lsuf}.npy").astype(np.float32)
tr_lab = np.load(dp / f"train_labels_LH{lsuf}.npy").astype(np.float32)
rng = np.random.default_rng(seed)
idx = rng.choice(len(imgs), n_fields, replace=False)
label_mu = tr_lab.mean(0)
label_std = np.where(tr_lab.std(0) == 0, 1.0, tr_lab.std(0))
return imgs[idx], labels[idx], label_mu, label_std
def normal_kl(mean1, log_var1, mean2, log_var2):
return 0.5 * (
-1.0 + log_var2 - log_var1
+ torch.exp(log_var1 - log_var2)
+ ((mean1 - mean2) ** 2) * torch.exp(-log_var2)
)
def _approx_standard_normal_cdf(x):
return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * x ** 3)))
def discretised_gaussian_log_likelihood(x_0, mean, log_var):
centered_x = x_0 - mean
inv_stdv = torch.exp(-0.5 * log_var)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_plus = _approx_standard_normal_cdf(plus_in)
cdf_min = _approx_standard_normal_cdf(min_in)
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = (cdf_plus - cdf_min).clamp(min=1e-12)
log_probs = torch.where(
x_0 < -0.999,
log_cdf_plus,
torch.where(x_0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta)),
)
return log_probs
def predict_x_start_from_eps(diff: GaussianDiffusion, x_t: torch.Tensor,
t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor:
# Matches GaussianDiffusion._predict_xstart_from_noise (diffusion_conditional.py)
return (
diff._extract(diff.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t
- diff._extract(diff.sqrt_recip_minus_one, t, x_t.shape) * eps
)
def q_posterior_mean_var(diff: GaussianDiffusion, x_start: torch.Tensor,
x_t: torch.Tensor, t: torch.Tensor):
mean = (
diff._extract(diff.posterior_mean_coef1, t, x_t.shape) * x_start
+ diff._extract(diff.posterior_mean_coef2, t, x_t.shape) * x_t
)
var = diff._extract(diff.posterior_variance, t, x_t.shape)
log_var_c = diff._extract(diff.posterior_log_variance_clipped, t, x_t.shape)
return mean, var, log_var_c
@torch.no_grad()
def compute_L_t(
model: ConditionalDiffusionModel,
x_0: torch.Tensor,
labels_n: torch.Tensor,
t: int,
fixed_eps: torch.Tensor,
) -> torch.Tensor:
diff = model.diffusion
device = x_0.device
B = x_0.shape[0]
t_vec = torch.full((B,), t, device=device, dtype=torch.long)
if t == 0:
t1 = torch.full((B,), 1, device=device, dtype=torch.long)
ab1 = diff._extract(diff.alphas_cumprod, t1, x_0.shape)
x_1 = torch.sqrt(ab1) * x_0 + torch.sqrt(1.0 - ab1) * fixed_eps
eps_pred = model(x_1, t1, labels_n)
x_start_pred = predict_x_start_from_eps(diff, x_1, t1, eps_pred).clamp(-1, 1)
mean, _, log_var = q_posterior_mean_var(diff, x_start_pred, x_1, t1)
log_p = discretised_gaussian_log_likelihood(x_0, mean, log_var)
return -log_p.sum(dim=(1, 2, 3))
ab_t = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape)
x_t = torch.sqrt(ab_t) * x_0 + torch.sqrt(1.0 - ab_t) * fixed_eps
true_mean, _, true_log_var = q_posterior_mean_var(diff, x_0, x_t, t_vec)
eps_pred = model(x_t, t_vec, labels_n)
x_start_pred = predict_x_start_from_eps(diff, x_t, t_vec, eps_pred).clamp(-1, 1)
model_mean, _, model_log_var = q_posterior_mean_var(diff, x_start_pred, x_t, t_vec)
kl = normal_kl(true_mean, true_log_var, model_mean, model_log_var)
return kl.sum(dim=(1, 2, 3))
@torch.no_grad()
def compute_L_T_analytic(diff: GaussianDiffusion, x_0: torch.Tensor) -> torch.Tensor:
T = diff.timesteps
t_vec = torch.full((x_0.shape[0],), T - 1, device=x_0.device, dtype=torch.long)
abar_T = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape)
mean1 = torch.sqrt(abar_T) * x_0
log_var1 = torch.log((1.0 - abar_T).clamp(min=1e-30))
kl = normal_kl(mean1, log_var1, torch.zeros_like(mean1), torch.zeros_like(log_var1))
return kl.sum(dim=(1, 2, 3))
def build_eval_grid(
Om_true: float,
s8_true: float,
grid_size: int,
span: float = 0.1,
Om_range: Tuple[float, float] = (0.10, 0.50),
s8_range: Tuple[float, float] = (0.60, 1.00),
) -> Tuple[np.ndarray, np.ndarray]:
Om_lo = max(Om_true - span, Om_range[0])
Om_hi = min(Om_true + span, Om_range[1])
s8_lo = max(s8_true - span, s8_range[0])
s8_hi = min(s8_true + span, s8_range[1])
Om_1d = np.linspace(Om_lo, Om_hi, grid_size)
s8_1d = np.linspace(s8_lo, s8_hi, grid_size)
return Om_1d, s8_1d
@torch.no_grad()
def evaluate_vlb_surface(
model: ConditionalDiffusionModel,
x_0: torch.Tensor,
Om_grid: np.ndarray,
s8_grid: np.ndarray,
label_mu: np.ndarray,
label_std: np.ndarray,
t_values: List[int],
n_seeds: int = 4,
batch_size: int = 32,
label_dim: int = 2,
fixed_seed: int = 0,
device: Optional[torch.device] = None,
) -> Dict[int, np.ndarray]:
device = device or x_0.device
nO, nS = len(Om_grid), len(s8_grid)
n_pts = nO * nS
Omg, s8g = np.meshgrid(Om_grid, s8_grid, indexing="ij")
raw_labels = np.column_stack([Omg.ravel(), s8g.ravel()])
if label_dim > 2:
pad = np.zeros((n_pts, label_dim - 2), dtype=np.float32)
for i in range(label_dim - 2):
pad[:, i] = label_mu[2 + i]
raw_labels = np.concatenate([raw_labels, pad], axis=1)
norm_labels = (raw_labels - label_mu) / label_std
norm_labels_t = torch.from_numpy(norm_labels.astype(np.float32)).to(device)
L_surfaces = {t: np.zeros(n_pts, dtype=np.float64) for t in t_values}
H, W = x_0.shape[-2], x_0.shape[-1]
rng_torch = torch.Generator(device=device).manual_seed(fixed_seed)
seeds_eps = [
torch.randn(1, 1, H, W, generator=rng_torch, device=device)
for _ in range(n_seeds)
]
for _, fixed_eps in enumerate(seeds_eps):
for t in t_values:
for start in range(0, n_pts, batch_size):
end = min(start + batch_size, n_pts)
bsz = end - start
x_b = x_0.expand(bsz, -1, -1, -1)
lbl_b = norm_labels_t[start:end]
eps_b = fixed_eps.expand(bsz, -1, -1, -1)
L_t = compute_L_t(model, x_b, lbl_b, t=t, fixed_eps=eps_b)
L_surfaces[t][start:end] += L_t.cpu().numpy() / n_seeds
return {t: L_surfaces[t].reshape(nO, nS) for t in t_values}
def marginal_from_neg2dL(
neg2dL: np.ndarray, Om_grid: np.ndarray, s8_grid: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Tuple[float, float]]:
L = -0.5 * neg2dL
L = L - L.max()
P = np.exp(L)
Om_marginal = P.sum(axis=1)
Om_marginal /= Om_marginal.sum()
s8_marginal = P.sum(axis=0)
s8_marginal /= s8_marginal.sum()
Om_pred = float(Om_grid[np.argmax(Om_marginal)])
s8_pred = float(s8_grid[np.argmax(s8_marginal)])
return Om_marginal, s8_marginal, (Om_pred, s8_pred)
def credible_interval_68(values: np.ndarray, probs: np.ndarray) -> Tuple[float, float, float]:
cdf = np.cumsum(probs)
cdf /= cdf[-1]
median = float(np.interp(0.50, cdf, values))
lo = float(np.interp(0.16, cdf, values))
hi = float(np.interp(0.84, cdf, values))
return median, lo, hi
def fig_contours_per_t(
surfaces: Dict[int, np.ndarray],
Om_grid: np.ndarray,
s8_grid: np.ndarray,
Om_true: float,
s8_true: float,
out_path: Path,
dpi: int = 200,
) -> None:
fig, ax = plt.subplots(figsize=(7, 6.5), dpi=dpi)
cmap = plt.cm.viridis
n_t = len(surfaces)
colors = cmap(np.linspace(0.05, 0.95, n_t))
for (t, L_surf), col in zip(sorted(surfaces.items()), colors):
neg2dL = 2.0 * (L_surf - L_surf.min())
ax.contour(
Om_grid, s8_grid, neg2dL.T,
levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"],
)
ax.plot([], [], color=col, lw=1.8, label=f"t={t}")
ax.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10)
ax.set_xlabel(r"$\Omega_m$", fontsize=12)
ax.set_ylabel(r"$\sigma_8$", fontsize=12)
ax.set_title(
r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contour per timestep"
"\n(Mudur-style) smaller $t$ → tighter constraint",
fontweight="bold", fontsize=10,
)
ax.legend(fontsize=8, loc="best", ncol=2, framealpha=0.92)
ax.grid(alpha=0.18)
ax.set_xlim(Om_grid[0], Om_grid[-1])
ax.set_ylim(s8_grid[0], s8_grid[-1])
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
plt.close(fig)
print(f" Saved -> {out_path}")
def _L0_posterior_smoothed(
L0_surface: np.ndarray, smooth_sigma: float = 0.6,
):
from scipy.ndimage import gaussian_filter as gf
neg2dL = 2.0 * (L0_surface - L0_surface.min())
surface_sm = gf(neg2dL, sigma=smooth_sigma)
return surface_sm, neg2dL
def draw_L0_posterior_main_panel(
ax,
surface_sm: np.ndarray,
Om_grid: np.ndarray,
s8_grid: np.ndarray,
Om_true: float,
s8_true: float,
Om_pred: float,
s8_pred: float,
*,
clabel_fontsize: float = 9.5,
marker_ms: float = 16,
) -> None:
ax.contourf(Om_grid, s8_grid, surface_sm.T, levels=60, cmap="Blues_r",
vmin=0, vmax=SIGMA_LEVELS[-1] * 3, extend="max", alpha=0.55)
for lv, co in zip(reversed(SIGMA_LEVELS), reversed(SIGMA_COLORS)):
ax.contourf(Om_grid, s8_grid, surface_sm.T,
levels=[0, lv], colors=[co], alpha=0.78)
cs = ax.contour(
Om_grid, s8_grid, surface_sm.T,
levels=SIGMA_LEVELS,
colors=["white", "white", "white"],
linewidths=[2.2, 1.6, 1.2],
linestyles=["-", "--", "-."],
)
ax.clabel(cs, fmt=SIGMA_LABELS, inline=True, fontsize=clabel_fontsize, colors="white")
ax.axvline(Om_true, color="red", lw=0.7, ls=":", alpha=0.6)
ax.axhline(s8_true, color="red", lw=0.7, ls=":", alpha=0.6)
ax.plot(Om_true, s8_true, "r+", ms=marker_ms, mew=2.5, zorder=6, label="True")
ax.plot(Om_pred, s8_pred, "w^", ms=max(6, marker_ms * 0.55), mew=1.2, zorder=6, label="MAP")
ax.set_xlim(Om_grid[0], Om_grid[-1])
ax.set_ylim(s8_grid[0], s8_grid[-1])
ax.grid(alpha=0.18)
def fig_posterior_L0_mosaic_3x3(
out_dir: Path,
n_fields: int,
out_path: Path,
mosaic_side_px: int = 10_000,
panel_inches: float = 4.0,
) -> None:
from matplotlib.patches import Patch
n_plot = min(n_fields, 9)
fig_side = panel_inches * 3
dpi = mosaic_side_px / fig_side
fig, axes = plt.subplots(
3, 3, figsize=(fig_side, fig_side), dpi=dpi,
squeeze=False,
)
for idx in range(9):
r, c = divmod(idx, 3)
ax = axes[r][c]
if idx >= n_plot:
ax.set_visible(False)
continue
nz = np.load(out_dir / f"field{idx:02d}_surfaces.npz")
L0 = np.asarray(nz["L_t0"])
Om_grid = np.asarray(nz["Om_grid"])
s8_grid = np.asarray(nz["s8_grid"])
Om_true = float(nz["Om_true"])
s8_true = float(nz["s8_true"])
surface_sm, _ = _L0_posterior_smoothed(L0, smooth_sigma=0.6)
_, _, (Om_pred, s8_pred) = marginal_from_neg2dL(surface_sm, Om_grid, s8_grid)
draw_L0_posterior_main_panel(
ax, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
clabel_fontsize=7.0, marker_ms=11,
)
if idx == 0:
legend_patches = [
Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
]
hs, ls_ = ax.get_legend_handles_labels()
ax.legend(
handles=legend_patches + hs,
labels=[p.get_label() for p in legend_patches] + ls_,
fontsize=6, loc="upper right", framealpha=0.9,
)
else:
leg = ax.get_legend()
if leg is not None:
leg.remove()
ax.set_title(
rf"field {idx}: $\Omega_m^{{\rm true}}={Om_true:.3f}$, $\sigma_8^{{\rm true}}={s8_true:.3f}$",
fontsize=8,
)
ax.set_xlabel(r"$\Omega_m$", fontsize=8)
ax.set_ylabel(r"$\sigma_8$", fontsize=8)
fig.suptitle(
r"VLB $L_0$ posterior (2D) — 9 test fields",
fontsize=11, fontweight="bold", y=0.995,
)
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
plt.close(fig)
print(f" Saved -> {out_path} (≈ {mosaic_side_px}×{mosaic_side_px} px)")
def fig_main_posterior(
L0_surface: np.ndarray,
Om_grid: np.ndarray,
s8_grid: np.ndarray,
Om_true: float,
s8_true: float,
out_path: Path,
dpi: int = 200,
):
from matplotlib.patches import Patch
surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6)
Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL(
surface_sm, Om_grid, s8_grid
)
Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg)
s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg)
fig = plt.figure(figsize=(8.5, 8.5), dpi=dpi)
gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4],
hspace=0.05, wspace=0.05,
left=0.10, right=0.95, top=0.95, bottom=0.08)
ax_main = fig.add_subplot(gs[1, 0])
ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main)
ax_rt = fig.add_subplot(gs[1, 1], sharey=ax_main)
draw_L0_posterior_main_panel(
ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
)
ax_main.set_xlabel(r"$\Omega_m$", fontsize=11)
ax_main.set_ylabel(r"$\sigma_8$", fontsize=11)
ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6)
ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4)
ax_top.axvline(Om_true, color="red", lw=1.0, ls=":")
ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--",
path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI")
ax_top.set_yticks([])
ax_top.tick_params(labelbottom=False)
ax_top.set_title(
rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$"
rf" (true: {Om_true:.3f})",
fontsize=9,
)
ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6)
ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4)
ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":")
ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--",
path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18)
ax_rt.set_xticks([])
ax_rt.tick_params(labelleft=False)
ax_rt.set_ylabel(
rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$"
rf" (true: {s8_true:.3f})",
fontsize=9, rotation=270, labelpad=15,
)
ax_rt.yaxis.set_label_position("right")
legend_patches = [
Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
]
hs, ls_ = ax_main.get_legend_handles_labels()
ax_main.legend(
handles=legend_patches + hs,
labels=[p.get_label() for p in legend_patches] + ls_,
fontsize=8, loc="upper right", framealpha=0.92,
)
fig.suptitle(
r"VLB Posterior using $L_0$ — joint and marginal distributions",
fontsize=10, fontweight="bold", y=0.99,
)
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
plt.close(fig)
print(f" Saved -> {out_path}")
return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi)
def fig_pred_vs_true(pred_results: List[Dict], out_path: Path, dpi: int = 200) -> None:
Om_true = np.array([r["Om_true"] for r in pred_results])
s8_true = np.array([r["s8_true"] for r in pred_results])
Om_pred = np.array([r["Om_pred"] for r in pred_results])
s8_pred = np.array([r["s8_pred"] for r in pred_results])
Om_err_lo = np.array([r["Om_pred"] - r["Om_lo"] for r in pred_results])
Om_err_hi = np.array([r["Om_hi"] - r["Om_pred"] for r in pred_results])
s8_err_lo = np.array([r["s8_pred"] - r["s8_lo"] for r in pred_results])
s8_err_hi = np.array([r["s8_hi"] - r["s8_pred"] for r in pred_results])
rmse_Om = np.sqrt(((Om_pred - Om_true) ** 2).mean())
rmse_s8 = np.sqrt(((s8_pred - s8_true) ** 2).mean())
fig, axes = plt.subplots(1, 2, figsize=(11, 5), dpi=dpi)
for ax, (true, pred, err_lo, err_hi, name, prange, rmse) in zip(axes, [
(Om_true, Om_pred, Om_err_lo, Om_err_hi, r"$\Omega_m$", (0.10, 0.50), rmse_Om),
(s8_true, s8_pred, s8_err_lo, s8_err_hi, r"$\sigma_8$", (0.60, 1.00), rmse_s8),
]):
ax.errorbar(
true, pred, yerr=[np.maximum(err_lo, 0), np.maximum(err_hi, 0)],
fmt="o", color=GEN_COLOR, ecolor=SIGMA_COLORS[1],
elinewidth=1.2, capsize=3, ms=6,
label="DDPM-VLB inference (68% CI)",
)
ax.plot(prange, prange, "k--", lw=1.0, alpha=0.5, label="Identity")
ax.set_xlabel(f"True {name}", fontsize=11)
ax.set_ylabel(f"Predicted {name}", fontsize=11)
ax.set_xlim(*prange)
ax.set_ylim(*prange)
ax.grid(alpha=0.2)
ax.legend(fontsize=9, loc="lower right")
ax.text(
0.04, 0.92, f"RMSE = {rmse:.4f}",
transform=ax.transAxes, fontsize=10,
bbox=dict(facecolor="white", edgecolor="#ccc", alpha=0.92, pad=4),
)
ax.set_title(f"{name}: predicted vs true", fontweight="bold", fontsize=10)
fig.suptitle(
"VLB Parameter Inference: predicted vs true\n"
r"Error bars = 68% CI from $L_0$ marginal posterior",
fontsize=10, fontweight="bold", y=1.01,
)
plt.tight_layout()
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
plt.close(fig)
print(f" Saved -> {out_path}")
print(f" RMSE: Omega_m={rmse_Om:.4f} sigma_8={rmse_s8:.4f}")
def fig_posterior_and_contours_combined(
surfaces: Dict[int, np.ndarray],
L0_surface: np.ndarray,
Om_grid: np.ndarray,
s8_grid: np.ndarray,
Om_true: float,
s8_true: float,
out_path: Path,
dpi: int = 200,
) -> Tuple[float, float, Tuple[float, float], Tuple[float, float]]:
"""
Create a combined figure with contours_per_t on left and posterior on right.
"""
from matplotlib.patches import Patch
surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6)
Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL(
surface_sm, Om_grid, s8_grid
)
Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg)
s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg)
fig = plt.figure(figsize=(16, 7), dpi=dpi)
gs = gridspec.GridSpec(2, 4, width_ratios=[4, 0.3, 4, 1], height_ratios=[1, 4],
hspace=0.08, wspace=0.10,
left=0.08, right=0.96, top=0.94, bottom=0.08)
# Left panel: Contours per timestep
ax_contours = fig.add_subplot(gs[:, 0])
cmap = plt.cm.viridis
n_t = len(surfaces)
colors = cmap(np.linspace(0.05, 0.95, n_t))
for (t, L_surf), col in zip(sorted(surfaces.items()), colors):
neg2dL = 2.0 * (L_surf - L_surf.min())
ax_contours.contour(
Om_grid, s8_grid, neg2dL.T,
levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"],
)
ax_contours.plot([], [], color=col, lw=1.8, label=f"t={t}")
ax_contours.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10)
ax_contours.set_xlabel(r"$\Omega_m$", fontsize=12)
ax_contours.set_ylabel(r"$\sigma_8$", fontsize=12)
ax_contours.set_title(
r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contours per timestep",
fontweight="bold", fontsize=11,
)
ax_contours.legend(fontsize=8, loc="best", ncol=1, framealpha=0.92)
ax_contours.grid(alpha=0.18)
ax_contours.set_xlim(Om_grid[0], Om_grid[-1])
ax_contours.set_ylim(s8_grid[0], s8_grid[-1])
# Right panel: Posterior L_0 (similar to fig_main_posterior layout)
ax_main = fig.add_subplot(gs[1, 2])
ax_top = fig.add_subplot(gs[0, 2], sharex=ax_main)
ax_rt = fig.add_subplot(gs[1, 3], sharey=ax_main)
draw_L0_posterior_main_panel(
ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
)
ax_main.set_xlabel(r"$\Omega_m$", fontsize=11)
ax_main.set_ylabel(r"$\sigma_8$", fontsize=11)
ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6)
ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4)
ax_top.axvline(Om_true, color="red", lw=1.0, ls=":")
ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--",
path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI")
ax_top.set_yticks([])
ax_top.tick_params(labelbottom=False)
ax_top.set_title(
rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$"
rf" (true: {Om_true:.3f})",
fontsize=9,
)
ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6)
ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4)
ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":")
ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--",
path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18)
ax_rt.set_xticks([])
ax_rt.tick_params(labelleft=False)
ax_rt.set_ylabel(
rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$"
rf" (true: {s8_true:.3f})",
fontsize=9, rotation=270, labelpad=15,
)
ax_rt.yaxis.set_label_position("right")
legend_patches = [
Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
]
hs, ls_ = ax_main.get_legend_handles_labels()
ax_main.legend(
handles=legend_patches + hs,
labels=[p.get_label() for p in legend_patches] + ls_,
fontsize=8, loc="upper right", framealpha=0.92,
)
fig.suptitle(
r"VLB Inference: $L_t$ contours (left) and $L_0$ posterior (right)",
fontsize=12, fontweight="bold", y=0.99,
)
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
plt.close(fig)
print(f" Saved -> {out_path}")
return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="VLB-based parameter inference for trained conditional DDPM.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument("--checkpoint", required=True)
p.add_argument("--training_args", default=None)
p.add_argument("--data_dir", default="./data/params_2")
p.add_argument("--output_dir", default="vlb_inference_outputs")
p.add_argument("--n_fields", type=int, default=9)
p.add_argument(
"--grid_size", type=int, default=10_000,
help="Ωm×σ8 evaluation grid resolution (each side). Values > 300 require "
"--allow_huge_grid (very long runs for large grid_size).",
)
p.add_argument(
"--allow_huge_grid", action="store_true",
help="Required when --grid_size > 300 (avoids accidental multi-week GPU jobs).",
)
p.add_argument(
"--mosaic_side_px", type=int, default=10_000,
help="Pixel width/height of posterior_L0_mosaic_3x3.png (square).",
)
p.add_argument(
"--mosaic_panel_inches", type=float, default=4.0,
help="Matplotlib size (inches) of each 3×3 panel; dpi = mosaic_side_px / (3× this).",
)
p.add_argument("--span", type=float, default=0.10)
p.add_argument("--t_subset", type=int, nargs="+",
default=[0, 1, 2, 5, 8, 10, 15, 20])
p.add_argument("--n_seeds", type=int, default=4)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--device", default="auto")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--dpi", type=int, default=200)
return p.parse_args()
def autodetect_args() -> Optional[str]:
for pat in ["outputs_conditional_*/args.json", "outputs_conditional_*/args.txt"]:
cands = sorted(Path(".").glob(pat), key=os.path.getctime, reverse=True)
if cands:
return str(cands[0])
return None
def main() -> None:
args = parse_args()
if args.grid_size > 300 and not args.allow_huge_grid:
print(
"\nRefusing --grid_size={} (> 300) without --allow_huge_grid.\n"
"A 10_000×10_000 grid is roughly (200)^2 ≈ 40_000× more forward passes per\n"
"field than 50×50. For a quick run use e.g. --grid_size 50; for a high-res\n"
"summary figure use the default --mosaic_side_px without increasing grid_size.\n"
"To proceed anyway: add --allow_huge_grid\n".format(args.grid_size)
)
raise SystemExit(2)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.device == "auto"
else torch.device(args.device)
)
print(f"\nDevice: {device}")
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
if args.training_args is None:
args.training_args = autodetect_args()
if args.training_args is None:
raise FileNotFoundError("Cannot find args.json — pass --training_args")
print(f" Auto-detected args: {args.training_args}")
cfg = load_config(args.training_args)
print("\nLoading model ...")
model = load_model(args.checkpoint, cfg, device)
n_p = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_p:,} T={model.diffusion.timesteps}")
print(f"\nLoading {args.n_fields} test fields ...")
test_imgs, test_labels, label_mu, label_std = load_test_data(
args.data_dir, args.n_fields, seed=args.seed,
)
print(f" Image shape: {test_imgs.shape[1:]}")
print(f" Label dim: {test_labels.shape[1]}")
print(f" Label μ/σ: {label_mu} / {label_std}")
print(
f"\nEvaluating L_t on {args.grid_size}x{args.grid_size} grid for "
f"{len(args.t_subset)} timesteps × {args.n_seeds} seeds ..."
)
print(
f" -> {args.grid_size ** 2 * len(args.t_subset) * args.n_seeds:,} "
f"forward-pass groups per field (× seeds averaged)"
)
pred_results = []
label_dim = int(cfg.get("label_dim", 2))
for fi in range(args.n_fields):
Om_true = float(test_labels[fi, 0])
s8_true = float(test_labels[fi, 1])
print(f"\n [{fi+1}/{args.n_fields}] field with "
f"Om={Om_true:.3f}, s8={s8_true:.3f}")
x_0 = torch.from_numpy(test_imgs[fi:fi + 1] * 2.0 - 1.0).unsqueeze(1).to(device)
Om_grid, s8_grid = build_eval_grid(Om_true, s8_true, args.grid_size, args.span)
t_start = time.time()
surfaces = evaluate_vlb_surface(
model=model,
x_0=x_0,
Om_grid=Om_grid,
s8_grid=s8_grid,
label_mu=label_mu,
label_std=label_std,
t_values=args.t_subset,
n_seeds=args.n_seeds,
batch_size=args.batch_size,
label_dim=label_dim,
fixed_seed=args.seed + fi,
device=device,
)
elapsed = time.time() - t_start
print(f" Evaluation time: {elapsed:.1f}s")
np.savez(
out / f"field{fi:02d}_surfaces.npz",
**{f"L_t{t}": s for t, s in surfaces.items()},
Om_grid=Om_grid, s8_grid=s8_grid,
Om_true=Om_true, s8_true=s8_true,
)
if 0 in surfaces:
# Combined figure: contours_per_t + posterior on same plot
Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) = fig_posterior_and_contours_combined(
surfaces, surfaces[0], Om_grid, s8_grid, Om_true, s8_true,
out / f"field{fi:02d}_combined.png", dpi=args.dpi,
)
pred_results.append(dict(
Om_true=Om_true, s8_true=s8_true,
Om_pred=Om_pred, s8_pred=s8_pred,
Om_lo=Om_lo, Om_hi=Om_hi,
s8_lo=s8_lo, s8_hi=s8_hi,
))
# Also save individual figures for detailed inspection
fig_contours_per_t(
surfaces, Om_grid, s8_grid, Om_true, s8_true,
out / f"field{fi:02d}_contours_per_t.png", dpi=args.dpi,
)
if 0 in surfaces:
fig_main_posterior(
surfaces[0], Om_grid, s8_grid, Om_true, s8_true,
out / f"field{fi:02d}_posterior_L0.png", dpi=args.dpi,
)
if len(pred_results) >= 2:
fig_pred_vs_true(pred_results, out / "summary_pred_vs_true.png", dpi=args.dpi)
np.savez(
out / "summary.npz",
**{k: np.array([r[k] for r in pred_results])
for k in pred_results[0].keys()},
)
if args.n_fields >= 9 and all((out / f"field{i:02d}_surfaces.npz").is_file() for i in range(9)):
fig_posterior_L0_mosaic_3x3(
out, args.n_fields, out / "posterior_L0_mosaic_3x3.png",
mosaic_side_px=args.mosaic_side_px,
panel_inches=args.mosaic_panel_inches,
)
print(f"\nAll outputs -> {out.resolve()}/")
for f in sorted(out.glob("*.png")):
print(f" {f.name}")
if __name__ == "__main__":
main()