File size: 30,103 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
"""
Integration tests for pipeline routing logic.

Tests that the pipeline correctly routes execution based on flags like
hires_fix, img2img, adetailer, etc. All model loading
is mocked to avoid loading real weights.
"""

import os
import sys
import pytest
import torch
from pathlib import Path
from unittest.mock import patch, MagicMock, call, ANY
from typing import Tuple

# Add project root to path
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(project_root))

pytestmark = pytest.mark.slow


# =============================================================================
# Test Fixtures
# =============================================================================

@pytest.fixture
def mock_all_heavy_dependencies(request):
    """
    Comprehensive mock that patches all heavy dependencies to allow
    testing pipeline routing logic without loading real models.
    """
    patches = {}
    
    # Mock model loading
    patches['loader'] = patch('src.FileManaging.Loader.CheckpointLoaderSimple')
    patches['model_cache'] = patch('src.Device.ModelCache.get_model_cache')
    patches['load_model'] = patch('src.user.model_loader.load_model_for_pipeline')
    
    # Mock CLIP operations
    patches['clip_encode'] = patch('src.clip.Clip.CLIPTextEncode')
    patches['clip_set_layer'] = patch('src.clip.Clip.CLIPSetLastLayer')
    
    # Mock VAE operations
    patches['vae_decode'] = patch('src.AutoEncoders.VariationalAE.VAEDecode')
    patches['vae_loader'] = patch('src.AutoEncoders.VariationalAE.VAELoader')
    
    # Mock Latent operations
    patches['empty_latent'] = patch('src.Utilities.Latent.EmptyLatentImage')
    patches['latent_upscale'] = patch('src.Utilities.upscale.LatentUpscale')
    
    # Mock Sampler
    patches['ksampler'] = patch('src.sample.sampling.KSampler')
    
    # Mock Image operations
    patches['save_image'] = patch('src.FileManaging.ImageSaver.SaveImage')
    
    # Mock LoRA
    patches['lora_loader'] = patch('src.Model.LoRas.LoraLoader')
    
    # Mock optimizations to ensure they return the model (allowing .called checks)
    patches['sf_applier'] = patch('src.StableFast.StableFast.ApplyStableFastUnet')
    patches['dc_applier'] = patch('src.WaveSpeed.deepcache_nodes.ApplyDeepCacheOnModel')
    
    # Mock HiDiffusion
    patches['hidiff'] = patch('src.hidiffusion.msw_msa_attention.ApplyMSWMSAAttentionSimple')
    
    # Mock HDR
    patches['hdr'] = patch('src.AutoHDR.ahdr.HDREffects')
    
    # Mock Downloader to avoid network calls
    patches['downloader'] = patch('src.FileManaging.Downloader.CheckAndDownload')
    
    # Mock app_instance - explicitly set interrupt_flag to False
    mock_app = MagicMock()
    mock_app.interrupt_flag = False
    patches['app_instance'] = patch('src.user.app_instance.app', mock_app)
    
    # Start all patches
    mocks = {name: p.start() for name, p in patches.items()}

    # Ensure the global model cache is cleared at the start of the fixture to avoid
    # interaction with previously cached checkpoint entries from other tests.
    try:
        from src.Device.ModelCache import get_model_cache
        get_model_cache().clear_cache()
    except Exception:
        pass
    
    def teardown():
        # Stop all patches in reverse order
        for p in reversed(list(patches.values())):
            try:
                p.stop()
            except Exception:
                pass
        patch.stopall()

        # Also clear the global model cache in teardown to ensure mocks that
        # cached fake checkpoints don't leak into following tests.
        try:
            from src.Device.ModelCache import get_model_cache
            get_model_cache().clear_cache()
        except Exception:
            pass
    
    request.addfinalizer(teardown)
    
    # Configure default return values
    from conftest import MockModelPatcher
    mock_model_patcher = MockModelPatcher()
    mock_clip = MagicMock()
    mock_vae = MagicMock()
    
    mocks['loader'].return_value.load_checkpoint.return_value = (
        mock_model_patcher, mock_clip, mock_vae
    )
    mocks['model_cache'].return_value.get_cached_checkpoint.return_value = None
    mocks['load_model'].return_value = ("SD15", (mock_model_patcher, mock_clip, mock_vae))
    
    # Mock CLIP encoding
    mock_cond = [[torch.randn(1, 77, 768), {}]]
    mocks['clip_encode'].return_value.encode.return_value = (mock_cond,)
    mocks['clip_set_layer'].return_value.set_last_layer.return_value = (mock_clip,)
    
    # Mock VAE decoding
    mocks['vae_decode'].return_value.decode.return_value = (torch.rand(1, 512, 512, 3),)
    
    # Mock latent generation
    mock_latent = {"samples": torch.randn(1, 4, 64, 64)}
    mocks['empty_latent'].return_value.generate.return_value = (mock_latent,)
    mocks['latent_upscale'].return_value.upscale.return_value = ({"samples": torch.randn(1, 4, 128, 128)},)
    
    # Mock sampler
    mocks['ksampler'].return_value.sample.return_value = ({"samples": torch.randn(1, 4, 64, 64)},)
    
    # Mock LoRA loader
    mocks['lora_loader'].return_value.load_lora.return_value = (
        mock_model_patcher, mock_clip, mock_vae
    )
    
    # Mock HiDiffusion
    mocks['hidiff'].return_value.go.return_value = (mock_model_patcher,)
    
    # Mock HDR
    mocks['hdr'].return_value.apply_hdr2.return_value = (torch.rand(1, 512, 512, 3),)
    
    # Mock image saver
    mocks['save_image'].return_value.save_images.return_value = {"ui": {"images": []}}
    mocks['save_image'].return_value.save_images_async = MagicMock()
    
    # Configure optimization appliers to return the mock model in a tuple
    mocks['sf_applier'].return_value.apply_stable_fast.return_value = (mock_model_patcher,)
    mocks['dc_applier'].return_value.patch.return_value = (mock_model_patcher,)
    
    yield mocks


