""" Figure 6 style (arXiv:2409.09101) helpers for DDPM surrogate posteriors — use with ddpm_posterior_six_anchors / run_ddpm_figure6_suite. """ from __future__ import annotations from pathlib import Path import matplotlib.pyplot as plt import numpy as np from matplotlib import gridspec from figure6_2409_style import ( create_comparison_marginal_vs_profile, create_figure6_style_plot, ) from sigma_contour_utils import compute_sigma_levels def integrate_figure6_with_ddpm2( Wmap: np.ndarray, om_grid: np.ndarray, s8_grid: np.ndarray, true_om: float, true_s8: float, test_index: int, output_dir: Path, model_name: str = "DDPM-2", ) -> None: """Single-map Figure 6 style: marginal and profile PNGs.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) nm = model_name.replace(" ", "-").lower() fig_marginal = create_figure6_style_plot( Wmap, om_grid, s8_grid, true_param1=true_om, true_param2=true_s8, param1_label=r"$\Omega_m$", param2_label=r"$\sigma_8$", title=f"{model_name} — Test ix={test_index} (Marginal)", show_profile=False, figsize=(10, 10), ) save_path_marginal = output_dir / f"fig6_style_{nm}_ix{test_index}_marginal.png" fig_marginal.savefig(save_path_marginal, dpi=200, bbox_inches="tight") plt.close(fig_marginal) print(f"Saved: {save_path_marginal}") fig_profile = create_figure6_style_plot( Wmap, om_grid, s8_grid, true_param1=true_om, true_param2=true_s8, param1_label=r"$\Omega_m$", param2_label=r"$\sigma_8$", title=f"{model_name} — Test ix={test_index} (Profile)", show_profile=True, figsize=(10, 10), ) save_path_profile = output_dir / f"fig6_style_{nm}_ix{test_index}_profile.png" fig_profile.savefig(save_path_profile, dpi=200, bbox_inches="tight") plt.close(fig_profile) print(f"Saved: {save_path_profile}") fig_cmp = create_comparison_marginal_vs_profile( Wmap, om_grid, s8_grid, true_param1=true_om, true_param2=true_s8, title=f"{model_name} marginal vs profile — ix={test_index}", figsize=(11, 4.2), ) cmp_path = output_dir / f"fig6_marg_vs_prof_{nm}_ix{test_index}.png" fig_cmp.savefig(cmp_path, dpi=185, bbox_inches="tight") plt.close(fig_cmp) print(f"Saved: {cmp_path}") def integrate_figure6_with_multi_anchor( posteriors_list: list[np.ndarray], om_grid: np.ndarray, s8_grid: np.ndarray, true_values_list: list[tuple[float, float]], test_indices: list[int], output_dir: Path, model_name: str = "DDPM-2", ) -> None: """2×3 grid with Figure–6-ish 2D + top marginal.""" output_dir = Path(output_dir) nm = model_name.replace(" ", "-").lower() fig = plt.figure(figsize=(20, 14)) gs_o = gridspec.GridSpec(2, 3, figure=fig, hspace=0.33, wspace=0.32) for idx, (posterior, true_vals, test_ix) in enumerate( zip(posteriors_list, true_values_list, test_indices) ): true_om, true_s8 = true_vals row, col = divmod(idx, 3) posterior_norm = posterior / posterior.sum() sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954]) P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij") gs_sub = gridspec.GridSpecFromSubplotSpec( 2, 2, subplot_spec=gs_o[row, col], width_ratios=[4, 1], height_ratios=[1, 4], hspace=0.06, wspace=0.06, ) ax_main = fig.add_subplot(gs_sub[1, 0]) ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main) ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85) if len(set(sigma_levels)) >= 1: ax_main.contour( P1, P2, posterior_norm, levels=sigma_levels, colors=["darkblue", "steelblue"], linewidths=[2.0, 1.5], ) ax_main.scatter(true_om, true_s8, s=100, c="red", marker="x", linewidths=2.5, zorder=10) ax_main.set_xlim(om_grid[0], om_grid[-1]) ax_main.set_ylim(s8_grid[0], s8_grid[-1]) ax_main.set_xlabel(r"$\Omega_m$" if row == 1 else "", fontsize=11) ax_main.set_ylabel(r"$\sigma_8$" if col == 0 else "", fontsize=11) ax_main.set_title(f"Test ix={test_ix}", fontsize=11, pad=5) ax_main.grid(True, alpha=0.2) marginal_om = posterior_norm.sum(axis=1) marginal_om /= marginal_om.sum() + 1e-30 ax_top.fill_between( om_grid, 0.0, marginal_om, alpha=0.6, color="steelblue", edgecolor="steelblue", ) ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2) ax_top.set_xlim(om_grid[0], om_grid[-1]) ax_top.set_ylim(0, marginal_om.max() * 1.1) ax_top.tick_params(labelbottom=False, labelsize=9) ax_top.set_ylabel("$P(\\Omega_m)$", fontsize=9) ax_top.grid(True, alpha=0.2) ax_side = fig.add_subplot(gs_sub[1, 1], sharey=ax_main) marginal_s8 = posterior_norm.sum(axis=0) marginal_s8 /= marginal_s8.sum() + 1e-30 ax_side.fill_betweenx(s8_grid, 0.0, marginal_s8, alpha=0.6, color="steelblue", edgecolor="steelblue") ax_side.axhline(true_s8, color="red", linestyle="--", linewidth=2) ax_side.set_ylim(s8_grid[0], s8_grid[-1]) ax_side.set_xlim(0, marginal_s8.max() * 1.15) ax_side.tick_params(labelleft=False) fig.suptitle( f"{model_name} — Figure 6 Style: Six Test Anchors", fontsize=15, y=0.995, fontweight="bold", ) save_path = output_dir / f"fig6_style_{nm}_all_anchors.png" fig.savefig(save_path, dpi=200, bbox_inches="tight") plt.close(fig) print(f"Saved multi-anchor grid: {save_path}") def integrate_figure6_model_comparison( posteriors_dict: dict[str, np.ndarray], om_grid: np.ndarray, s8_grid: np.ndarray, true_om: float, true_s8: float, test_index: int, output_dir: Path, ) -> None: """Side-by-side model comparison panels.""" output_dir = Path(output_dir) n_models = len(posteriors_dict) fig = plt.figure(figsize=(8 * max(1, min(n_models, 4)), 8)) gs_outer = gridspec.GridSpec(1, n_models, figure=fig, wspace=0.32) for idx, (model_name, posterior) in enumerate(posteriors_dict.items()): gs_sub = gridspec.GridSpecFromSubplotSpec( 2, 2, subplot_spec=gs_outer[0, idx], width_ratios=[4, 1], height_ratios=[1, 4], hspace=0.06, wspace=0.06, ) posterior_norm = posterior / posterior.sum() sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954]) P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij") ax_main = fig.add_subplot(gs_sub[1, 0]) ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main) ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85) if len(set(sigma_levels)) >= 1: ax_main.contour( P1, P2, posterior_norm, levels=sigma_levels, colors=["darkblue", "steelblue"], linewidths=[2.5, 2.0], ) ax_main.scatter(true_om, true_s8, s=120, c="red", marker="x", linewidths=3, zorder=10) ax_main.set_xlabel(r"$\Omega_m$", fontsize=13) ax_main.set_ylabel(r"$\sigma_8$", fontsize=13) ax_main.set_title(model_name, fontsize=13, pad=10, fontweight="bold") ax_main.grid(True, alpha=0.3) ax_main.set_xlim(om_grid[0], om_grid[-1]) ax_main.set_ylim(s8_grid[0], s8_grid[-1]) marginal = posterior_norm.sum(axis=1) marginal /= marginal.sum() + 1e-30 ax_top.fill_between(om_grid, 0.0, marginal, alpha=0.6, color="steelblue") ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2.5) ax_top.set_xlim(om_grid[0], om_grid[-1]) ax_top.set_ylim(0, marginal.max() * 1.12) ax_top.tick_params(labelbottom=False) ax_top.grid(True, alpha=0.25) ax_sb = fig.add_subplot(gs_sub[1, 1], sharey=ax_main) marginal_s = posterior_norm.sum(axis=0) marginal_s /= marginal_s.sum() + 1e-30 ax_sb.fill_betweenx(s8_grid, 0.0, marginal_s, alpha=0.6, color="steelblue") ax_sb.axhline(true_s8, color="red", linestyle="--", linewidth=2.5) ax_sb.set_ylim(s8_grid[0], s8_grid[-1]) fig.suptitle( f"Model Comparison (Figure 6 Style) — Test ix={test_index}", fontsize=15, y=0.995, fontweight="bold", ) save_path = output_dir / f"fig6_style_model_comparison_ix{test_index}.png" fig.savefig(save_path, dpi=200, bbox_inches="tight") plt.close(fig) print(f"Saved model comparison: {save_path}") def print_integration_guide() -> None: example_integration = """ # Add imports next to posterior code: from figure6_2409_style import create_figure6_style_plot from ddpm_figure6_integration import ( integrate_figure6_with_ddpm2, integrate_figure6_with_multi_anchor, integrate_figure6_model_comparison, ) """ print(example_integration.strip())