DDPM-6param / cross_model /scripts /ddpm_figure6_integration.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified
"""
Figure 6 style (arXiv:2409.09101) helpers for DDPM surrogate posteriors — use with ddpm_posterior_six_anchors / run_ddpm_figure6_suite.
"""
from __future__ import annotations
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
from figure6_2409_style import (
create_comparison_marginal_vs_profile,
create_figure6_style_plot,
)
from sigma_contour_utils import compute_sigma_levels
def integrate_figure6_with_ddpm2(
Wmap: np.ndarray,
om_grid: np.ndarray,
s8_grid: np.ndarray,
true_om: float,
true_s8: float,
test_index: int,
output_dir: Path,
model_name: str = "DDPM-2",
) -> None:
"""Single-map Figure 6 style: marginal and profile PNGs."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
nm = model_name.replace(" ", "-").lower()
fig_marginal = create_figure6_style_plot(
Wmap,
om_grid,
s8_grid,
true_param1=true_om,
true_param2=true_s8,
param1_label=r"$\Omega_m$",
param2_label=r"$\sigma_8$",
title=f"{model_name} — Test ix={test_index} (Marginal)",
show_profile=False,
figsize=(10, 10),
)
save_path_marginal = output_dir / f"fig6_style_{nm}_ix{test_index}_marginal.png"
fig_marginal.savefig(save_path_marginal, dpi=200, bbox_inches="tight")
plt.close(fig_marginal)
print(f"Saved: {save_path_marginal}")
fig_profile = create_figure6_style_plot(
Wmap,
om_grid,
s8_grid,
true_param1=true_om,
true_param2=true_s8,
param1_label=r"$\Omega_m$",
param2_label=r"$\sigma_8$",
title=f"{model_name} — Test ix={test_index} (Profile)",
show_profile=True,
figsize=(10, 10),
)
save_path_profile = output_dir / f"fig6_style_{nm}_ix{test_index}_profile.png"
fig_profile.savefig(save_path_profile, dpi=200, bbox_inches="tight")
plt.close(fig_profile)
print(f"Saved: {save_path_profile}")
fig_cmp = create_comparison_marginal_vs_profile(
Wmap,
om_grid,
s8_grid,
true_param1=true_om,
true_param2=true_s8,
title=f"{model_name} marginal vs profile — ix={test_index}",
figsize=(11, 4.2),
)
cmp_path = output_dir / f"fig6_marg_vs_prof_{nm}_ix{test_index}.png"
fig_cmp.savefig(cmp_path, dpi=185, bbox_inches="tight")
plt.close(fig_cmp)
print(f"Saved: {cmp_path}")
def integrate_figure6_with_multi_anchor(
posteriors_list: list[np.ndarray],
om_grid: np.ndarray,
s8_grid: np.ndarray,
true_values_list: list[tuple[float, float]],
test_indices: list[int],
output_dir: Path,
model_name: str = "DDPM-2",
) -> None:
"""2×3 grid with Figure–6-ish 2D + top marginal."""
output_dir = Path(output_dir)
nm = model_name.replace(" ", "-").lower()
fig = plt.figure(figsize=(20, 14))
gs_o = gridspec.GridSpec(2, 3, figure=fig, hspace=0.33, wspace=0.32)
for idx, (posterior, true_vals, test_ix) in enumerate(
zip(posteriors_list, true_values_list, test_indices)
):
true_om, true_s8 = true_vals
row, col = divmod(idx, 3)
posterior_norm = posterior / posterior.sum()
sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954])
P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij")
gs_sub = gridspec.GridSpecFromSubplotSpec(
2,
2,
subplot_spec=gs_o[row, col],
width_ratios=[4, 1],
height_ratios=[1, 4],
hspace=0.06,
wspace=0.06,
)
ax_main = fig.add_subplot(gs_sub[1, 0])
ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main)
ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85)
if len(set(sigma_levels)) >= 1:
ax_main.contour(
P1,
P2,
posterior_norm,
levels=sigma_levels,
colors=["darkblue", "steelblue"],
linewidths=[2.0, 1.5],
)
ax_main.scatter(true_om, true_s8, s=100, c="red", marker="x", linewidths=2.5, zorder=10)
ax_main.set_xlim(om_grid[0], om_grid[-1])
ax_main.set_ylim(s8_grid[0], s8_grid[-1])
ax_main.set_xlabel(r"$\Omega_m$" if row == 1 else "", fontsize=11)
ax_main.set_ylabel(r"$\sigma_8$" if col == 0 else "", fontsize=11)
ax_main.set_title(f"Test ix={test_ix}", fontsize=11, pad=5)
ax_main.grid(True, alpha=0.2)
marginal_om = posterior_norm.sum(axis=1)
marginal_om /= marginal_om.sum() + 1e-30
ax_top.fill_between(
om_grid,
0.0,
marginal_om,
alpha=0.6,
color="steelblue",
edgecolor="steelblue",
)
ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2)
ax_top.set_xlim(om_grid[0], om_grid[-1])
ax_top.set_ylim(0, marginal_om.max() * 1.1)
ax_top.tick_params(labelbottom=False, labelsize=9)
ax_top.set_ylabel("$P(\\Omega_m)$", fontsize=9)
ax_top.grid(True, alpha=0.2)
ax_side = fig.add_subplot(gs_sub[1, 1], sharey=ax_main)
marginal_s8 = posterior_norm.sum(axis=0)
marginal_s8 /= marginal_s8.sum() + 1e-30
ax_side.fill_betweenx(s8_grid, 0.0, marginal_s8, alpha=0.6, color="steelblue", edgecolor="steelblue")
ax_side.axhline(true_s8, color="red", linestyle="--", linewidth=2)
ax_side.set_ylim(s8_grid[0], s8_grid[-1])
ax_side.set_xlim(0, marginal_s8.max() * 1.15)
ax_side.tick_params(labelleft=False)
fig.suptitle(
f"{model_name} — Figure 6 Style: Six Test Anchors",
fontsize=15,
y=0.995,
fontweight="bold",
)
save_path = output_dir / f"fig6_style_{nm}_all_anchors.png"
fig.savefig(save_path, dpi=200, bbox_inches="tight")
plt.close(fig)
print(f"Saved multi-anchor grid: {save_path}")
def integrate_figure6_model_comparison(
posteriors_dict: dict[str, np.ndarray],
om_grid: np.ndarray,
s8_grid: np.ndarray,
true_om: float,
true_s8: float,
test_index: int,
output_dir: Path,
) -> None:
"""Side-by-side model comparison panels."""
output_dir = Path(output_dir)
n_models = len(posteriors_dict)
fig = plt.figure(figsize=(8 * max(1, min(n_models, 4)), 8))
gs_outer = gridspec.GridSpec(1, n_models, figure=fig, wspace=0.32)
for idx, (model_name, posterior) in enumerate(posteriors_dict.items()):
gs_sub = gridspec.GridSpecFromSubplotSpec(
2,
2,
subplot_spec=gs_outer[0, idx],
width_ratios=[4, 1],
height_ratios=[1, 4],
hspace=0.06,
wspace=0.06,
)
posterior_norm = posterior / posterior.sum()
sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954])
P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij")
ax_main = fig.add_subplot(gs_sub[1, 0])
ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main)
ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85)
if len(set(sigma_levels)) >= 1:
ax_main.contour(
P1,
P2,
posterior_norm,
levels=sigma_levels,
colors=["darkblue", "steelblue"],
linewidths=[2.5, 2.0],
)
ax_main.scatter(true_om, true_s8, s=120, c="red", marker="x", linewidths=3, zorder=10)
ax_main.set_xlabel(r"$\Omega_m$", fontsize=13)
ax_main.set_ylabel(r"$\sigma_8$", fontsize=13)
ax_main.set_title(model_name, fontsize=13, pad=10, fontweight="bold")
ax_main.grid(True, alpha=0.3)
ax_main.set_xlim(om_grid[0], om_grid[-1])
ax_main.set_ylim(s8_grid[0], s8_grid[-1])
marginal = posterior_norm.sum(axis=1)
marginal /= marginal.sum() + 1e-30
ax_top.fill_between(om_grid, 0.0, marginal, alpha=0.6, color="steelblue")
ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2.5)
ax_top.set_xlim(om_grid[0], om_grid[-1])
ax_top.set_ylim(0, marginal.max() * 1.12)
ax_top.tick_params(labelbottom=False)
ax_top.grid(True, alpha=0.25)
ax_sb = fig.add_subplot(gs_sub[1, 1], sharey=ax_main)
marginal_s = posterior_norm.sum(axis=0)
marginal_s /= marginal_s.sum() + 1e-30
ax_sb.fill_betweenx(s8_grid, 0.0, marginal_s, alpha=0.6, color="steelblue")
ax_sb.axhline(true_s8, color="red", linestyle="--", linewidth=2.5)
ax_sb.set_ylim(s8_grid[0], s8_grid[-1])
fig.suptitle(
f"Model Comparison (Figure 6 Style) — Test ix={test_index}",
fontsize=15,
y=0.995,
fontweight="bold",
)
save_path = output_dir / f"fig6_style_model_comparison_ix{test_index}.png"
fig.savefig(save_path, dpi=200, bbox_inches="tight")
plt.close(fig)
print(f"Saved model comparison: {save_path}")
def print_integration_guide() -> None:
example_integration = """
# Add imports next to posterior code:
from figure6_2409_style import create_figure6_style_plot
from ddpm_figure6_integration import (
integrate_figure6_with_ddpm2,
integrate_figure6_with_multi_anchor,
integrate_figure6_model_comparison,
)
"""
print(example_integration.strip())