DDPM-6param / cross_model /scripts /triangle_plot_posterior.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified
#!/usr/bin/env python3
"""
Corner-style triangle plot for surrogate $(\\Omega_m,\\sigma_8)$ chains from ``ddpm_triangle_integration.py``.
Loads one or two ``.npz`` files (keys ``omega_m``, ``sigma_8`` / ``samples``, ``truth_*``) and draws
1D marginals + 2D density. If you substitute a script from your Downloads, keep ``--inputs``
and the expected ``.npz`` keys compatible.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
def _load_chain(path: Path) -> tuple[np.ndarray, np.ndarray, tuple[float, float] | None]:
d = np.load(path, allow_pickle=True)
if "samples" in d:
s = np.asarray(d["samples"], dtype=np.float64)
om, s8 = s[:, 0], s[:, 1]
else:
om = np.asarray(d["omega_m"], dtype=np.float64).ravel()
s8 = np.asarray(d["sigma_8"], dtype=np.float64).ravel()
truth = None
if "truth_Omega_m" in d.files and "truth_sigma_8" in d.files:
truth = (float(d["truth_Omega_m"]), float(d["truth_sigma_8"]))
return om, s8, truth
def main() -> None:
p = argparse.ArgumentParser(description="Triangle / corner plot for Ωm–σ8 surrogate chains.")
p.add_argument(
"--inputs",
"-i",
nargs="+",
type=Path,
required=True,
help="One or two .npz outputs from ddpm_triangle_integration.py",
)
p.add_argument(
"--labels",
nargs="*",
default=None,
help="Legend entries (default: paths' stems).",
)
p.add_argument(
"--output",
"-o",
type=Path,
default=None,
help="Output PNG (default: triangle_posterior_ddpm2_ddpm6.png next to first input).",
)
p.add_argument("--bins-1d", type=int, default=40)
p.add_argument("--bins-2d", type=int, default=45)
args = p.parse_args()
paths = [Path(x).resolve() for x in args.inputs]
names = args.labels if args.labels else [p.stem for p in paths]
if len(names) != len(paths):
raise SystemExit("--labels count must match --inputs")
colors = ("#1f77b4", "#d95f02", "#2ca02c")
fig = plt.figure(figsize=(8.2, 8.0))
ax00 = fig.add_axes([0.1, 0.55, 0.35, 0.35])
ax_cont = fig.add_axes([0.1, 0.1, 0.35, 0.35])
ax11 = fig.add_axes([0.55, 0.1, 0.35, 0.35])
ax_blank = fig.add_axes([0.55, 0.55, 0.35, 0.35])
ax_blank.axis("off")
for i, path in enumerate(paths):
om, s8, truth = _load_chain(path)
c = colors[i % len(colors)]
ax00.hist(
om,
bins=args.bins_1d,
density=True,
histtype="step",
color=c,
lw=2.0,
label=names[i],
)
ax11.hist(
s8,
bins=args.bins_1d,
density=True,
histtype="step",
color=c,
lw=2.0,
)
h2, xe, ye = np.histogram2d(om, s8, bins=args.bins_2d, density=True)
xc = 0.5 * (xe[1:] + xe[:-1])
yc = 0.5 * (ye[1:] + ye[:-1])
X, Y = np.meshgrid(xc, yc, indexing="ij")
Z = np.ma.masked_where(h2.T <= 1e-20, h2.T)
if i == 0 and np.ma.count(Z) > 0:
cf = ax_cont.contourf(X, Y, Z, alpha=0.45, cmap="Blues")
fig.colorbar(cf, ax=ax_cont, fraction=0.046, pad=0.04)
elif np.ma.count(Z) > 0:
ax_cont.contour(X, Y, Z, colors=[c], linewidths=[1.85])
if truth:
tx, ty = truth
ax_cont.scatter(tx, ty, marker="x", s=88, color=c, zorder=6)
ax00.set_title(r"$P(\Omega_m)$ marginal")
ax00.set_ylabel("density")
ax00.legend(fontsize=8, loc="upper right")
ax_cont.set_title(r"$2D$ surrogate posterior density")
ax_cont.set_xlabel(r"$\Omega_m$")
ax_cont.set_ylabel(r"$\sigma_8$")
ax11.set_title(r"$P(\sigma_8)$ marginal")
ax11.set_xlabel("density")
out = args.output or (paths[0].parent / ("triangle_" + "_".join(p.stem for p in paths) + ".png"))
out.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out, dpi=170, bbox_inches="tight")
plt.close(fig)
print("Saved", out)
if __name__ == "__main__":
main()