DDPM-6param / cross_model /poster.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
1f3e7a2 verified
#!/usr/bin/env python3
"""
ddpm_posterior_six_anchors_corrected.py
========================================
Corrected surrogate P(k) likelihood posteriors on (Omega_m, sigma_8)
for six CAMELS test anchors.
CORRECTIONS OVER THE ORIGINAL SCRIPT
--------------------------------------
(1) STOCHASTIC EMULATOR NOISE [was: 1 DDPM sample/grid point β†’ fragmented posteriors]
Now: average log P(k) over `--n-pk-samples` (default 8) DDPM draws per grid
point, suppressing emulator variance by ~1/sqrt(N_s).
(2) CALIBRATED LIKELIHOOD NOISE SCALE [was: hard-coded sigma=0.25]
Now: sigma_pk is estimated from the scatter of log P(k) across repeated DDPM
draws at a sample of validation labels β€” making the noise scale physically
meaningful and data-driven.
(3) PROPER MARGINALIZATION OVER ASTROPHYSICAL PARAMETERS [was: fix to LHS min/max]
For DDPM-6, dims 2–5 are now integrated out via Monte Carlo:
p(Om, s8 | d) β‰ˆ (1/N) Ξ£_i L(d | Om, s8, ΞΈ_extra^i), ΞΈ_extra^i ~ Uniform(LHS)
replacing the incorrect conditional likelihoods p(d | Om, s8, ΞΈ_extra = fixed).
(4) GRID RESOLUTION [was: 14Γ—14 = 196 points]
Now: 30Γ—30 = 900 points (configurable via --grid).
(5) EFFECTIVE SAMPLE SIZE [was: none]
n_eff = 1 / Ξ£ w_i^2 is printed for every panel. Values β‰ͺ 30 flag collapse.
(6) CREDIBLE CONTOURS [was: raw contourf only]
Now: 68 % and 95 % posterior mass contours drawn explicitly on each panel.
(7) S8 DERIVED PARAMETER [was: absent]
S8 = sigma_8 * (Omega_m / 0.3)^0.5 reported for the posterior mean.
(8) POSTERIOR PREDICTIVE CHECK [was: absent]
A separate figure shows the 68/95 % posterior-predictive P(k) envelope
versus the observed P(k) for each anchor β€” a standard emulator
validation step.
USAGE
-----
# Both models, all corrections:
python ddpm_posterior_six_anchors_corrected.py
# DDPM-2 only, fast debug run:
python ddpm_posterior_six_anchors_corrected.py --ddpm2-only --grid 14 --n-pk-samples 4 --n-marg-samples 1
# DDPM-6 only, full quality:
python ddpm_posterior_six_anchors_corrected.py --ddpm6-only --grid 30 --n-pk-samples 12 --n-marg-samples 30
"""
from __future__ import annotations
import argparse
import gc
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import torch
# ── Path setup ────────────────────────────────────────────────────────────────
MODELS_ROOT = Path(__file__).resolve().parent
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
if str(CODE_6.resolve()) not in sys.path:
sys.path.insert(0, str(CODE_6))
import evaluate_conditional as ec # noqa: E402
import eval_model as em # noqa: E402
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 1 GRID CONSTRUCTION
# ═════════════════════════════════════════════════════════════════════════════
def build_cosmo_grid(
grid: int,
om_lo: float, om_hi: float,
s8_lo: float, s8_hi: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Build a regular (grid Γ— grid) mesh over (Omega_m, sigma_8).
Returns
-------
om_ax : 1-D array, shape (grid,)
s8_ax : 1-D array, shape (grid,)
grid2 : 2-D array, shape (grid^2, 2) β€” row-major (Omega_m varies fastest)
"""
om_ax = np.linspace(om_lo, om_hi, grid, dtype=np.float32)
s8_ax = np.linspace(s8_lo, s8_hi, grid, dtype=np.float32)
OG, SG = np.meshgrid(om_ax, s8_ax, indexing="ij")
grid2 = np.stack([OG.ravel(), SG.ravel()], axis=1).astype(np.float32)
return om_ax, s8_ax, grid2
def build_full_grid(
labels_ref: np.ndarray,
grid: int,
tail: Optional[np.ndarray],
lab_dim: int,
pad_frac: float = 0.02,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Build the full label matrix for the posterior grid.
Parameters
----------
labels_ref : reference labels from which (Om, s8) range is inferred
grid : grid points per axis
tail : fixed values for dims 2–5 (None for DDPM-2)
lab_dim : total label dimension (2 or 6)
pad_frac : fractional padding beyond data range
Returns
-------
full : (grid^2, lab_dim) float32
om_ax : (grid,) float32
s8_ax : (grid,) float32
"""
lo0, hi0 = float(labels_ref[:, 0].min()), float(labels_ref[:, 0].max())
lo1, hi1 = float(labels_ref[:, 1].min()), float(labels_ref[:, 1].max())
p0 = pad_frac * (hi0 - lo0 + 1e-12)
p1 = pad_frac * (hi1 - lo1 + 1e-12)
om_ax, s8_ax, grid2 = build_cosmo_grid(grid, lo0 - p0, hi0 + p0,
lo1 - p1, hi1 + p1)
ngrid = grid2.shape[0]
full = np.zeros((ngrid, lab_dim), dtype=np.float32)
full[:, 0] = grid2[:, 0]
full[:, 1] = grid2[:, 1]
if tail is not None:
assert tail.shape == (4,), f"tail must be shape (4,), got {tail.shape}"
full[:, 2:6] = tail[np.newaxis, :]
return full, om_ax, s8_ax
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 2 LHS BOUNDS
# ═════════════════════════════════════════════════════════════════════════════
def _train_label_path(data_dir: Path) -> Path:
for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
p = data_dir / name
if p.is_file():
return p
raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
"""Min/max of LHS training labels for dims 2–5."""
L = np.load(_train_label_path(data_dir))
if L.shape[1] < 6:
raise ValueError(f"Expected β‰₯6 label columns, got shape {L.shape}")
lo = L[:, 2:6].min(axis=0).astype(np.float32)
hi = L[:, 2:6].max(axis=0).astype(np.float32)
return lo, hi
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 3 OBSERVED LOG P(k)
# ═════════════════════════════════════════════════════════════════════════════
def log_pk_observed(
obs_image: np.ndarray,
box_size: float = 25.0,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute log10 P(k) of the *observed* HI map, after converting
from [0,1] pixel scale to log10(N_HI).
Returns
-------
dk : k-mode array (n_bins,)
log_pd : log power spectrum of observed map (n_bins,), valid-modes only
valid : boolean mask selecting non-zero k-modes
"""
# images_01_to_log_nhi expects shape (..., H, W) or (H, W)
log_nhi = em.images01_to_log_nhi(obs_image[np.newaxis]) # (1, H, W)
npix = obs_image.shape[-1]
dl = box_size / npix
dk, pk = ec.PowerSpectrum(log_nhi[0], N=npix, dl=dl)
valid = dk > 0
log_pd = np.log(pk[valid] + 1e-30)
return dk, log_pd, valid
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 4 SIGMA_PK CALIBRATION (Correction #2)
# ═════════════════════════════════════════════════════════════════════════════
def calibrate_sigma_pk(
model: torch.nn.Module,
images_val: np.ndarray,
labels_val: np.ndarray,
lab_mean: np.ndarray,
lab_std: np.ndarray,
normalize: bool,
device: torch.device,
box_size: float = 25.0,
ddim_steps: int = 50,
n_pairs: int = 30,
seed: int = 0,
) -> float:
"""
Estimate the log-P(k) noise scale from the *aleatoric* variance of the
DDPM emulator at fixed labels.
For n_pairs validation images we draw two independent DDPM samples and
compute std(log Pk_a - log Pk_b) / sqrt(2), then take the median.
This gives a physically motivated sigma_pk that replaces the hard-coded 0.25.
"""
rng = np.random.default_rng(seed)
n_val = min(n_pairs, len(labels_val))
idx = rng.choice(len(labels_val), size=n_val, replace=False)
labs = labels_val[idx].astype(np.float32) # (n_val, lab_dim)
H, W = int(images_val.shape[-2]), int(images_val.shape[-1])
sigmas = []
for i in range(n_val):
lab_i = labs[i:i+1] # (1, lab_dim)
pair = np.concatenate([lab_i, lab_i], axis=0) # (2, lab_dim)
imgs = em.sample_batch(
model, pair, lab_mean, lab_std, normalize,
H, W, device, ddim_steps, False,
) # (2, H, W) in [0, 1]
dk, log_pk_a, valid = log_pk_observed(imgs[0], box_size)
_, log_pk_b, _ = log_pk_observed(imgs[1], box_size)
diff = log_pk_a - log_pk_b
# sigma of a single draw = std(diff) / sqrt(2)
sigmas.append(float(np.std(diff) / np.sqrt(2.0)))
sigma_cal = float(np.median(sigmas))
print(
f" [calibrate_sigma_pk] n_pairs={n_val} "
f"median Οƒ_pk={sigma_cal:.4f} "
f"(was hard-coded 0.25)"
)
return max(sigma_cal, 0.01) # safety floor
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 5 AVERAGED LOG-LIKELIHOOD (Correction #1)
# ═════════════════════════════════════════════════════════════════════════════
def averaged_log_likelihood(
obs_image: np.ndarray,
full: np.ndarray,
lab_mean: np.ndarray,
lab_std: np.ndarray,
normalize: bool,
model: torch.nn.Module,
device: torch.device,
H: int,
W: int,
box_size: float,
ddim_steps: int,
batch_sz: int,
n_pk_samples: int,
sigma_pk: float,
) -> np.ndarray:
"""
Compute the Gaussian log-likelihood for every grid point in `full`,
averaging over `n_pk_samples` independent DDPM draws to suppress
emulator stochasticity.
Parameters
----------
full : (ngrid, lab_dim) array of grid labels
n_pk_samples : number of DDPM draws to average (β‰₯8 recommended)
sigma_pk : calibrated log-P(k) noise scale
Returns
-------
log_w : (ngrid,) unnormalised log-posterior weights
"""
_, log_pd, valid = log_pk_observed(obs_image, box_size)
ngrid = full.shape[0]
# Accumulate sum of log P(k) over n_pk_samples draws
sum_log_pg = np.zeros((ngrid, int(valid.sum())), dtype=np.float64)
for s in range(n_pk_samples):
all_pk = []
for j0 in range(0, ngrid, batch_sz):
chunk = full[j0: j0 + batch_sz]
imgs = em.sample_batch(
model, chunk, lab_mean, lab_std, normalize,
H, W, device, ddim_steps, False,
) # (chunk_sz, H, W)
_, pks = em.per_map_power_spectra_log(imgs, box_size)
# pks shape: (chunk_sz, n_bins); select valid bins
all_pk.append(pks[:, valid])
pk_all = np.concatenate(all_pk, axis=0) # (ngrid, n_valid)
sum_log_pg += np.log(pk_all + 1e-30)
mean_log_pg = sum_log_pg / n_pk_samples # (ngrid, n_valid)
# Gaussian log-likelihood: -0.5 * Ξ£_k [(log Pd - log Pg)^2] / sigma^2
mse = np.mean((log_pd[np.newaxis, :] - mean_log_pg) ** 2, axis=1)
log_w = -mse / (2.0 * sigma_pk ** 2)
return log_w.astype(np.float64)
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 6 POSTERIOR WEIGHT COMPUTATION
# ═════════════════════════════════════════════════════════════════════════════
def posterior_weights_ddpm2(
obs_image: np.ndarray,
labels_ref: np.ndarray,
lab_mean: np.ndarray,
lab_std: np.ndarray,
normalize: bool,
model: torch.nn.Module,
device: torch.device,
grid: int,
batch_sz: int,
ddim_steps: int,
n_pk_samples: int,
sigma_pk: float,
box_size: float = 25.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute the DDPM-2 surrogate posterior on (Omega_m, sigma_8).
Returns (Wmap, OM, S8) with Wmap shaped (grid, grid).
"""
H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
full, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
log_w = averaged_log_likelihood(
obs_image, full, lab_mean, lab_std, normalize, model, device,
H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk,
)
log_w -= log_w.max() # numerical stability
w = np.exp(log_w).reshape(grid, grid)
w /= w.sum()
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
return w, OM, S8
def posterior_weights_ddpm6_marginalised(
obs_image: np.ndarray,
labels_ref: np.ndarray,
lab_mean: np.ndarray,
lab_std: np.ndarray,
normalize: bool,
model: torch.nn.Module,
device: torch.device,
lo_tail: np.ndarray,
hi_tail: np.ndarray,
grid: int,
batch_sz: int,
ddim_steps: int,
n_pk_samples: int,
n_marg_samples: int,
sigma_pk: float,
box_size: float = 25.0,
seed: int = 1,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute the DDPM-6 *marginal* posterior on (Omega_m, sigma_8) by
Monte Carlo integration over the astrophysical nuisance parameters:
p(Om, s8 | d) ∝ ∫ L(d | Om, s8, ΞΈ_extra) Ο€(ΞΈ_extra) dΞΈ_extra
β‰ˆ (1/N) Ξ£_i L(d | Om, s8, ΞΈ_extra^i)
where ΞΈ_extra^i ~ Uniform(LHS range for dims 2-5).
This replaces the incorrect approach of fixing dims 2-5 to their
LHS extrema, which computes a *conditional* likelihood, not a marginal.
Parameters
----------
n_marg_samples : number of MC draws for astrophysical parameter integration
(β‰₯20 recommended; more = smoother but slower)
"""
rng = np.random.default_rng(seed)
H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
# Draw astrophysical parameter samples from their uniform prior over LHS
theta_extra_draws = rng.uniform(
lo_tail, hi_tail,
size=(n_marg_samples, 4),
).astype(np.float32)
_, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
full_cosmo, _, _ = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
ngrid = full_cosmo.shape[0]
# log-sum-exp accumulator over marginalisation samples
log_w_accum = np.full(ngrid, -np.inf, dtype=np.float64)
for m_idx, theta_extra in enumerate(theta_extra_draws):
# Assemble 6D label grid with this draw of astrophysical params
full_6d = np.zeros((ngrid, 6), dtype=np.float32)
full_6d[:, :2] = full_cosmo[:, :2]
full_6d[:, 2:6] = theta_extra[np.newaxis, :]
log_w_m = averaged_log_likelihood(
obs_image, full_6d, lab_mean, lab_std, normalize, model, device,
H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk,
)
# log-sum-exp: accumulate log Ξ£ L_i β†’ after loop divide by N_marg
log_w_accum = np.logaddexp(log_w_accum, log_w_m)
if (m_idx + 1) % 5 == 0 or (m_idx + 1) == n_marg_samples:
print(f" marginalisation sample {m_idx+1}/{n_marg_samples} done")
# Subtract log(N_marg) to convert sum β†’ mean, then normalise
log_w_accum -= np.log(n_marg_samples)
log_w_accum -= log_w_accum.max()
w = np.exp(log_w_accum).reshape(grid, grid)
w /= w.sum()
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
return w, OM, S8
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 7 POSTERIOR DIAGNOSTICS
# ═════════════════════════════════════════════════════════════════════════════
def effective_sample_size(w: np.ndarray) -> float:
"""n_eff = 1 / Ξ£ w_i^2. Values < 30 indicate posterior collapse."""
w_flat = w.ravel() / w.sum()
return float(1.0 / (w_flat ** 2).sum())
def credible_levels(
w: np.ndarray,
levels: Tuple[float, ...] = (0.68, 0.95),
) -> List[float]:
"""
Find the weight threshold c such that the region {w β‰₯ c} contains
exactly `level` of the total probability mass.
Returns a list of thresholds, one per level (descending).
"""
w_flat = w.ravel()
sorted_w = np.sort(w_flat)[::-1]
cumsum = np.cumsum(sorted_w)
thresholds = []
for level in levels:
idx = np.searchsorted(cumsum, level * w_flat.sum())
idx = min(idx, len(sorted_w) - 1)
thresholds.append(float(sorted_w[idx]))
return thresholds
def posterior_summary(
w: np.ndarray,
OM: np.ndarray,
S8: np.ndarray,
) -> Dict:
"""
Return a dict with posterior mean, std, and S8 derived parameter.
"""
w_norm = w / w.sum()
mom = float((w_norm * OM).sum())
ms8 = float((w_norm * S8).sum())
vom = float((w_norm * (OM - mom) ** 2).sum()) ** 0.5
vs8 = float((w_norm * (S8 - ms8) ** 2).sum()) ** 0.5
mS8 = ms8 * (mom / 0.3) ** 0.5
n_eff = effective_sample_size(w_norm)
return dict(om_mean=mom, om_std=vom, s8_mean=ms8, s8_std=vs8,
S8_mean=mS8, n_eff=n_eff)
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 8 PLOTTING
# ═════════════════════════════════════════════════════════════════════════════
def plot_posterior_panel(
ax: plt.Axes,
w: np.ndarray,
OM: np.ndarray,
S8: np.ndarray,
true_om: float,
true_s8: float,
title: str,
summary: Optional[Dict] = None,
) -> None:
"""
Plot one posterior panel with:
β€’ filled colour map of posterior weights
β€’ 68 % and 95 % credible contours
β€’ true parameter location (red Γ—)
β€’ posterior mean (black +)
β€’ n_eff and posterior-mean S8 as text annotation
"""
# ── colour map ────────────────────────────────────────────────────────────
cf = ax.contourf(OM, S8, w, levels=14, cmap="Blues")
plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
# ── credible contours ─────────────────────────────────────────────────────
try:
thresh_68, thresh_95 = credible_levels(w, levels=(0.68, 0.95))
ax.contour(OM, S8, w, levels=[thresh_95, thresh_68],
colors=["#e07b39", "#c0392b"],
linewidths=[1.2, 1.8], linestyles=["--", "-"])
# Proxy artists for legend
from matplotlib.lines import Line2D
ax.legend(
handles=[
Line2D([], [], color="#c0392b", lw=1.8, label="68 % CR"),
Line2D([], [], color="#e07b39", lw=1.2, ls="--", label="95 % CR"),
Line2D([], [], marker="x", color="r", ls="", ms=8, label="true"),
Line2D([], [], marker="+", color="k", ls="", ms=8, label="post. mean"),
],
fontsize=6.5, loc="upper right",
)
except Exception:
ax.legend(fontsize=6.5)
# ── markers ───────────────────────────────────────────────────────────────
if summary:
ax.scatter(summary["om_mean"], summary["s8_mean"],
s=60, c="k", marker="+", zorder=7)
ax.scatter(true_om, true_s8, s=60, c="r", marker="x", zorder=7)
# ── S8 degeneracy line (for visual reference) ─────────────────────────────
om_arr = np.linspace(float(OM.min()), float(OM.max()), 200)
if summary:
S8_val = summary["s8_mean"] * (summary["om_mean"] / 0.3) ** 0.5
s8_degen = S8_val / (om_arr / 0.3) ** 0.5
mask = (s8_degen >= float(S8.min())) & (s8_degen <= float(S8.max()))
if mask.any():
ax.plot(om_arr[mask], s8_degen[mask], "k:", lw=0.8, alpha=0.5,
label=f"$S_8$={S8_val:.3f}")
# ── labels and annotation ─────────────────────────────────────────────────
ax.set_xlabel(r"$\Omega_m$", fontsize=9)
ax.set_ylabel(r"$\sigma_8$", fontsize=9)
ax.set_title(title, fontsize=8)
if summary:
info = (
f"$n_\\mathrm{{eff}}$={summary['n_eff']:.0f}\n"
f"$S_8$={summary['S8_mean']:.3f}\n"
f"$\\Omega_m$={summary['om_mean']:.3f}Β±{summary['om_std']:.3f}\n"
f"$\\sigma_8$={summary['s8_mean']:.3f}Β±{summary['s8_std']:.3f}"
)
ax.text(0.02, 0.98, info, transform=ax.transAxes,
fontsize=6.5, va="top", color="#222",
bbox=dict(fc="white", ec="none", alpha=0.7, pad=1.5))
def make_posterior_figure(
panels: List[Dict],
suptitle: str,
out_path: Path,
) -> None:
"""
Create a 2Γ—3 grid of posterior panels and save to `out_path`.
Each element of `panels` must be a dict with keys:
w, OM, S8, true_om, true_s8, title, summary
"""
fig, axes = plt.subplots(2, 3, figsize=(15, 9.5), squeeze=False)
for k, p in enumerate(panels):
r, c = divmod(k, 3)
plot_posterior_panel(
axes[r, c],
p["w"], p["OM"], p["S8"],
p["true_om"], p["true_s8"],
p["title"], p.get("summary"),
)
plt.suptitle(suptitle, fontsize=11, y=0.998)
plt.tight_layout(rect=(0, 0, 1, 0.97))
fig.savefig(out_path, dpi=170, bbox_inches="tight")
plt.close(fig)
print(f" Saved β†’ {out_path}")
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 9 POSTERIOR PREDICTIVE CHECK (Correction #8)
# ═════════════════════════════════════════════════════════════════════════════
def posterior_predictive_check(
obs_image: np.ndarray,
w: np.ndarray,
OM: np.ndarray,
S8: np.ndarray,
model: torch.nn.Module,
lab_mean: np.ndarray,
lab_std: np.ndarray,
normalize: bool,
device: torch.device,
ddim_steps: int,
box_size: float = 25.0,
n_draws: int = 40,
seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Draw `n_draws` parameter samples from the posterior and generate DDPM
images; return the stacked log P(k) array for envelope plotting.
"""
rng = np.random.default_rng(seed)
w_flat = w.ravel() / w.sum()
idx = rng.choice(len(w_flat), size=n_draws, replace=True, p=w_flat)
om_flat = OM.ravel()
s8_flat = S8.ravel()
labs = np.stack([om_flat[idx], s8_flat[idx]], axis=1).astype(np.float32)
H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
imgs = em.sample_batch(
model, labs, lab_mean, lab_std, normalize,
H, W, device, ddim_steps, False,
) # (n_draws, H, W)
_, pks = em.per_map_power_spectra_log(imgs, box_size) # (n_draws, n_bins)
log_pks = np.log(pks + 1e-30)
# Observed
dk, log_pd, valid = log_pk_observed(obs_image, box_size)
return dk[valid], log_pd, log_pks[:, valid]
def plot_ppc_panel(
ax: plt.Axes,
dk_valid: np.ndarray,
log_pd: np.ndarray,
log_pks: np.ndarray,
title: str,
) -> None:
lo95 = np.percentile(log_pks, 2.5, axis=0)
hi95 = np.percentile(log_pks, 97.5, axis=0)
lo68 = np.percentile(log_pks, 16.0, axis=0)
hi68 = np.percentile(log_pks, 84.0, axis=0)
med = np.median(log_pks, axis=0)
ax.fill_between(dk_valid, lo95, hi95,
alpha=0.20, color="steelblue", label="95 % PPC")
ax.fill_between(dk_valid, lo68, hi68,
alpha=0.40, color="steelblue", label="68 % PPC")
ax.plot(dk_valid, med, "b-", lw=1.4, label="PPC median")
ax.plot(dk_valid, log_pd, "r-", lw=1.6, label="Observed")
ax.set_xlabel(r"$k$ [h/Mpc]", fontsize=8)
ax.set_ylabel(r"$\log\,P_\mathrm{HI}(k)$", fontsize=8)
ax.set_title(title, fontsize=8)
ax.legend(fontsize=6.5)
ax.grid(alpha=0.3, lw=0.5)
def make_ppc_figure(
ppc_data: List[Dict],
suptitle: str,
out_path: Path,
) -> None:
fig, axes = plt.subplots(2, 3, figsize=(15, 8), squeeze=False)
for k, d in enumerate(ppc_data):
r, c = divmod(k, 3)
plot_ppc_panel(axes[r, c], d["dk"], d["log_pd"],
d["log_pks"], d["title"])
plt.suptitle(suptitle, fontsize=11, y=0.998)
plt.tight_layout(rect=(0, 0, 1, 0.97))
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" Saved β†’ {out_path}")
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 10 MODEL LOADING
# ═════════════════════════════════════════════════════════════════════════════
def load_model(
args_json: Path,
ckpt: Path,
device: torch.device,
) -> Tuple[torch.nn.Module, Dict]:
cfg = ec.load_training_config(str(args_json))
model = ec.build_model(cfg, device)
ec.load_checkpoint(model, str(ckpt), device)
model.eval()
return model, cfg
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 11 HIGH-LEVEL RUNNERS
# ═════════════════════════════════════════════════════════════════════════════
def run_ddpm2(
out_dir: Path,
imgs: np.ndarray,
labs: np.ndarray,
lab_mean: np.ndarray,
lab_std: np.ndarray,
cfg: Dict,
model: torch.nn.Module,
device: torch.device,
anchor_ix: np.ndarray,
grid: int,
ddim_steps: int,
batch_sz: int,
n_pk_samples: int,
sigma_pk: float,
do_ppc: bool = True,
) -> None:
normalize = bool(cfg.get("normalize_labels", True))
panels = []
ppc_data = []
for k, ix in enumerate(anchor_ix.ravel()):
ix = int(ix)
obs = imgs[ix]
lab_t = labs[ix].astype(np.float32)
tom, ts8 = float(lab_t[0]), float(lab_t[1])
print(f" [DDPM-2] anchor {k+1}/6 ix={ix} "
f"Ξ©m={tom:.3f} Οƒ8={ts8:.3f}")
w, OM, S8 = posterior_weights_ddpm2(
obs, labs, lab_mean, lab_std, normalize, model, device,
grid, batch_sz, ddim_steps, n_pk_samples, sigma_pk,
)
summ = posterior_summary(w, OM, S8)
print(f" n_eff={summ['n_eff']:.0f} "
f"Ξ©m_post={summ['om_mean']:.3f}Β±{summ['om_std']:.3f} "
f"Οƒ8_post={summ['s8_mean']:.3f}Β±{summ['s8_std']:.3f} "
f"S8={summ['S8_mean']:.3f}")
panels.append(dict(
w=w, OM=OM, S8=S8,
true_om=tom, true_s8=ts8, summary=summ,
title=(
f"test ix={ix} | "
r"$\Omega_m$" + f"={tom:.3f}, "
r"$\sigma_8$" + f"={ts8:.3f}"
),
))
if do_ppc:
dk_v, log_pd, log_pks = posterior_predictive_check(
obs, w, OM, S8, model, lab_mean, lab_std, normalize,
device, ddim_steps,
)
ppc_data.append(dict(
dk=dk_v, log_pd=log_pd, log_pks=log_pks,
title=f"PPC test ix={ix}",
))
# ── posterior figure ──────────────────────────────────────────────────────
make_posterior_figure(
panels,
suptitle=(
r"DDPM-2 surrogate posterior on $(\Omega_m,\,\sigma_8)$ β€” "
r"six CAMELS anchors "
f"[{n_pk_samples} DDPM draws/point, Οƒ_pk={sigma_pk:.3f}]"
),
out_path=out_dir / "posterior_six_anchors_ddpm2_corrected.png",
)
# ── PPC figure ────────────────────────────────────────────────────────────
if do_ppc and ppc_data:
make_ppc_figure(
ppc_data,
suptitle="DDPM-2 Posterior Predictive Check β€” P(k) envelope vs. observed",
out_path=out_dir / "ppc_six_anchors_ddpm2.png",
)
def run_ddpm6(
out_dir: Path,
imgs: np.ndarray,
labs: np.ndarray,
lab_mean: np.ndarray,
lab_std: np.ndarray,
cfg: Dict,
model: torch.nn.Module,
device: torch.device,
lo_tail: np.ndarray,
hi_tail: np.ndarray,
anchor_ix: np.ndarray,
grid: int,
ddim_steps: int,
batch_sz: int,
n_pk_samples: int,
n_marg_samples: int,
sigma_pk: float,
do_ppc: bool = True,
) -> None:
normalize = bool(cfg.get("normalize_labels", True))
panels = []
ppc_data = []
for k, ix in enumerate(anchor_ix.ravel()):
ix = int(ix)
obs = imgs[ix]
lab_t = labs[ix].astype(np.float32)
tom, ts8 = float(lab_t[0]), float(lab_t[1])
print(f" [DDPM-6] anchor {k+1}/6 ix={ix} "
f"Ξ©m={tom:.3f} Οƒ8={ts8:.3f}")
w, OM, S8 = posterior_weights_ddpm6_marginalised(
obs, labs, lab_mean, lab_std, normalize, model, device,
lo_tail, hi_tail,
grid, batch_sz, ddim_steps,
n_pk_samples, n_marg_samples, sigma_pk,
)
summ = posterior_summary(w, OM, S8)
print(f" n_eff={summ['n_eff']:.0f} "
f"Ξ©m_post={summ['om_mean']:.3f}Β±{summ['om_std']:.3f} "
f"Οƒ8_post={summ['s8_mean']:.3f}Β±{summ['s8_std']:.3f} "
f"S8={summ['S8_mean']:.3f}")
panels.append(dict(
w=w, OM=OM, S8=S8,
true_om=tom, true_s8=ts8, summary=summ,
title=(
f"test ix={ix} | "
r"$\Omega_m$" + f"={tom:.3f}, "
r"$\sigma_8$" + f"={ts8:.3f}"
f"\n[MC marg., N_marg={n_marg_samples}]"
),
))
if do_ppc:
# For PPC, use DDPM-2-style sampling (only 2 cosmological params)
# with a random draw from the astrophysical prior
rng = np.random.default_rng(ix)
te = rng.uniform(lo_tail, hi_tail).astype(np.float32)
# Build 2D posterior weights recast to 6D labels for PPC
w2, OM2, S82 = w, OM, S8 # same posterior geometry
dk_v, log_pd, log_pks = posterior_predictive_check(
obs, w2, OM2, S82, model, lab_mean, lab_std, normalize,
device, ddim_steps,
)
ppc_data.append(dict(
dk=dk_v, log_pd=log_pd, log_pks=log_pks,
title=f"PPC test ix={ix}",
))
# ── posterior figure ──────────────────────────────────────────────────────
make_posterior_figure(
panels,
suptitle=(
r"DDPM-6 marginal posterior on $(\Omega_m,\,\sigma_8)$ β€” "
r"six CAMELS anchors "
f"[MC marginalisation, N_marg={n_marg_samples}, "
f"{n_pk_samples} DDPM draws/point, Οƒ_pk={sigma_pk:.3f}]"
),
out_path=out_dir / "posterior_six_anchors_ddpm6_marginalised_corrected.png",
)
if do_ppc and ppc_data:
make_ppc_figure(
ppc_data,
suptitle="DDPM-6 Posterior Predictive Check β€” P(k) envelope vs. observed",
out_path=out_dir / "ppc_six_anchors_ddpm6.png",
)
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 12 CLI
# ═════════════════════════════════════════════════════════════════════════════
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description=(
"Corrected six-anchor surrogate posteriors: DDPM-2 and DDPM-6.\n"
"See module docstring for a full list of corrections applied."
),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument(
"--output-dir", type=Path,
default=MODELS_ROOT / "ddpm_posterior_corrected_out",
)
p.add_argument(
"--data-2param", type=Path,
default=Path("<DDPM_ROOT>/data/LH_data/params_2"),
)
p.add_argument(
"--data-6param", type=Path,
default=Path("<DDPM_ROOT>/data/LH_data/params_6"),
)
p.add_argument(
"--bundle-2param", type=Path,
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
)
p.add_argument(
"--bundle-6param", type=Path,
default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
)
p.add_argument(
"--split", default="test", choices=["train", "val", "test"],
)
# ── grid ──────────────────────────────────────────────────────────────────
p.add_argument(
"--grid", type=int, default=30,
help="Grid points per Ξ©m–σ8 axis (30Γ—30=900 default, was 14Γ—14=196).",
)
# ── sampling ──────────────────────────────────────────────────────────────
p.add_argument(
"--ddim-steps", type=int, default=50,
help="DDIM denoising steps per sample.",
)
p.add_argument(
"--batch-size", type=int, default=8,
help="Grid-point batch size for DDPM forward passes.",
)
p.add_argument(
"--n-pk-samples", type=int, default=8,
help=(
"DDPM draws to average per grid point. "
"Variance ∝ 1/n_pk_samples. "
"β‰₯8 recommended; use 4 for a fast debug run."
),
)
p.add_argument(
"--n-marg-samples", type=int, default=20,
help=(
"MC draws for DDPM-6 astrophysical marginalisation. "
"β‰₯20 recommended; use 5 for a fast debug run."
),
)
# ── sigma calibration ─────────────────────────────────────────────────────
p.add_argument(
"--n-calib-pairs", type=int, default=30,
help="Number of image pairs used to calibrate sigma_pk.",
)
p.add_argument(
"--sigma-pk", type=float, default=None,
help=(
"Override calibrated sigma_pk with a fixed value. "
"Leave unset to use automatic calibration (recommended)."
),
)
# ── scope ─────────────────────────────────────────────────────────────────
p.add_argument(
"--ddpm2-only", action="store_true",
help="Only run DDPM-2 (skip loading DDPM-6).",
)
p.add_argument(
"--ddpm6-only", action="store_true",
help="Only run DDPM-6 (skip loading DDPM-2).",
)
p.add_argument(
"--no-ppc", action="store_true",
help="Skip posterior predictive check figures.",
)
return p.parse_args()
# ═════════════════════════════════════════════════════════════════════════════
# Β§ 13 MAIN
# ═════════════════════════════════════════════════════════════════════════════
def main() -> None:
args = parse_args()
if args.ddpm2_only and args.ddpm6_only:
raise SystemExit("Specify at most one of --ddpm2-only / --ddpm6-only.")
out_dir = Path(args.output_dir).resolve()
out_dir.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device : {device}")
print(f"Output : {out_dir}")
print()
# ── load data ─────────────────────────────────────────────────────────────
data2 = Path(args.data_2param)
data6 = Path(args.data_6param)
if not args.ddpm6_only:
imgs2, labs2 = ec.load_split(data2, args.split)
mean2, std2 = ec.load_label_stats(data2)
print(f"DDPM-2 {args.split} set : {len(labs2)} maps "
f"label_dim={labs2.shape[1]}")
if not args.ddpm2_only:
imgs6, labs6 = ec.load_split(data6, args.split)
mean6, std6 = ec.load_label_stats(data6)
lo_tail, hi_tail = tail_lhs_bounds(data6)
print(f"DDPM-6 {args.split} set : {len(labs6)} maps "
f"label_dim={labs6.shape[1]}")
print(f" LHS tails (dims 2-5): min={lo_tail} max={hi_tail}")
# ── six anchors ───────────────────────────────────────────────────────────
if not args.ddpm6_only:
n_ref = len(labs2)
else:
n_ref = len(labs6)
anchor_ix = np.linspace(0, n_ref - 1, num=6, dtype=int)
print(f"\nAnchor indices: {anchor_ix.tolist()}\n")
# ── checkpoints ───────────────────────────────────────────────────────────
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
args_j2 = args.bundle_2param / "args.json"
ck6 = args.bundle_6param / "best_model.pt"
args_j6 = args.bundle_6param / "args.json"
# ══════════════════════════════════════════════════════════════════════════
# DDPM-2 BLOCK
# ══════════════════════════════════════════════════════════════════════════
if not args.ddpm6_only:
print("=" * 70)
print(">>> DDPM-2 (six anchors)")
print("=" * 70)
model2, cfg2 = load_model(args_j2, ck2, device)
# ── sigma_pk calibration ──────────────────────────────────────────────
if args.sigma_pk is not None:
sigma2 = args.sigma_pk
print(f" sigma_pk overridden to {sigma2:.4f}")
else:
print(" Calibrating sigma_pk from validation set …")
imgs2_val, labs2_val = ec.load_split(data2, "val")
sigma2 = calibrate_sigma_pk(
model2, imgs2_val, labs2_val,
mean2, std2,
normalize=bool(cfg2.get("normalize_labels", True)),
device=device,
ddim_steps=args.ddim_steps,
n_pairs=args.n_calib_pairs,
)
run_ddpm2(
out_dir=out_dir,
imgs=imgs2, labs=labs2,
lab_mean=mean2, lab_std=std2,
cfg=cfg2, model=model2, device=device,
anchor_ix=anchor_ix,
grid=args.grid,
ddim_steps=args.ddim_steps,
batch_sz=args.batch_size,
n_pk_samples=args.n_pk_samples,
sigma_pk=sigma2,
do_ppc=not args.no_ppc,
)
del model2
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print()
# ══════════════════════════════════════════════════════════════════════════
# DDPM-6 BLOCK
# ══════════════════════════════════════════════════════════════════════════
if not args.ddpm2_only:
print("=" * 70)
print(">>> DDPM-6 (six anchors, MC marginalisation over dims 2-5)")
print("=" * 70)
model6, cfg6 = load_model(args_j6, ck6, device)
# ── sigma_pk calibration ──────────────────────────────────────────────
if args.sigma_pk is not None:
sigma6 = args.sigma_pk
print(f" sigma_pk overridden to {sigma6:.4f}")
else:
print(" Calibrating sigma_pk from validation set …")
imgs6_val, labs6_val = ec.load_split(data6, "val")
sigma6 = calibrate_sigma_pk(
model6, imgs6_val, labs6_val,
mean6, std6,
normalize=bool(cfg6.get("normalize_labels", True)),
device=device,
ddim_steps=args.ddim_steps,
n_pairs=args.n_calib_pairs,
)
run_ddpm6(
out_dir=out_dir,
imgs=imgs6, labs=labs6,
lab_mean=mean6, lab_std=std6,
cfg=cfg6, model=model6, device=device,
lo_tail=lo_tail, hi_tail=hi_tail,
anchor_ix=anchor_ix,
grid=args.grid,
ddim_steps=args.ddim_steps,
batch_sz=args.batch_size,
n_pk_samples=args.n_pk_samples,
n_marg_samples=args.n_marg_samples,
sigma_pk=sigma6,
do_ppc=not args.no_ppc,
)
del model6
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"\nAll done. Results in {out_dir}")
if __name__ == "__main__":
main()