DDPM-6param / src /plot_r2_cosmology_lhs.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
1f3e7a2 verified
#!/usr/bin/env python3
"""
Reproduce the Latin-hypercube R² figure (μ(P) and σ(P)) in (Ωm, σ8) with a layout
that avoids colorbar / suptitle overlap.
Usage (full run — slow):
python plot_r2_cosmology_lhs.py --output ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png
Replay plot only from saved arrays:
python plot_r2_cosmology_lhs.py --from-npz r2_lhs_data.npz --output out.png
Defaults match ddpm_conditional / evaluate_conditional 6-param setup.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from matplotlib.gridspec import GridSpec
import evaluate_conditional as ec
import eval_model as em
_SCRIPT_DIR = Path(__file__).resolve().parent
_DEFAULT_CKPT = _SCRIPT_DIR / "outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt"
_DEFAULT_DATA = "<DDPM_ROOT>/data/LH_data/params_6"
def latin_hypercube_scaled(
n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator
) -> np.ndarray:
"""n points in [lo, hi] per dimension (classic LHS)."""
d = int(lo.shape[0])
u = rng.random((n, d))
cut = np.linspace(0.0, 1.0, n + 1)
a, b = cut[:-1], cut[1:]
width = (b - a)[:, np.newaxis]
rd = a[:, np.newaxis] + u * width
for j in range(d):
rng.shuffle(rd[:, j])
span = (hi - lo).astype(np.float64)
return (lo + rd * span).astype(np.float32)
def compute_lhs_r2(
model: torch.nn.Module,
images_split: np.ndarray,
labels_split: np.ndarray,
label_mean: np.ndarray,
label_std: np.ndarray,
device: torch.device,
lhs_n: int,
maps_per_point: int,
batch_size: int,
box_size_mpc: float,
ddim_steps: int,
seed: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Returns lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b."""
lo_b = labels_split.min(axis=0)
hi_b = labels_split.max(axis=0)
rng = np.random.default_rng(seed)
lhs_pts = latin_hypercube_scaled(lhs_n, lo_b, hi_b, rng)
ldim = labels_split.shape[1]
h, w = int(images_split.shape[-2]), int(images_split.shape[-1])
bs = min(batch_size, maps_per_point)
npix = int(images_split.shape[-1])
dl = box_size_mpc / npix
def pk_stack(imgs: np.ndarray) -> np.ndarray:
return np.stack([ec.PowerSpectrum(im, N=npix, dl=dl)[1] for im in imgs], axis=0)
r2_mu_arr = np.full(lhs_n, np.nan, dtype=np.float64)
r2_sig_arr = np.full(lhs_n, np.nan, dtype=np.float64)
model.eval()
for ti in range(lhs_n):
theta = lhs_pts[ti]
dist = np.linalg.norm(labels_split - theta, axis=1)
nn_idx = np.argsort(dist)[:maps_per_point]
real_batch = images_split[nn_idx]
rep = np.tile(theta[None, :], (maps_per_point, 1))
gen_chunks = []
for j in range(0, maps_per_point, bs):
chunk = rep[j : j + bs]
bt = ec.prepare_labels_for_model(chunk, label_mean, label_std).to(device)
with torch.no_grad():
g = model.sample(
labels=bt,
channels=1,
height=h,
width=w,
device=device,
progress=False,
use_ddim=True,
ddim_steps=ddim_steps,
)
gen_chunks.append(ec.from_model_output(g))
gen_batch = np.concatenate(gen_chunks, axis=0)
pk_r = pk_stack(real_batch)
pk_g = pk_stack(gen_batch)
km = np.arange(pk_r.shape[1], dtype=int) > 0
mu_r, mu_g = pk_r.mean(axis=0), pk_g.mean(axis=0)
sr, sg = pk_r.std(axis=0), pk_g.std(axis=0)
r2_mu_arr[ti] = em.r2_score_1d(mu_r[km], mu_g[km])
r2_sig_arr[ti] = em.r2_score_1d(sr[km], sg[km])
return lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b
def plot_r2_cosmology_figure(
lhs_pts: np.ndarray,
r2_mu_arr: np.ndarray,
r2_sig_arr: np.ndarray,
lo_b: np.ndarray,
hi_b: np.ndarray,
out_path: Path,
r2_vmin: float = 0.90,
r2_vmax: float = 1.0,
lhs_n: int | None = None,
maps_per_point: int | None = None,
dpi: int = 160,
) -> None:
"""
Two-panel scatter in (Ωm, σ8) with a dedicated colorbar column (no overlap with heatmap).
"""
lhs_n = lhs_n if lhs_n is not None else len(r2_mu_arr)
maps_per_point = maps_per_point if maps_per_point is not None else 15
ldim = lhs_pts.shape[1]
om_plot = lhs_pts[:, 0]
s8_plot = lhs_pts[:, 1] if ldim >= 2 else np.zeros(lhs_n)
cmap = em.cmap_r2_hiflow()
norm = Normalize(vmin=r2_vmin, vmax=r2_vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
fig = plt.figure(figsize=(11.5, 4.9))
# Left: data panels; narrow right strip: colorbar only (avoids fig.colorbar + tight_layout clash)
gs = GridSpec(
nrows=1,
ncols=3,
figure=fig,
width_ratios=[1.0, 1.0, 0.065],
wspace=0.26,
left=0.07,
right=0.98,
top=0.82,
bottom=0.14,
)
ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
cax = fig.add_subplot(gs[0, 2])
pad_x = 0.02 * (float(hi_b[0] - lo_b[0]) + 1e-6)
ax0.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x)
ax1.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x)
if ldim >= 2:
pad_y = 0.02 * (float(hi_b[1] - lo_b[1]) + 1e-6)
ax0.set_ylim(float(lo_b[1]) - pad_y, float(hi_b[1]) + pad_y)
for ax, r2v, subtitle in zip(
(ax0, ax1),
(r2_mu_arr, r2_sig_arr),
(r"$R^2$ for $\mu(P)$", r"$R^2$ for $\sigma(P)$"),
):
ok = np.isfinite(r2v)
ax.scatter(
om_plot[ok],
s8_plot[ok],
c=np.clip(r2v[ok], r2_vmin, r2_vmax),
cmap=cmap,
norm=norm,
s=52,
alpha=0.92,
edgecolors="k",
linewidths=0.35,
)
ax.set_xlabel(r"$\Omega_m$", fontsize=12)
ax.set_title(subtitle, fontsize=11)
ax.grid(True, alpha=0.25)
ax0.set_ylabel(r"$\sigma_8$", fontsize=12)
plt.setp(ax1.get_yticklabels(), visible=False)
cb = fig.colorbar(sm, cax=cax)
cb.set_label(r"$R^2$", fontsize=11)
cax.tick_params(labelsize=9)
fig.suptitle(
r"Visual summary of $R^2$ (CAMELS vs conditional DDPM) vs cosmology — "
+ f"{lhs_n} Latin Hypercube samples; {maps_per_point} maps / point",
fontsize=11,
fontweight="bold",
y=0.96,
)
out_path = Path(out_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
# Do not use bbox_inches="tight" — it rebalance axes and can squeeze the colorbar into the panels.
fig.savefig(out_path, dpi=dpi)
plt.close(fig)
def _resolve_training_args(checkpoint: Path) -> Path | None:
run = checkpoint.parent.parent if checkpoint.parent.name == "checkpoints" else checkpoint.parent
for name in ("args.json", "args.txt"):
p = run / name
if p.is_file():
return p
return None
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="LHS R² cosmology figure (fixed colorbar layout)")
p.add_argument("--checkpoint", type=str, default=str(_DEFAULT_CKPT))
p.add_argument("--data-dir", type=str, default=_DEFAULT_DATA)
p.add_argument("--split", type=str, default="test", choices=("train", "val", "test"))
p.add_argument("--output", type=str, default=str(_SCRIPT_DIR / "ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png"))
p.add_argument("--from-npz", type=str, default=None, help="Load lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b")
p.add_argument("--save-npz", type=str, default=None, help="After compute, save arrays for --from-npz replot")
p.add_argument("--lhs-n", type=int, default=50)
p.add_argument("--maps-per-point", type=int, default=15)
p.add_argument("--batch-size", type=int, default=8)
p.add_argument("--ddim-steps", type=int, default=50)
p.add_argument("--box-size-mpc", type=float, default=25.0)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--r2-vmin", type=float, default=0.90)
p.add_argument("--r2-vmax", type=float, default=1.0)
p.add_argument("--dpi", type=int, default=160)
return p.parse_args()
def main() -> None:
args = parse_args()
out_path = Path(args.output)
if args.from_npz:
z = np.load(args.from_npz, allow_pickle=False)
lhs_pts = z["lhs_pts"]
r2_mu_arr = z["r2_mu_arr"]
r2_sig_arr = z["r2_sig_arr"]
lo_b = z["lo_b"]
hi_b = z["hi_b"]
else:
ckpt = Path(args.checkpoint).expanduser().resolve()
if not ckpt.is_file():
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
ta = _resolve_training_args(ckpt)
config: dict = {}
if ta is not None:
config = ec.load_training_config(str(ta))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ec.build_model(config, device)
ec.load_checkpoint(model, str(ckpt), device)
data_dir = Path(args.data_dir)
images_split, labels_split = ec.load_split(data_dir, args.split)
label_mean, label_std = ec.load_label_stats(data_dir)
lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b = compute_lhs_r2(
model,
images_split,
labels_split,
label_mean,
label_std,
device,
lhs_n=args.lhs_n,
maps_per_point=args.maps_per_point,
batch_size=args.batch_size,
box_size_mpc=args.box_size_mpc,
ddim_steps=args.ddim_steps,
seed=args.seed,
)
if args.save_npz:
np.savez(
args.save_npz,
lhs_pts=lhs_pts,
r2_mu_arr=r2_mu_arr,
r2_sig_arr=r2_sig_arr,
lo_b=lo_b,
hi_b=hi_b,
)
print("Saved", args.save_npz)
plot_r2_cosmology_figure(
lhs_pts,
r2_mu_arr,
r2_sig_arr,
lo_b,
hi_b,
out_path,
r2_vmin=args.r2_vmin,
r2_vmax=args.r2_vmax,
lhs_n=args.lhs_n,
maps_per_point=args.maps_per_point,
dpi=args.dpi,
)
print("Saved", out_path.resolve())
if __name__ == "__main__":
main()