Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified | """ | |
| Figure 6 style (arXiv:2409.09101) helpers for DDPM surrogate posteriors — use with ddpm_posterior_six_anchors / run_ddpm_figure6_suite. | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib import gridspec | |
| from figure6_2409_style import ( | |
| create_comparison_marginal_vs_profile, | |
| create_figure6_style_plot, | |
| ) | |
| from sigma_contour_utils import compute_sigma_levels | |
| def integrate_figure6_with_ddpm2( | |
| Wmap: np.ndarray, | |
| om_grid: np.ndarray, | |
| s8_grid: np.ndarray, | |
| true_om: float, | |
| true_s8: float, | |
| test_index: int, | |
| output_dir: Path, | |
| model_name: str = "DDPM-2", | |
| ) -> None: | |
| """Single-map Figure 6 style: marginal and profile PNGs.""" | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| nm = model_name.replace(" ", "-").lower() | |
| fig_marginal = create_figure6_style_plot( | |
| Wmap, | |
| om_grid, | |
| s8_grid, | |
| true_param1=true_om, | |
| true_param2=true_s8, | |
| param1_label=r"$\Omega_m$", | |
| param2_label=r"$\sigma_8$", | |
| title=f"{model_name} — Test ix={test_index} (Marginal)", | |
| show_profile=False, | |
| figsize=(10, 10), | |
| ) | |
| save_path_marginal = output_dir / f"fig6_style_{nm}_ix{test_index}_marginal.png" | |
| fig_marginal.savefig(save_path_marginal, dpi=200, bbox_inches="tight") | |
| plt.close(fig_marginal) | |
| print(f"Saved: {save_path_marginal}") | |
| fig_profile = create_figure6_style_plot( | |
| Wmap, | |
| om_grid, | |
| s8_grid, | |
| true_param1=true_om, | |
| true_param2=true_s8, | |
| param1_label=r"$\Omega_m$", | |
| param2_label=r"$\sigma_8$", | |
| title=f"{model_name} — Test ix={test_index} (Profile)", | |
| show_profile=True, | |
| figsize=(10, 10), | |
| ) | |
| save_path_profile = output_dir / f"fig6_style_{nm}_ix{test_index}_profile.png" | |
| fig_profile.savefig(save_path_profile, dpi=200, bbox_inches="tight") | |
| plt.close(fig_profile) | |
| print(f"Saved: {save_path_profile}") | |
| fig_cmp = create_comparison_marginal_vs_profile( | |
| Wmap, | |
| om_grid, | |
| s8_grid, | |
| true_param1=true_om, | |
| true_param2=true_s8, | |
| title=f"{model_name} marginal vs profile — ix={test_index}", | |
| figsize=(11, 4.2), | |
| ) | |
| cmp_path = output_dir / f"fig6_marg_vs_prof_{nm}_ix{test_index}.png" | |
| fig_cmp.savefig(cmp_path, dpi=185, bbox_inches="tight") | |
| plt.close(fig_cmp) | |
| print(f"Saved: {cmp_path}") | |
| def integrate_figure6_with_multi_anchor( | |
| posteriors_list: list[np.ndarray], | |
| om_grid: np.ndarray, | |
| s8_grid: np.ndarray, | |
| true_values_list: list[tuple[float, float]], | |
| test_indices: list[int], | |
| output_dir: Path, | |
| model_name: str = "DDPM-2", | |
| ) -> None: | |
| """2×3 grid with Figure–6-ish 2D + top marginal.""" | |
| output_dir = Path(output_dir) | |
| nm = model_name.replace(" ", "-").lower() | |
| fig = plt.figure(figsize=(20, 14)) | |
| gs_o = gridspec.GridSpec(2, 3, figure=fig, hspace=0.33, wspace=0.32) | |
| for idx, (posterior, true_vals, test_ix) in enumerate( | |
| zip(posteriors_list, true_values_list, test_indices) | |
| ): | |
| true_om, true_s8 = true_vals | |
| row, col = divmod(idx, 3) | |
| posterior_norm = posterior / posterior.sum() | |
| sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954]) | |
| P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij") | |
| gs_sub = gridspec.GridSpecFromSubplotSpec( | |
| 2, | |
| 2, | |
| subplot_spec=gs_o[row, col], | |
| width_ratios=[4, 1], | |
| height_ratios=[1, 4], | |
| hspace=0.06, | |
| wspace=0.06, | |
| ) | |
| ax_main = fig.add_subplot(gs_sub[1, 0]) | |
| ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main) | |
| ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85) | |
| if len(set(sigma_levels)) >= 1: | |
| ax_main.contour( | |
| P1, | |
| P2, | |
| posterior_norm, | |
| levels=sigma_levels, | |
| colors=["darkblue", "steelblue"], | |
| linewidths=[2.0, 1.5], | |
| ) | |
| ax_main.scatter(true_om, true_s8, s=100, c="red", marker="x", linewidths=2.5, zorder=10) | |
| ax_main.set_xlim(om_grid[0], om_grid[-1]) | |
| ax_main.set_ylim(s8_grid[0], s8_grid[-1]) | |
| ax_main.set_xlabel(r"$\Omega_m$" if row == 1 else "", fontsize=11) | |
| ax_main.set_ylabel(r"$\sigma_8$" if col == 0 else "", fontsize=11) | |
| ax_main.set_title(f"Test ix={test_ix}", fontsize=11, pad=5) | |
| ax_main.grid(True, alpha=0.2) | |
| marginal_om = posterior_norm.sum(axis=1) | |
| marginal_om /= marginal_om.sum() + 1e-30 | |
| ax_top.fill_between( | |
| om_grid, | |
| 0.0, | |
| marginal_om, | |
| alpha=0.6, | |
| color="steelblue", | |
| edgecolor="steelblue", | |
| ) | |
| ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2) | |
| ax_top.set_xlim(om_grid[0], om_grid[-1]) | |
| ax_top.set_ylim(0, marginal_om.max() * 1.1) | |
| ax_top.tick_params(labelbottom=False, labelsize=9) | |
| ax_top.set_ylabel("$P(\\Omega_m)$", fontsize=9) | |
| ax_top.grid(True, alpha=0.2) | |
| ax_side = fig.add_subplot(gs_sub[1, 1], sharey=ax_main) | |
| marginal_s8 = posterior_norm.sum(axis=0) | |
| marginal_s8 /= marginal_s8.sum() + 1e-30 | |
| ax_side.fill_betweenx(s8_grid, 0.0, marginal_s8, alpha=0.6, color="steelblue", edgecolor="steelblue") | |
| ax_side.axhline(true_s8, color="red", linestyle="--", linewidth=2) | |
| ax_side.set_ylim(s8_grid[0], s8_grid[-1]) | |
| ax_side.set_xlim(0, marginal_s8.max() * 1.15) | |
| ax_side.tick_params(labelleft=False) | |
| fig.suptitle( | |
| f"{model_name} — Figure 6 Style: Six Test Anchors", | |
| fontsize=15, | |
| y=0.995, | |
| fontweight="bold", | |
| ) | |
| save_path = output_dir / f"fig6_style_{nm}_all_anchors.png" | |
| fig.savefig(save_path, dpi=200, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"Saved multi-anchor grid: {save_path}") | |
| def integrate_figure6_model_comparison( | |
| posteriors_dict: dict[str, np.ndarray], | |
| om_grid: np.ndarray, | |
| s8_grid: np.ndarray, | |
| true_om: float, | |
| true_s8: float, | |
| test_index: int, | |
| output_dir: Path, | |
| ) -> None: | |
| """Side-by-side model comparison panels.""" | |
| output_dir = Path(output_dir) | |
| n_models = len(posteriors_dict) | |
| fig = plt.figure(figsize=(8 * max(1, min(n_models, 4)), 8)) | |
| gs_outer = gridspec.GridSpec(1, n_models, figure=fig, wspace=0.32) | |
| for idx, (model_name, posterior) in enumerate(posteriors_dict.items()): | |
| gs_sub = gridspec.GridSpecFromSubplotSpec( | |
| 2, | |
| 2, | |
| subplot_spec=gs_outer[0, idx], | |
| width_ratios=[4, 1], | |
| height_ratios=[1, 4], | |
| hspace=0.06, | |
| wspace=0.06, | |
| ) | |
| posterior_norm = posterior / posterior.sum() | |
| sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954]) | |
| P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij") | |
| ax_main = fig.add_subplot(gs_sub[1, 0]) | |
| ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main) | |
| ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85) | |
| if len(set(sigma_levels)) >= 1: | |
| ax_main.contour( | |
| P1, | |
| P2, | |
| posterior_norm, | |
| levels=sigma_levels, | |
| colors=["darkblue", "steelblue"], | |
| linewidths=[2.5, 2.0], | |
| ) | |
| ax_main.scatter(true_om, true_s8, s=120, c="red", marker="x", linewidths=3, zorder=10) | |
| ax_main.set_xlabel(r"$\Omega_m$", fontsize=13) | |
| ax_main.set_ylabel(r"$\sigma_8$", fontsize=13) | |
| ax_main.set_title(model_name, fontsize=13, pad=10, fontweight="bold") | |
| ax_main.grid(True, alpha=0.3) | |
| ax_main.set_xlim(om_grid[0], om_grid[-1]) | |
| ax_main.set_ylim(s8_grid[0], s8_grid[-1]) | |
| marginal = posterior_norm.sum(axis=1) | |
| marginal /= marginal.sum() + 1e-30 | |
| ax_top.fill_between(om_grid, 0.0, marginal, alpha=0.6, color="steelblue") | |
| ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2.5) | |
| ax_top.set_xlim(om_grid[0], om_grid[-1]) | |
| ax_top.set_ylim(0, marginal.max() * 1.12) | |
| ax_top.tick_params(labelbottom=False) | |
| ax_top.grid(True, alpha=0.25) | |
| ax_sb = fig.add_subplot(gs_sub[1, 1], sharey=ax_main) | |
| marginal_s = posterior_norm.sum(axis=0) | |
| marginal_s /= marginal_s.sum() + 1e-30 | |
| ax_sb.fill_betweenx(s8_grid, 0.0, marginal_s, alpha=0.6, color="steelblue") | |
| ax_sb.axhline(true_s8, color="red", linestyle="--", linewidth=2.5) | |
| ax_sb.set_ylim(s8_grid[0], s8_grid[-1]) | |
| fig.suptitle( | |
| f"Model Comparison (Figure 6 Style) — Test ix={test_index}", | |
| fontsize=15, | |
| y=0.995, | |
| fontweight="bold", | |
| ) | |
| save_path = output_dir / f"fig6_style_model_comparison_ix{test_index}.png" | |
| fig.savefig(save_path, dpi=200, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"Saved model comparison: {save_path}") | |
| def print_integration_guide() -> None: | |
| example_integration = """ | |
| # Add imports next to posterior code: | |
| from figure6_2409_style import create_figure6_style_plot | |
| from ddpm_figure6_integration import ( | |
| integrate_figure6_with_ddpm2, | |
| integrate_figure6_with_multi_anchor, | |
| integrate_figure6_model_comparison, | |
| ) | |
| """ | |
| print(example_integration.strip()) | |