DDPM-2param / cross_model /scripts /figure6_2409_style.py
collins909's picture
Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
c496462 verified
"""
Figure-6-inspired layout for 2-parameter posteriors (arXiv:2409.09101 style):
main 2D panel with 1D marginals on adjacent edges — marginal sums vs profiles (max) optional.
"""
from __future__ import annotations
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec as mgs
from sigma_contour_utils import compute_sigma_levels
def create_figure6_style_plot(
Wmap: np.ndarray,
om_ax: np.ndarray,
s8_ax: np.ndarray,
*,
true_param1: float,
true_param2: float,
param1_label: str = r"$\Omega_m$",
param2_label: str = r"$\sigma_8$",
title: str = "",
show_profile: bool = False,
figsize: Tuple[float, float] = (10, 10),
):
"""
Parameters
----------
Wmap : (G, G) posterior masses on grid (same layout as DDPM OM meshgrid with indexing='ij').
om_ax, s8_ax : 1-D grids aligned with axes 0 and 1 of ``Wmap``.
show_profile :
False → 1D marginals are sums (*marginal*) over the other parameter.
True → 1D marginals are max (*profile*) over the other parameter (then normalized).
"""
p = np.asarray(Wmap, dtype=np.float64)
p = p / (p.sum() + 1e-30)
P1, P2 = np.meshgrid(om_ax, s8_ax, indexing="ij")
if show_profile:
m1 = np.max(p, axis=1)
m2 = np.max(p, axis=0)
else:
m1 = p.sum(axis=1)
m2 = p.sum(axis=0)
m1 = np.asarray(m1, dtype=np.float64)
m2 = np.asarray(m2, dtype=np.float64)
m1 /= m1.max() + 1e-30
m2 /= m2.max() + 1e-30
fig = plt.figure(figsize=figsize)
gs = mgs.GridSpec(
nrows=2,
ncols=2,
figure=fig,
width_ratios=[4.0, 1.05],
height_ratios=[1.05, 4.0],
wspace=0.035,
hspace=0.035,
left=0.12,
right=0.98,
bottom=0.1,
top=0.92,
)
ax_main = fig.add_subplot(gs[1, 0])
ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main)
ax_r = fig.add_subplot(gs[1, 1], sharey=ax_main)
ax_empty = fig.add_subplot(gs[0, 1])
ax_empty.axis("off")
lvl = compute_sigma_levels(p, [0.683, 0.954])
ax_main.contourf(P1, P2, p, levels=20, cmap="Blues", alpha=0.88)
if len(set(lvl)) >= 2:
ax_main.contour(P1, P2, p, levels=lvl, colors=["darkblue", "steelblue"], linewidths=[2.0, 1.5])
ax_main.scatter(
true_param1,
true_param2,
s=120,
c="red",
marker="x",
linewidths=2.8,
zorder=15,
label="true",
)
ax_main.set_xlabel(param1_label, fontsize=13)
ax_main.set_ylabel(param2_label, fontsize=13)
ax_main.grid(True, alpha=0.28)
ax_main.legend(fontsize=8, loc="upper right")
ax_top.fill_between(om_ax, 0.0, m1, alpha=0.62, color="steelblue")
ax_top.axvline(true_param1, color="red", ls="--", lw=2.0)
ax_top.set_ylim(0.0, float(np.max(m1) * 1.12))
ax_top.set_ylabel("$P(\\mathrm{prof.})$" if show_profile else "$P(\\mathrm{margin.})$", fontsize=10)
ax_top.tick_params(labelbottom=False)
ax_top.grid(True, alpha=0.25)
ax_r.fill_betweenx(s8_ax, 0.0, m2, alpha=0.62, color="steelblue")
ax_r.axhline(true_param2, color="red", ls="--", lw=2.0)
ax_r.set_xlim(0.0, float(np.max(m2) * 1.12))
ax_r.set_xlabel("$P$", fontsize=10)
ax_r.tick_params(labelleft=False)
ax_r.grid(True, alpha=0.25)
kind = "Profile" if show_profile else "Marginal"
fig.suptitle(f"{title} ({kind})", fontsize=14, fontweight="bold", y=0.98)
plt.setp(ax_top.get_xticklabels(), visible=False)
return fig
def create_comparison_marginal_vs_profile(
Wmap: np.ndarray,
om_ax: np.ndarray,
s8_ax: np.ndarray,
*,
true_param1: float,
true_param2: float,
param1_label: str = r"$\Omega_m$",
param2_label: str = r"$\sigma_8$",
title: str = "",
figsize: Tuple[float, float] = (10, 4.2),
):
"""Two rows: Ωm and σ8 marginals (sum) vs profile (max) on shared parameter axes."""
p = np.asarray(Wmap, dtype=np.float64)
p /= p.sum() + 1e-30
marg_om = p.sum(axis=1)
marg_s8 = p.sum(axis=0)
prof_om = np.max(p, axis=1)
prof_s8 = np.max(p, axis=0)
marg_om /= marg_om.sum() + 1e-30
marg_s8 /= marg_s8.sum() + 1e-30
prof_om /= prof_om.max() + 1e-30
prof_s8 /= prof_s8.max() + 1e-30
fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=False)
for ax, xaxis, marg, prof, xlab, xv in zip(
axes,
(om_ax, s8_ax),
(marg_om, marg_s8),
(prof_om, prof_s8),
(param1_label, param2_label),
(true_param1, true_param2),
):
ax.plot(xaxis, marg, lw=2.0, ls="-", label="marginal")
ax.plot(xaxis, prof, lw=2.0, ls="--", label="profile")
ax.axvline(xv, color="crimson", ls=":", lw=1.8)
ax.set_xlabel(xlab, fontsize=12)
ax.set_ylabel("norm. density", fontsize=10)
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
fig.suptitle(title, fontsize=12, fontweight="bold")
fig.tight_layout(rect=(0, 0, 1, 0.93))
return fig