Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified | #!/usr/bin/env python3 | |
| """ | |
| Corner-style triangle plot for surrogate $(\\Omega_m,\\sigma_8)$ chains from ``ddpm_triangle_integration.py``. | |
| Loads one or two ``.npz`` files (keys ``omega_m``, ``sigma_8`` / ``samples``, ``truth_*``) and draws | |
| 1D marginals + 2D density. If you substitute a script from your Downloads, keep ``--inputs`` | |
| and the expected ``.npz`` keys compatible. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def _load_chain(path: Path) -> tuple[np.ndarray, np.ndarray, tuple[float, float] | None]: | |
| d = np.load(path, allow_pickle=True) | |
| if "samples" in d: | |
| s = np.asarray(d["samples"], dtype=np.float64) | |
| om, s8 = s[:, 0], s[:, 1] | |
| else: | |
| om = np.asarray(d["omega_m"], dtype=np.float64).ravel() | |
| s8 = np.asarray(d["sigma_8"], dtype=np.float64).ravel() | |
| truth = None | |
| if "truth_Omega_m" in d.files and "truth_sigma_8" in d.files: | |
| truth = (float(d["truth_Omega_m"]), float(d["truth_sigma_8"])) | |
| return om, s8, truth | |
| def main() -> None: | |
| p = argparse.ArgumentParser(description="Triangle / corner plot for Ωm–σ8 surrogate chains.") | |
| p.add_argument( | |
| "--inputs", | |
| "-i", | |
| nargs="+", | |
| type=Path, | |
| required=True, | |
| help="One or two .npz outputs from ddpm_triangle_integration.py", | |
| ) | |
| p.add_argument( | |
| "--labels", | |
| nargs="*", | |
| default=None, | |
| help="Legend entries (default: paths' stems).", | |
| ) | |
| p.add_argument( | |
| "--output", | |
| "-o", | |
| type=Path, | |
| default=None, | |
| help="Output PNG (default: triangle_posterior_ddpm2_ddpm6.png next to first input).", | |
| ) | |
| p.add_argument("--bins-1d", type=int, default=40) | |
| p.add_argument("--bins-2d", type=int, default=45) | |
| args = p.parse_args() | |
| paths = [Path(x).resolve() for x in args.inputs] | |
| names = args.labels if args.labels else [p.stem for p in paths] | |
| if len(names) != len(paths): | |
| raise SystemExit("--labels count must match --inputs") | |
| colors = ("#1f77b4", "#d95f02", "#2ca02c") | |
| fig = plt.figure(figsize=(8.2, 8.0)) | |
| ax00 = fig.add_axes([0.1, 0.55, 0.35, 0.35]) | |
| ax_cont = fig.add_axes([0.1, 0.1, 0.35, 0.35]) | |
| ax11 = fig.add_axes([0.55, 0.1, 0.35, 0.35]) | |
| ax_blank = fig.add_axes([0.55, 0.55, 0.35, 0.35]) | |
| ax_blank.axis("off") | |
| for i, path in enumerate(paths): | |
| om, s8, truth = _load_chain(path) | |
| c = colors[i % len(colors)] | |
| ax00.hist( | |
| om, | |
| bins=args.bins_1d, | |
| density=True, | |
| histtype="step", | |
| color=c, | |
| lw=2.0, | |
| label=names[i], | |
| ) | |
| ax11.hist( | |
| s8, | |
| bins=args.bins_1d, | |
| density=True, | |
| histtype="step", | |
| color=c, | |
| lw=2.0, | |
| ) | |
| h2, xe, ye = np.histogram2d(om, s8, bins=args.bins_2d, density=True) | |
| xc = 0.5 * (xe[1:] + xe[:-1]) | |
| yc = 0.5 * (ye[1:] + ye[:-1]) | |
| X, Y = np.meshgrid(xc, yc, indexing="ij") | |
| Z = np.ma.masked_where(h2.T <= 1e-20, h2.T) | |
| if i == 0 and np.ma.count(Z) > 0: | |
| cf = ax_cont.contourf(X, Y, Z, alpha=0.45, cmap="Blues") | |
| fig.colorbar(cf, ax=ax_cont, fraction=0.046, pad=0.04) | |
| elif np.ma.count(Z) > 0: | |
| ax_cont.contour(X, Y, Z, colors=[c], linewidths=[1.85]) | |
| if truth: | |
| tx, ty = truth | |
| ax_cont.scatter(tx, ty, marker="x", s=88, color=c, zorder=6) | |
| ax00.set_title(r"$P(\Omega_m)$ marginal") | |
| ax00.set_ylabel("density") | |
| ax00.legend(fontsize=8, loc="upper right") | |
| ax_cont.set_title(r"$2D$ surrogate posterior density") | |
| ax_cont.set_xlabel(r"$\Omega_m$") | |
| ax_cont.set_ylabel(r"$\sigma_8$") | |
| ax11.set_title(r"$P(\sigma_8)$ marginal") | |
| ax11.set_xlabel("density") | |
| out = args.output or (paths[0].parent / ("triangle_" + "_".join(p.stem for p in paths) + ".png")) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(out, dpi=170, bbox_inches="tight") | |
| plt.close(fig) | |
| print("Saved", out) | |
| if __name__ == "__main__": | |
| main() | |