Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
f513198 verified | #!/usr/bin/env python3 | |
| """ | |
| Compute surrogate posteriors and emit Figure-6 style figures (arXiv:2409.09101-inspired). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| _SCRIPTS = Path(__file__).resolve().parent | |
| MODELS_ROOT = Path(__file__).resolve().parents[1] | |
| CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6" | |
| if str(_SCRIPTS) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPTS)) | |
| if str(CODE_6.resolve()) not in sys.path: | |
| sys.path.insert(0, str(CODE_6.resolve())) | |
| import evaluate_conditional as ec # noqa: E402 | |
| import ddpm_posterior_six_anchors as dps # noqa: E402 | |
| from ddpm_figure6_integration import ( # noqa: E402 | |
| integrate_figure6_model_comparison, | |
| integrate_figure6_with_ddpm2, | |
| integrate_figure6_with_multi_anchor, | |
| print_integration_guide, | |
| ) | |
| def main() -> None: | |
| p = argparse.ArgumentParser(description="DDPM Figure-6 style posterior suite.") | |
| p.add_argument("--output-dir", type=Path, default=MODELS_ROOT / "ddpm_figure6_out") | |
| p.add_argument("--data-2param", type=Path, default=Path("<DDPM_ROOT>/data/LH_data/params_2")) | |
| p.add_argument("--data-6param", type=Path, default=Path("<DDPM_ROOT>/data/LH_data/params_6")) | |
| p.add_argument( | |
| "--bundle-2param", | |
| type=Path, | |
| default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200", | |
| ) | |
| p.add_argument( | |
| "--bundle-6param", | |
| type=Path, | |
| default=MODELS_ROOT / "notebook_model_weights" / "6param_best", | |
| ) | |
| p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) | |
| p.add_argument("--test-index", type=int, default=56, help="Index for single comparison + per-map fig6.") | |
| p.add_argument("--grid", type=int, default=14) | |
| p.add_argument("--ddim-steps", type=int, default=50) | |
| p.add_argument("--batch-size", type=int, default=8) | |
| p.add_argument( | |
| "--six-anchors-only", | |
| action="store_true", | |
| help="Only 2×3 multi-anchor plots (skip triple model comparison at --test-index).", | |
| ) | |
| p.add_argument( | |
| "--no-six-grid", | |
| action="store_true", | |
| help="Skip multi-anchor 2×3 panels.", | |
| ) | |
| p.add_argument( | |
| "--no-single-fig6", | |
| action="store_true", | |
| help="Skip per-map marginal/profile for test-index on DDPM-2 and DDPM-6 (truth tail).", | |
| ) | |
| p.add_argument( | |
| "--guide", | |
| action="store_true", | |
| help="Print markdown-style integration notes and exit.", | |
| ) | |
| args = p.parse_args() | |
| if args.guide: | |
| print_integration_guide() | |
| return | |
| out = Path(args.output_dir).resolve() | |
| out.mkdir(parents=True, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("device:", device) | |
| data2 = Path(args.data_2param) | |
| data6 = Path(args.data_6param) | |
| imgs2, lab2 = ec.load_split(data2, args.split) | |
| imgs6, lab6 = ec.load_split(data6, args.split) | |
| n = min(len(lab2), len(lab6)) | |
| anchor_ix = np.linspace(0, n - 1, num=6, dtype=int) | |
| low_tail, hi_tail = dps.tail_lhs_bounds(data6) | |
| ck2 = args.bundle_2param / "checkpoint_epoch_200.pt" | |
| aj2 = args.bundle_2param / "args.json" | |
| ck6 = args.bundle_6param / "best_model.pt" | |
| aj6 = args.bundle_6param / "args.json" | |
| mean2, std2 = ec.load_label_stats(data2) | |
| mean6, std6 = ec.load_label_stats(data6) | |
| ix = int(args.test_index) | |
| if not (0 <= ix < n): | |
| raise SystemExit(f"--test-index {ix} invalid (max {n - 1})") | |
| lab_box = lab6[:, :2].copy() | |
| if not args.six_anchors_only: | |
| print(">>> Loading models for ix=", ix, "...") | |
| m2, c2 = dps.load_model(aj2, ck2, device) | |
| m6, c6 = dps.load_model(aj6, ck6, device) | |
| normalize2 = bool(c2.get("normalize_labels", True)) | |
| normalize6 = bool(c6.get("normalize_labels", True)) | |
| obs2 = imgs2[ix] | |
| obs6 = imgs6[ix] | |
| lt2 = lab2[ix].astype(np.float64) | |
| lt6 = lab6[ix].astype(np.float64) | |
| ta2om, ta2s8 = float(lt2[0]), float(lt2[1]) | |
| tom, ts8 = float(lt6[0]), float(lt6[1]) | |
| full2, om_ax, s8_ax = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2) | |
| Wm2, _, _ = dps.posterior_weights( | |
| obs2, | |
| full2, | |
| om_ax, | |
| s8_ax, | |
| mean2, | |
| std2, | |
| normalize2, | |
| m2, | |
| H=int(obs2.shape[-2]), | |
| W=int(obs2.shape[-1]), | |
| device=device, | |
| grid=args.grid, | |
| batch_sz=args.batch_size, | |
| ddim_steps=args.ddim_steps, | |
| ) | |
| full6truth, om6, s86 = dps.build_full_grid_2d( | |
| lab6, args.grid, tail=lab6[ix, 2:6].astype(np.float32), lab_dim=6 | |
| ) | |
| Wm6t, _, _ = dps.posterior_weights( | |
| obs6, | |
| full6truth, | |
| om6, | |
| s86, | |
| mean6, | |
| std6, | |
| normalize6, | |
| m6, | |
| H=int(obs6.shape[-2]), | |
| W=int(obs6.shape[-1]), | |
| device=device, | |
| grid=args.grid, | |
| batch_sz=args.batch_size, | |
| ddim_steps=args.ddim_steps, | |
| ) | |
| full6lo, om_b, s8_b = dps.build_full_grid_2d(lab6, args.grid, tail=low_tail, lab_dim=6) | |
| Wm6lo, _, _ = dps.posterior_weights( | |
| obs6, | |
| full6lo, | |
| om_b, | |
| s8_b, | |
| mean6, | |
| std6, | |
| normalize6, | |
| m6, | |
| H=int(obs6.shape[-2]), | |
| W=int(obs6.shape[-1]), | |
| device=device, | |
| grid=args.grid, | |
| batch_sz=args.batch_size, | |
| ddim_steps=args.ddim_steps, | |
| ) | |
| full6hi, om_c, s8_c = dps.build_full_grid_2d(lab6, args.grid, tail=hi_tail, lab_dim=6) | |
| Wm6hi, _, _ = dps.posterior_weights( | |
| obs6, | |
| full6hi, | |
| om_c, | |
| s8_c, | |
| mean6, | |
| std6, | |
| normalize6, | |
| m6, | |
| H=int(obs6.shape[-2]), | |
| W=int(obs6.shape[-1]), | |
| device=device, | |
| grid=args.grid, | |
| batch_sz=args.batch_size, | |
| ddim_steps=args.ddim_steps, | |
| ) | |
| if not (np.allclose(om_ax, om_b, rtol=0, atol=1e-12) and np.allclose(s8_ax, s86)): | |
| print("Warning: Ωm–σ8 grids differ between setups; plotting uses DDPM-2 Ωm/σ8 axes.") | |
| integrate_figure6_model_comparison( | |
| { | |
| "DDPM-2": Wm2, | |
| "DDPM-6 (truth-tail)": Wm6t, | |
| "DDPM-6 (min-tail)": Wm6lo, | |
| "DDPM-6 (max-tail)": Wm6hi, | |
| }, | |
| om_ax, | |
| s8_ax, | |
| tom, | |
| ts8, | |
| ix, | |
| out, | |
| ) | |
| if not args.no_single_fig6: | |
| integrate_figure6_with_ddpm2(Wm2, om_ax, s8_ax, ta2om, ta2s8, ix, out, model_name="DDPM-2") | |
| integrate_figure6_with_ddpm2(Wm6t, om_ax, s8_ax, tom, ts8, ix, out, model_name="DDPM-6-truth") | |
| del m2, m6 | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # --- Six anchors: multi grids for DDPM-2 + DDPM-6 truth tail --- | |
| if not args.no_six_grid: | |
| print(">>> Six-anchor Figure 6 grids...") | |
| model2, cfg2 = dps.load_model(aj2, ck2, device) | |
| model6, cfg6 = dps.load_model(aj6, ck6, device) | |
| nz2 = bool(cfg2.get("normalize_labels", True)) | |
| nz6 = bool(cfg6.get("normalize_labels", True)) | |
| post2: list[np.ndarray] = [] | |
| post6: list[np.ndarray] = [] | |
| truths: list[tuple[float, float]] = [] | |
| indices: list[int] = [] | |
| for k, jx in enumerate(anchor_ix.ravel()): | |
| indices.append(int(jx)) | |
| o2 = imgs2[jx] | |
| lb2 = lab2[jx].astype(np.float64) | |
| f2, oa, sa = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2) | |
| W2, _, _ = dps.posterior_weights( | |
| o2, | |
| f2, | |
| oa, | |
| sa, | |
| mean2, | |
| std2, | |
| nz2, | |
| model2, | |
| H=int(o2.shape[-2]), | |
| W=int(o2.shape[-1]), | |
| device=device, | |
| grid=args.grid, | |
| batch_sz=args.batch_size, | |
| ddim_steps=args.ddim_steps, | |
| ) | |
| post2.append(W2) | |
| o6 = imgs6[jx] | |
| lb6 = lab6[jx] | |
| tail_truth = lb6.astype(np.float32)[2:6] | |
| f6, oa6, sa6 = dps.build_full_grid_2d(lab6, args.grid, tail=tail_truth, lab_dim=6) | |
| W6, _, _ = dps.posterior_weights( | |
| o6, | |
| f6, | |
| oa6, | |
| sa6, | |
| mean6, | |
| std6, | |
| nz6, | |
| model6, | |
| H=int(o6.shape[-2]), | |
| W=int(o6.shape[-1]), | |
| device=device, | |
| grid=args.grid, | |
| batch_sz=args.batch_size, | |
| ddim_steps=args.ddim_steps, | |
| ) | |
| post6.append(W6) | |
| truths.append((float(lb2[0]), float(lb2[1]))) | |
| integrate_figure6_with_multi_anchor( | |
| post2, | |
| oa, | |
| sa, | |
| truths, | |
| indices, | |
| out, | |
| model_name="DDPM-2", | |
| ) | |
| truths6 = [(float(lab6[int(j)][0]), float(lab6[int(j)][1])) for j in anchor_ix] | |
| integrate_figure6_with_multi_anchor( | |
| post6, | |
| oa6, | |
| sa6, | |
| truths6, | |
| indices, | |
| out, | |
| model_name="DDPM-6-truth-tail", | |
| ) | |
| del model2, model6 | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"Done. Outputs in {out}") | |
| if __name__ == "__main__": | |
| main() | |