# =============================================================================
# Pipeline Flag Routing Tests
# =============================================================================

@pytest.mark.slow
class TestPipelineBasicRouting:
    """Test basic pipeline routing based on flags."""
    
    def test_pipeline_runs_without_exception(self, mock_all_heavy_dependencies):
        """Pipeline should run without raising exceptions when properly mocked."""
        from src.user.pipeline import pipeline
        
        # Should not raise
        result = pipeline(
            prompt="a test prompt",
            w=512,
            h=512,
            number=1,
            batch=1,
        )
        
        assert result is not None
    
    def test_pipeline_returns_result_dict(self, mock_all_heavy_dependencies):
        """Pipeline should return a result dictionary."""
        from src.user.pipeline import pipeline
        
        result = pipeline(
            prompt="a test prompt",
            w=512,
            h=512,
        )
        
        assert isinstance(result, dict), f"Expected dict, got {type(result)}"
        assert "original_prompt" in result or "batched_results" in result


@pytest.mark.slow
class TestHiresFixRouting:
    """Test hires_fix flag routing."""
    
    def test_hires_fix_triggers_upscale(self, mock_all_heavy_dependencies):
        """hires_fix=True should trigger latent upscaling."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            hires_fix=True,
        )
        
        # Verify upscale was called
        latent_upscale = mock_all_heavy_dependencies['latent_upscale']
        assert latent_upscale.return_value.upscale.called, (
            "Latent upscale should be called when hires_fix=True"
        )
    
    def test_hires_fix_false_skips_upscale(self, mock_all_heavy_dependencies):
        """hires_fix=False should not trigger latent upscaling."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            hires_fix=False,
        )
        
        # Verify upscale was NOT called
        latent_upscale = mock_all_heavy_dependencies['latent_upscale']
        assert not latent_upscale.return_value.upscale.called, (
            "Latent upscale should NOT be called when hires_fix=False"
        )
    
    def test_hires_fix_runs_additional_sampling_pass(self, mock_all_heavy_dependencies):
        """hires_fix should run an additional sampling pass at higher resolution."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            hires_fix=True,
        )
        
        ksampler = mock_all_heavy_dependencies['ksampler']
        # Should have been called at least twice (initial + hires pass)
        assert ksampler.return_value.sample.call_count >= 2, (
            f"Expected at least 2 sampling passes for hires_fix, "
            f"got {ksampler.return_value.sample.call_count}"
        )

    def test_batched_hires_fix_with_refiner_sdxl(self, mock_all_heavy_dependencies):
        """Batched hires_fix with SDXL refiner should call latent upscaling using refiner prompts."""
        from src.user.pipeline import pipeline

        pipeline(
            prompt=["one"],
            w=512,
            h=512,
            batch=1,
            hires_fix=True,
            per_sample_info=[{"hires_fix": True}],
            refiner_model_path="refiner.safetensors",
            refiner_switch_step=1,
            model_path="my_sdxl_model.safetensors",
        )

        latent_upscale = mock_all_heavy_dependencies['latent_upscale']
        assert latent_upscale.return_value.upscale.called, (
            "Latent upscale should be called for batched hires_fix with refiner"
        )

    def test_hires_fix_reloads_base_model_after_refiner(self, mock_all_heavy_dependencies):
        """If a refiner unloaded the base model, the pipeline must reload the base model before HiresFix."""
        from src.user.pipeline import pipeline
        from unittest.mock import patch

        # Patch HiresFix.apply so we can inspect the `model` argument passed to it
        with patch('src.Processors.HiresFix.HiresFix.apply') as mock_hires_apply:
            mock_hires_apply.return_value = {"samples": __import__('torch').randn(1, 4, 128, 128)}

            pipeline(
                prompt=["one"],
                w=512,
                h=512,
                batch=1,
                hires_fix=True,
                per_sample_info=[{"hires_fix": True}],
                refiner_model_path="refiner.safetensors",
                refiner_switch_step=1,
                model_path="my_sdxl_model.safetensors",
            )

            assert mock_hires_apply.called, "HiresFix.apply should be invoked"
            called_model = mock_hires_apply.call_args[0][2]

            # The model passed to HiresFix must be loaded and have an inner model object
            assert getattr(called_model, 'is_loaded', False), "Base model must be loaded when passed to HiresFix"
            assert getattr(called_model, 'model', None) is not None, "Base model.model must be present for the hires pass"

    def test_batched_adetailer_with_refiner_sdxl(self, mock_all_heavy_dependencies):
        """Batched adetailer with SDXL refiner should call Adetailer.apply without NameError."""
        from src.user.pipeline import pipeline
        from unittest.mock import patch

        with patch('src.Processors.Adetailer.Adetailer.apply') as mock_adetail_apply:
            mock_adetail_apply.return_value = (torch.rand(1, 512, 512, 3), [])
            pipeline(
                prompt=["one"],
                w=512,
                h=512,
                batch=1,
                adetailer=True,
                per_sample_info=[{"adetailer": True}],
            )
            assert mock_adetail_apply.called, "Adetailer.apply should be called for batched adetailer with refiner"

    def test_hires_fix_with_flux_model(self, mock_all_heavy_dependencies):
        """HiresFix should work with Flux model (no refiner) and call upscale."""
        from src.user.pipeline import pipeline

        pipeline(
            prompt="test",
            w=512,
            h=512,
            hires_fix=True,
            model_path="flux_model.safetensors",
        )

        latent_upscale = mock_all_heavy_dependencies['latent_upscale']
        assert latent_upscale.return_value.upscale.called, "Latent upscale should be called for Flux model"

    def test_hires_fix_injects_size_conditioning_for_sdxl(self, mock_all_heavy_dependencies):
        """HiresFix should inject width/height into prompt conditioning for SDXL models."""
        from src.user.pipeline import pipeline
        from conftest import MockCheckpointResult

        # Force the loader to return an SDXL checkpoint to emulate SDXL behavior
        mock_all_heavy_dependencies['load_model'].return_value = ("SDXL", MockCheckpointResult("SDXL").as_tuple())

        pipeline(
            prompt="test",
            w=512,
            h=512,
            hires_fix=True,
            model_path="my_sdxl_model.safetensors",
        )

        ksampler = mock_all_heavy_dependencies['ksampler']
        # Inspect the last sampler call (hires pass)
        assert ksampler.return_value.sample.call_count >= 2
        last_call = ksampler.return_value.sample.call_args_list[-1]
        kwargs = last_call.kwargs
        positive = kwargs.get('positive')

        # The conditioning metadata should include updated width/height for the hires pass
        assert isinstance(positive, list)
        meta = positive[0][1]
        assert meta.get('width') == 1024
        assert meta.get('height') == 1024


class TestImg2ImgRouting:
    """Test img2img flag routing."""
    
    def test_img2img_requires_image_source(self, mock_all_heavy_dependencies, tmp_path):
        """img2img=True should use provided image path."""
        from src.user.pipeline import pipeline
        from PIL import Image
        
        # Create a test image
        test_image = tmp_path / "test.png"
        img = Image.new('RGB', (256, 256), color='red')
        img.save(test_image)
        
        # Mock the img2img-specific components
        with patch('src.UltimateSDUpscale.UltimateSDUpscale.UltimateSDUpscale') as mock_upscale:
            with patch('src.UltimateSDUpscale.USDU_upscaler.UpscaleModelLoader') as mock_loader:
                mock_upscale.return_value.upscale.return_value = (torch.rand(1, 512, 512, 3),)
                mock_loader.return_value.load_model.return_value = (MagicMock(),)
                
                pipeline(
                    prompt="test",
                    w=512,
                    h=512,
                    img2img=True,
                    img2img_image=str(test_image),
                )
                
                # UltimateSDUpscale should be used for img2img
                assert mock_upscale.called, (
                    "UltimateSDUpscale should be used for img2img"
                )
    
    def test_img2img_false_uses_text2img(self, mock_all_heavy_dependencies):
        """img2img=False should use text-to-image path."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            img2img=False,
        )
        
        # EmptyLatentImage should be called for text2img
        empty_latent = mock_all_heavy_dependencies['empty_latent']
        assert empty_latent.return_value.generate.called, (
            "EmptyLatentImage.generate should be called for text2img"
        )


