Instructions to use arlaz/modular-flux2-multidiffusion with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use arlaz/modular-flux2-multidiffusion with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("arlaz/modular-flux2-multidiffusion", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- Draw Things
- DiffusionBee
| from types import SimpleNamespace | |
| import pytest | |
| import torch | |
| from PIL import Image | |
| from examples import example as runner | |
| from examples import example_remote as remote_runner | |
| class FakeImage: | |
| def __init__(self, size): | |
| self.size = size | |
| self.saved = [] | |
| def save(self, path, **kwargs): | |
| self.saved.append((path, kwargs)) | |
| def _args(*extra): | |
| return runner.apply_window_stride_defaults(runner.parse_args(["--base-model", "model", *extra])) | |
| def test_quantization_parser_defaults_to_none(): | |
| args = _args() | |
| assert args.quantization == "none" | |
| assert args.transformer_quantization is None | |
| assert args.text_encoder_quantization is None | |
| assert args.vae_quantization is None | |
| assert args.allow_tf32 is False | |
| def test_allow_tf32_parser_flag(): | |
| args = _args("--allow-tf32") | |
| assert args.allow_tf32 is True | |
| def test_configure_torch_skips_tf32_when_not_allowed(monkeypatch): | |
| calls = [] | |
| monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: True) | |
| monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append) | |
| runner.configure_torch(allow_tf32=False) | |
| assert calls == [] | |
| def test_configure_torch_enables_tf32_when_allowed_on_cuda(monkeypatch): | |
| calls = [] | |
| original_fp32_precision = getattr(runner.torch.backends, "fp32_precision", None) | |
| monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: True) | |
| monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append) | |
| try: | |
| runner.configure_torch(allow_tf32=True) | |
| assert runner.torch.backends.fp32_precision == "tf32" | |
| assert calls == ["high"] | |
| finally: | |
| if original_fp32_precision is not None: | |
| runner.torch.backends.fp32_precision = original_fp32_precision | |
| def test_configure_torch_skips_tf32_without_cuda(monkeypatch): | |
| calls = [] | |
| monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: False) | |
| monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append) | |
| runner.configure_torch(allow_tf32=True) | |
| assert calls == [] | |
| def test_save_images_defaults_to_png_for_small_images(tmp_path): | |
| image = FakeImage((4096, 4096)) | |
| runner.save_images([image], tmp_path / "output", num_images_per_prompt=1) | |
| assert image.saved == [(tmp_path / "output.png", {})] | |
| def test_save_images_defaults_to_bigtiff_for_large_images(tmp_path): | |
| image = FakeImage((4097, 4096)) | |
| runner.save_images([image], tmp_path / "output", num_images_per_prompt=1) | |
| assert image.saved == [(tmp_path / "output.tif", {"format": "TIFF", "big_tiff": True})] | |
| def test_save_images_switches_large_png_output_to_bigtiff(tmp_path): | |
| image = FakeImage((4097, 4096)) | |
| with pytest.warns(UserWarning, match="saving BigTIFF"): | |
| runner.save_images([image], tmp_path / "output.png", num_images_per_prompt=1) | |
| assert image.saved == [(tmp_path / "output.tif", {"format": "TIFF", "big_tiff": True})] | |
| def test_save_images_preserves_numbered_large_bigtiff_stems(tmp_path): | |
| images = [FakeImage((4097, 4096)), FakeImage((4097, 4096))] | |
| runner.save_images(images, tmp_path / "output", num_images_per_prompt=2) | |
| assert images[0].saved == [(tmp_path / "output_0.tif", {"format": "TIFF", "big_tiff": True})] | |
| assert images[1].saved == [(tmp_path / "output_1.tif", {"format": "TIFF", "big_tiff": True})] | |
| def test_quantization_flag_without_value_uses_float8_weight_only_default(): | |
| args = _args("--quantization") | |
| assert args.quantization == "float8wo" | |
| def test_quantization_parser_accepts_component_overrides(): | |
| args = _args( | |
| "--quantization", | |
| "float8wo", | |
| "--transformer-quantization", | |
| "int4wo", | |
| "--text-encoder-quantization", | |
| "none", | |
| "--vae-quantization", | |
| "float8dyn", | |
| ) | |
| assert args.quantization == "float8wo" | |
| assert args.transformer_quantization == "int4wo" | |
| assert args.text_encoder_quantization == "none" | |
| assert args.vae_quantization == "float8dyn" | |
| def test_resolve_quantization_mapping_applies_base_strategy_to_all_quantizable_components(): | |
| mapping = runner.resolve_quantization_mapping("float8wo") | |
| assert mapping == { | |
| "transformer": "float8wo", | |
| "text_encoder": "float8wo", | |
| "vae": "float8wo", | |
| } | |
| def test_resolve_quantization_mapping_uses_component_overrides(): | |
| mapping = runner.resolve_quantization_mapping( | |
| "float8wo", | |
| transformer_quantization="int4wo", | |
| text_encoder_quantization="none", | |
| vae_quantization="float8dyn", | |
| ) | |
| assert mapping == {"transformer": "int4wo", "vae": "float8dyn"} | |
| def test_build_quantization_config_returns_none_when_disabled(): | |
| assert runner.build_quantization_config("none") is None | |
| def test_build_quantization_config_uses_component_specific_torchao_configs(): | |
| pytest.importorskip("torchao") | |
| config = runner.build_quantization_config( | |
| "float8wo", | |
| transformer_quantization="int4wo", | |
| text_encoder_quantization="int8wo", | |
| vae_quantization="float8dyn", | |
| ) | |
| assert set(config.quant_mapping) == {"transformer", "text_encoder", "vae"} | |
| assert type(config.quant_mapping["transformer"].quant_type).__name__ == "Int4WeightOnlyConfig" | |
| assert type(config.quant_mapping["text_encoder"].quant_type).__name__ == "Int8WeightOnlyConfig" | |
| assert type(config.quant_mapping["vae"].quant_type).__name__ == "Float8DynamicActivationFloat8WeightConfig" | |
| def test_call_kwargs_forward_all_weighting_types(weighting_type): | |
| args = _args("--weighting-type", weighting_type) | |
| generator = object() | |
| call_kwargs = runner.build_call_kwargs(args, SimpleNamespace(), generator) | |
| assert call_kwargs["weighting_type"] == weighting_type | |
| assert "num_inference_steps" not in call_kwargs | |
| assert call_kwargs["generator"] is generator | |
| def test_call_kwargs_forward_offset_and_panorama_options(): | |
| call_kwargs = runner.build_call_kwargs( | |
| _args( | |
| "--window-stride-height-offset", | |
| "128", | |
| "--window-stride-width-offset", | |
| "256", | |
| "--panorama-width", | |
| "--panorama-height", | |
| ), | |
| SimpleNamespace(), | |
| object(), | |
| ) | |
| assert call_kwargs["window_stride_height_offset"] == 128 | |
| assert call_kwargs["window_stride_width_offset"] == 256 | |
| assert call_kwargs["panorama_width"] is True | |
| assert call_kwargs["panorama_height"] is True | |
| def test_call_kwargs_forward_window_batch_size(): | |
| call_kwargs = runner.build_call_kwargs( | |
| _args("--batch-size", "6"), | |
| SimpleNamespace(), | |
| object(), | |
| ) | |
| assert call_kwargs["window_batch_size"] == 6 | |
| def test_call_kwargs_rejects_invalid_window_batch_size(): | |
| with pytest.raises(ValueError, match="batch-size"): | |
| runner.build_call_kwargs( | |
| _args("--batch-size", "0"), | |
| SimpleNamespace(), | |
| object(), | |
| ) | |
| def test_csv_prompt_loads_regional_masks_in_csv_order(tmp_path): | |
| masks_dir = tmp_path / "masks" | |
| masks_dir.mkdir() | |
| Image.new("L", (32, 32), 255).save(masks_dir / "00.png") | |
| Image.new("L", (32, 32), 255).save(masks_dir / "08.png") | |
| csv_path = tmp_path / "prompts.csv" | |
| csv_path.write_text('mask,prompt\n08.png,"a forest"\n00.png,"a mountain"\n', encoding="utf-8") | |
| call_kwargs = runner.build_call_kwargs( | |
| _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"), | |
| SimpleNamespace(vae_scale_factor=8), | |
| object(), | |
| ) | |
| assert call_kwargs["prompt"] == ["", "a forest", "a mountain"] | |
| assert torch.equal(call_kwargs["regional_masks"], torch.ones((2, 2, 2))) | |
| def test_csv_prompt_requires_masks_folder(tmp_path): | |
| csv_path = tmp_path / "prompts.csv" | |
| csv_path.write_text('mask,prompt\n00.png,"a mountain"\n', encoding="utf-8") | |
| with pytest.raises(ValueError, match="--masks"): | |
| runner.build_call_kwargs( | |
| _args("--prompt", str(csv_path), "--height", "32", "--width", "32"), | |
| SimpleNamespace(vae_scale_factor=8), | |
| object(), | |
| ) | |
| def test_regional_mask_loader_warns_for_non_binary_masks(tmp_path): | |
| masks_dir = tmp_path / "masks" | |
| masks_dir.mkdir() | |
| Image.new("L", (32, 32), 128).save(masks_dir / "00.png") | |
| csv_path = tmp_path / "prompts.csv" | |
| csv_path.write_text('mask,prompt\n00.png,"a foggy valley"\n', encoding="utf-8") | |
| with pytest.warns(UserWarning, match="not binary"): | |
| call_kwargs = runner.build_call_kwargs( | |
| _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"), | |
| SimpleNamespace(vae_scale_factor=8), | |
| object(), | |
| ) | |
| assert torch.equal(call_kwargs["regional_masks"], torch.full((1, 2, 2), 128 / 255)) | |
| def test_regional_mask_loader_rejects_wrong_size(tmp_path): | |
| masks_dir = tmp_path / "masks" | |
| masks_dir.mkdir() | |
| Image.new("L", (16, 32), 255).save(masks_dir / "00.png") | |
| csv_path = tmp_path / "prompts.csv" | |
| csv_path.write_text('mask,prompt\n00.png,"a mountain"\n', encoding="utf-8") | |
| with pytest.raises(ValueError, match="must have size"): | |
| runner.build_call_kwargs( | |
| _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"), | |
| SimpleNamespace(vae_scale_factor=8), | |
| object(), | |
| ) | |
| def test_multidiffusion_blocks_are_selected_from_loaded_pipeline_config(config, expected_class_name): | |
| assert runner.build_multidiffusion_blocks(config).__class__.__name__ == expected_class_name | |
| def test_explicit_num_inference_steps_are_forwarded(): | |
| call_kwargs = runner.build_call_kwargs( | |
| _args("--num-inference-steps", "7"), | |
| SimpleNamespace(), | |
| object(), | |
| {"_class_name": "Flux2KleinPipeline", "is_distilled": True}, | |
| ) | |
| assert call_kwargs["num_inference_steps"] == 7 | |
| def test_distilled_klein_default_num_inference_steps_is_forwarded(): | |
| call_kwargs = runner.build_call_kwargs( | |
| _args(), | |
| SimpleNamespace(), | |
| object(), | |
| {"_class_name": "Flux2KleinPipeline", "is_distilled": True}, | |
| ) | |
| assert call_kwargs["num_inference_steps"] == 4 | |
| def test_non_distilled_models_keep_pipeline_num_inference_steps_default(config): | |
| call_kwargs = runner.build_call_kwargs( | |
| _args(), | |
| SimpleNamespace(), | |
| object(), | |
| config, | |
| ) | |
| assert "num_inference_steps" not in call_kwargs | |
| def test_enable_tiling_and_slicing_call_vae_methods(): | |
| class FakeVae: | |
| def __init__(self): | |
| self.tiling_enabled = False | |
| self.slicing_enabled = False | |
| def enable_tiling(self): | |
| self.tiling_enabled = True | |
| def enable_slicing(self): | |
| self.slicing_enabled = True | |
| pipe = SimpleNamespace(vae=FakeVae()) | |
| runner.apply_vae_memory_options(pipe, enable_tiling=True, enable_slicing=True) | |
| assert pipe.vae.tiling_enabled | |
| assert pipe.vae.slicing_enabled | |
| def test_init_selects_blocks_from_model_config_and_keeps_model_guidance_default(monkeypatch): | |
| calls = [] | |
| class FakeBlocks: | |
| def init_pipeline(self, base_model, components_manager): | |
| calls.append((base_model, components_manager)) | |
| return FakePipe() | |
| class FakePipe: | |
| def __init__(self): | |
| self.guider = object() | |
| self.updated_components = None | |
| def load_components(self, **kwargs): | |
| self.load_kwargs = kwargs | |
| def update_components(self, **components): | |
| self.updated_components = components | |
| monkeypatch.setattr( | |
| runner, "build_multidiffusion_blocks", lambda model_config: calls.append(model_config) or FakeBlocks() | |
| ) | |
| pipe = runner.init_modular_pipeline( | |
| base_model="black-forest-labs/FLUX.2-klein-base-4B", | |
| model_config={"_class_name": "Flux2KleinPipeline"}, | |
| guidance_scale=None, | |
| dtype="float32", | |
| device="cpu", | |
| local_files_only=True, | |
| compile=False, | |
| ) | |
| assert calls[0] == {"_class_name": "Flux2KleinPipeline"} | |
| assert calls[1][0] == "black-forest-labs/FLUX.2-klein-base-4B" | |
| assert pipe.updated_components is None | |
| def test_init_applies_explicit_guidance_scale(monkeypatch): | |
| class FakeGuiderSpec: | |
| def create(guidance_scale): | |
| return {"guidance_scale": guidance_scale} | |
| class FakePipe: | |
| def __init__(self): | |
| self.guider = object() | |
| self.updated_components = None | |
| def load_components(self, **kwargs): | |
| self.load_kwargs = kwargs | |
| def get_component_spec(self, name): | |
| assert name == "guider" | |
| return FakeGuiderSpec() | |
| def update_components(self, **components): | |
| self.updated_components = components | |
| monkeypatch.setattr( | |
| runner, | |
| "build_multidiffusion_blocks", | |
| lambda model_config: SimpleNamespace(init_pipeline=lambda base_model, components_manager: FakePipe()), | |
| ) | |
| pipe = runner.init_modular_pipeline( | |
| base_model="black-forest-labs/FLUX.2-klein-base-4B", | |
| model_config={"_class_name": "Flux2KleinPipeline"}, | |
| guidance_scale=2.5, | |
| dtype="float32", | |
| device="cpu", | |
| local_files_only=True, | |
| compile=False, | |
| ) | |
| assert pipe.updated_components == {"guider": {"guidance_scale": 2.5}} | |
| def test_remote_parser_requires_hub_repo_id(): | |
| args = remote_runner.parse_args(["--repo-id", "user/multidiff-modular"]) | |
| assert args.repo_id == "user/multidiff-modular" | |
| def test_remote_init_uses_modular_from_pretrained_with_remote_code(monkeypatch): | |
| calls = {} | |
| class FakePipe: | |
| def __init__(self): | |
| self.guider = None | |
| def load_components(self, **kwargs): | |
| calls["load_components"] = kwargs | |
| def fake_from_pretrained(repo_id, **kwargs): | |
| calls["repo_id"] = repo_id | |
| calls["from_pretrained"] = kwargs | |
| return FakePipe() | |
| monkeypatch.setattr(remote_runner.ModularPipeline, "from_pretrained", fake_from_pretrained) | |
| pipe = remote_runner.init_remote_modular_pipeline( | |
| repo_id="user/multidiff-modular", | |
| guidance_scale=None, | |
| dtype="float32", | |
| device="cpu", | |
| local_files_only=True, | |
| compile=False, | |
| ) | |
| assert pipe.guider is None | |
| assert calls["repo_id"] == "user/multidiff-modular" | |
| assert calls["from_pretrained"]["trust_remote_code"] is True | |
| assert calls["from_pretrained"]["local_files_only"] is True | |
| assert calls["load_components"]["torch_dtype"] is runner.DTYPE_MAP["float32"] | |
| assert calls["load_components"]["local_files_only"] is True | |
| def test_image_conditioning_loads_and_forwards_image(monkeypatch): | |
| loaded_image = object() | |
| seen_paths = [] | |
| def fake_load_image(path): | |
| seen_paths.append(path) | |
| return loaded_image | |
| monkeypatch.setattr(runner, "load_image", fake_load_image) | |
| call_kwargs = runner.build_call_kwargs( | |
| _args("--image-conditioning", "conditioning.png"), | |
| SimpleNamespace(config={}), | |
| object(), | |
| ) | |
| assert seen_paths == ["conditioning.png"] | |
| assert call_kwargs["image"] is loaded_image | |
| def test_image_img2img_prepares_latents_and_forwards_strength(monkeypatch): | |
| loaded_image = object() | |
| raw_latents = object() | |
| generator = object() | |
| vae = object() | |
| class FakeImageProcessor: | |
| def __init__(self): | |
| self.calls = [] | |
| def preprocess(self, image, *, height, width, resize_mode): | |
| self.calls.append((image, height, width, resize_mode)) | |
| return "preprocessed-image" | |
| class FakeEncoderPipe: | |
| def __init__(self): | |
| self.updated_components = None | |
| self.calls = [] | |
| def update_components(self, **components): | |
| self.updated_components = components | |
| def __call__(self, *, condition_images, generator): | |
| self.calls.append((condition_images, generator)) | |
| return SimpleNamespace(image_latents=[raw_latents]) | |
| encoder_pipe = FakeEncoderPipe() | |
| encoder_block = SimpleNamespace(init_pipeline=lambda: encoder_pipe) | |
| img_conditioning_block = SimpleNamespace(sub_blocks={"encode": encoder_block}) | |
| vae_encoder_block = SimpleNamespace(sub_blocks={"img_conditioning": img_conditioning_block}) | |
| image_processor = FakeImageProcessor() | |
| pipe = SimpleNamespace( | |
| config={}, | |
| vae=vae, | |
| vae_scale_factor=8, | |
| blocks=SimpleNamespace(sub_blocks={"vae_encoder": vae_encoder_block}), | |
| image_processor=image_processor, | |
| ) | |
| monkeypatch.setattr(runner, "load_image", lambda path: loaded_image) | |
| monkeypatch.setattr( | |
| runner.Flux2PrepareImageLatentsStep, | |
| "_pack_latents", | |
| staticmethod(lambda latents: ("packed", latents)), | |
| ) | |
| call_kwargs = runner.build_call_kwargs( | |
| _args("--image-img2img", "init.png", "--height", "65", "--width", "81", "--strength", "0.42"), | |
| pipe, | |
| generator, | |
| ) | |
| assert image_processor.calls == [(loaded_image, 64, 80, "default")] | |
| assert encoder_pipe.updated_components == {"vae": vae} | |
| assert encoder_pipe.calls == [(["preprocessed-image"], generator)] | |
| assert call_kwargs["image_img2img"] == ("packed", raw_latents) | |
| assert call_kwargs["strength"] == 0.42 | |