File size: 14,694 Bytes
c6535db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 | import math
import torch
import re
from .extrapolation import (
extrapolate_epsilon_linear,
extrapolate_epsilon_richardson,
)
def _parse_hs_mode(skip_mode: str):
"""Parse decoupled history/stride mode strings like 'h2/s3'.
Returns (history_order:int, skip_calls:int) if recognized and allowed,
otherwise None. Allowed set (explicit, compact):
- h2/sK for K in {2,3,4,5,6}
- h3/sK for K in {3,4,5,6}
- h4/sK for K in {4,5,6}
(No h4/s2 or h4/s3 to avoid overly aggressive cadence for high-order.)
"""
try:
if not isinstance(skip_mode, str):
return None
m = skip_mode.strip().lower()
if m == "h2":
return (2, 2)
if m == "h3":
return (3, 3)
if m == "h4":
return (4, 4)
# hN/sK form
if "/" in m:
left, right = m.split("/", 1)
left = left.strip()
right = right.strip()
if left.startswith("h") and right.startswith("s") and len(left) > 1 and len(right) > 1:
n_str = left[1:]
k_str = right[1:]
if all(ch.isdigit() for ch in n_str) and all(ch.isdigit() for ch in k_str):
n = int(n_str)
k = int(k_str)
allowed = set((2, kk) for kk in (2, 3, 4, 5, 6))
allowed |= set((3, kk) for kk in (3, 4, 5, 6))
allowed |= set((4, kk) for kk in (4, 5, 6))
if (n, k) in allowed:
return (n, k)
return None
except Exception:
return None
def _normalize_skip_mode(skip_mode: str):
"""Map UI aliases to canonical internal modes.
Canonical: none | linear | richardson | quad | adaptive
Aliases: h2 -> linear, h3 -> richardson, h4 -> quad
Legacy: linear, richardson, quad kept for back-compat
"""
if not isinstance(skip_mode, str):
return "none"
m = skip_mode.lower()
if m in ("h2", "linear"):
return "linear"
if m in ("h3", "richardson"):
return "richardson"
if m in ("h4", "quad"):
return "h4"
if m == "adaptive":
return "adaptive"
return "none"
def parse_skip_indices_config(text: str):
"""Parse explicit skip indices configuration string.
Input examples:
- "h2, 3, 4, 7, 9"
- "3 6 8" (defaults to h2)
- "h4, 10, 12" (first hN wins)
Behavior:
- Tokenize on commas and/or whitespace; case-insensitive.
- First hN token wins among {h2,h3,h4}; default predictor is h2 when not specified.
- Collect distinct integer tokens into a set; drop invalids.
- Always filter out step indices 0 and 1 here; engine will further bound to total steps.
Returns: (predictor: 'linear'|'richardson'|'h4', indices: set[int]).
If no indices parsed, indices will be an empty set.
"""
if not isinstance(text, str):
return "linear", set()
tokens = [t.strip() for t in re.split(r"[\s,]+", text.strip()) if t.strip()]
predictor = None
indices = set()
for tok in tokens:
tl = tok.lower()
# predictor selection (first hN wins)
if predictor is None and len(tl) >= 2 and tl[0] == 'h' and tl[1:].isdigit():
n = int(tl[1:])
if n >= 4:
predictor = "h4"
elif n == 3:
predictor = "richardson"
elif n == 2:
predictor = "linear"
continue
# integer index
if tl.lstrip("+-").isdigit():
try:
v = int(tl)
except Exception:
continue
# Filter out 0 and 1 here
if v >= 2:
indices.add(v)
if predictor is None:
predictor = "linear"
return predictor, indices
def should_skip_model_call(error_ratio, step_index, total_steps, skip_mode, epsilon_history, protect_last_steps=4, protect_first_steps=2):
"""Decide whether to skip the model call based on skip_mode pattern and history.
This mirrors the existing logic used in sampling_engine.py, including first/last step
protections and required history lengths for linear/richardson patterns.
"""
# First, handle decoupled hN/sK modes if provided
hs = _parse_hs_mode(skip_mode)
# Never skip first few steps (need to build history)
try:
pfs = int(protect_first_steps)
except Exception:
pfs = 2
if pfs < 0:
pfs = 0
# Never skip last few steps (critical for quality)
try:
pls = int(protect_last_steps)
except Exception:
pls = 4
if pls < 1:
pls = 1
# Decoupled history/stride path
if hs is not None:
history_order, skip_calls = hs
# Guard windows
if step_index < pfs or step_index >= total_steps - pls:
return False, None
# Require sufficient REAL epsilon history
if len(epsilon_history) < history_order:
return False, None
# Align first eligible skip to the later of warmup or required history
anchor = max(pfs, history_order)
# Pattern: Call×K, then Skip → cycle length = K+1, skip on last position
cycle_len = int(skip_calls) + 1
cycle_position = (step_index - anchor) % cycle_len
if cycle_position == (cycle_len - 1):
# Choose predictor by history order
if history_order >= 4:
return True, "h4"
elif history_order == 3:
return True, "richardson"
else:
return True, "linear"
return False, None
# Normalize legacy modes when not using hN/sK
skip_mode = _normalize_skip_mode(skip_mode)
# Never skip if mode is "none"
if skip_mode == "none":
return False, None
# Guard windows for legacy modes
if step_index < pfs:
return False, None
if step_index >= total_steps - pls:
return False, None
# Check if we have enough history
if len(epsilon_history) < 2:
return False, None
if skip_mode == "h4":
# 4-history mode: Call 4, Skip 1 (uses 4-point predictor on skip)
anchor = max(pfs, 4)
if step_index >= anchor:
cycle_position = (step_index - anchor) % 5
if cycle_position == 4:
return True, "h4"
return False, None
elif skip_mode == "linear":
# Call, Call, Skip cycle
cycle_position = (step_index - pfs) % 3
if cycle_position == 2:
return True, "linear"
return False, None
elif skip_mode == "richardson":
# Call, Call, Call, Skip cycle, needs 3 history
if len(epsilon_history) < 3:
return False, None
anchor = pfs + 1 # shift so default pfs=2 behaves like previous (anchor=3)
cycle_position = (step_index - anchor) % 4
if cycle_position == 3:
return True, "richardson"
return False, None
elif skip_mode == "adaptive":
# Pattern-gated adaptive using error_ratio bands (legacy behavior)
# Potential skip (every 3rd step after warmup) with tight band
cycle_position = (step_index - 2) % 3
if cycle_position == 2 and 0.97 <= error_ratio <= 1.03 and len(epsilon_history) >= 2:
return True, "linear"
# Every 4th step after more warmup with very tight band
cycle_position = (step_index - 3) % 4
if cycle_position == 3 and 0.99 <= error_ratio <= 1.01 and len(epsilon_history) >= 3:
return True, "richardson"
return False, None
return False, None
def validate_epsilon_hat(eps_hat, prev_eps=None, min_abs=1e-8, min_rel=1e-6):
"""Validate extrapolated epsilon before using it for a skip step.
Returns (ok, reason, hat_norm, prev_norm).
Reasons: 'none', 'nan_inf', 'too_small_abs', 'too_small_rel'
"""
prev_norm = None
if eps_hat is None:
return False, 'none', 0.0, prev_norm
try:
if torch.isnan(eps_hat).any() or torch.isinf(eps_hat).any():
return False, 'nan_inf', float('nan'), None
hat_norm = torch.norm(eps_hat).item()
except Exception:
return False, 'nan_inf', float('nan'), None
if not math.isfinite(hat_norm):
return False, 'nan_inf', hat_norm, None
if hat_norm < min_abs:
return False, 'too_small_abs', hat_norm, None
if prev_eps is not None:
try:
prev_norm = torch.norm(prev_eps).item()
except Exception:
prev_norm = None
if prev_norm is not None and prev_norm > 0 and hat_norm < (min_rel * prev_norm):
return False, 'too_small_rel', hat_norm, prev_norm
return True, '', hat_norm, prev_norm
def decide_skip_adaptive(
epsilon_history,
step_index,
total_steps,
protect_last_steps=4,
protect_first_steps=2,
tol_relative=0.10,
anchor_interval=None,
max_consecutive_skips=None,
skip_stats=None,
# Optional: predicted-state (x_next) gating context
x_current=None,
sigma_current=None,
sigma_next=None,
sampler_kind=None,
sigma_previous=None,
):
"""Adaptive skip decision using a dual-order epsilon-space gate.
- Builds two predictions from REAL epsilon history at the current step time:
high order (richardson, h3) and lower order (linear, h2).
- Computes a relative error in epsilon space: ||ε̂_hi - ε̂_lo|| / max(||ε̂_hi||, eps_rel).
- Allows skipping when error is below tolerance and guard rails permit.
Returns: (should_skip: bool, epsilon_hat: Tensor|None, meta: dict)
"""
# Guard: skip disabled near start/end
try:
pfs = max(0, int(protect_first_steps))
except Exception:
pfs = 2
try:
pls = max(1, int(protect_last_steps))
except Exception:
pls = 4
if step_index < pfs or step_index >= total_steps - pls:
return False, None, {"reason": "protected_region"}
# History requirement (need >=3 REAL eps for richardson)
if len(epsilon_history) < 3:
return False, None, {"reason": "insufficient_history"}
# Defaults for aggressive profile if not provided
if anchor_interval is None:
anchor_interval = 4 # Absolute cadence: every Nth step index (offset by protect_first_steps)
if max_consecutive_skips is None:
max_consecutive_skips = 2 # Local cap on back-to-back skips
# Local consecutive cap
if isinstance(skip_stats, dict):
consec = skip_stats.get("consecutive_skips", 0)
if consec >= max_consecutive_skips:
return False, None, {"reason": "max_consecutive"}
# Absolute anchor cadence (does not reset on REAL calls)
# Force REAL on every Nth index from the first eligible step (pfs offset), excluding protected tail
if anchor_interval and anchor_interval > 0 and step_index >= pfs:
try:
offset_idx = step_index - pfs
if (offset_idx % int(anchor_interval)) == 0:
return False, None, {"reason": "anchor_abs"}
except Exception:
pass
# Build predictions
eps_hat_hi = extrapolate_epsilon_richardson(epsilon_history)
eps_hat_lo = extrapolate_epsilon_linear(epsilon_history)
if eps_hat_hi is None or eps_hat_lo is None:
return False, None, {"reason": "predict_failed"}
# Finite checks
if (not torch.isfinite(eps_hat_hi).all()) or (not torch.isfinite(eps_hat_lo).all()):
return False, None, {"reason": "non_finite"}
# Prefer predicted-state (x_next) gating when context is provided; else epsilon-space fallback
def _rms(t):
try:
return float(torch.sqrt(torch.mean((t.float()) ** 2)).item())
except Exception:
return float('inf')
if x_current is not None and sigma_current is not None and sigma_next is not None and sampler_kind is not None:
try:
dt = sigma_next - sigma_current
# Compute predicted next states for this sampler
if sampler_kind in ("euler", "res_2s", "dpmpp_2s"):
# Euler-like update
d_hi = -eps_hat_hi / sigma_current
d_lo = -eps_hat_lo / sigma_current
x_next_hi = x_current + dt * d_hi
x_next_lo = x_current + dt * d_lo
elif sampler_kind == "ddim":
# x_next = x0 + (sigma_next/sigma_current)*(x - x0), x0 = x + eps
scale = (sigma_next / sigma_current)
x0_hi = x_current + eps_hat_hi
x0_lo = x_current + eps_hat_lo
x_next_hi = x0_hi + scale * (x_current - x0_hi)
x_next_lo = x0_lo + scale * (x_current - x0_lo)
elif sampler_kind in ("dpmpp_2m", "lms"):
# AB2-style with optional previous derivative from last REAL epsilon
d_hi = -eps_hat_hi / sigma_current
d_lo = -eps_hat_lo / sigma_current
d_prev = None
if sigma_previous is not None and len(epsilon_history) >= 1:
d_prev = -(epsilon_history[-1]) / sigma_previous
if d_prev is not None:
x_next_hi = x_current + dt * (1.5 * d_hi - 0.5 * d_prev)
x_next_lo = x_current + dt * (1.5 * d_lo - 0.5 * d_prev)
else:
x_next_hi = x_current + dt * d_hi
x_next_lo = x_current + dt * d_lo
else:
# Fallback to epsilon-space metric if sampler not supported
x_next_hi = None
x_next_lo = None
if x_next_hi is not None and x_next_lo is not None:
num = _rms(x_next_hi - x_next_lo)
den = max(_rms(x_next_hi), 1e-6)
rel_err = num / den
meta = {"relative_error": rel_err, "hi_order": 3, "lo_order": 2, "space": "x_next"}
if rel_err <= float(tol_relative):
return True, eps_hat_hi, meta
return False, None, meta
except Exception:
# Fall through to epsilon-space
pass
# Epsilon-space fallback
diff = eps_hat_hi - eps_hat_lo
num = _rms(diff)
den = max(_rms(eps_hat_hi), 1e-6)
if not math.isfinite(num) or not math.isfinite(den) or den <= 0:
return False, None, {"reason": "bad_metric"}
rel_err = num / den
meta = {"relative_error": rel_err, "hi_order": 3, "lo_order": 2, "space": "epsilon"}
if rel_err <= float(tol_relative):
return True, eps_hat_hi, meta
return False, None, meta
|