Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
c496462 verified | """ | |
| Figure-6-inspired layout for 2-parameter posteriors (arXiv:2409.09101 style): | |
| main 2D panel with 1D marginals on adjacent edges — marginal sums vs profiles (max) optional. | |
| """ | |
| from __future__ import annotations | |
| from typing import Tuple | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib import gridspec as mgs | |
| from sigma_contour_utils import compute_sigma_levels | |
| def create_figure6_style_plot( | |
| Wmap: np.ndarray, | |
| om_ax: np.ndarray, | |
| s8_ax: np.ndarray, | |
| *, | |
| true_param1: float, | |
| true_param2: float, | |
| param1_label: str = r"$\Omega_m$", | |
| param2_label: str = r"$\sigma_8$", | |
| title: str = "", | |
| show_profile: bool = False, | |
| figsize: Tuple[float, float] = (10, 10), | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| Wmap : (G, G) posterior masses on grid (same layout as DDPM OM meshgrid with indexing='ij'). | |
| om_ax, s8_ax : 1-D grids aligned with axes 0 and 1 of ``Wmap``. | |
| show_profile : | |
| False → 1D marginals are sums (*marginal*) over the other parameter. | |
| True → 1D marginals are max (*profile*) over the other parameter (then normalized). | |
| """ | |
| p = np.asarray(Wmap, dtype=np.float64) | |
| p = p / (p.sum() + 1e-30) | |
| P1, P2 = np.meshgrid(om_ax, s8_ax, indexing="ij") | |
| if show_profile: | |
| m1 = np.max(p, axis=1) | |
| m2 = np.max(p, axis=0) | |
| else: | |
| m1 = p.sum(axis=1) | |
| m2 = p.sum(axis=0) | |
| m1 = np.asarray(m1, dtype=np.float64) | |
| m2 = np.asarray(m2, dtype=np.float64) | |
| m1 /= m1.max() + 1e-30 | |
| m2 /= m2.max() + 1e-30 | |
| fig = plt.figure(figsize=figsize) | |
| gs = mgs.GridSpec( | |
| nrows=2, | |
| ncols=2, | |
| figure=fig, | |
| width_ratios=[4.0, 1.05], | |
| height_ratios=[1.05, 4.0], | |
| wspace=0.035, | |
| hspace=0.035, | |
| left=0.12, | |
| right=0.98, | |
| bottom=0.1, | |
| top=0.92, | |
| ) | |
| ax_main = fig.add_subplot(gs[1, 0]) | |
| ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main) | |
| ax_r = fig.add_subplot(gs[1, 1], sharey=ax_main) | |
| ax_empty = fig.add_subplot(gs[0, 1]) | |
| ax_empty.axis("off") | |
| lvl = compute_sigma_levels(p, [0.683, 0.954]) | |
| ax_main.contourf(P1, P2, p, levels=20, cmap="Blues", alpha=0.88) | |
| if len(set(lvl)) >= 2: | |
| ax_main.contour(P1, P2, p, levels=lvl, colors=["darkblue", "steelblue"], linewidths=[2.0, 1.5]) | |
| ax_main.scatter( | |
| true_param1, | |
| true_param2, | |
| s=120, | |
| c="red", | |
| marker="x", | |
| linewidths=2.8, | |
| zorder=15, | |
| label="true", | |
| ) | |
| ax_main.set_xlabel(param1_label, fontsize=13) | |
| ax_main.set_ylabel(param2_label, fontsize=13) | |
| ax_main.grid(True, alpha=0.28) | |
| ax_main.legend(fontsize=8, loc="upper right") | |
| ax_top.fill_between(om_ax, 0.0, m1, alpha=0.62, color="steelblue") | |
| ax_top.axvline(true_param1, color="red", ls="--", lw=2.0) | |
| ax_top.set_ylim(0.0, float(np.max(m1) * 1.12)) | |
| ax_top.set_ylabel("$P(\\mathrm{prof.})$" if show_profile else "$P(\\mathrm{margin.})$", fontsize=10) | |
| ax_top.tick_params(labelbottom=False) | |
| ax_top.grid(True, alpha=0.25) | |
| ax_r.fill_betweenx(s8_ax, 0.0, m2, alpha=0.62, color="steelblue") | |
| ax_r.axhline(true_param2, color="red", ls="--", lw=2.0) | |
| ax_r.set_xlim(0.0, float(np.max(m2) * 1.12)) | |
| ax_r.set_xlabel("$P$", fontsize=10) | |
| ax_r.tick_params(labelleft=False) | |
| ax_r.grid(True, alpha=0.25) | |
| kind = "Profile" if show_profile else "Marginal" | |
| fig.suptitle(f"{title} ({kind})", fontsize=14, fontweight="bold", y=0.98) | |
| plt.setp(ax_top.get_xticklabels(), visible=False) | |
| return fig | |
| def create_comparison_marginal_vs_profile( | |
| Wmap: np.ndarray, | |
| om_ax: np.ndarray, | |
| s8_ax: np.ndarray, | |
| *, | |
| true_param1: float, | |
| true_param2: float, | |
| param1_label: str = r"$\Omega_m$", | |
| param2_label: str = r"$\sigma_8$", | |
| title: str = "", | |
| figsize: Tuple[float, float] = (10, 4.2), | |
| ): | |
| """Two rows: Ωm and σ8 marginals (sum) vs profile (max) on shared parameter axes.""" | |
| p = np.asarray(Wmap, dtype=np.float64) | |
| p /= p.sum() + 1e-30 | |
| marg_om = p.sum(axis=1) | |
| marg_s8 = p.sum(axis=0) | |
| prof_om = np.max(p, axis=1) | |
| prof_s8 = np.max(p, axis=0) | |
| marg_om /= marg_om.sum() + 1e-30 | |
| marg_s8 /= marg_s8.sum() + 1e-30 | |
| prof_om /= prof_om.max() + 1e-30 | |
| prof_s8 /= prof_s8.max() + 1e-30 | |
| fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=False) | |
| for ax, xaxis, marg, prof, xlab, xv in zip( | |
| axes, | |
| (om_ax, s8_ax), | |
| (marg_om, marg_s8), | |
| (prof_om, prof_s8), | |
| (param1_label, param2_label), | |
| (true_param1, true_param2), | |
| ): | |
| ax.plot(xaxis, marg, lw=2.0, ls="-", label="marginal") | |
| ax.plot(xaxis, prof, lw=2.0, ls="--", label="profile") | |
| ax.axvline(xv, color="crimson", ls=":", lw=1.8) | |
| ax.set_xlabel(xlab, fontsize=12) | |
| ax.set_ylabel("norm. density", fontsize=10) | |
| ax.legend(fontsize=9) | |
| ax.grid(True, alpha=0.3) | |
| fig.suptitle(title, fontsize=12, fontweight="bold") | |
| fig.tight_layout(rect=(0, 0, 1, 0.93)) | |
| return fig | |