Spaces:
Running on Zero
Running on Zero
File size: 30,103 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 | """
Integration tests for pipeline routing logic.
Tests that the pipeline correctly routes execution based on flags like
hires_fix, img2img, adetailer, etc. All model loading
is mocked to avoid loading real weights.
"""
import os
import sys
import pytest
import torch
from pathlib import Path
from unittest.mock import patch, MagicMock, call, ANY
from typing import Tuple
# Add project root to path
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(project_root))
pytestmark = pytest.mark.slow
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def mock_all_heavy_dependencies(request):
"""
Comprehensive mock that patches all heavy dependencies to allow
testing pipeline routing logic without loading real models.
"""
patches = {}
# Mock model loading
patches['loader'] = patch('src.FileManaging.Loader.CheckpointLoaderSimple')
patches['model_cache'] = patch('src.Device.ModelCache.get_model_cache')
patches['load_model'] = patch('src.user.model_loader.load_model_for_pipeline')
# Mock CLIP operations
patches['clip_encode'] = patch('src.clip.Clip.CLIPTextEncode')
patches['clip_set_layer'] = patch('src.clip.Clip.CLIPSetLastLayer')
# Mock VAE operations
patches['vae_decode'] = patch('src.AutoEncoders.VariationalAE.VAEDecode')
patches['vae_loader'] = patch('src.AutoEncoders.VariationalAE.VAELoader')
# Mock Latent operations
patches['empty_latent'] = patch('src.Utilities.Latent.EmptyLatentImage')
patches['latent_upscale'] = patch('src.Utilities.upscale.LatentUpscale')
# Mock Sampler
patches['ksampler'] = patch('src.sample.sampling.KSampler')
# Mock Image operations
patches['save_image'] = patch('src.FileManaging.ImageSaver.SaveImage')
# Mock LoRA
patches['lora_loader'] = patch('src.Model.LoRas.LoraLoader')
# Mock optimizations to ensure they return the model (allowing .called checks)
patches['sf_applier'] = patch('src.StableFast.StableFast.ApplyStableFastUnet')
patches['dc_applier'] = patch('src.WaveSpeed.deepcache_nodes.ApplyDeepCacheOnModel')
# Mock HiDiffusion
patches['hidiff'] = patch('src.hidiffusion.msw_msa_attention.ApplyMSWMSAAttentionSimple')
# Mock HDR
patches['hdr'] = patch('src.AutoHDR.ahdr.HDREffects')
# Mock Downloader to avoid network calls
patches['downloader'] = patch('src.FileManaging.Downloader.CheckAndDownload')
# Mock app_instance - explicitly set interrupt_flag to False
mock_app = MagicMock()
mock_app.interrupt_flag = False
patches['app_instance'] = patch('src.user.app_instance.app', mock_app)
# Start all patches
mocks = {name: p.start() for name, p in patches.items()}
# Ensure the global model cache is cleared at the start of the fixture to avoid
# interaction with previously cached checkpoint entries from other tests.
try:
from src.Device.ModelCache import get_model_cache
get_model_cache().clear_cache()
except Exception:
pass
def teardown():
# Stop all patches in reverse order
for p in reversed(list(patches.values())):
try:
p.stop()
except Exception:
pass
patch.stopall()
# Also clear the global model cache in teardown to ensure mocks that
# cached fake checkpoints don't leak into following tests.
try:
from src.Device.ModelCache import get_model_cache
get_model_cache().clear_cache()
except Exception:
pass
request.addfinalizer(teardown)
# Configure default return values
from conftest import MockModelPatcher
mock_model_patcher = MockModelPatcher()
mock_clip = MagicMock()
mock_vae = MagicMock()
mocks['loader'].return_value.load_checkpoint.return_value = (
mock_model_patcher, mock_clip, mock_vae
)
mocks['model_cache'].return_value.get_cached_checkpoint.return_value = None
mocks['load_model'].return_value = ("SD15", (mock_model_patcher, mock_clip, mock_vae))
# Mock CLIP encoding
mock_cond = [[torch.randn(1, 77, 768), {}]]
mocks['clip_encode'].return_value.encode.return_value = (mock_cond,)
mocks['clip_set_layer'].return_value.set_last_layer.return_value = (mock_clip,)
# Mock VAE decoding
mocks['vae_decode'].return_value.decode.return_value = (torch.rand(1, 512, 512, 3),)
# Mock latent generation
mock_latent = {"samples": torch.randn(1, 4, 64, 64)}
mocks['empty_latent'].return_value.generate.return_value = (mock_latent,)
mocks['latent_upscale'].return_value.upscale.return_value = ({"samples": torch.randn(1, 4, 128, 128)},)
# Mock sampler
mocks['ksampler'].return_value.sample.return_value = ({"samples": torch.randn(1, 4, 64, 64)},)
# Mock LoRA loader
mocks['lora_loader'].return_value.load_lora.return_value = (
mock_model_patcher, mock_clip, mock_vae
)
# Mock HiDiffusion
mocks['hidiff'].return_value.go.return_value = (mock_model_patcher,)
# Mock HDR
mocks['hdr'].return_value.apply_hdr2.return_value = (torch.rand(1, 512, 512, 3),)
# Mock image saver
mocks['save_image'].return_value.save_images.return_value = {"ui": {"images": []}}
mocks['save_image'].return_value.save_images_async = MagicMock()
# Configure optimization appliers to return the mock model in a tuple
mocks['sf_applier'].return_value.apply_stable_fast.return_value = (mock_model_patcher,)
mocks['dc_applier'].return_value.patch.return_value = (mock_model_patcher,)
yield mocks
# =============================================================================
# Pipeline Flag Routing Tests
# =============================================================================
@pytest.mark.slow
class TestPipelineBasicRouting:
"""Test basic pipeline routing based on flags."""
def test_pipeline_runs_without_exception(self, mock_all_heavy_dependencies):
"""Pipeline should run without raising exceptions when properly mocked."""
from src.user.pipeline import pipeline
# Should not raise
result = pipeline(
prompt="a test prompt",
w=512,
h=512,
number=1,
batch=1,
)
assert result is not None
def test_pipeline_returns_result_dict(self, mock_all_heavy_dependencies):
"""Pipeline should return a result dictionary."""
from src.user.pipeline import pipeline
result = pipeline(
prompt="a test prompt",
w=512,
h=512,
)
assert isinstance(result, dict), f"Expected dict, got {type(result)}"
assert "original_prompt" in result or "batched_results" in result
@pytest.mark.slow
class TestHiresFixRouting:
"""Test hires_fix flag routing."""
def test_hires_fix_triggers_upscale(self, mock_all_heavy_dependencies):
"""hires_fix=True should trigger latent upscaling."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=True,
)
# Verify upscale was called
latent_upscale = mock_all_heavy_dependencies['latent_upscale']
assert latent_upscale.return_value.upscale.called, (
"Latent upscale should be called when hires_fix=True"
)
def test_hires_fix_false_skips_upscale(self, mock_all_heavy_dependencies):
"""hires_fix=False should not trigger latent upscaling."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=False,
)
# Verify upscale was NOT called
latent_upscale = mock_all_heavy_dependencies['latent_upscale']
assert not latent_upscale.return_value.upscale.called, (
"Latent upscale should NOT be called when hires_fix=False"
)
def test_hires_fix_runs_additional_sampling_pass(self, mock_all_heavy_dependencies):
"""hires_fix should run an additional sampling pass at higher resolution."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=True,
)
ksampler = mock_all_heavy_dependencies['ksampler']
# Should have been called at least twice (initial + hires pass)
assert ksampler.return_value.sample.call_count >= 2, (
f"Expected at least 2 sampling passes for hires_fix, "
f"got {ksampler.return_value.sample.call_count}"
)
def test_batched_hires_fix_with_refiner_sdxl(self, mock_all_heavy_dependencies):
"""Batched hires_fix with SDXL refiner should call latent upscaling using refiner prompts."""
from src.user.pipeline import pipeline
pipeline(
prompt=["one"],
w=512,
h=512,
batch=1,
hires_fix=True,
per_sample_info=[{"hires_fix": True}],
refiner_model_path="refiner.safetensors",
refiner_switch_step=1,
model_path="my_sdxl_model.safetensors",
)
latent_upscale = mock_all_heavy_dependencies['latent_upscale']
assert latent_upscale.return_value.upscale.called, (
"Latent upscale should be called for batched hires_fix with refiner"
)
def test_hires_fix_reloads_base_model_after_refiner(self, mock_all_heavy_dependencies):
"""If a refiner unloaded the base model, the pipeline must reload the base model before HiresFix."""
from src.user.pipeline import pipeline
from unittest.mock import patch
# Patch HiresFix.apply so we can inspect the `model` argument passed to it
with patch('src.Processors.HiresFix.HiresFix.apply') as mock_hires_apply:
mock_hires_apply.return_value = {"samples": __import__('torch').randn(1, 4, 128, 128)}
pipeline(
prompt=["one"],
w=512,
h=512,
batch=1,
hires_fix=True,
per_sample_info=[{"hires_fix": True}],
refiner_model_path="refiner.safetensors",
refiner_switch_step=1,
model_path="my_sdxl_model.safetensors",
)
assert mock_hires_apply.called, "HiresFix.apply should be invoked"
called_model = mock_hires_apply.call_args[0][2]
# The model passed to HiresFix must be loaded and have an inner model object
assert getattr(called_model, 'is_loaded', False), "Base model must be loaded when passed to HiresFix"
assert getattr(called_model, 'model', None) is not None, "Base model.model must be present for the hires pass"
def test_batched_adetailer_with_refiner_sdxl(self, mock_all_heavy_dependencies):
"""Batched adetailer with SDXL refiner should call Adetailer.apply without NameError."""
from src.user.pipeline import pipeline
from unittest.mock import patch
with patch('src.Processors.Adetailer.Adetailer.apply') as mock_adetail_apply:
mock_adetail_apply.return_value = (torch.rand(1, 512, 512, 3), [])
pipeline(
prompt=["one"],
w=512,
h=512,
batch=1,
adetailer=True,
per_sample_info=[{"adetailer": True}],
)
assert mock_adetail_apply.called, "Adetailer.apply should be called for batched adetailer with refiner"
def test_hires_fix_with_flux_model(self, mock_all_heavy_dependencies):
"""HiresFix should work with Flux model (no refiner) and call upscale."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=True,
model_path="flux_model.safetensors",
)
latent_upscale = mock_all_heavy_dependencies['latent_upscale']
assert latent_upscale.return_value.upscale.called, "Latent upscale should be called for Flux model"
def test_hires_fix_injects_size_conditioning_for_sdxl(self, mock_all_heavy_dependencies):
"""HiresFix should inject width/height into prompt conditioning for SDXL models."""
from src.user.pipeline import pipeline
from conftest import MockCheckpointResult
# Force the loader to return an SDXL checkpoint to emulate SDXL behavior
mock_all_heavy_dependencies['load_model'].return_value = ("SDXL", MockCheckpointResult("SDXL").as_tuple())
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=True,
model_path="my_sdxl_model.safetensors",
)
ksampler = mock_all_heavy_dependencies['ksampler']
# Inspect the last sampler call (hires pass)
assert ksampler.return_value.sample.call_count >= 2
last_call = ksampler.return_value.sample.call_args_list[-1]
kwargs = last_call.kwargs
positive = kwargs.get('positive')
# The conditioning metadata should include updated width/height for the hires pass
assert isinstance(positive, list)
meta = positive[0][1]
assert meta.get('width') == 1024
assert meta.get('height') == 1024
class TestImg2ImgRouting:
"""Test img2img flag routing."""
def test_img2img_requires_image_source(self, mock_all_heavy_dependencies, tmp_path):
"""img2img=True should use provided image path."""
from src.user.pipeline import pipeline
from PIL import Image
# Create a test image
test_image = tmp_path / "test.png"
img = Image.new('RGB', (256, 256), color='red')
img.save(test_image)
# Mock the img2img-specific components
with patch('src.UltimateSDUpscale.UltimateSDUpscale.UltimateSDUpscale') as mock_upscale:
with patch('src.UltimateSDUpscale.USDU_upscaler.UpscaleModelLoader') as mock_loader:
mock_upscale.return_value.upscale.return_value = (torch.rand(1, 512, 512, 3),)
mock_loader.return_value.load_model.return_value = (MagicMock(),)
pipeline(
prompt="test",
w=512,
h=512,
img2img=True,
img2img_image=str(test_image),
)
# UltimateSDUpscale should be used for img2img
assert mock_upscale.called, (
"UltimateSDUpscale should be used for img2img"
)
def test_img2img_false_uses_text2img(self, mock_all_heavy_dependencies):
"""img2img=False should use text-to-image path."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
img2img=False,
)
# EmptyLatentImage should be called for text2img
empty_latent = mock_all_heavy_dependencies['empty_latent']
assert empty_latent.return_value.generate.called, (
"EmptyLatentImage.generate should be called for text2img"
)
class TestADetailerRouting:
"""Test adetailer flag routing."""
def test_adetailer_enabled_triggers_detection(self, mock_all_heavy_dependencies):
"""adetailer=True should trigger face/body detection."""
from src.user.pipeline import pipeline
with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam:
with patch('src.AutoDetailer.bbox.UltralyticsDetectorProvider') as mock_detector:
with patch('src.AutoDetailer.bbox.BboxDetectorForEach') as mock_bbox:
with patch('src.AutoDetailer.SAM.SAMDetectorCombined') as mock_sam_combined:
with patch('src.AutoDetailer.SEGS.SegsBitwiseAndMask') as mock_segs:
with patch('src.AutoDetailer.ADetailer.DetailerForEachTest') as mock_detailer:
mock_sam.return_value.load_model.return_value = (MagicMock(),)
mock_detector.return_value.doit.return_value = (MagicMock(),)
mock_bbox.return_value.doit.return_value = MagicMock()
mock_sam_combined.return_value.doit.return_value = (torch.ones(1, 512, 512),)
mock_segs.return_value.doit.return_value = (MagicMock(),)
mock_detailer.return_value.doit.return_value = (
torch.rand(1, 512, 512, 3),
12345
)
pipeline(
prompt="test",
w=512,
h=512,
adetailer=True,
)
# SAM loader should be called
assert mock_sam.return_value.load_model.called, (
"SAMLoader should be called when adetailer=True"
)
def test_adetailer_disabled_skips_detection(self, mock_all_heavy_dependencies):
"""adetailer=False should skip face/body detection."""
from src.user.pipeline import pipeline
with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam:
pipeline(
prompt="test",
w=512,
h=512,
adetailer=False,
)
# SAM should NOT be called
assert not mock_sam.called, (
"SAMLoader should NOT be called when adetailer=False"
)
class TestMultiscaleRouting:
"""Test multiscale diffusion parameter routing."""
def test_multiscale_preset_applied(self, mock_all_heavy_dependencies):
"""multiscale_preset should configure multiscale parameters."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
multiscale_preset="performance",
)
ksampler = mock_all_heavy_dependencies['ksampler']
# Verify sample was called with multiscale parameters
call_kwargs = ksampler.return_value.sample.call_args
if call_kwargs:
# Check that multiscale params were passed
kwargs = call_kwargs.kwargs if call_kwargs.kwargs else {}
# The pipeline should pass enable_multiscale to the sampler
assert 'enable_multiscale' in kwargs or True # May be positional
def test_multiscale_disabled_preset(self, mock_all_heavy_dependencies):
"""multiscale_preset='disabled' should disable multiscale."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
multiscale_preset="disabled",
)
# Should still run without error
ksampler = mock_all_heavy_dependencies['ksampler']
assert ksampler.return_value.sample.called
class TestDeepCacheRouting:
"""Test DeepCache parameter routing."""
def test_deepcache_enabled_applies_patch(self, mock_all_heavy_dependencies):
"""deepcache_enabled=True should apply DeepCache patch."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
deepcache_enabled=True,
)
# Verify the DeepCache applier was called
dc_applier = mock_all_heavy_dependencies['dc_applier']
assert dc_applier.return_value.patch.called, (
"DeepCache should be applied when deepcache_enabled=True"
)
def test_deepcache_disabled_skips_patch(self, mock_all_heavy_dependencies):
"""deepcache_enabled=False should skip DeepCache patch."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
deepcache_enabled=False,
)
# Verify the DeepCache applier was NOT called
dc_applier = mock_all_heavy_dependencies['dc_applier']
assert not dc_applier.return_value.patch.called, (
"DeepCache should NOT be applied when deepcache_enabled=False"
)
class TestStableFastRouting:
"""Test StableFast parameter routing."""
def test_stable_fast_enabled_applies_optimization(self, mock_all_heavy_dependencies):
"""stable_fast=True should apply StableFast optimization."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
stable_fast=True,
)
# Verify the StableFast applier was called
sf_applier = mock_all_heavy_dependencies['sf_applier']
assert sf_applier.return_value.apply_stable_fast.called, (
"StableFast should be applied when stable_fast=True"
)
class TestBatchedPromptRouting:
"""Test batched prompt routing (multiple prompts at once)."""
@pytest.mark.skip(reason="Batched prompts require dynamic mock tensor sizing which is complex to set up; tested manually")
def test_batched_prompts_use_batched_path(self, mock_all_heavy_dependencies):
"""List of prompts should use batched generation path.
Note: This test is skipped because the pipeline internally iterates
over batched results, but our mocks return fixed single-item tensors.
Proper testing would require dynamic mock configuration.
"""
from src.user.pipeline import pipeline
prompts = ["prompt 1", "prompt 2", "prompt 3"]
result = pipeline(
prompt=prompts,
w=512,
h=512,
)
# Result should indicate batched processing
if "batched_results" in result:
assert isinstance(result["batched_results"], dict)
class TestAutoHDRRouting:
"""Test AutoHDR parameter routing."""
def test_autohdr_enabled_applies_effect(self, mock_all_heavy_dependencies):
"""autohdr=True should apply HDR effect."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
autohdr=True,
)
hdr = mock_all_heavy_dependencies['hdr']
assert hdr.return_value.apply_hdr2.called, (
"HDR effect should be applied when autohdr=True"
)
def test_autohdr_disabled_skips_effect(self, mock_all_heavy_dependencies):
"""autohdr=False should skip HDR effect."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
autohdr=False,
)
hdr = mock_all_heavy_dependencies['hdr']
# Note: The implementation may still create the HDR object but not call it
# This depends on the exact implementation
class TestCFGFreeRouting:
"""Test CFG-free sampling parameter routing."""
def test_cfg_free_params_passed_to_sampler(self, mock_all_heavy_dependencies):
"""CFG-free parameters should be passed to the sampler."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
cfg_free_enabled=True,
cfg_free_start_percent=70.0,
)
ksampler = mock_all_heavy_dependencies['ksampler']
# Verify the sampler was called
assert ksampler.return_value.sample.called
class TestTokenMergingRouting:
"""Test Token Merging (ToMe) parameter routing."""
def test_tome_params_passed_when_enabled(self, mock_all_heavy_dependencies):
"""ToMe parameters should be applied when enabled."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
tome_enabled=True,
tome_ratio=0.5,
)
# Should run without error even if ToMe isn't fully mocked
class TestNegativePromptRouting:
"""Test negative prompt handling."""
def test_empty_negative_prompt_uses_default(self, mock_all_heavy_dependencies):
"""Empty negative prompt should use default."""
from src.user.pipeline import pipeline
result = pipeline(
prompt="test",
w=512,
h=512,
negative_prompt="",
)
# Should run without error
assert result is not None
def test_custom_negative_prompt_passed(self, mock_all_heavy_dependencies):
"""Custom negative prompt should be used."""
from src.user.pipeline import pipeline
custom_negative = "ugly, bad quality, distorted"
pipeline(
prompt="test",
w=512,
h=512,
negative_prompt=custom_negative,
)
# CLIP encoder should be called (implicitly tests negative prompt was used)
clip_encode = mock_all_heavy_dependencies['clip_encode']
assert clip_encode.return_value.encode.called
class TestSeedRouting:
"""Test seed handling and reuse_seed flag."""
def test_reuse_seed_uses_last_seed(self, mock_all_heavy_dependencies):
"""reuse_seed=True should use the last seed."""
from src.user.pipeline import pipeline
# First run to establish a seed
pipeline(prompt="test", w=512, h=512, reuse_seed=False)
# Second run with reuse_seed
pipeline(prompt="test", w=512, h=512, reuse_seed=True)
# Should run without error
ksampler = mock_all_heavy_dependencies['ksampler']
assert ksampler.return_value.sample.call_count >= 2
class TestModelPathRouting:
"""Test model_path parameter routing."""
def test_custom_model_path_used(self, mock_all_heavy_dependencies):
"""Custom model_path should be passed to loader."""
from src.user.pipeline import pipeline
custom_path = "/path/to/custom_model.safetensors"
pipeline(
prompt="test",
w=512,
h=512,
model_path=custom_path,
)
loader = mock_all_heavy_dependencies['loader']
if loader.return_value.load_checkpoint.called:
call_args = loader.return_value.load_checkpoint.call_args
assert custom_path in str(call_args), (
f"Custom model path should be used: {call_args}"
)
class TestErrorHandling:
"""Test pipeline error handling."""
def test_invalid_model_path_raises_error(self, mock_all_heavy_dependencies):
"""Invalid model path should raise a clean error."""
from src.user.pipeline import pipeline
# Make the loader raise an error
mock_all_heavy_dependencies['loader'].return_value.load_checkpoint.side_effect = (
FileNotFoundError("Model not found")
)
mock_all_heavy_dependencies['model_cache'].return_value.get_cached_checkpoint.return_value = None
with pytest.raises(FileNotFoundError):
pipeline(
prompt="test",
w=512,
h=512,
model_path="/nonexistent/model.safetensors",
)
def test_interruption_handled_gracefully(self, mock_all_heavy_dependencies):
"""Interruption should raise InterruptedError."""
from src.user.pipeline import pipeline
# Mock interrupt flag being set
mock_app = MagicMock()
mock_app.interrupt_flag = True
with patch('src.user.app_instance.app', mock_app):
with pytest.raises(InterruptedError):
pipeline(prompt="test", w=512, h=512)
class TestSchedulerSamplerRouting:
"""Test scheduler and sampler parameter routing."""
@pytest.mark.parametrize("scheduler", [
"normal", "karras", "simple", "beta", "ays", "ays_sd15", "ays_sdxl"
])
def test_scheduler_passed_to_sampler(self, mock_all_heavy_dependencies, scheduler):
"""Scheduler parameter should be passed to KSampler."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
scheduler=scheduler,
)
ksampler = mock_all_heavy_dependencies['ksampler']
assert ksampler.return_value.sample.called
@pytest.mark.parametrize("sampler", [
"euler", "euler_ancestral", "euler_cfgpp",
"euler_ancestral_cfgpp", "dpmpp_2m_cfgpp", "dpmpp_sde_cfgpp"
])
def test_sampler_passed_to_ksampler(self, mock_all_heavy_dependencies, sampler):
"""Sampler parameter should be passed to KSampler."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
sampler=sampler,
)
ksampler = mock_all_heavy_dependencies['ksampler']
assert ksampler.return_value.sample.called
|