#!/usr/bin/env python3 """ Corner-style triangle plot for surrogate $(\\Omega_m,\\sigma_8)$ chains from ``ddpm_triangle_integration.py``. Loads one or two ``.npz`` files (keys ``omega_m``, ``sigma_8`` / ``samples``, ``truth_*``) and draws 1D marginals + 2D density. If you substitute a script from your Downloads, keep ``--inputs`` and the expected ``.npz`` keys compatible. """ from __future__ import annotations import argparse from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np def _load_chain(path: Path) -> tuple[np.ndarray, np.ndarray, tuple[float, float] | None]: d = np.load(path, allow_pickle=True) if "samples" in d: s = np.asarray(d["samples"], dtype=np.float64) om, s8 = s[:, 0], s[:, 1] else: om = np.asarray(d["omega_m"], dtype=np.float64).ravel() s8 = np.asarray(d["sigma_8"], dtype=np.float64).ravel() truth = None if "truth_Omega_m" in d.files and "truth_sigma_8" in d.files: truth = (float(d["truth_Omega_m"]), float(d["truth_sigma_8"])) return om, s8, truth def main() -> None: p = argparse.ArgumentParser(description="Triangle / corner plot for Ωm–σ8 surrogate chains.") p.add_argument( "--inputs", "-i", nargs="+", type=Path, required=True, help="One or two .npz outputs from ddpm_triangle_integration.py", ) p.add_argument( "--labels", nargs="*", default=None, help="Legend entries (default: paths' stems).", ) p.add_argument( "--output", "-o", type=Path, default=None, help="Output PNG (default: triangle_posterior_ddpm2_ddpm6.png next to first input).", ) p.add_argument("--bins-1d", type=int, default=40) p.add_argument("--bins-2d", type=int, default=45) args = p.parse_args() paths = [Path(x).resolve() for x in args.inputs] names = args.labels if args.labels else [p.stem for p in paths] if len(names) != len(paths): raise SystemExit("--labels count must match --inputs") colors = ("#1f77b4", "#d95f02", "#2ca02c") fig = plt.figure(figsize=(8.2, 8.0)) ax00 = fig.add_axes([0.1, 0.55, 0.35, 0.35]) ax_cont = fig.add_axes([0.1, 0.1, 0.35, 0.35]) ax11 = fig.add_axes([0.55, 0.1, 0.35, 0.35]) ax_blank = fig.add_axes([0.55, 0.55, 0.35, 0.35]) ax_blank.axis("off") for i, path in enumerate(paths): om, s8, truth = _load_chain(path) c = colors[i % len(colors)] ax00.hist( om, bins=args.bins_1d, density=True, histtype="step", color=c, lw=2.0, label=names[i], ) ax11.hist( s8, bins=args.bins_1d, density=True, histtype="step", color=c, lw=2.0, ) h2, xe, ye = np.histogram2d(om, s8, bins=args.bins_2d, density=True) xc = 0.5 * (xe[1:] + xe[:-1]) yc = 0.5 * (ye[1:] + ye[:-1]) X, Y = np.meshgrid(xc, yc, indexing="ij") Z = np.ma.masked_where(h2.T <= 1e-20, h2.T) if i == 0 and np.ma.count(Z) > 0: cf = ax_cont.contourf(X, Y, Z, alpha=0.45, cmap="Blues") fig.colorbar(cf, ax=ax_cont, fraction=0.046, pad=0.04) elif np.ma.count(Z) > 0: ax_cont.contour(X, Y, Z, colors=[c], linewidths=[1.85]) if truth: tx, ty = truth ax_cont.scatter(tx, ty, marker="x", s=88, color=c, zorder=6) ax00.set_title(r"$P(\Omega_m)$ marginal") ax00.set_ylabel("density") ax00.legend(fontsize=8, loc="upper right") ax_cont.set_title(r"$2D$ surrogate posterior density") ax_cont.set_xlabel(r"$\Omega_m$") ax_cont.set_ylabel(r"$\sigma_8$") ax11.set_title(r"$P(\sigma_8)$ marginal") ax11.set_xlabel("density") out = args.output or (paths[0].parent / ("triangle_" + "_".join(p.stem for p in paths) + ".png")) out.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out, dpi=170, bbox_inches="tight") plt.close(fig) print("Saved", out) if __name__ == "__main__": main()