| import torch |
| import threading |
|
|
|
|
| |
| |
| |
| _sigma_ctx = threading.local() |
|
|
|
|
| def set_sigma_target(target): |
| """Set the sigma target for the current step (called by engine).""" |
| _sigma_ctx.target = target |
|
|
|
|
| def _get_sigma_target(): |
| """Get the current sigma target, or None if not set.""" |
| return getattr(_sigma_ctx, 'target', None) |
|
|
|
|
| |
| |
| |
| _denoised_ctx = threading.local() |
|
|
|
|
| def set_current_x(x): |
| """Set the current noisy latent for denoised-mode conversion (called by engine).""" |
| _denoised_ctx.current_x = x |
|
|
|
|
| def _get_current_x(): |
| """Get the current noisy latent, or None if not set.""" |
| return getattr(_denoised_ctx, 'current_x', None) |
|
|
|
|
| |
| |
| |
| class SigmaAwareHistory(list): |
| """Epsilon history that also tracks sigma and denoised per entry. |
| |
| Backward-compatible with plain list — all existing sampler code that |
| treats epsilon_history as a list works unchanged. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.sigmas = [] |
| self.denoised = [] |
| self._pending_sigma = None |
| self._pending_x = None |
|
|
| def set_pending_sigma(self, sigma): |
| """Set the sigma that will be recorded with the next appended epsilon.""" |
| self._pending_sigma = sigma |
|
|
| def set_pending_x(self, x): |
| """Set the noisy latent so denoised can be computed on append.""" |
| self._pending_x = x |
|
|
| def append(self, epsilon): |
| super().append(epsilon) |
| |
| if self._pending_sigma is not None: |
| self.sigmas.append(self._pending_sigma) |
| self._pending_sigma = None |
| else: |
| self.sigmas.append(None) |
| |
| if self._pending_x is not None: |
| self.denoised.append(epsilon + self._pending_x) |
| self._pending_x = None |
| else: |
| self.denoised.append(None) |
|
|
|
|
| |
| |
| |
| def _sigma_aware_ok(epsilon_history, n): |
| """Return (target, [s0..s_{n-1}]) if sigma-aware is active and usable, else None.""" |
| target = _get_sigma_target() |
| if target is None: |
| return None |
| if not hasattr(epsilon_history, 'sigmas'): |
| return None |
| sigs = epsilon_history.sigmas |
| if len(sigs) < n: |
| return None |
| trailing = [sigs[-(n - i)] for i in range(n)] |
| if any(s is None for s in trailing): |
| return None |
| return target, trailing |
|
|
|
|
| def _denoised_mode_ok(epsilon_history, n): |
| """Return current_x if denoised extrapolation is active and usable, else None.""" |
| current_x = _get_current_x() |
| if current_x is None: |
| return None |
| if not hasattr(epsilon_history, 'denoised'): |
| return None |
| den = epsilon_history.denoised |
| if len(den) < n: |
| return None |
| if any(den[-(n - i)] is None for i in range(n)): |
| return None |
| return current_x |
|
|
|
|
| def _get_values(epsilon_history, n, current_x): |
| """Get the N trailing values to extrapolate (denoised if current_x, else epsilon).""" |
| if current_x is not None: |
| return [epsilon_history.denoised[-(n - i)] for i in range(n)] |
| return [epsilon_history[-(n - i)] for i in range(n)] |
|
|
|
|
| def _maybe_convert(result, current_x): |
| """Convert predicted denoised back to epsilon if in denoised mode.""" |
| if current_x is not None: |
| return result - current_x |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def extrapolate_epsilon_linear(epsilon_history): |
| """Linear (2-point) epsilon extrapolation using last two REAL epsilons. |
| |
| Args: |
| epsilon_history: list[Tensor] of REAL epsilons, oldest..newest |
| Returns: |
| Tensor or None |
| """ |
| if len(epsilon_history) < 2: |
| return None |
|
|
| current_x = _denoised_mode_ok(epsilon_history, 2) |
| v0, v1 = _get_values(epsilon_history, 2, current_x) |
|
|
| |
| sa = _sigma_aware_ok(epsilon_history, 2) |
| if sa is not None: |
| target, (s0, s1) = sa |
| denom = s1 - s0 |
| if abs(denom) > 1e-12: |
| L0 = (target - s1) / (s0 - s1) |
| L1 = (target - s0) / (s1 - s0) |
| return _maybe_convert(L0 * v0 + L1 * v1, current_x) |
|
|
| |
| return _maybe_convert(v1 + (v1 - v0), current_x) |
|
|
|
|
| def extrapolate_epsilon_richardson(epsilon_history): |
| """Richardson (3-point) epsilon extrapolation using last three REAL epsilons. |
| |
| Args: |
| epsilon_history: list[Tensor] of REAL epsilons, oldest..newest |
| Returns: |
| Tensor or None |
| """ |
| if len(epsilon_history) < 3: |
| return extrapolate_epsilon_linear(epsilon_history) |
|
|
| current_x = _denoised_mode_ok(epsilon_history, 3) |
| v0, v1, v2 = _get_values(epsilon_history, 3, current_x) |
|
|
| |
| sa = _sigma_aware_ok(epsilon_history, 3) |
| if sa is not None: |
| target, (s0, s1, s2) = sa |
| d01 = s0 - s1 |
| d02 = s0 - s2 |
| d10 = s1 - s0 |
| d12 = s1 - s2 |
| d20 = s2 - s0 |
| d21 = s2 - s1 |
| if abs(d01 * d02) > 1e-12 and abs(d10 * d12) > 1e-12 and abs(d20 * d21) > 1e-12: |
| L0 = (target - s1) * (target - s2) / (d01 * d02) |
| L1 = (target - s0) * (target - s2) / (d10 * d12) |
| L2 = (target - s0) * (target - s1) / (d20 * d21) |
| return _maybe_convert(L0 * v0 + L1 * v1 + L2 * v2, current_x) |
|
|
| |
| return _maybe_convert(3 * v2 - 3 * v1 + v0, current_x) |
|
|
|
|
| def extrapolate_epsilon_h4(epsilon_history): |
| """4-point (cubic) epsilon extrapolation using last four REAL epsilons. |
| |
| Assumes uniform step spacing in the prediction index. Uses Lagrange |
| coefficients for points at t = [-3, -2, -1, 0] to predict at t = 1: |
| eps_hat_{n+1} = -1*eps_{n-3} + 4*eps_{n-2} - 6*eps_{n-1} + 4*eps_{n} |
| |
| Falls back to 3-point when history is insufficient. |
| """ |
| if len(epsilon_history) < 4: |
| return extrapolate_epsilon_richardson(epsilon_history) |
|
|
| current_x = _denoised_mode_ok(epsilon_history, 4) |
| vals = _get_values(epsilon_history, 4, current_x) |
|
|
| |
| sa = _sigma_aware_ok(epsilon_history, 4) |
| if sa is not None: |
| target, (s0, s1, s2, s3) = sa |
| nodes = [s0, s1, s2, s3] |
| |
| ok = True |
| for i in range(4): |
| prod = 1.0 |
| for j in range(4): |
| if i != j: |
| prod *= (nodes[i] - nodes[j]) |
| if abs(prod) < 1e-12: |
| ok = False |
| break |
| if ok: |
| result = torch.zeros_like(vals[0]) |
| for i in range(4): |
| basis = 1.0 |
| for j in range(4): |
| if i != j: |
| basis *= (target - nodes[j]) / (nodes[i] - nodes[j]) |
| result = result + basis * vals[i] |
| return _maybe_convert(result, current_x) |
|
|
| |
| return _maybe_convert((-1.0) * vals[0] + 4.0 * vals[1] - 6.0 * vals[2] + 4.0 * vals[3], current_x) |
|
|