""" Integration tests for pipeline routing logic. Tests that the pipeline correctly routes execution based on flags like hires_fix, img2img, adetailer, etc. All model loading is mocked to avoid loading real weights. """ import os import sys import pytest import torch from pathlib import Path from unittest.mock import patch, MagicMock, call, ANY from typing import Tuple # Add project root to path project_root = Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(project_root)) pytestmark = pytest.mark.slow # ============================================================================= # Test Fixtures # ============================================================================= @pytest.fixture def mock_all_heavy_dependencies(request): """ Comprehensive mock that patches all heavy dependencies to allow testing pipeline routing logic without loading real models. """ patches = {} # Mock model loading patches['loader'] = patch('src.FileManaging.Loader.CheckpointLoaderSimple') patches['model_cache'] = patch('src.Device.ModelCache.get_model_cache') patches['load_model'] = patch('src.user.model_loader.load_model_for_pipeline') # Mock CLIP operations patches['clip_encode'] = patch('src.clip.Clip.CLIPTextEncode') patches['clip_set_layer'] = patch('src.clip.Clip.CLIPSetLastLayer') # Mock VAE operations patches['vae_decode'] = patch('src.AutoEncoders.VariationalAE.VAEDecode') patches['vae_loader'] = patch('src.AutoEncoders.VariationalAE.VAELoader') # Mock Latent operations patches['empty_latent'] = patch('src.Utilities.Latent.EmptyLatentImage') patches['latent_upscale'] = patch('src.Utilities.upscale.LatentUpscale') # Mock Sampler patches['ksampler'] = patch('src.sample.sampling.KSampler') # Mock Image operations patches['save_image'] = patch('src.FileManaging.ImageSaver.SaveImage') # Mock LoRA patches['lora_loader'] = patch('src.Model.LoRas.LoraLoader') # Mock optimizations to ensure they return the model (allowing .called checks) patches['sf_applier'] = patch('src.StableFast.StableFast.ApplyStableFastUnet') patches['dc_applier'] = patch('src.WaveSpeed.deepcache_nodes.ApplyDeepCacheOnModel') # Mock HiDiffusion patches['hidiff'] = patch('src.hidiffusion.msw_msa_attention.ApplyMSWMSAAttentionSimple') # Mock HDR patches['hdr'] = patch('src.AutoHDR.ahdr.HDREffects') # Mock Downloader to avoid network calls patches['downloader'] = patch('src.FileManaging.Downloader.CheckAndDownload') # Mock app_instance - explicitly set interrupt_flag to False mock_app = MagicMock() mock_app.interrupt_flag = False patches['app_instance'] = patch('src.user.app_instance.app', mock_app) # Start all patches mocks = {name: p.start() for name, p in patches.items()} # Ensure the global model cache is cleared at the start of the fixture to avoid # interaction with previously cached checkpoint entries from other tests. try: from src.Device.ModelCache import get_model_cache get_model_cache().clear_cache() except Exception: pass def teardown(): # Stop all patches in reverse order for p in reversed(list(patches.values())): try: p.stop() except Exception: pass patch.stopall() # Also clear the global model cache in teardown to ensure mocks that # cached fake checkpoints don't leak into following tests. try: from src.Device.ModelCache import get_model_cache get_model_cache().clear_cache() except Exception: pass request.addfinalizer(teardown) # Configure default return values from conftest import MockModelPatcher mock_model_patcher = MockModelPatcher() mock_clip = MagicMock() mock_vae = MagicMock() mocks['loader'].return_value.load_checkpoint.return_value = ( mock_model_patcher, mock_clip, mock_vae ) mocks['model_cache'].return_value.get_cached_checkpoint.return_value = None mocks['load_model'].return_value = ("SD15", (mock_model_patcher, mock_clip, mock_vae)) # Mock CLIP encoding mock_cond = [[torch.randn(1, 77, 768), {}]] mocks['clip_encode'].return_value.encode.return_value = (mock_cond,) mocks['clip_set_layer'].return_value.set_last_layer.return_value = (mock_clip,) # Mock VAE decoding mocks['vae_decode'].return_value.decode.return_value = (torch.rand(1, 512, 512, 3),) # Mock latent generation mock_latent = {"samples": torch.randn(1, 4, 64, 64)} mocks['empty_latent'].return_value.generate.return_value = (mock_latent,) mocks['latent_upscale'].return_value.upscale.return_value = ({"samples": torch.randn(1, 4, 128, 128)},) # Mock sampler mocks['ksampler'].return_value.sample.return_value = ({"samples": torch.randn(1, 4, 64, 64)},) # Mock LoRA loader mocks['lora_loader'].return_value.load_lora.return_value = ( mock_model_patcher, mock_clip, mock_vae ) # Mock HiDiffusion mocks['hidiff'].return_value.go.return_value = (mock_model_patcher,) # Mock HDR mocks['hdr'].return_value.apply_hdr2.return_value = (torch.rand(1, 512, 512, 3),) # Mock image saver mocks['save_image'].return_value.save_images.return_value = {"ui": {"images": []}} mocks['save_image'].return_value.save_images_async = MagicMock() # Configure optimization appliers to return the mock model in a tuple mocks['sf_applier'].return_value.apply_stable_fast.return_value = (mock_model_patcher,) mocks['dc_applier'].return_value.patch.return_value = (mock_model_patcher,) yield mocks # ============================================================================= # Pipeline Flag Routing Tests # ============================================================================= @pytest.mark.slow class TestPipelineBasicRouting: """Test basic pipeline routing based on flags.""" def test_pipeline_runs_without_exception(self, mock_all_heavy_dependencies): """Pipeline should run without raising exceptions when properly mocked.""" from src.user.pipeline import pipeline # Should not raise result = pipeline( prompt="a test prompt", w=512, h=512, number=1, batch=1, ) assert result is not None def test_pipeline_returns_result_dict(self, mock_all_heavy_dependencies): """Pipeline should return a result dictionary.""" from src.user.pipeline import pipeline result = pipeline( prompt="a test prompt", w=512, h=512, ) assert isinstance(result, dict), f"Expected dict, got {type(result)}" assert "original_prompt" in result or "batched_results" in result @pytest.mark.slow class TestHiresFixRouting: """Test hires_fix flag routing.""" def test_hires_fix_triggers_upscale(self, mock_all_heavy_dependencies): """hires_fix=True should trigger latent upscaling.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, hires_fix=True, ) # Verify upscale was called latent_upscale = mock_all_heavy_dependencies['latent_upscale'] assert latent_upscale.return_value.upscale.called, ( "Latent upscale should be called when hires_fix=True" ) def test_hires_fix_false_skips_upscale(self, mock_all_heavy_dependencies): """hires_fix=False should not trigger latent upscaling.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, hires_fix=False, ) # Verify upscale was NOT called latent_upscale = mock_all_heavy_dependencies['latent_upscale'] assert not latent_upscale.return_value.upscale.called, ( "Latent upscale should NOT be called when hires_fix=False" ) def test_hires_fix_runs_additional_sampling_pass(self, mock_all_heavy_dependencies): """hires_fix should run an additional sampling pass at higher resolution.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, hires_fix=True, ) ksampler = mock_all_heavy_dependencies['ksampler'] # Should have been called at least twice (initial + hires pass) assert ksampler.return_value.sample.call_count >= 2, ( f"Expected at least 2 sampling passes for hires_fix, " f"got {ksampler.return_value.sample.call_count}" ) def test_batched_hires_fix_with_refiner_sdxl(self, mock_all_heavy_dependencies): """Batched hires_fix with SDXL refiner should call latent upscaling using refiner prompts.""" from src.user.pipeline import pipeline pipeline( prompt=["one"], w=512, h=512, batch=1, hires_fix=True, per_sample_info=[{"hires_fix": True}], refiner_model_path="refiner.safetensors", refiner_switch_step=1, model_path="my_sdxl_model.safetensors", ) latent_upscale = mock_all_heavy_dependencies['latent_upscale'] assert latent_upscale.return_value.upscale.called, ( "Latent upscale should be called for batched hires_fix with refiner" ) def test_hires_fix_reloads_base_model_after_refiner(self, mock_all_heavy_dependencies): """If a refiner unloaded the base model, the pipeline must reload the base model before HiresFix.""" from src.user.pipeline import pipeline from unittest.mock import patch # Patch HiresFix.apply so we can inspect the `model` argument passed to it with patch('src.Processors.HiresFix.HiresFix.apply') as mock_hires_apply: mock_hires_apply.return_value = {"samples": __import__('torch').randn(1, 4, 128, 128)} pipeline( prompt=["one"], w=512, h=512, batch=1, hires_fix=True, per_sample_info=[{"hires_fix": True}], refiner_model_path="refiner.safetensors", refiner_switch_step=1, model_path="my_sdxl_model.safetensors", ) assert mock_hires_apply.called, "HiresFix.apply should be invoked" called_model = mock_hires_apply.call_args[0][2] # The model passed to HiresFix must be loaded and have an inner model object assert getattr(called_model, 'is_loaded', False), "Base model must be loaded when passed to HiresFix" assert getattr(called_model, 'model', None) is not None, "Base model.model must be present for the hires pass" def test_batched_adetailer_with_refiner_sdxl(self, mock_all_heavy_dependencies): """Batched adetailer with SDXL refiner should call Adetailer.apply without NameError.""" from src.user.pipeline import pipeline from unittest.mock import patch with patch('src.Processors.Adetailer.Adetailer.apply') as mock_adetail_apply: mock_adetail_apply.return_value = (torch.rand(1, 512, 512, 3), []) pipeline( prompt=["one"], w=512, h=512, batch=1, adetailer=True, per_sample_info=[{"adetailer": True}], ) assert mock_adetail_apply.called, "Adetailer.apply should be called for batched adetailer with refiner" def test_hires_fix_with_flux_model(self, mock_all_heavy_dependencies): """HiresFix should work with Flux model (no refiner) and call upscale.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, hires_fix=True, model_path="flux_model.safetensors", ) latent_upscale = mock_all_heavy_dependencies['latent_upscale'] assert latent_upscale.return_value.upscale.called, "Latent upscale should be called for Flux model" def test_hires_fix_injects_size_conditioning_for_sdxl(self, mock_all_heavy_dependencies): """HiresFix should inject width/height into prompt conditioning for SDXL models.""" from src.user.pipeline import pipeline from conftest import MockCheckpointResult # Force the loader to return an SDXL checkpoint to emulate SDXL behavior mock_all_heavy_dependencies['load_model'].return_value = ("SDXL", MockCheckpointResult("SDXL").as_tuple()) pipeline( prompt="test", w=512, h=512, hires_fix=True, model_path="my_sdxl_model.safetensors", ) ksampler = mock_all_heavy_dependencies['ksampler'] # Inspect the last sampler call (hires pass) assert ksampler.return_value.sample.call_count >= 2 last_call = ksampler.return_value.sample.call_args_list[-1] kwargs = last_call.kwargs positive = kwargs.get('positive') # The conditioning metadata should include updated width/height for the hires pass assert isinstance(positive, list) meta = positive[0][1] assert meta.get('width') == 1024 assert meta.get('height') == 1024 class TestImg2ImgRouting: """Test img2img flag routing.""" def test_img2img_requires_image_source(self, mock_all_heavy_dependencies, tmp_path): """img2img=True should use provided image path.""" from src.user.pipeline import pipeline from PIL import Image # Create a test image test_image = tmp_path / "test.png" img = Image.new('RGB', (256, 256), color='red') img.save(test_image) # Mock the img2img-specific components with patch('src.UltimateSDUpscale.UltimateSDUpscale.UltimateSDUpscale') as mock_upscale: with patch('src.UltimateSDUpscale.USDU_upscaler.UpscaleModelLoader') as mock_loader: mock_upscale.return_value.upscale.return_value = (torch.rand(1, 512, 512, 3),) mock_loader.return_value.load_model.return_value = (MagicMock(),) pipeline( prompt="test", w=512, h=512, img2img=True, img2img_image=str(test_image), ) # UltimateSDUpscale should be used for img2img assert mock_upscale.called, ( "UltimateSDUpscale should be used for img2img" ) def test_img2img_false_uses_text2img(self, mock_all_heavy_dependencies): """img2img=False should use text-to-image path.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, img2img=False, ) # EmptyLatentImage should be called for text2img empty_latent = mock_all_heavy_dependencies['empty_latent'] assert empty_latent.return_value.generate.called, ( "EmptyLatentImage.generate should be called for text2img" ) class TestADetailerRouting: """Test adetailer flag routing.""" def test_adetailer_enabled_triggers_detection(self, mock_all_heavy_dependencies): """adetailer=True should trigger face/body detection.""" from src.user.pipeline import pipeline with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam: with patch('src.AutoDetailer.bbox.UltralyticsDetectorProvider') as mock_detector: with patch('src.AutoDetailer.bbox.BboxDetectorForEach') as mock_bbox: with patch('src.AutoDetailer.SAM.SAMDetectorCombined') as mock_sam_combined: with patch('src.AutoDetailer.SEGS.SegsBitwiseAndMask') as mock_segs: with patch('src.AutoDetailer.ADetailer.DetailerForEachTest') as mock_detailer: mock_sam.return_value.load_model.return_value = (MagicMock(),) mock_detector.return_value.doit.return_value = (MagicMock(),) mock_bbox.return_value.doit.return_value = MagicMock() mock_sam_combined.return_value.doit.return_value = (torch.ones(1, 512, 512),) mock_segs.return_value.doit.return_value = (MagicMock(),) mock_detailer.return_value.doit.return_value = ( torch.rand(1, 512, 512, 3), 12345 ) pipeline( prompt="test", w=512, h=512, adetailer=True, ) # SAM loader should be called assert mock_sam.return_value.load_model.called, ( "SAMLoader should be called when adetailer=True" ) def test_adetailer_disabled_skips_detection(self, mock_all_heavy_dependencies): """adetailer=False should skip face/body detection.""" from src.user.pipeline import pipeline with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam: pipeline( prompt="test", w=512, h=512, adetailer=False, ) # SAM should NOT be called assert not mock_sam.called, ( "SAMLoader should NOT be called when adetailer=False" ) class TestMultiscaleRouting: """Test multiscale diffusion parameter routing.""" def test_multiscale_preset_applied(self, mock_all_heavy_dependencies): """multiscale_preset should configure multiscale parameters.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, multiscale_preset="performance", ) ksampler = mock_all_heavy_dependencies['ksampler'] # Verify sample was called with multiscale parameters call_kwargs = ksampler.return_value.sample.call_args if call_kwargs: # Check that multiscale params were passed kwargs = call_kwargs.kwargs if call_kwargs.kwargs else {} # The pipeline should pass enable_multiscale to the sampler assert 'enable_multiscale' in kwargs or True # May be positional def test_multiscale_disabled_preset(self, mock_all_heavy_dependencies): """multiscale_preset='disabled' should disable multiscale.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, multiscale_preset="disabled", ) # Should still run without error ksampler = mock_all_heavy_dependencies['ksampler'] assert ksampler.return_value.sample.called class TestDeepCacheRouting: """Test DeepCache parameter routing.""" def test_deepcache_enabled_applies_patch(self, mock_all_heavy_dependencies): """deepcache_enabled=True should apply DeepCache patch.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, deepcache_enabled=True, ) # Verify the DeepCache applier was called dc_applier = mock_all_heavy_dependencies['dc_applier'] assert dc_applier.return_value.patch.called, ( "DeepCache should be applied when deepcache_enabled=True" ) def test_deepcache_disabled_skips_patch(self, mock_all_heavy_dependencies): """deepcache_enabled=False should skip DeepCache patch.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, deepcache_enabled=False, ) # Verify the DeepCache applier was NOT called dc_applier = mock_all_heavy_dependencies['dc_applier'] assert not dc_applier.return_value.patch.called, ( "DeepCache should NOT be applied when deepcache_enabled=False" ) class TestStableFastRouting: """Test StableFast parameter routing.""" def test_stable_fast_enabled_applies_optimization(self, mock_all_heavy_dependencies): """stable_fast=True should apply StableFast optimization.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, stable_fast=True, ) # Verify the StableFast applier was called sf_applier = mock_all_heavy_dependencies['sf_applier'] assert sf_applier.return_value.apply_stable_fast.called, ( "StableFast should be applied when stable_fast=True" ) class TestBatchedPromptRouting: """Test batched prompt routing (multiple prompts at once).""" @pytest.mark.skip(reason="Batched prompts require dynamic mock tensor sizing which is complex to set up; tested manually") def test_batched_prompts_use_batched_path(self, mock_all_heavy_dependencies): """List of prompts should use batched generation path. Note: This test is skipped because the pipeline internally iterates over batched results, but our mocks return fixed single-item tensors. Proper testing would require dynamic mock configuration. """ from src.user.pipeline import pipeline prompts = ["prompt 1", "prompt 2", "prompt 3"] result = pipeline( prompt=prompts, w=512, h=512, ) # Result should indicate batched processing if "batched_results" in result: assert isinstance(result["batched_results"], dict) class TestAutoHDRRouting: """Test AutoHDR parameter routing.""" def test_autohdr_enabled_applies_effect(self, mock_all_heavy_dependencies): """autohdr=True should apply HDR effect.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, autohdr=True, ) hdr = mock_all_heavy_dependencies['hdr'] assert hdr.return_value.apply_hdr2.called, ( "HDR effect should be applied when autohdr=True" ) def test_autohdr_disabled_skips_effect(self, mock_all_heavy_dependencies): """autohdr=False should skip HDR effect.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, autohdr=False, ) hdr = mock_all_heavy_dependencies['hdr'] # Note: The implementation may still create the HDR object but not call it # This depends on the exact implementation class TestCFGFreeRouting: """Test CFG-free sampling parameter routing.""" def test_cfg_free_params_passed_to_sampler(self, mock_all_heavy_dependencies): """CFG-free parameters should be passed to the sampler.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, cfg_free_enabled=True, cfg_free_start_percent=70.0, ) ksampler = mock_all_heavy_dependencies['ksampler'] # Verify the sampler was called assert ksampler.return_value.sample.called class TestTokenMergingRouting: """Test Token Merging (ToMe) parameter routing.""" def test_tome_params_passed_when_enabled(self, mock_all_heavy_dependencies): """ToMe parameters should be applied when enabled.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, tome_enabled=True, tome_ratio=0.5, ) # Should run without error even if ToMe isn't fully mocked class TestNegativePromptRouting: """Test negative prompt handling.""" def test_empty_negative_prompt_uses_default(self, mock_all_heavy_dependencies): """Empty negative prompt should use default.""" from src.user.pipeline import pipeline result = pipeline( prompt="test", w=512, h=512, negative_prompt="", ) # Should run without error assert result is not None def test_custom_negative_prompt_passed(self, mock_all_heavy_dependencies): """Custom negative prompt should be used.""" from src.user.pipeline import pipeline custom_negative = "ugly, bad quality, distorted" pipeline( prompt="test", w=512, h=512, negative_prompt=custom_negative, ) # CLIP encoder should be called (implicitly tests negative prompt was used) clip_encode = mock_all_heavy_dependencies['clip_encode'] assert clip_encode.return_value.encode.called class TestSeedRouting: """Test seed handling and reuse_seed flag.""" def test_reuse_seed_uses_last_seed(self, mock_all_heavy_dependencies): """reuse_seed=True should use the last seed.""" from src.user.pipeline import pipeline # First run to establish a seed pipeline(prompt="test", w=512, h=512, reuse_seed=False) # Second run with reuse_seed pipeline(prompt="test", w=512, h=512, reuse_seed=True) # Should run without error ksampler = mock_all_heavy_dependencies['ksampler'] assert ksampler.return_value.sample.call_count >= 2 class TestModelPathRouting: """Test model_path parameter routing.""" def test_custom_model_path_used(self, mock_all_heavy_dependencies): """Custom model_path should be passed to loader.""" from src.user.pipeline import pipeline custom_path = "/path/to/custom_model.safetensors" pipeline( prompt="test", w=512, h=512, model_path=custom_path, ) loader = mock_all_heavy_dependencies['loader'] if loader.return_value.load_checkpoint.called: call_args = loader.return_value.load_checkpoint.call_args assert custom_path in str(call_args), ( f"Custom model path should be used: {call_args}" ) class TestErrorHandling: """Test pipeline error handling.""" def test_invalid_model_path_raises_error(self, mock_all_heavy_dependencies): """Invalid model path should raise a clean error.""" from src.user.pipeline import pipeline # Make the loader raise an error mock_all_heavy_dependencies['loader'].return_value.load_checkpoint.side_effect = ( FileNotFoundError("Model not found") ) mock_all_heavy_dependencies['model_cache'].return_value.get_cached_checkpoint.return_value = None with pytest.raises(FileNotFoundError): pipeline( prompt="test", w=512, h=512, model_path="/nonexistent/model.safetensors", ) def test_interruption_handled_gracefully(self, mock_all_heavy_dependencies): """Interruption should raise InterruptedError.""" from src.user.pipeline import pipeline # Mock interrupt flag being set mock_app = MagicMock() mock_app.interrupt_flag = True with patch('src.user.app_instance.app', mock_app): with pytest.raises(InterruptedError): pipeline(prompt="test", w=512, h=512) class TestSchedulerSamplerRouting: """Test scheduler and sampler parameter routing.""" @pytest.mark.parametrize("scheduler", [ "normal", "karras", "simple", "beta", "ays", "ays_sd15", "ays_sdxl" ]) def test_scheduler_passed_to_sampler(self, mock_all_heavy_dependencies, scheduler): """Scheduler parameter should be passed to KSampler.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, scheduler=scheduler, ) ksampler = mock_all_heavy_dependencies['ksampler'] assert ksampler.return_value.sample.called @pytest.mark.parametrize("sampler", [ "euler", "euler_ancestral", "euler_cfgpp", "euler_ancestral_cfgpp", "dpmpp_2m_cfgpp", "dpmpp_sde_cfgpp" ]) def test_sampler_passed_to_ksampler(self, mock_all_heavy_dependencies, sampler): """Sampler parameter should be passed to KSampler.""" from src.user.pipeline import pipeline pipeline( prompt="test", w=512, h=512, sampler=sampler, ) ksampler = mock_all_heavy_dependencies['ksampler'] assert ksampler.return_value.sample.called