modular-flux2-multidiffusion / tests /test_example_runner.py
arlaz's picture
Initial Multidiff Modular export
d8bf41d
from types import SimpleNamespace
import pytest
import torch
from PIL import Image
from examples import example as runner
from examples import example_remote as remote_runner
class FakeImage:
def __init__(self, size):
self.size = size
self.saved = []
def save(self, path, **kwargs):
self.saved.append((path, kwargs))
def _args(*extra):
return runner.apply_window_stride_defaults(runner.parse_args(["--base-model", "model", *extra]))
def test_quantization_parser_defaults_to_none():
args = _args()
assert args.quantization == "none"
assert args.transformer_quantization is None
assert args.text_encoder_quantization is None
assert args.vae_quantization is None
assert args.allow_tf32 is False
def test_allow_tf32_parser_flag():
args = _args("--allow-tf32")
assert args.allow_tf32 is True
def test_configure_torch_skips_tf32_when_not_allowed(monkeypatch):
calls = []
monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append)
runner.configure_torch(allow_tf32=False)
assert calls == []
def test_configure_torch_enables_tf32_when_allowed_on_cuda(monkeypatch):
calls = []
original_fp32_precision = getattr(runner.torch.backends, "fp32_precision", None)
monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append)
try:
runner.configure_torch(allow_tf32=True)
assert runner.torch.backends.fp32_precision == "tf32"
assert calls == ["high"]
finally:
if original_fp32_precision is not None:
runner.torch.backends.fp32_precision = original_fp32_precision
def test_configure_torch_skips_tf32_without_cuda(monkeypatch):
calls = []
monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append)
runner.configure_torch(allow_tf32=True)
assert calls == []
def test_save_images_defaults_to_png_for_small_images(tmp_path):
image = FakeImage((4096, 4096))
runner.save_images([image], tmp_path / "output", num_images_per_prompt=1)
assert image.saved == [(tmp_path / "output.png", {})]
def test_save_images_defaults_to_bigtiff_for_large_images(tmp_path):
image = FakeImage((4097, 4096))
runner.save_images([image], tmp_path / "output", num_images_per_prompt=1)
assert image.saved == [(tmp_path / "output.tif", {"format": "TIFF", "big_tiff": True})]
def test_save_images_switches_large_png_output_to_bigtiff(tmp_path):
image = FakeImage((4097, 4096))
with pytest.warns(UserWarning, match="saving BigTIFF"):
runner.save_images([image], tmp_path / "output.png", num_images_per_prompt=1)
assert image.saved == [(tmp_path / "output.tif", {"format": "TIFF", "big_tiff": True})]
def test_save_images_preserves_numbered_large_bigtiff_stems(tmp_path):
images = [FakeImage((4097, 4096)), FakeImage((4097, 4096))]
runner.save_images(images, tmp_path / "output", num_images_per_prompt=2)
assert images[0].saved == [(tmp_path / "output_0.tif", {"format": "TIFF", "big_tiff": True})]
assert images[1].saved == [(tmp_path / "output_1.tif", {"format": "TIFF", "big_tiff": True})]
def test_quantization_flag_without_value_uses_float8_weight_only_default():
args = _args("--quantization")
assert args.quantization == "float8wo"
def test_quantization_parser_accepts_component_overrides():
args = _args(
"--quantization",
"float8wo",
"--transformer-quantization",
"int4wo",
"--text-encoder-quantization",
"none",
"--vae-quantization",
"float8dyn",
)
assert args.quantization == "float8wo"
assert args.transformer_quantization == "int4wo"
assert args.text_encoder_quantization == "none"
assert args.vae_quantization == "float8dyn"
def test_resolve_quantization_mapping_applies_base_strategy_to_all_quantizable_components():
mapping = runner.resolve_quantization_mapping("float8wo")
assert mapping == {
"transformer": "float8wo",
"text_encoder": "float8wo",
"vae": "float8wo",
}
def test_resolve_quantization_mapping_uses_component_overrides():
mapping = runner.resolve_quantization_mapping(
"float8wo",
transformer_quantization="int4wo",
text_encoder_quantization="none",
vae_quantization="float8dyn",
)
assert mapping == {"transformer": "int4wo", "vae": "float8dyn"}
def test_build_quantization_config_returns_none_when_disabled():
assert runner.build_quantization_config("none") is None
def test_build_quantization_config_uses_component_specific_torchao_configs():
pytest.importorskip("torchao")
config = runner.build_quantization_config(
"float8wo",
transformer_quantization="int4wo",
text_encoder_quantization="int8wo",
vae_quantization="float8dyn",
)
assert set(config.quant_mapping) == {"transformer", "text_encoder", "vae"}
assert type(config.quant_mapping["transformer"].quant_type).__name__ == "Int4WeightOnlyConfig"
assert type(config.quant_mapping["text_encoder"].quant_type).__name__ == "Int8WeightOnlyConfig"
assert type(config.quant_mapping["vae"].quant_type).__name__ == "Float8DynamicActivationFloat8WeightConfig"
@pytest.mark.parametrize("weighting_type", ["none", "linear", "cosine"])
def test_call_kwargs_forward_all_weighting_types(weighting_type):
args = _args("--weighting-type", weighting_type)
generator = object()
call_kwargs = runner.build_call_kwargs(args, SimpleNamespace(), generator)
assert call_kwargs["weighting_type"] == weighting_type
assert "num_inference_steps" not in call_kwargs
assert call_kwargs["generator"] is generator
def test_call_kwargs_forward_offset_and_panorama_options():
call_kwargs = runner.build_call_kwargs(
_args(
"--window-stride-height-offset",
"128",
"--window-stride-width-offset",
"256",
"--panorama-width",
"--panorama-height",
),
SimpleNamespace(),
object(),
)
assert call_kwargs["window_stride_height_offset"] == 128
assert call_kwargs["window_stride_width_offset"] == 256
assert call_kwargs["panorama_width"] is True
assert call_kwargs["panorama_height"] is True
def test_call_kwargs_forward_window_batch_size():
call_kwargs = runner.build_call_kwargs(
_args("--batch-size", "6"),
SimpleNamespace(),
object(),
)
assert call_kwargs["window_batch_size"] == 6
def test_call_kwargs_rejects_invalid_window_batch_size():
with pytest.raises(ValueError, match="batch-size"):
runner.build_call_kwargs(
_args("--batch-size", "0"),
SimpleNamespace(),
object(),
)
def test_csv_prompt_loads_regional_masks_in_csv_order(tmp_path):
masks_dir = tmp_path / "masks"
masks_dir.mkdir()
Image.new("L", (32, 32), 255).save(masks_dir / "00.png")
Image.new("L", (32, 32), 255).save(masks_dir / "08.png")
csv_path = tmp_path / "prompts.csv"
csv_path.write_text('mask,prompt\n08.png,"a forest"\n00.png,"a mountain"\n', encoding="utf-8")
call_kwargs = runner.build_call_kwargs(
_args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"),
SimpleNamespace(vae_scale_factor=8),
object(),
)
assert call_kwargs["prompt"] == ["", "a forest", "a mountain"]
assert torch.equal(call_kwargs["regional_masks"], torch.ones((2, 2, 2)))
def test_csv_prompt_requires_masks_folder(tmp_path):
csv_path = tmp_path / "prompts.csv"
csv_path.write_text('mask,prompt\n00.png,"a mountain"\n', encoding="utf-8")
with pytest.raises(ValueError, match="--masks"):
runner.build_call_kwargs(
_args("--prompt", str(csv_path), "--height", "32", "--width", "32"),
SimpleNamespace(vae_scale_factor=8),
object(),
)
def test_regional_mask_loader_warns_for_non_binary_masks(tmp_path):
masks_dir = tmp_path / "masks"
masks_dir.mkdir()
Image.new("L", (32, 32), 128).save(masks_dir / "00.png")
csv_path = tmp_path / "prompts.csv"
csv_path.write_text('mask,prompt\n00.png,"a foggy valley"\n', encoding="utf-8")
with pytest.warns(UserWarning, match="not binary"):
call_kwargs = runner.build_call_kwargs(
_args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"),
SimpleNamespace(vae_scale_factor=8),
object(),
)
assert torch.equal(call_kwargs["regional_masks"], torch.full((1, 2, 2), 128 / 255))
def test_regional_mask_loader_rejects_wrong_size(tmp_path):
masks_dir = tmp_path / "masks"
masks_dir.mkdir()
Image.new("L", (16, 32), 255).save(masks_dir / "00.png")
csv_path = tmp_path / "prompts.csv"
csv_path.write_text('mask,prompt\n00.png,"a mountain"\n', encoding="utf-8")
with pytest.raises(ValueError, match="must have size"):
runner.build_call_kwargs(
_args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"),
SimpleNamespace(vae_scale_factor=8),
object(),
)
@pytest.mark.parametrize(
("config", "expected_class_name"),
[
({"_class_name": "Flux2Pipeline"}, "Flux2MultiDiffusionAutoBlocks"),
({"_class_name": "Flux2KleinPipeline", "is_distilled": True}, "Flux2KleinMultiDiffusionAutoBlocks"),
({"_class_name": "Flux2KleinPipeline", "is_distilled": False}, "Flux2KleinBaseMultiDiffusionAutoBlocks"),
({"_class_name": "Flux2KleinPipeline"}, "Flux2KleinBaseMultiDiffusionAutoBlocks"),
],
)
def test_multidiffusion_blocks_are_selected_from_loaded_pipeline_config(config, expected_class_name):
assert runner.build_multidiffusion_blocks(config).__class__.__name__ == expected_class_name
def test_explicit_num_inference_steps_are_forwarded():
call_kwargs = runner.build_call_kwargs(
_args("--num-inference-steps", "7"),
SimpleNamespace(),
object(),
{"_class_name": "Flux2KleinPipeline", "is_distilled": True},
)
assert call_kwargs["num_inference_steps"] == 7
def test_distilled_klein_default_num_inference_steps_is_forwarded():
call_kwargs = runner.build_call_kwargs(
_args(),
SimpleNamespace(),
object(),
{"_class_name": "Flux2KleinPipeline", "is_distilled": True},
)
assert call_kwargs["num_inference_steps"] == 4
@pytest.mark.parametrize(
"config",
[
{"_class_name": "Flux2Pipeline"},
{"_class_name": "Flux2KleinPipeline", "is_distilled": False},
{"_class_name": "Flux2KleinModularPipeline", "is_distilled": False},
{"_class_name": "Flux2KleinPipeline"},
{},
],
)
def test_non_distilled_models_keep_pipeline_num_inference_steps_default(config):
call_kwargs = runner.build_call_kwargs(
_args(),
SimpleNamespace(),
object(),
config,
)
assert "num_inference_steps" not in call_kwargs
def test_enable_tiling_and_slicing_call_vae_methods():
class FakeVae:
def __init__(self):
self.tiling_enabled = False
self.slicing_enabled = False
def enable_tiling(self):
self.tiling_enabled = True
def enable_slicing(self):
self.slicing_enabled = True
pipe = SimpleNamespace(vae=FakeVae())
runner.apply_vae_memory_options(pipe, enable_tiling=True, enable_slicing=True)
assert pipe.vae.tiling_enabled
assert pipe.vae.slicing_enabled
def test_init_selects_blocks_from_model_config_and_keeps_model_guidance_default(monkeypatch):
calls = []
class FakeBlocks:
def init_pipeline(self, base_model, components_manager):
calls.append((base_model, components_manager))
return FakePipe()
class FakePipe:
def __init__(self):
self.guider = object()
self.updated_components = None
def load_components(self, **kwargs):
self.load_kwargs = kwargs
def update_components(self, **components):
self.updated_components = components
monkeypatch.setattr(
runner, "build_multidiffusion_blocks", lambda model_config: calls.append(model_config) or FakeBlocks()
)
pipe = runner.init_modular_pipeline(
base_model="black-forest-labs/FLUX.2-klein-base-4B",
model_config={"_class_name": "Flux2KleinPipeline"},
guidance_scale=None,
dtype="float32",
device="cpu",
local_files_only=True,
compile=False,
)
assert calls[0] == {"_class_name": "Flux2KleinPipeline"}
assert calls[1][0] == "black-forest-labs/FLUX.2-klein-base-4B"
assert pipe.updated_components is None
def test_init_applies_explicit_guidance_scale(monkeypatch):
class FakeGuiderSpec:
@staticmethod
def create(guidance_scale):
return {"guidance_scale": guidance_scale}
class FakePipe:
def __init__(self):
self.guider = object()
self.updated_components = None
def load_components(self, **kwargs):
self.load_kwargs = kwargs
def get_component_spec(self, name):
assert name == "guider"
return FakeGuiderSpec()
def update_components(self, **components):
self.updated_components = components
monkeypatch.setattr(
runner,
"build_multidiffusion_blocks",
lambda model_config: SimpleNamespace(init_pipeline=lambda base_model, components_manager: FakePipe()),
)
pipe = runner.init_modular_pipeline(
base_model="black-forest-labs/FLUX.2-klein-base-4B",
model_config={"_class_name": "Flux2KleinPipeline"},
guidance_scale=2.5,
dtype="float32",
device="cpu",
local_files_only=True,
compile=False,
)
assert pipe.updated_components == {"guider": {"guidance_scale": 2.5}}
def test_remote_parser_requires_hub_repo_id():
args = remote_runner.parse_args(["--repo-id", "user/multidiff-modular"])
assert args.repo_id == "user/multidiff-modular"
def test_remote_init_uses_modular_from_pretrained_with_remote_code(monkeypatch):
calls = {}
class FakePipe:
def __init__(self):
self.guider = None
def load_components(self, **kwargs):
calls["load_components"] = kwargs
def fake_from_pretrained(repo_id, **kwargs):
calls["repo_id"] = repo_id
calls["from_pretrained"] = kwargs
return FakePipe()
monkeypatch.setattr(remote_runner.ModularPipeline, "from_pretrained", fake_from_pretrained)
pipe = remote_runner.init_remote_modular_pipeline(
repo_id="user/multidiff-modular",
guidance_scale=None,
dtype="float32",
device="cpu",
local_files_only=True,
compile=False,
)
assert pipe.guider is None
assert calls["repo_id"] == "user/multidiff-modular"
assert calls["from_pretrained"]["trust_remote_code"] is True
assert calls["from_pretrained"]["local_files_only"] is True
assert calls["load_components"]["torch_dtype"] is runner.DTYPE_MAP["float32"]
assert calls["load_components"]["local_files_only"] is True
def test_image_conditioning_loads_and_forwards_image(monkeypatch):
loaded_image = object()
seen_paths = []
def fake_load_image(path):
seen_paths.append(path)
return loaded_image
monkeypatch.setattr(runner, "load_image", fake_load_image)
call_kwargs = runner.build_call_kwargs(
_args("--image-conditioning", "conditioning.png"),
SimpleNamespace(config={}),
object(),
)
assert seen_paths == ["conditioning.png"]
assert call_kwargs["image"] is loaded_image
def test_image_img2img_prepares_latents_and_forwards_strength(monkeypatch):
loaded_image = object()
raw_latents = object()
generator = object()
vae = object()
class FakeImageProcessor:
def __init__(self):
self.calls = []
def preprocess(self, image, *, height, width, resize_mode):
self.calls.append((image, height, width, resize_mode))
return "preprocessed-image"
class FakeEncoderPipe:
def __init__(self):
self.updated_components = None
self.calls = []
def update_components(self, **components):
self.updated_components = components
def __call__(self, *, condition_images, generator):
self.calls.append((condition_images, generator))
return SimpleNamespace(image_latents=[raw_latents])
encoder_pipe = FakeEncoderPipe()
encoder_block = SimpleNamespace(init_pipeline=lambda: encoder_pipe)
img_conditioning_block = SimpleNamespace(sub_blocks={"encode": encoder_block})
vae_encoder_block = SimpleNamespace(sub_blocks={"img_conditioning": img_conditioning_block})
image_processor = FakeImageProcessor()
pipe = SimpleNamespace(
config={},
vae=vae,
vae_scale_factor=8,
blocks=SimpleNamespace(sub_blocks={"vae_encoder": vae_encoder_block}),
image_processor=image_processor,
)
monkeypatch.setattr(runner, "load_image", lambda path: loaded_image)
monkeypatch.setattr(
runner.Flux2PrepareImageLatentsStep,
"_pack_latents",
staticmethod(lambda latents: ("packed", latents)),
)
call_kwargs = runner.build_call_kwargs(
_args("--image-img2img", "init.png", "--height", "65", "--width", "81", "--strength", "0.42"),
pipe,
generator,
)
assert image_processor.calls == [(loaded_image, 64, 80, "default")]
assert encoder_pipe.updated_components == {"vae": vae}
assert encoder_pipe.calls == [(["preprocessed-image"], generator)]
assert call_kwargs["image_img2img"] == ("packed", raw_latents)
assert call_kwargs["strength"] == 0.42