class TestADetailerRouting:
    """Test adetailer flag routing."""
    
    def test_adetailer_enabled_triggers_detection(self, mock_all_heavy_dependencies):
        """adetailer=True should trigger face/body detection."""
        from src.user.pipeline import pipeline
        
        with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam:
            with patch('src.AutoDetailer.bbox.UltralyticsDetectorProvider') as mock_detector:
                with patch('src.AutoDetailer.bbox.BboxDetectorForEach') as mock_bbox:
                    with patch('src.AutoDetailer.SAM.SAMDetectorCombined') as mock_sam_combined:
                        with patch('src.AutoDetailer.SEGS.SegsBitwiseAndMask') as mock_segs:
                            with patch('src.AutoDetailer.ADetailer.DetailerForEachTest') as mock_detailer:
                                mock_sam.return_value.load_model.return_value = (MagicMock(),)
                                mock_detector.return_value.doit.return_value = (MagicMock(),)
                                mock_bbox.return_value.doit.return_value = MagicMock()
                                mock_sam_combined.return_value.doit.return_value = (torch.ones(1, 512, 512),)
                                mock_segs.return_value.doit.return_value = (MagicMock(),)
                                mock_detailer.return_value.doit.return_value = (
                                    torch.rand(1, 512, 512, 3),
                                    12345
                                )
                                
                                pipeline(
                                    prompt="test",
                                    w=512,
                                    h=512,
                                    adetailer=True,
                                )
                                
                                # SAM loader should be called
                                assert mock_sam.return_value.load_model.called, (
                                    "SAMLoader should be called when adetailer=True"
                                )
    
    def test_adetailer_disabled_skips_detection(self, mock_all_heavy_dependencies):
        """adetailer=False should skip face/body detection."""
        from src.user.pipeline import pipeline
        
        with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam:
            pipeline(
                prompt="test",
                w=512,
                h=512,
                adetailer=False,
            )
            
            # SAM should NOT be called
            assert not mock_sam.called, (
                "SAMLoader should NOT be called when adetailer=False"
            )


