| import math |
| import torch |
| import re |
| from .extrapolation import ( |
| extrapolate_epsilon_linear, |
| extrapolate_epsilon_richardson, |
| ) |
|
|
|
|
| def _parse_hs_mode(skip_mode: str): |
| """Parse decoupled history/stride mode strings like 'h2/s3'. |
| |
| Returns (history_order:int, skip_calls:int) if recognized and allowed, |
| otherwise None. Allowed set (explicit, compact): |
| - h2/sK for K in {2,3,4,5,6} |
| - h3/sK for K in {3,4,5,6} |
| - h4/sK for K in {4,5,6} |
| (No h4/s2 or h4/s3 to avoid overly aggressive cadence for high-order.) |
| """ |
| try: |
| if not isinstance(skip_mode, str): |
| return None |
| m = skip_mode.strip().lower() |
| if m == "h2": |
| return (2, 2) |
| if m == "h3": |
| return (3, 3) |
| if m == "h4": |
| return (4, 4) |
| |
| if "/" in m: |
| left, right = m.split("/", 1) |
| left = left.strip() |
| right = right.strip() |
| if left.startswith("h") and right.startswith("s") and len(left) > 1 and len(right) > 1: |
| n_str = left[1:] |
| k_str = right[1:] |
| if all(ch.isdigit() for ch in n_str) and all(ch.isdigit() for ch in k_str): |
| n = int(n_str) |
| k = int(k_str) |
| allowed = set((2, kk) for kk in (2, 3, 4, 5, 6)) |
| allowed |= set((3, kk) for kk in (3, 4, 5, 6)) |
| allowed |= set((4, kk) for kk in (4, 5, 6)) |
| if (n, k) in allowed: |
| return (n, k) |
| return None |
| except Exception: |
| return None |
|
|
|
|
| def _normalize_skip_mode(skip_mode: str): |
| """Map UI aliases to canonical internal modes. |
| |
| Canonical: none | linear | richardson | quad | adaptive |
| Aliases: h2 -> linear, h3 -> richardson, h4 -> quad |
| Legacy: linear, richardson, quad kept for back-compat |
| """ |
| if not isinstance(skip_mode, str): |
| return "none" |
| m = skip_mode.lower() |
| if m in ("h2", "linear"): |
| return "linear" |
| if m in ("h3", "richardson"): |
| return "richardson" |
| if m in ("h4", "quad"): |
| return "h4" |
| if m == "adaptive": |
| return "adaptive" |
| return "none" |
|
|
|
|
| def parse_skip_indices_config(text: str): |
| """Parse explicit skip indices configuration string. |
| |
| Input examples: |
| - "h2, 3, 4, 7, 9" |
| - "3 6 8" (defaults to h2) |
| - "h4, 10, 12" (first hN wins) |
| |
| Behavior: |
| - Tokenize on commas and/or whitespace; case-insensitive. |
| - First hN token wins among {h2,h3,h4}; default predictor is h2 when not specified. |
| - Collect distinct integer tokens into a set; drop invalids. |
| - Always filter out step indices 0 and 1 here; engine will further bound to total steps. |
| |
| Returns: (predictor: 'linear'|'richardson'|'h4', indices: set[int]). |
| If no indices parsed, indices will be an empty set. |
| """ |
| if not isinstance(text, str): |
| return "linear", set() |
|
|
| tokens = [t.strip() for t in re.split(r"[\s,]+", text.strip()) if t.strip()] |
| predictor = None |
| indices = set() |
|
|
| for tok in tokens: |
| tl = tok.lower() |
| |
| if predictor is None and len(tl) >= 2 and tl[0] == 'h' and tl[1:].isdigit(): |
| n = int(tl[1:]) |
| if n >= 4: |
| predictor = "h4" |
| elif n == 3: |
| predictor = "richardson" |
| elif n == 2: |
| predictor = "linear" |
| continue |
| |
| if tl.lstrip("+-").isdigit(): |
| try: |
| v = int(tl) |
| except Exception: |
| continue |
| |
| if v >= 2: |
| indices.add(v) |
|
|
| if predictor is None: |
| predictor = "linear" |
|
|
| return predictor, indices |
|
|
|
|
| def should_skip_model_call(error_ratio, step_index, total_steps, skip_mode, epsilon_history, protect_last_steps=4, protect_first_steps=2): |
| """Decide whether to skip the model call based on skip_mode pattern and history. |
| |
| This mirrors the existing logic used in sampling_engine.py, including first/last step |
| protections and required history lengths for linear/richardson patterns. |
| """ |
| |
| hs = _parse_hs_mode(skip_mode) |
|
|
| |
| try: |
| pfs = int(protect_first_steps) |
| except Exception: |
| pfs = 2 |
| if pfs < 0: |
| pfs = 0 |
|
|
| |
| try: |
| pls = int(protect_last_steps) |
| except Exception: |
| pls = 4 |
| if pls < 1: |
| pls = 1 |
|
|
| |
| if hs is not None: |
| history_order, skip_calls = hs |
| |
| if step_index < pfs or step_index >= total_steps - pls: |
| return False, None |
| |
| if len(epsilon_history) < history_order: |
| return False, None |
| |
| anchor = max(pfs, history_order) |
| |
| cycle_len = int(skip_calls) + 1 |
| cycle_position = (step_index - anchor) % cycle_len |
| if cycle_position == (cycle_len - 1): |
| |
| if history_order >= 4: |
| return True, "h4" |
| elif history_order == 3: |
| return True, "richardson" |
| else: |
| return True, "linear" |
| return False, None |
|
|
| |
| skip_mode = _normalize_skip_mode(skip_mode) |
|
|
| |
| if skip_mode == "none": |
| return False, None |
|
|
| |
| if step_index < pfs: |
| return False, None |
| if step_index >= total_steps - pls: |
| return False, None |
|
|
| |
| if len(epsilon_history) < 2: |
| return False, None |
|
|
| if skip_mode == "h4": |
| |
| anchor = max(pfs, 4) |
| if step_index >= anchor: |
| cycle_position = (step_index - anchor) % 5 |
| if cycle_position == 4: |
| return True, "h4" |
| return False, None |
|
|
| elif skip_mode == "linear": |
| |
| cycle_position = (step_index - pfs) % 3 |
| if cycle_position == 2: |
| return True, "linear" |
| return False, None |
|
|
| elif skip_mode == "richardson": |
| |
| if len(epsilon_history) < 3: |
| return False, None |
| anchor = pfs + 1 |
| cycle_position = (step_index - anchor) % 4 |
| if cycle_position == 3: |
| return True, "richardson" |
| return False, None |
|
|
| elif skip_mode == "adaptive": |
| |
| |
| cycle_position = (step_index - 2) % 3 |
| if cycle_position == 2 and 0.97 <= error_ratio <= 1.03 and len(epsilon_history) >= 2: |
| return True, "linear" |
| |
| cycle_position = (step_index - 3) % 4 |
| if cycle_position == 3 and 0.99 <= error_ratio <= 1.01 and len(epsilon_history) >= 3: |
| return True, "richardson" |
| return False, None |
|
|
| return False, None |
|
|
|
|
| def validate_epsilon_hat(eps_hat, prev_eps=None, min_abs=1e-8, min_rel=1e-6): |
| """Validate extrapolated epsilon before using it for a skip step. |
| |
| Returns (ok, reason, hat_norm, prev_norm). |
| Reasons: 'none', 'nan_inf', 'too_small_abs', 'too_small_rel' |
| """ |
| prev_norm = None |
| if eps_hat is None: |
| return False, 'none', 0.0, prev_norm |
| try: |
| if torch.isnan(eps_hat).any() or torch.isinf(eps_hat).any(): |
| return False, 'nan_inf', float('nan'), None |
| hat_norm = torch.norm(eps_hat).item() |
| except Exception: |
| return False, 'nan_inf', float('nan'), None |
|
|
| if not math.isfinite(hat_norm): |
| return False, 'nan_inf', hat_norm, None |
| if hat_norm < min_abs: |
| return False, 'too_small_abs', hat_norm, None |
|
|
| if prev_eps is not None: |
| try: |
| prev_norm = torch.norm(prev_eps).item() |
| except Exception: |
| prev_norm = None |
| if prev_norm is not None and prev_norm > 0 and hat_norm < (min_rel * prev_norm): |
| return False, 'too_small_rel', hat_norm, prev_norm |
|
|
| return True, '', hat_norm, prev_norm |
|
|
|
|
| def decide_skip_adaptive( |
| epsilon_history, |
| step_index, |
| total_steps, |
| protect_last_steps=4, |
| protect_first_steps=2, |
| tol_relative=0.10, |
| anchor_interval=None, |
| max_consecutive_skips=None, |
| skip_stats=None, |
| |
| x_current=None, |
| sigma_current=None, |
| sigma_next=None, |
| sampler_kind=None, |
| sigma_previous=None, |
| ): |
| """Adaptive skip decision using a dual-order epsilon-space gate. |
| |
| - Builds two predictions from REAL epsilon history at the current step time: |
| high order (richardson, h3) and lower order (linear, h2). |
| - Computes a relative error in epsilon space: ||ε̂_hi - ε̂_lo|| / max(||ε̂_hi||, eps_rel). |
| - Allows skipping when error is below tolerance and guard rails permit. |
| |
| Returns: (should_skip: bool, epsilon_hat: Tensor|None, meta: dict) |
| """ |
| |
| try: |
| pfs = max(0, int(protect_first_steps)) |
| except Exception: |
| pfs = 2 |
| try: |
| pls = max(1, int(protect_last_steps)) |
| except Exception: |
| pls = 4 |
| if step_index < pfs or step_index >= total_steps - pls: |
| return False, None, {"reason": "protected_region"} |
|
|
| |
| if len(epsilon_history) < 3: |
| return False, None, {"reason": "insufficient_history"} |
|
|
| |
| if anchor_interval is None: |
| anchor_interval = 4 |
| if max_consecutive_skips is None: |
| max_consecutive_skips = 2 |
|
|
| |
| if isinstance(skip_stats, dict): |
| consec = skip_stats.get("consecutive_skips", 0) |
| if consec >= max_consecutive_skips: |
| return False, None, {"reason": "max_consecutive"} |
|
|
| |
| |
| if anchor_interval and anchor_interval > 0 and step_index >= pfs: |
| try: |
| offset_idx = step_index - pfs |
| if (offset_idx % int(anchor_interval)) == 0: |
| return False, None, {"reason": "anchor_abs"} |
| except Exception: |
| pass |
|
|
| |
| eps_hat_hi = extrapolate_epsilon_richardson(epsilon_history) |
| eps_hat_lo = extrapolate_epsilon_linear(epsilon_history) |
| if eps_hat_hi is None or eps_hat_lo is None: |
| return False, None, {"reason": "predict_failed"} |
|
|
| |
| if (not torch.isfinite(eps_hat_hi).all()) or (not torch.isfinite(eps_hat_lo).all()): |
| return False, None, {"reason": "non_finite"} |
|
|
| |
| def _rms(t): |
| try: |
| return float(torch.sqrt(torch.mean((t.float()) ** 2)).item()) |
| except Exception: |
| return float('inf') |
|
|
| if x_current is not None and sigma_current is not None and sigma_next is not None and sampler_kind is not None: |
| try: |
| dt = sigma_next - sigma_current |
| |
| if sampler_kind in ("euler", "res_2s", "dpmpp_2s"): |
| |
| d_hi = -eps_hat_hi / sigma_current |
| d_lo = -eps_hat_lo / sigma_current |
| x_next_hi = x_current + dt * d_hi |
| x_next_lo = x_current + dt * d_lo |
| elif sampler_kind == "ddim": |
| |
| scale = (sigma_next / sigma_current) |
| x0_hi = x_current + eps_hat_hi |
| x0_lo = x_current + eps_hat_lo |
| x_next_hi = x0_hi + scale * (x_current - x0_hi) |
| x_next_lo = x0_lo + scale * (x_current - x0_lo) |
| elif sampler_kind in ("dpmpp_2m", "lms"): |
| |
| d_hi = -eps_hat_hi / sigma_current |
| d_lo = -eps_hat_lo / sigma_current |
| d_prev = None |
| if sigma_previous is not None and len(epsilon_history) >= 1: |
| d_prev = -(epsilon_history[-1]) / sigma_previous |
| if d_prev is not None: |
| x_next_hi = x_current + dt * (1.5 * d_hi - 0.5 * d_prev) |
| x_next_lo = x_current + dt * (1.5 * d_lo - 0.5 * d_prev) |
| else: |
| x_next_hi = x_current + dt * d_hi |
| x_next_lo = x_current + dt * d_lo |
| else: |
| |
| x_next_hi = None |
| x_next_lo = None |
|
|
| if x_next_hi is not None and x_next_lo is not None: |
| num = _rms(x_next_hi - x_next_lo) |
| den = max(_rms(x_next_hi), 1e-6) |
| rel_err = num / den |
| meta = {"relative_error": rel_err, "hi_order": 3, "lo_order": 2, "space": "x_next"} |
| if rel_err <= float(tol_relative): |
| return True, eps_hat_hi, meta |
| return False, None, meta |
| except Exception: |
| |
| pass |
|
|
| |
| diff = eps_hat_hi - eps_hat_lo |
| num = _rms(diff) |
| den = max(_rms(eps_hat_hi), 1e-6) |
| if not math.isfinite(num) or not math.isfinite(den) or den <= 0: |
| return False, None, {"reason": "bad_metric"} |
| rel_err = num / den |
| meta = {"relative_error": rel_err, "hi_order": 3, "lo_order": 2, "space": "epsilon"} |
| if rel_err <= float(tol_relative): |
| return True, eps_hat_hi, meta |
| return False, None, meta |
|
|