File size: 37,264 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
from .engine import sample_fsampler


def sample_step_euler(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args,
                      epsilon_history, learning_ratio, smoothing_beta, predictor_type,
                      step_index, total_steps, skip_mode="none", skip_stats=None, debug=False):
    """Standard Euler step using Karras ODE derivative formulation.



    Implements the standard k-diffusion Euler method:

    - Converts denoised to ODE derivative: d = (x - denoised) / sigma

    - Takes Euler step: x = x + d * dt, where dt = sigma_next - sigma_current



    Supports model call skipping via epsilon extrapolation.

    """
    x = noisy_latent

    # Update skip statistics
    if skip_stats is not None:
        skip_stats["total_steps"] += 1

    # Check if we should skip the model call
    should_skip, skip_method = should_skip_model_call(
        1.0,  # error_ratio - Euler doesn't track this, use neutral value
        step_index,
        total_steps,
        skip_mode,
        epsilon_history
    )

    # Get denoised: either from model call or extrapolation
    was_skipped = False

    if should_skip and skip_method is not None:
        # SKIP: Use epsilon extrapolation
        if skip_method == "linear":
            epsilon = extrapolate_epsilon_linear(epsilon_history)
        elif skip_method == "richardson":
            epsilon = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon = None

        # Safety check: if extrapolation failed, fall back to model call
        if epsilon is None or torch.isnan(epsilon).any():
            should_skip = False
            if debug:
                print(f"euler step {step_index}: extrapolation failed, falling back to model call")

        if should_skip and epsilon is not None:
            # Successful skip - reconstruct denoised from extrapolated epsilon
            # Apply universal learning stabilizer if we have enough REAL history (>=3)
            if len(epsilon_history) >= 3:
                epsilon = epsilon / max(learning_ratio, 1e-8)
            denoised = x + epsilon
            was_skipped = True
            if skip_stats is not None:
                skip_stats["skipped"] += 1
            if debug:
                e_norm = torch.norm(epsilon).item()
                dt_val = (sigma_next - sigma_current).item() if torch.is_tensor(sigma_next) else float(sigma_next - sigma_current)
                print(f"euler step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}, dt={dt_val:.4f}")

    if not should_skip:
        # CALL MODEL: Normal path
        denoised = model(x, sigma_current * s_in, **extra_args)
        if skip_stats is not None:
            skip_stats["model_calls"] += 1

    # Karras ODE derivative: d = (x - denoised) / sigma
    # This is the standard k-diffusion formulation
    d = (x - denoised) / sigma_current

    # Euler step in sigma space: x = x + d * dt
    dt = sigma_next - sigma_current
    x = x + d * dt

    # Store REAL epsilon for extrapolation/learning (append full history for this run)
    if not was_skipped:
        epsilon = denoised - noisy_latent
        epsilon_history.append(epsilon)
        # Universal learning update only when enough REAL history exists (>=3)
        if len(epsilon_history) >= 3:
            # Compute predictor-matched epsilon_hat from REAL history
            if predictor_type == "richardson":
                epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
            else:
                epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
            if epsilon_hat is not None:
                learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon) + 1e-8)).item()
                # EMA update with smoothing_beta and clamp
                learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
                # clamps (hidden constants)
                if learning_ratio < 0.5:
                    learning_ratio = 0.5
                elif learning_ratio > 2.0:
                    learning_ratio = 2.0

    if debug and not was_skipped:
        d_norm = torch.norm(d).item()
        e_norm = torch.norm(epsilon).item()
        if len(epsilon_history) >= 3:
            print(f"euler step {step_index}: e_norm={e_norm:.2f}, d_norm={d_norm:.2f}, dt={dt.item():.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")
        else:
            print(f"euler step {step_index}: e_norm={e_norm:.2f}, d_norm={d_norm:.2f}, dt={dt.item():.4f}")
    return x, learning_ratio


