#!/usr/bin/env python3 """ Surrogate posterior on $(\\Omega_m, \\sigma_8)$ → triangle/MCMC-style chains for one test map. Loads the same surrogate likelihood used in ``ddpm_posterior_six_anchors``, resamples discrete posterior masses to ``--n-hist`` correlated $(\\Omega_m,\\sigma_8)$ pairs, and writes ``.npz``. DDPM-2: sweeps $(\\Omega_m,\\sigma_8)$. DDPM-6: dims 2–5 fixed per ``--six-tail-mode`` (``truth`` uses the test-map labels 2–5; ``min``/``max`` use LHS extrema from training labels). If you replace this file with a copy from your machine (Downloads), keep argparse compatible or wrap it. """ from __future__ import annotations import argparse import sys from pathlib import Path import numpy as np import torch _SCRIPTS = Path(__file__).resolve().parent if str(_SCRIPTS) not in sys.path: sys.path.insert(0, str(_SCRIPTS)) import ddpm_posterior_six_anchors as dps # noqa: E402 MODELS_ROOT = Path(__file__).resolve().parents[1] CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6" if str(CODE_6.resolve()) not in sys.path: sys.path.insert(0, str(CODE_6.resolve())) import evaluate_conditional as ec # noqa: E402 def _tail_vec( mode: str, lab_full: np.ndarray, data6: Path, ) -> np.ndarray | None: if lab_full.size <= 2: return None if mode == "truth": return lab_full[2:6].astype(np.float32) low, hi = dps.tail_lhs_bounds(data6) if mode == "min": return low if mode == "max": return hi raise ValueError("six-tail-mode must be truth|min|max") def main() -> None: p = argparse.ArgumentParser(description="DDPM surrogate posterior → resampled Ωm σ8 chains (.npz).") p.add_argument( "--label-dim", type=int, choices=[2, 6], required=True, help="Which model to use.", ) p.add_argument( "--bundle", type=Path, default=None, help="Checkpoint bundle dir with args.json (default: notebook_model_weights/<2|6>).", ) p.add_argument( "--checkpoint-name", type=str, default=None, help="Checkpoint file under bundle (defaults: DDPM2 epoch200, DDPM6 best_model).", ) p.add_argument( "--data-dir", type=Path, default=None, help="LH data dir matching label_dim (default: params_2 vs params_6).", ) p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) p.add_argument("--test-index", type=int, default=56, help="Index into split for CAMELS observation.") p.add_argument("--grid", type=int, default=14) p.add_argument("--ddim-steps", type=int, default=50) p.add_argument("--batch-size", type=int, default=8) p.add_argument( "--n-hist", type=int, default=10_000, help="Resampled posterior pairs (with replacement).", ) p.add_argument( "--six-tail-mode", type=str, default="truth", choices=["truth", "min", "max"], help="Applies only to label_dim==6 — how dims 2–5 are fixed.", ) p.add_argument( "--output", "-o", type=Path, required=True, help="Output .npz path.", ) p.add_argument("--seed", type=int, default=42) args = p.parse_args() ld = args.label_dim if ld == 2: data_dir = args.data_dir or Path("/data/LH_data/params_2") bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "2param_epoch200" ck_name = args.checkpoint_name or "checkpoint_epoch_200.pt" else: data_dir = args.data_dir or Path("/data/LH_data/params_6") bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "6param_best" ck_name = args.checkpoint_name or "best_model.pt" rng = np.random.default_rng(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") imgs, labs = ec.load_split(data_dir, args.split) ix = int(args.test_index) if not (0 <= ix < len(labs)): raise SystemExit(f"test-index {ix} out of range for split ({len(labs)} rows)") lab_t = labs[ix].astype(np.float64) obs = imgs[ix] ckpt = bundle / ck_name args_json = bundle / "args.json" mean, std = ec.load_label_stats(data_dir) tail = None if ld == 6: lab6 = labs[ix].astype(np.float64) if lab6.shape[0] != 6: raise SystemExit("--label-dim 6 requires labels with 6 columns in data-dir") tail = _tail_vec(args.six_tail_mode, lab6, Path(data_dir)) model, cfg = dps.load_model(args_json, ckpt, device) normalize = bool(cfg.get("normalize_labels", True)) H = int(obs.shape[-2]) W = int(obs.shape[-1]) gsz = args.grid full, om_ax, s8_ax = dps.build_full_grid_2d(labs, gsz, tail=tail, lab_dim=ld) Wmap, OM, S8 = dps.posterior_weights( obs, full, om_ax, s8_ax, mean, std, normalize, model, H=H, W=W, device=device, grid=gsz, batch_sz=args.batch_size, ddim_steps=args.ddim_steps, ) wflat = np.clip(Wmap.ravel().astype(np.float64), 0.0, None) if wflat.sum() <= 0: raise RuntimeError("Posterior masses collapsed to zero.") wflat /= wflat.sum() omapflat = OM.ravel() s8flat = S8.ravel() draws = rng.choice(np.arange(len(wflat)), size=args.n_hist, replace=True, p=wflat) samp_om = omapflat[draws].astype(np.float64) samp_s8 = s8flat[draws].astype(np.float64) out = Path(args.output).resolve() out.parent.mkdir(parents=True, exist_ok=True) tag = f"ddpm{ld}_{args.six_tail_mode}" if ld == 6 else "ddpm2" np.savez_compressed( out, omega_m=samp_om, sigma_8=samp_s8, samples=np.column_stack([samp_om, samp_s8]), truth_Omega_m=float(lab_t[0]), truth_sigma_8=float(lab_t[1]), posterior_map=Wmap, OM=OM, S8=S8, index=np.array(ix, dtype=np.int32), label_dim=np.array(ld, dtype=np.int16), meta_tag=np.array(tag, dtype="U128"), six_tail_mode=np.array(args.six_tail_mode if ld == 6 else "", dtype="U16"), ) print("Saved", out, "pairs:", args.n_hist, "device:", device) if __name__ == "__main__": main()