File size: 4,204 Bytes
c496462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
"""
Corner-style triangle plot for surrogate $(\\Omega_m,\\sigma_8)$ chains from ``ddpm_triangle_integration.py``.

Loads one or two ``.npz`` files (keys ``omega_m``, ``sigma_8`` / ``samples``, ``truth_*``) and draws
1D marginals + 2D density. If you substitute a script from your Downloads, keep ``--inputs``
and the expected ``.npz`` keys compatible.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np


def _load_chain(path: Path) -> tuple[np.ndarray, np.ndarray, tuple[float, float] | None]:
    d = np.load(path, allow_pickle=True)
    if "samples" in d:
        s = np.asarray(d["samples"], dtype=np.float64)
        om, s8 = s[:, 0], s[:, 1]
    else:
        om = np.asarray(d["omega_m"], dtype=np.float64).ravel()
        s8 = np.asarray(d["sigma_8"], dtype=np.float64).ravel()
    truth = None
    if "truth_Omega_m" in d.files and "truth_sigma_8" in d.files:
        truth = (float(d["truth_Omega_m"]), float(d["truth_sigma_8"]))
    return om, s8, truth


def main() -> None:
    p = argparse.ArgumentParser(description="Triangle / corner plot for Ωm–σ8 surrogate chains.")
    p.add_argument(
        "--inputs",
        "-i",
        nargs="+",
        type=Path,
        required=True,
        help="One or two .npz outputs from ddpm_triangle_integration.py",
    )
    p.add_argument(
        "--labels",
        nargs="*",
        default=None,
        help="Legend entries (default: paths' stems).",
    )
    p.add_argument(
        "--output",
        "-o",
        type=Path,
        default=None,
        help="Output PNG (default: triangle_posterior_ddpm2_ddpm6.png next to first input).",
    )
    p.add_argument("--bins-1d", type=int, default=40)
    p.add_argument("--bins-2d", type=int, default=45)
    args = p.parse_args()

    paths = [Path(x).resolve() for x in args.inputs]
    names = args.labels if args.labels else [p.stem for p in paths]
    if len(names) != len(paths):
        raise SystemExit("--labels count must match --inputs")

    colors = ("#1f77b4", "#d95f02", "#2ca02c")
    fig = plt.figure(figsize=(8.2, 8.0))

    ax00 = fig.add_axes([0.1, 0.55, 0.35, 0.35])
    ax_cont = fig.add_axes([0.1, 0.1, 0.35, 0.35])
    ax11 = fig.add_axes([0.55, 0.1, 0.35, 0.35])
    ax_blank = fig.add_axes([0.55, 0.55, 0.35, 0.35])
    ax_blank.axis("off")

    for i, path in enumerate(paths):
        om, s8, truth = _load_chain(path)
        c = colors[i % len(colors)]
        ax00.hist(
            om,
            bins=args.bins_1d,
            density=True,
            histtype="step",
            color=c,
            lw=2.0,
            label=names[i],
        )
        ax11.hist(
            s8,
            bins=args.bins_1d,
            density=True,
            histtype="step",
            color=c,
            lw=2.0,
        )
        h2, xe, ye = np.histogram2d(om, s8, bins=args.bins_2d, density=True)
        xc = 0.5 * (xe[1:] + xe[:-1])
        yc = 0.5 * (ye[1:] + ye[:-1])
        X, Y = np.meshgrid(xc, yc, indexing="ij")
        Z = np.ma.masked_where(h2.T <= 1e-20, h2.T)
        if i == 0 and np.ma.count(Z) > 0:
            cf = ax_cont.contourf(X, Y, Z, alpha=0.45, cmap="Blues")
            fig.colorbar(cf, ax=ax_cont, fraction=0.046, pad=0.04)
        elif np.ma.count(Z) > 0:
            ax_cont.contour(X, Y, Z, colors=[c], linewidths=[1.85])
        if truth:
            tx, ty = truth
            ax_cont.scatter(tx, ty, marker="x", s=88, color=c, zorder=6)

    ax00.set_title(r"$P(\Omega_m)$ marginal")
    ax00.set_ylabel("density")
    ax00.legend(fontsize=8, loc="upper right")
    ax_cont.set_title(r"$2D$ surrogate posterior density")
    ax_cont.set_xlabel(r"$\Omega_m$")
    ax_cont.set_ylabel(r"$\sigma_8$")
    ax11.set_title(r"$P(\sigma_8)$ marginal")
    ax11.set_xlabel("density")

    out = args.output or (paths[0].parent / ("triangle_" + "_".join(p.stem for p in paths) + ".png"))
    out.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out, dpi=170, bbox_inches="tight")
    plt.close(fig)
    print("Saved", out)


if __name__ == "__main__":
    main()