Spaces:
Running on Zero
Running on Zero
| """ | |
| Integration tests for pipeline routing logic. | |
| Tests that the pipeline correctly routes execution based on flags like | |
| hires_fix, img2img, adetailer, etc. All model loading | |
| is mocked to avoid loading real weights. | |
| """ | |
| import os | |
| import sys | |
| import pytest | |
| import torch | |
| from pathlib import Path | |
| from unittest.mock import patch, MagicMock, call, ANY | |
| from typing import Tuple | |
| # Add project root to path | |
| project_root = Path(__file__).resolve().parent.parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| pytestmark = pytest.mark.slow | |
| # ============================================================================= | |
| # Test Fixtures | |
| # ============================================================================= | |
| def mock_all_heavy_dependencies(request): | |
| """ | |
| Comprehensive mock that patches all heavy dependencies to allow | |
| testing pipeline routing logic without loading real models. | |
| """ | |
| patches = {} | |
| # Mock model loading | |
| patches['loader'] = patch('src.FileManaging.Loader.CheckpointLoaderSimple') | |
| patches['model_cache'] = patch('src.Device.ModelCache.get_model_cache') | |
| patches['load_model'] = patch('src.user.model_loader.load_model_for_pipeline') | |
| # Mock CLIP operations | |
| patches['clip_encode'] = patch('src.clip.Clip.CLIPTextEncode') | |
| patches['clip_set_layer'] = patch('src.clip.Clip.CLIPSetLastLayer') | |
| # Mock VAE operations | |
| patches['vae_decode'] = patch('src.AutoEncoders.VariationalAE.VAEDecode') | |
| patches['vae_loader'] = patch('src.AutoEncoders.VariationalAE.VAELoader') | |
| # Mock Latent operations | |
| patches['empty_latent'] = patch('src.Utilities.Latent.EmptyLatentImage') | |
| patches['latent_upscale'] = patch('src.Utilities.upscale.LatentUpscale') | |
| # Mock Sampler | |
| patches['ksampler'] = patch('src.sample.sampling.KSampler') | |
| # Mock Image operations | |
| patches['save_image'] = patch('src.FileManaging.ImageSaver.SaveImage') | |
| # Mock LoRA | |
| patches['lora_loader'] = patch('src.Model.LoRas.LoraLoader') | |
| # Mock optimizations to ensure they return the model (allowing .called checks) | |
| patches['sf_applier'] = patch('src.StableFast.StableFast.ApplyStableFastUnet') | |
| patches['dc_applier'] = patch('src.WaveSpeed.deepcache_nodes.ApplyDeepCacheOnModel') | |
| # Mock HiDiffusion | |
| patches['hidiff'] = patch('src.hidiffusion.msw_msa_attention.ApplyMSWMSAAttentionSimple') | |
| # Mock HDR | |
| patches['hdr'] = patch('src.AutoHDR.ahdr.HDREffects') | |
| # Mock Downloader to avoid network calls | |
| patches['downloader'] = patch('src.FileManaging.Downloader.CheckAndDownload') | |
| # Mock app_instance - explicitly set interrupt_flag to False | |
| mock_app = MagicMock() | |
| mock_app.interrupt_flag = False | |
| patches['app_instance'] = patch('src.user.app_instance.app', mock_app) | |
| # Start all patches | |
| mocks = {name: p.start() for name, p in patches.items()} | |
| # Ensure the global model cache is cleared at the start of the fixture to avoid | |
| # interaction with previously cached checkpoint entries from other tests. | |
| try: | |
| from src.Device.ModelCache import get_model_cache | |
| get_model_cache().clear_cache() | |
| except Exception: | |
| pass | |
| def teardown(): | |
| # Stop all patches in reverse order | |
| for p in reversed(list(patches.values())): | |
| try: | |
| p.stop() | |
| except Exception: | |
| pass | |
| patch.stopall() | |
| # Also clear the global model cache in teardown to ensure mocks that | |
| # cached fake checkpoints don't leak into following tests. | |
| try: | |
| from src.Device.ModelCache import get_model_cache | |
| get_model_cache().clear_cache() | |
| except Exception: | |
| pass | |
| request.addfinalizer(teardown) | |
| # Configure default return values | |
| from conftest import MockModelPatcher | |
| mock_model_patcher = MockModelPatcher() | |
| mock_clip = MagicMock() | |
| mock_vae = MagicMock() | |
| mocks['loader'].return_value.load_checkpoint.return_value = ( | |
| mock_model_patcher, mock_clip, mock_vae | |
| ) | |
| mocks['model_cache'].return_value.get_cached_checkpoint.return_value = None | |
| mocks['load_model'].return_value = ("SD15", (mock_model_patcher, mock_clip, mock_vae)) | |
| # Mock CLIP encoding | |
| mock_cond = [[torch.randn(1, 77, 768), {}]] | |
| mocks['clip_encode'].return_value.encode.return_value = (mock_cond,) | |
| mocks['clip_set_layer'].return_value.set_last_layer.return_value = (mock_clip,) | |
| # Mock VAE decoding | |
| mocks['vae_decode'].return_value.decode.return_value = (torch.rand(1, 512, 512, 3),) | |
| # Mock latent generation | |
| mock_latent = {"samples": torch.randn(1, 4, 64, 64)} | |
| mocks['empty_latent'].return_value.generate.return_value = (mock_latent,) | |
| mocks['latent_upscale'].return_value.upscale.return_value = ({"samples": torch.randn(1, 4, 128, 128)},) | |
| # Mock sampler | |
| mocks['ksampler'].return_value.sample.return_value = ({"samples": torch.randn(1, 4, 64, 64)},) | |
| # Mock LoRA loader | |
| mocks['lora_loader'].return_value.load_lora.return_value = ( | |
| mock_model_patcher, mock_clip, mock_vae | |
| ) | |
| # Mock HiDiffusion | |
| mocks['hidiff'].return_value.go.return_value = (mock_model_patcher,) | |
| # Mock HDR | |
| mocks['hdr'].return_value.apply_hdr2.return_value = (torch.rand(1, 512, 512, 3),) | |
| # Mock image saver | |
| mocks['save_image'].return_value.save_images.return_value = {"ui": {"images": []}} | |
| mocks['save_image'].return_value.save_images_async = MagicMock() | |
| # Configure optimization appliers to return the mock model in a tuple | |
| mocks['sf_applier'].return_value.apply_stable_fast.return_value = (mock_model_patcher,) | |
| mocks['dc_applier'].return_value.patch.return_value = (mock_model_patcher,) | |
| yield mocks | |
| # ============================================================================= | |
| # Pipeline Flag Routing Tests | |
| # ============================================================================= | |
| class TestPipelineBasicRouting: | |
| """Test basic pipeline routing based on flags.""" | |
| def test_pipeline_runs_without_exception(self, mock_all_heavy_dependencies): | |
| """Pipeline should run without raising exceptions when properly mocked.""" | |
| from src.user.pipeline import pipeline | |
| # Should not raise | |
| result = pipeline( | |
| prompt="a test prompt", | |
| w=512, | |
| h=512, | |
| number=1, | |
| batch=1, | |
| ) | |
| assert result is not None | |
| def test_pipeline_returns_result_dict(self, mock_all_heavy_dependencies): | |
| """Pipeline should return a result dictionary.""" | |
| from src.user.pipeline import pipeline | |
| result = pipeline( | |
| prompt="a test prompt", | |
| w=512, | |
| h=512, | |
| ) | |
| assert isinstance(result, dict), f"Expected dict, got {type(result)}" | |
| assert "original_prompt" in result or "batched_results" in result | |
| class TestHiresFixRouting: | |
| """Test hires_fix flag routing.""" | |
| def test_hires_fix_triggers_upscale(self, mock_all_heavy_dependencies): | |
| """hires_fix=True should trigger latent upscaling.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| hires_fix=True, | |
| ) | |
| # Verify upscale was called | |
| latent_upscale = mock_all_heavy_dependencies['latent_upscale'] | |
| assert latent_upscale.return_value.upscale.called, ( | |
| "Latent upscale should be called when hires_fix=True" | |
| ) | |
| def test_hires_fix_false_skips_upscale(self, mock_all_heavy_dependencies): | |
| """hires_fix=False should not trigger latent upscaling.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| hires_fix=False, | |
| ) | |
| # Verify upscale was NOT called | |
| latent_upscale = mock_all_heavy_dependencies['latent_upscale'] | |
| assert not latent_upscale.return_value.upscale.called, ( | |
| "Latent upscale should NOT be called when hires_fix=False" | |
| ) | |
| def test_hires_fix_runs_additional_sampling_pass(self, mock_all_heavy_dependencies): | |
| """hires_fix should run an additional sampling pass at higher resolution.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| hires_fix=True, | |
| ) | |
| ksampler = mock_all_heavy_dependencies['ksampler'] | |
| # Should have been called at least twice (initial + hires pass) | |
| assert ksampler.return_value.sample.call_count >= 2, ( | |
| f"Expected at least 2 sampling passes for hires_fix, " | |
| f"got {ksampler.return_value.sample.call_count}" | |
| ) | |
| def test_batched_hires_fix_with_refiner_sdxl(self, mock_all_heavy_dependencies): | |
| """Batched hires_fix with SDXL refiner should call latent upscaling using refiner prompts.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt=["one"], | |
| w=512, | |
| h=512, | |
| batch=1, | |
| hires_fix=True, | |
| per_sample_info=[{"hires_fix": True}], | |
| refiner_model_path="refiner.safetensors", | |
| refiner_switch_step=1, | |
| model_path="my_sdxl_model.safetensors", | |
| ) | |
| latent_upscale = mock_all_heavy_dependencies['latent_upscale'] | |
| assert latent_upscale.return_value.upscale.called, ( | |
| "Latent upscale should be called for batched hires_fix with refiner" | |
| ) | |
| def test_hires_fix_reloads_base_model_after_refiner(self, mock_all_heavy_dependencies): | |
| """If a refiner unloaded the base model, the pipeline must reload the base model before HiresFix.""" | |
| from src.user.pipeline import pipeline | |
| from unittest.mock import patch | |
| # Patch HiresFix.apply so we can inspect the `model` argument passed to it | |
| with patch('src.Processors.HiresFix.HiresFix.apply') as mock_hires_apply: | |
| mock_hires_apply.return_value = {"samples": __import__('torch').randn(1, 4, 128, 128)} | |
| pipeline( | |
| prompt=["one"], | |
| w=512, | |
| h=512, | |
| batch=1, | |
| hires_fix=True, | |
| per_sample_info=[{"hires_fix": True}], | |
| refiner_model_path="refiner.safetensors", | |
| refiner_switch_step=1, | |
| model_path="my_sdxl_model.safetensors", | |
| ) | |
| assert mock_hires_apply.called, "HiresFix.apply should be invoked" | |
| called_model = mock_hires_apply.call_args[0][2] | |
| # The model passed to HiresFix must be loaded and have an inner model object | |
| assert getattr(called_model, 'is_loaded', False), "Base model must be loaded when passed to HiresFix" | |
| assert getattr(called_model, 'model', None) is not None, "Base model.model must be present for the hires pass" | |
| def test_batched_adetailer_with_refiner_sdxl(self, mock_all_heavy_dependencies): | |
| """Batched adetailer with SDXL refiner should call Adetailer.apply without NameError.""" | |
| from src.user.pipeline import pipeline | |
| from unittest.mock import patch | |
| with patch('src.Processors.Adetailer.Adetailer.apply') as mock_adetail_apply: | |
| mock_adetail_apply.return_value = (torch.rand(1, 512, 512, 3), []) | |
| pipeline( | |
| prompt=["one"], | |
| w=512, | |
| h=512, | |
| batch=1, | |
| adetailer=True, | |
| per_sample_info=[{"adetailer": True}], | |
| ) | |
| assert mock_adetail_apply.called, "Adetailer.apply should be called for batched adetailer with refiner" | |
| def test_hires_fix_with_flux_model(self, mock_all_heavy_dependencies): | |
| """HiresFix should work with Flux model (no refiner) and call upscale.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| hires_fix=True, | |
| model_path="flux_model.safetensors", | |
| ) | |
| latent_upscale = mock_all_heavy_dependencies['latent_upscale'] | |
| assert latent_upscale.return_value.upscale.called, "Latent upscale should be called for Flux model" | |
| def test_hires_fix_injects_size_conditioning_for_sdxl(self, mock_all_heavy_dependencies): | |
| """HiresFix should inject width/height into prompt conditioning for SDXL models.""" | |
| from src.user.pipeline import pipeline | |
| from conftest import MockCheckpointResult | |
| # Force the loader to return an SDXL checkpoint to emulate SDXL behavior | |
| mock_all_heavy_dependencies['load_model'].return_value = ("SDXL", MockCheckpointResult("SDXL").as_tuple()) | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| hires_fix=True, | |
| model_path="my_sdxl_model.safetensors", | |
| ) | |
| ksampler = mock_all_heavy_dependencies['ksampler'] | |
| # Inspect the last sampler call (hires pass) | |
| assert ksampler.return_value.sample.call_count >= 2 | |
| last_call = ksampler.return_value.sample.call_args_list[-1] | |
| kwargs = last_call.kwargs | |
| positive = kwargs.get('positive') | |
| # The conditioning metadata should include updated width/height for the hires pass | |
| assert isinstance(positive, list) | |
| meta = positive[0][1] | |
| assert meta.get('width') == 1024 | |
| assert meta.get('height') == 1024 | |
| class TestImg2ImgRouting: | |
| """Test img2img flag routing.""" | |
| def test_img2img_requires_image_source(self, mock_all_heavy_dependencies, tmp_path): | |
| """img2img=True should use provided image path.""" | |
| from src.user.pipeline import pipeline | |
| from PIL import Image | |
| # Create a test image | |
| test_image = tmp_path / "test.png" | |
| img = Image.new('RGB', (256, 256), color='red') | |
| img.save(test_image) | |
| # Mock the img2img-specific components | |
| with patch('src.UltimateSDUpscale.UltimateSDUpscale.UltimateSDUpscale') as mock_upscale: | |
| with patch('src.UltimateSDUpscale.USDU_upscaler.UpscaleModelLoader') as mock_loader: | |
| mock_upscale.return_value.upscale.return_value = (torch.rand(1, 512, 512, 3),) | |
| mock_loader.return_value.load_model.return_value = (MagicMock(),) | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| img2img=True, | |
| img2img_image=str(test_image), | |
| ) | |
| # UltimateSDUpscale should be used for img2img | |
| assert mock_upscale.called, ( | |
| "UltimateSDUpscale should be used for img2img" | |
| ) | |
| def test_img2img_false_uses_text2img(self, mock_all_heavy_dependencies): | |
| """img2img=False should use text-to-image path.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| img2img=False, | |
| ) | |
| # EmptyLatentImage should be called for text2img | |
| empty_latent = mock_all_heavy_dependencies['empty_latent'] | |
| assert empty_latent.return_value.generate.called, ( | |
| "EmptyLatentImage.generate should be called for text2img" | |
| ) | |
| class TestADetailerRouting: | |
| """Test adetailer flag routing.""" | |
| def test_adetailer_enabled_triggers_detection(self, mock_all_heavy_dependencies): | |
| """adetailer=True should trigger face/body detection.""" | |
| from src.user.pipeline import pipeline | |
| with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam: | |
| with patch('src.AutoDetailer.bbox.UltralyticsDetectorProvider') as mock_detector: | |
| with patch('src.AutoDetailer.bbox.BboxDetectorForEach') as mock_bbox: | |
| with patch('src.AutoDetailer.SAM.SAMDetectorCombined') as mock_sam_combined: | |
| with patch('src.AutoDetailer.SEGS.SegsBitwiseAndMask') as mock_segs: | |
| with patch('src.AutoDetailer.ADetailer.DetailerForEachTest') as mock_detailer: | |
| mock_sam.return_value.load_model.return_value = (MagicMock(),) | |
| mock_detector.return_value.doit.return_value = (MagicMock(),) | |
| mock_bbox.return_value.doit.return_value = MagicMock() | |
| mock_sam_combined.return_value.doit.return_value = (torch.ones(1, 512, 512),) | |
| mock_segs.return_value.doit.return_value = (MagicMock(),) | |
| mock_detailer.return_value.doit.return_value = ( | |
| torch.rand(1, 512, 512, 3), | |
| 12345 | |
| ) | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| adetailer=True, | |
| ) | |
| # SAM loader should be called | |
| assert mock_sam.return_value.load_model.called, ( | |
| "SAMLoader should be called when adetailer=True" | |
| ) | |
| def test_adetailer_disabled_skips_detection(self, mock_all_heavy_dependencies): | |
| """adetailer=False should skip face/body detection.""" | |
| from src.user.pipeline import pipeline | |
| with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam: | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| adetailer=False, | |
| ) | |
| # SAM should NOT be called | |
| assert not mock_sam.called, ( | |
| "SAMLoader should NOT be called when adetailer=False" | |
| ) | |
| class TestMultiscaleRouting: | |
| """Test multiscale diffusion parameter routing.""" | |
| def test_multiscale_preset_applied(self, mock_all_heavy_dependencies): | |
| """multiscale_preset should configure multiscale parameters.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| multiscale_preset="performance", | |
| ) | |
| ksampler = mock_all_heavy_dependencies['ksampler'] | |
| # Verify sample was called with multiscale parameters | |
| call_kwargs = ksampler.return_value.sample.call_args | |
| if call_kwargs: | |
| # Check that multiscale params were passed | |
| kwargs = call_kwargs.kwargs if call_kwargs.kwargs else {} | |
| # The pipeline should pass enable_multiscale to the sampler | |
| assert 'enable_multiscale' in kwargs or True # May be positional | |
| def test_multiscale_disabled_preset(self, mock_all_heavy_dependencies): | |
| """multiscale_preset='disabled' should disable multiscale.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| multiscale_preset="disabled", | |
| ) | |
| # Should still run without error | |
| ksampler = mock_all_heavy_dependencies['ksampler'] | |
| assert ksampler.return_value.sample.called | |
| class TestDeepCacheRouting: | |
| """Test DeepCache parameter routing.""" | |
| def test_deepcache_enabled_applies_patch(self, mock_all_heavy_dependencies): | |
| """deepcache_enabled=True should apply DeepCache patch.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| deepcache_enabled=True, | |
| ) | |
| # Verify the DeepCache applier was called | |
| dc_applier = mock_all_heavy_dependencies['dc_applier'] | |
| assert dc_applier.return_value.patch.called, ( | |
| "DeepCache should be applied when deepcache_enabled=True" | |
| ) | |
| def test_deepcache_disabled_skips_patch(self, mock_all_heavy_dependencies): | |
| """deepcache_enabled=False should skip DeepCache patch.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| deepcache_enabled=False, | |
| ) | |
| # Verify the DeepCache applier was NOT called | |
| dc_applier = mock_all_heavy_dependencies['dc_applier'] | |
| assert not dc_applier.return_value.patch.called, ( | |
| "DeepCache should NOT be applied when deepcache_enabled=False" | |
| ) | |
| class TestStableFastRouting: | |
| """Test StableFast parameter routing.""" | |
| def test_stable_fast_enabled_applies_optimization(self, mock_all_heavy_dependencies): | |
| """stable_fast=True should apply StableFast optimization.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| stable_fast=True, | |
| ) | |
| # Verify the StableFast applier was called | |
| sf_applier = mock_all_heavy_dependencies['sf_applier'] | |
| assert sf_applier.return_value.apply_stable_fast.called, ( | |
| "StableFast should be applied when stable_fast=True" | |
| ) | |
| class TestBatchedPromptRouting: | |
| """Test batched prompt routing (multiple prompts at once).""" | |
| def test_batched_prompts_use_batched_path(self, mock_all_heavy_dependencies): | |
| """List of prompts should use batched generation path. | |
| Note: This test is skipped because the pipeline internally iterates | |
| over batched results, but our mocks return fixed single-item tensors. | |
| Proper testing would require dynamic mock configuration. | |
| """ | |
| from src.user.pipeline import pipeline | |
| prompts = ["prompt 1", "prompt 2", "prompt 3"] | |
| result = pipeline( | |
| prompt=prompts, | |
| w=512, | |
| h=512, | |
| ) | |
| # Result should indicate batched processing | |
| if "batched_results" in result: | |
| assert isinstance(result["batched_results"], dict) | |
| class TestAutoHDRRouting: | |
| """Test AutoHDR parameter routing.""" | |
| def test_autohdr_enabled_applies_effect(self, mock_all_heavy_dependencies): | |
| """autohdr=True should apply HDR effect.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| autohdr=True, | |
| ) | |
| hdr = mock_all_heavy_dependencies['hdr'] | |
| assert hdr.return_value.apply_hdr2.called, ( | |
| "HDR effect should be applied when autohdr=True" | |
| ) | |
| def test_autohdr_disabled_skips_effect(self, mock_all_heavy_dependencies): | |
| """autohdr=False should skip HDR effect.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| autohdr=False, | |
| ) | |
| hdr = mock_all_heavy_dependencies['hdr'] | |
| # Note: The implementation may still create the HDR object but not call it | |
| # This depends on the exact implementation | |
| class TestCFGFreeRouting: | |
| """Test CFG-free sampling parameter routing.""" | |
| def test_cfg_free_params_passed_to_sampler(self, mock_all_heavy_dependencies): | |
| """CFG-free parameters should be passed to the sampler.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| cfg_free_enabled=True, | |
| cfg_free_start_percent=70.0, | |
| ) | |
| ksampler = mock_all_heavy_dependencies['ksampler'] | |
| # Verify the sampler was called | |
| assert ksampler.return_value.sample.called | |
| class TestTokenMergingRouting: | |
| """Test Token Merging (ToMe) parameter routing.""" | |
| def test_tome_params_passed_when_enabled(self, mock_all_heavy_dependencies): | |
| """ToMe parameters should be applied when enabled.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| tome_enabled=True, | |
| tome_ratio=0.5, | |
| ) | |
| # Should run without error even if ToMe isn't fully mocked | |
| class TestNegativePromptRouting: | |
| """Test negative prompt handling.""" | |
| def test_empty_negative_prompt_uses_default(self, mock_all_heavy_dependencies): | |
| """Empty negative prompt should use default.""" | |
| from src.user.pipeline import pipeline | |
| result = pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| negative_prompt="", | |
| ) | |
| # Should run without error | |
| assert result is not None | |
| def test_custom_negative_prompt_passed(self, mock_all_heavy_dependencies): | |
| """Custom negative prompt should be used.""" | |
| from src.user.pipeline import pipeline | |
| custom_negative = "ugly, bad quality, distorted" | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| negative_prompt=custom_negative, | |
| ) | |
| # CLIP encoder should be called (implicitly tests negative prompt was used) | |
| clip_encode = mock_all_heavy_dependencies['clip_encode'] | |
| assert clip_encode.return_value.encode.called | |
| class TestSeedRouting: | |
| """Test seed handling and reuse_seed flag.""" | |
| def test_reuse_seed_uses_last_seed(self, mock_all_heavy_dependencies): | |
| """reuse_seed=True should use the last seed.""" | |
| from src.user.pipeline import pipeline | |
| # First run to establish a seed | |
| pipeline(prompt="test", w=512, h=512, reuse_seed=False) | |
| # Second run with reuse_seed | |
| pipeline(prompt="test", w=512, h=512, reuse_seed=True) | |
| # Should run without error | |
| ksampler = mock_all_heavy_dependencies['ksampler'] | |
| assert ksampler.return_value.sample.call_count >= 2 | |
| class TestModelPathRouting: | |
| """Test model_path parameter routing.""" | |
| def test_custom_model_path_used(self, mock_all_heavy_dependencies): | |
| """Custom model_path should be passed to loader.""" | |
| from src.user.pipeline import pipeline | |
| custom_path = "/path/to/custom_model.safetensors" | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| model_path=custom_path, | |
| ) | |
| loader = mock_all_heavy_dependencies['loader'] | |
| if loader.return_value.load_checkpoint.called: | |
| call_args = loader.return_value.load_checkpoint.call_args | |
| assert custom_path in str(call_args), ( | |
| f"Custom model path should be used: {call_args}" | |
| ) | |
| class TestErrorHandling: | |
| """Test pipeline error handling.""" | |
| def test_invalid_model_path_raises_error(self, mock_all_heavy_dependencies): | |
| """Invalid model path should raise a clean error.""" | |
| from src.user.pipeline import pipeline | |
| # Make the loader raise an error | |
| mock_all_heavy_dependencies['loader'].return_value.load_checkpoint.side_effect = ( | |
| FileNotFoundError("Model not found") | |
| ) | |
| mock_all_heavy_dependencies['model_cache'].return_value.get_cached_checkpoint.return_value = None | |
| with pytest.raises(FileNotFoundError): | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| model_path="/nonexistent/model.safetensors", | |
| ) | |
| def test_interruption_handled_gracefully(self, mock_all_heavy_dependencies): | |
| """Interruption should raise InterruptedError.""" | |
| from src.user.pipeline import pipeline | |
| # Mock interrupt flag being set | |
| mock_app = MagicMock() | |
| mock_app.interrupt_flag = True | |
| with patch('src.user.app_instance.app', mock_app): | |
| with pytest.raises(InterruptedError): | |
| pipeline(prompt="test", w=512, h=512) | |
| class TestSchedulerSamplerRouting: | |
| """Test scheduler and sampler parameter routing.""" | |
| def test_scheduler_passed_to_sampler(self, mock_all_heavy_dependencies, scheduler): | |
| """Scheduler parameter should be passed to KSampler.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| scheduler=scheduler, | |
| ) | |
| ksampler = mock_all_heavy_dependencies['ksampler'] | |
| assert ksampler.return_value.sample.called | |
| def test_sampler_passed_to_ksampler(self, mock_all_heavy_dependencies, sampler): | |
| """Sampler parameter should be passed to KSampler.""" | |
| from src.user.pipeline import pipeline | |
| pipeline( | |
| prompt="test", | |
| w=512, | |
| h=512, | |
| sampler=sampler, | |
| ) | |
| ksampler = mock_all_heavy_dependencies['ksampler'] | |
| assert ksampler.return_value.sample.called | |