#!/usr/bin/env python3 """ Compute surrogate posteriors and emit Figure-6 style figures (arXiv:2409.09101-inspired). """ from __future__ import annotations import argparse import gc import sys from pathlib import Path import numpy as np import torch _SCRIPTS = Path(__file__).resolve().parent MODELS_ROOT = Path(__file__).resolve().parents[1] CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6" if str(_SCRIPTS) not in sys.path: sys.path.insert(0, str(_SCRIPTS)) if str(CODE_6.resolve()) not in sys.path: sys.path.insert(0, str(CODE_6.resolve())) import evaluate_conditional as ec # noqa: E402 import ddpm_posterior_six_anchors as dps # noqa: E402 from ddpm_figure6_integration import ( # noqa: E402 integrate_figure6_model_comparison, integrate_figure6_with_ddpm2, integrate_figure6_with_multi_anchor, print_integration_guide, ) def main() -> None: p = argparse.ArgumentParser(description="DDPM Figure-6 style posterior suite.") p.add_argument("--output-dir", type=Path, default=MODELS_ROOT / "ddpm_figure6_out") p.add_argument("--data-2param", type=Path, default=Path("/data/LH_data/params_2")) p.add_argument("--data-6param", type=Path, default=Path("/data/LH_data/params_6")) p.add_argument( "--bundle-2param", type=Path, default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200", ) p.add_argument( "--bundle-6param", type=Path, default=MODELS_ROOT / "notebook_model_weights" / "6param_best", ) p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) p.add_argument("--test-index", type=int, default=56, help="Index for single comparison + per-map fig6.") p.add_argument("--grid", type=int, default=14) p.add_argument("--ddim-steps", type=int, default=50) p.add_argument("--batch-size", type=int, default=8) p.add_argument( "--six-anchors-only", action="store_true", help="Only 2×3 multi-anchor plots (skip triple model comparison at --test-index).", ) p.add_argument( "--no-six-grid", action="store_true", help="Skip multi-anchor 2×3 panels.", ) p.add_argument( "--no-single-fig6", action="store_true", help="Skip per-map marginal/profile for test-index on DDPM-2 and DDPM-6 (truth tail).", ) p.add_argument( "--guide", action="store_true", help="Print markdown-style integration notes and exit.", ) args = p.parse_args() if args.guide: print_integration_guide() return out = Path(args.output_dir).resolve() out.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("device:", device) data2 = Path(args.data_2param) data6 = Path(args.data_6param) imgs2, lab2 = ec.load_split(data2, args.split) imgs6, lab6 = ec.load_split(data6, args.split) n = min(len(lab2), len(lab6)) anchor_ix = np.linspace(0, n - 1, num=6, dtype=int) low_tail, hi_tail = dps.tail_lhs_bounds(data6) ck2 = args.bundle_2param / "checkpoint_epoch_200.pt" aj2 = args.bundle_2param / "args.json" ck6 = args.bundle_6param / "best_model.pt" aj6 = args.bundle_6param / "args.json" mean2, std2 = ec.load_label_stats(data2) mean6, std6 = ec.load_label_stats(data6) ix = int(args.test_index) if not (0 <= ix < n): raise SystemExit(f"--test-index {ix} invalid (max {n - 1})") lab_box = lab6[:, :2].copy() if not args.six_anchors_only: print(">>> Loading models for ix=", ix, "...") m2, c2 = dps.load_model(aj2, ck2, device) m6, c6 = dps.load_model(aj6, ck6, device) normalize2 = bool(c2.get("normalize_labels", True)) normalize6 = bool(c6.get("normalize_labels", True)) obs2 = imgs2[ix] obs6 = imgs6[ix] lt2 = lab2[ix].astype(np.float64) lt6 = lab6[ix].astype(np.float64) ta2om, ta2s8 = float(lt2[0]), float(lt2[1]) tom, ts8 = float(lt6[0]), float(lt6[1]) full2, om_ax, s8_ax = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2) Wm2, _, _ = dps.posterior_weights( obs2, full2, om_ax, s8_ax, mean2, std2, normalize2, m2, H=int(obs2.shape[-2]), W=int(obs2.shape[-1]), device=device, grid=args.grid, batch_sz=args.batch_size, ddim_steps=args.ddim_steps, ) full6truth, om6, s86 = dps.build_full_grid_2d( lab6, args.grid, tail=lab6[ix, 2:6].astype(np.float32), lab_dim=6 ) Wm6t, _, _ = dps.posterior_weights( obs6, full6truth, om6, s86, mean6, std6, normalize6, m6, H=int(obs6.shape[-2]), W=int(obs6.shape[-1]), device=device, grid=args.grid, batch_sz=args.batch_size, ddim_steps=args.ddim_steps, ) full6lo, om_b, s8_b = dps.build_full_grid_2d(lab6, args.grid, tail=low_tail, lab_dim=6) Wm6lo, _, _ = dps.posterior_weights( obs6, full6lo, om_b, s8_b, mean6, std6, normalize6, m6, H=int(obs6.shape[-2]), W=int(obs6.shape[-1]), device=device, grid=args.grid, batch_sz=args.batch_size, ddim_steps=args.ddim_steps, ) full6hi, om_c, s8_c = dps.build_full_grid_2d(lab6, args.grid, tail=hi_tail, lab_dim=6) Wm6hi, _, _ = dps.posterior_weights( obs6, full6hi, om_c, s8_c, mean6, std6, normalize6, m6, H=int(obs6.shape[-2]), W=int(obs6.shape[-1]), device=device, grid=args.grid, batch_sz=args.batch_size, ddim_steps=args.ddim_steps, ) if not (np.allclose(om_ax, om_b, rtol=0, atol=1e-12) and np.allclose(s8_ax, s86)): print("Warning: Ωm–σ8 grids differ between setups; plotting uses DDPM-2 Ωm/σ8 axes.") integrate_figure6_model_comparison( { "DDPM-2": Wm2, "DDPM-6 (truth-tail)": Wm6t, "DDPM-6 (min-tail)": Wm6lo, "DDPM-6 (max-tail)": Wm6hi, }, om_ax, s8_ax, tom, ts8, ix, out, ) if not args.no_single_fig6: integrate_figure6_with_ddpm2(Wm2, om_ax, s8_ax, ta2om, ta2s8, ix, out, model_name="DDPM-2") integrate_figure6_with_ddpm2(Wm6t, om_ax, s8_ax, tom, ts8, ix, out, model_name="DDPM-6-truth") del m2, m6 gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # --- Six anchors: multi grids for DDPM-2 + DDPM-6 truth tail --- if not args.no_six_grid: print(">>> Six-anchor Figure 6 grids...") model2, cfg2 = dps.load_model(aj2, ck2, device) model6, cfg6 = dps.load_model(aj6, ck6, device) nz2 = bool(cfg2.get("normalize_labels", True)) nz6 = bool(cfg6.get("normalize_labels", True)) post2: list[np.ndarray] = [] post6: list[np.ndarray] = [] truths: list[tuple[float, float]] = [] indices: list[int] = [] for k, jx in enumerate(anchor_ix.ravel()): indices.append(int(jx)) o2 = imgs2[jx] lb2 = lab2[jx].astype(np.float64) f2, oa, sa = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2) W2, _, _ = dps.posterior_weights( o2, f2, oa, sa, mean2, std2, nz2, model2, H=int(o2.shape[-2]), W=int(o2.shape[-1]), device=device, grid=args.grid, batch_sz=args.batch_size, ddim_steps=args.ddim_steps, ) post2.append(W2) o6 = imgs6[jx] lb6 = lab6[jx] tail_truth = lb6.astype(np.float32)[2:6] f6, oa6, sa6 = dps.build_full_grid_2d(lab6, args.grid, tail=tail_truth, lab_dim=6) W6, _, _ = dps.posterior_weights( o6, f6, oa6, sa6, mean6, std6, nz6, model6, H=int(o6.shape[-2]), W=int(o6.shape[-1]), device=device, grid=args.grid, batch_sz=args.batch_size, ddim_steps=args.ddim_steps, ) post6.append(W6) truths.append((float(lb2[0]), float(lb2[1]))) integrate_figure6_with_multi_anchor( post2, oa, sa, truths, indices, out, model_name="DDPM-2", ) truths6 = [(float(lab6[int(j)][0]), float(lab6[int(j)][1])) for j in anchor_ix] integrate_figure6_with_multi_anchor( post6, oa6, sa6, truths6, indices, out, model_name="DDPM-6-truth-tail", ) del model2, model6 gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"Done. Outputs in {out}") if __name__ == "__main__": main()