""" Helpers for ddpm_cond_eval.ipynb: R², P(k) on log N_HI fields, DDIM batches. Uses evaluate_conditional for PowerSpectrum, sampling, and label z-scoring. """ from __future__ import annotations import numpy as np import torch from matplotlib.colors import LinearSegmentedColormap import evaluate_conditional as ec LO_LOG, HI_LOG = 14.0, 22.0 def r2_score_1d(y_true: np.ndarray, y_pred: np.ndarray) -> float: """Univariate R² (same as sklearn for 1D arrays).""" y_true = np.asarray(y_true, dtype=np.float64).ravel() y_pred = np.asarray(y_pred, dtype=np.float64).ravel() ss_res = np.sum((y_true - y_pred) ** 2) ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) if ss_tot < 1e-30: return 0.0 if ss_res < 1e-30 else float("-inf") return float(1.0 - ss_res / ss_tot) def cmap_r2_hiflow() -> LinearSegmentedColormap: """Green → yellow → red → purple → dark blue (HIFlow Fig. 7 style).""" return LinearSegmentedColormap.from_list( "r2_hiflow", ["#00a651", "#ffcc00", "#e74c3c", "#7d3c98", "#0d1b5c"], N=256, ) def images01_to_log_nhi(img01: np.ndarray, lo: float = LO_LOG, hi: float = HI_LOG) -> np.ndarray: """Maps in [0,1] linear in column density → log10(N_HI/cm^-2).""" return lo + (hi - lo) * np.clip(img01, 0.0, 1.0).astype(np.float64) def per_map_power_spectra_log( images_01: np.ndarray, box_size: float = 25.0, lo: float = LO_LOG, hi: float = HI_LOG ) -> tuple[np.ndarray, np.ndarray]: """Return (dk, Pk) with Pk shape (N, n_bins) using log10 N_HI field.""" logf = images01_to_log_nhi(images_01, lo, hi) n = logf.shape[0] npix = logf.shape[-1] dl = box_size / npix dk, _ = ec.PowerSpectrum(logf[0], N=npix, dl=dl) pks = np.stack([ec.PowerSpectrum(logf[i], N=npix, dl=dl)[1] for i in range(n)]) return dk, pks def sample_batch( model: torch.nn.Module, labels_np: np.ndarray, label_mean: np.ndarray, label_std: np.ndarray, normalize_labels: bool, height: int, width: int, device: torch.device, ddim_steps: int, progress: bool, ) -> np.ndarray: """DDIM sample batch; labels_np shape (B, label_dim). mean/std same length as label_dim.""" labels_np = np.asarray(labels_np, dtype=np.float32) mean = np.asarray(label_mean, dtype=np.float32) std = np.asarray(label_std, dtype=np.float32) if normalize_labels: t = ec.prepare_labels_for_model(labels_np, mean, std).to(device) else: t = torch.from_numpy(labels_np).float().to(device) with torch.no_grad(): out = model.sample( labels=t, channels=1, height=height, width=width, device=device, progress=progress, use_ddim=True, ddim_steps=ddim_steps, ) return ec.from_model_output(out)