def sample_step_ddim(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args,
                     epsilon_history, learning_ratio, smoothing_beta, predictor_type,
                     step_index, total_steps, skip_mode="none", debug=False):
    """DDIM deterministic step (eta=0) with optional skipping.

    Formula: x_next = x0 + (sigma_next / sigma_current) * (x - x0), where x0 = denoised.
    On skips, use extrapolated epsilon_hat to form x0_hat = x + epsilon_hat_scaled.
    """
    x = noisy_latent

    # Decide skip
    should_skip, skip_method = should_skip_model_call(
        1.0, step_index, total_steps, skip_mode, epsilon_history
    )

    was_skipped = False
    if should_skip and skip_method is not None:
        # Predictor from REAL history
        if skip_method == "richardson":
            epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon_hat = extrapolate_epsilon_linear(epsilon_history)

        if epsilon_hat is None or torch.isnan(epsilon_hat).any():
            should_skip = False
        else:
            if len(epsilon_history) >= 3:
                epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8)

            x0_hat = x + epsilon_hat
            scale = (sigma_next / sigma_current)
            x = x0_hat + scale * (x - x0_hat)
            was_skipped = True
            if debug:
                e_norm = torch.norm(epsilon_hat).item()
                print(f"ddim step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}")

    if not should_skip:
        # REAL call
        denoised = model(x, sigma_current * s_in, **extra_args)
        # Update: x_next = denoised + (sigma_next / sigma_current) * (x - denoised)
        scale = (sigma_next / sigma_current)
        x = denoised + scale * (x - denoised)

        # Learning update (append REAL epsilon and update L if ≥3 REAL eps)
        epsilon_real = denoised - noisy_latent
        epsilon_history.append(epsilon_real)
        if len(epsilon_history) >= 3:
            if predictor_type == "richardson":
                epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
            else:
                epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
            if epsilon_hat is not None:
                learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon_real) + 1e-8)).item()
                learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
                if learning_ratio < 0.5:
                    learning_ratio = 0.5
                elif learning_ratio > 2.0:
                    learning_ratio = 2.0
                if debug:
                    print(f"ddim step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")

    return x, learning_ratio


def sample_step_dpmpp_2m(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args,
                         epsilon_history, learning_ratio, smoothing_beta, predictor_type,
                         step_index, total_steps, skip_mode="none", skip_stats=None, debug=False):
    """DPM++ 2M (second-order multistep) with learning + skip.

    Update: x_next = x + dt * [ (3/2)·d_n − (1/2)·d_{n−1} ], with d = (x − denoised)/sigma.
    First step falls back to Euler.
    On skip, use epsilon_hat (scaled by 1/L) to form d_n.
    """
    x = noisy_latent

    # Count step
    if skip_stats is not None:
        skip_stats["total_steps"] += 1

    # Skip decision
    should_skip, skip_method = should_skip_model_call(1.0, step_index, total_steps, skip_mode, epsilon_history)

    d_prev = None
    if sigma_previous is not None and len(epsilon_history) >= 1:
        eps_prev = epsilon_history[-1]
        d_prev = -(eps_prev) / sigma_previous

    if should_skip and skip_method is not None:
        if skip_method == "richardson":
            epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon_hat = extrapolate_epsilon_linear(epsilon_history)

        if epsilon_hat is None or torch.isnan(epsilon_hat).any():
            should_skip = False
        else:
            if len(epsilon_history) >= 3:
                epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8)
            d_curr = -(epsilon_hat) / sigma_current
            dt = sigma_next - sigma_current
            if d_prev is not None:
                x = x + dt * (1.5 * d_curr - 0.5 * d_prev)
            else:
                x = x + dt * d_curr
            if skip_stats is not None:
                skip_stats["skipped"] += 1
            if debug:
                d_norm = torch.norm(d_curr).item()
                print(f"dpmpp_2m step {step_index} [SKIPPED-{skip_method}]: d_norm={d_norm:.2f}, L={learning_ratio:.4f}")
            return x, learning_ratio

    # REAL call
    denoised = model(x, sigma_current * s_in, **extra_args)
    eps_curr = denoised - x
    d_curr = -eps_curr / sigma_current
    dt = sigma_next - sigma_current
    if d_prev is not None:
        x = x + dt * (1.5 * d_curr - 0.5 * d_prev)
    else:
        x = x + dt * d_curr
    if skip_stats is not None:
        skip_stats["model_calls"] += 1

    # Learning update
    epsilon_history.append(eps_curr)
    if len(epsilon_history) >= 3:
        if predictor_type == "richardson":
            epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
        if epsilon_hat is not None:
            learn_obs = (torch.norm(epsilon_hat) / (torch.norm(eps_curr) + 1e-8)).item()
            learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
            if learning_ratio < 0.5:
                learning_ratio = 0.5
            elif learning_ratio > 2.0:
                learning_ratio = 2.0
            if debug:
                print(f"dpmpp_2m step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")

    if debug and d_prev is not None:
        d_norm = torch.norm(d_curr).item()
        print(f"dpmpp_2m step {step_index}: d_norm={d_norm:.2f}, AB2")
    elif debug:
        d_norm = torch.norm(d_curr).item()
        print(f"dpmpp_2m step {step_index}: d_norm={d_norm:.2f}, Euler")

    return x, learning_ratio


