#!/usr/bin/env python3 """ Reproduce the Latin-hypercube R² figure (μ(P) and σ(P)) in (Ωm, σ8) with a layout that avoids colorbar / suptitle overlap. Usage (full run — slow): python plot_r2_cosmology_lhs.py --output ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png Replay plot only from saved arrays: python plot_r2_cosmology_lhs.py --from-npz r2_lhs_data.npz --output out.png Defaults match ddpm_conditional / evaluate_conditional 6-param setup. """ from __future__ import annotations import argparse from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.cm import ScalarMappable from matplotlib.colors import Normalize from matplotlib.gridspec import GridSpec import evaluate_conditional as ec import eval_model as em _SCRIPT_DIR = Path(__file__).resolve().parent _DEFAULT_CKPT = _SCRIPT_DIR / "outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt" _DEFAULT_DATA = "/data/LH_data/params_6" def latin_hypercube_scaled( n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator ) -> np.ndarray: """n points in [lo, hi] per dimension (classic LHS).""" d = int(lo.shape[0]) u = rng.random((n, d)) cut = np.linspace(0.0, 1.0, n + 1) a, b = cut[:-1], cut[1:] width = (b - a)[:, np.newaxis] rd = a[:, np.newaxis] + u * width for j in range(d): rng.shuffle(rd[:, j]) span = (hi - lo).astype(np.float64) return (lo + rd * span).astype(np.float32) def compute_lhs_r2( model: torch.nn.Module, images_split: np.ndarray, labels_split: np.ndarray, label_mean: np.ndarray, label_std: np.ndarray, device: torch.device, lhs_n: int, maps_per_point: int, batch_size: int, box_size_mpc: float, ddim_steps: int, seed: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Returns lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b.""" lo_b = labels_split.min(axis=0) hi_b = labels_split.max(axis=0) rng = np.random.default_rng(seed) lhs_pts = latin_hypercube_scaled(lhs_n, lo_b, hi_b, rng) ldim = labels_split.shape[1] h, w = int(images_split.shape[-2]), int(images_split.shape[-1]) bs = min(batch_size, maps_per_point) npix = int(images_split.shape[-1]) dl = box_size_mpc / npix def pk_stack(imgs: np.ndarray) -> np.ndarray: return np.stack([ec.PowerSpectrum(im, N=npix, dl=dl)[1] for im in imgs], axis=0) r2_mu_arr = np.full(lhs_n, np.nan, dtype=np.float64) r2_sig_arr = np.full(lhs_n, np.nan, dtype=np.float64) model.eval() for ti in range(lhs_n): theta = lhs_pts[ti] dist = np.linalg.norm(labels_split - theta, axis=1) nn_idx = np.argsort(dist)[:maps_per_point] real_batch = images_split[nn_idx] rep = np.tile(theta[None, :], (maps_per_point, 1)) gen_chunks = [] for j in range(0, maps_per_point, bs): chunk = rep[j : j + bs] bt = ec.prepare_labels_for_model(chunk, label_mean, label_std).to(device) with torch.no_grad(): g = model.sample( labels=bt, channels=1, height=h, width=w, device=device, progress=False, use_ddim=True, ddim_steps=ddim_steps, ) gen_chunks.append(ec.from_model_output(g)) gen_batch = np.concatenate(gen_chunks, axis=0) pk_r = pk_stack(real_batch) pk_g = pk_stack(gen_batch) km = np.arange(pk_r.shape[1], dtype=int) > 0 mu_r, mu_g = pk_r.mean(axis=0), pk_g.mean(axis=0) sr, sg = pk_r.std(axis=0), pk_g.std(axis=0) r2_mu_arr[ti] = em.r2_score_1d(mu_r[km], mu_g[km]) r2_sig_arr[ti] = em.r2_score_1d(sr[km], sg[km]) return lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b def plot_r2_cosmology_figure( lhs_pts: np.ndarray, r2_mu_arr: np.ndarray, r2_sig_arr: np.ndarray, lo_b: np.ndarray, hi_b: np.ndarray, out_path: Path, r2_vmin: float = 0.90, r2_vmax: float = 1.0, lhs_n: int | None = None, maps_per_point: int | None = None, dpi: int = 160, ) -> None: """ Two-panel scatter in (Ωm, σ8) with a dedicated colorbar column (no overlap with heatmap). """ lhs_n = lhs_n if lhs_n is not None else len(r2_mu_arr) maps_per_point = maps_per_point if maps_per_point is not None else 15 ldim = lhs_pts.shape[1] om_plot = lhs_pts[:, 0] s8_plot = lhs_pts[:, 1] if ldim >= 2 else np.zeros(lhs_n) cmap = em.cmap_r2_hiflow() norm = Normalize(vmin=r2_vmin, vmax=r2_vmax) sm = ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) fig = plt.figure(figsize=(11.5, 4.9)) # Left: data panels; narrow right strip: colorbar only (avoids fig.colorbar + tight_layout clash) gs = GridSpec( nrows=1, ncols=3, figure=fig, width_ratios=[1.0, 1.0, 0.065], wspace=0.26, left=0.07, right=0.98, top=0.82, bottom=0.14, ) ax0 = fig.add_subplot(gs[0, 0]) ax1 = fig.add_subplot(gs[0, 1], sharey=ax0) cax = fig.add_subplot(gs[0, 2]) pad_x = 0.02 * (float(hi_b[0] - lo_b[0]) + 1e-6) ax0.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x) ax1.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x) if ldim >= 2: pad_y = 0.02 * (float(hi_b[1] - lo_b[1]) + 1e-6) ax0.set_ylim(float(lo_b[1]) - pad_y, float(hi_b[1]) + pad_y) for ax, r2v, subtitle in zip( (ax0, ax1), (r2_mu_arr, r2_sig_arr), (r"$R^2$ for $\mu(P)$", r"$R^2$ for $\sigma(P)$"), ): ok = np.isfinite(r2v) ax.scatter( om_plot[ok], s8_plot[ok], c=np.clip(r2v[ok], r2_vmin, r2_vmax), cmap=cmap, norm=norm, s=52, alpha=0.92, edgecolors="k", linewidths=0.35, ) ax.set_xlabel(r"$\Omega_m$", fontsize=12) ax.set_title(subtitle, fontsize=11) ax.grid(True, alpha=0.25) ax0.set_ylabel(r"$\sigma_8$", fontsize=12) plt.setp(ax1.get_yticklabels(), visible=False) cb = fig.colorbar(sm, cax=cax) cb.set_label(r"$R^2$", fontsize=11) cax.tick_params(labelsize=9) fig.suptitle( r"Visual summary of $R^2$ (CAMELS vs conditional DDPM) vs cosmology — " + f"{lhs_n} Latin Hypercube samples; {maps_per_point} maps / point", fontsize=11, fontweight="bold", y=0.96, ) out_path = Path(out_path) out_path.parent.mkdir(parents=True, exist_ok=True) # Do not use bbox_inches="tight" — it rebalance axes and can squeeze the colorbar into the panels. fig.savefig(out_path, dpi=dpi) plt.close(fig) def _resolve_training_args(checkpoint: Path) -> Path | None: run = checkpoint.parent.parent if checkpoint.parent.name == "checkpoints" else checkpoint.parent for name in ("args.json", "args.txt"): p = run / name if p.is_file(): return p return None def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="LHS R² cosmology figure (fixed colorbar layout)") p.add_argument("--checkpoint", type=str, default=str(_DEFAULT_CKPT)) p.add_argument("--data-dir", type=str, default=_DEFAULT_DATA) p.add_argument("--split", type=str, default="test", choices=("train", "val", "test")) p.add_argument("--output", type=str, default=str(_SCRIPT_DIR / "ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png")) p.add_argument("--from-npz", type=str, default=None, help="Load lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b") p.add_argument("--save-npz", type=str, default=None, help="After compute, save arrays for --from-npz replot") p.add_argument("--lhs-n", type=int, default=50) p.add_argument("--maps-per-point", type=int, default=15) p.add_argument("--batch-size", type=int, default=8) p.add_argument("--ddim-steps", type=int, default=50) p.add_argument("--box-size-mpc", type=float, default=25.0) p.add_argument("--seed", type=int, default=42) p.add_argument("--r2-vmin", type=float, default=0.90) p.add_argument("--r2-vmax", type=float, default=1.0) p.add_argument("--dpi", type=int, default=160) return p.parse_args() def main() -> None: args = parse_args() out_path = Path(args.output) if args.from_npz: z = np.load(args.from_npz, allow_pickle=False) lhs_pts = z["lhs_pts"] r2_mu_arr = z["r2_mu_arr"] r2_sig_arr = z["r2_sig_arr"] lo_b = z["lo_b"] hi_b = z["hi_b"] else: ckpt = Path(args.checkpoint).expanduser().resolve() if not ckpt.is_file(): raise FileNotFoundError(f"Checkpoint not found: {ckpt}") ta = _resolve_training_args(ckpt) config: dict = {} if ta is not None: config = ec.load_training_config(str(ta)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ec.build_model(config, device) ec.load_checkpoint(model, str(ckpt), device) data_dir = Path(args.data_dir) images_split, labels_split = ec.load_split(data_dir, args.split) label_mean, label_std = ec.load_label_stats(data_dir) lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b = compute_lhs_r2( model, images_split, labels_split, label_mean, label_std, device, lhs_n=args.lhs_n, maps_per_point=args.maps_per_point, batch_size=args.batch_size, box_size_mpc=args.box_size_mpc, ddim_steps=args.ddim_steps, seed=args.seed, ) if args.save_npz: np.savez( args.save_npz, lhs_pts=lhs_pts, r2_mu_arr=r2_mu_arr, r2_sig_arr=r2_sig_arr, lo_b=lo_b, hi_b=hi_b, ) print("Saved", args.save_npz) plot_r2_cosmology_figure( lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b, out_path, r2_vmin=args.r2_vmin, r2_vmax=args.r2_vmax, lhs_n=args.lhs_n, maps_per_point=args.maps_per_point, dpi=args.dpi, ) print("Saved", out_path.resolve()) if __name__ == "__main__": main()