class TestMultiscaleRouting:
    """Test multiscale diffusion parameter routing."""
    
    def test_multiscale_preset_applied(self, mock_all_heavy_dependencies):
        """multiscale_preset should configure multiscale parameters."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            multiscale_preset="performance",
        )
        
        ksampler = mock_all_heavy_dependencies['ksampler']
        # Verify sample was called with multiscale parameters
        call_kwargs = ksampler.return_value.sample.call_args
        if call_kwargs:
            # Check that multiscale params were passed
            kwargs = call_kwargs.kwargs if call_kwargs.kwargs else {}
            # The pipeline should pass enable_multiscale to the sampler
            assert 'enable_multiscale' in kwargs or True  # May be positional
    
    def test_multiscale_disabled_preset(self, mock_all_heavy_dependencies):
        """multiscale_preset='disabled' should disable multiscale."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            multiscale_preset="disabled",
        )
        
        # Should still run without error
        ksampler = mock_all_heavy_dependencies['ksampler']
        assert ksampler.return_value.sample.called


class TestDeepCacheRouting:
    """Test DeepCache parameter routing."""
    
    def test_deepcache_enabled_applies_patch(self, mock_all_heavy_dependencies):
        """deepcache_enabled=True should apply DeepCache patch."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            deepcache_enabled=True,
        )
        
        # Verify the DeepCache applier was called
        dc_applier = mock_all_heavy_dependencies['dc_applier']
        assert dc_applier.return_value.patch.called, (
            "DeepCache should be applied when deepcache_enabled=True"
        )
    
    def test_deepcache_disabled_skips_patch(self, mock_all_heavy_dependencies):
        """deepcache_enabled=False should skip DeepCache patch."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            deepcache_enabled=False,
        )
        
        # Verify the DeepCache applier was NOT called
        dc_applier = mock_all_heavy_dependencies['dc_applier']
        assert not dc_applier.return_value.patch.called, (
            "DeepCache should NOT be applied when deepcache_enabled=False"
        )