def sample_step_dpmpp_2s(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args,
                         epsilon_history, learning_ratio, smoothing_beta, predictor_type,
                         step_index, total_steps, skip_mode="none", debug=False):
    """DPM++ 2S (two-stage ODE) with learning + skip.

    Two real evaluations:
      d1 at (x, sigma_current), predictor x_pred = x + dt*d1
      d2 at (x_pred, sigma_next), corrector: x_next = x + dt*0.5*(d1 + d2)
    On skip, use Euler-like inter-step update with epsilon_hat.
    """
    x = noisy_latent

    # Final step: avoid division by zero at sigma_next ~ 0; land on denoised
    sigma_next_value = sigma_next.item() if torch.is_tensor(sigma_next) else float(sigma_next)
    if abs(sigma_next_value) <= 1e-8:
        den = model(x, sigma_current * s_in, **extra_args)
        x = den
        # Learning update (REAL)
        eps_real = den - noisy_latent
        epsilon_history.append(eps_real)
        if len(epsilon_history) >= 3:
            if predictor_type == "richardson":
                epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
            else:
                epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
            if epsilon_hat is not None:
                learn_obs = (torch.norm(epsilon_hat) / (torch.norm(eps_real) + 1e-8)).item()
                learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
                if learning_ratio < 0.5:
                    learning_ratio = 0.5
                elif learning_ratio > 2.0:
                    learning_ratio = 2.0
                if debug:
                    print(f"dpmpp_2s step {step_index} (final) [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")
        if debug:
            print(f"dpmpp_2s step {step_index} (final step): landing on denoised")
        return x, learning_ratio

    # Skip decision
    should_skip, skip_method = should_skip_model_call(1.0, step_index, total_steps, skip_mode, epsilon_history)

    if should_skip and skip_method is not None:
        if skip_method == "richardson":
            epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
        if epsilon_hat is None or torch.isnan(epsilon_hat).any():
            should_skip = False
        else:
            if len(epsilon_history) >= 3:
                epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8)
            d = -(epsilon_hat) / sigma_current
            dt = sigma_next - sigma_current
            x = x + dt * d
            if debug:
                e_norm = torch.norm(epsilon_hat).item()
                print(f"dpmpp_2s step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}")
            return x, learning_ratio

    # REAL evaluations
    den1 = model(x, sigma_current * s_in, **extra_args)
    d1 = (x - den1) / sigma_current
    dt = sigma_next - sigma_current
    x_pred = x + dt * d1
    den2 = model(x_pred, sigma_next * s_in, **extra_args)
    d2 = (x_pred - den2) / sigma_next
    x = x + dt * 0.5 * (d1 + d2)

    # Learning update from stage-1 epsilon
    eps_real = den1 - noisy_latent
    epsilon_history.append(eps_real)
    if len(epsilon_history) >= 3:
        if predictor_type == "richardson":
            epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
        if epsilon_hat is not None:
            learn_obs = (torch.norm(epsilon_hat) / (torch.norm(eps_real) + 1e-8)).item()
            learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
            if learning_ratio < 0.5:
                learning_ratio = 0.5
            elif learning_ratio > 2.0:
                learning_ratio = 2.0
            if debug:
                print(f"dpmpp_2s step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")

    if debug:
        d1n = torch.norm(d1).item(); d2n = torch.norm(d2).item()
        print(f"dpmpp_2s step {step_index}: d1_norm={d1n:.2f}, d2_norm={d2n:.2f}")

    return x, learning_ratio


