File size: 14,694 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import math
import torch
import re
from .extrapolation import (
    extrapolate_epsilon_linear,
    extrapolate_epsilon_richardson,
)


def _parse_hs_mode(skip_mode: str):
    """Parse decoupled history/stride mode strings like 'h2/s3'.

    Returns (history_order:int, skip_calls:int) if recognized and allowed,
    otherwise None. Allowed set (explicit, compact):
      - h2/sK for K in {2,3,4,5,6}
      - h3/sK for K in {3,4,5,6}
      - h4/sK for K in {4,5,6}
    (No h4/s2 or h4/s3 to avoid overly aggressive cadence for high-order.)
    """
    try:
        if not isinstance(skip_mode, str):
            return None
        m = skip_mode.strip().lower()
        if m == "h2":
            return (2, 2)
        if m == "h3":
            return (3, 3)
        if m == "h4":
            return (4, 4)
        # hN/sK form
        if "/" in m:
            left, right = m.split("/", 1)
            left = left.strip()
            right = right.strip()
            if left.startswith("h") and right.startswith("s") and len(left) > 1 and len(right) > 1:
                n_str = left[1:]
                k_str = right[1:]
                if all(ch.isdigit() for ch in n_str) and all(ch.isdigit() for ch in k_str):
                    n = int(n_str)
                    k = int(k_str)
                    allowed = set((2, kk) for kk in (2, 3, 4, 5, 6))
                    allowed |= set((3, kk) for kk in (3, 4, 5, 6))
                    allowed |= set((4, kk) for kk in (4, 5, 6))
                    if (n, k) in allowed:
                        return (n, k)
        return None
    except Exception:
        return None


def _normalize_skip_mode(skip_mode: str):
    """Map UI aliases to canonical internal modes.

    Canonical: none | linear | richardson | quad | adaptive
    Aliases: h2 -> linear, h3 -> richardson, h4 -> quad
    Legacy: linear, richardson, quad kept for back-compat
    """
    if not isinstance(skip_mode, str):
        return "none"
    m = skip_mode.lower()
    if m in ("h2", "linear"):
        return "linear"
    if m in ("h3", "richardson"):
        return "richardson"
    if m in ("h4", "quad"):
        return "h4"
    if m == "adaptive":
        return "adaptive"
    return "none"


def parse_skip_indices_config(text: str):
    """Parse explicit skip indices configuration string.

    Input examples:
      - "h2, 3, 4, 7, 9"
      - "3 6 8" (defaults to h2)
      - "h4, 10, 12" (first hN wins)

    Behavior:
      - Tokenize on commas and/or whitespace; case-insensitive.
      - First hN token wins among {h2,h3,h4}; default predictor is h2 when not specified.
      - Collect distinct integer tokens into a set; drop invalids.
      - Always filter out step indices 0 and 1 here; engine will further bound to total steps.

    Returns: (predictor: 'linear'|'richardson'|'h4', indices: set[int]).
             If no indices parsed, indices will be an empty set.
    """
    if not isinstance(text, str):
        return "linear", set()

    tokens = [t.strip() for t in re.split(r"[\s,]+", text.strip()) if t.strip()]
    predictor = None
    indices = set()

    for tok in tokens:
        tl = tok.lower()
        # predictor selection (first hN wins)
        if predictor is None and len(tl) >= 2 and tl[0] == 'h' and tl[1:].isdigit():
            n = int(tl[1:])
            if n >= 4:
                predictor = "h4"
            elif n == 3:
                predictor = "richardson"
            elif n == 2:
                predictor = "linear"
            continue
        # integer index
        if tl.lstrip("+-").isdigit():
            try:
                v = int(tl)
            except Exception:
                continue
            # Filter out 0 and 1 here
            if v >= 2:
                indices.add(v)

    if predictor is None:
        predictor = "linear"

    return predictor, indices


