File size: 9,896 Bytes
0d57b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""Shared spectral expansion and transition scheduling utilities."""
from __future__ import annotations

import math
import os
import re
from pathlib import Path
from typing import Iterable, List, Sequence, Tuple

import numpy as np
import pywt
import torch
import yaml
from scipy.fft import dctn, idctn


def power_spectrum(omega: float, A: float, beta: float) -> float:
    """Radial power-law spectrum ``P(omega) = A * |omega|**(-beta)``."""
    return A * abs(omega) ** (-beta)


def activation_time(P_omega: float, delta: float) -> float:
    """Return the activation time for one frequency."""
    if delta >= 1.0 + P_omega:
        raise ValueError(
            f"delta={delta} >= 1 + P={1 + P_omega:.4f}; noise-dominated "
            "criterion is trivially satisfied at all t."
        )
    return 1.0 / (1.0 + math.sqrt(delta / (P_omega * (1.0 + P_omega - delta))))


def delta_optimal_transitions(
    scales: Sequence[float],
    delta: float,
    A: float,
    beta: float,
    H: int,
    W: int,
) -> List[float]:
    """Return transition times for adjacent scales."""
    validate_scales(scales)
    omega_max = min(H, W) / 2.0
    transitions: List[float] = []
    for i in range(len(scales) - 1):
        omega_i = scales[i] * omega_max
        transitions.append(activation_time(power_spectrum(omega_i, A, beta), delta))
    return transitions


def align_timestep(t: float, r: float) -> float:
    """Return the aligned flow-matching time after a scale jump."""
    return r * t / (1.0 + (r - 1.0) * t)


def kappa(t: float, r: float) -> float:
    """Return the state-rescaling factor for a scale jump."""
    return r / (1.0 + (r - 1.0) * t)


def _dct_expand_np(
    x_np: np.ndarray, target_hw: Tuple[int, int], t: float, seed: int,
) -> np.ndarray:
    """NumPy core for DCT spectral expansion."""
    H_tgt, W_tgt = target_hw
    H_src, W_src = x_np.shape[-2], x_np.shape[-1]
    if H_tgt < H_src or W_tgt < W_src:
        raise ValueError(
            f"DCT expand: target {target_hw} smaller than source ({H_src}, {W_src})."
        )
    rng = np.random.default_rng(seed)
    out = np.empty(x_np.shape[:-2] + (H_tgt, W_tgt), dtype=np.float32)
    for idx in np.ndindex(*x_np.shape[:-2]):
        coeffs_src = dctn(x_np[idx], type=2, norm="ortho")
        big = t * rng.standard_normal((H_tgt, W_tgt)).astype(np.float32)
        big[:H_src, :W_src] = coeffs_src
        out[idx] = idctn(big, type=2, norm="ortho").astype(np.float32)
    return out


def _dwt_expand_np(x_np: np.ndarray, t: float, seed: int) -> np.ndarray:
    """NumPy core for Haar wavelet spectral expansion."""
    H_src, W_src = x_np.shape[-2], x_np.shape[-1]
    H_tgt, W_tgt = H_src * 2, W_src * 2
    rng = np.random.default_rng(seed)
    out = np.empty(x_np.shape[:-2] + (H_tgt, W_tgt), dtype=np.float32)
    for idx in np.ndindex(*x_np.shape[:-2]):
        LL = x_np[idx]
        LH = t * rng.standard_normal(LL.shape).astype(np.float32)
        HL = t * rng.standard_normal(LL.shape).astype(np.float32)
        HH = t * rng.standard_normal(LL.shape).astype(np.float32)
        out[idx] = pywt.waverec2(
            [LL, (LH, HL, HH)], "haar", mode="periodization"
        ).astype(np.float32)
    return out


def _fft_expand_np(
    x_np: np.ndarray, target_hw: Tuple[int, int], t: float, seed: int,
) -> np.ndarray:
    """NumPy core for FFT spectral expansion."""
    H_tgt, W_tgt = target_hw
    H_src, W_src = x_np.shape[-2], x_np.shape[-1]
    if H_tgt < H_src or W_tgt < W_src:
        raise ValueError(
            f"FFT expand: target {target_hw} smaller than source ({H_src}, {W_src})."
        )
    rng = np.random.default_rng(seed)
    pad_h, pad_w = (H_tgt - H_src) // 2, (W_tgt - W_src) // 2
    out = np.empty(x_np.shape[:-2] + (H_tgt, W_tgt), dtype=np.float32)
    for idx in np.ndindex(*x_np.shape[:-2]):
        X_src = np.fft.fftshift(np.fft.fft2(x_np[idx], norm="ortho"))
        nr = rng.standard_normal((H_tgt, W_tgt)).astype(np.float32)
        ni = rng.standard_normal((H_tgt, W_tgt)).astype(np.float32)
        X_big = np.fft.fftshift(t * (nr + 1j * ni) / np.sqrt(2.0))
        X_big[pad_h:pad_h + H_src, pad_w:pad_w + W_src] = X_src
        out[idx] = np.fft.ifft2(np.fft.ifftshift(X_big), norm="ortho").real.astype(np.float32)
    return out


def dct_expand_2d(
    x: torch.Tensor, target_hw: Tuple[int, int], t: float, seed: int,
) -> torch.Tensor:
    """Expand the trailing spatial axes with DCT-II coefficients and noise."""
    out = _dct_expand_np(x.detach().cpu().float().numpy(), target_hw, t, seed)
    return torch.from_numpy(out).to(device=x.device, dtype=x.dtype)


def dwt_expand_2d(x: torch.Tensor, t: float, seed: int) -> torch.Tensor:
    """Expand the trailing spatial axes with one-level Haar wavelets."""
    out = _dwt_expand_np(x.detach().cpu().float().numpy(), t, seed)
    return torch.from_numpy(out).to(device=x.device, dtype=x.dtype)


def fft_expand_2d(
    x: torch.Tensor, target_hw: Tuple[int, int], t: float, seed: int,
) -> torch.Tensor:
    """Expand the trailing spatial axes with centered FFT coefficients and noise."""
    out = _fft_expand_np(x.detach().cpu().float().numpy(), target_hw, t, seed)
    return torch.from_numpy(out).to(device=x.device, dtype=x.dtype)


def spectral_expand_and_align(
    x: torch.Tensor,
    s_i: float,
    s_next: float,
    t: float,
    transform: str,
    seed: int,
    H: int,
    W: int,
) -> Tuple[torch.Tensor, float]:
    """Expand ``x`` from ``s_i`` to ``s_next`` and return the aligned time."""
    if transform not in ("dct", "dwt", "fft"):
        raise ValueError(f"transform must be 'dct'|'dwt'|'fft', got {transform!r}")
    if not (0.0 < s_i < s_next <= 1.0):
        raise ValueError(f"require 0 < s_i < s_next <= 1, got s_i={s_i}, s_next={s_next}")

    H_src, W_src = round(s_i * H), round(s_i * W)
    H_tgt, W_tgt = round(s_next * H), round(s_next * W)
    if abs(H_src - s_i * H) > 1e-6 or abs(W_src - s_i * W) > 1e-6:
        raise ValueError(
            f"scale {s_i} does not give integer dims at ({H}, {W}): "
            f"s_i*H = {s_i * H}, s_i*W = {s_i * W}"
        )
    if abs(H_tgt - s_next * H) > 1e-6 or abs(W_tgt - s_next * W) > 1e-6:
        raise ValueError(
            f"scale {s_next} does not give integer dims at ({H}, {W}): "
            f"s_next*H = {s_next * H}, s_next*W = {s_next * W}"
        )

    r_h = H_tgt / H_src
    r_w = W_tgt / W_src
    if abs(r_h - r_w) > 1e-6:
        raise ValueError(
            f"non-uniform scale ratio not supported: r_h={r_h}, r_w={r_w}"
        )
    r = r_h

    x_np = x.detach().cpu().float().numpy()
    if transform == "dwt":
        if abs(r - 2.0) > 1e-6:
            raise ValueError(
                f"DWT requires a 2x scale ratio between consecutive scales; "
                f"got s_next/s_i = {s_next}/{s_i} = {r:.4f}. "
                f"Use --transform dct or --transform fft for non-dyadic scales."
            )
        expanded = _dwt_expand_np(x_np, t, seed)
    elif transform == "dct":
        expanded = _dct_expand_np(x_np, (H_tgt, W_tgt), t, seed)
    else:  # fft
        expanded = _fft_expand_np(x_np, (H_tgt, W_tgt), t, seed)

    # Keep rescaling in float32 before the final cast.
    rescaled = (kappa(t, r) * expanded).astype(np.float32)
    x_tilde = torch.from_numpy(rescaled).to(device=x.device, dtype=x.dtype)
    return x_tilde, align_timestep(t, r)


def find_first_step_below(sigmas: Iterable[float], threshold: float) -> int:
    """Return the first step index whose sigma is below ``threshold``."""
    sigmas = list(sigmas)
    n_steps = len(sigmas) - 1
    for i in range(n_steps):
        s = sigmas[i].item() if hasattr(sigmas[i], "item") else float(sigmas[i])
        if s <= threshold:
            return i
    return n_steps


def reset_scheduler_state(scheduler, step_index: int) -> None:
    """Reset solver buffers after a transition."""
    if hasattr(scheduler, "model_outputs"):
        order = getattr(scheduler.config, "solver_order", 1)
        scheduler.model_outputs = [None] * order
    if hasattr(scheduler, "lower_order_nums"):
        scheduler.lower_order_nums = 0
    if hasattr(scheduler, "last_sample"):
        scheduler.last_sample = None
    scheduler._step_index = step_index


def validate_scales(scales: Sequence[float]) -> None:
    """Validate a strictly increasing scale list ending at 1.0."""
    if len(scales) == 0:
        raise ValueError("--scales is empty; supply at least one value.")
    if any(s <= 0.0 or s > 1.0 for s in scales):
        raise ValueError(f"every scale must be in (0, 1]; got {list(scales)}")
    if abs(scales[-1] - 1.0) > 1e-6:
        raise ValueError(f"last scale must equal 1.0 (full resolution); got {scales[-1]}")
    for a, b in zip(scales[:-1], scales[1:]):
        if not (a < b):
            raise ValueError(f"scales must be strictly increasing; got {list(scales)}")


_ENV_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")


def _expand_env(value):
    if isinstance(value, str):
        def repl(m: re.Match) -> str:
            var = m.group(1)
            if var not in os.environ:
                raise KeyError(
                    f"Environment variable {var!r} referenced in configs.yaml "
                    "is not set."
                )
            return os.environ[var]
        return _ENV_PATTERN.sub(repl, value)
    if isinstance(value, dict):
        return {k: _expand_env(v) for k, v in value.items()}
    if isinstance(value, list):
        return [_expand_env(v) for v in value]
    return value


def load_config(yaml_path: str | Path, model_key: str) -> dict:
    """Load a model config and expand ``${ENV_VAR}`` placeholders."""
    with open(yaml_path, "r") as f:
        data = yaml.safe_load(f)
    if model_key not in data:
        raise KeyError(f"model {model_key!r} not in {yaml_path}; have {list(data)}")
    return _expand_env(data[model_key])