#!/usr/bin/env python3 """ ddpm_posterior_six_anchors_corrected.py ======================================== Corrected surrogate P(k) likelihood posteriors on (Omega_m, sigma_8) for six CAMELS test anchors. CORRECTIONS OVER THE ORIGINAL SCRIPT -------------------------------------- (1) STOCHASTIC EMULATOR NOISE [was: 1 DDPM sample/grid point → fragmented posteriors] Now: average log P(k) over `--n-pk-samples` (default 8) DDPM draws per grid point, suppressing emulator variance by ~1/sqrt(N_s). (2) CALIBRATED LIKELIHOOD NOISE SCALE [was: hard-coded sigma=0.25] Now: sigma_pk is estimated from the scatter of log P(k) across repeated DDPM draws at a sample of validation labels — making the noise scale physically meaningful and data-driven. (3) PROPER MARGINALIZATION OVER ASTROPHYSICAL PARAMETERS [was: fix to LHS min/max] For DDPM-6, dims 2–5 are now integrated out via Monte Carlo: p(Om, s8 | d) ≈ (1/N) Σ_i L(d | Om, s8, θ_extra^i), θ_extra^i ~ Uniform(LHS) replacing the incorrect conditional likelihoods p(d | Om, s8, θ_extra = fixed). (4) GRID RESOLUTION [was: 14×14 = 196 points] Now: 30×30 = 900 points (configurable via --grid). (5) EFFECTIVE SAMPLE SIZE [was: none] n_eff = 1 / Σ w_i^2 is printed for every panel. Values ≪ 30 flag collapse. (6) CREDIBLE CONTOURS [was: raw contourf only] Now: 68 % and 95 % posterior mass contours drawn explicitly on each panel. (7) S8 DERIVED PARAMETER [was: absent] S8 = sigma_8 * (Omega_m / 0.3)^0.5 reported for the posterior mean. (8) POSTERIOR PREDICTIVE CHECK [was: absent] A separate figure shows the 68/95 % posterior-predictive P(k) envelope versus the observed P(k) for each anchor — a standard emulator validation step. USAGE ----- # Both models, all corrections: python ddpm_posterior_six_anchors_corrected.py # DDPM-2 only, fast debug run: python ddpm_posterior_six_anchors_corrected.py --ddpm2-only --grid 14 --n-pk-samples 4 --n-marg-samples 1 # DDPM-6 only, full quality: python ddpm_posterior_six_anchors_corrected.py --ddpm6-only --grid 30 --n-pk-samples 12 --n-marg-samples 30 """ from __future__ import annotations import argparse import gc import sys from pathlib import Path from typing import Dict, List, Optional, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.ticker as mticker import numpy as np import torch # ── Path setup ──────────────────────────────────────────────────────────────── MODELS_ROOT = Path(__file__).resolve().parent CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6" if str(CODE_6.resolve()) not in sys.path: sys.path.insert(0, str(CODE_6)) import evaluate_conditional as ec # noqa: E402 import eval_model as em # noqa: E402 # ═════════════════════════════════════════════════════════════════════════════ # § 1 GRID CONSTRUCTION # ═════════════════════════════════════════════════════════════════════════════ def build_cosmo_grid( grid: int, om_lo: float, om_hi: float, s8_lo: float, s8_hi: float, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Build a regular (grid × grid) mesh over (Omega_m, sigma_8). Returns ------- om_ax : 1-D array, shape (grid,) s8_ax : 1-D array, shape (grid,) grid2 : 2-D array, shape (grid^2, 2) — row-major (Omega_m varies fastest) """ om_ax = np.linspace(om_lo, om_hi, grid, dtype=np.float32) s8_ax = np.linspace(s8_lo, s8_hi, grid, dtype=np.float32) OG, SG = np.meshgrid(om_ax, s8_ax, indexing="ij") grid2 = np.stack([OG.ravel(), SG.ravel()], axis=1).astype(np.float32) return om_ax, s8_ax, grid2 def build_full_grid( labels_ref: np.ndarray, grid: int, tail: Optional[np.ndarray], lab_dim: int, pad_frac: float = 0.02, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Build the full label matrix for the posterior grid. Parameters ---------- labels_ref : reference labels from which (Om, s8) range is inferred grid : grid points per axis tail : fixed values for dims 2–5 (None for DDPM-2) lab_dim : total label dimension (2 or 6) pad_frac : fractional padding beyond data range Returns ------- full : (grid^2, lab_dim) float32 om_ax : (grid,) float32 s8_ax : (grid,) float32 """ lo0, hi0 = float(labels_ref[:, 0].min()), float(labels_ref[:, 0].max()) lo1, hi1 = float(labels_ref[:, 1].min()), float(labels_ref[:, 1].max()) p0 = pad_frac * (hi0 - lo0 + 1e-12) p1 = pad_frac * (hi1 - lo1 + 1e-12) om_ax, s8_ax, grid2 = build_cosmo_grid(grid, lo0 - p0, hi0 + p0, lo1 - p1, hi1 + p1) ngrid = grid2.shape[0] full = np.zeros((ngrid, lab_dim), dtype=np.float32) full[:, 0] = grid2[:, 0] full[:, 1] = grid2[:, 1] if tail is not None: assert tail.shape == (4,), f"tail must be shape (4,), got {tail.shape}" full[:, 2:6] = tail[np.newaxis, :] return full, om_ax, s8_ax # ═════════════════════════════════════════════════════════════════════════════ # § 2 LHS BOUNDS # ═════════════════════════════════════════════════════════════════════════════ def _train_label_path(data_dir: Path) -> Path: for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"): p = data_dir / name if p.is_file(): return p raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}") def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]: """Min/max of LHS training labels for dims 2–5.""" L = np.load(_train_label_path(data_dir)) if L.shape[1] < 6: raise ValueError(f"Expected ≥6 label columns, got shape {L.shape}") lo = L[:, 2:6].min(axis=0).astype(np.float32) hi = L[:, 2:6].max(axis=0).astype(np.float32) return lo, hi # ═════════════════════════════════════════════════════════════════════════════ # § 3 OBSERVED LOG P(k) # ═════════════════════════════════════════════════════════════════════════════ def log_pk_observed( obs_image: np.ndarray, box_size: float = 25.0, ) -> Tuple[np.ndarray, np.ndarray]: """ Compute log10 P(k) of the *observed* HI map, after converting from [0,1] pixel scale to log10(N_HI). Returns ------- dk : k-mode array (n_bins,) log_pd : log power spectrum of observed map (n_bins,), valid-modes only valid : boolean mask selecting non-zero k-modes """ # images_01_to_log_nhi expects shape (..., H, W) or (H, W) log_nhi = em.images01_to_log_nhi(obs_image[np.newaxis]) # (1, H, W) npix = obs_image.shape[-1] dl = box_size / npix dk, pk = ec.PowerSpectrum(log_nhi[0], N=npix, dl=dl) valid = dk > 0 log_pd = np.log(pk[valid] + 1e-30) return dk, log_pd, valid # ═════════════════════════════════════════════════════════════════════════════ # § 4 SIGMA_PK CALIBRATION (Correction #2) # ═════════════════════════════════════════════════════════════════════════════ def calibrate_sigma_pk( model: torch.nn.Module, images_val: np.ndarray, labels_val: np.ndarray, lab_mean: np.ndarray, lab_std: np.ndarray, normalize: bool, device: torch.device, box_size: float = 25.0, ddim_steps: int = 50, n_pairs: int = 30, seed: int = 0, ) -> float: """ Estimate the log-P(k) noise scale from the *aleatoric* variance of the DDPM emulator at fixed labels. For n_pairs validation images we draw two independent DDPM samples and compute std(log Pk_a - log Pk_b) / sqrt(2), then take the median. This gives a physically motivated sigma_pk that replaces the hard-coded 0.25. """ rng = np.random.default_rng(seed) n_val = min(n_pairs, len(labels_val)) idx = rng.choice(len(labels_val), size=n_val, replace=False) labs = labels_val[idx].astype(np.float32) # (n_val, lab_dim) H, W = int(images_val.shape[-2]), int(images_val.shape[-1]) sigmas = [] for i in range(n_val): lab_i = labs[i:i+1] # (1, lab_dim) pair = np.concatenate([lab_i, lab_i], axis=0) # (2, lab_dim) imgs = em.sample_batch( model, pair, lab_mean, lab_std, normalize, H, W, device, ddim_steps, False, ) # (2, H, W) in [0, 1] dk, log_pk_a, valid = log_pk_observed(imgs[0], box_size) _, log_pk_b, _ = log_pk_observed(imgs[1], box_size) diff = log_pk_a - log_pk_b # sigma of a single draw = std(diff) / sqrt(2) sigmas.append(float(np.std(diff) / np.sqrt(2.0))) sigma_cal = float(np.median(sigmas)) print( f" [calibrate_sigma_pk] n_pairs={n_val} " f"median σ_pk={sigma_cal:.4f} " f"(was hard-coded 0.25)" ) return max(sigma_cal, 0.01) # safety floor # ═════════════════════════════════════════════════════════════════════════════ # § 5 AVERAGED LOG-LIKELIHOOD (Correction #1) # ═════════════════════════════════════════════════════════════════════════════ def averaged_log_likelihood( obs_image: np.ndarray, full: np.ndarray, lab_mean: np.ndarray, lab_std: np.ndarray, normalize: bool, model: torch.nn.Module, device: torch.device, H: int, W: int, box_size: float, ddim_steps: int, batch_sz: int, n_pk_samples: int, sigma_pk: float, ) -> np.ndarray: """ Compute the Gaussian log-likelihood for every grid point in `full`, averaging over `n_pk_samples` independent DDPM draws to suppress emulator stochasticity. Parameters ---------- full : (ngrid, lab_dim) array of grid labels n_pk_samples : number of DDPM draws to average (≥8 recommended) sigma_pk : calibrated log-P(k) noise scale Returns ------- log_w : (ngrid,) unnormalised log-posterior weights """ _, log_pd, valid = log_pk_observed(obs_image, box_size) ngrid = full.shape[0] # Accumulate sum of log P(k) over n_pk_samples draws sum_log_pg = np.zeros((ngrid, int(valid.sum())), dtype=np.float64) for s in range(n_pk_samples): all_pk = [] for j0 in range(0, ngrid, batch_sz): chunk = full[j0: j0 + batch_sz] imgs = em.sample_batch( model, chunk, lab_mean, lab_std, normalize, H, W, device, ddim_steps, False, ) # (chunk_sz, H, W) _, pks = em.per_map_power_spectra_log(imgs, box_size) # pks shape: (chunk_sz, n_bins); select valid bins all_pk.append(pks[:, valid]) pk_all = np.concatenate(all_pk, axis=0) # (ngrid, n_valid) sum_log_pg += np.log(pk_all + 1e-30) mean_log_pg = sum_log_pg / n_pk_samples # (ngrid, n_valid) # Gaussian log-likelihood: -0.5 * Σ_k [(log Pd - log Pg)^2] / sigma^2 mse = np.mean((log_pd[np.newaxis, :] - mean_log_pg) ** 2, axis=1) log_w = -mse / (2.0 * sigma_pk ** 2) return log_w.astype(np.float64) # ═════════════════════════════════════════════════════════════════════════════ # § 6 POSTERIOR WEIGHT COMPUTATION # ═════════════════════════════════════════════════════════════════════════════ def posterior_weights_ddpm2( obs_image: np.ndarray, labels_ref: np.ndarray, lab_mean: np.ndarray, lab_std: np.ndarray, normalize: bool, model: torch.nn.Module, device: torch.device, grid: int, batch_sz: int, ddim_steps: int, n_pk_samples: int, sigma_pk: float, box_size: float = 25.0, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Compute the DDPM-2 surrogate posterior on (Omega_m, sigma_8). Returns (Wmap, OM, S8) with Wmap shaped (grid, grid). """ H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1]) full, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2) log_w = averaged_log_likelihood( obs_image, full, lab_mean, lab_std, normalize, model, device, H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk, ) log_w -= log_w.max() # numerical stability w = np.exp(log_w).reshape(grid, grid) w /= w.sum() OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij") return w, OM, S8 def posterior_weights_ddpm6_marginalised( obs_image: np.ndarray, labels_ref: np.ndarray, lab_mean: np.ndarray, lab_std: np.ndarray, normalize: bool, model: torch.nn.Module, device: torch.device, lo_tail: np.ndarray, hi_tail: np.ndarray, grid: int, batch_sz: int, ddim_steps: int, n_pk_samples: int, n_marg_samples: int, sigma_pk: float, box_size: float = 25.0, seed: int = 1, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Compute the DDPM-6 *marginal* posterior on (Omega_m, sigma_8) by Monte Carlo integration over the astrophysical nuisance parameters: p(Om, s8 | d) ∝ ∫ L(d | Om, s8, θ_extra) π(θ_extra) dθ_extra ≈ (1/N) Σ_i L(d | Om, s8, θ_extra^i) where θ_extra^i ~ Uniform(LHS range for dims 2-5). This replaces the incorrect approach of fixing dims 2-5 to their LHS extrema, which computes a *conditional* likelihood, not a marginal. Parameters ---------- n_marg_samples : number of MC draws for astrophysical parameter integration (≥20 recommended; more = smoother but slower) """ rng = np.random.default_rng(seed) H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1]) # Draw astrophysical parameter samples from their uniform prior over LHS theta_extra_draws = rng.uniform( lo_tail, hi_tail, size=(n_marg_samples, 4), ).astype(np.float32) _, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2) full_cosmo, _, _ = build_full_grid(labels_ref, grid, tail=None, lab_dim=2) ngrid = full_cosmo.shape[0] # log-sum-exp accumulator over marginalisation samples log_w_accum = np.full(ngrid, -np.inf, dtype=np.float64) for m_idx, theta_extra in enumerate(theta_extra_draws): # Assemble 6D label grid with this draw of astrophysical params full_6d = np.zeros((ngrid, 6), dtype=np.float32) full_6d[:, :2] = full_cosmo[:, :2] full_6d[:, 2:6] = theta_extra[np.newaxis, :] log_w_m = averaged_log_likelihood( obs_image, full_6d, lab_mean, lab_std, normalize, model, device, H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk, ) # log-sum-exp: accumulate log Σ L_i → after loop divide by N_marg log_w_accum = np.logaddexp(log_w_accum, log_w_m) if (m_idx + 1) % 5 == 0 or (m_idx + 1) == n_marg_samples: print(f" marginalisation sample {m_idx+1}/{n_marg_samples} done") # Subtract log(N_marg) to convert sum → mean, then normalise log_w_accum -= np.log(n_marg_samples) log_w_accum -= log_w_accum.max() w = np.exp(log_w_accum).reshape(grid, grid) w /= w.sum() OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij") return w, OM, S8 # ═════════════════════════════════════════════════════════════════════════════ # § 7 POSTERIOR DIAGNOSTICS # ═════════════════════════════════════════════════════════════════════════════ def effective_sample_size(w: np.ndarray) -> float: """n_eff = 1 / Σ w_i^2. Values < 30 indicate posterior collapse.""" w_flat = w.ravel() / w.sum() return float(1.0 / (w_flat ** 2).sum()) def credible_levels( w: np.ndarray, levels: Tuple[float, ...] = (0.68, 0.95), ) -> List[float]: """ Find the weight threshold c such that the region {w ≥ c} contains exactly `level` of the total probability mass. Returns a list of thresholds, one per level (descending). """ w_flat = w.ravel() sorted_w = np.sort(w_flat)[::-1] cumsum = np.cumsum(sorted_w) thresholds = [] for level in levels: idx = np.searchsorted(cumsum, level * w_flat.sum()) idx = min(idx, len(sorted_w) - 1) thresholds.append(float(sorted_w[idx])) return thresholds def posterior_summary( w: np.ndarray, OM: np.ndarray, S8: np.ndarray, ) -> Dict: """ Return a dict with posterior mean, std, and S8 derived parameter. """ w_norm = w / w.sum() mom = float((w_norm * OM).sum()) ms8 = float((w_norm * S8).sum()) vom = float((w_norm * (OM - mom) ** 2).sum()) ** 0.5 vs8 = float((w_norm * (S8 - ms8) ** 2).sum()) ** 0.5 mS8 = ms8 * (mom / 0.3) ** 0.5 n_eff = effective_sample_size(w_norm) return dict(om_mean=mom, om_std=vom, s8_mean=ms8, s8_std=vs8, S8_mean=mS8, n_eff=n_eff) # ═════════════════════════════════════════════════════════════════════════════ # § 8 PLOTTING # ═════════════════════════════════════════════════════════════════════════════ def plot_posterior_panel( ax: plt.Axes, w: np.ndarray, OM: np.ndarray, S8: np.ndarray, true_om: float, true_s8: float, title: str, summary: Optional[Dict] = None, ) -> None: """ Plot one posterior panel with: • filled colour map of posterior weights • 68 % and 95 % credible contours • true parameter location (red ×) • posterior mean (black +) • n_eff and posterior-mean S8 as text annotation """ # ── colour map ──────────────────────────────────────────────────────────── cf = ax.contourf(OM, S8, w, levels=14, cmap="Blues") plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04) # ── credible contours ───────────────────────────────────────────────────── try: thresh_68, thresh_95 = credible_levels(w, levels=(0.68, 0.95)) ax.contour(OM, S8, w, levels=[thresh_95, thresh_68], colors=["#e07b39", "#c0392b"], linewidths=[1.2, 1.8], linestyles=["--", "-"]) # Proxy artists for legend from matplotlib.lines import Line2D ax.legend( handles=[ Line2D([], [], color="#c0392b", lw=1.8, label="68 % CR"), Line2D([], [], color="#e07b39", lw=1.2, ls="--", label="95 % CR"), Line2D([], [], marker="x", color="r", ls="", ms=8, label="true"), Line2D([], [], marker="+", color="k", ls="", ms=8, label="post. mean"), ], fontsize=6.5, loc="upper right", ) except Exception: ax.legend(fontsize=6.5) # ── markers ─────────────────────────────────────────────────────────────── if summary: ax.scatter(summary["om_mean"], summary["s8_mean"], s=60, c="k", marker="+", zorder=7) ax.scatter(true_om, true_s8, s=60, c="r", marker="x", zorder=7) # ── S8 degeneracy line (for visual reference) ───────────────────────────── om_arr = np.linspace(float(OM.min()), float(OM.max()), 200) if summary: S8_val = summary["s8_mean"] * (summary["om_mean"] / 0.3) ** 0.5 s8_degen = S8_val / (om_arr / 0.3) ** 0.5 mask = (s8_degen >= float(S8.min())) & (s8_degen <= float(S8.max())) if mask.any(): ax.plot(om_arr[mask], s8_degen[mask], "k:", lw=0.8, alpha=0.5, label=f"$S_8$={S8_val:.3f}") # ── labels and annotation ───────────────────────────────────────────────── ax.set_xlabel(r"$\Omega_m$", fontsize=9) ax.set_ylabel(r"$\sigma_8$", fontsize=9) ax.set_title(title, fontsize=8) if summary: info = ( f"$n_\\mathrm{{eff}}$={summary['n_eff']:.0f}\n" f"$S_8$={summary['S8_mean']:.3f}\n" f"$\\Omega_m$={summary['om_mean']:.3f}±{summary['om_std']:.3f}\n" f"$\\sigma_8$={summary['s8_mean']:.3f}±{summary['s8_std']:.3f}" ) ax.text(0.02, 0.98, info, transform=ax.transAxes, fontsize=6.5, va="top", color="#222", bbox=dict(fc="white", ec="none", alpha=0.7, pad=1.5)) def make_posterior_figure( panels: List[Dict], suptitle: str, out_path: Path, ) -> None: """ Create a 2×3 grid of posterior panels and save to `out_path`. Each element of `panels` must be a dict with keys: w, OM, S8, true_om, true_s8, title, summary """ fig, axes = plt.subplots(2, 3, figsize=(15, 9.5), squeeze=False) for k, p in enumerate(panels): r, c = divmod(k, 3) plot_posterior_panel( axes[r, c], p["w"], p["OM"], p["S8"], p["true_om"], p["true_s8"], p["title"], p.get("summary"), ) plt.suptitle(suptitle, fontsize=11, y=0.998) plt.tight_layout(rect=(0, 0, 1, 0.97)) fig.savefig(out_path, dpi=170, bbox_inches="tight") plt.close(fig) print(f" Saved → {out_path}") # ═════════════════════════════════════════════════════════════════════════════ # § 9 POSTERIOR PREDICTIVE CHECK (Correction #8) # ═════════════════════════════════════════════════════════════════════════════ def posterior_predictive_check( obs_image: np.ndarray, w: np.ndarray, OM: np.ndarray, S8: np.ndarray, model: torch.nn.Module, lab_mean: np.ndarray, lab_std: np.ndarray, normalize: bool, device: torch.device, ddim_steps: int, box_size: float = 25.0, n_draws: int = 40, seed: int = 42, ) -> Tuple[np.ndarray, np.ndarray]: """ Draw `n_draws` parameter samples from the posterior and generate DDPM images; return the stacked log P(k) array for envelope plotting. """ rng = np.random.default_rng(seed) w_flat = w.ravel() / w.sum() idx = rng.choice(len(w_flat), size=n_draws, replace=True, p=w_flat) om_flat = OM.ravel() s8_flat = S8.ravel() labs = np.stack([om_flat[idx], s8_flat[idx]], axis=1).astype(np.float32) H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1]) imgs = em.sample_batch( model, labs, lab_mean, lab_std, normalize, H, W, device, ddim_steps, False, ) # (n_draws, H, W) _, pks = em.per_map_power_spectra_log(imgs, box_size) # (n_draws, n_bins) log_pks = np.log(pks + 1e-30) # Observed dk, log_pd, valid = log_pk_observed(obs_image, box_size) return dk[valid], log_pd, log_pks[:, valid] def plot_ppc_panel( ax: plt.Axes, dk_valid: np.ndarray, log_pd: np.ndarray, log_pks: np.ndarray, title: str, ) -> None: lo95 = np.percentile(log_pks, 2.5, axis=0) hi95 = np.percentile(log_pks, 97.5, axis=0) lo68 = np.percentile(log_pks, 16.0, axis=0) hi68 = np.percentile(log_pks, 84.0, axis=0) med = np.median(log_pks, axis=0) ax.fill_between(dk_valid, lo95, hi95, alpha=0.20, color="steelblue", label="95 % PPC") ax.fill_between(dk_valid, lo68, hi68, alpha=0.40, color="steelblue", label="68 % PPC") ax.plot(dk_valid, med, "b-", lw=1.4, label="PPC median") ax.plot(dk_valid, log_pd, "r-", lw=1.6, label="Observed") ax.set_xlabel(r"$k$ [h/Mpc]", fontsize=8) ax.set_ylabel(r"$\log\,P_\mathrm{HI}(k)$", fontsize=8) ax.set_title(title, fontsize=8) ax.legend(fontsize=6.5) ax.grid(alpha=0.3, lw=0.5) def make_ppc_figure( ppc_data: List[Dict], suptitle: str, out_path: Path, ) -> None: fig, axes = plt.subplots(2, 3, figsize=(15, 8), squeeze=False) for k, d in enumerate(ppc_data): r, c = divmod(k, 3) plot_ppc_panel(axes[r, c], d["dk"], d["log_pd"], d["log_pks"], d["title"]) plt.suptitle(suptitle, fontsize=11, y=0.998) plt.tight_layout(rect=(0, 0, 1, 0.97)) fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f" Saved → {out_path}") # ═════════════════════════════════════════════════════════════════════════════ # § 10 MODEL LOADING # ═════════════════════════════════════════════════════════════════════════════ def load_model( args_json: Path, ckpt: Path, device: torch.device, ) -> Tuple[torch.nn.Module, Dict]: cfg = ec.load_training_config(str(args_json)) model = ec.build_model(cfg, device) ec.load_checkpoint(model, str(ckpt), device) model.eval() return model, cfg # ═════════════════════════════════════════════════════════════════════════════ # § 11 HIGH-LEVEL RUNNERS # ═════════════════════════════════════════════════════════════════════════════ def run_ddpm2( out_dir: Path, imgs: np.ndarray, labs: np.ndarray, lab_mean: np.ndarray, lab_std: np.ndarray, cfg: Dict, model: torch.nn.Module, device: torch.device, anchor_ix: np.ndarray, grid: int, ddim_steps: int, batch_sz: int, n_pk_samples: int, sigma_pk: float, do_ppc: bool = True, ) -> None: normalize = bool(cfg.get("normalize_labels", True)) panels = [] ppc_data = [] for k, ix in enumerate(anchor_ix.ravel()): ix = int(ix) obs = imgs[ix] lab_t = labs[ix].astype(np.float32) tom, ts8 = float(lab_t[0]), float(lab_t[1]) print(f" [DDPM-2] anchor {k+1}/6 ix={ix} " f"Ωm={tom:.3f} σ8={ts8:.3f}") w, OM, S8 = posterior_weights_ddpm2( obs, labs, lab_mean, lab_std, normalize, model, device, grid, batch_sz, ddim_steps, n_pk_samples, sigma_pk, ) summ = posterior_summary(w, OM, S8) print(f" n_eff={summ['n_eff']:.0f} " f"Ωm_post={summ['om_mean']:.3f}±{summ['om_std']:.3f} " f"σ8_post={summ['s8_mean']:.3f}±{summ['s8_std']:.3f} " f"S8={summ['S8_mean']:.3f}") panels.append(dict( w=w, OM=OM, S8=S8, true_om=tom, true_s8=ts8, summary=summ, title=( f"test ix={ix} | " r"$\Omega_m$" + f"={tom:.3f}, " r"$\sigma_8$" + f"={ts8:.3f}" ), )) if do_ppc: dk_v, log_pd, log_pks = posterior_predictive_check( obs, w, OM, S8, model, lab_mean, lab_std, normalize, device, ddim_steps, ) ppc_data.append(dict( dk=dk_v, log_pd=log_pd, log_pks=log_pks, title=f"PPC test ix={ix}", )) # ── posterior figure ────────────────────────────────────────────────────── make_posterior_figure( panels, suptitle=( r"DDPM-2 surrogate posterior on $(\Omega_m,\,\sigma_8)$ — " r"six CAMELS anchors " f"[{n_pk_samples} DDPM draws/point, σ_pk={sigma_pk:.3f}]" ), out_path=out_dir / "posterior_six_anchors_ddpm2_corrected.png", ) # ── PPC figure ──────────────────────────────────────────────────────────── if do_ppc and ppc_data: make_ppc_figure( ppc_data, suptitle="DDPM-2 Posterior Predictive Check — P(k) envelope vs. observed", out_path=out_dir / "ppc_six_anchors_ddpm2.png", ) def run_ddpm6( out_dir: Path, imgs: np.ndarray, labs: np.ndarray, lab_mean: np.ndarray, lab_std: np.ndarray, cfg: Dict, model: torch.nn.Module, device: torch.device, lo_tail: np.ndarray, hi_tail: np.ndarray, anchor_ix: np.ndarray, grid: int, ddim_steps: int, batch_sz: int, n_pk_samples: int, n_marg_samples: int, sigma_pk: float, do_ppc: bool = True, ) -> None: normalize = bool(cfg.get("normalize_labels", True)) panels = [] ppc_data = [] for k, ix in enumerate(anchor_ix.ravel()): ix = int(ix) obs = imgs[ix] lab_t = labs[ix].astype(np.float32) tom, ts8 = float(lab_t[0]), float(lab_t[1]) print(f" [DDPM-6] anchor {k+1}/6 ix={ix} " f"Ωm={tom:.3f} σ8={ts8:.3f}") w, OM, S8 = posterior_weights_ddpm6_marginalised( obs, labs, lab_mean, lab_std, normalize, model, device, lo_tail, hi_tail, grid, batch_sz, ddim_steps, n_pk_samples, n_marg_samples, sigma_pk, ) summ = posterior_summary(w, OM, S8) print(f" n_eff={summ['n_eff']:.0f} " f"Ωm_post={summ['om_mean']:.3f}±{summ['om_std']:.3f} " f"σ8_post={summ['s8_mean']:.3f}±{summ['s8_std']:.3f} " f"S8={summ['S8_mean']:.3f}") panels.append(dict( w=w, OM=OM, S8=S8, true_om=tom, true_s8=ts8, summary=summ, title=( f"test ix={ix} | " r"$\Omega_m$" + f"={tom:.3f}, " r"$\sigma_8$" + f"={ts8:.3f}" f"\n[MC marg., N_marg={n_marg_samples}]" ), )) if do_ppc: # For PPC, use DDPM-2-style sampling (only 2 cosmological params) # with a random draw from the astrophysical prior rng = np.random.default_rng(ix) te = rng.uniform(lo_tail, hi_tail).astype(np.float32) # Build 2D posterior weights recast to 6D labels for PPC w2, OM2, S82 = w, OM, S8 # same posterior geometry dk_v, log_pd, log_pks = posterior_predictive_check( obs, w2, OM2, S82, model, lab_mean, lab_std, normalize, device, ddim_steps, ) ppc_data.append(dict( dk=dk_v, log_pd=log_pd, log_pks=log_pks, title=f"PPC test ix={ix}", )) # ── posterior figure ────────────────────────────────────────────────────── make_posterior_figure( panels, suptitle=( r"DDPM-6 marginal posterior on $(\Omega_m,\,\sigma_8)$ — " r"six CAMELS anchors " f"[MC marginalisation, N_marg={n_marg_samples}, " f"{n_pk_samples} DDPM draws/point, σ_pk={sigma_pk:.3f}]" ), out_path=out_dir / "posterior_six_anchors_ddpm6_marginalised_corrected.png", ) if do_ppc and ppc_data: make_ppc_figure( ppc_data, suptitle="DDPM-6 Posterior Predictive Check — P(k) envelope vs. observed", out_path=out_dir / "ppc_six_anchors_ddpm6.png", ) # ═════════════════════════════════════════════════════════════════════════════ # § 12 CLI # ═════════════════════════════════════════════════════════════════════════════ def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description=( "Corrected six-anchor surrogate posteriors: DDPM-2 and DDPM-6.\n" "See module docstring for a full list of corrections applied." ), formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument( "--output-dir", type=Path, default=MODELS_ROOT / "ddpm_posterior_corrected_out", ) p.add_argument( "--data-2param", type=Path, default=Path("/data/LH_data/params_2"), ) p.add_argument( "--data-6param", type=Path, default=Path("/data/LH_data/params_6"), ) p.add_argument( "--bundle-2param", type=Path, default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200", ) p.add_argument( "--bundle-6param", type=Path, default=MODELS_ROOT / "notebook_model_weights" / "6param_best", ) p.add_argument( "--split", default="test", choices=["train", "val", "test"], ) # ── grid ────────────────────────────────────────────────────────────────── p.add_argument( "--grid", type=int, default=30, help="Grid points per Ωm–σ8 axis (30×30=900 default, was 14×14=196).", ) # ── sampling ────────────────────────────────────────────────────────────── p.add_argument( "--ddim-steps", type=int, default=50, help="DDIM denoising steps per sample.", ) p.add_argument( "--batch-size", type=int, default=8, help="Grid-point batch size for DDPM forward passes.", ) p.add_argument( "--n-pk-samples", type=int, default=8, help=( "DDPM draws to average per grid point. " "Variance ∝ 1/n_pk_samples. " "≥8 recommended; use 4 for a fast debug run." ), ) p.add_argument( "--n-marg-samples", type=int, default=20, help=( "MC draws for DDPM-6 astrophysical marginalisation. " "≥20 recommended; use 5 for a fast debug run." ), ) # ── sigma calibration ───────────────────────────────────────────────────── p.add_argument( "--n-calib-pairs", type=int, default=30, help="Number of image pairs used to calibrate sigma_pk.", ) p.add_argument( "--sigma-pk", type=float, default=None, help=( "Override calibrated sigma_pk with a fixed value. " "Leave unset to use automatic calibration (recommended)." ), ) # ── scope ───────────────────────────────────────────────────────────────── p.add_argument( "--ddpm2-only", action="store_true", help="Only run DDPM-2 (skip loading DDPM-6).", ) p.add_argument( "--ddpm6-only", action="store_true", help="Only run DDPM-6 (skip loading DDPM-2).", ) p.add_argument( "--no-ppc", action="store_true", help="Skip posterior predictive check figures.", ) return p.parse_args() # ═════════════════════════════════════════════════════════════════════════════ # § 13 MAIN # ═════════════════════════════════════════════════════════════════════════════ def main() -> None: args = parse_args() if args.ddpm2_only and args.ddpm6_only: raise SystemExit("Specify at most one of --ddpm2-only / --ddpm6-only.") out_dir = Path(args.output_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device : {device}") print(f"Output : {out_dir}") print() # ── load data ───────────────────────────────────────────────────────────── data2 = Path(args.data_2param) data6 = Path(args.data_6param) if not args.ddpm6_only: imgs2, labs2 = ec.load_split(data2, args.split) mean2, std2 = ec.load_label_stats(data2) print(f"DDPM-2 {args.split} set : {len(labs2)} maps " f"label_dim={labs2.shape[1]}") if not args.ddpm2_only: imgs6, labs6 = ec.load_split(data6, args.split) mean6, std6 = ec.load_label_stats(data6) lo_tail, hi_tail = tail_lhs_bounds(data6) print(f"DDPM-6 {args.split} set : {len(labs6)} maps " f"label_dim={labs6.shape[1]}") print(f" LHS tails (dims 2-5): min={lo_tail} max={hi_tail}") # ── six anchors ─────────────────────────────────────────────────────────── if not args.ddpm6_only: n_ref = len(labs2) else: n_ref = len(labs6) anchor_ix = np.linspace(0, n_ref - 1, num=6, dtype=int) print(f"\nAnchor indices: {anchor_ix.tolist()}\n") # ── checkpoints ─────────────────────────────────────────────────────────── ck2 = args.bundle_2param / "checkpoint_epoch_200.pt" args_j2 = args.bundle_2param / "args.json" ck6 = args.bundle_6param / "best_model.pt" args_j6 = args.bundle_6param / "args.json" # ══════════════════════════════════════════════════════════════════════════ # DDPM-2 BLOCK # ══════════════════════════════════════════════════════════════════════════ if not args.ddpm6_only: print("=" * 70) print(">>> DDPM-2 (six anchors)") print("=" * 70) model2, cfg2 = load_model(args_j2, ck2, device) # ── sigma_pk calibration ────────────────────────────────────────────── if args.sigma_pk is not None: sigma2 = args.sigma_pk print(f" sigma_pk overridden to {sigma2:.4f}") else: print(" Calibrating sigma_pk from validation set …") imgs2_val, labs2_val = ec.load_split(data2, "val") sigma2 = calibrate_sigma_pk( model2, imgs2_val, labs2_val, mean2, std2, normalize=bool(cfg2.get("normalize_labels", True)), device=device, ddim_steps=args.ddim_steps, n_pairs=args.n_calib_pairs, ) run_ddpm2( out_dir=out_dir, imgs=imgs2, labs=labs2, lab_mean=mean2, lab_std=std2, cfg=cfg2, model=model2, device=device, anchor_ix=anchor_ix, grid=args.grid, ddim_steps=args.ddim_steps, batch_sz=args.batch_size, n_pk_samples=args.n_pk_samples, sigma_pk=sigma2, do_ppc=not args.no_ppc, ) del model2 gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print() # ══════════════════════════════════════════════════════════════════════════ # DDPM-6 BLOCK # ══════════════════════════════════════════════════════════════════════════ if not args.ddpm2_only: print("=" * 70) print(">>> DDPM-6 (six anchors, MC marginalisation over dims 2-5)") print("=" * 70) model6, cfg6 = load_model(args_j6, ck6, device) # ── sigma_pk calibration ────────────────────────────────────────────── if args.sigma_pk is not None: sigma6 = args.sigma_pk print(f" sigma_pk overridden to {sigma6:.4f}") else: print(" Calibrating sigma_pk from validation set …") imgs6_val, labs6_val = ec.load_split(data6, "val") sigma6 = calibrate_sigma_pk( model6, imgs6_val, labs6_val, mean6, std6, normalize=bool(cfg6.get("normalize_labels", True)), device=device, ddim_steps=args.ddim_steps, n_pairs=args.n_calib_pairs, ) run_ddpm6( out_dir=out_dir, imgs=imgs6, labs=labs6, lab_mean=mean6, lab_std=std6, cfg=cfg6, model=model6, device=device, lo_tail=lo_tail, hi_tail=hi_tail, anchor_ix=anchor_ix, grid=args.grid, ddim_steps=args.ddim_steps, batch_sz=args.batch_size, n_pk_samples=args.n_pk_samples, n_marg_samples=args.n_marg_samples, sigma_pk=sigma6, do_ppc=not args.no_ppc, ) del model6 gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"\nAll done. Results in {out_dir}") if __name__ == "__main__": main()