class TestStableFastRouting:
    """Test StableFast parameter routing."""
    
    def test_stable_fast_enabled_applies_optimization(self, mock_all_heavy_dependencies):
        """stable_fast=True should apply StableFast optimization."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            stable_fast=True,
        )
        
        # Verify the StableFast applier was called
        sf_applier = mock_all_heavy_dependencies['sf_applier']
        assert sf_applier.return_value.apply_stable_fast.called, (
            "StableFast should be applied when stable_fast=True"
        )


class TestBatchedPromptRouting:
    """Test batched prompt routing (multiple prompts at once)."""
    
    @pytest.mark.skip(reason="Batched prompts require dynamic mock tensor sizing which is complex to set up; tested manually")
    def test_batched_prompts_use_batched_path(self, mock_all_heavy_dependencies):
        """List of prompts should use batched generation path.
        
        Note: This test is skipped because the pipeline internally iterates
        over batched results, but our mocks return fixed single-item tensors.
        Proper testing would require dynamic mock configuration.
        """
        from src.user.pipeline import pipeline
        
        prompts = ["prompt 1", "prompt 2", "prompt 3"]
        
        result = pipeline(
            prompt=prompts,
            w=512,
            h=512,
        )
        
        # Result should indicate batched processing
        if "batched_results" in result:
            assert isinstance(result["batched_results"], dict)


class TestAutoHDRRouting:
    """Test AutoHDR parameter routing."""
    
    def test_autohdr_enabled_applies_effect(self, mock_all_heavy_dependencies):
        """autohdr=True should apply HDR effect."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            autohdr=True,
        )
        
        hdr = mock_all_heavy_dependencies['hdr']
        assert hdr.return_value.apply_hdr2.called, (
            "HDR effect should be applied when autohdr=True"
        )
    
    def test_autohdr_disabled_skips_effect(self, mock_all_heavy_dependencies):
        """autohdr=False should skip HDR effect."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            autohdr=False,
        )
        
        hdr = mock_all_heavy_dependencies['hdr']
        # Note: The implementation may still create the HDR object but not call it
        # This depends on the exact implementation


class TestCFGFreeRouting:
    """Test CFG-free sampling parameter routing."""
    
    def test_cfg_free_params_passed_to_sampler(self, mock_all_heavy_dependencies):
        """CFG-free parameters should be passed to the sampler."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            cfg_free_enabled=True,
            cfg_free_start_percent=70.0,
        )
        
        ksampler = mock_all_heavy_dependencies['ksampler']
        # Verify the sampler was called
        assert ksampler.return_value.sample.called


