DDPM-2param / cross_model /scripts /compare_ddpm_models.py
collins909's picture
Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
f513198 verified
#!/usr/bin/env python3
"""
Compare 2-parameter and 6-parameter conditional DDPMs (CAMELS LH) side-by-side:
• Random-draw vs test-conditioned triplets (CAMELS | DDPM-2 | DDPM-6)
• Six anchor cosmologies: P(k) and PDF diagnostics (triple curves per panel where applicable)
• LHS R² cosmology plots (LHS-50 × 15 maps — expensive)
• MLP P(k) → label recovery ( sklearn MLP, two models + shared CAMELS calibration )
• Surrogate posterior on (Ωm, σ8) for a fixed test index
• Training / validation loss on one axis (Slurm .out for DDPM-6; DDPM-2 defaults to bundled JSON)
Outputs under --output-dir (default: Models/ddpm_comparison_out/).
GPU: both models are resident while generating comparison panels; use a single GPU with
sufficient memory, or run heavier steps separately with refactors.
"""
from __future__ import annotations
import argparse
import gc
import sys
from pathlib import Path
from typing import Dict, Sequence, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
# --- Repo paths ---
MODELS_ROOT = Path(__file__).resolve().parents[1]
CODE_6 = (MODELS_ROOT / "6param_ddpm_hi_lh6").resolve()
if str(CODE_6) not in sys.path:
sys.path.insert(0, str(CODE_6))
import evaluate_conditional as ec # noqa: E402
import eval_model as em # noqa: E402
from figure9_posterior import build_cosmo_grid, log_pk_observed # noqa: E402
from plot_r2_cosmology_lhs import compute_lhs_r2, plot_r2_cosmology_figure # noqa: E402
from compare_ddpm_training_curves import ( # noqa: E402
load_train_val_series,
parse_slurm_training_log,
)
DEFAULT_SLURM_6 = Path(
"<DDPM_ROOT>/april_26/ddpm_hi_lh6/scripts/shell/slurm-698243.out"
)
# Bundled train/val (no 2-param Slurm log in-repo); see ``ddpm_2param_training_loss.json``.
DEFAULT_DDPM2_TRAINING = (Path(__file__).resolve().parent / "ddpm_2param_training_loss.json")
def _fmt_title(lab: np.ndarray) -> str:
t = np.asarray(lab, dtype=float).ravel()
if t.size <= 2:
return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f}"
tail = ", ".join(f"{float(v):.3g}" for v in t[2:])
return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f} | " + tail
def _latin_hypercube(n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Classic LHS (same as notebook)."""
d = int(lo.shape[0])
u = rng.random((n, d))
cut = np.linspace(0.0, 1.0, n + 1)
a, b = cut[:-1], cut[1:]
width = (b - a)[:, np.newaxis]
rd = a[:, np.newaxis] + u * width
for j in range(d):
rng.shuffle(rd[:, j])
span = (hi - lo).astype(np.float64)
return (lo + rd * span).astype(np.float32)
@torch.no_grad()
def generate_maps(
model: torch.nn.Module,
labels_np: np.ndarray,
label_mean: np.ndarray,
label_std: np.ndarray,
H: int,
W: int,
device: torch.device,
ddim_steps: int,
batch_size: int,
) -> np.ndarray:
out = []
n = labels_np.shape[0]
for j0 in range(0, n, batch_size):
chunk = labels_np[j0 : j0 + batch_size]
bt = ec.prepare_labels_for_model(chunk.astype(np.float32), label_mean, label_std).to(device)
g = model.sample(
labels=bt,
channels=1,
height=H,
width=W,
device=device,
progress=False,
use_ddim=True,
ddim_steps=ddim_steps,
)
out.append(ec.from_model_output(g))
return np.concatenate(out, axis=0)
def load_model(bundle_args: Path, ckpt: Path, device: torch.device):
cfg = ec.load_training_config(str(bundle_args))
model = ec.build_model(cfg, device)
ec.load_checkpoint(model, str(ckpt), device)
model.eval()
return model, cfg
def free_torch():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def plot_training_overlay(
out_dir: Path,
slurm6: Path | None,
slurm2: Path | None,
) -> None:
"""Train + val curves for DDPM6 and optionally DDPM2 on one logarithmic-loss axis."""
fig, ax = plt.subplots(figsize=(9, 5))
plotted = False
if slurm6 and Path(slurm6).is_file():
ep, tr, va = parse_slurm_training_log(slurm6)
ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-6 train", color="#1f77b4", alpha=0.85)
ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-6 val", color="#174a75", alpha=0.95)
plotted = True
else:
print("Warning: 6-param Slurm log not found; skipped overlay for DDPM-6.")
if slurm2 and Path(slurm2).is_file():
ep, tr, va = load_train_val_series(slurm2)
ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-2 train", color="#ff7f0e", alpha=0.85)
ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-2 val", color="#994d00", alpha=0.95)
plotted = True
elif slurm2 is not None:
print(f"Warning: 2-param training series not found ({slurm2}); use --slurm-2param or restore bundled JSON.")
if not plotted:
print("No Slurm logs parsed — writing placeholder note instead of curves.")
ax.text(
0.5,
0.5,
"Pass --slurm-6param; DDPM-2 uses bundled JSON by default (--slurm-2param).",
ha="center",
va="center",
transform=ax.transAxes,
)
else:
ax.set_yscale("log")
ax.grid(True, alpha=0.3)
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE diffusion loss")
ax.legend(loc="upper right", fontsize=8)
ax.set_title("Training / validation curves (combined)")
outp = out_dir / "comparison_training_train_val_overlay.png"
fig.savefig(outp, dpi=170, bbox_inches="tight")
plt.close(fig)
print("Saved", outp)
def run_random_theta_triplets(
out_dir: Path,
imgs6: np.ndarray,
lab6: np.ndarray,
mean6: np.ndarray,
std6: np.ndarray,
mean2: np.ndarray,
std2: np.ndarray,
model2,
model6,
device: torch.device,
ddim_steps: int,
seed: int,
n_pairs: int,
batch_size: int,
):
"""Random LHS targets in CAMELS bbox; CAMELS column = NN real map."""
rng = np.random.default_rng(seed)
lo, hi = lab6.min(0), lab6.max(0)
targets = _latin_hypercube(min(n_pairs, 32), lo, hi, rng)[:n_pairs]
H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
tg2 = targets[:, :2].astype(np.float32)
fig = plt.figure(figsize=(3.8 * max(3, n_pairs * 3), 4.1))
for i in range(n_pairs):
theta6 = targets[i].astype(np.float32)
theta2 = tg2[i]
dist = np.linalg.norm(lab6 - theta6[None, :], axis=1).astype(np.float64)
nn = int(np.argmin(dist))
nn_img = imgs6[nn]
gen2 = generate_maps(
model2, theta2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size
)
gen6 = generate_maps(
model6, theta6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size
)
titles = ("CAMELS (NN)", "DDPM-2", "DDPM-6")
for j, img in enumerate((nn_img, gen2[0], gen6[0])):
ax = fig.add_subplot(1, n_pairs * 3, i * 3 + j + 1)
ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno")
ax.axis("off")
ax.set_title(titles[j], fontsize=8)
plt.suptitle(
"Random LHS cosmologies — CAMELS = nearest-neighbour truth | gens conditioned on LHS labels",
fontsize=10,
y=1.02,
)
p = out_dir / "comparison_random_lhs_triplets_camels_ddpm2_ddpm6.png"
plt.tight_layout()
plt.savefig(p, dpi=160, bbox_inches="tight")
plt.close(fig)
print("Saved", p)
def run_conditioned_test_triplets(
out_dir: Path,
imgs6: np.ndarray,
lab6: np.ndarray,
mean6: np.ndarray,
std6: np.ndarray,
mean2: np.ndarray,
std2: np.ndarray,
model2,
model6,
device: torch.device,
ddim_steps: int,
seed: int,
n_pairs: int,
batch_size: int,
):
"""Same rows from test split: conditioned on truth labels."""
rng = np.random.default_rng(seed + 1)
idx = rng.choice(len(imgs6), size=min(n_pairs, len(imgs6)), replace=False)
H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
fig, axes = plt.subplots(1, n_pairs * 3, figsize=(2.9 * n_pairs * 3, 3.8), squeeze=False)
for ii, ix in enumerate(idx):
tg6 = lab6[ix].astype(np.float32)
tg2 = tg6[:2]
rm = imgs6[ix]
g2 = generate_maps(model2, tg2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size)[0]
g6 = generate_maps(model6, tg6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size)[0]
for j, img in enumerate((rm, g2, g6)):
ax = axes[0, ii * 3 + j]
ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno")
ax.axis("off")
if ii == 0:
ax.set_title(("CAMELS", "DDPM-2", "DDPM-6")[j], fontsize=8)
axes[0, ii * 3].set_xlabel(_fmt_title(tg6), fontsize=7)
plt.suptitle(f"Random test ix (conditioned on truth labels), n={len(idx)}", fontsize=10, y=1.06)
p = out_dir / "comparison_test_conditioned_camels_ddpm2_ddpm6.png"
plt.savefig(p, dpi=160, bbox_inches="tight")
plt.close(fig)
print("Saved", p)
def pk_pdf_six_sets(
out_dir: Path,
name: str,
images_split: np.ndarray,
labels_split: np.ndarray,
label_mean: np.ndarray,
label_std: np.ndarray,
model,
device: torch.device,
ddim_steps: int,
batch_size: int,
n_per_set: int,
):
"""Six anchor rows (evenly spaced ix in test split), N_PER_SET DDIM samples."""
H, W = int(images_split.shape[-2]), int(images_split.shape[-1])
ldim = int(labels_split.shape[1])
idx = np.linspace(0, len(labels_split) - 1, num=6, dtype=int)
targets = labels_split[idx].copy()
box = 25.0
dk_ref = None
panels_pk = []
rng_pdf_bins = np.linspace(14.0, 22.0, 101)
bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:])
fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True)
axes_pk = axes_pk.ravel()
fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 4.8 * 2), squeeze=False)
for si, target_l in enumerate(targets):
dist = np.linalg.norm(labels_split - target_l[None, :], axis=1).astype(np.float64)
ex = idx[si]
dist[ex] = np.inf if ex < len(dist) else np.inf
nn_idx = np.argsort(dist)[:n_per_set]
real_batch = images_split[nn_idx]
rep = np.tile(target_l[None, :], (n_per_set, 1))
gen = generate_maps(model, rep, label_mean, label_std, H, W, device, ddim_steps, batch_size)
dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box)
dk_g, mg, sg = ec.calculate_power_spectrum_batch(gen, box_size=box)
dk_ref = dk_r
x = dk_ref[1:]
axpk = axes_pk[si]
axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#333")
axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.08, color="#333")
axpk.plot(x, mg[1:], lw=2, label=f"Generated (ldim={ldim})", color="#d95f02")
axpk.fill_between(x, mg[1:] - sg[1:], mg[1:] + sg[1:], alpha=0.08, color="#d95f02")
axpk.set_yscale("log")
axpk.grid(alpha=0.25)
axpk.set_title(_fmt_title(target_l), fontsize=8)
panels_pk.append((si, dk_r, mr, sr, mg, sg))
# PDF µ/σ
tb = []; rb = []
for i in range(n_per_set):
for arr, store in zip((real_batch, gen), (tb, rb)):
ims = np.clip(arr[i].ravel(), 0.0, 1.0)
logn = 14.0 + (22.0 - 14.0) * ims
hst, _ = np.histogram(logn, bins=rng_pdf_bins, density=True)
store.append(hst)
tb = np.asarray(tb); rb = np.asarray(rb)
axes_pdf[si, 0].plot(bin_pdf, tb.mean(axis=0), lw=2, label="CAMELS NN", color="#333")
axes_pdf[si, 0].plot(bin_pdf, rb.mean(axis=0), lw=2, label="Generated", color="#d95f02")
axes_pdf[si, 1].plot(bin_pdf, tb.std(axis=0), lw=2, ls="-", label="CAMELS σ", color="#333")
axes_pdf[si, 1].plot(bin_pdf, rb.std(axis=0), lw=2, ls="--", label="Gen σ", color="#d95f02")
axes_pk[0].legend(fontsize=7, loc="lower left")
fig_pk.suptitle(f"$P(k)$ mean±std — six anchors — {name}", fontsize=10)
fig_pk.tight_layout()
p_pk = out_dir / f"six_anchor_pk_{name}.png"
fig_pk.savefig(p_pk, dpi=160)
plt.close(fig_pk)
axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$")
axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$")
fig_pdf.suptitle(f"PDF mean & σ — six anchors × {n_per_set}{name}")
fig_pdf.tight_layout()
p_pdf = out_dir / f"six_anchor_pdf_mu_sigma_{name}.png"
fig_pdf.savefig(p_pdf, dpi=160)
plt.close(fig_pdf)
print("Saved", p_pk)
print("Saved", p_pdf)
def pk_six_triplet_combined(
out_dir: Path,
imgs6: np.ndarray,
lab6: np.ndarray,
mean6: np.ndarray,
std6: np.ndarray,
mean2: np.ndarray,
std2: np.ndarray,
model2: torch.nn.Module,
model6: torch.nn.Module,
device: torch.device,
ddim_steps: int,
batch_size: int,
n_per_set: int,
) -> None:
"""Six anchors — mean P(k) for CAMELS vs DDPM-2 vs DDPM-6; analogous PDF overlays."""
H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
idx = np.linspace(0, len(lab6) - 1, num=6, dtype=int)
targets = lab6[idx].copy()
box = 25.0
fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True)
axes_pk = axes_pk.ravel()
rng_pdf_bins = np.linspace(14.0, 22.0, 101)
bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:])
fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 10.5))
for si, target_l in enumerate(targets):
dist = np.linalg.norm(lab6 - target_l[None, :], axis=1).astype(np.float64)
ex = int(idx[si])
if ex < len(dist):
dist = dist.copy()
dist[ex] = np.inf
nn_idx = np.argsort(dist)[:n_per_set]
real_batch = imgs6[nn_idx]
tg2 = np.tile(target_l[:2][None, :], (n_per_set, 1)).astype(np.float32)
tg6 = np.tile(target_l[None, :], (n_per_set, 1)).astype(np.float32)
gen2 = generate_maps(model2, tg2, mean2, std2, H, W, device, ddim_steps, batch_size)
gen6 = generate_maps(model6, tg6, mean6, std6, H, W, device, ddim_steps, batch_size)
dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box)
_, m2, s2 = ec.calculate_power_spectrum_batch(gen2, box_size=box)
_, mG, sG = ec.calculate_power_spectrum_batch(gen6, box_size=box)
x = dk_r[1:]
axpk = axes_pk[si]
axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#222")
axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.06, color="#222")
axpk.plot(x, m2[1:], lw=2, label="DDPM-2 μ", color="#ff7f0e")
axpk.fill_between(x, m2[1:] - s2[1:], m2[1:] + s2[1:], alpha=0.06, color="#ff7f0e")
axpk.plot(x, mG[1:], lw=2, label="DDPM-6 μ", color="#1f77b4")
axpk.fill_between(x, mG[1:] - sG[1:], mG[1:] + sG[1:], alpha=0.06, color="#1f77b4")
axpk.set_yscale("log")
axpk.grid(alpha=0.25)
axpk.set_title(_fmt_title(target_l), fontsize=8)
if si == 0:
axpk.legend(fontsize=6.2, loc="lower left")
pdf_rows_lists = []
for imgs in (real_batch, gen2, gen6):
hb = []
for i in range(min(n_per_set, len(imgs))):
px = np.clip(imgs[i].ravel(), 0.0, 1.0)
ln = 14.0 + (22.0 - 14.0) * px
hb.append(np.histogram(ln, bins=rng_pdf_bins, density=True)[0])
pdf_rows_lists.append(np.asarray(hb))
cam_pdf, d2_pdf, d6_pdf = pdf_rows_lists
axes_pdf[si, 0].plot(bin_pdf, cam_pdf.mean(axis=0), lw=2, color="#222", label="CAMELS μ")
axes_pdf[si, 0].plot(bin_pdf, d2_pdf.mean(axis=0), lw=2, color="#ff7f0e", label="DDPM-2 μ")
axes_pdf[si, 0].plot(bin_pdf, d6_pdf.mean(axis=0), lw=2, color="#1f77b4", label="DDPM-6 μ")
axes_pdf[si, 1].plot(bin_pdf, cam_pdf.std(axis=0), lw=2, color="#222")
axes_pdf[si, 1].plot(bin_pdf, d2_pdf.std(axis=0), lw=2, ls="--", color="#ff7f0e")
axes_pdf[si, 1].plot(bin_pdf, d6_pdf.std(axis=0), lw=2, ls="--", color="#1f77b4")
fig_pk.suptitle("$P(k)$ CAMELS vs DDPM-2 vs DDPM-6 — six Ωm–σ8 anchors", fontsize=11)
fig_pk.tight_layout()
p_pk = out_dir / "six_anchor_pk_overlay_camels_ddpm2_ddpm6.png"
fig_pk.savefig(p_pk, dpi=160)
plt.close(fig_pk)
axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$")
axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$")
fig_pdf.suptitle(r"PDF mean ($\mu$) and std ($\sigma$) overlays", fontsize=10)
fig_pdf.tight_layout()
p_pdf = out_dir / "six_anchor_pdf_overlay_camels_ddpm2_ddpm6.png"
fig_pdf.savefig(p_pdf, dpi=160)
plt.close(fig_pdf)
print("Saved", p_pk)
print("Saved", p_pdf)
def mlp_recovery_dual(
out_dir: Path,
data_train: Path,
imgs_te: np.ndarray,
lab_te: np.ndarray,
mean: np.ndarray,
std: np.ndarray,
model_ddpm: torch.nn.Module,
tag: str,
device: torch.device,
ddim_steps: int,
seed: int,
) -> None:
from sklearn.metrics import mean_squared_error
from sklearn.neural_network import MLPRegressor
ldim = lab_te.shape[1]
Npix = imgs_te.shape[-1]
dl = 25.0 / Npix
def pk_row(im):
_dk, pk = ec.PowerSpectrum(np.asarray(im, dtype=np.float64), N=Npix, dl=dl)
return pk[1:].astype(np.float32)
img_tr_np, lab_tr_np = ec.load_split(data_train, "train")
if len(img_tr_np) > 2000:
rng = np.random.default_rng(seed)
jj = rng.choice(len(img_tr_np), 2000, replace=False)
img_tr_np, lab_tr_np = img_tr_np[jj], lab_tr_np[jj]
X_train = np.stack([pk_row(img_tr_np[i]) for i in range(len(img_tr_np))], axis=0)
y_train = lab_tr_np.astype(np.float32)
mlp = MLPRegressor(
hidden_layer_sizes=(64, 64),
alpha=1e-4,
random_state=seed,
max_iter=250,
early_stopping=True,
validation_fraction=0.1,
)
mlp.fit(X_train, y_train)
n_ev = min(40, len(imgs_te))
eval_idx = np.arange(n_ev)
X_real = np.stack([pk_row(imgs_te[i]) for i in eval_idx], axis=0)
y_true = lab_te[eval_idx]
preds_real = mlp.predict(X_real)
gens = []
H, W = int(imgs_te.shape[-2]), int(imgs_te.shape[-1])
for i0 in range(0, n_ev, 8):
bs_chunk = min(8, n_ev - i0)
lbl = y_true[i0 : i0 + bs_chunk]
g = generate_maps(model_ddpm, lbl, mean, std, H, W, device, ddim_steps, bs_chunk)
gens.extend([pk_row(g[j]) for j in range(len(g))])
X_gen = np.stack(gens, axis=0)
preds_gen = mlp.predict(X_gen)
rmse_real = np.sqrt(mean_squared_error(y_true, preds_real, multioutput="raw_values"))
rmse_gen = np.sqrt(mean_squared_error(y_true, preds_gen, multioutput="raw_values"))
fig, axes = plt.subplots(2, ldim, figsize=(max(9.0, 2.8 * max(ldim, 2)), 4.9), squeeze=False)
if ldim == 1:
axes = np.reshape(axes, (2, 1))
for k in range(ldim):
for row, preds, rmv, ylab in (
(0, preds_real, rmse_real, "CAMELS P(k) predictions"),
(1, preds_gen, rmse_gen, f"{tag}: generated P(k)"),
):
ax = axes[row, k]
lo = float(y_true[:, k].min()); hi = float(y_true[:, k].max())
pad = 0.03 * (hi - lo + 1e-12)
ax.scatter(y_true[:, k], preds[:, k], s=14, alpha=0.72, edgecolors="none", c="#333")
ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], color="crimson", lw=1.0)
ax.grid(True, alpha=0.28)
ax.set_title(f"dim {k} RMSE={float(rmv[k]):.4f}", fontsize=8)
if k == 0:
ax.set_ylabel(ylab, fontsize=8)
plt.suptitle(
"MLP: train on CAMELS train P(k), test on CAMELS vs DDPM-drawn spectra",
fontsize=10,
y=1.02,
)
plt.tight_layout()
p = out_dir / f"mlp_pk_parameter_recovery_{tag}.png"
plt.savefig(p, dpi=165, bbox_inches="tight")
plt.close(fig)
print("Saved", p)
def posterior_one_index(
out_dir: Path,
images_split: np.ndarray,
labels_split: np.ndarray,
lab_mean: np.ndarray,
lab_std: np.ndarray,
model,
cfg: Dict,
device,
ix: int,
tag: str,
ddim_steps: int,
grid: int,
batch_sz: int,
):
normalize = bool(cfg.get("normalize_labels", True))
lab_dim = labels_split.shape[1]
H, W = int(images_split.shape[-2]), int(images_split.shape[-1])
obs = images_split[ix]
label_anchor_full = labels_split[ix].astype(np.float32)
lo0 = float(labels_split[:, 0].min())
hi0 = float(labels_split[:, 0].max())
lo1 = float(labels_split[:, 1].min())
hi1 = float(labels_split[:, 1].max())
pad0 = 0.02 * (hi0 - lo0 + 1e-12)
pad1 = 0.02 * (hi1 - lo1 + 1e-12)
om_ax, s8_ax, OG, SG, grid2 = build_cosmo_grid(
grid, lo0 - pad0, hi0 + pad0, lo1 - pad1, hi1 + pad1
)
g = grid
ngrid = grid2.shape[0]
npix = int(obs.shape[-1])
dl = 25.0 / npix
dk, _ = ec.PowerSpectrum(em.images01_to_log_nhi(obs), N=npix, dl=dl)
valid = dk > 0
log_pd = log_pk_observed(obs, 25.0, dk)
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
full = np.tile(label_anchor_full[np.newaxis, :], (ngrid, 1))
full[:, 0] = grid2[:, 0].astype(np.float32)
full[:, 1] = grid2[:, 1].astype(np.float32)
def weights_full() -> np.ndarray:
scores = []
for j0 in range(0, ngrid, batch_sz):
chunk = full[j0 : j0 + batch_sz]
imgs = em.sample_batch(
model,
chunk,
lab_mean,
lab_std,
normalize,
H,
W,
device,
ddim_steps,
False,
)
_, pkc = em.per_map_power_spectra_log(imgs, 25.0)
log_pg = np.log(pkc[:, valid] + 1e-30)
mse = np.mean((log_pd[np.newaxis, :] - log_pg) ** 2, axis=1)
scores.append(-mse / (2.0 * 0.25**2))
sc = np.concatenate(scores)
sc -= sc.max()
w = np.exp(sc).reshape(g, g)
w /= w.sum()
return w
Wmap = weights_full()
tom, ts8 = float(label_anchor_full[0]), float(label_anchor_full[1])
mom = float((Wmap * OM).sum())
ms8 = float((Wmap * S8).sum())
fig, ax = plt.subplots(figsize=(5.2, 4.6))
cf = ax.contourf(OM, S8, Wmap, levels=12, cmap="Blues")
plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
ax.scatter(tom, ts8, s=55, c="r", marker="x", zorder=6, label="true")
ax.scatter(mom, ms8, s=60, c="k", marker="+", zorder=6, label="post. mean")
ax.set_xlabel(r"$\Omega_m$")
ax.set_ylabel(r"$\sigma_8$")
ax.legend(fontsize=8)
ax.set_title(f"Surrogate posterior (test ix={ix}, ldim={lab_dim})", fontsize=10)
p = out_dir / f"posterior_surrogate_test_ix_{ix}_{tag}.png"
fig.savefig(p, dpi=160, bbox_inches="tight")
plt.close(fig)
print("Saved", p)
def main(argv: Sequence[str] | None = None) -> None:
p = argparse.ArgumentParser(description="DDPM-2 vs DDPM-6 comparison suite.")
p.add_argument(
"--output-dir",
type=Path,
default=MODELS_ROOT / "ddpm_comparison_out",
)
p.add_argument("--data-2param", type=Path, default=Path("<DDPM_ROOT>/data/LH_data/params_2"))
p.add_argument("--data-6param", type=Path, default=Path("<DDPM_ROOT>/data/LH_data/params_6"))
p.add_argument(
"--bundle-2param",
type=Path,
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
)
p.add_argument(
"--bundle-6param",
type=Path,
default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
)
p.add_argument("--posterior-index", type=int, default=56)
p.add_argument("--lhs-n", type=int, default=50)
p.add_argument("--six-n-per-anchor", type=int, default=15)
p.add_argument("--ddim-steps", type=int, default=50)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--batch-size", type=int, default=8)
p.add_argument("--slurm-6param", type=Path, default=DEFAULT_SLURM_6)
p.add_argument(
"--slurm-2param",
type=Path,
default=DEFAULT_DDPM2_TRAINING,
help="DDPM-2 train/val series: Slurm .out (parsed) or bundled ddpm_2param_training_loss.json.",
)
p.add_argument("--skip-lhs-r2", action="store_true", help="LHS R² plots are expensive; skip if set.")
p.add_argument("--n-random-triplets", type=int, default=4)
args = p.parse_args(list(argv) if argv is not None else None)
out_dir = Path(args.output_dir).resolve()
out_dir.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
args2 = args.bundle_2param / "args.json"
ck6 = args.bundle_6param / "best_model.pt"
args6 = args.bundle_6param / "args.json"
data2 = Path(args.data_2param)
data6 = Path(args.data_6param)
imgs6, lab6 = ec.load_split(data6, "test")
mean6, std6 = ec.load_label_stats(data6)
mean2, std2 = ec.load_label_stats(data2)
plot_training_overlay(out_dir, args.slurm_6param, args.slurm_2param)
imgs2, lab2 = ec.load_split(data2, "test")
print(">>> Loading DDPM-2...")
model2, cfg2 = load_model(args2, ck2, device)
print(">>> Loading DDPM-6...")
model6, cfg6 = load_model(args6, ck6, device)
print(">>> Random LHS + conditioned triplets...")
try:
run_random_theta_triplets(
out_dir,
imgs6,
lab6,
mean6,
std6,
mean2,
std2,
model2,
model6,
device=device,
ddim_steps=args.ddim_steps,
seed=args.seed,
n_pairs=args.n_random_triplets,
batch_size=args.batch_size,
)
run_conditioned_test_triplets(
out_dir,
imgs6,
lab6,
mean6,
std6,
mean2,
std2,
model2,
model6,
device=device,
ddim_steps=args.ddim_steps,
seed=args.seed,
n_pairs=args.n_random_triplets,
batch_size=args.batch_size,
)
except Exception as exc:
print("Triplet grids failed:", exc)
print(">>> Six-anchor overlays (combined + per-model)...")
try:
pk_six_triplet_combined(
out_dir,
imgs6,
lab6,
mean6,
std6,
mean2,
std2,
model2,
model6,
device=device,
ddim_steps=args.ddim_steps,
batch_size=args.batch_size,
n_per_set=args.six_n_per_anchor,
)
pk_pdf_six_sets(
out_dir,
"ddpm6_only",
imgs6,
lab6,
mean6,
std6,
model6,
device,
args.ddim_steps,
args.batch_size,
args.six_n_per_anchor,
)
pk_pdf_six_sets(
out_dir,
"ddpm2_only",
imgs2,
lab2,
mean2,
std2,
model2,
device,
args.ddim_steps,
args.batch_size,
args.six_n_per_anchor,
)
except Exception as exc:
print("P(k)/PDF six-anchor plots failed:", exc)
if not args.skip_lhs_r2:
print(">>> LHS R² (LHS-50 × 15 DDIM each — long)...")
try:
for label, imgs, labs, mn, sd, mdl in (
("ddpm2_lhs50", imgs2, lab2, mean2, std2, model2),
("ddpm6_lhs50", imgs6, lab6, mean6, std6, model6),
):
lhs_pts, r2_mu, r2_sig, lo_b, hi_b = compute_lhs_r2(
mdl,
imgs,
labs,
mn,
sd,
device,
args.lhs_n,
15,
args.batch_size,
25.0,
args.ddim_steps,
args.seed,
)
outp = out_dir / f"r2_cosmology_lhs{args.lhs_n}_{label}.png"
plot_r2_cosmology_figure(lhs_pts, r2_mu, r2_sig, lo_b, hi_b, outp, dpi=160)
print("Saved", outp)
np.savez(
out_dir / f"r2_lhs_data_{label}.npz",
lhs_pts=lhs_pts,
r2_mu_arr=r2_mu,
r2_sig_arr=r2_sig,
lo_b=lo_b,
hi_b=hi_b,
)
except Exception as exc:
print("LHS R² skipped:", exc)
else:
print("(Skipping LHS R².)")
print(">>> MLP P(k) parameter recovery...")
try:
mlp_recovery_dual(
out_dir, data2, imgs2[:40], lab2[:40], mean2, std2, model2, "ddpm2param", device, args.ddim_steps, args.seed
)
mlp_recovery_dual(
out_dir, data6, imgs6[:40], lab6[:40], mean6, std6, model6, "ddpm6param", device, args.ddim_steps, args.seed
)
except Exception as exc:
print("MLP recovery skipped:", exc)
print(f">>> Surrogate posteriors (test index {args.posterior_index})...")
try:
ix = int(args.posterior_index)
posterior_one_index(
out_dir, imgs6, lab6, mean6, std6, model6, cfg6, device, ix, "ddpm6", args.ddim_steps, 14, args.batch_size
)
posterior_one_index(
out_dir,
imgs2,
lab2,
mean2,
std2,
model2,
cfg2,
device,
ix,
"ddpm2",
args.ddim_steps,
14,
args.batch_size,
)
except Exception as exc:
print("Posterior panels skipped:", exc)
del model2, model6
free_torch()
print(f"Done. Outputs in {out_dir}")
if __name__ == "__main__":
main()