| """Shared spectral expansion and transition scheduling utilities.""" |
| from __future__ import annotations |
|
|
| import math |
| import os |
| import re |
| from pathlib import Path |
| from typing import Iterable, List, Sequence, Tuple |
|
|
| import numpy as np |
| import pywt |
| import torch |
| import yaml |
| from scipy.fft import dctn, idctn |
|
|
|
|
| def power_spectrum(omega: float, A: float, beta: float) -> float: |
| """Radial power-law spectrum ``P(omega) = A * |omega|**(-beta)``.""" |
| return A * abs(omega) ** (-beta) |
|
|
|
|
| def activation_time(P_omega: float, delta: float) -> float: |
| """Return the activation time for one frequency.""" |
| if delta >= 1.0 + P_omega: |
| raise ValueError( |
| f"delta={delta} >= 1 + P={1 + P_omega:.4f}; noise-dominated " |
| "criterion is trivially satisfied at all t." |
| ) |
| return 1.0 / (1.0 + math.sqrt(delta / (P_omega * (1.0 + P_omega - delta)))) |
|
|
|
|
| def delta_optimal_transitions( |
| scales: Sequence[float], |
| delta: float, |
| A: float, |
| beta: float, |
| H: int, |
| W: int, |
| ) -> List[float]: |
| """Return transition times for adjacent scales.""" |
| validate_scales(scales) |
| omega_max = min(H, W) / 2.0 |
| transitions: List[float] = [] |
| for i in range(len(scales) - 1): |
| omega_i = scales[i] * omega_max |
| transitions.append(activation_time(power_spectrum(omega_i, A, beta), delta)) |
| return transitions |
|
|
|
|
| def align_timestep(t: float, r: float) -> float: |
| """Return the aligned flow-matching time after a scale jump.""" |
| return r * t / (1.0 + (r - 1.0) * t) |
|
|
|
|
| def kappa(t: float, r: float) -> float: |
| """Return the state-rescaling factor for a scale jump.""" |
| return r / (1.0 + (r - 1.0) * t) |
|
|
|
|
| def _dct_expand_np( |
| x_np: np.ndarray, target_hw: Tuple[int, int], t: float, seed: int, |
| ) -> np.ndarray: |
| """NumPy core for DCT spectral expansion.""" |
| H_tgt, W_tgt = target_hw |
| H_src, W_src = x_np.shape[-2], x_np.shape[-1] |
| if H_tgt < H_src or W_tgt < W_src: |
| raise ValueError( |
| f"DCT expand: target {target_hw} smaller than source ({H_src}, {W_src})." |
| ) |
| rng = np.random.default_rng(seed) |
| out = np.empty(x_np.shape[:-2] + (H_tgt, W_tgt), dtype=np.float32) |
| for idx in np.ndindex(*x_np.shape[:-2]): |
| coeffs_src = dctn(x_np[idx], type=2, norm="ortho") |
| big = t * rng.standard_normal((H_tgt, W_tgt)).astype(np.float32) |
| big[:H_src, :W_src] = coeffs_src |
| out[idx] = idctn(big, type=2, norm="ortho").astype(np.float32) |
| return out |
|
|
|
|
| def _dwt_expand_np(x_np: np.ndarray, t: float, seed: int) -> np.ndarray: |
| """NumPy core for Haar wavelet spectral expansion.""" |
| H_src, W_src = x_np.shape[-2], x_np.shape[-1] |
| H_tgt, W_tgt = H_src * 2, W_src * 2 |
| rng = np.random.default_rng(seed) |
| out = np.empty(x_np.shape[:-2] + (H_tgt, W_tgt), dtype=np.float32) |
| for idx in np.ndindex(*x_np.shape[:-2]): |
| LL = x_np[idx] |
| LH = t * rng.standard_normal(LL.shape).astype(np.float32) |
| HL = t * rng.standard_normal(LL.shape).astype(np.float32) |
| HH = t * rng.standard_normal(LL.shape).astype(np.float32) |
| out[idx] = pywt.waverec2( |
| [LL, (LH, HL, HH)], "haar", mode="periodization" |
| ).astype(np.float32) |
| return out |
|
|
|
|
| def _fft_expand_np( |
| x_np: np.ndarray, target_hw: Tuple[int, int], t: float, seed: int, |
| ) -> np.ndarray: |
| """NumPy core for FFT spectral expansion.""" |
| H_tgt, W_tgt = target_hw |
| H_src, W_src = x_np.shape[-2], x_np.shape[-1] |
| if H_tgt < H_src or W_tgt < W_src: |
| raise ValueError( |
| f"FFT expand: target {target_hw} smaller than source ({H_src}, {W_src})." |
| ) |
| rng = np.random.default_rng(seed) |
| pad_h, pad_w = (H_tgt - H_src) // 2, (W_tgt - W_src) // 2 |
| out = np.empty(x_np.shape[:-2] + (H_tgt, W_tgt), dtype=np.float32) |
| for idx in np.ndindex(*x_np.shape[:-2]): |
| X_src = np.fft.fftshift(np.fft.fft2(x_np[idx], norm="ortho")) |
| nr = rng.standard_normal((H_tgt, W_tgt)).astype(np.float32) |
| ni = rng.standard_normal((H_tgt, W_tgt)).astype(np.float32) |
| X_big = np.fft.fftshift(t * (nr + 1j * ni) / np.sqrt(2.0)) |
| X_big[pad_h:pad_h + H_src, pad_w:pad_w + W_src] = X_src |
| out[idx] = np.fft.ifft2(np.fft.ifftshift(X_big), norm="ortho").real.astype(np.float32) |
| return out |
|
|
|
|
| def dct_expand_2d( |
| x: torch.Tensor, target_hw: Tuple[int, int], t: float, seed: int, |
| ) -> torch.Tensor: |
| """Expand the trailing spatial axes with DCT-II coefficients and noise.""" |
| out = _dct_expand_np(x.detach().cpu().float().numpy(), target_hw, t, seed) |
| return torch.from_numpy(out).to(device=x.device, dtype=x.dtype) |
|
|
|
|
| def dwt_expand_2d(x: torch.Tensor, t: float, seed: int) -> torch.Tensor: |
| """Expand the trailing spatial axes with one-level Haar wavelets.""" |
| out = _dwt_expand_np(x.detach().cpu().float().numpy(), t, seed) |
| return torch.from_numpy(out).to(device=x.device, dtype=x.dtype) |
|
|
|
|
| def fft_expand_2d( |
| x: torch.Tensor, target_hw: Tuple[int, int], t: float, seed: int, |
| ) -> torch.Tensor: |
| """Expand the trailing spatial axes with centered FFT coefficients and noise.""" |
| out = _fft_expand_np(x.detach().cpu().float().numpy(), target_hw, t, seed) |
| return torch.from_numpy(out).to(device=x.device, dtype=x.dtype) |
|
|
|
|
| def spectral_expand_and_align( |
| x: torch.Tensor, |
| s_i: float, |
| s_next: float, |
| t: float, |
| transform: str, |
| seed: int, |
| H: int, |
| W: int, |
| ) -> Tuple[torch.Tensor, float]: |
| """Expand ``x`` from ``s_i`` to ``s_next`` and return the aligned time.""" |
| if transform not in ("dct", "dwt", "fft"): |
| raise ValueError(f"transform must be 'dct'|'dwt'|'fft', got {transform!r}") |
| if not (0.0 < s_i < s_next <= 1.0): |
| raise ValueError(f"require 0 < s_i < s_next <= 1, got s_i={s_i}, s_next={s_next}") |
|
|
| H_src, W_src = round(s_i * H), round(s_i * W) |
| H_tgt, W_tgt = round(s_next * H), round(s_next * W) |
| if abs(H_src - s_i * H) > 1e-6 or abs(W_src - s_i * W) > 1e-6: |
| raise ValueError( |
| f"scale {s_i} does not give integer dims at ({H}, {W}): " |
| f"s_i*H = {s_i * H}, s_i*W = {s_i * W}" |
| ) |
| if abs(H_tgt - s_next * H) > 1e-6 or abs(W_tgt - s_next * W) > 1e-6: |
| raise ValueError( |
| f"scale {s_next} does not give integer dims at ({H}, {W}): " |
| f"s_next*H = {s_next * H}, s_next*W = {s_next * W}" |
| ) |
|
|
| r_h = H_tgt / H_src |
| r_w = W_tgt / W_src |
| if abs(r_h - r_w) > 1e-6: |
| raise ValueError( |
| f"non-uniform scale ratio not supported: r_h={r_h}, r_w={r_w}" |
| ) |
| r = r_h |
|
|
| x_np = x.detach().cpu().float().numpy() |
| if transform == "dwt": |
| if abs(r - 2.0) > 1e-6: |
| raise ValueError( |
| f"DWT requires a 2x scale ratio between consecutive scales; " |
| f"got s_next/s_i = {s_next}/{s_i} = {r:.4f}. " |
| f"Use --transform dct or --transform fft for non-dyadic scales." |
| ) |
| expanded = _dwt_expand_np(x_np, t, seed) |
| elif transform == "dct": |
| expanded = _dct_expand_np(x_np, (H_tgt, W_tgt), t, seed) |
| else: |
| expanded = _fft_expand_np(x_np, (H_tgt, W_tgt), t, seed) |
|
|
| |
| rescaled = (kappa(t, r) * expanded).astype(np.float32) |
| x_tilde = torch.from_numpy(rescaled).to(device=x.device, dtype=x.dtype) |
| return x_tilde, align_timestep(t, r) |
|
|
|
|
| def find_first_step_below(sigmas: Iterable[float], threshold: float) -> int: |
| """Return the first step index whose sigma is below ``threshold``.""" |
| sigmas = list(sigmas) |
| n_steps = len(sigmas) - 1 |
| for i in range(n_steps): |
| s = sigmas[i].item() if hasattr(sigmas[i], "item") else float(sigmas[i]) |
| if s <= threshold: |
| return i |
| return n_steps |
|
|
|
|
| def reset_scheduler_state(scheduler, step_index: int) -> None: |
| """Reset solver buffers after a transition.""" |
| if hasattr(scheduler, "model_outputs"): |
| order = getattr(scheduler.config, "solver_order", 1) |
| scheduler.model_outputs = [None] * order |
| if hasattr(scheduler, "lower_order_nums"): |
| scheduler.lower_order_nums = 0 |
| if hasattr(scheduler, "last_sample"): |
| scheduler.last_sample = None |
| scheduler._step_index = step_index |
|
|
|
|
| def validate_scales(scales: Sequence[float]) -> None: |
| """Validate a strictly increasing scale list ending at 1.0.""" |
| if len(scales) == 0: |
| raise ValueError("--scales is empty; supply at least one value.") |
| if any(s <= 0.0 or s > 1.0 for s in scales): |
| raise ValueError(f"every scale must be in (0, 1]; got {list(scales)}") |
| if abs(scales[-1] - 1.0) > 1e-6: |
| raise ValueError(f"last scale must equal 1.0 (full resolution); got {scales[-1]}") |
| for a, b in zip(scales[:-1], scales[1:]): |
| if not (a < b): |
| raise ValueError(f"scales must be strictly increasing; got {list(scales)}") |
|
|
|
|
| _ENV_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") |
|
|
|
|
| def _expand_env(value): |
| if isinstance(value, str): |
| def repl(m: re.Match) -> str: |
| var = m.group(1) |
| if var not in os.environ: |
| raise KeyError( |
| f"Environment variable {var!r} referenced in configs.yaml " |
| "is not set." |
| ) |
| return os.environ[var] |
| return _ENV_PATTERN.sub(repl, value) |
| if isinstance(value, dict): |
| return {k: _expand_env(v) for k, v in value.items()} |
| if isinstance(value, list): |
| return [_expand_env(v) for v in value] |
| return value |
|
|
|
|
| def load_config(yaml_path: str | Path, model_key: str) -> dict: |
| """Load a model config and expand ``${ENV_VAR}`` placeholders.""" |
| with open(yaml_path, "r") as f: |
| data = yaml.safe_load(f) |
| if model_key not in data: |
| raise KeyError(f"model {model_key!r} not in {yaml_path}; have {list(data)}") |
| return _expand_env(data[model_key]) |
|
|