#!/usr/bin/env python3 """ Compare 2-parameter and 6-parameter conditional DDPMs (CAMELS LH) side-by-side: • Random-draw vs test-conditioned triplets (CAMELS | DDPM-2 | DDPM-6) • Six anchor cosmologies: P(k) and PDF diagnostics (triple curves per panel where applicable) • LHS R² cosmology plots (LHS-50 × 15 maps — expensive) • MLP P(k) → label recovery ( sklearn MLP, two models + shared CAMELS calibration ) • Surrogate posterior on (Ωm, σ8) for a fixed test index • Training / validation loss on one axis (Slurm .out for DDPM-6; DDPM-2 defaults to bundled JSON) Outputs under --output-dir (default: Models/ddpm_comparison_out/). GPU: both models are resident while generating comparison panels; use a single GPU with sufficient memory, or run heavier steps separately with refactors. """ from __future__ import annotations import argparse import gc import sys from pathlib import Path from typing import Dict, Sequence, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch # --- Repo paths --- MODELS_ROOT = Path(__file__).resolve().parents[1] CODE_6 = (MODELS_ROOT / "6param_ddpm_hi_lh6").resolve() if str(CODE_6) not in sys.path: sys.path.insert(0, str(CODE_6)) import evaluate_conditional as ec # noqa: E402 import eval_model as em # noqa: E402 from figure9_posterior import build_cosmo_grid, log_pk_observed # noqa: E402 from plot_r2_cosmology_lhs import compute_lhs_r2, plot_r2_cosmology_figure # noqa: E402 from compare_ddpm_training_curves import ( # noqa: E402 load_train_val_series, parse_slurm_training_log, ) DEFAULT_SLURM_6 = Path( "/april_26/ddpm_hi_lh6/scripts/shell/slurm-698243.out" ) # Bundled train/val (no 2-param Slurm log in-repo); see ``ddpm_2param_training_loss.json``. DEFAULT_DDPM2_TRAINING = (Path(__file__).resolve().parent / "ddpm_2param_training_loss.json") def _fmt_title(lab: np.ndarray) -> str: t = np.asarray(lab, dtype=float).ravel() if t.size <= 2: return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f}" tail = ", ".join(f"{float(v):.3g}" for v in t[2:]) return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f} | " + tail def _latin_hypercube(n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator) -> np.ndarray: """Classic LHS (same as notebook).""" d = int(lo.shape[0]) u = rng.random((n, d)) cut = np.linspace(0.0, 1.0, n + 1) a, b = cut[:-1], cut[1:] width = (b - a)[:, np.newaxis] rd = a[:, np.newaxis] + u * width for j in range(d): rng.shuffle(rd[:, j]) span = (hi - lo).astype(np.float64) return (lo + rd * span).astype(np.float32) @torch.no_grad() def generate_maps( model: torch.nn.Module, labels_np: np.ndarray, label_mean: np.ndarray, label_std: np.ndarray, H: int, W: int, device: torch.device, ddim_steps: int, batch_size: int, ) -> np.ndarray: out = [] n = labels_np.shape[0] for j0 in range(0, n, batch_size): chunk = labels_np[j0 : j0 + batch_size] bt = ec.prepare_labels_for_model(chunk.astype(np.float32), label_mean, label_std).to(device) g = model.sample( labels=bt, channels=1, height=H, width=W, device=device, progress=False, use_ddim=True, ddim_steps=ddim_steps, ) out.append(ec.from_model_output(g)) return np.concatenate(out, axis=0) def load_model(bundle_args: Path, ckpt: Path, device: torch.device): cfg = ec.load_training_config(str(bundle_args)) model = ec.build_model(cfg, device) ec.load_checkpoint(model, str(ckpt), device) model.eval() return model, cfg def free_torch(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def plot_training_overlay( out_dir: Path, slurm6: Path | None, slurm2: Path | None, ) -> None: """Train + val curves for DDPM6 and optionally DDPM2 on one logarithmic-loss axis.""" fig, ax = plt.subplots(figsize=(9, 5)) plotted = False if slurm6 and Path(slurm6).is_file(): ep, tr, va = parse_slurm_training_log(slurm6) ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-6 train", color="#1f77b4", alpha=0.85) ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-6 val", color="#174a75", alpha=0.95) plotted = True else: print("Warning: 6-param Slurm log not found; skipped overlay for DDPM-6.") if slurm2 and Path(slurm2).is_file(): ep, tr, va = load_train_val_series(slurm2) ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-2 train", color="#ff7f0e", alpha=0.85) ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-2 val", color="#994d00", alpha=0.95) plotted = True elif slurm2 is not None: print(f"Warning: 2-param training series not found ({slurm2}); use --slurm-2param or restore bundled JSON.") if not plotted: print("No Slurm logs parsed — writing placeholder note instead of curves.") ax.text( 0.5, 0.5, "Pass --slurm-6param; DDPM-2 uses bundled JSON by default (--slurm-2param).", ha="center", va="center", transform=ax.transAxes, ) else: ax.set_yscale("log") ax.grid(True, alpha=0.3) ax.set_xlabel("Epoch") ax.set_ylabel("MSE diffusion loss") ax.legend(loc="upper right", fontsize=8) ax.set_title("Training / validation curves (combined)") outp = out_dir / "comparison_training_train_val_overlay.png" fig.savefig(outp, dpi=170, bbox_inches="tight") plt.close(fig) print("Saved", outp) def run_random_theta_triplets( out_dir: Path, imgs6: np.ndarray, lab6: np.ndarray, mean6: np.ndarray, std6: np.ndarray, mean2: np.ndarray, std2: np.ndarray, model2, model6, device: torch.device, ddim_steps: int, seed: int, n_pairs: int, batch_size: int, ): """Random LHS targets in CAMELS bbox; CAMELS column = NN real map.""" rng = np.random.default_rng(seed) lo, hi = lab6.min(0), lab6.max(0) targets = _latin_hypercube(min(n_pairs, 32), lo, hi, rng)[:n_pairs] H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1]) tg2 = targets[:, :2].astype(np.float32) fig = plt.figure(figsize=(3.8 * max(3, n_pairs * 3), 4.1)) for i in range(n_pairs): theta6 = targets[i].astype(np.float32) theta2 = tg2[i] dist = np.linalg.norm(lab6 - theta6[None, :], axis=1).astype(np.float64) nn = int(np.argmin(dist)) nn_img = imgs6[nn] gen2 = generate_maps( model2, theta2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size ) gen6 = generate_maps( model6, theta6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size ) titles = ("CAMELS (NN)", "DDPM-2", "DDPM-6") for j, img in enumerate((nn_img, gen2[0], gen6[0])): ax = fig.add_subplot(1, n_pairs * 3, i * 3 + j + 1) ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno") ax.axis("off") ax.set_title(titles[j], fontsize=8) plt.suptitle( "Random LHS cosmologies — CAMELS = nearest-neighbour truth | gens conditioned on LHS labels", fontsize=10, y=1.02, ) p = out_dir / "comparison_random_lhs_triplets_camels_ddpm2_ddpm6.png" plt.tight_layout() plt.savefig(p, dpi=160, bbox_inches="tight") plt.close(fig) print("Saved", p) def run_conditioned_test_triplets( out_dir: Path, imgs6: np.ndarray, lab6: np.ndarray, mean6: np.ndarray, std6: np.ndarray, mean2: np.ndarray, std2: np.ndarray, model2, model6, device: torch.device, ddim_steps: int, seed: int, n_pairs: int, batch_size: int, ): """Same rows from test split: conditioned on truth labels.""" rng = np.random.default_rng(seed + 1) idx = rng.choice(len(imgs6), size=min(n_pairs, len(imgs6)), replace=False) H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1]) fig, axes = plt.subplots(1, n_pairs * 3, figsize=(2.9 * n_pairs * 3, 3.8), squeeze=False) for ii, ix in enumerate(idx): tg6 = lab6[ix].astype(np.float32) tg2 = tg6[:2] rm = imgs6[ix] g2 = generate_maps(model2, tg2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size)[0] g6 = generate_maps(model6, tg6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size)[0] for j, img in enumerate((rm, g2, g6)): ax = axes[0, ii * 3 + j] ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno") ax.axis("off") if ii == 0: ax.set_title(("CAMELS", "DDPM-2", "DDPM-6")[j], fontsize=8) axes[0, ii * 3].set_xlabel(_fmt_title(tg6), fontsize=7) plt.suptitle(f"Random test ix (conditioned on truth labels), n={len(idx)}", fontsize=10, y=1.06) p = out_dir / "comparison_test_conditioned_camels_ddpm2_ddpm6.png" plt.savefig(p, dpi=160, bbox_inches="tight") plt.close(fig) print("Saved", p) def pk_pdf_six_sets( out_dir: Path, name: str, images_split: np.ndarray, labels_split: np.ndarray, label_mean: np.ndarray, label_std: np.ndarray, model, device: torch.device, ddim_steps: int, batch_size: int, n_per_set: int, ): """Six anchor rows (evenly spaced ix in test split), N_PER_SET DDIM samples.""" H, W = int(images_split.shape[-2]), int(images_split.shape[-1]) ldim = int(labels_split.shape[1]) idx = np.linspace(0, len(labels_split) - 1, num=6, dtype=int) targets = labels_split[idx].copy() box = 25.0 dk_ref = None panels_pk = [] rng_pdf_bins = np.linspace(14.0, 22.0, 101) bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:]) fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True) axes_pk = axes_pk.ravel() fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 4.8 * 2), squeeze=False) for si, target_l in enumerate(targets): dist = np.linalg.norm(labels_split - target_l[None, :], axis=1).astype(np.float64) ex = idx[si] dist[ex] = np.inf if ex < len(dist) else np.inf nn_idx = np.argsort(dist)[:n_per_set] real_batch = images_split[nn_idx] rep = np.tile(target_l[None, :], (n_per_set, 1)) gen = generate_maps(model, rep, label_mean, label_std, H, W, device, ddim_steps, batch_size) dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box) dk_g, mg, sg = ec.calculate_power_spectrum_batch(gen, box_size=box) dk_ref = dk_r x = dk_ref[1:] axpk = axes_pk[si] axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#333") axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.08, color="#333") axpk.plot(x, mg[1:], lw=2, label=f"Generated (ldim={ldim})", color="#d95f02") axpk.fill_between(x, mg[1:] - sg[1:], mg[1:] + sg[1:], alpha=0.08, color="#d95f02") axpk.set_yscale("log") axpk.grid(alpha=0.25) axpk.set_title(_fmt_title(target_l), fontsize=8) panels_pk.append((si, dk_r, mr, sr, mg, sg)) # PDF µ/σ tb = []; rb = [] for i in range(n_per_set): for arr, store in zip((real_batch, gen), (tb, rb)): ims = np.clip(arr[i].ravel(), 0.0, 1.0) logn = 14.0 + (22.0 - 14.0) * ims hst, _ = np.histogram(logn, bins=rng_pdf_bins, density=True) store.append(hst) tb = np.asarray(tb); rb = np.asarray(rb) axes_pdf[si, 0].plot(bin_pdf, tb.mean(axis=0), lw=2, label="CAMELS NN", color="#333") axes_pdf[si, 0].plot(bin_pdf, rb.mean(axis=0), lw=2, label="Generated", color="#d95f02") axes_pdf[si, 1].plot(bin_pdf, tb.std(axis=0), lw=2, ls="-", label="CAMELS σ", color="#333") axes_pdf[si, 1].plot(bin_pdf, rb.std(axis=0), lw=2, ls="--", label="Gen σ", color="#d95f02") axes_pk[0].legend(fontsize=7, loc="lower left") fig_pk.suptitle(f"$P(k)$ mean±std — six anchors — {name}", fontsize=10) fig_pk.tight_layout() p_pk = out_dir / f"six_anchor_pk_{name}.png" fig_pk.savefig(p_pk, dpi=160) plt.close(fig_pk) axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$") axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$") fig_pdf.suptitle(f"PDF mean & σ — six anchors × {n_per_set} — {name}") fig_pdf.tight_layout() p_pdf = out_dir / f"six_anchor_pdf_mu_sigma_{name}.png" fig_pdf.savefig(p_pdf, dpi=160) plt.close(fig_pdf) print("Saved", p_pk) print("Saved", p_pdf) def pk_six_triplet_combined( out_dir: Path, imgs6: np.ndarray, lab6: np.ndarray, mean6: np.ndarray, std6: np.ndarray, mean2: np.ndarray, std2: np.ndarray, model2: torch.nn.Module, model6: torch.nn.Module, device: torch.device, ddim_steps: int, batch_size: int, n_per_set: int, ) -> None: """Six anchors — mean P(k) for CAMELS vs DDPM-2 vs DDPM-6; analogous PDF overlays.""" H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1]) idx = np.linspace(0, len(lab6) - 1, num=6, dtype=int) targets = lab6[idx].copy() box = 25.0 fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True) axes_pk = axes_pk.ravel() rng_pdf_bins = np.linspace(14.0, 22.0, 101) bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:]) fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 10.5)) for si, target_l in enumerate(targets): dist = np.linalg.norm(lab6 - target_l[None, :], axis=1).astype(np.float64) ex = int(idx[si]) if ex < len(dist): dist = dist.copy() dist[ex] = np.inf nn_idx = np.argsort(dist)[:n_per_set] real_batch = imgs6[nn_idx] tg2 = np.tile(target_l[:2][None, :], (n_per_set, 1)).astype(np.float32) tg6 = np.tile(target_l[None, :], (n_per_set, 1)).astype(np.float32) gen2 = generate_maps(model2, tg2, mean2, std2, H, W, device, ddim_steps, batch_size) gen6 = generate_maps(model6, tg6, mean6, std6, H, W, device, ddim_steps, batch_size) dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box) _, m2, s2 = ec.calculate_power_spectrum_batch(gen2, box_size=box) _, mG, sG = ec.calculate_power_spectrum_batch(gen6, box_size=box) x = dk_r[1:] axpk = axes_pk[si] axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#222") axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.06, color="#222") axpk.plot(x, m2[1:], lw=2, label="DDPM-2 μ", color="#ff7f0e") axpk.fill_between(x, m2[1:] - s2[1:], m2[1:] + s2[1:], alpha=0.06, color="#ff7f0e") axpk.plot(x, mG[1:], lw=2, label="DDPM-6 μ", color="#1f77b4") axpk.fill_between(x, mG[1:] - sG[1:], mG[1:] + sG[1:], alpha=0.06, color="#1f77b4") axpk.set_yscale("log") axpk.grid(alpha=0.25) axpk.set_title(_fmt_title(target_l), fontsize=8) if si == 0: axpk.legend(fontsize=6.2, loc="lower left") pdf_rows_lists = [] for imgs in (real_batch, gen2, gen6): hb = [] for i in range(min(n_per_set, len(imgs))): px = np.clip(imgs[i].ravel(), 0.0, 1.0) ln = 14.0 + (22.0 - 14.0) * px hb.append(np.histogram(ln, bins=rng_pdf_bins, density=True)[0]) pdf_rows_lists.append(np.asarray(hb)) cam_pdf, d2_pdf, d6_pdf = pdf_rows_lists axes_pdf[si, 0].plot(bin_pdf, cam_pdf.mean(axis=0), lw=2, color="#222", label="CAMELS μ") axes_pdf[si, 0].plot(bin_pdf, d2_pdf.mean(axis=0), lw=2, color="#ff7f0e", label="DDPM-2 μ") axes_pdf[si, 0].plot(bin_pdf, d6_pdf.mean(axis=0), lw=2, color="#1f77b4", label="DDPM-6 μ") axes_pdf[si, 1].plot(bin_pdf, cam_pdf.std(axis=0), lw=2, color="#222") axes_pdf[si, 1].plot(bin_pdf, d2_pdf.std(axis=0), lw=2, ls="--", color="#ff7f0e") axes_pdf[si, 1].plot(bin_pdf, d6_pdf.std(axis=0), lw=2, ls="--", color="#1f77b4") fig_pk.suptitle("$P(k)$ CAMELS vs DDPM-2 vs DDPM-6 — six Ωm–σ8 anchors", fontsize=11) fig_pk.tight_layout() p_pk = out_dir / "six_anchor_pk_overlay_camels_ddpm2_ddpm6.png" fig_pk.savefig(p_pk, dpi=160) plt.close(fig_pk) axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$") axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$") fig_pdf.suptitle(r"PDF mean ($\mu$) and std ($\sigma$) overlays", fontsize=10) fig_pdf.tight_layout() p_pdf = out_dir / "six_anchor_pdf_overlay_camels_ddpm2_ddpm6.png" fig_pdf.savefig(p_pdf, dpi=160) plt.close(fig_pdf) print("Saved", p_pk) print("Saved", p_pdf) def mlp_recovery_dual( out_dir: Path, data_train: Path, imgs_te: np.ndarray, lab_te: np.ndarray, mean: np.ndarray, std: np.ndarray, model_ddpm: torch.nn.Module, tag: str, device: torch.device, ddim_steps: int, seed: int, ) -> None: from sklearn.metrics import mean_squared_error from sklearn.neural_network import MLPRegressor ldim = lab_te.shape[1] Npix = imgs_te.shape[-1] dl = 25.0 / Npix def pk_row(im): _dk, pk = ec.PowerSpectrum(np.asarray(im, dtype=np.float64), N=Npix, dl=dl) return pk[1:].astype(np.float32) img_tr_np, lab_tr_np = ec.load_split(data_train, "train") if len(img_tr_np) > 2000: rng = np.random.default_rng(seed) jj = rng.choice(len(img_tr_np), 2000, replace=False) img_tr_np, lab_tr_np = img_tr_np[jj], lab_tr_np[jj] X_train = np.stack([pk_row(img_tr_np[i]) for i in range(len(img_tr_np))], axis=0) y_train = lab_tr_np.astype(np.float32) mlp = MLPRegressor( hidden_layer_sizes=(64, 64), alpha=1e-4, random_state=seed, max_iter=250, early_stopping=True, validation_fraction=0.1, ) mlp.fit(X_train, y_train) n_ev = min(40, len(imgs_te)) eval_idx = np.arange(n_ev) X_real = np.stack([pk_row(imgs_te[i]) for i in eval_idx], axis=0) y_true = lab_te[eval_idx] preds_real = mlp.predict(X_real) gens = [] H, W = int(imgs_te.shape[-2]), int(imgs_te.shape[-1]) for i0 in range(0, n_ev, 8): bs_chunk = min(8, n_ev - i0) lbl = y_true[i0 : i0 + bs_chunk] g = generate_maps(model_ddpm, lbl, mean, std, H, W, device, ddim_steps, bs_chunk) gens.extend([pk_row(g[j]) for j in range(len(g))]) X_gen = np.stack(gens, axis=0) preds_gen = mlp.predict(X_gen) rmse_real = np.sqrt(mean_squared_error(y_true, preds_real, multioutput="raw_values")) rmse_gen = np.sqrt(mean_squared_error(y_true, preds_gen, multioutput="raw_values")) fig, axes = plt.subplots(2, ldim, figsize=(max(9.0, 2.8 * max(ldim, 2)), 4.9), squeeze=False) if ldim == 1: axes = np.reshape(axes, (2, 1)) for k in range(ldim): for row, preds, rmv, ylab in ( (0, preds_real, rmse_real, "CAMELS P(k) predictions"), (1, preds_gen, rmse_gen, f"{tag}: generated P(k)"), ): ax = axes[row, k] lo = float(y_true[:, k].min()); hi = float(y_true[:, k].max()) pad = 0.03 * (hi - lo + 1e-12) ax.scatter(y_true[:, k], preds[:, k], s=14, alpha=0.72, edgecolors="none", c="#333") ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], color="crimson", lw=1.0) ax.grid(True, alpha=0.28) ax.set_title(f"dim {k} RMSE={float(rmv[k]):.4f}", fontsize=8) if k == 0: ax.set_ylabel(ylab, fontsize=8) plt.suptitle( "MLP: train on CAMELS train P(k), test on CAMELS vs DDPM-drawn spectra", fontsize=10, y=1.02, ) plt.tight_layout() p = out_dir / f"mlp_pk_parameter_recovery_{tag}.png" plt.savefig(p, dpi=165, bbox_inches="tight") plt.close(fig) print("Saved", p) def posterior_one_index( out_dir: Path, images_split: np.ndarray, labels_split: np.ndarray, lab_mean: np.ndarray, lab_std: np.ndarray, model, cfg: Dict, device, ix: int, tag: str, ddim_steps: int, grid: int, batch_sz: int, ): normalize = bool(cfg.get("normalize_labels", True)) lab_dim = labels_split.shape[1] H, W = int(images_split.shape[-2]), int(images_split.shape[-1]) obs = images_split[ix] label_anchor_full = labels_split[ix].astype(np.float32) lo0 = float(labels_split[:, 0].min()) hi0 = float(labels_split[:, 0].max()) lo1 = float(labels_split[:, 1].min()) hi1 = float(labels_split[:, 1].max()) pad0 = 0.02 * (hi0 - lo0 + 1e-12) pad1 = 0.02 * (hi1 - lo1 + 1e-12) om_ax, s8_ax, OG, SG, grid2 = build_cosmo_grid( grid, lo0 - pad0, hi0 + pad0, lo1 - pad1, hi1 + pad1 ) g = grid ngrid = grid2.shape[0] npix = int(obs.shape[-1]) dl = 25.0 / npix dk, _ = ec.PowerSpectrum(em.images01_to_log_nhi(obs), N=npix, dl=dl) valid = dk > 0 log_pd = log_pk_observed(obs, 25.0, dk) OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij") full = np.tile(label_anchor_full[np.newaxis, :], (ngrid, 1)) full[:, 0] = grid2[:, 0].astype(np.float32) full[:, 1] = grid2[:, 1].astype(np.float32) def weights_full() -> np.ndarray: scores = [] for j0 in range(0, ngrid, batch_sz): chunk = full[j0 : j0 + batch_sz] imgs = em.sample_batch( model, chunk, lab_mean, lab_std, normalize, H, W, device, ddim_steps, False, ) _, pkc = em.per_map_power_spectra_log(imgs, 25.0) log_pg = np.log(pkc[:, valid] + 1e-30) mse = np.mean((log_pd[np.newaxis, :] - log_pg) ** 2, axis=1) scores.append(-mse / (2.0 * 0.25**2)) sc = np.concatenate(scores) sc -= sc.max() w = np.exp(sc).reshape(g, g) w /= w.sum() return w Wmap = weights_full() tom, ts8 = float(label_anchor_full[0]), float(label_anchor_full[1]) mom = float((Wmap * OM).sum()) ms8 = float((Wmap * S8).sum()) fig, ax = plt.subplots(figsize=(5.2, 4.6)) cf = ax.contourf(OM, S8, Wmap, levels=12, cmap="Blues") plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04) ax.scatter(tom, ts8, s=55, c="r", marker="x", zorder=6, label="true") ax.scatter(mom, ms8, s=60, c="k", marker="+", zorder=6, label="post. mean") ax.set_xlabel(r"$\Omega_m$") ax.set_ylabel(r"$\sigma_8$") ax.legend(fontsize=8) ax.set_title(f"Surrogate posterior (test ix={ix}, ldim={lab_dim})", fontsize=10) p = out_dir / f"posterior_surrogate_test_ix_{ix}_{tag}.png" fig.savefig(p, dpi=160, bbox_inches="tight") plt.close(fig) print("Saved", p) def main(argv: Sequence[str] | None = None) -> None: p = argparse.ArgumentParser(description="DDPM-2 vs DDPM-6 comparison suite.") p.add_argument( "--output-dir", type=Path, default=MODELS_ROOT / "ddpm_comparison_out", ) p.add_argument("--data-2param", type=Path, default=Path("/data/LH_data/params_2")) p.add_argument("--data-6param", type=Path, default=Path("/data/LH_data/params_6")) p.add_argument( "--bundle-2param", type=Path, default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200", ) p.add_argument( "--bundle-6param", type=Path, default=MODELS_ROOT / "notebook_model_weights" / "6param_best", ) p.add_argument("--posterior-index", type=int, default=56) p.add_argument("--lhs-n", type=int, default=50) p.add_argument("--six-n-per-anchor", type=int, default=15) p.add_argument("--ddim-steps", type=int, default=50) p.add_argument("--seed", type=int, default=42) p.add_argument("--batch-size", type=int, default=8) p.add_argument("--slurm-6param", type=Path, default=DEFAULT_SLURM_6) p.add_argument( "--slurm-2param", type=Path, default=DEFAULT_DDPM2_TRAINING, help="DDPM-2 train/val series: Slurm .out (parsed) or bundled ddpm_2param_training_loss.json.", ) p.add_argument("--skip-lhs-r2", action="store_true", help="LHS R² plots are expensive; skip if set.") p.add_argument("--n-random-triplets", type=int, default=4) args = p.parse_args(list(argv) if argv is not None else None) out_dir = Path(args.output_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("device:", device) ck2 = args.bundle_2param / "checkpoint_epoch_200.pt" args2 = args.bundle_2param / "args.json" ck6 = args.bundle_6param / "best_model.pt" args6 = args.bundle_6param / "args.json" data2 = Path(args.data_2param) data6 = Path(args.data_6param) imgs6, lab6 = ec.load_split(data6, "test") mean6, std6 = ec.load_label_stats(data6) mean2, std2 = ec.load_label_stats(data2) plot_training_overlay(out_dir, args.slurm_6param, args.slurm_2param) imgs2, lab2 = ec.load_split(data2, "test") print(">>> Loading DDPM-2...") model2, cfg2 = load_model(args2, ck2, device) print(">>> Loading DDPM-6...") model6, cfg6 = load_model(args6, ck6, device) print(">>> Random LHS + conditioned triplets...") try: run_random_theta_triplets( out_dir, imgs6, lab6, mean6, std6, mean2, std2, model2, model6, device=device, ddim_steps=args.ddim_steps, seed=args.seed, n_pairs=args.n_random_triplets, batch_size=args.batch_size, ) run_conditioned_test_triplets( out_dir, imgs6, lab6, mean6, std6, mean2, std2, model2, model6, device=device, ddim_steps=args.ddim_steps, seed=args.seed, n_pairs=args.n_random_triplets, batch_size=args.batch_size, ) except Exception as exc: print("Triplet grids failed:", exc) print(">>> Six-anchor overlays (combined + per-model)...") try: pk_six_triplet_combined( out_dir, imgs6, lab6, mean6, std6, mean2, std2, model2, model6, device=device, ddim_steps=args.ddim_steps, batch_size=args.batch_size, n_per_set=args.six_n_per_anchor, ) pk_pdf_six_sets( out_dir, "ddpm6_only", imgs6, lab6, mean6, std6, model6, device, args.ddim_steps, args.batch_size, args.six_n_per_anchor, ) pk_pdf_six_sets( out_dir, "ddpm2_only", imgs2, lab2, mean2, std2, model2, device, args.ddim_steps, args.batch_size, args.six_n_per_anchor, ) except Exception as exc: print("P(k)/PDF six-anchor plots failed:", exc) if not args.skip_lhs_r2: print(">>> LHS R² (LHS-50 × 15 DDIM each — long)...") try: for label, imgs, labs, mn, sd, mdl in ( ("ddpm2_lhs50", imgs2, lab2, mean2, std2, model2), ("ddpm6_lhs50", imgs6, lab6, mean6, std6, model6), ): lhs_pts, r2_mu, r2_sig, lo_b, hi_b = compute_lhs_r2( mdl, imgs, labs, mn, sd, device, args.lhs_n, 15, args.batch_size, 25.0, args.ddim_steps, args.seed, ) outp = out_dir / f"r2_cosmology_lhs{args.lhs_n}_{label}.png" plot_r2_cosmology_figure(lhs_pts, r2_mu, r2_sig, lo_b, hi_b, outp, dpi=160) print("Saved", outp) np.savez( out_dir / f"r2_lhs_data_{label}.npz", lhs_pts=lhs_pts, r2_mu_arr=r2_mu, r2_sig_arr=r2_sig, lo_b=lo_b, hi_b=hi_b, ) except Exception as exc: print("LHS R² skipped:", exc) else: print("(Skipping LHS R².)") print(">>> MLP P(k) parameter recovery...") try: mlp_recovery_dual( out_dir, data2, imgs2[:40], lab2[:40], mean2, std2, model2, "ddpm2param", device, args.ddim_steps, args.seed ) mlp_recovery_dual( out_dir, data6, imgs6[:40], lab6[:40], mean6, std6, model6, "ddpm6param", device, args.ddim_steps, args.seed ) except Exception as exc: print("MLP recovery skipped:", exc) print(f">>> Surrogate posteriors (test index {args.posterior_index})...") try: ix = int(args.posterior_index) posterior_one_index( out_dir, imgs6, lab6, mean6, std6, model6, cfg6, device, ix, "ddpm6", args.ddim_steps, 14, args.batch_size ) posterior_one_index( out_dir, imgs2, lab2, mean2, std2, model2, cfg2, device, ix, "ddpm2", args.ddim_steps, 14, args.batch_size, ) except Exception as exc: print("Posterior panels skipped:", exc) del model2, model6 free_torch() print(f"Done. Outputs in {out_dir}") if __name__ == "__main__": main()