from types import SimpleNamespace import pytest import torch from PIL import Image from examples import example as runner from examples import example_remote as remote_runner class FakeImage: def __init__(self, size): self.size = size self.saved = [] def save(self, path, **kwargs): self.saved.append((path, kwargs)) def _args(*extra): return runner.apply_window_stride_defaults(runner.parse_args(["--base-model", "model", *extra])) def test_quantization_parser_defaults_to_none(): args = _args() assert args.quantization == "none" assert args.transformer_quantization is None assert args.text_encoder_quantization is None assert args.vae_quantization is None assert args.allow_tf32 is False def test_allow_tf32_parser_flag(): args = _args("--allow-tf32") assert args.allow_tf32 is True def test_configure_torch_skips_tf32_when_not_allowed(monkeypatch): calls = [] monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: True) monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append) runner.configure_torch(allow_tf32=False) assert calls == [] def test_configure_torch_enables_tf32_when_allowed_on_cuda(monkeypatch): calls = [] original_fp32_precision = getattr(runner.torch.backends, "fp32_precision", None) monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: True) monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append) try: runner.configure_torch(allow_tf32=True) assert runner.torch.backends.fp32_precision == "tf32" assert calls == ["high"] finally: if original_fp32_precision is not None: runner.torch.backends.fp32_precision = original_fp32_precision def test_configure_torch_skips_tf32_without_cuda(monkeypatch): calls = [] monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: False) monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append) runner.configure_torch(allow_tf32=True) assert calls == [] def test_save_images_defaults_to_png_for_small_images(tmp_path): image = FakeImage((4096, 4096)) runner.save_images([image], tmp_path / "output", num_images_per_prompt=1) assert image.saved == [(tmp_path / "output.png", {})] def test_save_images_defaults_to_bigtiff_for_large_images(tmp_path): image = FakeImage((4097, 4096)) runner.save_images([image], tmp_path / "output", num_images_per_prompt=1) assert image.saved == [(tmp_path / "output.tif", {"format": "TIFF", "big_tiff": True})] def test_save_images_switches_large_png_output_to_bigtiff(tmp_path): image = FakeImage((4097, 4096)) with pytest.warns(UserWarning, match="saving BigTIFF"): runner.save_images([image], tmp_path / "output.png", num_images_per_prompt=1) assert image.saved == [(tmp_path / "output.tif", {"format": "TIFF", "big_tiff": True})] def test_save_images_preserves_numbered_large_bigtiff_stems(tmp_path): images = [FakeImage((4097, 4096)), FakeImage((4097, 4096))] runner.save_images(images, tmp_path / "output", num_images_per_prompt=2) assert images[0].saved == [(tmp_path / "output_0.tif", {"format": "TIFF", "big_tiff": True})] assert images[1].saved == [(tmp_path / "output_1.tif", {"format": "TIFF", "big_tiff": True})] def test_quantization_flag_without_value_uses_float8_weight_only_default(): args = _args("--quantization") assert args.quantization == "float8wo" def test_quantization_parser_accepts_component_overrides(): args = _args( "--quantization", "float8wo", "--transformer-quantization", "int4wo", "--text-encoder-quantization", "none", "--vae-quantization", "float8dyn", ) assert args.quantization == "float8wo" assert args.transformer_quantization == "int4wo" assert args.text_encoder_quantization == "none" assert args.vae_quantization == "float8dyn" def test_resolve_quantization_mapping_applies_base_strategy_to_all_quantizable_components(): mapping = runner.resolve_quantization_mapping("float8wo") assert mapping == { "transformer": "float8wo", "text_encoder": "float8wo", "vae": "float8wo", } def test_resolve_quantization_mapping_uses_component_overrides(): mapping = runner.resolve_quantization_mapping( "float8wo", transformer_quantization="int4wo", text_encoder_quantization="none", vae_quantization="float8dyn", ) assert mapping == {"transformer": "int4wo", "vae": "float8dyn"} def test_build_quantization_config_returns_none_when_disabled(): assert runner.build_quantization_config("none") is None def test_build_quantization_config_uses_component_specific_torchao_configs(): pytest.importorskip("torchao") config = runner.build_quantization_config( "float8wo", transformer_quantization="int4wo", text_encoder_quantization="int8wo", vae_quantization="float8dyn", ) assert set(config.quant_mapping) == {"transformer", "text_encoder", "vae"} assert type(config.quant_mapping["transformer"].quant_type).__name__ == "Int4WeightOnlyConfig" assert type(config.quant_mapping["text_encoder"].quant_type).__name__ == "Int8WeightOnlyConfig" assert type(config.quant_mapping["vae"].quant_type).__name__ == "Float8DynamicActivationFloat8WeightConfig" @pytest.mark.parametrize("weighting_type", ["none", "linear", "cosine"]) def test_call_kwargs_forward_all_weighting_types(weighting_type): args = _args("--weighting-type", weighting_type) generator = object() call_kwargs = runner.build_call_kwargs(args, SimpleNamespace(), generator) assert call_kwargs["weighting_type"] == weighting_type assert "num_inference_steps" not in call_kwargs assert call_kwargs["generator"] is generator def test_call_kwargs_forward_offset_and_panorama_options(): call_kwargs = runner.build_call_kwargs( _args( "--window-stride-height-offset", "128", "--window-stride-width-offset", "256", "--panorama-width", "--panorama-height", ), SimpleNamespace(), object(), ) assert call_kwargs["window_stride_height_offset"] == 128 assert call_kwargs["window_stride_width_offset"] == 256 assert call_kwargs["panorama_width"] is True assert call_kwargs["panorama_height"] is True def test_call_kwargs_forward_window_batch_size(): call_kwargs = runner.build_call_kwargs( _args("--batch-size", "6"), SimpleNamespace(), object(), ) assert call_kwargs["window_batch_size"] == 6 def test_call_kwargs_rejects_invalid_window_batch_size(): with pytest.raises(ValueError, match="batch-size"): runner.build_call_kwargs( _args("--batch-size", "0"), SimpleNamespace(), object(), ) def test_csv_prompt_loads_regional_masks_in_csv_order(tmp_path): masks_dir = tmp_path / "masks" masks_dir.mkdir() Image.new("L", (32, 32), 255).save(masks_dir / "00.png") Image.new("L", (32, 32), 255).save(masks_dir / "08.png") csv_path = tmp_path / "prompts.csv" csv_path.write_text('mask,prompt\n08.png,"a forest"\n00.png,"a mountain"\n', encoding="utf-8") call_kwargs = runner.build_call_kwargs( _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"), SimpleNamespace(vae_scale_factor=8), object(), ) assert call_kwargs["prompt"] == ["", "a forest", "a mountain"] assert torch.equal(call_kwargs["regional_masks"], torch.ones((2, 2, 2))) def test_csv_prompt_requires_masks_folder(tmp_path): csv_path = tmp_path / "prompts.csv" csv_path.write_text('mask,prompt\n00.png,"a mountain"\n', encoding="utf-8") with pytest.raises(ValueError, match="--masks"): runner.build_call_kwargs( _args("--prompt", str(csv_path), "--height", "32", "--width", "32"), SimpleNamespace(vae_scale_factor=8), object(), ) def test_regional_mask_loader_warns_for_non_binary_masks(tmp_path): masks_dir = tmp_path / "masks" masks_dir.mkdir() Image.new("L", (32, 32), 128).save(masks_dir / "00.png") csv_path = tmp_path / "prompts.csv" csv_path.write_text('mask,prompt\n00.png,"a foggy valley"\n', encoding="utf-8") with pytest.warns(UserWarning, match="not binary"): call_kwargs = runner.build_call_kwargs( _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"), SimpleNamespace(vae_scale_factor=8), object(), ) assert torch.equal(call_kwargs["regional_masks"], torch.full((1, 2, 2), 128 / 255)) def test_regional_mask_loader_rejects_wrong_size(tmp_path): masks_dir = tmp_path / "masks" masks_dir.mkdir() Image.new("L", (16, 32), 255).save(masks_dir / "00.png") csv_path = tmp_path / "prompts.csv" csv_path.write_text('mask,prompt\n00.png,"a mountain"\n', encoding="utf-8") with pytest.raises(ValueError, match="must have size"): runner.build_call_kwargs( _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"), SimpleNamespace(vae_scale_factor=8), object(), ) @pytest.mark.parametrize( ("config", "expected_class_name"), [ ({"_class_name": "Flux2Pipeline"}, "Flux2MultiDiffusionAutoBlocks"), ({"_class_name": "Flux2KleinPipeline", "is_distilled": True}, "Flux2KleinMultiDiffusionAutoBlocks"), ({"_class_name": "Flux2KleinPipeline", "is_distilled": False}, "Flux2KleinBaseMultiDiffusionAutoBlocks"), ({"_class_name": "Flux2KleinPipeline"}, "Flux2KleinBaseMultiDiffusionAutoBlocks"), ], ) def test_multidiffusion_blocks_are_selected_from_loaded_pipeline_config(config, expected_class_name): assert runner.build_multidiffusion_blocks(config).__class__.__name__ == expected_class_name def test_explicit_num_inference_steps_are_forwarded(): call_kwargs = runner.build_call_kwargs( _args("--num-inference-steps", "7"), SimpleNamespace(), object(), {"_class_name": "Flux2KleinPipeline", "is_distilled": True}, ) assert call_kwargs["num_inference_steps"] == 7 def test_distilled_klein_default_num_inference_steps_is_forwarded(): call_kwargs = runner.build_call_kwargs( _args(), SimpleNamespace(), object(), {"_class_name": "Flux2KleinPipeline", "is_distilled": True}, ) assert call_kwargs["num_inference_steps"] == 4 @pytest.mark.parametrize( "config", [ {"_class_name": "Flux2Pipeline"}, {"_class_name": "Flux2KleinPipeline", "is_distilled": False}, {"_class_name": "Flux2KleinModularPipeline", "is_distilled": False}, {"_class_name": "Flux2KleinPipeline"}, {}, ], ) def test_non_distilled_models_keep_pipeline_num_inference_steps_default(config): call_kwargs = runner.build_call_kwargs( _args(), SimpleNamespace(), object(), config, ) assert "num_inference_steps" not in call_kwargs def test_enable_tiling_and_slicing_call_vae_methods(): class FakeVae: def __init__(self): self.tiling_enabled = False self.slicing_enabled = False def enable_tiling(self): self.tiling_enabled = True def enable_slicing(self): self.slicing_enabled = True pipe = SimpleNamespace(vae=FakeVae()) runner.apply_vae_memory_options(pipe, enable_tiling=True, enable_slicing=True) assert pipe.vae.tiling_enabled assert pipe.vae.slicing_enabled def test_init_selects_blocks_from_model_config_and_keeps_model_guidance_default(monkeypatch): calls = [] class FakeBlocks: def init_pipeline(self, base_model, components_manager): calls.append((base_model, components_manager)) return FakePipe() class FakePipe: def __init__(self): self.guider = object() self.updated_components = None def load_components(self, **kwargs): self.load_kwargs = kwargs def update_components(self, **components): self.updated_components = components monkeypatch.setattr( runner, "build_multidiffusion_blocks", lambda model_config: calls.append(model_config) or FakeBlocks() ) pipe = runner.init_modular_pipeline( base_model="black-forest-labs/FLUX.2-klein-base-4B", model_config={"_class_name": "Flux2KleinPipeline"}, guidance_scale=None, dtype="float32", device="cpu", local_files_only=True, compile=False, ) assert calls[0] == {"_class_name": "Flux2KleinPipeline"} assert calls[1][0] == "black-forest-labs/FLUX.2-klein-base-4B" assert pipe.updated_components is None def test_init_applies_explicit_guidance_scale(monkeypatch): class FakeGuiderSpec: @staticmethod def create(guidance_scale): return {"guidance_scale": guidance_scale} class FakePipe: def __init__(self): self.guider = object() self.updated_components = None def load_components(self, **kwargs): self.load_kwargs = kwargs def get_component_spec(self, name): assert name == "guider" return FakeGuiderSpec() def update_components(self, **components): self.updated_components = components monkeypatch.setattr( runner, "build_multidiffusion_blocks", lambda model_config: SimpleNamespace(init_pipeline=lambda base_model, components_manager: FakePipe()), ) pipe = runner.init_modular_pipeline( base_model="black-forest-labs/FLUX.2-klein-base-4B", model_config={"_class_name": "Flux2KleinPipeline"}, guidance_scale=2.5, dtype="float32", device="cpu", local_files_only=True, compile=False, ) assert pipe.updated_components == {"guider": {"guidance_scale": 2.5}} def test_remote_parser_requires_hub_repo_id(): args = remote_runner.parse_args(["--repo-id", "user/multidiff-modular"]) assert args.repo_id == "user/multidiff-modular" def test_remote_init_uses_modular_from_pretrained_with_remote_code(monkeypatch): calls = {} class FakePipe: def __init__(self): self.guider = None def load_components(self, **kwargs): calls["load_components"] = kwargs def fake_from_pretrained(repo_id, **kwargs): calls["repo_id"] = repo_id calls["from_pretrained"] = kwargs return FakePipe() monkeypatch.setattr(remote_runner.ModularPipeline, "from_pretrained", fake_from_pretrained) pipe = remote_runner.init_remote_modular_pipeline( repo_id="user/multidiff-modular", guidance_scale=None, dtype="float32", device="cpu", local_files_only=True, compile=False, ) assert pipe.guider is None assert calls["repo_id"] == "user/multidiff-modular" assert calls["from_pretrained"]["trust_remote_code"] is True assert calls["from_pretrained"]["local_files_only"] is True assert calls["load_components"]["torch_dtype"] is runner.DTYPE_MAP["float32"] assert calls["load_components"]["local_files_only"] is True def test_image_conditioning_loads_and_forwards_image(monkeypatch): loaded_image = object() seen_paths = [] def fake_load_image(path): seen_paths.append(path) return loaded_image monkeypatch.setattr(runner, "load_image", fake_load_image) call_kwargs = runner.build_call_kwargs( _args("--image-conditioning", "conditioning.png"), SimpleNamespace(config={}), object(), ) assert seen_paths == ["conditioning.png"] assert call_kwargs["image"] is loaded_image def test_image_img2img_prepares_latents_and_forwards_strength(monkeypatch): loaded_image = object() raw_latents = object() generator = object() vae = object() class FakeImageProcessor: def __init__(self): self.calls = [] def preprocess(self, image, *, height, width, resize_mode): self.calls.append((image, height, width, resize_mode)) return "preprocessed-image" class FakeEncoderPipe: def __init__(self): self.updated_components = None self.calls = [] def update_components(self, **components): self.updated_components = components def __call__(self, *, condition_images, generator): self.calls.append((condition_images, generator)) return SimpleNamespace(image_latents=[raw_latents]) encoder_pipe = FakeEncoderPipe() encoder_block = SimpleNamespace(init_pipeline=lambda: encoder_pipe) img_conditioning_block = SimpleNamespace(sub_blocks={"encode": encoder_block}) vae_encoder_block = SimpleNamespace(sub_blocks={"img_conditioning": img_conditioning_block}) image_processor = FakeImageProcessor() pipe = SimpleNamespace( config={}, vae=vae, vae_scale_factor=8, blocks=SimpleNamespace(sub_blocks={"vae_encoder": vae_encoder_block}), image_processor=image_processor, ) monkeypatch.setattr(runner, "load_image", lambda path: loaded_image) monkeypatch.setattr( runner.Flux2PrepareImageLatentsStep, "_pack_latents", staticmethod(lambda latents: ("packed", latents)), ) call_kwargs = runner.build_call_kwargs( _args("--image-img2img", "init.png", "--height", "65", "--width", "81", "--strength", "0.42"), pipe, generator, ) assert image_processor.calls == [(loaded_image, 64, 80, "default")] assert encoder_pipe.updated_components == {"vae": vae} assert encoder_pipe.calls == [(["preprocessed-image"], generator)] assert call_kwargs["image_img2img"] == ("packed", raw_latents) assert call_kwargs["strength"] == 0.42