def _ab2_update(x, dt, d_curr, d_prev=None):
    if d_prev is not None:
        return x + dt * (1.5 * d_curr - 0.5 * d_prev)
    else:
        return x + dt * d_curr


def sample_step_lms(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args,
                    epsilon_history, learning_ratio, smoothing_beta, predictor_type,
                    step_index, total_steps, skip_mode="none", skip_stats=None, debug=False):
    """LMS (AB2 baseline) with learning + skip.

    d = (x - denoised)/sigma; x_next = x + dt * [ (3/2)·d_n − (1/2)·d_{n−1} ]
    """
    x = noisy_latent
    if skip_stats is not None:
        skip_stats["total_steps"] += 1

    should_skip, skip_method = should_skip_model_call(1.0, step_index, total_steps, skip_mode, epsilon_history)

    d_prev = None
    if sigma_previous is not None and len(epsilon_history) >= 1:
        d_prev = -(epsilon_history[-1]) / sigma_previous

    if should_skip and skip_method is not None:
        epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) if skip_method == "richardson" else extrapolate_epsilon_linear(epsilon_history)
        if epsilon_hat is None or torch.isnan(epsilon_hat).any():
            should_skip = False
        else:
            if len(epsilon_history) >= 3:
                epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8)
            d_curr = -epsilon_hat / sigma_current
            dt = sigma_next - sigma_current
            x = _ab2_update(x, dt, d_curr, d_prev)
            if skip_stats is not None:
                skip_stats["skipped"] += 1
            if debug:
                d_norm = torch.norm(d_curr).item()
                print(f"lms step {step_index} [SKIPPED-{skip_method}]: d_norm={d_norm:.2f}, L={learning_ratio:.4f}")
            return x, learning_ratio

    # REAL call
    den = model(x, sigma_current * s_in, **extra_args)
    eps = den - x
    d_curr = -eps / sigma_current
    dt = sigma_next - sigma_current
    x = _ab2_update(x, dt, d_curr, d_prev)
    if skip_stats is not None:
        skip_stats["model_calls"] += 1

    # Learning
    epsilon_history.append(eps)
    if len(epsilon_history) >= 3:
        epsilon_hat = extrapolate_epsilon_richardson(epsilon_history) if predictor_type == "richardson" else extrapolate_epsilon_linear(epsilon_history)
        if epsilon_hat is not None:
            learn_obs = (torch.norm(epsilon_hat) / (torch.norm(eps) + 1e-8)).item()
            learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
            if learning_ratio < 0.5:
                learning_ratio = 0.5
            elif learning_ratio > 2.0:
                learning_ratio = 2.0
            if debug:
                print(f"lms step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")

    if debug:
        dn = torch.norm(d_curr).item()
        print(f"lms step {step_index}: d_norm={dn:.2f}{', AB2' if d_prev is not None else ', Euler'}")

    return x, learning_ratio


def sample_step_plms(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args,
                     epsilon_history, learning_ratio, smoothing_beta, predictor_type,
                     step_index, total_steps, skip_mode="none", skip_stats=None, debug=False):
    """PLMS (baseline AB2 for now) with learning + skip.

    Note: For a full PLMS (PNDM) 4-step, we'd need sigma history for prior steps.
    This baseline uses AB2 until we thread sigma history; still useful and consistent with LMS.
    """
    # For now, mirror LMS AB2 behavior
    return sample_step_lms(model, noisy_latent, sigma_current, sigma_next, sigma_previous, s_in, extra_args,
                           epsilon_history, learning_ratio, smoothing_beta, predictor_type,
                           step_index, total_steps, skip_mode, skip_stats, debug)

