Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
1f3e7a2 verified | #!/usr/bin/env python3 | |
| """ | |
| Surrogate posterior on $(\\Omega_m, \\sigma_8)$ → triangle/MCMC-style chains for one test map. | |
| Loads the same surrogate likelihood used in ``ddpm_posterior_six_anchors``, resamples discrete | |
| posterior masses to ``--n-hist`` correlated $(\\Omega_m,\\sigma_8)$ pairs, and writes ``.npz``. | |
| DDPM-2: sweeps $(\\Omega_m,\\sigma_8)$. | |
| DDPM-6: dims 2–5 fixed per ``--six-tail-mode`` (``truth`` uses the test-map labels 2–5; ``min``/``max`` | |
| use LHS extrema from training labels). | |
| If you replace this file with a copy from your machine (Downloads), keep argparse compatible or wrap it. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| _SCRIPTS = Path(__file__).resolve().parent | |
| if str(_SCRIPTS) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPTS)) | |
| import ddpm_posterior_six_anchors as dps # noqa: E402 | |
| MODELS_ROOT = Path(__file__).resolve().parents[1] | |
| CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6" | |
| if str(CODE_6.resolve()) not in sys.path: | |
| sys.path.insert(0, str(CODE_6.resolve())) | |
| import evaluate_conditional as ec # noqa: E402 | |
| def _tail_vec( | |
| mode: str, | |
| lab_full: np.ndarray, | |
| data6: Path, | |
| ) -> np.ndarray | None: | |
| if lab_full.size <= 2: | |
| return None | |
| if mode == "truth": | |
| return lab_full[2:6].astype(np.float32) | |
| low, hi = dps.tail_lhs_bounds(data6) | |
| if mode == "min": | |
| return low | |
| if mode == "max": | |
| return hi | |
| raise ValueError("six-tail-mode must be truth|min|max") | |
| def main() -> None: | |
| p = argparse.ArgumentParser(description="DDPM surrogate posterior → resampled Ωm σ8 chains (.npz).") | |
| p.add_argument( | |
| "--label-dim", | |
| type=int, | |
| choices=[2, 6], | |
| required=True, | |
| help="Which model to use.", | |
| ) | |
| p.add_argument( | |
| "--bundle", | |
| type=Path, | |
| default=None, | |
| help="Checkpoint bundle dir with args.json (default: notebook_model_weights/<2|6>).", | |
| ) | |
| p.add_argument( | |
| "--checkpoint-name", | |
| type=str, | |
| default=None, | |
| help="Checkpoint file under bundle (defaults: DDPM2 epoch200, DDPM6 best_model).", | |
| ) | |
| p.add_argument( | |
| "--data-dir", | |
| type=Path, | |
| default=None, | |
| help="LH data dir matching label_dim (default: params_2 vs params_6).", | |
| ) | |
| p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) | |
| p.add_argument("--test-index", type=int, default=56, help="Index into split for CAMELS observation.") | |
| p.add_argument("--grid", type=int, default=14) | |
| p.add_argument("--ddim-steps", type=int, default=50) | |
| p.add_argument("--batch-size", type=int, default=8) | |
| p.add_argument( | |
| "--n-hist", | |
| type=int, | |
| default=10_000, | |
| help="Resampled posterior pairs (with replacement).", | |
| ) | |
| p.add_argument( | |
| "--six-tail-mode", | |
| type=str, | |
| default="truth", | |
| choices=["truth", "min", "max"], | |
| help="Applies only to label_dim==6 — how dims 2–5 are fixed.", | |
| ) | |
| p.add_argument( | |
| "--output", | |
| "-o", | |
| type=Path, | |
| required=True, | |
| help="Output .npz path.", | |
| ) | |
| p.add_argument("--seed", type=int, default=42) | |
| args = p.parse_args() | |
| ld = args.label_dim | |
| if ld == 2: | |
| data_dir = args.data_dir or Path("<DDPM_ROOT>/data/LH_data/params_2") | |
| bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "2param_epoch200" | |
| ck_name = args.checkpoint_name or "checkpoint_epoch_200.pt" | |
| else: | |
| data_dir = args.data_dir or Path("<DDPM_ROOT>/data/LH_data/params_6") | |
| bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "6param_best" | |
| ck_name = args.checkpoint_name or "best_model.pt" | |
| rng = np.random.default_rng(args.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| imgs, labs = ec.load_split(data_dir, args.split) | |
| ix = int(args.test_index) | |
| if not (0 <= ix < len(labs)): | |
| raise SystemExit(f"test-index {ix} out of range for split ({len(labs)} rows)") | |
| lab_t = labs[ix].astype(np.float64) | |
| obs = imgs[ix] | |
| ckpt = bundle / ck_name | |
| args_json = bundle / "args.json" | |
| mean, std = ec.load_label_stats(data_dir) | |
| tail = None | |
| if ld == 6: | |
| lab6 = labs[ix].astype(np.float64) | |
| if lab6.shape[0] != 6: | |
| raise SystemExit("--label-dim 6 requires labels with 6 columns in data-dir") | |
| tail = _tail_vec(args.six_tail_mode, lab6, Path(data_dir)) | |
| model, cfg = dps.load_model(args_json, ckpt, device) | |
| normalize = bool(cfg.get("normalize_labels", True)) | |
| H = int(obs.shape[-2]) | |
| W = int(obs.shape[-1]) | |
| gsz = args.grid | |
| full, om_ax, s8_ax = dps.build_full_grid_2d(labs, gsz, tail=tail, lab_dim=ld) | |
| Wmap, OM, S8 = dps.posterior_weights( | |
| obs, | |
| full, | |
| om_ax, | |
| s8_ax, | |
| mean, | |
| std, | |
| normalize, | |
| model, | |
| H=H, | |
| W=W, | |
| device=device, | |
| grid=gsz, | |
| batch_sz=args.batch_size, | |
| ddim_steps=args.ddim_steps, | |
| ) | |
| wflat = np.clip(Wmap.ravel().astype(np.float64), 0.0, None) | |
| if wflat.sum() <= 0: | |
| raise RuntimeError("Posterior masses collapsed to zero.") | |
| wflat /= wflat.sum() | |
| omapflat = OM.ravel() | |
| s8flat = S8.ravel() | |
| draws = rng.choice(np.arange(len(wflat)), size=args.n_hist, replace=True, p=wflat) | |
| samp_om = omapflat[draws].astype(np.float64) | |
| samp_s8 = s8flat[draws].astype(np.float64) | |
| out = Path(args.output).resolve() | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| tag = f"ddpm{ld}_{args.six_tail_mode}" if ld == 6 else "ddpm2" | |
| np.savez_compressed( | |
| out, | |
| omega_m=samp_om, | |
| sigma_8=samp_s8, | |
| samples=np.column_stack([samp_om, samp_s8]), | |
| truth_Omega_m=float(lab_t[0]), | |
| truth_sigma_8=float(lab_t[1]), | |
| posterior_map=Wmap, | |
| OM=OM, | |
| S8=S8, | |
| index=np.array(ix, dtype=np.int32), | |
| label_dim=np.array(ld, dtype=np.int16), | |
| meta_tag=np.array(tag, dtype="U128"), | |
| six_tail_mode=np.array(args.six_tail_mode if ld == 6 else "", dtype="U16"), | |
| ) | |
| print("Saved", out, "pairs:", args.n_hist, "device:", device) | |
| if __name__ == "__main__": | |
| main() | |