Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
1f3e7a2 verified | #!/usr/bin/env python3 | |
| """ | |
| Reproduce the Latin-hypercube R² figure (μ(P) and σ(P)) in (Ωm, σ8) with a layout | |
| that avoids colorbar / suptitle overlap. | |
| Usage (full run — slow): | |
| python plot_r2_cosmology_lhs.py --output ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png | |
| Replay plot only from saved arrays: | |
| python plot_r2_cosmology_lhs.py --from-npz r2_lhs_data.npz --output out.png | |
| Defaults match ddpm_conditional / evaluate_conditional 6-param setup. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from matplotlib.cm import ScalarMappable | |
| from matplotlib.colors import Normalize | |
| from matplotlib.gridspec import GridSpec | |
| import evaluate_conditional as ec | |
| import eval_model as em | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| _DEFAULT_CKPT = _SCRIPT_DIR / "outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt" | |
| _DEFAULT_DATA = "<DDPM_ROOT>/data/LH_data/params_6" | |
| def latin_hypercube_scaled( | |
| n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator | |
| ) -> np.ndarray: | |
| """n points in [lo, hi] per dimension (classic LHS).""" | |
| d = int(lo.shape[0]) | |
| u = rng.random((n, d)) | |
| cut = np.linspace(0.0, 1.0, n + 1) | |
| a, b = cut[:-1], cut[1:] | |
| width = (b - a)[:, np.newaxis] | |
| rd = a[:, np.newaxis] + u * width | |
| for j in range(d): | |
| rng.shuffle(rd[:, j]) | |
| span = (hi - lo).astype(np.float64) | |
| return (lo + rd * span).astype(np.float32) | |
| def compute_lhs_r2( | |
| model: torch.nn.Module, | |
| images_split: np.ndarray, | |
| labels_split: np.ndarray, | |
| label_mean: np.ndarray, | |
| label_std: np.ndarray, | |
| device: torch.device, | |
| lhs_n: int, | |
| maps_per_point: int, | |
| batch_size: int, | |
| box_size_mpc: float, | |
| ddim_steps: int, | |
| seed: int, | |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| """Returns lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b.""" | |
| lo_b = labels_split.min(axis=0) | |
| hi_b = labels_split.max(axis=0) | |
| rng = np.random.default_rng(seed) | |
| lhs_pts = latin_hypercube_scaled(lhs_n, lo_b, hi_b, rng) | |
| ldim = labels_split.shape[1] | |
| h, w = int(images_split.shape[-2]), int(images_split.shape[-1]) | |
| bs = min(batch_size, maps_per_point) | |
| npix = int(images_split.shape[-1]) | |
| dl = box_size_mpc / npix | |
| def pk_stack(imgs: np.ndarray) -> np.ndarray: | |
| return np.stack([ec.PowerSpectrum(im, N=npix, dl=dl)[1] for im in imgs], axis=0) | |
| r2_mu_arr = np.full(lhs_n, np.nan, dtype=np.float64) | |
| r2_sig_arr = np.full(lhs_n, np.nan, dtype=np.float64) | |
| model.eval() | |
| for ti in range(lhs_n): | |
| theta = lhs_pts[ti] | |
| dist = np.linalg.norm(labels_split - theta, axis=1) | |
| nn_idx = np.argsort(dist)[:maps_per_point] | |
| real_batch = images_split[nn_idx] | |
| rep = np.tile(theta[None, :], (maps_per_point, 1)) | |
| gen_chunks = [] | |
| for j in range(0, maps_per_point, bs): | |
| chunk = rep[j : j + bs] | |
| bt = ec.prepare_labels_for_model(chunk, label_mean, label_std).to(device) | |
| with torch.no_grad(): | |
| g = model.sample( | |
| labels=bt, | |
| channels=1, | |
| height=h, | |
| width=w, | |
| device=device, | |
| progress=False, | |
| use_ddim=True, | |
| ddim_steps=ddim_steps, | |
| ) | |
| gen_chunks.append(ec.from_model_output(g)) | |
| gen_batch = np.concatenate(gen_chunks, axis=0) | |
| pk_r = pk_stack(real_batch) | |
| pk_g = pk_stack(gen_batch) | |
| km = np.arange(pk_r.shape[1], dtype=int) > 0 | |
| mu_r, mu_g = pk_r.mean(axis=0), pk_g.mean(axis=0) | |
| sr, sg = pk_r.std(axis=0), pk_g.std(axis=0) | |
| r2_mu_arr[ti] = em.r2_score_1d(mu_r[km], mu_g[km]) | |
| r2_sig_arr[ti] = em.r2_score_1d(sr[km], sg[km]) | |
| return lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b | |
| def plot_r2_cosmology_figure( | |
| lhs_pts: np.ndarray, | |
| r2_mu_arr: np.ndarray, | |
| r2_sig_arr: np.ndarray, | |
| lo_b: np.ndarray, | |
| hi_b: np.ndarray, | |
| out_path: Path, | |
| r2_vmin: float = 0.90, | |
| r2_vmax: float = 1.0, | |
| lhs_n: int | None = None, | |
| maps_per_point: int | None = None, | |
| dpi: int = 160, | |
| ) -> None: | |
| """ | |
| Two-panel scatter in (Ωm, σ8) with a dedicated colorbar column (no overlap with heatmap). | |
| """ | |
| lhs_n = lhs_n if lhs_n is not None else len(r2_mu_arr) | |
| maps_per_point = maps_per_point if maps_per_point is not None else 15 | |
| ldim = lhs_pts.shape[1] | |
| om_plot = lhs_pts[:, 0] | |
| s8_plot = lhs_pts[:, 1] if ldim >= 2 else np.zeros(lhs_n) | |
| cmap = em.cmap_r2_hiflow() | |
| norm = Normalize(vmin=r2_vmin, vmax=r2_vmax) | |
| sm = ScalarMappable(norm=norm, cmap=cmap) | |
| sm.set_array([]) | |
| fig = plt.figure(figsize=(11.5, 4.9)) | |
| # Left: data panels; narrow right strip: colorbar only (avoids fig.colorbar + tight_layout clash) | |
| gs = GridSpec( | |
| nrows=1, | |
| ncols=3, | |
| figure=fig, | |
| width_ratios=[1.0, 1.0, 0.065], | |
| wspace=0.26, | |
| left=0.07, | |
| right=0.98, | |
| top=0.82, | |
| bottom=0.14, | |
| ) | |
| ax0 = fig.add_subplot(gs[0, 0]) | |
| ax1 = fig.add_subplot(gs[0, 1], sharey=ax0) | |
| cax = fig.add_subplot(gs[0, 2]) | |
| pad_x = 0.02 * (float(hi_b[0] - lo_b[0]) + 1e-6) | |
| ax0.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x) | |
| ax1.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x) | |
| if ldim >= 2: | |
| pad_y = 0.02 * (float(hi_b[1] - lo_b[1]) + 1e-6) | |
| ax0.set_ylim(float(lo_b[1]) - pad_y, float(hi_b[1]) + pad_y) | |
| for ax, r2v, subtitle in zip( | |
| (ax0, ax1), | |
| (r2_mu_arr, r2_sig_arr), | |
| (r"$R^2$ for $\mu(P)$", r"$R^2$ for $\sigma(P)$"), | |
| ): | |
| ok = np.isfinite(r2v) | |
| ax.scatter( | |
| om_plot[ok], | |
| s8_plot[ok], | |
| c=np.clip(r2v[ok], r2_vmin, r2_vmax), | |
| cmap=cmap, | |
| norm=norm, | |
| s=52, | |
| alpha=0.92, | |
| edgecolors="k", | |
| linewidths=0.35, | |
| ) | |
| ax.set_xlabel(r"$\Omega_m$", fontsize=12) | |
| ax.set_title(subtitle, fontsize=11) | |
| ax.grid(True, alpha=0.25) | |
| ax0.set_ylabel(r"$\sigma_8$", fontsize=12) | |
| plt.setp(ax1.get_yticklabels(), visible=False) | |
| cb = fig.colorbar(sm, cax=cax) | |
| cb.set_label(r"$R^2$", fontsize=11) | |
| cax.tick_params(labelsize=9) | |
| fig.suptitle( | |
| r"Visual summary of $R^2$ (CAMELS vs conditional DDPM) vs cosmology — " | |
| + f"{lhs_n} Latin Hypercube samples; {maps_per_point} maps / point", | |
| fontsize=11, | |
| fontweight="bold", | |
| y=0.96, | |
| ) | |
| out_path = Path(out_path) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Do not use bbox_inches="tight" — it rebalance axes and can squeeze the colorbar into the panels. | |
| fig.savefig(out_path, dpi=dpi) | |
| plt.close(fig) | |
| def _resolve_training_args(checkpoint: Path) -> Path | None: | |
| run = checkpoint.parent.parent if checkpoint.parent.name == "checkpoints" else checkpoint.parent | |
| for name in ("args.json", "args.txt"): | |
| p = run / name | |
| if p.is_file(): | |
| return p | |
| return None | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="LHS R² cosmology figure (fixed colorbar layout)") | |
| p.add_argument("--checkpoint", type=str, default=str(_DEFAULT_CKPT)) | |
| p.add_argument("--data-dir", type=str, default=_DEFAULT_DATA) | |
| p.add_argument("--split", type=str, default="test", choices=("train", "val", "test")) | |
| p.add_argument("--output", type=str, default=str(_SCRIPT_DIR / "ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png")) | |
| p.add_argument("--from-npz", type=str, default=None, help="Load lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b") | |
| p.add_argument("--save-npz", type=str, default=None, help="After compute, save arrays for --from-npz replot") | |
| p.add_argument("--lhs-n", type=int, default=50) | |
| p.add_argument("--maps-per-point", type=int, default=15) | |
| p.add_argument("--batch-size", type=int, default=8) | |
| p.add_argument("--ddim-steps", type=int, default=50) | |
| p.add_argument("--box-size-mpc", type=float, default=25.0) | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument("--r2-vmin", type=float, default=0.90) | |
| p.add_argument("--r2-vmax", type=float, default=1.0) | |
| p.add_argument("--dpi", type=int, default=160) | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| out_path = Path(args.output) | |
| if args.from_npz: | |
| z = np.load(args.from_npz, allow_pickle=False) | |
| lhs_pts = z["lhs_pts"] | |
| r2_mu_arr = z["r2_mu_arr"] | |
| r2_sig_arr = z["r2_sig_arr"] | |
| lo_b = z["lo_b"] | |
| hi_b = z["hi_b"] | |
| else: | |
| ckpt = Path(args.checkpoint).expanduser().resolve() | |
| if not ckpt.is_file(): | |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt}") | |
| ta = _resolve_training_args(ckpt) | |
| config: dict = {} | |
| if ta is not None: | |
| config = ec.load_training_config(str(ta)) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = ec.build_model(config, device) | |
| ec.load_checkpoint(model, str(ckpt), device) | |
| data_dir = Path(args.data_dir) | |
| images_split, labels_split = ec.load_split(data_dir, args.split) | |
| label_mean, label_std = ec.load_label_stats(data_dir) | |
| lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b = compute_lhs_r2( | |
| model, | |
| images_split, | |
| labels_split, | |
| label_mean, | |
| label_std, | |
| device, | |
| lhs_n=args.lhs_n, | |
| maps_per_point=args.maps_per_point, | |
| batch_size=args.batch_size, | |
| box_size_mpc=args.box_size_mpc, | |
| ddim_steps=args.ddim_steps, | |
| seed=args.seed, | |
| ) | |
| if args.save_npz: | |
| np.savez( | |
| args.save_npz, | |
| lhs_pts=lhs_pts, | |
| r2_mu_arr=r2_mu_arr, | |
| r2_sig_arr=r2_sig_arr, | |
| lo_b=lo_b, | |
| hi_b=hi_b, | |
| ) | |
| print("Saved", args.save_npz) | |
| plot_r2_cosmology_figure( | |
| lhs_pts, | |
| r2_mu_arr, | |
| r2_sig_arr, | |
| lo_b, | |
| hi_b, | |
| out_path, | |
| r2_vmin=args.r2_vmin, | |
| r2_vmax=args.r2_vmax, | |
| lhs_n=args.lhs_n, | |
| maps_per_point=args.maps_per_point, | |
| dpi=args.dpi, | |
| ) | |
| print("Saved", out_path.resolve()) | |
| if __name__ == "__main__": | |
| main() | |