# Rebind local names to refactored implementations (ensures imports take precedence)
from .samplers.euler import sample_step_euler as _euler_impl
from .samplers.res2m import sample_step_res_2m as _res2m_impl
from .samplers.res2s import sample_step_res_2s as _res2s_impl
from .samplers.ddim import sample_step_ddim as _ddim_impl
from .samplers.dpmpp_2m import sample_step_dpmpp_2m as _dpmpp2m_impl
from .samplers.dpmpp_2s import sample_step_dpmpp_2s as _dpmpp2s_impl
from .samplers.lms import sample_step_lms as _lms_impl

sample_step_euler = _euler_impl
sample_step_res_2m = _res2m_impl
sample_step_res_2s = _res2s_impl
sample_step_ddim = _ddim_impl
sample_step_dpmpp_2m = _dpmpp2m_impl
sample_step_dpmpp_2s = _dpmpp2s_impl
sample_step_lms = _lms_impl


def sample_step_res_2m(model, noisy_latent, sigma_current, sigma_next, sigma_previous,
                       s_in, extra_args, error_history, epsilon_history, prev_was_skipped, step_index, total_steps,
                       adaptive_mode="none", smoothing_beta=0.9, smoothed_error_ratio=1.0,
                       learning_ratio=1.0, predictor_type="linear",
                       skip_mode="none", skip_stats=None, debug=False):
    """res_2m: 2-multistep method using history from previous steps.



    Matches RES4LYF implementation:

    - Stores denoised predictions in history (not epsilon directly)

    - Recomputes epsilon from stored denoised each step

    - Uses c2 = (-h_prev / h) for multistep coefficients

    """
    x_0 = noisy_latent  # Starting point for this step

    # Update skip statistics
    if skip_stats is not None:
        skip_stats["total_steps"] += 1

    # Check if we should skip the model call
    should_skip, skip_method = should_skip_model_call(
        smoothed_error_ratio, step_index, total_steps, skip_mode, epsilon_history
    )

    # Get epsilon: either from model call or extrapolation
    was_skipped = False  # Track if this step used extrapolation

    if should_skip and skip_method is not None:
        # SKIP: Use extrapolation
        if skip_method == "linear":
            epsilon_current = extrapolate_epsilon_linear(epsilon_history)
        elif skip_method == "richardson":
            epsilon_current = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon_current = None

        # Safety check: if extrapolation failed, fall back to model call
        if epsilon_current is None or torch.isnan(epsilon_current).any():
            should_skip = False
            if debug:
                print(f"res_2m step {step_index}: extrapolation failed, falling back to model call")

        if should_skip and epsilon_current is not None:
            # Successful skip - reconstruct denoised from extrapolated epsilon
            if len(epsilon_history) >= 3:
                epsilon_current = epsilon_current / max(learning_ratio, 1e-8)
            denoised = x_0 + epsilon_current
            was_skipped = True
            if skip_stats is not None:
                skip_stats["skipped"] += 1
            if debug:
                e_norm = torch.norm(epsilon_current).item()
                print(f"res_2m step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}")

    if not should_skip:
        # CALL MODEL: Normal path
        denoised = model(noisy_latent, sigma_current * s_in, **extra_args)
        epsilon_current = denoised - x_0
        if skip_stats is not None:
            skip_stats["model_calls"] += 1

    # Step size in log space: h = -log(sigma_next / sigma_current)
    h = -torch.log(sigma_next / sigma_current)

    # Check if this is the final step (sigma_next = 0)
    # RES4LYF line 178: if sigma_next == 0
    sigma_next_value = sigma_next.item() if torch.is_tensor(sigma_next) else sigma_next
    is_final_step = (sigma_next_value == 0)

    # Check if we have history and can use multistep
    # RES4LYF stores denoised in data_[] array, loads it as: eps_[1] = -(x_0 - data_[1])
    if len(error_history) >= 1 and sigma_previous is not None and not is_final_step:
        # Load previous denoised from history and compute epsilon from it
        # RES4LYF line 215: eps_[1] = -(x_0 - data_[1]) = data_[1] - x_0
        denoised_previous = error_history[-1]
        epsilon_previous = denoised_previous - x_0

        # Multistep coefficient: RES4LYF line 808: c2 = (-h_prev / h).item()
        h_prev = -torch.log(sigma_current / sigma_previous)
        c2 = (-h_prev / h).item()

        # Phi function weights: RES4LYF lines 889-890
        # b2 = φ(2)/c2, b1 = φ(1) - b2
        phi_1 = phi_function(order=1, step_size=-h)
        phi_2 = phi_function(order=2, step_size=-h)

        b2_base = phi_2 / c2
        b1_base = phi_1 - b2_base

        # Adaptive weight adjustment based on error ratio
        # IMPORTANT: Only calculate error_ratio on real model calls, not extrapolated epsilon
        if adaptive_mode != "none" and not was_skipped:
            # Calculate error ratio (only on real model calls)
            error_curr = torch.norm(epsilon_current).item()
            error_prev = torch.norm(epsilon_previous).item()
            error_ratio = error_curr / (error_prev + 1e-8)  # Avoid division by zero

            if adaptive_mode == "learning":
                # MODE 2: EMA smoothed adjustment (learned pattern)
                smoothed_error_ratio_next = (smoothing_beta * smoothed_error_ratio +
                                             (1 - smoothing_beta) * error_ratio)
                adjustment = 1.0 / smoothed_error_ratio_next
                adjustment = max(0.5, min(2.0, adjustment))  # Clamp to [0.5, 2.0]
            else:
                adjustment = 1.0
                smoothed_error_ratio_next = 1.0

            # Apply adjustment to weights
            b1_adjusted = b1_base * adjustment
            b2_adjusted = b2_base / adjustment

            # Normalize to preserve sum (maintains phi_1 constraint)
            sum_adjusted = b1_adjusted + b2_adjusted
            sum_target = b1_base + b2_base  # Should equal phi_1
            scale = sum_target / sum_adjusted

            b1 = b1_adjusted * scale
            b2 = b2_adjusted * scale
        elif adaptive_mode != "none" and was_skipped:
            # Skipped step: preserve previous smoothed_error_ratio, use baseline weights
            # Don't poison the adaptive system with extrapolated epsilon
            b1 = b1_base
            b2 = b2_base
            adjustment = 1.0
            smoothed_error_ratio_next = smoothed_error_ratio  # Preserve previous value
            error_ratio = None  # Mark as not calculated
        else:
            # No adaptation (baseline RES2M)
            b1 = b1_base
            b2 = b2_base
            adjustment = 1.0
            smoothed_error_ratio_next = 1.0
            error_ratio = None

        # Integration: RES4LYF line 364: x = x_0 + h * rk.b_k_sum(eps_, 0)
        # For 2-multistep: b = [b1, b2], eps_ = [eps_current, eps_previous]
        # So: b_k_sum = b1*eps_current + b2*eps_previous
        x = x_0 + h * (b1 * epsilon_current + b2 * epsilon_previous)

        if debug:
            eps_prev_norm = torch.norm(epsilon_previous).item()
            eps_curr_norm = torch.norm(epsilon_current).item()
            if adaptive_mode != "none":
                # Only print immediate EXTRAPOLATED case here; REAL case is printed after learning update
                if error_ratio is None:
                    print(
                        f"res_2m step {step_index} [learning] [EXTRAPOLATED]: "
                        f"baseline φ-weights (adaptive error_ratio preserved); ε̂ scaled by 1/L={learning_ratio:.4f}; "
                        f"b1={b1.item():.4f}, b2={b2.item():.4f}"
                    )
            else:
                print(f"res_2m step {step_index}: eps_prev_norm={eps_prev_norm:.2f}, eps_curr_norm={eps_curr_norm:.2f}, "
                      f"c2={c2:.4f}, b1={b1.item():.4f}, b2={b2.item():.4f}")
    else:
        # First step / post-skip reanchor / final step
        if is_final_step:
            # Final step: sigma_next = 0
            # Return denoised directly (Euler method for final step)
            # Note: Computing h = -log(0/sigma) would give infinity, causing NaN
            # Full DEIS final step would require porting get_deis_coeff_list() from res4lyf
            # For now, standard Euler works perfectly for the final step
            x = denoised
            if debug:
                print(f"res_2m step {step_index} (final step): using Euler")
        else:
            # Use standard Euler integration when we cannot form a valid previous step
            # Reason classification improves log clarity
            if prev_was_skipped:
                reason = "post-skip reanchor"
            elif sigma_previous is None or len(error_history) == 0:
                reason = "first step"
            else:
                reason = "no-history reanchor"
            x = x_0 + h * epsilon_current
            if debug:
                print(f"res_2m step {step_index} ({reason}): using Euler")

        # No adaptation on first/final steps
        smoothed_error_ratio_next = 1.0

    # Store denoised for NEXT step (include SKIPPED to preserve multistep continuity)
    error_history.append(denoised)
    if len(error_history) > 2:
        error_history.pop(0)

    # Store REAL epsilon only for extrapolation/learning; keep full history (no cap)
    if not was_skipped:
        epsilon_history.append(epsilon_current)
        # Universal learning update when enough REAL history exists (>=3)
        if len(epsilon_history) >= 3:
            if predictor_type == "richardson":
                epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
            else:
                epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
            if epsilon_hat is not None:
                learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon_current) + 1e-8)).item()
                learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
                if learning_ratio < 0.5:
                    learning_ratio = 0.5
                elif learning_ratio > 2.0:
                    learning_ratio = 2.0
                if debug:
                    if adaptive_mode != "none" and 'error_ratio' in locals() and error_ratio is not None:
                        # Combined one-line print for learning mode on REAL step
                        print(
                            f"res_2m step {step_index} [learning] [REAL]: "
                            f"err_ratio={error_ratio:.4f}, adjust={adjustment:.4f}, "
                            f"b1={b1.item():.4f}({b1_base.item():.4f}), b2={b2.item():.4f}({b2_base.item():.4f})"
                            f" | learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}"
                        )
                    elif adaptive_mode != "none":
                        # If for any reason error_ratio wasn't available, still show learning update succinctly
                        print(f"res_2m step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")

    return x, smoothed_error_ratio_next, learning_ratio, was_skipped