class TestTokenMergingRouting:
    """Test Token Merging (ToMe) parameter routing."""
    
    def test_tome_params_passed_when_enabled(self, mock_all_heavy_dependencies):
        """ToMe parameters should be applied when enabled."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            tome_enabled=True,
            tome_ratio=0.5,
        )
        
        # Should run without error even if ToMe isn't fully mocked


class TestNegativePromptRouting:
    """Test negative prompt handling."""
    
    def test_empty_negative_prompt_uses_default(self, mock_all_heavy_dependencies):
        """Empty negative prompt should use default."""
        from src.user.pipeline import pipeline
        
        result = pipeline(
            prompt="test",
            w=512,
            h=512,
            negative_prompt="",
        )
        
        # Should run without error
        assert result is not None
    
    def test_custom_negative_prompt_passed(self, mock_all_heavy_dependencies):
        """Custom negative prompt should be used."""
        from src.user.pipeline import pipeline
        
        custom_negative = "ugly, bad quality, distorted"
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            negative_prompt=custom_negative,
        )
        
        # CLIP encoder should be called (implicitly tests negative prompt was used)
        clip_encode = mock_all_heavy_dependencies['clip_encode']
        assert clip_encode.return_value.encode.called


class TestSeedRouting:
    """Test seed handling and reuse_seed flag."""
    
    def test_reuse_seed_uses_last_seed(self, mock_all_heavy_dependencies):
        """reuse_seed=True should use the last seed."""
        from src.user.pipeline import pipeline
        
        # First run to establish a seed
        pipeline(prompt="test", w=512, h=512, reuse_seed=False)
        
        # Second run with reuse_seed
        pipeline(prompt="test", w=512, h=512, reuse_seed=True)
        
        # Should run without error
        ksampler = mock_all_heavy_dependencies['ksampler']
        assert ksampler.return_value.sample.call_count >= 2


class TestModelPathRouting:
    """Test model_path parameter routing."""
    
    def test_custom_model_path_used(self, mock_all_heavy_dependencies):
        """Custom model_path should be passed to loader."""
        from src.user.pipeline import pipeline
        
        custom_path = "/path/to/custom_model.safetensors"
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            model_path=custom_path,
        )
        
        loader = mock_all_heavy_dependencies['loader']
        if loader.return_value.load_checkpoint.called:
            call_args = loader.return_value.load_checkpoint.call_args
            assert custom_path in str(call_args), (
                f"Custom model path should be used: {call_args}"
            )


class TestErrorHandling:
    """Test pipeline error handling."""
    
    def test_invalid_model_path_raises_error(self, mock_all_heavy_dependencies):
        """Invalid model path should raise a clean error."""
        from src.user.pipeline import pipeline
        
        # Make the loader raise an error
        mock_all_heavy_dependencies['loader'].return_value.load_checkpoint.side_effect = (
            FileNotFoundError("Model not found")
        )
        mock_all_heavy_dependencies['model_cache'].return_value.get_cached_checkpoint.return_value = None
        
        with pytest.raises(FileNotFoundError):
            pipeline(
                prompt="test",
                w=512,
                h=512,
                model_path="/nonexistent/model.safetensors",
            )
    
    def test_interruption_handled_gracefully(self, mock_all_heavy_dependencies):
        """Interruption should raise InterruptedError."""
        from src.user.pipeline import pipeline
        
        # Mock interrupt flag being set
        mock_app = MagicMock()
        mock_app.interrupt_flag = True
        
        with patch('src.user.app_instance.app', mock_app):
            with pytest.raises(InterruptedError):
                pipeline(prompt="test", w=512, h=512)


class TestSchedulerSamplerRouting:
    """Test scheduler and sampler parameter routing."""
    
    @pytest.mark.parametrize("scheduler", [
        "normal", "karras", "simple", "beta", "ays", "ays_sd15", "ays_sdxl"
    ])
    def test_scheduler_passed_to_sampler(self, mock_all_heavy_dependencies, scheduler):
        """Scheduler parameter should be passed to KSampler."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            scheduler=scheduler,
        )
        
        ksampler = mock_all_heavy_dependencies['ksampler']
        assert ksampler.return_value.sample.called
    
    @pytest.mark.parametrize("sampler", [
        "euler", "euler_ancestral", "euler_cfgpp", 
        "euler_ancestral_cfgpp", "dpmpp_2m_cfgpp", "dpmpp_sde_cfgpp"
    ])
    def test_sampler_passed_to_ksampler(self, mock_all_heavy_dependencies, sampler):
        """Sampler parameter should be passed to KSampler."""
        from src.user.pipeline import pipeline
        
        pipeline(
            prompt="test",
            w=512,
            h=512,
            sampler=sampler,
        )
        
        ksampler = mock_all_heavy_dependencies['ksampler']
        assert ksampler.return_value.sample.called