def should_skip_model_call(error_ratio, step_index, total_steps, skip_mode, epsilon_history, protect_last_steps=4, protect_first_steps=2):
    """Decide whether to skip the model call based on skip_mode pattern and history.

    This mirrors the existing logic used in sampling_engine.py, including first/last step
    protections and required history lengths for linear/richardson patterns.
    """
    # First, handle decoupled hN/sK modes if provided
    hs = _parse_hs_mode(skip_mode)

    # Never skip first few steps (need to build history)
    try:
        pfs = int(protect_first_steps)
    except Exception:
        pfs = 2
    if pfs < 0:
        pfs = 0

    # Never skip last few steps (critical for quality)
    try:
        pls = int(protect_last_steps)
    except Exception:
        pls = 4
    if pls < 1:
        pls = 1

    # Decoupled history/stride path
    if hs is not None:
        history_order, skip_calls = hs
        # Guard windows
        if step_index < pfs or step_index >= total_steps - pls:
            return False, None
        # Require sufficient REAL epsilon history
        if len(epsilon_history) < history_order:
            return False, None
        # Align first eligible skip to the later of warmup or required history
        anchor = max(pfs, history_order)
        # Pattern: Call×K, then Skip → cycle length = K+1, skip on last position
        cycle_len = int(skip_calls) + 1
        cycle_position = (step_index - anchor) % cycle_len
        if cycle_position == (cycle_len - 1):
            # Choose predictor by history order
            if history_order >= 4:
                return True, "h4"
            elif history_order == 3:
                return True, "richardson"
            else:
                return True, "linear"
        return False, None

    # Normalize legacy modes when not using hN/sK
    skip_mode = _normalize_skip_mode(skip_mode)

    # Never skip if mode is "none"
    if skip_mode == "none":
        return False, None

    # Guard windows for legacy modes
    if step_index < pfs:
        return False, None
    if step_index >= total_steps - pls:
        return False, None

    # Check if we have enough history
    if len(epsilon_history) < 2:
        return False, None

    if skip_mode == "h4":
        # 4-history mode: Call 4, Skip 1 (uses 4-point predictor on skip)
        anchor = max(pfs, 4)
        if step_index >= anchor:
            cycle_position = (step_index - anchor) % 5
            if cycle_position == 4:
                return True, "h4"
        return False, None

    elif skip_mode == "linear":
        # Call, Call, Skip cycle
        cycle_position = (step_index - pfs) % 3
        if cycle_position == 2:
            return True, "linear"
        return False, None

    elif skip_mode == "richardson":
        # Call, Call, Call, Skip cycle, needs 3 history
        if len(epsilon_history) < 3:
            return False, None
        anchor = pfs + 1  # shift so default pfs=2 behaves like previous (anchor=3)
        cycle_position = (step_index - anchor) % 4
        if cycle_position == 3:
            return True, "richardson"
        return False, None

    elif skip_mode == "adaptive":
        # Pattern-gated adaptive using error_ratio bands (legacy behavior)
        # Potential skip (every 3rd step after warmup) with tight band
        cycle_position = (step_index - 2) % 3
        if cycle_position == 2 and 0.97 <= error_ratio <= 1.03 and len(epsilon_history) >= 2:
            return True, "linear"
        # Every 4th step after more warmup with very tight band
        cycle_position = (step_index - 3) % 4
        if cycle_position == 3 and 0.99 <= error_ratio <= 1.01 and len(epsilon_history) >= 3:
            return True, "richardson"
        return False, None

    return False, None


def validate_epsilon_hat(eps_hat, prev_eps=None, min_abs=1e-8, min_rel=1e-6):
    """Validate extrapolated epsilon before using it for a skip step.

    Returns (ok, reason, hat_norm, prev_norm).
    Reasons: 'none', 'nan_inf', 'too_small_abs', 'too_small_rel'
    """
    prev_norm = None
    if eps_hat is None:
        return False, 'none', 0.0, prev_norm
    try:
        if torch.isnan(eps_hat).any() or torch.isinf(eps_hat).any():
            return False, 'nan_inf', float('nan'), None
        hat_norm = torch.norm(eps_hat).item()
    except Exception:
        return False, 'nan_inf', float('nan'), None

    if not math.isfinite(hat_norm):
        return False, 'nan_inf', hat_norm, None
    if hat_norm < min_abs:
        return False, 'too_small_abs', hat_norm, None

    if prev_eps is not None:
        try:
            prev_norm = torch.norm(prev_eps).item()
        except Exception:
            prev_norm = None
        if prev_norm is not None and prev_norm > 0 and hat_norm < (min_rel * prev_norm):
            return False, 'too_small_rel', hat_norm, prev_norm

    return True, '', hat_norm, prev_norm