def sample_step_res_2s(model, noisy_latent, sigma_current, sigma_next, s_in, extra_args,
                       epsilon_history, learning_ratio, smoothing_beta, predictor_type,
                       step_index, total_steps, debug=False, skip_mode="none"):
    """res_2s: 2-stage exponential integrator (baseline, no skipping).

    - Stage 1: Evaluate at current sigma
    - Stage 2: Evaluate at midpoint sigma (geometric in log-sigma)
    - Combine with phi-based weights
    - Update universal learning ratio on REAL steps (epsilon_history REAL-only)
    """
    noisy_latent_at_step_start = noisy_latent

    # Inter-step skip support (baseline: Euler-like update with ε̂)
    should_skip, skip_method = should_skip_model_call(
        1.0,  # res_2s doesn't track error_ratio; adaptive uses bands but we'll pass 1.0
        step_index,
        total_steps,
        skip_mode,
        epsilon_history
    )
    # Note: should_skip_model_call internally checks first <2 and last 4 guards and history length.
    if should_skip and skip_method is not None:
        # Build epsilon_hat from REAL history
        if skip_method == "richardson":
            epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon_hat = extrapolate_epsilon_linear(epsilon_history)

        # Fallback if missing/NaN
        if epsilon_hat is None or torch.isnan(epsilon_hat).any():
            should_skip = False
        else:
            # Scale by learning ratio if we have ≥3 REAL eps in history
            if len(epsilon_history) >= 3:
                epsilon_hat = epsilon_hat / max(learning_ratio, 1e-8)

            # Euler-like update using epsilon_hat
            d = -(epsilon_hat) / sigma_current
            dt = sigma_next - sigma_current
            noisy_latent = noisy_latent + d * dt

            if debug:
                e_norm = torch.norm(epsilon_hat).item()
                dt_val = (sigma_next - sigma_current).item() if torch.is_tensor(sigma_next) else float(sigma_next - sigma_current)
                print(f"res_2s step {step_index} [SKIPPED-{skip_method}]: e_norm={e_norm:.2f}, L={learning_ratio:.4f}, dt={dt_val:.4f}")

            return noisy_latent, learning_ratio

    # Step size in log space
    step_size = -torch.log(sigma_next / sigma_current)

    # Check if this is the final step (sigma_next = 0)
    # When sigma_next = 0, step_size → ∞, causing numerical issues
    # RES4LYF switches to ralston for final step; we use Euler for simplicity
    sigma_next_value = sigma_next.item() if torch.is_tensor(sigma_next) else sigma_next
    is_final_step = (sigma_next_value == 0)

    if is_final_step:
        # Final step: land on denoised directly (avoid infinite step size)
        model_prediction = model(noisy_latent, sigma_current * s_in, **extra_args)
        noisy_latent = model_prediction
        if debug:
            print(f"res_2s step {step_index} (final step): using Euler")
        # Learning update on REAL call
        epsilon_real = model_prediction - noisy_latent_at_step_start
        epsilon_history.append(epsilon_real)
        if len(epsilon_history) >= 3:
            if predictor_type == "richardson":
                epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
            else:
                epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
            if epsilon_hat is not None:
                learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon_real) + 1e-8)).item()
                learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
                if learning_ratio < 0.5:
                    learning_ratio = 0.5
                elif learning_ratio > 2.0:
                    learning_ratio = 2.0
                if debug:
                    print(f"res_2s step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")
        return noisy_latent, learning_ratio

    midpoint_fraction = 0.5  # Evaluate at midpoint

    # Phi function weights for 2-stage method
    phi_1_value = phi_function(order=1, step_size=-step_size)
    phi_2_value = phi_function(order=2, step_size=-step_size)

    # Weights for final integration
    weight_stage2 = phi_2_value / midpoint_fraction
    weight_stage1 = phi_1_value - weight_stage2

    # Weight for advancing to stage 2
    phi_1_at_midpoint = phi_function(order=1, step_size=-step_size * midpoint_fraction)
    stage2_advance_weight = midpoint_fraction * phi_1_at_midpoint

    # Stage 1: Evaluate at current sigma
    model_prediction_stage1 = model(noisy_latent, sigma_current * s_in, **extra_args)
    error_stage1 = -(noisy_latent_at_step_start - model_prediction_stage1)  # epsilon at current sigma

    # Stage 2: Evaluate at midpoint sigma
    sigma_midpoint = torch.exp(-(-torch.log(sigma_current) + step_size * midpoint_fraction))
    noisy_latent_midpoint = noisy_latent_at_step_start + (step_size * stage2_advance_weight) * error_stage1

    model_prediction_stage2 = model(noisy_latent_midpoint, sigma_midpoint * s_in, **extra_args)
    error_stage2 = -(noisy_latent_at_step_start - model_prediction_stage2)  # epsilon at midpoint

    # Final integration with weighted stages
    noisy_latent = noisy_latent_at_step_start + step_size * (
        weight_stage1 * error_stage1 +
        weight_stage2 * error_stage2
    )

    if debug:
        stage1_norm = torch.norm(error_stage1).item()
        stage2_norm = torch.norm(error_stage2).item()
        print(f"res_2s step {step_index}: stage1_norm={stage1_norm:.2f}, stage2_norm={stage2_norm:.2f}, "
              f"weight_s1={weight_stage1.item():.4f}, weight_s2={weight_stage2.item():.4f}")

    # Learning update on REAL call (use epsilon at current sigma: error_stage1)
    epsilon_real = error_stage1
    epsilon_history.append(epsilon_real)
    if len(epsilon_history) >= 3:
        if predictor_type == "richardson":
            epsilon_hat = extrapolate_epsilon_richardson(epsilon_history)
        else:
            epsilon_hat = extrapolate_epsilon_linear(epsilon_history)
        if epsilon_hat is not None:
            learn_obs = (torch.norm(epsilon_hat) / (torch.norm(epsilon_real) + 1e-8)).item()
            learning_ratio = smoothing_beta * learning_ratio + (1.0 - smoothing_beta) * learn_obs
            if learning_ratio < 0.5:
                learning_ratio = 0.5
            elif learning_ratio > 2.0:
                learning_ratio = 2.0
            if debug:
                print(f"res_2s step {step_index} [LEARN]: learn_obs={learn_obs:.4f}, L={learning_ratio:.4f}, beta={smoothing_beta}")

    return noisy_latent, learning_ratio