Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
import math
import torch
import re
from .extrapolation import (
extrapolate_epsilon_linear,
extrapolate_epsilon_richardson,
)
def _parse_hs_mode(skip_mode: str):
"""Parse decoupled history/stride mode strings like 'h2/s3'.
Returns (history_order:int, skip_calls:int) if recognized and allowed,
otherwise None. Allowed set (explicit, compact):
- h2/sK for K in {2,3,4,5,6}
- h3/sK for K in {3,4,5,6}
- h4/sK for K in {4,5,6}
(No h4/s2 or h4/s3 to avoid overly aggressive cadence for high-order.)
"""
try:
if not isinstance(skip_mode, str):
return None
m = skip_mode.strip().lower()
if m == "h2":
return (2, 2)
if m == "h3":
return (3, 3)
if m == "h4":
return (4, 4)
# hN/sK form
if "/" in m:
left, right = m.split("/", 1)
left = left.strip()
right = right.strip()
if left.startswith("h") and right.startswith("s") and len(left) > 1 and len(right) > 1:
n_str = left[1:]
k_str = right[1:]
if all(ch.isdigit() for ch in n_str) and all(ch.isdigit() for ch in k_str):
n = int(n_str)
k = int(k_str)
allowed = set((2, kk) for kk in (2, 3, 4, 5, 6))
allowed |= set((3, kk) for kk in (3, 4, 5, 6))
allowed |= set((4, kk) for kk in (4, 5, 6))
if (n, k) in allowed:
return (n, k)
return None
except Exception:
return None
def _normalize_skip_mode(skip_mode: str):
"""Map UI aliases to canonical internal modes.
Canonical: none | linear | richardson | quad | adaptive
Aliases: h2 -> linear, h3 -> richardson, h4 -> quad
Legacy: linear, richardson, quad kept for back-compat
"""
if not isinstance(skip_mode, str):
return "none"
m = skip_mode.lower()
if m in ("h2", "linear"):
return "linear"
if m in ("h3", "richardson"):
return "richardson"
if m in ("h4", "quad"):
return "h4"
if m == "adaptive":
return "adaptive"
return "none"
def parse_skip_indices_config(text: str):
"""Parse explicit skip indices configuration string.
Input examples:
- "h2, 3, 4, 7, 9"
- "3 6 8" (defaults to h2)
- "h4, 10, 12" (first hN wins)
Behavior:
- Tokenize on commas and/or whitespace; case-insensitive.
- First hN token wins among {h2,h3,h4}; default predictor is h2 when not specified.
- Collect distinct integer tokens into a set; drop invalids.
- Always filter out step indices 0 and 1 here; engine will further bound to total steps.
Returns: (predictor: 'linear'|'richardson'|'h4', indices: set[int]).
If no indices parsed, indices will be an empty set.
"""
if not isinstance(text, str):
return "linear", set()
tokens = [t.strip() for t in re.split(r"[\s,]+", text.strip()) if t.strip()]
predictor = None
indices = set()
for tok in tokens:
tl = tok.lower()
# predictor selection (first hN wins)
if predictor is None and len(tl) >= 2 and tl[0] == 'h' and tl[1:].isdigit():
n = int(tl[1:])
if n >= 4:
predictor = "h4"
elif n == 3:
predictor = "richardson"
elif n == 2:
predictor = "linear"
continue
# integer index
if tl.lstrip("+-").isdigit():
try:
v = int(tl)
except Exception:
continue
# Filter out 0 and 1 here
if v >= 2:
indices.add(v)
if predictor is None:
predictor = "linear"
return predictor, indices
def should_skip_model_call(error_ratio, step_index, total_steps, skip_mode, epsilon_history, protect_last_steps=4, protect_first_steps=2):
"""Decide whether to skip the model call based on skip_mode pattern and history.
This mirrors the existing logic used in sampling_engine.py, including first/last step
protections and required history lengths for linear/richardson patterns.
"""
# First, handle decoupled hN/sK modes if provided
hs = _parse_hs_mode(skip_mode)
# Never skip first few steps (need to build history)
try:
pfs = int(protect_first_steps)
except Exception:
pfs = 2
if pfs < 0:
pfs = 0
# Never skip last few steps (critical for quality)
try:
pls = int(protect_last_steps)
except Exception:
pls = 4
if pls < 1:
pls = 1
# Decoupled history/stride path
if hs is not None:
history_order, skip_calls = hs
# Guard windows
if step_index < pfs or step_index >= total_steps - pls:
return False, None
# Require sufficient REAL epsilon history
if len(epsilon_history) < history_order:
return False, None
# Align first eligible skip to the later of warmup or required history
anchor = max(pfs, history_order)
# Pattern: Call×K, then Skip → cycle length = K+1, skip on last position
cycle_len = int(skip_calls) + 1
cycle_position = (step_index - anchor) % cycle_len
if cycle_position == (cycle_len - 1):
# Choose predictor by history order
if history_order >= 4:
return True, "h4"
elif history_order == 3:
return True, "richardson"
else:
return True, "linear"
return False, None
# Normalize legacy modes when not using hN/sK
skip_mode = _normalize_skip_mode(skip_mode)
# Never skip if mode is "none"
if skip_mode == "none":
return False, None
# Guard windows for legacy modes
if step_index < pfs:
return False, None
if step_index >= total_steps - pls:
return False, None
# Check if we have enough history
if len(epsilon_history) < 2:
return False, None
if skip_mode == "h4":
# 4-history mode: Call 4, Skip 1 (uses 4-point predictor on skip)
anchor = max(pfs, 4)
if step_index >= anchor:
cycle_position = (step_index - anchor) % 5
if cycle_position == 4:
return True, "h4"
return False, None
elif skip_mode == "linear":
# Call, Call, Skip cycle
cycle_position = (step_index - pfs) % 3
if cycle_position == 2:
return True, "linear"
return False, None
elif skip_mode == "richardson":
# Call, Call, Call, Skip cycle, needs 3 history
if len(epsilon_history) < 3:
return False, None
anchor = pfs + 1 # shift so default pfs=2 behaves like previous (anchor=3)
cycle_position = (step_index - anchor) % 4
if cycle_position == 3:
return True, "richardson"
return False, None
elif skip_mode == "adaptive":
# Pattern-gated adaptive using error_ratio bands (legacy behavior)
# Potential skip (every 3rd step after warmup) with tight band
cycle_position = (step_index - 2) % 3
if cycle_position == 2 and 0.97 <= error_ratio <= 1.03 and len(epsilon_history) >= 2:
return True, "linear"
# Every 4th step after more warmup with very tight band
cycle_position = (step_index - 3) % 4
if cycle_position == 3 and 0.99 <= error_ratio <= 1.01 and len(epsilon_history) >= 3:
return True, "richardson"
return False, None
return False, None
def validate_epsilon_hat(eps_hat, prev_eps=None, min_abs=1e-8, min_rel=1e-6):
"""Validate extrapolated epsilon before using it for a skip step.
Returns (ok, reason, hat_norm, prev_norm).
Reasons: 'none', 'nan_inf', 'too_small_abs', 'too_small_rel'
"""
prev_norm = None
if eps_hat is None:
return False, 'none', 0.0, prev_norm
try:
if torch.isnan(eps_hat).any() or torch.isinf(eps_hat).any():
return False, 'nan_inf', float('nan'), None
hat_norm = torch.norm(eps_hat).item()
except Exception:
return False, 'nan_inf', float('nan'), None
if not math.isfinite(hat_norm):
return False, 'nan_inf', hat_norm, None
if hat_norm < min_abs:
return False, 'too_small_abs', hat_norm, None
if prev_eps is not None:
try:
prev_norm = torch.norm(prev_eps).item()
except Exception:
prev_norm = None
if prev_norm is not None and prev_norm > 0 and hat_norm < (min_rel * prev_norm):
return False, 'too_small_rel', hat_norm, prev_norm
return True, '', hat_norm, prev_norm
def decide_skip_adaptive(
epsilon_history,
step_index,
total_steps,
protect_last_steps=4,
protect_first_steps=2,
tol_relative=0.10,
anchor_interval=None,
max_consecutive_skips=None,
skip_stats=None,
# Optional: predicted-state (x_next) gating context
x_current=None,
sigma_current=None,
sigma_next=None,
sampler_kind=None,
sigma_previous=None,
):
"""Adaptive skip decision using a dual-order epsilon-space gate.
- Builds two predictions from REAL epsilon history at the current step time:
high order (richardson, h3) and lower order (linear, h2).
- Computes a relative error in epsilon space: ||ε̂_hi - ε̂_lo|| / max(||ε̂_hi||, eps_rel).
- Allows skipping when error is below tolerance and guard rails permit.
Returns: (should_skip: bool, epsilon_hat: Tensor|None, meta: dict)
"""
# Guard: skip disabled near start/end
try:
pfs = max(0, int(protect_first_steps))
except Exception:
pfs = 2
try:
pls = max(1, int(protect_last_steps))
except Exception:
pls = 4
if step_index < pfs or step_index >= total_steps - pls:
return False, None, {"reason": "protected_region"}
# History requirement (need >=3 REAL eps for richardson)
if len(epsilon_history) < 3:
return False, None, {"reason": "insufficient_history"}
# Defaults for aggressive profile if not provided
if anchor_interval is None:
anchor_interval = 4 # Absolute cadence: every Nth step index (offset by protect_first_steps)
if max_consecutive_skips is None:
max_consecutive_skips = 2 # Local cap on back-to-back skips
# Local consecutive cap
if isinstance(skip_stats, dict):
consec = skip_stats.get("consecutive_skips", 0)
if consec >= max_consecutive_skips:
return False, None, {"reason": "max_consecutive"}
# Absolute anchor cadence (does not reset on REAL calls)
# Force REAL on every Nth index from the first eligible step (pfs offset), excluding protected tail
if anchor_interval and anchor_interval > 0 and step_index >= pfs:
try:
offset_idx = step_index - pfs
if (offset_idx % int(anchor_interval)) == 0:
return False, None, {"reason": "anchor_abs"}
except Exception:
pass
# Build predictions
eps_hat_hi = extrapolate_epsilon_richardson(epsilon_history)
eps_hat_lo = extrapolate_epsilon_linear(epsilon_history)
if eps_hat_hi is None or eps_hat_lo is None:
return False, None, {"reason": "predict_failed"}
# Finite checks
if (not torch.isfinite(eps_hat_hi).all()) or (not torch.isfinite(eps_hat_lo).all()):
return False, None, {"reason": "non_finite"}
# Prefer predicted-state (x_next) gating when context is provided; else epsilon-space fallback
def _rms(t):
try:
return float(torch.sqrt(torch.mean((t.float()) ** 2)).item())
except Exception:
return float('inf')
if x_current is not None and sigma_current is not None and sigma_next is not None and sampler_kind is not None:
try:
dt = sigma_next - sigma_current
# Compute predicted next states for this sampler
if sampler_kind in ("euler", "res_2s", "dpmpp_2s"):
# Euler-like update
d_hi = -eps_hat_hi / sigma_current
d_lo = -eps_hat_lo / sigma_current
x_next_hi = x_current + dt * d_hi
x_next_lo = x_current + dt * d_lo
elif sampler_kind == "ddim":
# x_next = x0 + (sigma_next/sigma_current)*(x - x0), x0 = x + eps
scale = (sigma_next / sigma_current)
x0_hi = x_current + eps_hat_hi
x0_lo = x_current + eps_hat_lo
x_next_hi = x0_hi + scale * (x_current - x0_hi)
x_next_lo = x0_lo + scale * (x_current - x0_lo)
elif sampler_kind in ("dpmpp_2m", "lms"):
# AB2-style with optional previous derivative from last REAL epsilon
d_hi = -eps_hat_hi / sigma_current
d_lo = -eps_hat_lo / sigma_current
d_prev = None
if sigma_previous is not None and len(epsilon_history) >= 1:
d_prev = -(epsilon_history[-1]) / sigma_previous
if d_prev is not None:
x_next_hi = x_current + dt * (1.5 * d_hi - 0.5 * d_prev)
x_next_lo = x_current + dt * (1.5 * d_lo - 0.5 * d_prev)
else:
x_next_hi = x_current + dt * d_hi
x_next_lo = x_current + dt * d_lo
else:
# Fallback to epsilon-space metric if sampler not supported
x_next_hi = None
x_next_lo = None
if x_next_hi is not None and x_next_lo is not None:
num = _rms(x_next_hi - x_next_lo)
den = max(_rms(x_next_hi), 1e-6)
rel_err = num / den
meta = {"relative_error": rel_err, "hi_order": 3, "lo_order": 2, "space": "x_next"}
if rel_err <= float(tol_relative):
return True, eps_hat_hi, meta
return False, None, meta
except Exception:
# Fall through to epsilon-space
pass
# Epsilon-space fallback
diff = eps_hat_hi - eps_hat_lo
num = _rms(diff)
den = max(_rms(eps_hat_hi), 1e-6)
if not math.isfinite(num) or not math.isfinite(den) or den <= 0:
return False, None, {"reason": "bad_metric"}
rel_err = num / den
meta = {"relative_error": rel_err, "hi_order": 3, "lo_order": 2, "space": "epsilon"}
if rel_err <= float(tol_relative):
return True, eps_hat_hi, meta
return False, None, meta