DDPM-6param / src /figure9_posterior.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified
"""
Figure-9 style surrogate posteriors: build (Ωm, σ8) grids and log P(k) for observed maps.
Used by ddpm_cond_eval.ipynb. Sampling and P(k) live in eval_model.py.
"""
from __future__ import annotations
import numpy as np
import eval_model as em
def build_cosmo_grid(
g: int,
om_lo: float,
om_hi: float,
s8_lo: float,
s8_hi: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
om_axis = np.linspace(om_lo, om_hi, g, dtype=np.float64)
s8_axis = np.linspace(s8_lo, s8_hi, g, dtype=np.float64)
og, sg = np.meshgrid(om_axis, s8_axis, indexing="ij")
grid_labels = np.stack([og.ravel(), sg.ravel()], axis=1).astype(np.float32)
return om_axis, s8_axis, og, sg, grid_labels
def log_pk_observed(img01: np.ndarray, box_size: float, dk: np.ndarray) -> np.ndarray:
"""Single map → log P(k) on bins where dk > 0."""
_, pk = em.per_map_power_spectra_log(img01[np.newaxis, ...], box_size)
valid = dk > 0
if pk.shape[1] != len(dk):
raise ValueError("P(k) bin count mismatch vs dk")
return np.log(pk[0, valid] + 1e-30)