def decide_skip_adaptive(
    epsilon_history,
    step_index,
    total_steps,
    protect_last_steps=4,
    protect_first_steps=2,
    tol_relative=0.10,
    anchor_interval=None,
    max_consecutive_skips=None,
    skip_stats=None,
    # Optional: predicted-state (x_next) gating context
    x_current=None,
    sigma_current=None,
    sigma_next=None,
    sampler_kind=None,
    sigma_previous=None,
):
    """Adaptive skip decision using a dual-order epsilon-space gate.

    - Builds two predictions from REAL epsilon history at the current step time:
      high order (richardson, h3) and lower order (linear, h2).
    - Computes a relative error in epsilon space: ||ε̂_hi - ε̂_lo|| / max(||ε̂_hi||, eps_rel).
    - Allows skipping when error is below tolerance and guard rails permit.

    Returns: (should_skip: bool, epsilon_hat: Tensor|None, meta: dict)
    """
    # Guard: skip disabled near start/end
    try:
        pfs = max(0, int(protect_first_steps))
    except Exception:
        pfs = 2
    try:
        pls = max(1, int(protect_last_steps))
    except Exception:
        pls = 4
    if step_index < pfs or step_index >= total_steps - pls:
        return False, None, {"reason": "protected_region"}

    # History requirement (need >=3 REAL eps for richardson)
    if len(epsilon_history) < 3:
        return False, None, {"reason": "insufficient_history"}

    # Defaults for aggressive profile if not provided
    if anchor_interval is None:
        anchor_interval = 4            # Absolute cadence: every Nth step index (offset by protect_first_steps)
    if max_consecutive_skips is None:
        max_consecutive_skips = 2 # Local cap on back-to-back skips

    # Local consecutive cap
    if isinstance(skip_stats, dict):
        consec = skip_stats.get("consecutive_skips", 0)
        if consec >= max_consecutive_skips:
            return False, None, {"reason": "max_consecutive"}

    # Absolute anchor cadence (does not reset on REAL calls)
    # Force REAL on every Nth index from the first eligible step (pfs offset), excluding protected tail
    if anchor_interval and anchor_interval > 0 and step_index >= pfs:
        try:
            offset_idx = step_index - pfs
            if (offset_idx % int(anchor_interval)) == 0:
                return False, None, {"reason": "anchor_abs"}
        except Exception:
            pass

    # Build predictions
    eps_hat_hi = extrapolate_epsilon_richardson(epsilon_history)
    eps_hat_lo = extrapolate_epsilon_linear(epsilon_history)
    if eps_hat_hi is None or eps_hat_lo is None:
        return False, None, {"reason": "predict_failed"}

    # Finite checks
    if (not torch.isfinite(eps_hat_hi).all()) or (not torch.isfinite(eps_hat_lo).all()):
        return False, None, {"reason": "non_finite"}

    # Prefer predicted-state (x_next) gating when context is provided; else epsilon-space fallback
    def _rms(t):
        try:
            return float(torch.sqrt(torch.mean((t.float()) ** 2)).item())
        except Exception:
            return float('inf')

    if x_current is not None and sigma_current is not None and sigma_next is not None and sampler_kind is not None:
        try:
            dt = sigma_next - sigma_current
            # Compute predicted next states for this sampler
            if sampler_kind in ("euler", "res_2s", "dpmpp_2s"):
                # Euler-like update
                d_hi = -eps_hat_hi / sigma_current
                d_lo = -eps_hat_lo / sigma_current
                x_next_hi = x_current + dt * d_hi
                x_next_lo = x_current + dt * d_lo
            elif sampler_kind == "ddim":
                # x_next = x0 + (sigma_next/sigma_current)*(x - x0), x0 = x + eps
                scale = (sigma_next / sigma_current)
                x0_hi = x_current + eps_hat_hi
                x0_lo = x_current + eps_hat_lo
                x_next_hi = x0_hi + scale * (x_current - x0_hi)
                x_next_lo = x0_lo + scale * (x_current - x0_lo)
            elif sampler_kind in ("dpmpp_2m", "lms"):
                # AB2-style with optional previous derivative from last REAL epsilon
                d_hi = -eps_hat_hi / sigma_current
                d_lo = -eps_hat_lo / sigma_current
                d_prev = None
                if sigma_previous is not None and len(epsilon_history) >= 1:
                    d_prev = -(epsilon_history[-1]) / sigma_previous
                if d_prev is not None:
                    x_next_hi = x_current + dt * (1.5 * d_hi - 0.5 * d_prev)
                    x_next_lo = x_current + dt * (1.5 * d_lo - 0.5 * d_prev)
                else:
                    x_next_hi = x_current + dt * d_hi
                    x_next_lo = x_current + dt * d_lo
            else:
                # Fallback to epsilon-space metric if sampler not supported
                x_next_hi = None
                x_next_lo = None

            if x_next_hi is not None and x_next_lo is not None:
                num = _rms(x_next_hi - x_next_lo)
                den = max(_rms(x_next_hi), 1e-6)
                rel_err = num / den
                meta = {"relative_error": rel_err, "hi_order": 3, "lo_order": 2, "space": "x_next"}
                if rel_err <= float(tol_relative):
                    return True, eps_hat_hi, meta
                return False, None, meta
        except Exception:
            # Fall through to epsilon-space
            pass

    # Epsilon-space fallback
    diff = eps_hat_hi - eps_hat_lo
    num = _rms(diff)
    den = max(_rms(eps_hat_hi), 1e-6)
    if not math.isfinite(num) or not math.isfinite(den) or den <= 0:
        return False, None, {"reason": "bad_metric"}
    rel_err = num / den
    meta = {"relative_error": rel_err, "hi_order": 3, "lo_order": 2, "space": "epsilon"}
    if rel_err <= float(tol_relative):
        return True, eps_hat_hi, meta
    return False, None, meta