DDPM-6param / src /eval_model.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified
"""
Helpers for ddpm_cond_eval.ipynb: R², P(k) on log N_HI fields, DDIM batches.
Uses evaluate_conditional for PowerSpectrum, sampling, and label z-scoring.
"""
from __future__ import annotations
import numpy as np
import torch
from matplotlib.colors import LinearSegmentedColormap
import evaluate_conditional as ec
LO_LOG, HI_LOG = 14.0, 22.0
def r2_score_1d(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""Univariate R² (same as sklearn for 1D arrays)."""
y_true = np.asarray(y_true, dtype=np.float64).ravel()
y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
ss_res = np.sum((y_true - y_pred) ** 2)
ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
if ss_tot < 1e-30:
return 0.0 if ss_res < 1e-30 else float("-inf")
return float(1.0 - ss_res / ss_tot)
def cmap_r2_hiflow() -> LinearSegmentedColormap:
"""Green → yellow → red → purple → dark blue (HIFlow Fig. 7 style)."""
return LinearSegmentedColormap.from_list(
"r2_hiflow",
["#00a651", "#ffcc00", "#e74c3c", "#7d3c98", "#0d1b5c"],
N=256,
)
def images01_to_log_nhi(img01: np.ndarray, lo: float = LO_LOG, hi: float = HI_LOG) -> np.ndarray:
"""Maps in [0,1] linear in column density → log10(N_HI/cm^-2)."""
return lo + (hi - lo) * np.clip(img01, 0.0, 1.0).astype(np.float64)
def per_map_power_spectra_log(
images_01: np.ndarray, box_size: float = 25.0, lo: float = LO_LOG, hi: float = HI_LOG
) -> tuple[np.ndarray, np.ndarray]:
"""Return (dk, Pk) with Pk shape (N, n_bins) using log10 N_HI field."""
logf = images01_to_log_nhi(images_01, lo, hi)
n = logf.shape[0]
npix = logf.shape[-1]
dl = box_size / npix
dk, _ = ec.PowerSpectrum(logf[0], N=npix, dl=dl)
pks = np.stack([ec.PowerSpectrum(logf[i], N=npix, dl=dl)[1] for i in range(n)])
return dk, pks
def sample_batch(
model: torch.nn.Module,
labels_np: np.ndarray,
label_mean: np.ndarray,
label_std: np.ndarray,
normalize_labels: bool,
height: int,
width: int,
device: torch.device,
ddim_steps: int,
progress: bool,
) -> np.ndarray:
"""DDIM sample batch; labels_np shape (B, label_dim). mean/std same length as label_dim."""
labels_np = np.asarray(labels_np, dtype=np.float32)
mean = np.asarray(label_mean, dtype=np.float32)
std = np.asarray(label_std, dtype=np.float32)
if normalize_labels:
t = ec.prepare_labels_for_model(labels_np, mean, std).to(device)
else:
t = torch.from_numpy(labels_np).float().to(device)
with torch.no_grad():
out = model.sample(
labels=t,
channels=1,
height=height,
width=width,
device=device,
progress=progress,
use_ddim=True,
ddim_steps=ddim_steps,
)
return ec.from_model_output(out)