Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified | """ | |
| Helpers for ddpm_cond_eval.ipynb: R², P(k) on log N_HI fields, DDIM batches. | |
| Uses evaluate_conditional for PowerSpectrum, sampling, and label z-scoring. | |
| """ | |
| from __future__ import annotations | |
| import numpy as np | |
| import torch | |
| from matplotlib.colors import LinearSegmentedColormap | |
| import evaluate_conditional as ec | |
| LO_LOG, HI_LOG = 14.0, 22.0 | |
| def r2_score_1d(y_true: np.ndarray, y_pred: np.ndarray) -> float: | |
| """Univariate R² (same as sklearn for 1D arrays).""" | |
| y_true = np.asarray(y_true, dtype=np.float64).ravel() | |
| y_pred = np.asarray(y_pred, dtype=np.float64).ravel() | |
| ss_res = np.sum((y_true - y_pred) ** 2) | |
| ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) | |
| if ss_tot < 1e-30: | |
| return 0.0 if ss_res < 1e-30 else float("-inf") | |
| return float(1.0 - ss_res / ss_tot) | |
| def cmap_r2_hiflow() -> LinearSegmentedColormap: | |
| """Green → yellow → red → purple → dark blue (HIFlow Fig. 7 style).""" | |
| return LinearSegmentedColormap.from_list( | |
| "r2_hiflow", | |
| ["#00a651", "#ffcc00", "#e74c3c", "#7d3c98", "#0d1b5c"], | |
| N=256, | |
| ) | |
| def images01_to_log_nhi(img01: np.ndarray, lo: float = LO_LOG, hi: float = HI_LOG) -> np.ndarray: | |
| """Maps in [0,1] linear in column density → log10(N_HI/cm^-2).""" | |
| return lo + (hi - lo) * np.clip(img01, 0.0, 1.0).astype(np.float64) | |
| def per_map_power_spectra_log( | |
| images_01: np.ndarray, box_size: float = 25.0, lo: float = LO_LOG, hi: float = HI_LOG | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| """Return (dk, Pk) with Pk shape (N, n_bins) using log10 N_HI field.""" | |
| logf = images01_to_log_nhi(images_01, lo, hi) | |
| n = logf.shape[0] | |
| npix = logf.shape[-1] | |
| dl = box_size / npix | |
| dk, _ = ec.PowerSpectrum(logf[0], N=npix, dl=dl) | |
| pks = np.stack([ec.PowerSpectrum(logf[i], N=npix, dl=dl)[1] for i in range(n)]) | |
| return dk, pks | |
| def sample_batch( | |
| model: torch.nn.Module, | |
| labels_np: np.ndarray, | |
| label_mean: np.ndarray, | |
| label_std: np.ndarray, | |
| normalize_labels: bool, | |
| height: int, | |
| width: int, | |
| device: torch.device, | |
| ddim_steps: int, | |
| progress: bool, | |
| ) -> np.ndarray: | |
| """DDIM sample batch; labels_np shape (B, label_dim). mean/std same length as label_dim.""" | |
| labels_np = np.asarray(labels_np, dtype=np.float32) | |
| mean = np.asarray(label_mean, dtype=np.float32) | |
| std = np.asarray(label_std, dtype=np.float32) | |
| if normalize_labels: | |
| t = ec.prepare_labels_for_model(labels_np, mean, std).to(device) | |
| else: | |
| t = torch.from_numpy(labels_np).float().to(device) | |
| with torch.no_grad(): | |
| out = model.sample( | |
| labels=t, | |
| channels=1, | |
| height=height, | |
| width=width, | |
| device=device, | |
| progress=progress, | |
| use_ddim=True, | |
| ddim_steps=ddim_steps, | |
| ) | |
| return ec.from_model_output(out) | |