Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
f513198 verified | #!/usr/bin/env python3 | |
| """ | |
| Compare 2-parameter and 6-parameter conditional DDPMs (CAMELS LH) side-by-side: | |
| • Random-draw vs test-conditioned triplets (CAMELS | DDPM-2 | DDPM-6) | |
| • Six anchor cosmologies: P(k) and PDF diagnostics (triple curves per panel where applicable) | |
| • LHS R² cosmology plots (LHS-50 × 15 maps — expensive) | |
| • MLP P(k) → label recovery ( sklearn MLP, two models + shared CAMELS calibration ) | |
| • Surrogate posterior on (Ωm, σ8) for a fixed test index | |
| • Training / validation loss on one axis (Slurm .out for DDPM-6; DDPM-2 defaults to bundled JSON) | |
| Outputs under --output-dir (default: Models/ddpm_comparison_out/). | |
| GPU: both models are resident while generating comparison panels; use a single GPU with | |
| sufficient memory, or run heavier steps separately with refactors. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Sequence, Tuple | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| # --- Repo paths --- | |
| MODELS_ROOT = Path(__file__).resolve().parents[1] | |
| CODE_6 = (MODELS_ROOT / "6param_ddpm_hi_lh6").resolve() | |
| if str(CODE_6) not in sys.path: | |
| sys.path.insert(0, str(CODE_6)) | |
| import evaluate_conditional as ec # noqa: E402 | |
| import eval_model as em # noqa: E402 | |
| from figure9_posterior import build_cosmo_grid, log_pk_observed # noqa: E402 | |
| from plot_r2_cosmology_lhs import compute_lhs_r2, plot_r2_cosmology_figure # noqa: E402 | |
| from compare_ddpm_training_curves import ( # noqa: E402 | |
| load_train_val_series, | |
| parse_slurm_training_log, | |
| ) | |
| DEFAULT_SLURM_6 = Path( | |
| "<DDPM_ROOT>/april_26/ddpm_hi_lh6/scripts/shell/slurm-698243.out" | |
| ) | |
| # Bundled train/val (no 2-param Slurm log in-repo); see ``ddpm_2param_training_loss.json``. | |
| DEFAULT_DDPM2_TRAINING = (Path(__file__).resolve().parent / "ddpm_2param_training_loss.json") | |
| def _fmt_title(lab: np.ndarray) -> str: | |
| t = np.asarray(lab, dtype=float).ravel() | |
| if t.size <= 2: | |
| return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f}" | |
| tail = ", ".join(f"{float(v):.3g}" for v in t[2:]) | |
| return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f} | " + tail | |
| def _latin_hypercube(n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator) -> np.ndarray: | |
| """Classic LHS (same as notebook).""" | |
| d = int(lo.shape[0]) | |
| u = rng.random((n, d)) | |
| cut = np.linspace(0.0, 1.0, n + 1) | |
| a, b = cut[:-1], cut[1:] | |
| width = (b - a)[:, np.newaxis] | |
| rd = a[:, np.newaxis] + u * width | |
| for j in range(d): | |
| rng.shuffle(rd[:, j]) | |
| span = (hi - lo).astype(np.float64) | |
| return (lo + rd * span).astype(np.float32) | |
| def generate_maps( | |
| model: torch.nn.Module, | |
| labels_np: np.ndarray, | |
| label_mean: np.ndarray, | |
| label_std: np.ndarray, | |
| H: int, | |
| W: int, | |
| device: torch.device, | |
| ddim_steps: int, | |
| batch_size: int, | |
| ) -> np.ndarray: | |
| out = [] | |
| n = labels_np.shape[0] | |
| for j0 in range(0, n, batch_size): | |
| chunk = labels_np[j0 : j0 + batch_size] | |
| bt = ec.prepare_labels_for_model(chunk.astype(np.float32), label_mean, label_std).to(device) | |
| g = model.sample( | |
| labels=bt, | |
| channels=1, | |
| height=H, | |
| width=W, | |
| device=device, | |
| progress=False, | |
| use_ddim=True, | |
| ddim_steps=ddim_steps, | |
| ) | |
| out.append(ec.from_model_output(g)) | |
| return np.concatenate(out, axis=0) | |
| def load_model(bundle_args: Path, ckpt: Path, device: torch.device): | |
| cfg = ec.load_training_config(str(bundle_args)) | |
| model = ec.build_model(cfg, device) | |
| ec.load_checkpoint(model, str(ckpt), device) | |
| model.eval() | |
| return model, cfg | |
| def free_torch(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def plot_training_overlay( | |
| out_dir: Path, | |
| slurm6: Path | None, | |
| slurm2: Path | None, | |
| ) -> None: | |
| """Train + val curves for DDPM6 and optionally DDPM2 on one logarithmic-loss axis.""" | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| plotted = False | |
| if slurm6 and Path(slurm6).is_file(): | |
| ep, tr, va = parse_slurm_training_log(slurm6) | |
| ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-6 train", color="#1f77b4", alpha=0.85) | |
| ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-6 val", color="#174a75", alpha=0.95) | |
| plotted = True | |
| else: | |
| print("Warning: 6-param Slurm log not found; skipped overlay for DDPM-6.") | |
| if slurm2 and Path(slurm2).is_file(): | |
| ep, tr, va = load_train_val_series(slurm2) | |
| ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-2 train", color="#ff7f0e", alpha=0.85) | |
| ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-2 val", color="#994d00", alpha=0.95) | |
| plotted = True | |
| elif slurm2 is not None: | |
| print(f"Warning: 2-param training series not found ({slurm2}); use --slurm-2param or restore bundled JSON.") | |
| if not plotted: | |
| print("No Slurm logs parsed — writing placeholder note instead of curves.") | |
| ax.text( | |
| 0.5, | |
| 0.5, | |
| "Pass --slurm-6param; DDPM-2 uses bundled JSON by default (--slurm-2param).", | |
| ha="center", | |
| va="center", | |
| transform=ax.transAxes, | |
| ) | |
| else: | |
| ax.set_yscale("log") | |
| ax.grid(True, alpha=0.3) | |
| ax.set_xlabel("Epoch") | |
| ax.set_ylabel("MSE diffusion loss") | |
| ax.legend(loc="upper right", fontsize=8) | |
| ax.set_title("Training / validation curves (combined)") | |
| outp = out_dir / "comparison_training_train_val_overlay.png" | |
| fig.savefig(outp, dpi=170, bbox_inches="tight") | |
| plt.close(fig) | |
| print("Saved", outp) | |
| def run_random_theta_triplets( | |
| out_dir: Path, | |
| imgs6: np.ndarray, | |
| lab6: np.ndarray, | |
| mean6: np.ndarray, | |
| std6: np.ndarray, | |
| mean2: np.ndarray, | |
| std2: np.ndarray, | |
| model2, | |
| model6, | |
| device: torch.device, | |
| ddim_steps: int, | |
| seed: int, | |
| n_pairs: int, | |
| batch_size: int, | |
| ): | |
| """Random LHS targets in CAMELS bbox; CAMELS column = NN real map.""" | |
| rng = np.random.default_rng(seed) | |
| lo, hi = lab6.min(0), lab6.max(0) | |
| targets = _latin_hypercube(min(n_pairs, 32), lo, hi, rng)[:n_pairs] | |
| H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1]) | |
| tg2 = targets[:, :2].astype(np.float32) | |
| fig = plt.figure(figsize=(3.8 * max(3, n_pairs * 3), 4.1)) | |
| for i in range(n_pairs): | |
| theta6 = targets[i].astype(np.float32) | |
| theta2 = tg2[i] | |
| dist = np.linalg.norm(lab6 - theta6[None, :], axis=1).astype(np.float64) | |
| nn = int(np.argmin(dist)) | |
| nn_img = imgs6[nn] | |
| gen2 = generate_maps( | |
| model2, theta2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size | |
| ) | |
| gen6 = generate_maps( | |
| model6, theta6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size | |
| ) | |
| titles = ("CAMELS (NN)", "DDPM-2", "DDPM-6") | |
| for j, img in enumerate((nn_img, gen2[0], gen6[0])): | |
| ax = fig.add_subplot(1, n_pairs * 3, i * 3 + j + 1) | |
| ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno") | |
| ax.axis("off") | |
| ax.set_title(titles[j], fontsize=8) | |
| plt.suptitle( | |
| "Random LHS cosmologies — CAMELS = nearest-neighbour truth | gens conditioned on LHS labels", | |
| fontsize=10, | |
| y=1.02, | |
| ) | |
| p = out_dir / "comparison_random_lhs_triplets_camels_ddpm2_ddpm6.png" | |
| plt.tight_layout() | |
| plt.savefig(p, dpi=160, bbox_inches="tight") | |
| plt.close(fig) | |
| print("Saved", p) | |
| def run_conditioned_test_triplets( | |
| out_dir: Path, | |
| imgs6: np.ndarray, | |
| lab6: np.ndarray, | |
| mean6: np.ndarray, | |
| std6: np.ndarray, | |
| mean2: np.ndarray, | |
| std2: np.ndarray, | |
| model2, | |
| model6, | |
| device: torch.device, | |
| ddim_steps: int, | |
| seed: int, | |
| n_pairs: int, | |
| batch_size: int, | |
| ): | |
| """Same rows from test split: conditioned on truth labels.""" | |
| rng = np.random.default_rng(seed + 1) | |
| idx = rng.choice(len(imgs6), size=min(n_pairs, len(imgs6)), replace=False) | |
| H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1]) | |
| fig, axes = plt.subplots(1, n_pairs * 3, figsize=(2.9 * n_pairs * 3, 3.8), squeeze=False) | |
| for ii, ix in enumerate(idx): | |
| tg6 = lab6[ix].astype(np.float32) | |
| tg2 = tg6[:2] | |
| rm = imgs6[ix] | |
| g2 = generate_maps(model2, tg2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size)[0] | |
| g6 = generate_maps(model6, tg6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size)[0] | |
| for j, img in enumerate((rm, g2, g6)): | |
| ax = axes[0, ii * 3 + j] | |
| ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno") | |
| ax.axis("off") | |
| if ii == 0: | |
| ax.set_title(("CAMELS", "DDPM-2", "DDPM-6")[j], fontsize=8) | |
| axes[0, ii * 3].set_xlabel(_fmt_title(tg6), fontsize=7) | |
| plt.suptitle(f"Random test ix (conditioned on truth labels), n={len(idx)}", fontsize=10, y=1.06) | |
| p = out_dir / "comparison_test_conditioned_camels_ddpm2_ddpm6.png" | |
| plt.savefig(p, dpi=160, bbox_inches="tight") | |
| plt.close(fig) | |
| print("Saved", p) | |
| def pk_pdf_six_sets( | |
| out_dir: Path, | |
| name: str, | |
| images_split: np.ndarray, | |
| labels_split: np.ndarray, | |
| label_mean: np.ndarray, | |
| label_std: np.ndarray, | |
| model, | |
| device: torch.device, | |
| ddim_steps: int, | |
| batch_size: int, | |
| n_per_set: int, | |
| ): | |
| """Six anchor rows (evenly spaced ix in test split), N_PER_SET DDIM samples.""" | |
| H, W = int(images_split.shape[-2]), int(images_split.shape[-1]) | |
| ldim = int(labels_split.shape[1]) | |
| idx = np.linspace(0, len(labels_split) - 1, num=6, dtype=int) | |
| targets = labels_split[idx].copy() | |
| box = 25.0 | |
| dk_ref = None | |
| panels_pk = [] | |
| rng_pdf_bins = np.linspace(14.0, 22.0, 101) | |
| bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:]) | |
| fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True) | |
| axes_pk = axes_pk.ravel() | |
| fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 4.8 * 2), squeeze=False) | |
| for si, target_l in enumerate(targets): | |
| dist = np.linalg.norm(labels_split - target_l[None, :], axis=1).astype(np.float64) | |
| ex = idx[si] | |
| dist[ex] = np.inf if ex < len(dist) else np.inf | |
| nn_idx = np.argsort(dist)[:n_per_set] | |
| real_batch = images_split[nn_idx] | |
| rep = np.tile(target_l[None, :], (n_per_set, 1)) | |
| gen = generate_maps(model, rep, label_mean, label_std, H, W, device, ddim_steps, batch_size) | |
| dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box) | |
| dk_g, mg, sg = ec.calculate_power_spectrum_batch(gen, box_size=box) | |
| dk_ref = dk_r | |
| x = dk_ref[1:] | |
| axpk = axes_pk[si] | |
| axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#333") | |
| axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.08, color="#333") | |
| axpk.plot(x, mg[1:], lw=2, label=f"Generated (ldim={ldim})", color="#d95f02") | |
| axpk.fill_between(x, mg[1:] - sg[1:], mg[1:] + sg[1:], alpha=0.08, color="#d95f02") | |
| axpk.set_yscale("log") | |
| axpk.grid(alpha=0.25) | |
| axpk.set_title(_fmt_title(target_l), fontsize=8) | |
| panels_pk.append((si, dk_r, mr, sr, mg, sg)) | |
| # PDF µ/σ | |
| tb = []; rb = [] | |
| for i in range(n_per_set): | |
| for arr, store in zip((real_batch, gen), (tb, rb)): | |
| ims = np.clip(arr[i].ravel(), 0.0, 1.0) | |
| logn = 14.0 + (22.0 - 14.0) * ims | |
| hst, _ = np.histogram(logn, bins=rng_pdf_bins, density=True) | |
| store.append(hst) | |
| tb = np.asarray(tb); rb = np.asarray(rb) | |
| axes_pdf[si, 0].plot(bin_pdf, tb.mean(axis=0), lw=2, label="CAMELS NN", color="#333") | |
| axes_pdf[si, 0].plot(bin_pdf, rb.mean(axis=0), lw=2, label="Generated", color="#d95f02") | |
| axes_pdf[si, 1].plot(bin_pdf, tb.std(axis=0), lw=2, ls="-", label="CAMELS σ", color="#333") | |
| axes_pdf[si, 1].plot(bin_pdf, rb.std(axis=0), lw=2, ls="--", label="Gen σ", color="#d95f02") | |
| axes_pk[0].legend(fontsize=7, loc="lower left") | |
| fig_pk.suptitle(f"$P(k)$ mean±std — six anchors — {name}", fontsize=10) | |
| fig_pk.tight_layout() | |
| p_pk = out_dir / f"six_anchor_pk_{name}.png" | |
| fig_pk.savefig(p_pk, dpi=160) | |
| plt.close(fig_pk) | |
| axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$") | |
| axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$") | |
| fig_pdf.suptitle(f"PDF mean & σ — six anchors × {n_per_set} — {name}") | |
| fig_pdf.tight_layout() | |
| p_pdf = out_dir / f"six_anchor_pdf_mu_sigma_{name}.png" | |
| fig_pdf.savefig(p_pdf, dpi=160) | |
| plt.close(fig_pdf) | |
| print("Saved", p_pk) | |
| print("Saved", p_pdf) | |
| def pk_six_triplet_combined( | |
| out_dir: Path, | |
| imgs6: np.ndarray, | |
| lab6: np.ndarray, | |
| mean6: np.ndarray, | |
| std6: np.ndarray, | |
| mean2: np.ndarray, | |
| std2: np.ndarray, | |
| model2: torch.nn.Module, | |
| model6: torch.nn.Module, | |
| device: torch.device, | |
| ddim_steps: int, | |
| batch_size: int, | |
| n_per_set: int, | |
| ) -> None: | |
| """Six anchors — mean P(k) for CAMELS vs DDPM-2 vs DDPM-6; analogous PDF overlays.""" | |
| H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1]) | |
| idx = np.linspace(0, len(lab6) - 1, num=6, dtype=int) | |
| targets = lab6[idx].copy() | |
| box = 25.0 | |
| fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True) | |
| axes_pk = axes_pk.ravel() | |
| rng_pdf_bins = np.linspace(14.0, 22.0, 101) | |
| bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:]) | |
| fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 10.5)) | |
| for si, target_l in enumerate(targets): | |
| dist = np.linalg.norm(lab6 - target_l[None, :], axis=1).astype(np.float64) | |
| ex = int(idx[si]) | |
| if ex < len(dist): | |
| dist = dist.copy() | |
| dist[ex] = np.inf | |
| nn_idx = np.argsort(dist)[:n_per_set] | |
| real_batch = imgs6[nn_idx] | |
| tg2 = np.tile(target_l[:2][None, :], (n_per_set, 1)).astype(np.float32) | |
| tg6 = np.tile(target_l[None, :], (n_per_set, 1)).astype(np.float32) | |
| gen2 = generate_maps(model2, tg2, mean2, std2, H, W, device, ddim_steps, batch_size) | |
| gen6 = generate_maps(model6, tg6, mean6, std6, H, W, device, ddim_steps, batch_size) | |
| dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box) | |
| _, m2, s2 = ec.calculate_power_spectrum_batch(gen2, box_size=box) | |
| _, mG, sG = ec.calculate_power_spectrum_batch(gen6, box_size=box) | |
| x = dk_r[1:] | |
| axpk = axes_pk[si] | |
| axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#222") | |
| axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.06, color="#222") | |
| axpk.plot(x, m2[1:], lw=2, label="DDPM-2 μ", color="#ff7f0e") | |
| axpk.fill_between(x, m2[1:] - s2[1:], m2[1:] + s2[1:], alpha=0.06, color="#ff7f0e") | |
| axpk.plot(x, mG[1:], lw=2, label="DDPM-6 μ", color="#1f77b4") | |
| axpk.fill_between(x, mG[1:] - sG[1:], mG[1:] + sG[1:], alpha=0.06, color="#1f77b4") | |
| axpk.set_yscale("log") | |
| axpk.grid(alpha=0.25) | |
| axpk.set_title(_fmt_title(target_l), fontsize=8) | |
| if si == 0: | |
| axpk.legend(fontsize=6.2, loc="lower left") | |
| pdf_rows_lists = [] | |
| for imgs in (real_batch, gen2, gen6): | |
| hb = [] | |
| for i in range(min(n_per_set, len(imgs))): | |
| px = np.clip(imgs[i].ravel(), 0.0, 1.0) | |
| ln = 14.0 + (22.0 - 14.0) * px | |
| hb.append(np.histogram(ln, bins=rng_pdf_bins, density=True)[0]) | |
| pdf_rows_lists.append(np.asarray(hb)) | |
| cam_pdf, d2_pdf, d6_pdf = pdf_rows_lists | |
| axes_pdf[si, 0].plot(bin_pdf, cam_pdf.mean(axis=0), lw=2, color="#222", label="CAMELS μ") | |
| axes_pdf[si, 0].plot(bin_pdf, d2_pdf.mean(axis=0), lw=2, color="#ff7f0e", label="DDPM-2 μ") | |
| axes_pdf[si, 0].plot(bin_pdf, d6_pdf.mean(axis=0), lw=2, color="#1f77b4", label="DDPM-6 μ") | |
| axes_pdf[si, 1].plot(bin_pdf, cam_pdf.std(axis=0), lw=2, color="#222") | |
| axes_pdf[si, 1].plot(bin_pdf, d2_pdf.std(axis=0), lw=2, ls="--", color="#ff7f0e") | |
| axes_pdf[si, 1].plot(bin_pdf, d6_pdf.std(axis=0), lw=2, ls="--", color="#1f77b4") | |
| fig_pk.suptitle("$P(k)$ CAMELS vs DDPM-2 vs DDPM-6 — six Ωm–σ8 anchors", fontsize=11) | |
| fig_pk.tight_layout() | |
| p_pk = out_dir / "six_anchor_pk_overlay_camels_ddpm2_ddpm6.png" | |
| fig_pk.savefig(p_pk, dpi=160) | |
| plt.close(fig_pk) | |
| axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$") | |
| axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$") | |
| fig_pdf.suptitle(r"PDF mean ($\mu$) and std ($\sigma$) overlays", fontsize=10) | |
| fig_pdf.tight_layout() | |
| p_pdf = out_dir / "six_anchor_pdf_overlay_camels_ddpm2_ddpm6.png" | |
| fig_pdf.savefig(p_pdf, dpi=160) | |
| plt.close(fig_pdf) | |
| print("Saved", p_pk) | |
| print("Saved", p_pdf) | |
| def mlp_recovery_dual( | |
| out_dir: Path, | |
| data_train: Path, | |
| imgs_te: np.ndarray, | |
| lab_te: np.ndarray, | |
| mean: np.ndarray, | |
| std: np.ndarray, | |
| model_ddpm: torch.nn.Module, | |
| tag: str, | |
| device: torch.device, | |
| ddim_steps: int, | |
| seed: int, | |
| ) -> None: | |
| from sklearn.metrics import mean_squared_error | |
| from sklearn.neural_network import MLPRegressor | |
| ldim = lab_te.shape[1] | |
| Npix = imgs_te.shape[-1] | |
| dl = 25.0 / Npix | |
| def pk_row(im): | |
| _dk, pk = ec.PowerSpectrum(np.asarray(im, dtype=np.float64), N=Npix, dl=dl) | |
| return pk[1:].astype(np.float32) | |
| img_tr_np, lab_tr_np = ec.load_split(data_train, "train") | |
| if len(img_tr_np) > 2000: | |
| rng = np.random.default_rng(seed) | |
| jj = rng.choice(len(img_tr_np), 2000, replace=False) | |
| img_tr_np, lab_tr_np = img_tr_np[jj], lab_tr_np[jj] | |
| X_train = np.stack([pk_row(img_tr_np[i]) for i in range(len(img_tr_np))], axis=0) | |
| y_train = lab_tr_np.astype(np.float32) | |
| mlp = MLPRegressor( | |
| hidden_layer_sizes=(64, 64), | |
| alpha=1e-4, | |
| random_state=seed, | |
| max_iter=250, | |
| early_stopping=True, | |
| validation_fraction=0.1, | |
| ) | |
| mlp.fit(X_train, y_train) | |
| n_ev = min(40, len(imgs_te)) | |
| eval_idx = np.arange(n_ev) | |
| X_real = np.stack([pk_row(imgs_te[i]) for i in eval_idx], axis=0) | |
| y_true = lab_te[eval_idx] | |
| preds_real = mlp.predict(X_real) | |
| gens = [] | |
| H, W = int(imgs_te.shape[-2]), int(imgs_te.shape[-1]) | |
| for i0 in range(0, n_ev, 8): | |
| bs_chunk = min(8, n_ev - i0) | |
| lbl = y_true[i0 : i0 + bs_chunk] | |
| g = generate_maps(model_ddpm, lbl, mean, std, H, W, device, ddim_steps, bs_chunk) | |
| gens.extend([pk_row(g[j]) for j in range(len(g))]) | |
| X_gen = np.stack(gens, axis=0) | |
| preds_gen = mlp.predict(X_gen) | |
| rmse_real = np.sqrt(mean_squared_error(y_true, preds_real, multioutput="raw_values")) | |
| rmse_gen = np.sqrt(mean_squared_error(y_true, preds_gen, multioutput="raw_values")) | |
| fig, axes = plt.subplots(2, ldim, figsize=(max(9.0, 2.8 * max(ldim, 2)), 4.9), squeeze=False) | |
| if ldim == 1: | |
| axes = np.reshape(axes, (2, 1)) | |
| for k in range(ldim): | |
| for row, preds, rmv, ylab in ( | |
| (0, preds_real, rmse_real, "CAMELS P(k) predictions"), | |
| (1, preds_gen, rmse_gen, f"{tag}: generated P(k)"), | |
| ): | |
| ax = axes[row, k] | |
| lo = float(y_true[:, k].min()); hi = float(y_true[:, k].max()) | |
| pad = 0.03 * (hi - lo + 1e-12) | |
| ax.scatter(y_true[:, k], preds[:, k], s=14, alpha=0.72, edgecolors="none", c="#333") | |
| ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], color="crimson", lw=1.0) | |
| ax.grid(True, alpha=0.28) | |
| ax.set_title(f"dim {k} RMSE={float(rmv[k]):.4f}", fontsize=8) | |
| if k == 0: | |
| ax.set_ylabel(ylab, fontsize=8) | |
| plt.suptitle( | |
| "MLP: train on CAMELS train P(k), test on CAMELS vs DDPM-drawn spectra", | |
| fontsize=10, | |
| y=1.02, | |
| ) | |
| plt.tight_layout() | |
| p = out_dir / f"mlp_pk_parameter_recovery_{tag}.png" | |
| plt.savefig(p, dpi=165, bbox_inches="tight") | |
| plt.close(fig) | |
| print("Saved", p) | |
| def posterior_one_index( | |
| out_dir: Path, | |
| images_split: np.ndarray, | |
| labels_split: np.ndarray, | |
| lab_mean: np.ndarray, | |
| lab_std: np.ndarray, | |
| model, | |
| cfg: Dict, | |
| device, | |
| ix: int, | |
| tag: str, | |
| ddim_steps: int, | |
| grid: int, | |
| batch_sz: int, | |
| ): | |
| normalize = bool(cfg.get("normalize_labels", True)) | |
| lab_dim = labels_split.shape[1] | |
| H, W = int(images_split.shape[-2]), int(images_split.shape[-1]) | |
| obs = images_split[ix] | |
| label_anchor_full = labels_split[ix].astype(np.float32) | |
| lo0 = float(labels_split[:, 0].min()) | |
| hi0 = float(labels_split[:, 0].max()) | |
| lo1 = float(labels_split[:, 1].min()) | |
| hi1 = float(labels_split[:, 1].max()) | |
| pad0 = 0.02 * (hi0 - lo0 + 1e-12) | |
| pad1 = 0.02 * (hi1 - lo1 + 1e-12) | |
| om_ax, s8_ax, OG, SG, grid2 = build_cosmo_grid( | |
| grid, lo0 - pad0, hi0 + pad0, lo1 - pad1, hi1 + pad1 | |
| ) | |
| g = grid | |
| ngrid = grid2.shape[0] | |
| npix = int(obs.shape[-1]) | |
| dl = 25.0 / npix | |
| dk, _ = ec.PowerSpectrum(em.images01_to_log_nhi(obs), N=npix, dl=dl) | |
| valid = dk > 0 | |
| log_pd = log_pk_observed(obs, 25.0, dk) | |
| OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij") | |
| full = np.tile(label_anchor_full[np.newaxis, :], (ngrid, 1)) | |
| full[:, 0] = grid2[:, 0].astype(np.float32) | |
| full[:, 1] = grid2[:, 1].astype(np.float32) | |
| def weights_full() -> np.ndarray: | |
| scores = [] | |
| for j0 in range(0, ngrid, batch_sz): | |
| chunk = full[j0 : j0 + batch_sz] | |
| imgs = em.sample_batch( | |
| model, | |
| chunk, | |
| lab_mean, | |
| lab_std, | |
| normalize, | |
| H, | |
| W, | |
| device, | |
| ddim_steps, | |
| False, | |
| ) | |
| _, pkc = em.per_map_power_spectra_log(imgs, 25.0) | |
| log_pg = np.log(pkc[:, valid] + 1e-30) | |
| mse = np.mean((log_pd[np.newaxis, :] - log_pg) ** 2, axis=1) | |
| scores.append(-mse / (2.0 * 0.25**2)) | |
| sc = np.concatenate(scores) | |
| sc -= sc.max() | |
| w = np.exp(sc).reshape(g, g) | |
| w /= w.sum() | |
| return w | |
| Wmap = weights_full() | |
| tom, ts8 = float(label_anchor_full[0]), float(label_anchor_full[1]) | |
| mom = float((Wmap * OM).sum()) | |
| ms8 = float((Wmap * S8).sum()) | |
| fig, ax = plt.subplots(figsize=(5.2, 4.6)) | |
| cf = ax.contourf(OM, S8, Wmap, levels=12, cmap="Blues") | |
| plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04) | |
| ax.scatter(tom, ts8, s=55, c="r", marker="x", zorder=6, label="true") | |
| ax.scatter(mom, ms8, s=60, c="k", marker="+", zorder=6, label="post. mean") | |
| ax.set_xlabel(r"$\Omega_m$") | |
| ax.set_ylabel(r"$\sigma_8$") | |
| ax.legend(fontsize=8) | |
| ax.set_title(f"Surrogate posterior (test ix={ix}, ldim={lab_dim})", fontsize=10) | |
| p = out_dir / f"posterior_surrogate_test_ix_{ix}_{tag}.png" | |
| fig.savefig(p, dpi=160, bbox_inches="tight") | |
| plt.close(fig) | |
| print("Saved", p) | |
| def main(argv: Sequence[str] | None = None) -> None: | |
| p = argparse.ArgumentParser(description="DDPM-2 vs DDPM-6 comparison suite.") | |
| p.add_argument( | |
| "--output-dir", | |
| type=Path, | |
| default=MODELS_ROOT / "ddpm_comparison_out", | |
| ) | |
| p.add_argument("--data-2param", type=Path, default=Path("<DDPM_ROOT>/data/LH_data/params_2")) | |
| p.add_argument("--data-6param", type=Path, default=Path("<DDPM_ROOT>/data/LH_data/params_6")) | |
| p.add_argument( | |
| "--bundle-2param", | |
| type=Path, | |
| default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200", | |
| ) | |
| p.add_argument( | |
| "--bundle-6param", | |
| type=Path, | |
| default=MODELS_ROOT / "notebook_model_weights" / "6param_best", | |
| ) | |
| p.add_argument("--posterior-index", type=int, default=56) | |
| p.add_argument("--lhs-n", type=int, default=50) | |
| p.add_argument("--six-n-per-anchor", type=int, default=15) | |
| p.add_argument("--ddim-steps", type=int, default=50) | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument("--batch-size", type=int, default=8) | |
| p.add_argument("--slurm-6param", type=Path, default=DEFAULT_SLURM_6) | |
| p.add_argument( | |
| "--slurm-2param", | |
| type=Path, | |
| default=DEFAULT_DDPM2_TRAINING, | |
| help="DDPM-2 train/val series: Slurm .out (parsed) or bundled ddpm_2param_training_loss.json.", | |
| ) | |
| p.add_argument("--skip-lhs-r2", action="store_true", help="LHS R² plots are expensive; skip if set.") | |
| p.add_argument("--n-random-triplets", type=int, default=4) | |
| args = p.parse_args(list(argv) if argv is not None else None) | |
| out_dir = Path(args.output_dir).resolve() | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("device:", device) | |
| ck2 = args.bundle_2param / "checkpoint_epoch_200.pt" | |
| args2 = args.bundle_2param / "args.json" | |
| ck6 = args.bundle_6param / "best_model.pt" | |
| args6 = args.bundle_6param / "args.json" | |
| data2 = Path(args.data_2param) | |
| data6 = Path(args.data_6param) | |
| imgs6, lab6 = ec.load_split(data6, "test") | |
| mean6, std6 = ec.load_label_stats(data6) | |
| mean2, std2 = ec.load_label_stats(data2) | |
| plot_training_overlay(out_dir, args.slurm_6param, args.slurm_2param) | |
| imgs2, lab2 = ec.load_split(data2, "test") | |
| print(">>> Loading DDPM-2...") | |
| model2, cfg2 = load_model(args2, ck2, device) | |
| print(">>> Loading DDPM-6...") | |
| model6, cfg6 = load_model(args6, ck6, device) | |
| print(">>> Random LHS + conditioned triplets...") | |
| try: | |
| run_random_theta_triplets( | |
| out_dir, | |
| imgs6, | |
| lab6, | |
| mean6, | |
| std6, | |
| mean2, | |
| std2, | |
| model2, | |
| model6, | |
| device=device, | |
| ddim_steps=args.ddim_steps, | |
| seed=args.seed, | |
| n_pairs=args.n_random_triplets, | |
| batch_size=args.batch_size, | |
| ) | |
| run_conditioned_test_triplets( | |
| out_dir, | |
| imgs6, | |
| lab6, | |
| mean6, | |
| std6, | |
| mean2, | |
| std2, | |
| model2, | |
| model6, | |
| device=device, | |
| ddim_steps=args.ddim_steps, | |
| seed=args.seed, | |
| n_pairs=args.n_random_triplets, | |
| batch_size=args.batch_size, | |
| ) | |
| except Exception as exc: | |
| print("Triplet grids failed:", exc) | |
| print(">>> Six-anchor overlays (combined + per-model)...") | |
| try: | |
| pk_six_triplet_combined( | |
| out_dir, | |
| imgs6, | |
| lab6, | |
| mean6, | |
| std6, | |
| mean2, | |
| std2, | |
| model2, | |
| model6, | |
| device=device, | |
| ddim_steps=args.ddim_steps, | |
| batch_size=args.batch_size, | |
| n_per_set=args.six_n_per_anchor, | |
| ) | |
| pk_pdf_six_sets( | |
| out_dir, | |
| "ddpm6_only", | |
| imgs6, | |
| lab6, | |
| mean6, | |
| std6, | |
| model6, | |
| device, | |
| args.ddim_steps, | |
| args.batch_size, | |
| args.six_n_per_anchor, | |
| ) | |
| pk_pdf_six_sets( | |
| out_dir, | |
| "ddpm2_only", | |
| imgs2, | |
| lab2, | |
| mean2, | |
| std2, | |
| model2, | |
| device, | |
| args.ddim_steps, | |
| args.batch_size, | |
| args.six_n_per_anchor, | |
| ) | |
| except Exception as exc: | |
| print("P(k)/PDF six-anchor plots failed:", exc) | |
| if not args.skip_lhs_r2: | |
| print(">>> LHS R² (LHS-50 × 15 DDIM each — long)...") | |
| try: | |
| for label, imgs, labs, mn, sd, mdl in ( | |
| ("ddpm2_lhs50", imgs2, lab2, mean2, std2, model2), | |
| ("ddpm6_lhs50", imgs6, lab6, mean6, std6, model6), | |
| ): | |
| lhs_pts, r2_mu, r2_sig, lo_b, hi_b = compute_lhs_r2( | |
| mdl, | |
| imgs, | |
| labs, | |
| mn, | |
| sd, | |
| device, | |
| args.lhs_n, | |
| 15, | |
| args.batch_size, | |
| 25.0, | |
| args.ddim_steps, | |
| args.seed, | |
| ) | |
| outp = out_dir / f"r2_cosmology_lhs{args.lhs_n}_{label}.png" | |
| plot_r2_cosmology_figure(lhs_pts, r2_mu, r2_sig, lo_b, hi_b, outp, dpi=160) | |
| print("Saved", outp) | |
| np.savez( | |
| out_dir / f"r2_lhs_data_{label}.npz", | |
| lhs_pts=lhs_pts, | |
| r2_mu_arr=r2_mu, | |
| r2_sig_arr=r2_sig, | |
| lo_b=lo_b, | |
| hi_b=hi_b, | |
| ) | |
| except Exception as exc: | |
| print("LHS R² skipped:", exc) | |
| else: | |
| print("(Skipping LHS R².)") | |
| print(">>> MLP P(k) parameter recovery...") | |
| try: | |
| mlp_recovery_dual( | |
| out_dir, data2, imgs2[:40], lab2[:40], mean2, std2, model2, "ddpm2param", device, args.ddim_steps, args.seed | |
| ) | |
| mlp_recovery_dual( | |
| out_dir, data6, imgs6[:40], lab6[:40], mean6, std6, model6, "ddpm6param", device, args.ddim_steps, args.seed | |
| ) | |
| except Exception as exc: | |
| print("MLP recovery skipped:", exc) | |
| print(f">>> Surrogate posteriors (test index {args.posterior_index})...") | |
| try: | |
| ix = int(args.posterior_index) | |
| posterior_one_index( | |
| out_dir, imgs6, lab6, mean6, std6, model6, cfg6, device, ix, "ddpm6", args.ddim_steps, 14, args.batch_size | |
| ) | |
| posterior_one_index( | |
| out_dir, | |
| imgs2, | |
| lab2, | |
| mean2, | |
| std2, | |
| model2, | |
| cfg2, | |
| device, | |
| ix, | |
| "ddpm2", | |
| args.ddim_steps, | |
| 14, | |
| args.batch_size, | |
| ) | |
| except Exception as exc: | |
| print("Posterior panels skipped:", exc) | |
| del model2, model6 | |
| free_torch() | |
| print(f"Done. Outputs in {out_dir}") | |
| if __name__ == "__main__": | |
| main() |