File size: 2,867 Bytes
eb725f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Helpers for ddpm_cond_eval.ipynb: R², P(k) on log N_HI fields, DDIM batches.

Uses evaluate_conditional for PowerSpectrum, sampling, and label z-scoring.
"""
from __future__ import annotations

import numpy as np
import torch
from matplotlib.colors import LinearSegmentedColormap

import evaluate_conditional as ec

LO_LOG, HI_LOG = 14.0, 22.0


def r2_score_1d(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Univariate R² (same as sklearn for 1D arrays)."""
    y_true = np.asarray(y_true, dtype=np.float64).ravel()
    y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
    if ss_tot < 1e-30:
        return 0.0 if ss_res < 1e-30 else float("-inf")
    return float(1.0 - ss_res / ss_tot)


def cmap_r2_hiflow() -> LinearSegmentedColormap:
    """Green → yellow → red → purple → dark blue (HIFlow Fig. 7 style)."""
    return LinearSegmentedColormap.from_list(
        "r2_hiflow",
        ["#00a651", "#ffcc00", "#e74c3c", "#7d3c98", "#0d1b5c"],
        N=256,
    )


def images01_to_log_nhi(img01: np.ndarray, lo: float = LO_LOG, hi: float = HI_LOG) -> np.ndarray:
    """Maps in [0,1] linear in column density → log10(N_HI/cm^-2)."""
    return lo + (hi - lo) * np.clip(img01, 0.0, 1.0).astype(np.float64)


def per_map_power_spectra_log(
    images_01: np.ndarray, box_size: float = 25.0, lo: float = LO_LOG, hi: float = HI_LOG
) -> tuple[np.ndarray, np.ndarray]:
    """Return (dk, Pk) with Pk shape (N, n_bins) using log10 N_HI field."""
    logf = images01_to_log_nhi(images_01, lo, hi)
    n = logf.shape[0]
    npix = logf.shape[-1]
    dl = box_size / npix
    dk, _ = ec.PowerSpectrum(logf[0], N=npix, dl=dl)
    pks = np.stack([ec.PowerSpectrum(logf[i], N=npix, dl=dl)[1] for i in range(n)])
    return dk, pks


def sample_batch(
    model: torch.nn.Module,
    labels_np: np.ndarray,
    label_mean: np.ndarray,
    label_std: np.ndarray,
    normalize_labels: bool,
    height: int,
    width: int,
    device: torch.device,
    ddim_steps: int,
    progress: bool,
) -> np.ndarray:
    """DDIM sample batch; labels_np shape (B, label_dim). mean/std same length as label_dim."""
    labels_np = np.asarray(labels_np, dtype=np.float32)
    mean = np.asarray(label_mean, dtype=np.float32)
    std = np.asarray(label_std, dtype=np.float32)
    if normalize_labels:
        t = ec.prepare_labels_for_model(labels_np, mean, std).to(device)
    else:
        t = torch.from_numpy(labels_np).float().to(device)
    with torch.no_grad():
        out = model.sample(
            labels=t,
            channels=1,
            height=height,
            width=width,
            device=device,
            progress=progress,
            use_ddim=True,
            ddim_steps=ddim_steps,
        )
    return ec.from_model_output(out)