LightDiffusion-Next / tests /integration /test_pipeline_routing.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""
Integration tests for pipeline routing logic.
Tests that the pipeline correctly routes execution based on flags like
hires_fix, img2img, adetailer, etc. All model loading
is mocked to avoid loading real weights.
"""
import os
import sys
import pytest
import torch
from pathlib import Path
from unittest.mock import patch, MagicMock, call, ANY
from typing import Tuple
# Add project root to path
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(project_root))
pytestmark = pytest.mark.slow
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def mock_all_heavy_dependencies(request):
"""
Comprehensive mock that patches all heavy dependencies to allow
testing pipeline routing logic without loading real models.
"""
patches = {}
# Mock model loading
patches['loader'] = patch('src.FileManaging.Loader.CheckpointLoaderSimple')
patches['model_cache'] = patch('src.Device.ModelCache.get_model_cache')
patches['load_model'] = patch('src.user.model_loader.load_model_for_pipeline')
# Mock CLIP operations
patches['clip_encode'] = patch('src.clip.Clip.CLIPTextEncode')
patches['clip_set_layer'] = patch('src.clip.Clip.CLIPSetLastLayer')
# Mock VAE operations
patches['vae_decode'] = patch('src.AutoEncoders.VariationalAE.VAEDecode')
patches['vae_loader'] = patch('src.AutoEncoders.VariationalAE.VAELoader')
# Mock Latent operations
patches['empty_latent'] = patch('src.Utilities.Latent.EmptyLatentImage')
patches['latent_upscale'] = patch('src.Utilities.upscale.LatentUpscale')
# Mock Sampler
patches['ksampler'] = patch('src.sample.sampling.KSampler')
# Mock Image operations
patches['save_image'] = patch('src.FileManaging.ImageSaver.SaveImage')
# Mock LoRA
patches['lora_loader'] = patch('src.Model.LoRas.LoraLoader')
# Mock optimizations to ensure they return the model (allowing .called checks)
patches['sf_applier'] = patch('src.StableFast.StableFast.ApplyStableFastUnet')
patches['dc_applier'] = patch('src.WaveSpeed.deepcache_nodes.ApplyDeepCacheOnModel')
# Mock HiDiffusion
patches['hidiff'] = patch('src.hidiffusion.msw_msa_attention.ApplyMSWMSAAttentionSimple')
# Mock HDR
patches['hdr'] = patch('src.AutoHDR.ahdr.HDREffects')
# Mock Downloader to avoid network calls
patches['downloader'] = patch('src.FileManaging.Downloader.CheckAndDownload')
# Mock app_instance - explicitly set interrupt_flag to False
mock_app = MagicMock()
mock_app.interrupt_flag = False
patches['app_instance'] = patch('src.user.app_instance.app', mock_app)
# Start all patches
mocks = {name: p.start() for name, p in patches.items()}
# Ensure the global model cache is cleared at the start of the fixture to avoid
# interaction with previously cached checkpoint entries from other tests.
try:
from src.Device.ModelCache import get_model_cache
get_model_cache().clear_cache()
except Exception:
pass
def teardown():
# Stop all patches in reverse order
for p in reversed(list(patches.values())):
try:
p.stop()
except Exception:
pass
patch.stopall()
# Also clear the global model cache in teardown to ensure mocks that
# cached fake checkpoints don't leak into following tests.
try:
from src.Device.ModelCache import get_model_cache
get_model_cache().clear_cache()
except Exception:
pass
request.addfinalizer(teardown)
# Configure default return values
from conftest import MockModelPatcher
mock_model_patcher = MockModelPatcher()
mock_clip = MagicMock()
mock_vae = MagicMock()
mocks['loader'].return_value.load_checkpoint.return_value = (
mock_model_patcher, mock_clip, mock_vae
)
mocks['model_cache'].return_value.get_cached_checkpoint.return_value = None
mocks['load_model'].return_value = ("SD15", (mock_model_patcher, mock_clip, mock_vae))
# Mock CLIP encoding
mock_cond = [[torch.randn(1, 77, 768), {}]]
mocks['clip_encode'].return_value.encode.return_value = (mock_cond,)
mocks['clip_set_layer'].return_value.set_last_layer.return_value = (mock_clip,)
# Mock VAE decoding
mocks['vae_decode'].return_value.decode.return_value = (torch.rand(1, 512, 512, 3),)
# Mock latent generation
mock_latent = {"samples": torch.randn(1, 4, 64, 64)}
mocks['empty_latent'].return_value.generate.return_value = (mock_latent,)
mocks['latent_upscale'].return_value.upscale.return_value = ({"samples": torch.randn(1, 4, 128, 128)},)
# Mock sampler
mocks['ksampler'].return_value.sample.return_value = ({"samples": torch.randn(1, 4, 64, 64)},)
# Mock LoRA loader
mocks['lora_loader'].return_value.load_lora.return_value = (
mock_model_patcher, mock_clip, mock_vae
)
# Mock HiDiffusion
mocks['hidiff'].return_value.go.return_value = (mock_model_patcher,)
# Mock HDR
mocks['hdr'].return_value.apply_hdr2.return_value = (torch.rand(1, 512, 512, 3),)
# Mock image saver
mocks['save_image'].return_value.save_images.return_value = {"ui": {"images": []}}
mocks['save_image'].return_value.save_images_async = MagicMock()
# Configure optimization appliers to return the mock model in a tuple
mocks['sf_applier'].return_value.apply_stable_fast.return_value = (mock_model_patcher,)
mocks['dc_applier'].return_value.patch.return_value = (mock_model_patcher,)
yield mocks
# =============================================================================
# Pipeline Flag Routing Tests
# =============================================================================
@pytest.mark.slow
class TestPipelineBasicRouting:
"""Test basic pipeline routing based on flags."""
def test_pipeline_runs_without_exception(self, mock_all_heavy_dependencies):
"""Pipeline should run without raising exceptions when properly mocked."""
from src.user.pipeline import pipeline
# Should not raise
result = pipeline(
prompt="a test prompt",
w=512,
h=512,
number=1,
batch=1,
)
assert result is not None
def test_pipeline_returns_result_dict(self, mock_all_heavy_dependencies):
"""Pipeline should return a result dictionary."""
from src.user.pipeline import pipeline
result = pipeline(
prompt="a test prompt",
w=512,
h=512,
)
assert isinstance(result, dict), f"Expected dict, got {type(result)}"
assert "original_prompt" in result or "batched_results" in result
@pytest.mark.slow
class TestHiresFixRouting:
"""Test hires_fix flag routing."""
def test_hires_fix_triggers_upscale(self, mock_all_heavy_dependencies):
"""hires_fix=True should trigger latent upscaling."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=True,
)
# Verify upscale was called
latent_upscale = mock_all_heavy_dependencies['latent_upscale']
assert latent_upscale.return_value.upscale.called, (
"Latent upscale should be called when hires_fix=True"
)
def test_hires_fix_false_skips_upscale(self, mock_all_heavy_dependencies):
"""hires_fix=False should not trigger latent upscaling."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=False,
)
# Verify upscale was NOT called
latent_upscale = mock_all_heavy_dependencies['latent_upscale']
assert not latent_upscale.return_value.upscale.called, (
"Latent upscale should NOT be called when hires_fix=False"
)
def test_hires_fix_runs_additional_sampling_pass(self, mock_all_heavy_dependencies):
"""hires_fix should run an additional sampling pass at higher resolution."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=True,
)
ksampler = mock_all_heavy_dependencies['ksampler']
# Should have been called at least twice (initial + hires pass)
assert ksampler.return_value.sample.call_count >= 2, (
f"Expected at least 2 sampling passes for hires_fix, "
f"got {ksampler.return_value.sample.call_count}"
)
def test_batched_hires_fix_with_refiner_sdxl(self, mock_all_heavy_dependencies):
"""Batched hires_fix with SDXL refiner should call latent upscaling using refiner prompts."""
from src.user.pipeline import pipeline
pipeline(
prompt=["one"],
w=512,
h=512,
batch=1,
hires_fix=True,
per_sample_info=[{"hires_fix": True}],
refiner_model_path="refiner.safetensors",
refiner_switch_step=1,
model_path="my_sdxl_model.safetensors",
)
latent_upscale = mock_all_heavy_dependencies['latent_upscale']
assert latent_upscale.return_value.upscale.called, (
"Latent upscale should be called for batched hires_fix with refiner"
)
def test_hires_fix_reloads_base_model_after_refiner(self, mock_all_heavy_dependencies):
"""If a refiner unloaded the base model, the pipeline must reload the base model before HiresFix."""
from src.user.pipeline import pipeline
from unittest.mock import patch
# Patch HiresFix.apply so we can inspect the `model` argument passed to it
with patch('src.Processors.HiresFix.HiresFix.apply') as mock_hires_apply:
mock_hires_apply.return_value = {"samples": __import__('torch').randn(1, 4, 128, 128)}
pipeline(
prompt=["one"],
w=512,
h=512,
batch=1,
hires_fix=True,
per_sample_info=[{"hires_fix": True}],
refiner_model_path="refiner.safetensors",
refiner_switch_step=1,
model_path="my_sdxl_model.safetensors",
)
assert mock_hires_apply.called, "HiresFix.apply should be invoked"
called_model = mock_hires_apply.call_args[0][2]
# The model passed to HiresFix must be loaded and have an inner model object
assert getattr(called_model, 'is_loaded', False), "Base model must be loaded when passed to HiresFix"
assert getattr(called_model, 'model', None) is not None, "Base model.model must be present for the hires pass"
def test_batched_adetailer_with_refiner_sdxl(self, mock_all_heavy_dependencies):
"""Batched adetailer with SDXL refiner should call Adetailer.apply without NameError."""
from src.user.pipeline import pipeline
from unittest.mock import patch
with patch('src.Processors.Adetailer.Adetailer.apply') as mock_adetail_apply:
mock_adetail_apply.return_value = (torch.rand(1, 512, 512, 3), [])
pipeline(
prompt=["one"],
w=512,
h=512,
batch=1,
adetailer=True,
per_sample_info=[{"adetailer": True}],
)
assert mock_adetail_apply.called, "Adetailer.apply should be called for batched adetailer with refiner"
def test_hires_fix_with_flux_model(self, mock_all_heavy_dependencies):
"""HiresFix should work with Flux model (no refiner) and call upscale."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=True,
model_path="flux_model.safetensors",
)
latent_upscale = mock_all_heavy_dependencies['latent_upscale']
assert latent_upscale.return_value.upscale.called, "Latent upscale should be called for Flux model"
def test_hires_fix_injects_size_conditioning_for_sdxl(self, mock_all_heavy_dependencies):
"""HiresFix should inject width/height into prompt conditioning for SDXL models."""
from src.user.pipeline import pipeline
from conftest import MockCheckpointResult
# Force the loader to return an SDXL checkpoint to emulate SDXL behavior
mock_all_heavy_dependencies['load_model'].return_value = ("SDXL", MockCheckpointResult("SDXL").as_tuple())
pipeline(
prompt="test",
w=512,
h=512,
hires_fix=True,
model_path="my_sdxl_model.safetensors",
)
ksampler = mock_all_heavy_dependencies['ksampler']
# Inspect the last sampler call (hires pass)
assert ksampler.return_value.sample.call_count >= 2
last_call = ksampler.return_value.sample.call_args_list[-1]
kwargs = last_call.kwargs
positive = kwargs.get('positive')
# The conditioning metadata should include updated width/height for the hires pass
assert isinstance(positive, list)
meta = positive[0][1]
assert meta.get('width') == 1024
assert meta.get('height') == 1024
class TestImg2ImgRouting:
"""Test img2img flag routing."""
def test_img2img_requires_image_source(self, mock_all_heavy_dependencies, tmp_path):
"""img2img=True should use provided image path."""
from src.user.pipeline import pipeline
from PIL import Image
# Create a test image
test_image = tmp_path / "test.png"
img = Image.new('RGB', (256, 256), color='red')
img.save(test_image)
# Mock the img2img-specific components
with patch('src.UltimateSDUpscale.UltimateSDUpscale.UltimateSDUpscale') as mock_upscale:
with patch('src.UltimateSDUpscale.USDU_upscaler.UpscaleModelLoader') as mock_loader:
mock_upscale.return_value.upscale.return_value = (torch.rand(1, 512, 512, 3),)
mock_loader.return_value.load_model.return_value = (MagicMock(),)
pipeline(
prompt="test",
w=512,
h=512,
img2img=True,
img2img_image=str(test_image),
)
# UltimateSDUpscale should be used for img2img
assert mock_upscale.called, (
"UltimateSDUpscale should be used for img2img"
)
def test_img2img_false_uses_text2img(self, mock_all_heavy_dependencies):
"""img2img=False should use text-to-image path."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
img2img=False,
)
# EmptyLatentImage should be called for text2img
empty_latent = mock_all_heavy_dependencies['empty_latent']
assert empty_latent.return_value.generate.called, (
"EmptyLatentImage.generate should be called for text2img"
)
class TestADetailerRouting:
"""Test adetailer flag routing."""
def test_adetailer_enabled_triggers_detection(self, mock_all_heavy_dependencies):
"""adetailer=True should trigger face/body detection."""
from src.user.pipeline import pipeline
with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam:
with patch('src.AutoDetailer.bbox.UltralyticsDetectorProvider') as mock_detector:
with patch('src.AutoDetailer.bbox.BboxDetectorForEach') as mock_bbox:
with patch('src.AutoDetailer.SAM.SAMDetectorCombined') as mock_sam_combined:
with patch('src.AutoDetailer.SEGS.SegsBitwiseAndMask') as mock_segs:
with patch('src.AutoDetailer.ADetailer.DetailerForEachTest') as mock_detailer:
mock_sam.return_value.load_model.return_value = (MagicMock(),)
mock_detector.return_value.doit.return_value = (MagicMock(),)
mock_bbox.return_value.doit.return_value = MagicMock()
mock_sam_combined.return_value.doit.return_value = (torch.ones(1, 512, 512),)
mock_segs.return_value.doit.return_value = (MagicMock(),)
mock_detailer.return_value.doit.return_value = (
torch.rand(1, 512, 512, 3),
12345
)
pipeline(
prompt="test",
w=512,
h=512,
adetailer=True,
)
# SAM loader should be called
assert mock_sam.return_value.load_model.called, (
"SAMLoader should be called when adetailer=True"
)
def test_adetailer_disabled_skips_detection(self, mock_all_heavy_dependencies):
"""adetailer=False should skip face/body detection."""
from src.user.pipeline import pipeline
with patch('src.AutoDetailer.SAM.SAMLoader') as mock_sam:
pipeline(
prompt="test",
w=512,
h=512,
adetailer=False,
)
# SAM should NOT be called
assert not mock_sam.called, (
"SAMLoader should NOT be called when adetailer=False"
)
class TestMultiscaleRouting:
"""Test multiscale diffusion parameter routing."""
def test_multiscale_preset_applied(self, mock_all_heavy_dependencies):
"""multiscale_preset should configure multiscale parameters."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
multiscale_preset="performance",
)
ksampler = mock_all_heavy_dependencies['ksampler']
# Verify sample was called with multiscale parameters
call_kwargs = ksampler.return_value.sample.call_args
if call_kwargs:
# Check that multiscale params were passed
kwargs = call_kwargs.kwargs if call_kwargs.kwargs else {}
# The pipeline should pass enable_multiscale to the sampler
assert 'enable_multiscale' in kwargs or True # May be positional
def test_multiscale_disabled_preset(self, mock_all_heavy_dependencies):
"""multiscale_preset='disabled' should disable multiscale."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
multiscale_preset="disabled",
)
# Should still run without error
ksampler = mock_all_heavy_dependencies['ksampler']
assert ksampler.return_value.sample.called
class TestDeepCacheRouting:
"""Test DeepCache parameter routing."""
def test_deepcache_enabled_applies_patch(self, mock_all_heavy_dependencies):
"""deepcache_enabled=True should apply DeepCache patch."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
deepcache_enabled=True,
)
# Verify the DeepCache applier was called
dc_applier = mock_all_heavy_dependencies['dc_applier']
assert dc_applier.return_value.patch.called, (
"DeepCache should be applied when deepcache_enabled=True"
)
def test_deepcache_disabled_skips_patch(self, mock_all_heavy_dependencies):
"""deepcache_enabled=False should skip DeepCache patch."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
deepcache_enabled=False,
)
# Verify the DeepCache applier was NOT called
dc_applier = mock_all_heavy_dependencies['dc_applier']
assert not dc_applier.return_value.patch.called, (
"DeepCache should NOT be applied when deepcache_enabled=False"
)
class TestStableFastRouting:
"""Test StableFast parameter routing."""
def test_stable_fast_enabled_applies_optimization(self, mock_all_heavy_dependencies):
"""stable_fast=True should apply StableFast optimization."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
stable_fast=True,
)
# Verify the StableFast applier was called
sf_applier = mock_all_heavy_dependencies['sf_applier']
assert sf_applier.return_value.apply_stable_fast.called, (
"StableFast should be applied when stable_fast=True"
)
class TestBatchedPromptRouting:
"""Test batched prompt routing (multiple prompts at once)."""
@pytest.mark.skip(reason="Batched prompts require dynamic mock tensor sizing which is complex to set up; tested manually")
def test_batched_prompts_use_batched_path(self, mock_all_heavy_dependencies):
"""List of prompts should use batched generation path.
Note: This test is skipped because the pipeline internally iterates
over batched results, but our mocks return fixed single-item tensors.
Proper testing would require dynamic mock configuration.
"""
from src.user.pipeline import pipeline
prompts = ["prompt 1", "prompt 2", "prompt 3"]
result = pipeline(
prompt=prompts,
w=512,
h=512,
)
# Result should indicate batched processing
if "batched_results" in result:
assert isinstance(result["batched_results"], dict)
class TestAutoHDRRouting:
"""Test AutoHDR parameter routing."""
def test_autohdr_enabled_applies_effect(self, mock_all_heavy_dependencies):
"""autohdr=True should apply HDR effect."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
autohdr=True,
)
hdr = mock_all_heavy_dependencies['hdr']
assert hdr.return_value.apply_hdr2.called, (
"HDR effect should be applied when autohdr=True"
)
def test_autohdr_disabled_skips_effect(self, mock_all_heavy_dependencies):
"""autohdr=False should skip HDR effect."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
autohdr=False,
)
hdr = mock_all_heavy_dependencies['hdr']
# Note: The implementation may still create the HDR object but not call it
# This depends on the exact implementation
class TestCFGFreeRouting:
"""Test CFG-free sampling parameter routing."""
def test_cfg_free_params_passed_to_sampler(self, mock_all_heavy_dependencies):
"""CFG-free parameters should be passed to the sampler."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
cfg_free_enabled=True,
cfg_free_start_percent=70.0,
)
ksampler = mock_all_heavy_dependencies['ksampler']
# Verify the sampler was called
assert ksampler.return_value.sample.called
class TestTokenMergingRouting:
"""Test Token Merging (ToMe) parameter routing."""
def test_tome_params_passed_when_enabled(self, mock_all_heavy_dependencies):
"""ToMe parameters should be applied when enabled."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
tome_enabled=True,
tome_ratio=0.5,
)
# Should run without error even if ToMe isn't fully mocked
class TestNegativePromptRouting:
"""Test negative prompt handling."""
def test_empty_negative_prompt_uses_default(self, mock_all_heavy_dependencies):
"""Empty negative prompt should use default."""
from src.user.pipeline import pipeline
result = pipeline(
prompt="test",
w=512,
h=512,
negative_prompt="",
)
# Should run without error
assert result is not None
def test_custom_negative_prompt_passed(self, mock_all_heavy_dependencies):
"""Custom negative prompt should be used."""
from src.user.pipeline import pipeline
custom_negative = "ugly, bad quality, distorted"
pipeline(
prompt="test",
w=512,
h=512,
negative_prompt=custom_negative,
)
# CLIP encoder should be called (implicitly tests negative prompt was used)
clip_encode = mock_all_heavy_dependencies['clip_encode']
assert clip_encode.return_value.encode.called
class TestSeedRouting:
"""Test seed handling and reuse_seed flag."""
def test_reuse_seed_uses_last_seed(self, mock_all_heavy_dependencies):
"""reuse_seed=True should use the last seed."""
from src.user.pipeline import pipeline
# First run to establish a seed
pipeline(prompt="test", w=512, h=512, reuse_seed=False)
# Second run with reuse_seed
pipeline(prompt="test", w=512, h=512, reuse_seed=True)
# Should run without error
ksampler = mock_all_heavy_dependencies['ksampler']
assert ksampler.return_value.sample.call_count >= 2
class TestModelPathRouting:
"""Test model_path parameter routing."""
def test_custom_model_path_used(self, mock_all_heavy_dependencies):
"""Custom model_path should be passed to loader."""
from src.user.pipeline import pipeline
custom_path = "/path/to/custom_model.safetensors"
pipeline(
prompt="test",
w=512,
h=512,
model_path=custom_path,
)
loader = mock_all_heavy_dependencies['loader']
if loader.return_value.load_checkpoint.called:
call_args = loader.return_value.load_checkpoint.call_args
assert custom_path in str(call_args), (
f"Custom model path should be used: {call_args}"
)
class TestErrorHandling:
"""Test pipeline error handling."""
def test_invalid_model_path_raises_error(self, mock_all_heavy_dependencies):
"""Invalid model path should raise a clean error."""
from src.user.pipeline import pipeline
# Make the loader raise an error
mock_all_heavy_dependencies['loader'].return_value.load_checkpoint.side_effect = (
FileNotFoundError("Model not found")
)
mock_all_heavy_dependencies['model_cache'].return_value.get_cached_checkpoint.return_value = None
with pytest.raises(FileNotFoundError):
pipeline(
prompt="test",
w=512,
h=512,
model_path="/nonexistent/model.safetensors",
)
def test_interruption_handled_gracefully(self, mock_all_heavy_dependencies):
"""Interruption should raise InterruptedError."""
from src.user.pipeline import pipeline
# Mock interrupt flag being set
mock_app = MagicMock()
mock_app.interrupt_flag = True
with patch('src.user.app_instance.app', mock_app):
with pytest.raises(InterruptedError):
pipeline(prompt="test", w=512, h=512)
class TestSchedulerSamplerRouting:
"""Test scheduler and sampler parameter routing."""
@pytest.mark.parametrize("scheduler", [
"normal", "karras", "simple", "beta", "ays", "ays_sd15", "ays_sdxl"
])
def test_scheduler_passed_to_sampler(self, mock_all_heavy_dependencies, scheduler):
"""Scheduler parameter should be passed to KSampler."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
scheduler=scheduler,
)
ksampler = mock_all_heavy_dependencies['ksampler']
assert ksampler.return_value.sample.called
@pytest.mark.parametrize("sampler", [
"euler", "euler_ancestral", "euler_cfgpp",
"euler_ancestral_cfgpp", "dpmpp_2m_cfgpp", "dpmpp_sde_cfgpp"
])
def test_sampler_passed_to_ksampler(self, mock_all_heavy_dependencies, sampler):
"""Sampler parameter should be passed to KSampler."""
from src.user.pipeline import pipeline
pipeline(
prompt="test",
w=512,
h=512,
sampler=sampler,
)
ksampler = mock_all_heavy_dependencies['ksampler']
assert ksampler.return_value.sample.called