DDPM-2param / cross_model /scripts /ddpm_triangle_integration.py
collins909's picture
Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
f513198 verified
#!/usr/bin/env python3
"""
Surrogate posterior on $(\\Omega_m, \\sigma_8)$ → triangle/MCMC-style chains for one test map.
Loads the same surrogate likelihood used in ``ddpm_posterior_six_anchors``, resamples discrete
posterior masses to ``--n-hist`` correlated $(\\Omega_m,\\sigma_8)$ pairs, and writes ``.npz``.
DDPM-2: sweeps $(\\Omega_m,\\sigma_8)$.
DDPM-6: dims 2–5 fixed per ``--six-tail-mode`` (``truth`` uses the test-map labels 2–5; ``min``/``max``
use LHS extrema from training labels).
If you replace this file with a copy from your machine (Downloads), keep argparse compatible or wrap it.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
_SCRIPTS = Path(__file__).resolve().parent
if str(_SCRIPTS) not in sys.path:
sys.path.insert(0, str(_SCRIPTS))
import ddpm_posterior_six_anchors as dps # noqa: E402
MODELS_ROOT = Path(__file__).resolve().parents[1]
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
if str(CODE_6.resolve()) not in sys.path:
sys.path.insert(0, str(CODE_6.resolve()))
import evaluate_conditional as ec # noqa: E402
def _tail_vec(
mode: str,
lab_full: np.ndarray,
data6: Path,
) -> np.ndarray | None:
if lab_full.size <= 2:
return None
if mode == "truth":
return lab_full[2:6].astype(np.float32)
low, hi = dps.tail_lhs_bounds(data6)
if mode == "min":
return low
if mode == "max":
return hi
raise ValueError("six-tail-mode must be truth|min|max")
def main() -> None:
p = argparse.ArgumentParser(description="DDPM surrogate posterior → resampled Ωm σ8 chains (.npz).")
p.add_argument(
"--label-dim",
type=int,
choices=[2, 6],
required=True,
help="Which model to use.",
)
p.add_argument(
"--bundle",
type=Path,
default=None,
help="Checkpoint bundle dir with args.json (default: notebook_model_weights/<2|6>).",
)
p.add_argument(
"--checkpoint-name",
type=str,
default=None,
help="Checkpoint file under bundle (defaults: DDPM2 epoch200, DDPM6 best_model).",
)
p.add_argument(
"--data-dir",
type=Path,
default=None,
help="LH data dir matching label_dim (default: params_2 vs params_6).",
)
p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
p.add_argument("--test-index", type=int, default=56, help="Index into split for CAMELS observation.")
p.add_argument("--grid", type=int, default=14)
p.add_argument("--ddim-steps", type=int, default=50)
p.add_argument("--batch-size", type=int, default=8)
p.add_argument(
"--n-hist",
type=int,
default=10_000,
help="Resampled posterior pairs (with replacement).",
)
p.add_argument(
"--six-tail-mode",
type=str,
default="truth",
choices=["truth", "min", "max"],
help="Applies only to label_dim==6 — how dims 2–5 are fixed.",
)
p.add_argument(
"--output",
"-o",
type=Path,
required=True,
help="Output .npz path.",
)
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
ld = args.label_dim
if ld == 2:
data_dir = args.data_dir or Path("<DDPM_ROOT>/data/LH_data/params_2")
bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "2param_epoch200"
ck_name = args.checkpoint_name or "checkpoint_epoch_200.pt"
else:
data_dir = args.data_dir or Path("<DDPM_ROOT>/data/LH_data/params_6")
bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "6param_best"
ck_name = args.checkpoint_name or "best_model.pt"
rng = np.random.default_rng(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imgs, labs = ec.load_split(data_dir, args.split)
ix = int(args.test_index)
if not (0 <= ix < len(labs)):
raise SystemExit(f"test-index {ix} out of range for split ({len(labs)} rows)")
lab_t = labs[ix].astype(np.float64)
obs = imgs[ix]
ckpt = bundle / ck_name
args_json = bundle / "args.json"
mean, std = ec.load_label_stats(data_dir)
tail = None
if ld == 6:
lab6 = labs[ix].astype(np.float64)
if lab6.shape[0] != 6:
raise SystemExit("--label-dim 6 requires labels with 6 columns in data-dir")
tail = _tail_vec(args.six_tail_mode, lab6, Path(data_dir))
model, cfg = dps.load_model(args_json, ckpt, device)
normalize = bool(cfg.get("normalize_labels", True))
H = int(obs.shape[-2])
W = int(obs.shape[-1])
gsz = args.grid
full, om_ax, s8_ax = dps.build_full_grid_2d(labs, gsz, tail=tail, lab_dim=ld)
Wmap, OM, S8 = dps.posterior_weights(
obs,
full,
om_ax,
s8_ax,
mean,
std,
normalize,
model,
H=H,
W=W,
device=device,
grid=gsz,
batch_sz=args.batch_size,
ddim_steps=args.ddim_steps,
)
wflat = np.clip(Wmap.ravel().astype(np.float64), 0.0, None)
if wflat.sum() <= 0:
raise RuntimeError("Posterior masses collapsed to zero.")
wflat /= wflat.sum()
omapflat = OM.ravel()
s8flat = S8.ravel()
draws = rng.choice(np.arange(len(wflat)), size=args.n_hist, replace=True, p=wflat)
samp_om = omapflat[draws].astype(np.float64)
samp_s8 = s8flat[draws].astype(np.float64)
out = Path(args.output).resolve()
out.parent.mkdir(parents=True, exist_ok=True)
tag = f"ddpm{ld}_{args.six_tail_mode}" if ld == 6 else "ddpm2"
np.savez_compressed(
out,
omega_m=samp_om,
sigma_8=samp_s8,
samples=np.column_stack([samp_om, samp_s8]),
truth_Omega_m=float(lab_t[0]),
truth_sigma_8=float(lab_t[1]),
posterior_map=Wmap,
OM=OM,
S8=S8,
index=np.array(ix, dtype=np.int32),
label_dim=np.array(ld, dtype=np.int16),
meta_tag=np.array(tag, dtype="U128"),
six_tail_mode=np.array(args.six_tail_mode if ld == 6 else "", dtype="U16"),
)
print("Saved", out, "pairs:", args.n_hist, "device:", device)
if __name__ == "__main__":
main()