""" Figure-6-inspired layout for 2-parameter posteriors (arXiv:2409.09101 style): main 2D panel with 1D marginals on adjacent edges — marginal sums vs profiles (max) optional. """ from __future__ import annotations from typing import Tuple import matplotlib.pyplot as plt import numpy as np from matplotlib import gridspec as mgs from sigma_contour_utils import compute_sigma_levels def create_figure6_style_plot( Wmap: np.ndarray, om_ax: np.ndarray, s8_ax: np.ndarray, *, true_param1: float, true_param2: float, param1_label: str = r"$\Omega_m$", param2_label: str = r"$\sigma_8$", title: str = "", show_profile: bool = False, figsize: Tuple[float, float] = (10, 10), ): """ Parameters ---------- Wmap : (G, G) posterior masses on grid (same layout as DDPM OM meshgrid with indexing='ij'). om_ax, s8_ax : 1-D grids aligned with axes 0 and 1 of ``Wmap``. show_profile : False → 1D marginals are sums (*marginal*) over the other parameter. True → 1D marginals are max (*profile*) over the other parameter (then normalized). """ p = np.asarray(Wmap, dtype=np.float64) p = p / (p.sum() + 1e-30) P1, P2 = np.meshgrid(om_ax, s8_ax, indexing="ij") if show_profile: m1 = np.max(p, axis=1) m2 = np.max(p, axis=0) else: m1 = p.sum(axis=1) m2 = p.sum(axis=0) m1 = np.asarray(m1, dtype=np.float64) m2 = np.asarray(m2, dtype=np.float64) m1 /= m1.max() + 1e-30 m2 /= m2.max() + 1e-30 fig = plt.figure(figsize=figsize) gs = mgs.GridSpec( nrows=2, ncols=2, figure=fig, width_ratios=[4.0, 1.05], height_ratios=[1.05, 4.0], wspace=0.035, hspace=0.035, left=0.12, right=0.98, bottom=0.1, top=0.92, ) ax_main = fig.add_subplot(gs[1, 0]) ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main) ax_r = fig.add_subplot(gs[1, 1], sharey=ax_main) ax_empty = fig.add_subplot(gs[0, 1]) ax_empty.axis("off") lvl = compute_sigma_levels(p, [0.683, 0.954]) ax_main.contourf(P1, P2, p, levels=20, cmap="Blues", alpha=0.88) if len(set(lvl)) >= 2: ax_main.contour(P1, P2, p, levels=lvl, colors=["darkblue", "steelblue"], linewidths=[2.0, 1.5]) ax_main.scatter( true_param1, true_param2, s=120, c="red", marker="x", linewidths=2.8, zorder=15, label="true", ) ax_main.set_xlabel(param1_label, fontsize=13) ax_main.set_ylabel(param2_label, fontsize=13) ax_main.grid(True, alpha=0.28) ax_main.legend(fontsize=8, loc="upper right") ax_top.fill_between(om_ax, 0.0, m1, alpha=0.62, color="steelblue") ax_top.axvline(true_param1, color="red", ls="--", lw=2.0) ax_top.set_ylim(0.0, float(np.max(m1) * 1.12)) ax_top.set_ylabel("$P(\\mathrm{prof.})$" if show_profile else "$P(\\mathrm{margin.})$", fontsize=10) ax_top.tick_params(labelbottom=False) ax_top.grid(True, alpha=0.25) ax_r.fill_betweenx(s8_ax, 0.0, m2, alpha=0.62, color="steelblue") ax_r.axhline(true_param2, color="red", ls="--", lw=2.0) ax_r.set_xlim(0.0, float(np.max(m2) * 1.12)) ax_r.set_xlabel("$P$", fontsize=10) ax_r.tick_params(labelleft=False) ax_r.grid(True, alpha=0.25) kind = "Profile" if show_profile else "Marginal" fig.suptitle(f"{title} ({kind})", fontsize=14, fontweight="bold", y=0.98) plt.setp(ax_top.get_xticklabels(), visible=False) return fig def create_comparison_marginal_vs_profile( Wmap: np.ndarray, om_ax: np.ndarray, s8_ax: np.ndarray, *, true_param1: float, true_param2: float, param1_label: str = r"$\Omega_m$", param2_label: str = r"$\sigma_8$", title: str = "", figsize: Tuple[float, float] = (10, 4.2), ): """Two rows: Ωm and σ8 marginals (sum) vs profile (max) on shared parameter axes.""" p = np.asarray(Wmap, dtype=np.float64) p /= p.sum() + 1e-30 marg_om = p.sum(axis=1) marg_s8 = p.sum(axis=0) prof_om = np.max(p, axis=1) prof_s8 = np.max(p, axis=0) marg_om /= marg_om.sum() + 1e-30 marg_s8 /= marg_s8.sum() + 1e-30 prof_om /= prof_om.max() + 1e-30 prof_s8 /= prof_s8.max() + 1e-30 fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=False) for ax, xaxis, marg, prof, xlab, xv in zip( axes, (om_ax, s8_ax), (marg_om, marg_s8), (prof_om, prof_s8), (param1_label, param2_label), (true_param1, true_param2), ): ax.plot(xaxis, marg, lw=2.0, ls="-", label="marginal") ax.plot(xaxis, prof, lw=2.0, ls="--", label="profile") ax.axvline(xv, color="crimson", ls=":", lw=1.8) ax.set_xlabel(xlab, fontsize=12) ax.set_ylabel("norm. density", fontsize=10) ax.legend(fontsize=9) ax.grid(True, alpha=0.3) fig.suptitle(title, fontsize=12, fontweight="bold") fig.tight_layout(rect=(0, 0, 1, 0.93)) return fig