File size: 5,033 Bytes
c496462 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """
Figure-6-inspired layout for 2-parameter posteriors (arXiv:2409.09101 style):
main 2D panel with 1D marginals on adjacent edges — marginal sums vs profiles (max) optional.
"""
from __future__ import annotations
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec as mgs
from sigma_contour_utils import compute_sigma_levels
def create_figure6_style_plot(
Wmap: np.ndarray,
om_ax: np.ndarray,
s8_ax: np.ndarray,
*,
true_param1: float,
true_param2: float,
param1_label: str = r"$\Omega_m$",
param2_label: str = r"$\sigma_8$",
title: str = "",
show_profile: bool = False,
figsize: Tuple[float, float] = (10, 10),
):
"""
Parameters
----------
Wmap : (G, G) posterior masses on grid (same layout as DDPM OM meshgrid with indexing='ij').
om_ax, s8_ax : 1-D grids aligned with axes 0 and 1 of ``Wmap``.
show_profile :
False → 1D marginals are sums (*marginal*) over the other parameter.
True → 1D marginals are max (*profile*) over the other parameter (then normalized).
"""
p = np.asarray(Wmap, dtype=np.float64)
p = p / (p.sum() + 1e-30)
P1, P2 = np.meshgrid(om_ax, s8_ax, indexing="ij")
if show_profile:
m1 = np.max(p, axis=1)
m2 = np.max(p, axis=0)
else:
m1 = p.sum(axis=1)
m2 = p.sum(axis=0)
m1 = np.asarray(m1, dtype=np.float64)
m2 = np.asarray(m2, dtype=np.float64)
m1 /= m1.max() + 1e-30
m2 /= m2.max() + 1e-30
fig = plt.figure(figsize=figsize)
gs = mgs.GridSpec(
nrows=2,
ncols=2,
figure=fig,
width_ratios=[4.0, 1.05],
height_ratios=[1.05, 4.0],
wspace=0.035,
hspace=0.035,
left=0.12,
right=0.98,
bottom=0.1,
top=0.92,
)
ax_main = fig.add_subplot(gs[1, 0])
ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main)
ax_r = fig.add_subplot(gs[1, 1], sharey=ax_main)
ax_empty = fig.add_subplot(gs[0, 1])
ax_empty.axis("off")
lvl = compute_sigma_levels(p, [0.683, 0.954])
ax_main.contourf(P1, P2, p, levels=20, cmap="Blues", alpha=0.88)
if len(set(lvl)) >= 2:
ax_main.contour(P1, P2, p, levels=lvl, colors=["darkblue", "steelblue"], linewidths=[2.0, 1.5])
ax_main.scatter(
true_param1,
true_param2,
s=120,
c="red",
marker="x",
linewidths=2.8,
zorder=15,
label="true",
)
ax_main.set_xlabel(param1_label, fontsize=13)
ax_main.set_ylabel(param2_label, fontsize=13)
ax_main.grid(True, alpha=0.28)
ax_main.legend(fontsize=8, loc="upper right")
ax_top.fill_between(om_ax, 0.0, m1, alpha=0.62, color="steelblue")
ax_top.axvline(true_param1, color="red", ls="--", lw=2.0)
ax_top.set_ylim(0.0, float(np.max(m1) * 1.12))
ax_top.set_ylabel("$P(\\mathrm{prof.})$" if show_profile else "$P(\\mathrm{margin.})$", fontsize=10)
ax_top.tick_params(labelbottom=False)
ax_top.grid(True, alpha=0.25)
ax_r.fill_betweenx(s8_ax, 0.0, m2, alpha=0.62, color="steelblue")
ax_r.axhline(true_param2, color="red", ls="--", lw=2.0)
ax_r.set_xlim(0.0, float(np.max(m2) * 1.12))
ax_r.set_xlabel("$P$", fontsize=10)
ax_r.tick_params(labelleft=False)
ax_r.grid(True, alpha=0.25)
kind = "Profile" if show_profile else "Marginal"
fig.suptitle(f"{title} ({kind})", fontsize=14, fontweight="bold", y=0.98)
plt.setp(ax_top.get_xticklabels(), visible=False)
return fig
def create_comparison_marginal_vs_profile(
Wmap: np.ndarray,
om_ax: np.ndarray,
s8_ax: np.ndarray,
*,
true_param1: float,
true_param2: float,
param1_label: str = r"$\Omega_m$",
param2_label: str = r"$\sigma_8$",
title: str = "",
figsize: Tuple[float, float] = (10, 4.2),
):
"""Two rows: Ωm and σ8 marginals (sum) vs profile (max) on shared parameter axes."""
p = np.asarray(Wmap, dtype=np.float64)
p /= p.sum() + 1e-30
marg_om = p.sum(axis=1)
marg_s8 = p.sum(axis=0)
prof_om = np.max(p, axis=1)
prof_s8 = np.max(p, axis=0)
marg_om /= marg_om.sum() + 1e-30
marg_s8 /= marg_s8.sum() + 1e-30
prof_om /= prof_om.max() + 1e-30
prof_s8 /= prof_s8.max() + 1e-30
fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=False)
for ax, xaxis, marg, prof, xlab, xv in zip(
axes,
(om_ax, s8_ax),
(marg_om, marg_s8),
(prof_om, prof_s8),
(param1_label, param2_label),
(true_param1, true_param2),
):
ax.plot(xaxis, marg, lw=2.0, ls="-", label="marginal")
ax.plot(xaxis, prof, lw=2.0, ls="--", label="profile")
ax.axvline(xv, color="crimson", ls=":", lw=1.8)
ax.set_xlabel(xlab, fontsize=12)
ax.set_ylabel("norm. density", fontsize=10)
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
fig.suptitle(title, fontsize=12, fontweight="bold")
fig.tight_layout(rect=(0, 0, 1, 0.93))
return fig
|