DDPM-2param / cross_model /scripts /run_ddpm_figure6_suite.py
collins909's picture
Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
f513198 verified
#!/usr/bin/env python3
"""
Compute surrogate posteriors and emit Figure-6 style figures (arXiv:2409.09101-inspired).
"""
from __future__ import annotations
import argparse
import gc
import sys
from pathlib import Path
import numpy as np
import torch
_SCRIPTS = Path(__file__).resolve().parent
MODELS_ROOT = Path(__file__).resolve().parents[1]
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
if str(_SCRIPTS) not in sys.path:
sys.path.insert(0, str(_SCRIPTS))
if str(CODE_6.resolve()) not in sys.path:
sys.path.insert(0, str(CODE_6.resolve()))
import evaluate_conditional as ec # noqa: E402
import ddpm_posterior_six_anchors as dps # noqa: E402
from ddpm_figure6_integration import ( # noqa: E402
integrate_figure6_model_comparison,
integrate_figure6_with_ddpm2,
integrate_figure6_with_multi_anchor,
print_integration_guide,
)
def main() -> None:
p = argparse.ArgumentParser(description="DDPM Figure-6 style posterior suite.")
p.add_argument("--output-dir", type=Path, default=MODELS_ROOT / "ddpm_figure6_out")
p.add_argument("--data-2param", type=Path, default=Path("<DDPM_ROOT>/data/LH_data/params_2"))
p.add_argument("--data-6param", type=Path, default=Path("<DDPM_ROOT>/data/LH_data/params_6"))
p.add_argument(
"--bundle-2param",
type=Path,
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
)
p.add_argument(
"--bundle-6param",
type=Path,
default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
)
p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
p.add_argument("--test-index", type=int, default=56, help="Index for single comparison + per-map fig6.")
p.add_argument("--grid", type=int, default=14)
p.add_argument("--ddim-steps", type=int, default=50)
p.add_argument("--batch-size", type=int, default=8)
p.add_argument(
"--six-anchors-only",
action="store_true",
help="Only 2×3 multi-anchor plots (skip triple model comparison at --test-index).",
)
p.add_argument(
"--no-six-grid",
action="store_true",
help="Skip multi-anchor 2×3 panels.",
)
p.add_argument(
"--no-single-fig6",
action="store_true",
help="Skip per-map marginal/profile for test-index on DDPM-2 and DDPM-6 (truth tail).",
)
p.add_argument(
"--guide",
action="store_true",
help="Print markdown-style integration notes and exit.",
)
args = p.parse_args()
if args.guide:
print_integration_guide()
return
out = Path(args.output_dir).resolve()
out.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
data2 = Path(args.data_2param)
data6 = Path(args.data_6param)
imgs2, lab2 = ec.load_split(data2, args.split)
imgs6, lab6 = ec.load_split(data6, args.split)
n = min(len(lab2), len(lab6))
anchor_ix = np.linspace(0, n - 1, num=6, dtype=int)
low_tail, hi_tail = dps.tail_lhs_bounds(data6)
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
aj2 = args.bundle_2param / "args.json"
ck6 = args.bundle_6param / "best_model.pt"
aj6 = args.bundle_6param / "args.json"
mean2, std2 = ec.load_label_stats(data2)
mean6, std6 = ec.load_label_stats(data6)
ix = int(args.test_index)
if not (0 <= ix < n):
raise SystemExit(f"--test-index {ix} invalid (max {n - 1})")
lab_box = lab6[:, :2].copy()
if not args.six_anchors_only:
print(">>> Loading models for ix=", ix, "...")
m2, c2 = dps.load_model(aj2, ck2, device)
m6, c6 = dps.load_model(aj6, ck6, device)
normalize2 = bool(c2.get("normalize_labels", True))
normalize6 = bool(c6.get("normalize_labels", True))
obs2 = imgs2[ix]
obs6 = imgs6[ix]
lt2 = lab2[ix].astype(np.float64)
lt6 = lab6[ix].astype(np.float64)
ta2om, ta2s8 = float(lt2[0]), float(lt2[1])
tom, ts8 = float(lt6[0]), float(lt6[1])
full2, om_ax, s8_ax = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2)
Wm2, _, _ = dps.posterior_weights(
obs2,
full2,
om_ax,
s8_ax,
mean2,
std2,
normalize2,
m2,
H=int(obs2.shape[-2]),
W=int(obs2.shape[-1]),
device=device,
grid=args.grid,
batch_sz=args.batch_size,
ddim_steps=args.ddim_steps,
)
full6truth, om6, s86 = dps.build_full_grid_2d(
lab6, args.grid, tail=lab6[ix, 2:6].astype(np.float32), lab_dim=6
)
Wm6t, _, _ = dps.posterior_weights(
obs6,
full6truth,
om6,
s86,
mean6,
std6,
normalize6,
m6,
H=int(obs6.shape[-2]),
W=int(obs6.shape[-1]),
device=device,
grid=args.grid,
batch_sz=args.batch_size,
ddim_steps=args.ddim_steps,
)
full6lo, om_b, s8_b = dps.build_full_grid_2d(lab6, args.grid, tail=low_tail, lab_dim=6)
Wm6lo, _, _ = dps.posterior_weights(
obs6,
full6lo,
om_b,
s8_b,
mean6,
std6,
normalize6,
m6,
H=int(obs6.shape[-2]),
W=int(obs6.shape[-1]),
device=device,
grid=args.grid,
batch_sz=args.batch_size,
ddim_steps=args.ddim_steps,
)
full6hi, om_c, s8_c = dps.build_full_grid_2d(lab6, args.grid, tail=hi_tail, lab_dim=6)
Wm6hi, _, _ = dps.posterior_weights(
obs6,
full6hi,
om_c,
s8_c,
mean6,
std6,
normalize6,
m6,
H=int(obs6.shape[-2]),
W=int(obs6.shape[-1]),
device=device,
grid=args.grid,
batch_sz=args.batch_size,
ddim_steps=args.ddim_steps,
)
if not (np.allclose(om_ax, om_b, rtol=0, atol=1e-12) and np.allclose(s8_ax, s86)):
print("Warning: Ωm–σ8 grids differ between setups; plotting uses DDPM-2 Ωm/σ8 axes.")
integrate_figure6_model_comparison(
{
"DDPM-2": Wm2,
"DDPM-6 (truth-tail)": Wm6t,
"DDPM-6 (min-tail)": Wm6lo,
"DDPM-6 (max-tail)": Wm6hi,
},
om_ax,
s8_ax,
tom,
ts8,
ix,
out,
)
if not args.no_single_fig6:
integrate_figure6_with_ddpm2(Wm2, om_ax, s8_ax, ta2om, ta2s8, ix, out, model_name="DDPM-2")
integrate_figure6_with_ddpm2(Wm6t, om_ax, s8_ax, tom, ts8, ix, out, model_name="DDPM-6-truth")
del m2, m6
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Six anchors: multi grids for DDPM-2 + DDPM-6 truth tail ---
if not args.no_six_grid:
print(">>> Six-anchor Figure 6 grids...")
model2, cfg2 = dps.load_model(aj2, ck2, device)
model6, cfg6 = dps.load_model(aj6, ck6, device)
nz2 = bool(cfg2.get("normalize_labels", True))
nz6 = bool(cfg6.get("normalize_labels", True))
post2: list[np.ndarray] = []
post6: list[np.ndarray] = []
truths: list[tuple[float, float]] = []
indices: list[int] = []
for k, jx in enumerate(anchor_ix.ravel()):
indices.append(int(jx))
o2 = imgs2[jx]
lb2 = lab2[jx].astype(np.float64)
f2, oa, sa = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2)
W2, _, _ = dps.posterior_weights(
o2,
f2,
oa,
sa,
mean2,
std2,
nz2,
model2,
H=int(o2.shape[-2]),
W=int(o2.shape[-1]),
device=device,
grid=args.grid,
batch_sz=args.batch_size,
ddim_steps=args.ddim_steps,
)
post2.append(W2)
o6 = imgs6[jx]
lb6 = lab6[jx]
tail_truth = lb6.astype(np.float32)[2:6]
f6, oa6, sa6 = dps.build_full_grid_2d(lab6, args.grid, tail=tail_truth, lab_dim=6)
W6, _, _ = dps.posterior_weights(
o6,
f6,
oa6,
sa6,
mean6,
std6,
nz6,
model6,
H=int(o6.shape[-2]),
W=int(o6.shape[-1]),
device=device,
grid=args.grid,
batch_sz=args.batch_size,
ddim_steps=args.ddim_steps,
)
post6.append(W6)
truths.append((float(lb2[0]), float(lb2[1])))
integrate_figure6_with_multi_anchor(
post2,
oa,
sa,
truths,
indices,
out,
model_name="DDPM-2",
)
truths6 = [(float(lab6[int(j)][0]), float(lab6[int(j)][1])) for j in anchor_ix]
integrate_figure6_with_multi_anchor(
post6,
oa6,
sa6,
truths6,
indices,
out,
model_name="DDPM-6-truth-tail",
)
del model2, model6
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Done. Outputs in {out}")
if __name__ == "__main__":
main()