Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified | #!/usr/bin/env python3 | |
| """ | |
| posterior_inference.py — VLB-based cosmological inference (Mudur et al. 2023 §4 style). | |
| Pure inference-time; frozen DDPM weights. Script lives next to diffusion_conditional.py. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import ast | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.gridspec as gridspec | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patheffects as mpathe | |
| import numpy as np | |
| import torch | |
| # ── Project imports ──────────────────────────────────────────────────────────── | |
| _ROOT = Path(__file__).resolve().parent | |
| if (_ROOT / "diffusion_conditional.py").is_file(): | |
| sys.path.insert(0, str(_ROOT)) | |
| from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel | |
| from unet_conditional import ConditionalUNet | |
| plt.rcParams.update({ | |
| "figure.facecolor": "white", "axes.facecolor": "white", | |
| "axes.edgecolor": "#222", "axes.linewidth": 0.7, | |
| "axes.spines.top": False, "axes.spines.right": False, | |
| "font.family": "DejaVu Sans", "font.size": 9.5, | |
| "savefig.facecolor": "white", | |
| }) | |
| REAL_COLOR = "#CC3333" | |
| GEN_COLOR = "#2266BB" | |
| SIGMA_LEVELS = [2.30, 6.17, 11.83] | |
| SIGMA_COLORS = ["#1a5c9e", "#5590d0", "#99c0ea"] | |
| SIGMA_LABELS = {2.30: r"$1\sigma$", 6.17: r"$2\sigma$", 11.83: r"$3\sigma$"} | |
| def load_config(path: str) -> Dict: | |
| p = Path(path) | |
| if p.suffix == ".json": | |
| with open(p) as f: | |
| return json.load(f) | |
| cfg = {} | |
| with open(p) as f: | |
| for line in f: | |
| if ":" not in line: | |
| continue | |
| k, v = line.strip().split(":", 1) | |
| try: | |
| cfg[k.strip()] = ast.literal_eval(v.strip()) | |
| except Exception: | |
| cfg[k.strip()] = v.strip() | |
| return cfg | |
| def load_model(ckpt: str, cfg: Dict, device: torch.device) -> ConditionalDiffusionModel: | |
| unet = ConditionalUNet( | |
| in_channels=1, out_channels=1, | |
| label_dim=int(cfg.get("label_dim", 2)), | |
| base_channels=int(cfg.get("base_channels", 64)), | |
| channel_multipliers=list(cfg.get("channel_multipliers", [1, 2, 4, 8])), | |
| attention_levels=list(cfg.get("attention_levels", [2, 3])), | |
| dropout=float(cfg.get("dropout", 0.1)), | |
| ) | |
| diff = GaussianDiffusion( | |
| timesteps=int(cfg.get("timesteps", 1500)), | |
| beta_start=float(cfg.get("beta_start", 1e-4)), | |
| beta_end=float(cfg.get("beta_end", 0.02)), | |
| schedule_type=str(cfg.get("schedule_type", "linear")), | |
| ) | |
| model = ConditionalDiffusionModel(unet, diff).to(device) | |
| ck = torch.load(ckpt, map_location=device, weights_only=False) | |
| if isinstance(ck, dict) and "ema_shadow" in ck: | |
| cur = model.state_dict() | |
| for k, v in ck["ema_shadow"].items(): | |
| if k in cur: | |
| cur[k] = v | |
| model.load_state_dict(cur) | |
| print(" Loaded EMA weights") | |
| elif isinstance(ck, dict) and "model_state_dict" in ck: | |
| model.load_state_dict(ck["model_state_dict"]) | |
| else: | |
| model.load_state_dict(ck) | |
| model.eval() | |
| for p in model.parameters(): | |
| p.requires_grad_(False) | |
| return model | |
| def load_test_data( | |
| data_dir: str, n_fields: int, seed: int = 42 | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| dp = Path(data_dir) | |
| lsuf = "_2" if (dp / "train_labels_LH_2.npy").exists() else "" | |
| isuf = "" if (dp / "train_LH.npy").exists() else "_6" | |
| imgs = np.load(dp / f"test_LH{isuf}.npy").astype(np.float32) | |
| labels = np.load(dp / f"test_labels_LH{lsuf}.npy").astype(np.float32) | |
| tr_lab = np.load(dp / f"train_labels_LH{lsuf}.npy").astype(np.float32) | |
| rng = np.random.default_rng(seed) | |
| idx = rng.choice(len(imgs), n_fields, replace=False) | |
| label_mu = tr_lab.mean(0) | |
| label_std = np.where(tr_lab.std(0) == 0, 1.0, tr_lab.std(0)) | |
| return imgs[idx], labels[idx], label_mu, label_std | |
| def normal_kl(mean1, log_var1, mean2, log_var2): | |
| return 0.5 * ( | |
| -1.0 + log_var2 - log_var1 | |
| + torch.exp(log_var1 - log_var2) | |
| + ((mean1 - mean2) ** 2) * torch.exp(-log_var2) | |
| ) | |
| def _approx_standard_normal_cdf(x): | |
| return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * x ** 3))) | |
| def discretised_gaussian_log_likelihood(x_0, mean, log_var): | |
| centered_x = x_0 - mean | |
| inv_stdv = torch.exp(-0.5 * log_var) | |
| plus_in = inv_stdv * (centered_x + 1.0 / 255.0) | |
| min_in = inv_stdv * (centered_x - 1.0 / 255.0) | |
| cdf_plus = _approx_standard_normal_cdf(plus_in) | |
| cdf_min = _approx_standard_normal_cdf(min_in) | |
| log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) | |
| log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) | |
| cdf_delta = (cdf_plus - cdf_min).clamp(min=1e-12) | |
| log_probs = torch.where( | |
| x_0 < -0.999, | |
| log_cdf_plus, | |
| torch.where(x_0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta)), | |
| ) | |
| return log_probs | |
| def predict_x_start_from_eps(diff: GaussianDiffusion, x_t: torch.Tensor, | |
| t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor: | |
| # Matches GaussianDiffusion._predict_xstart_from_noise (diffusion_conditional.py) | |
| return ( | |
| diff._extract(diff.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t | |
| - diff._extract(diff.sqrt_recip_minus_one, t, x_t.shape) * eps | |
| ) | |
| def q_posterior_mean_var(diff: GaussianDiffusion, x_start: torch.Tensor, | |
| x_t: torch.Tensor, t: torch.Tensor): | |
| mean = ( | |
| diff._extract(diff.posterior_mean_coef1, t, x_t.shape) * x_start | |
| + diff._extract(diff.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| var = diff._extract(diff.posterior_variance, t, x_t.shape) | |
| log_var_c = diff._extract(diff.posterior_log_variance_clipped, t, x_t.shape) | |
| return mean, var, log_var_c | |
| def compute_L_t( | |
| model: ConditionalDiffusionModel, | |
| x_0: torch.Tensor, | |
| labels_n: torch.Tensor, | |
| t: int, | |
| fixed_eps: torch.Tensor, | |
| ) -> torch.Tensor: | |
| diff = model.diffusion | |
| device = x_0.device | |
| B = x_0.shape[0] | |
| t_vec = torch.full((B,), t, device=device, dtype=torch.long) | |
| if t == 0: | |
| t1 = torch.full((B,), 1, device=device, dtype=torch.long) | |
| ab1 = diff._extract(diff.alphas_cumprod, t1, x_0.shape) | |
| x_1 = torch.sqrt(ab1) * x_0 + torch.sqrt(1.0 - ab1) * fixed_eps | |
| eps_pred = model(x_1, t1, labels_n) | |
| x_start_pred = predict_x_start_from_eps(diff, x_1, t1, eps_pred).clamp(-1, 1) | |
| mean, _, log_var = q_posterior_mean_var(diff, x_start_pred, x_1, t1) | |
| log_p = discretised_gaussian_log_likelihood(x_0, mean, log_var) | |
| return -log_p.sum(dim=(1, 2, 3)) | |
| ab_t = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape) | |
| x_t = torch.sqrt(ab_t) * x_0 + torch.sqrt(1.0 - ab_t) * fixed_eps | |
| true_mean, _, true_log_var = q_posterior_mean_var(diff, x_0, x_t, t_vec) | |
| eps_pred = model(x_t, t_vec, labels_n) | |
| x_start_pred = predict_x_start_from_eps(diff, x_t, t_vec, eps_pred).clamp(-1, 1) | |
| model_mean, _, model_log_var = q_posterior_mean_var(diff, x_start_pred, x_t, t_vec) | |
| kl = normal_kl(true_mean, true_log_var, model_mean, model_log_var) | |
| return kl.sum(dim=(1, 2, 3)) | |
| def compute_L_T_analytic(diff: GaussianDiffusion, x_0: torch.Tensor) -> torch.Tensor: | |
| T = diff.timesteps | |
| t_vec = torch.full((x_0.shape[0],), T - 1, device=x_0.device, dtype=torch.long) | |
| abar_T = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape) | |
| mean1 = torch.sqrt(abar_T) * x_0 | |
| log_var1 = torch.log((1.0 - abar_T).clamp(min=1e-30)) | |
| kl = normal_kl(mean1, log_var1, torch.zeros_like(mean1), torch.zeros_like(log_var1)) | |
| return kl.sum(dim=(1, 2, 3)) | |
| def build_eval_grid( | |
| Om_true: float, | |
| s8_true: float, | |
| grid_size: int, | |
| span: float = 0.1, | |
| Om_range: Tuple[float, float] = (0.10, 0.50), | |
| s8_range: Tuple[float, float] = (0.60, 1.00), | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| Om_lo = max(Om_true - span, Om_range[0]) | |
| Om_hi = min(Om_true + span, Om_range[1]) | |
| s8_lo = max(s8_true - span, s8_range[0]) | |
| s8_hi = min(s8_true + span, s8_range[1]) | |
| Om_1d = np.linspace(Om_lo, Om_hi, grid_size) | |
| s8_1d = np.linspace(s8_lo, s8_hi, grid_size) | |
| return Om_1d, s8_1d | |
| def evaluate_vlb_surface( | |
| model: ConditionalDiffusionModel, | |
| x_0: torch.Tensor, | |
| Om_grid: np.ndarray, | |
| s8_grid: np.ndarray, | |
| label_mu: np.ndarray, | |
| label_std: np.ndarray, | |
| t_values: List[int], | |
| n_seeds: int = 4, | |
| batch_size: int = 32, | |
| label_dim: int = 2, | |
| fixed_seed: int = 0, | |
| device: Optional[torch.device] = None, | |
| ) -> Dict[int, np.ndarray]: | |
| device = device or x_0.device | |
| nO, nS = len(Om_grid), len(s8_grid) | |
| n_pts = nO * nS | |
| Omg, s8g = np.meshgrid(Om_grid, s8_grid, indexing="ij") | |
| raw_labels = np.column_stack([Omg.ravel(), s8g.ravel()]) | |
| if label_dim > 2: | |
| pad = np.zeros((n_pts, label_dim - 2), dtype=np.float32) | |
| for i in range(label_dim - 2): | |
| pad[:, i] = label_mu[2 + i] | |
| raw_labels = np.concatenate([raw_labels, pad], axis=1) | |
| norm_labels = (raw_labels - label_mu) / label_std | |
| norm_labels_t = torch.from_numpy(norm_labels.astype(np.float32)).to(device) | |
| L_surfaces = {t: np.zeros(n_pts, dtype=np.float64) for t in t_values} | |
| H, W = x_0.shape[-2], x_0.shape[-1] | |
| rng_torch = torch.Generator(device=device).manual_seed(fixed_seed) | |
| seeds_eps = [ | |
| torch.randn(1, 1, H, W, generator=rng_torch, device=device) | |
| for _ in range(n_seeds) | |
| ] | |
| for _, fixed_eps in enumerate(seeds_eps): | |
| for t in t_values: | |
| for start in range(0, n_pts, batch_size): | |
| end = min(start + batch_size, n_pts) | |
| bsz = end - start | |
| x_b = x_0.expand(bsz, -1, -1, -1) | |
| lbl_b = norm_labels_t[start:end] | |
| eps_b = fixed_eps.expand(bsz, -1, -1, -1) | |
| L_t = compute_L_t(model, x_b, lbl_b, t=t, fixed_eps=eps_b) | |
| L_surfaces[t][start:end] += L_t.cpu().numpy() / n_seeds | |
| return {t: L_surfaces[t].reshape(nO, nS) for t in t_values} | |
| def marginal_from_neg2dL( | |
| neg2dL: np.ndarray, Om_grid: np.ndarray, s8_grid: np.ndarray | |
| ) -> Tuple[np.ndarray, np.ndarray, Tuple[float, float]]: | |
| L = -0.5 * neg2dL | |
| L = L - L.max() | |
| P = np.exp(L) | |
| Om_marginal = P.sum(axis=1) | |
| Om_marginal /= Om_marginal.sum() | |
| s8_marginal = P.sum(axis=0) | |
| s8_marginal /= s8_marginal.sum() | |
| Om_pred = float(Om_grid[np.argmax(Om_marginal)]) | |
| s8_pred = float(s8_grid[np.argmax(s8_marginal)]) | |
| return Om_marginal, s8_marginal, (Om_pred, s8_pred) | |
| def credible_interval_68(values: np.ndarray, probs: np.ndarray) -> Tuple[float, float, float]: | |
| cdf = np.cumsum(probs) | |
| cdf /= cdf[-1] | |
| median = float(np.interp(0.50, cdf, values)) | |
| lo = float(np.interp(0.16, cdf, values)) | |
| hi = float(np.interp(0.84, cdf, values)) | |
| return median, lo, hi | |
| def fig_contours_per_t( | |
| surfaces: Dict[int, np.ndarray], | |
| Om_grid: np.ndarray, | |
| s8_grid: np.ndarray, | |
| Om_true: float, | |
| s8_true: float, | |
| out_path: Path, | |
| dpi: int = 200, | |
| ) -> None: | |
| fig, ax = plt.subplots(figsize=(7, 6.5), dpi=dpi) | |
| cmap = plt.cm.viridis | |
| n_t = len(surfaces) | |
| colors = cmap(np.linspace(0.05, 0.95, n_t)) | |
| for (t, L_surf), col in zip(sorted(surfaces.items()), colors): | |
| neg2dL = 2.0 * (L_surf - L_surf.min()) | |
| ax.contour( | |
| Om_grid, s8_grid, neg2dL.T, | |
| levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"], | |
| ) | |
| ax.plot([], [], color=col, lw=1.8, label=f"t={t}") | |
| ax.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10) | |
| ax.set_xlabel(r"$\Omega_m$", fontsize=12) | |
| ax.set_ylabel(r"$\sigma_8$", fontsize=12) | |
| ax.set_title( | |
| r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contour per timestep" | |
| "\n(Mudur-style) smaller $t$ → tighter constraint", | |
| fontweight="bold", fontsize=10, | |
| ) | |
| ax.legend(fontsize=8, loc="best", ncol=2, framealpha=0.92) | |
| ax.grid(alpha=0.18) | |
| ax.set_xlim(Om_grid[0], Om_grid[-1]) | |
| ax.set_ylim(s8_grid[0], s8_grid[-1]) | |
| fig.savefig(out_path, bbox_inches="tight", dpi=dpi) | |
| plt.close(fig) | |
| print(f" Saved -> {out_path}") | |
| def _L0_posterior_smoothed( | |
| L0_surface: np.ndarray, smooth_sigma: float = 0.6, | |
| ): | |
| from scipy.ndimage import gaussian_filter as gf | |
| neg2dL = 2.0 * (L0_surface - L0_surface.min()) | |
| surface_sm = gf(neg2dL, sigma=smooth_sigma) | |
| return surface_sm, neg2dL | |
| def draw_L0_posterior_main_panel( | |
| ax, | |
| surface_sm: np.ndarray, | |
| Om_grid: np.ndarray, | |
| s8_grid: np.ndarray, | |
| Om_true: float, | |
| s8_true: float, | |
| Om_pred: float, | |
| s8_pred: float, | |
| *, | |
| clabel_fontsize: float = 9.5, | |
| marker_ms: float = 16, | |
| ) -> None: | |
| ax.contourf(Om_grid, s8_grid, surface_sm.T, levels=60, cmap="Blues_r", | |
| vmin=0, vmax=SIGMA_LEVELS[-1] * 3, extend="max", alpha=0.55) | |
| for lv, co in zip(reversed(SIGMA_LEVELS), reversed(SIGMA_COLORS)): | |
| ax.contourf(Om_grid, s8_grid, surface_sm.T, | |
| levels=[0, lv], colors=[co], alpha=0.78) | |
| cs = ax.contour( | |
| Om_grid, s8_grid, surface_sm.T, | |
| levels=SIGMA_LEVELS, | |
| colors=["white", "white", "white"], | |
| linewidths=[2.2, 1.6, 1.2], | |
| linestyles=["-", "--", "-."], | |
| ) | |
| ax.clabel(cs, fmt=SIGMA_LABELS, inline=True, fontsize=clabel_fontsize, colors="white") | |
| ax.axvline(Om_true, color="red", lw=0.7, ls=":", alpha=0.6) | |
| ax.axhline(s8_true, color="red", lw=0.7, ls=":", alpha=0.6) | |
| ax.plot(Om_true, s8_true, "r+", ms=marker_ms, mew=2.5, zorder=6, label="True") | |
| ax.plot(Om_pred, s8_pred, "w^", ms=max(6, marker_ms * 0.55), mew=1.2, zorder=6, label="MAP") | |
| ax.set_xlim(Om_grid[0], Om_grid[-1]) | |
| ax.set_ylim(s8_grid[0], s8_grid[-1]) | |
| ax.grid(alpha=0.18) | |
| def fig_posterior_L0_mosaic_3x3( | |
| out_dir: Path, | |
| n_fields: int, | |
| out_path: Path, | |
| mosaic_side_px: int = 10_000, | |
| panel_inches: float = 4.0, | |
| ) -> None: | |
| from matplotlib.patches import Patch | |
| n_plot = min(n_fields, 9) | |
| fig_side = panel_inches * 3 | |
| dpi = mosaic_side_px / fig_side | |
| fig, axes = plt.subplots( | |
| 3, 3, figsize=(fig_side, fig_side), dpi=dpi, | |
| squeeze=False, | |
| ) | |
| for idx in range(9): | |
| r, c = divmod(idx, 3) | |
| ax = axes[r][c] | |
| if idx >= n_plot: | |
| ax.set_visible(False) | |
| continue | |
| nz = np.load(out_dir / f"field{idx:02d}_surfaces.npz") | |
| L0 = np.asarray(nz["L_t0"]) | |
| Om_grid = np.asarray(nz["Om_grid"]) | |
| s8_grid = np.asarray(nz["s8_grid"]) | |
| Om_true = float(nz["Om_true"]) | |
| s8_true = float(nz["s8_true"]) | |
| surface_sm, _ = _L0_posterior_smoothed(L0, smooth_sigma=0.6) | |
| _, _, (Om_pred, s8_pred) = marginal_from_neg2dL(surface_sm, Om_grid, s8_grid) | |
| draw_L0_posterior_main_panel( | |
| ax, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred, | |
| clabel_fontsize=7.0, marker_ms=11, | |
| ) | |
| if idx == 0: | |
| legend_patches = [ | |
| Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"), | |
| Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"), | |
| Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"), | |
| ] | |
| hs, ls_ = ax.get_legend_handles_labels() | |
| ax.legend( | |
| handles=legend_patches + hs, | |
| labels=[p.get_label() for p in legend_patches] + ls_, | |
| fontsize=6, loc="upper right", framealpha=0.9, | |
| ) | |
| else: | |
| leg = ax.get_legend() | |
| if leg is not None: | |
| leg.remove() | |
| ax.set_title( | |
| rf"field {idx}: $\Omega_m^{{\rm true}}={Om_true:.3f}$, $\sigma_8^{{\rm true}}={s8_true:.3f}$", | |
| fontsize=8, | |
| ) | |
| ax.set_xlabel(r"$\Omega_m$", fontsize=8) | |
| ax.set_ylabel(r"$\sigma_8$", fontsize=8) | |
| fig.suptitle( | |
| r"VLB $L_0$ posterior (2D) — 9 test fields", | |
| fontsize=11, fontweight="bold", y=0.995, | |
| ) | |
| fig.savefig(out_path, bbox_inches="tight", dpi=dpi) | |
| plt.close(fig) | |
| print(f" Saved -> {out_path} (≈ {mosaic_side_px}×{mosaic_side_px} px)") | |
| def fig_main_posterior( | |
| L0_surface: np.ndarray, | |
| Om_grid: np.ndarray, | |
| s8_grid: np.ndarray, | |
| Om_true: float, | |
| s8_true: float, | |
| out_path: Path, | |
| dpi: int = 200, | |
| ): | |
| from matplotlib.patches import Patch | |
| surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6) | |
| Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL( | |
| surface_sm, Om_grid, s8_grid | |
| ) | |
| Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg) | |
| s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg) | |
| fig = plt.figure(figsize=(8.5, 8.5), dpi=dpi) | |
| gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4], | |
| hspace=0.05, wspace=0.05, | |
| left=0.10, right=0.95, top=0.95, bottom=0.08) | |
| ax_main = fig.add_subplot(gs[1, 0]) | |
| ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main) | |
| ax_rt = fig.add_subplot(gs[1, 1], sharey=ax_main) | |
| draw_L0_posterior_main_panel( | |
| ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred, | |
| ) | |
| ax_main.set_xlabel(r"$\Omega_m$", fontsize=11) | |
| ax_main.set_ylabel(r"$\sigma_8$", fontsize=11) | |
| ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6) | |
| ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4) | |
| ax_top.axvline(Om_true, color="red", lw=1.0, ls=":") | |
| ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--", | |
| path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")]) | |
| ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI") | |
| ax_top.set_yticks([]) | |
| ax_top.tick_params(labelbottom=False) | |
| ax_top.set_title( | |
| rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$" | |
| rf" (true: {Om_true:.3f})", | |
| fontsize=9, | |
| ) | |
| ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6) | |
| ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4) | |
| ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":") | |
| ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--", | |
| path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")]) | |
| ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18) | |
| ax_rt.set_xticks([]) | |
| ax_rt.tick_params(labelleft=False) | |
| ax_rt.set_ylabel( | |
| rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$" | |
| rf" (true: {s8_true:.3f})", | |
| fontsize=9, rotation=270, labelpad=15, | |
| ) | |
| ax_rt.yaxis.set_label_position("right") | |
| legend_patches = [ | |
| Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"), | |
| Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"), | |
| Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"), | |
| ] | |
| hs, ls_ = ax_main.get_legend_handles_labels() | |
| ax_main.legend( | |
| handles=legend_patches + hs, | |
| labels=[p.get_label() for p in legend_patches] + ls_, | |
| fontsize=8, loc="upper right", framealpha=0.92, | |
| ) | |
| fig.suptitle( | |
| r"VLB Posterior using $L_0$ — joint and marginal distributions", | |
| fontsize=10, fontweight="bold", y=0.99, | |
| ) | |
| fig.savefig(out_path, bbox_inches="tight", dpi=dpi) | |
| plt.close(fig) | |
| print(f" Saved -> {out_path}") | |
| return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) | |
| def fig_pred_vs_true(pred_results: List[Dict], out_path: Path, dpi: int = 200) -> None: | |
| Om_true = np.array([r["Om_true"] for r in pred_results]) | |
| s8_true = np.array([r["s8_true"] for r in pred_results]) | |
| Om_pred = np.array([r["Om_pred"] for r in pred_results]) | |
| s8_pred = np.array([r["s8_pred"] for r in pred_results]) | |
| Om_err_lo = np.array([r["Om_pred"] - r["Om_lo"] for r in pred_results]) | |
| Om_err_hi = np.array([r["Om_hi"] - r["Om_pred"] for r in pred_results]) | |
| s8_err_lo = np.array([r["s8_pred"] - r["s8_lo"] for r in pred_results]) | |
| s8_err_hi = np.array([r["s8_hi"] - r["s8_pred"] for r in pred_results]) | |
| rmse_Om = np.sqrt(((Om_pred - Om_true) ** 2).mean()) | |
| rmse_s8 = np.sqrt(((s8_pred - s8_true) ** 2).mean()) | |
| fig, axes = plt.subplots(1, 2, figsize=(11, 5), dpi=dpi) | |
| for ax, (true, pred, err_lo, err_hi, name, prange, rmse) in zip(axes, [ | |
| (Om_true, Om_pred, Om_err_lo, Om_err_hi, r"$\Omega_m$", (0.10, 0.50), rmse_Om), | |
| (s8_true, s8_pred, s8_err_lo, s8_err_hi, r"$\sigma_8$", (0.60, 1.00), rmse_s8), | |
| ]): | |
| ax.errorbar( | |
| true, pred, yerr=[np.maximum(err_lo, 0), np.maximum(err_hi, 0)], | |
| fmt="o", color=GEN_COLOR, ecolor=SIGMA_COLORS[1], | |
| elinewidth=1.2, capsize=3, ms=6, | |
| label="DDPM-VLB inference (68% CI)", | |
| ) | |
| ax.plot(prange, prange, "k--", lw=1.0, alpha=0.5, label="Identity") | |
| ax.set_xlabel(f"True {name}", fontsize=11) | |
| ax.set_ylabel(f"Predicted {name}", fontsize=11) | |
| ax.set_xlim(*prange) | |
| ax.set_ylim(*prange) | |
| ax.grid(alpha=0.2) | |
| ax.legend(fontsize=9, loc="lower right") | |
| ax.text( | |
| 0.04, 0.92, f"RMSE = {rmse:.4f}", | |
| transform=ax.transAxes, fontsize=10, | |
| bbox=dict(facecolor="white", edgecolor="#ccc", alpha=0.92, pad=4), | |
| ) | |
| ax.set_title(f"{name}: predicted vs true", fontweight="bold", fontsize=10) | |
| fig.suptitle( | |
| "VLB Parameter Inference: predicted vs true\n" | |
| r"Error bars = 68% CI from $L_0$ marginal posterior", | |
| fontsize=10, fontweight="bold", y=1.01, | |
| ) | |
| plt.tight_layout() | |
| fig.savefig(out_path, bbox_inches="tight", dpi=dpi) | |
| plt.close(fig) | |
| print(f" Saved -> {out_path}") | |
| print(f" RMSE: Omega_m={rmse_Om:.4f} sigma_8={rmse_s8:.4f}") | |
| def fig_posterior_and_contours_combined( | |
| surfaces: Dict[int, np.ndarray], | |
| L0_surface: np.ndarray, | |
| Om_grid: np.ndarray, | |
| s8_grid: np.ndarray, | |
| Om_true: float, | |
| s8_true: float, | |
| out_path: Path, | |
| dpi: int = 200, | |
| ) -> Tuple[float, float, Tuple[float, float], Tuple[float, float]]: | |
| """ | |
| Create a combined figure with contours_per_t on left and posterior on right. | |
| """ | |
| from matplotlib.patches import Patch | |
| surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6) | |
| Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL( | |
| surface_sm, Om_grid, s8_grid | |
| ) | |
| Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg) | |
| s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg) | |
| fig = plt.figure(figsize=(16, 7), dpi=dpi) | |
| gs = gridspec.GridSpec(2, 4, width_ratios=[4, 0.3, 4, 1], height_ratios=[1, 4], | |
| hspace=0.08, wspace=0.10, | |
| left=0.08, right=0.96, top=0.94, bottom=0.08) | |
| # Left panel: Contours per timestep | |
| ax_contours = fig.add_subplot(gs[:, 0]) | |
| cmap = plt.cm.viridis | |
| n_t = len(surfaces) | |
| colors = cmap(np.linspace(0.05, 0.95, n_t)) | |
| for (t, L_surf), col in zip(sorted(surfaces.items()), colors): | |
| neg2dL = 2.0 * (L_surf - L_surf.min()) | |
| ax_contours.contour( | |
| Om_grid, s8_grid, neg2dL.T, | |
| levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"], | |
| ) | |
| ax_contours.plot([], [], color=col, lw=1.8, label=f"t={t}") | |
| ax_contours.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10) | |
| ax_contours.set_xlabel(r"$\Omega_m$", fontsize=12) | |
| ax_contours.set_ylabel(r"$\sigma_8$", fontsize=12) | |
| ax_contours.set_title( | |
| r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contours per timestep", | |
| fontweight="bold", fontsize=11, | |
| ) | |
| ax_contours.legend(fontsize=8, loc="best", ncol=1, framealpha=0.92) | |
| ax_contours.grid(alpha=0.18) | |
| ax_contours.set_xlim(Om_grid[0], Om_grid[-1]) | |
| ax_contours.set_ylim(s8_grid[0], s8_grid[-1]) | |
| # Right panel: Posterior L_0 (similar to fig_main_posterior layout) | |
| ax_main = fig.add_subplot(gs[1, 2]) | |
| ax_top = fig.add_subplot(gs[0, 2], sharex=ax_main) | |
| ax_rt = fig.add_subplot(gs[1, 3], sharey=ax_main) | |
| draw_L0_posterior_main_panel( | |
| ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred, | |
| ) | |
| ax_main.set_xlabel(r"$\Omega_m$", fontsize=11) | |
| ax_main.set_ylabel(r"$\sigma_8$", fontsize=11) | |
| ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6) | |
| ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4) | |
| ax_top.axvline(Om_true, color="red", lw=1.0, ls=":") | |
| ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--", | |
| path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")]) | |
| ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI") | |
| ax_top.set_yticks([]) | |
| ax_top.tick_params(labelbottom=False) | |
| ax_top.set_title( | |
| rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$" | |
| rf" (true: {Om_true:.3f})", | |
| fontsize=9, | |
| ) | |
| ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6) | |
| ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4) | |
| ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":") | |
| ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--", | |
| path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")]) | |
| ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18) | |
| ax_rt.set_xticks([]) | |
| ax_rt.tick_params(labelleft=False) | |
| ax_rt.set_ylabel( | |
| rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$" | |
| rf" (true: {s8_true:.3f})", | |
| fontsize=9, rotation=270, labelpad=15, | |
| ) | |
| ax_rt.yaxis.set_label_position("right") | |
| legend_patches = [ | |
| Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"), | |
| Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"), | |
| Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"), | |
| ] | |
| hs, ls_ = ax_main.get_legend_handles_labels() | |
| ax_main.legend( | |
| handles=legend_patches + hs, | |
| labels=[p.get_label() for p in legend_patches] + ls_, | |
| fontsize=8, loc="upper right", framealpha=0.92, | |
| ) | |
| fig.suptitle( | |
| r"VLB Inference: $L_t$ contours (left) and $L_0$ posterior (right)", | |
| fontsize=12, fontweight="bold", y=0.99, | |
| ) | |
| fig.savefig(out_path, bbox_inches="tight", dpi=dpi) | |
| plt.close(fig) | |
| print(f" Saved -> {out_path}") | |
| return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="VLB-based parameter inference for trained conditional DDPM.", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| p.add_argument("--checkpoint", required=True) | |
| p.add_argument("--training_args", default=None) | |
| p.add_argument("--data_dir", default="./data/params_2") | |
| p.add_argument("--output_dir", default="vlb_inference_outputs") | |
| p.add_argument("--n_fields", type=int, default=9) | |
| p.add_argument( | |
| "--grid_size", type=int, default=10_000, | |
| help="Ωm×σ8 evaluation grid resolution (each side). Values > 300 require " | |
| "--allow_huge_grid (very long runs for large grid_size).", | |
| ) | |
| p.add_argument( | |
| "--allow_huge_grid", action="store_true", | |
| help="Required when --grid_size > 300 (avoids accidental multi-week GPU jobs).", | |
| ) | |
| p.add_argument( | |
| "--mosaic_side_px", type=int, default=10_000, | |
| help="Pixel width/height of posterior_L0_mosaic_3x3.png (square).", | |
| ) | |
| p.add_argument( | |
| "--mosaic_panel_inches", type=float, default=4.0, | |
| help="Matplotlib size (inches) of each 3×3 panel; dpi = mosaic_side_px / (3× this).", | |
| ) | |
| p.add_argument("--span", type=float, default=0.10) | |
| p.add_argument("--t_subset", type=int, nargs="+", | |
| default=[0, 1, 2, 5, 8, 10, 15, 20]) | |
| p.add_argument("--n_seeds", type=int, default=4) | |
| p.add_argument("--batch_size", type=int, default=32) | |
| p.add_argument("--device", default="auto") | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument("--dpi", type=int, default=200) | |
| return p.parse_args() | |
| def autodetect_args() -> Optional[str]: | |
| for pat in ["outputs_conditional_*/args.json", "outputs_conditional_*/args.txt"]: | |
| cands = sorted(Path(".").glob(pat), key=os.path.getctime, reverse=True) | |
| if cands: | |
| return str(cands[0]) | |
| return None | |
| def main() -> None: | |
| args = parse_args() | |
| if args.grid_size > 300 and not args.allow_huge_grid: | |
| print( | |
| "\nRefusing --grid_size={} (> 300) without --allow_huge_grid.\n" | |
| "A 10_000×10_000 grid is roughly (200)^2 ≈ 40_000× more forward passes per\n" | |
| "field than 50×50. For a quick run use e.g. --grid_size 50; for a high-res\n" | |
| "summary figure use the default --mosaic_side_px without increasing grid_size.\n" | |
| "To proceed anyway: add --allow_huge_grid\n".format(args.grid_size) | |
| ) | |
| raise SystemExit(2) | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| device = ( | |
| torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if args.device == "auto" | |
| else torch.device(args.device) | |
| ) | |
| print(f"\nDevice: {device}") | |
| out = Path(args.output_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| if args.training_args is None: | |
| args.training_args = autodetect_args() | |
| if args.training_args is None: | |
| raise FileNotFoundError("Cannot find args.json — pass --training_args") | |
| print(f" Auto-detected args: {args.training_args}") | |
| cfg = load_config(args.training_args) | |
| print("\nLoading model ...") | |
| model = load_model(args.checkpoint, cfg, device) | |
| n_p = sum(p.numel() for p in model.parameters()) | |
| print(f" Parameters: {n_p:,} T={model.diffusion.timesteps}") | |
| print(f"\nLoading {args.n_fields} test fields ...") | |
| test_imgs, test_labels, label_mu, label_std = load_test_data( | |
| args.data_dir, args.n_fields, seed=args.seed, | |
| ) | |
| print(f" Image shape: {test_imgs.shape[1:]}") | |
| print(f" Label dim: {test_labels.shape[1]}") | |
| print(f" Label μ/σ: {label_mu} / {label_std}") | |
| print( | |
| f"\nEvaluating L_t on {args.grid_size}x{args.grid_size} grid for " | |
| f"{len(args.t_subset)} timesteps × {args.n_seeds} seeds ..." | |
| ) | |
| print( | |
| f" -> {args.grid_size ** 2 * len(args.t_subset) * args.n_seeds:,} " | |
| f"forward-pass groups per field (× seeds averaged)" | |
| ) | |
| pred_results = [] | |
| label_dim = int(cfg.get("label_dim", 2)) | |
| for fi in range(args.n_fields): | |
| Om_true = float(test_labels[fi, 0]) | |
| s8_true = float(test_labels[fi, 1]) | |
| print(f"\n [{fi+1}/{args.n_fields}] field with " | |
| f"Om={Om_true:.3f}, s8={s8_true:.3f}") | |
| x_0 = torch.from_numpy(test_imgs[fi:fi + 1] * 2.0 - 1.0).unsqueeze(1).to(device) | |
| Om_grid, s8_grid = build_eval_grid(Om_true, s8_true, args.grid_size, args.span) | |
| t_start = time.time() | |
| surfaces = evaluate_vlb_surface( | |
| model=model, | |
| x_0=x_0, | |
| Om_grid=Om_grid, | |
| s8_grid=s8_grid, | |
| label_mu=label_mu, | |
| label_std=label_std, | |
| t_values=args.t_subset, | |
| n_seeds=args.n_seeds, | |
| batch_size=args.batch_size, | |
| label_dim=label_dim, | |
| fixed_seed=args.seed + fi, | |
| device=device, | |
| ) | |
| elapsed = time.time() - t_start | |
| print(f" Evaluation time: {elapsed:.1f}s") | |
| np.savez( | |
| out / f"field{fi:02d}_surfaces.npz", | |
| **{f"L_t{t}": s for t, s in surfaces.items()}, | |
| Om_grid=Om_grid, s8_grid=s8_grid, | |
| Om_true=Om_true, s8_true=s8_true, | |
| ) | |
| if 0 in surfaces: | |
| # Combined figure: contours_per_t + posterior on same plot | |
| Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) = fig_posterior_and_contours_combined( | |
| surfaces, surfaces[0], Om_grid, s8_grid, Om_true, s8_true, | |
| out / f"field{fi:02d}_combined.png", dpi=args.dpi, | |
| ) | |
| pred_results.append(dict( | |
| Om_true=Om_true, s8_true=s8_true, | |
| Om_pred=Om_pred, s8_pred=s8_pred, | |
| Om_lo=Om_lo, Om_hi=Om_hi, | |
| s8_lo=s8_lo, s8_hi=s8_hi, | |
| )) | |
| # Also save individual figures for detailed inspection | |
| fig_contours_per_t( | |
| surfaces, Om_grid, s8_grid, Om_true, s8_true, | |
| out / f"field{fi:02d}_contours_per_t.png", dpi=args.dpi, | |
| ) | |
| if 0 in surfaces: | |
| fig_main_posterior( | |
| surfaces[0], Om_grid, s8_grid, Om_true, s8_true, | |
| out / f"field{fi:02d}_posterior_L0.png", dpi=args.dpi, | |
| ) | |
| if len(pred_results) >= 2: | |
| fig_pred_vs_true(pred_results, out / "summary_pred_vs_true.png", dpi=args.dpi) | |
| np.savez( | |
| out / "summary.npz", | |
| **{k: np.array([r[k] for r in pred_results]) | |
| for k in pred_results[0].keys()}, | |
| ) | |
| if args.n_fields >= 9 and all((out / f"field{i:02d}_surfaces.npz").is_file() for i in range(9)): | |
| fig_posterior_L0_mosaic_3x3( | |
| out, args.n_fields, out / "posterior_L0_mosaic_3x3.png", | |
| mosaic_side_px=args.mosaic_side_px, | |
| panel_inches=args.mosaic_panel_inches, | |
| ) | |
| print(f"\nAll outputs -> {out.resolve()}/") | |
| for f in sorted(out.glob("*.png")): | |
| print(f" {f.name}") | |
| if __name__ == "__main__": | |
| main() | |