Instructions to use arlaz/modular-flux2-multidiffusion with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use arlaz/modular-flux2-multidiffusion with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("arlaz/modular-flux2-multidiffusion", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- Draw Things
- DiffusionBee
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import re | |
| import warnings | |
| from collections.abc import Mapping, Sequence | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers import ComponentsManager, DiffusionPipeline, ModularPipeline | |
| from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE | |
| from diffusers.modular_pipelines.flux2.before_denoise import Flux2PrepareImageLatentsStep | |
| from diffusers.utils import load_image | |
| from PIL import Image | |
| DTYPE_MAP = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| QUANTIZATION_CHOICES = ("none", "float8wo", "int8wo", "int4wo", "float8dyn") | |
| DEFAULT_QUANTIZATION = "none" | |
| DEFAULT_REQUESTED_QUANTIZATION = "float8wo" | |
| WEIGHTING_TYPES = ("none", "linear", "cosine") | |
| BIGTIFF_PIXEL_THRESHOLD = 4096 * 4096 | |
| TIFF_SUFFIXES = {".tif", ".tiff"} | |
| def configure_torch(*, allow_tf32: bool) -> None: | |
| if allow_tf32 and torch.cuda.is_available(): | |
| torch.backends.fp32_precision = "tf32" | |
| torch.set_float32_matmul_precision("high") | |
| def build_multidiffusion_blocks(model_config: Mapping): | |
| from block import ( | |
| Flux2KleinBaseMultiDiffusionAutoBlocks, | |
| Flux2KleinMultiDiffusionAutoBlocks, | |
| Flux2MultiDiffusionAutoBlocks, | |
| ) | |
| class_name = model_config.get("_class_name") | |
| if class_name == "Flux2Pipeline": | |
| return Flux2MultiDiffusionAutoBlocks() | |
| if class_name == "Flux2KleinPipeline": | |
| if model_config.get("is_distilled"): | |
| return Flux2KleinMultiDiffusionAutoBlocks() | |
| return Flux2KleinBaseMultiDiffusionAutoBlocks() | |
| raise ValueError(f"Cannot select MultiDiffusion blocks from model class {class_name!r}.") | |
| def add_input_arguments(parser: argparse.ArgumentParser) -> None: | |
| inputs = parser.add_argument_group("inputs") | |
| inputs.add_argument("--prompt", default="A dense renaissance fresco.") | |
| inputs.add_argument("--masks", type=Path, default=None) | |
| inputs.add_argument("--image-img2img", default=None) | |
| inputs.add_argument("--image-conditioning", default=None) | |
| inputs.add_argument("--strength", type=float, default=1.0, help="The strength of renoising in the img2img setting.") | |
| def add_canvas_arguments(parser: argparse.ArgumentParser) -> None: | |
| canvas = parser.add_argument_group("canvas") | |
| canvas.add_argument("--height", type=int, default=4096) | |
| canvas.add_argument("--width", type=int, default=4096) | |
| canvas.add_argument("--height-generation", type=int, default=None) | |
| canvas.add_argument("--width-generation", type=int, default=None) | |
| canvas.add_argument("--window-stride-height", type=int, default=None) | |
| canvas.add_argument("--window-stride-width", type=int, default=None) | |
| canvas.add_argument("--window-stride-height-offset", type=int, default=0) | |
| canvas.add_argument("--window-stride-width-offset", type=int, default=0) | |
| canvas.add_argument("--panorama-width", action="store_true") | |
| canvas.add_argument("--panorama-height", action="store_true") | |
| canvas.add_argument("--weighting-type", default="cosine", choices=WEIGHTING_TYPES) | |
| canvas.add_argument("--weighting-range", type=float, default=None) | |
| def add_inference_arguments(parser: argparse.ArgumentParser) -> None: | |
| inference = parser.add_argument_group("inference") | |
| inference.add_argument("--num-inference-steps", type=int, default=None) | |
| inference.add_argument("--guidance-scale", type=float, default=None) | |
| inference.add_argument("--terra-scale", type=float, default=None) | |
| inference.add_argument("--seed", type=int, default=42) | |
| inference.add_argument("--num-images-per-prompt", type=int, default=1) | |
| inference.add_argument("--batch-size", type=int, default=1) | |
| def add_runtime_arguments(parser: argparse.ArgumentParser) -> None: | |
| runtime = parser.add_argument_group("runtime") | |
| runtime.add_argument("--dtype", default="bfloat16", choices=tuple(DTYPE_MAP)) | |
| runtime.add_argument("--device", default="cuda") | |
| runtime.add_argument("--output", default="output.png") | |
| runtime.add_argument("--local-files-only", action="store_true") | |
| runtime.add_argument("--allow-tf32", action="store_true") | |
| runtime.add_argument("--compile", action="store_true") | |
| runtime.add_argument( | |
| "--quantize", | |
| dest="quantization", | |
| action="store_const", | |
| const=DEFAULT_REQUESTED_QUANTIZATION, | |
| default=argparse.SUPPRESS, | |
| help=argparse.SUPPRESS, | |
| ) | |
| runtime.add_argument( | |
| "--quantization", | |
| nargs="?", | |
| const=DEFAULT_REQUESTED_QUANTIZATION, | |
| default=DEFAULT_QUANTIZATION, | |
| choices=QUANTIZATION_CHOICES, | |
| help=( | |
| "TorchAO quantization strategy for transformer, text_encoder, and vae. " | |
| f"Passing the flag without a value uses {DEFAULT_REQUESTED_QUANTIZATION}." | |
| ), | |
| ) | |
| runtime.add_argument( | |
| "--transformer-quantization", | |
| default=None, | |
| choices=QUANTIZATION_CHOICES, | |
| help="Override the quantization strategy for the image transformer component.", | |
| ) | |
| runtime.add_argument( | |
| "--text-encoder-quantization", | |
| default=None, | |
| choices=QUANTIZATION_CHOICES, | |
| help="Override the quantization strategy for the text_encoder component.", | |
| ) | |
| runtime.add_argument( | |
| "--vae-quantization", | |
| default=None, | |
| choices=QUANTIZATION_CHOICES, | |
| help="Override the quantization strategy for the VAE component.", | |
| ) | |
| runtime.add_argument("--enable-tiling", action="store_true") | |
| runtime.add_argument("--enable-slicing", action="store_true") | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Run Flux2-Klein MultiDiffusion with local modular blocks.") | |
| model = parser.add_argument_group("model") | |
| model.add_argument("--base-model", required=True) | |
| model.add_argument("--lora-path", "--lora_path", type=Path) | |
| add_input_arguments(parser) | |
| add_canvas_arguments(parser) | |
| add_inference_arguments(parser) | |
| add_runtime_arguments(parser) | |
| return parser | |
| def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: | |
| return build_parser().parse_args(argv) | |
| def load_model_config(base_model: str, *, local_files_only: bool) -> dict: | |
| return dict( | |
| DiffusionPipeline.load_config( | |
| base_model, | |
| local_files_only=local_files_only, | |
| ) | |
| ) | |
| def apply_window_stride_defaults(args: argparse.Namespace) -> argparse.Namespace: | |
| if args.height_generation is not None and args.window_stride_height is None: | |
| args.window_stride_height = args.height_generation // 2 | |
| if args.width_generation is not None and args.window_stride_width is None: | |
| args.window_stride_width = args.width_generation // 2 | |
| return args | |
| def default_num_inference_steps(model_config: Mapping) -> int | None: | |
| if model_config.get("_class_name") in {"Flux2KleinPipeline", "Flux2KleinModularPipeline"} and model_config.get( | |
| "is_distilled" | |
| ): | |
| return 4 | |
| return None | |
| def apply_vae_memory_options(pipe, *, enable_tiling: bool, enable_slicing: bool) -> None: | |
| if enable_tiling: | |
| pipe.vae.enable_tiling() | |
| if enable_slicing: | |
| pipe.vae.enable_slicing() | |
| def _validate_quantization_strategy(strategy: str, *, name: str) -> str: | |
| if strategy not in QUANTIZATION_CHOICES: | |
| raise ValueError(f"`{name}` must be one of {QUANTIZATION_CHOICES}, got {strategy!r}.") | |
| return strategy | |
| def resolve_quantization_mapping( | |
| quantization: str | None = DEFAULT_QUANTIZATION, | |
| *, | |
| transformer_quantization: str | None = None, | |
| text_encoder_quantization: str | None = None, | |
| vae_quantization: str | None = None, | |
| ) -> dict[str, str]: | |
| base_quantization = _validate_quantization_strategy( | |
| quantization or DEFAULT_QUANTIZATION, | |
| name="quantization", | |
| ) | |
| component_strategies = { | |
| "transformer": transformer_quantization if transformer_quantization is not None else base_quantization, | |
| "text_encoder": text_encoder_quantization if text_encoder_quantization is not None else base_quantization, | |
| "vae": vae_quantization if vae_quantization is not None else base_quantization, | |
| } | |
| return { | |
| component: _validate_quantization_strategy(strategy, name=f"{component}_quantization") | |
| for component, strategy in component_strategies.items() | |
| if strategy != "none" | |
| } | |
| def _build_torchao_quant_type(strategy: str): | |
| from torchao.quantization import ( | |
| Float8DynamicActivationFloat8WeightConfig, | |
| Float8WeightOnlyConfig, | |
| Int4WeightOnlyConfig, | |
| Int8WeightOnlyConfig, | |
| ) | |
| if strategy == "float8wo": | |
| return Float8WeightOnlyConfig() | |
| if strategy == "int8wo": | |
| return Int8WeightOnlyConfig() | |
| if strategy == "int4wo": | |
| return Int4WeightOnlyConfig(group_size=128) | |
| if strategy == "float8dyn": | |
| return Float8DynamicActivationFloat8WeightConfig() | |
| raise ValueError(f"Cannot build TorchAO quantization config for {strategy!r}.") | |
| def build_quantization_config( | |
| quantization: str | None = DEFAULT_QUANTIZATION, | |
| *, | |
| transformer_quantization: str | None = None, | |
| text_encoder_quantization: str | None = None, | |
| vae_quantization: str | None = None, | |
| ): | |
| quant_mapping = resolve_quantization_mapping( | |
| quantization, | |
| transformer_quantization=transformer_quantization, | |
| text_encoder_quantization=text_encoder_quantization, | |
| vae_quantization=vae_quantization, | |
| ) | |
| if not quant_mapping: | |
| return None | |
| from diffusers import PipelineQuantizationConfig | |
| from diffusers import TorchAoConfig as DiffusersTorchAoConfig | |
| from transformers import TorchAoConfig as TransformersTorchAoConfig | |
| component_config_classes = { | |
| "transformer": DiffusersTorchAoConfig, | |
| "text_encoder": TransformersTorchAoConfig, | |
| "vae": DiffusersTorchAoConfig, | |
| } | |
| return PipelineQuantizationConfig( | |
| quant_mapping={ | |
| component: component_config_classes[component](_build_torchao_quant_type(strategy)) | |
| for component, strategy in quant_mapping.items() | |
| } | |
| ) | |
| def build_components_manager(device: torch.device) -> tuple[ComponentsManager, bool]: | |
| manager = ComponentsManager() | |
| use_auto_offload = device.type == "cuda" | |
| if use_auto_offload: | |
| manager.enable_auto_cpu_offload(device=device) | |
| elif device.type == "mps": | |
| print("MPS does not support ComponentsManager auto CPU offload; loading components directly on MPS.") | |
| return manager, use_auto_offload | |
| def init_modular_pipeline( | |
| *, | |
| base_model: str, | |
| model_config: Mapping | None = None, | |
| guidance_scale=None, | |
| dtype: str, | |
| device: str, | |
| local_files_only: bool, | |
| compile: bool, | |
| quantization: str | None = DEFAULT_QUANTIZATION, | |
| transformer_quantization: str | None = None, | |
| text_encoder_quantization: str | None = None, | |
| vae_quantization: str | None = None, | |
| ) -> ModularPipeline: | |
| device_obj = torch.device(device) | |
| manager, use_auto_offload = build_components_manager(device_obj) | |
| model_config = model_config or load_model_config( | |
| base_model, | |
| local_files_only=local_files_only, | |
| ) | |
| pipe = build_multidiffusion_blocks(model_config).init_pipeline( | |
| base_model, | |
| components_manager=manager, | |
| ) | |
| pipe.load_components( | |
| torch_dtype=DTYPE_MAP[dtype], | |
| local_files_only=local_files_only, | |
| quantization_config=build_quantization_config( | |
| quantization, | |
| transformer_quantization=transformer_quantization, | |
| text_encoder_quantization=text_encoder_quantization, | |
| vae_quantization=vae_quantization, | |
| ), | |
| ) | |
| if not use_auto_offload and device_obj.type != "cpu": | |
| pipe.to(device_obj) | |
| if guidance_scale is not None and hasattr(pipe, "guider") and pipe.guider is not None: | |
| guider_spec = pipe.get_component_spec("guider") | |
| pipe.update_components( | |
| guider=guider_spec.create(guidance_scale=guidance_scale), | |
| ) | |
| if compile: | |
| pipe.transformer.compile_repeated_blocks( | |
| fullgraph=True, | |
| dynamic=True, | |
| ) | |
| return pipe | |
| def prepare_img2img_latents(pipe, image, *, height: int, width: int, generator): | |
| multiple_of = pipe.vae_scale_factor * 2 | |
| vae_encoder = pipe.blocks.sub_blocks["vae_encoder"] | |
| img_conditioning = vae_encoder.sub_blocks["img_conditioning"] | |
| encode_block = img_conditioning.sub_blocks["encode"] | |
| vae_encoder_pipe = encode_block.init_pipeline() | |
| vae_encoder_pipe.update_components(vae=pipe.vae) | |
| target_height = (height // multiple_of) * multiple_of | |
| target_width = (width // multiple_of) * multiple_of | |
| # Do not use the stock vae_encoder block here: its preprocess step clamps to ~1024x1024. | |
| image_tensor = pipe.image_processor.preprocess( | |
| image, | |
| height=target_height, | |
| width=target_width, | |
| resize_mode="default", | |
| ) | |
| image_latents = vae_encoder_pipe( | |
| condition_images=[image_tensor], | |
| generator=generator, | |
| ).image_latents[0] | |
| return Flux2PrepareImageLatentsStep._pack_latents(image_latents) | |
| def _natural_sort_key(path: Path): | |
| return [int(part) if part.isdigit() else part.lower() for part in re.split(r"(\d+)", path.name)] | |
| def _is_csv_prompt(prompt: str) -> bool: | |
| return Path(prompt).suffix.lower() == ".csv" | |
| def _read_regional_prompt_csv(prompt_csv: Path) -> list[dict[str, str]]: | |
| if not prompt_csv.is_file(): | |
| raise FileNotFoundError(f"Prompt CSV does not exist: {prompt_csv}") | |
| with prompt_csv.open(newline="", encoding="utf-8") as handle: | |
| reader = csv.DictReader(handle) | |
| if reader.fieldnames != ["mask", "prompt"]: | |
| raise ValueError(f"Prompt CSV must have exact headers 'mask,prompt', got {reader.fieldnames!r}.") | |
| rows = [{"mask": row["mask"], "prompt": row["prompt"]} for row in reader] | |
| if not rows: | |
| raise ValueError(f"Prompt CSV must contain at least one regional prompt row: {prompt_csv}") | |
| return rows | |
| def _load_grayscale_mask(path: Path, *, width: int, height: int) -> torch.Tensor: | |
| with Image.open(path) as image: | |
| if image.size != (width, height): | |
| raise ValueError(f"Mask {path.name!r} must have size {(width, height)}, got {image.size}.") | |
| mask = torch.from_numpy(np.asarray(image.convert("L"), dtype=np.uint8).copy()).to(torch.float32) | |
| if not torch.all((mask == 0) | (mask == 255)): | |
| warnings.warn(f"Mask {path.name!r} is not binary; values will be used as fractional weights.", stacklevel=2) | |
| return mask / 255.0 | |
| def _load_regional_masks( | |
| *, | |
| rows: list[dict[str, str]], | |
| masks_dir: Path, | |
| height: int, | |
| width: int, | |
| vae_scale_factor: int, | |
| ) -> torch.Tensor: | |
| if not masks_dir.is_dir(): | |
| raise FileNotFoundError(f"Mask folder does not exist: {masks_dir}") | |
| available_masks = {path.name: path for path in sorted(masks_dir.iterdir(), key=_natural_sort_key) if path.is_file()} | |
| masks = [] | |
| for row in rows: | |
| mask_name = row["mask"] | |
| if mask_name not in available_masks: | |
| raise FileNotFoundError(f"Mask {mask_name!r} from prompt CSV was not found in {masks_dir}.") | |
| masks.append(_load_grayscale_mask(available_masks[mask_name], width=width, height=height)) | |
| latent_height = height // (vae_scale_factor * 2) | |
| latent_width = width // (vae_scale_factor * 2) | |
| regional_masks = F.interpolate( | |
| torch.stack(masks).unsqueeze(1), | |
| size=(latent_height, latent_width), | |
| mode="area", | |
| ).squeeze(1) | |
| if torch.any(regional_masks.sum(dim=0) <= 0): | |
| raise ValueError("Regional masks must cover every packed latent cell.") | |
| return regional_masks | |
| def prepare_regional_prompt_inputs(args: argparse.Namespace, pipe) -> tuple[str | list[str], torch.Tensor | None]: | |
| if not _is_csv_prompt(args.prompt): | |
| if args.masks is not None: | |
| raise ValueError("`--masks` is only valid when `--prompt` points to a CSV file.") | |
| return args.prompt, None | |
| if args.masks is None: | |
| raise ValueError("`--masks` is required when `--prompt` points to a CSV file.") | |
| rows = _read_regional_prompt_csv(Path(args.prompt)) | |
| regional_masks = _load_regional_masks( | |
| rows=rows, | |
| masks_dir=args.masks, | |
| height=args.height, | |
| width=args.width, | |
| vae_scale_factor=pipe.vae_scale_factor, | |
| ) | |
| return ["", *[row["prompt"] for row in rows]], regional_masks | |
| def build_call_kwargs(args: argparse.Namespace, pipe, generator, model_config: Mapping | None = None) -> dict: | |
| model_config = model_config or {} | |
| if args.batch_size <= 0: | |
| raise ValueError(f"`--batch-size` must be a positive integer, got {args.batch_size}.") | |
| prompt, regional_masks = prepare_regional_prompt_inputs(args, pipe) | |
| call_kwargs = { | |
| "prompt": prompt, | |
| "height": args.height, | |
| "width": args.width, | |
| "height_generation": args.height_generation, | |
| "width_generation": args.width_generation, | |
| "window_stride_height": args.window_stride_height, | |
| "window_stride_width": args.window_stride_width, | |
| "window_stride_height_offset": args.window_stride_height_offset, | |
| "window_stride_width_offset": args.window_stride_width_offset, | |
| "panorama_width": args.panorama_width, | |
| "panorama_height": args.panorama_height, | |
| "weighting_type": args.weighting_type, | |
| "weighting_range": args.weighting_range, | |
| "generator": generator, | |
| "num_images_per_prompt": args.num_images_per_prompt, | |
| "window_batch_size": args.batch_size, | |
| } | |
| if regional_masks is not None: | |
| call_kwargs["regional_masks"] = regional_masks | |
| if args.num_inference_steps is not None: | |
| call_kwargs["num_inference_steps"] = args.num_inference_steps | |
| elif (steps := default_num_inference_steps(model_config)) is not None: | |
| call_kwargs["num_inference_steps"] = steps | |
| if args.image_img2img is not None: | |
| call_kwargs["image_img2img"] = prepare_img2img_latents( | |
| pipe, | |
| load_image(args.image_img2img), | |
| height=args.height, | |
| width=args.width, | |
| generator=generator, | |
| ) | |
| call_kwargs["strength"] = args.strength | |
| if args.image_conditioning is not None: | |
| call_kwargs["image"] = load_image(args.image_conditioning) | |
| return call_kwargs | |
| def load_lora_adapter(pipe, *, lora_path: Path | None, terra_scale: float | None, num_images_per_prompt: int) -> None: | |
| if lora_path is None: | |
| return | |
| pipe.transformer.load_lora_adapter( | |
| lora_path, | |
| weight_name=LORA_WEIGHT_NAME_SAFE, | |
| ) | |
| print(f"Loaded LoRA weights from {lora_path}") | |
| print(pipe.transformer) | |
| if terra_scale is not None: | |
| pipe.transformer.set_terra_t( | |
| [terra_scale] * num_images_per_prompt, | |
| adapter=None, | |
| ) | |
| def save_images(images, output: str | Path, *, num_images_per_prompt: int) -> None: | |
| output_path = Path(output) | |
| for i, image in enumerate(images): | |
| use_bigtiff = image.size[0] * image.size[1] > BIGTIFF_PIXEL_THRESHOLD | |
| save_path = output_path | |
| if num_images_per_prompt > 1: | |
| save_path = output_path.with_stem(f"{output_path.stem}_{i}") | |
| save_path = _resolve_output_path_for_image(save_path, image=image, use_bigtiff=use_bigtiff) | |
| if use_bigtiff: | |
| image.save(save_path, format="TIFF", big_tiff=True) | |
| else: | |
| image.save(save_path) | |
| def _resolve_output_path_for_image(output_path: Path, *, image, use_bigtiff: bool) -> Path: | |
| if output_path.suffix == "": | |
| return output_path.with_suffix(".tif" if use_bigtiff else ".png") | |
| if use_bigtiff and output_path.suffix.lower() not in TIFF_SUFFIXES: | |
| bigtiff_path = output_path.with_suffix(".tif") | |
| warnings.warn( | |
| ( | |
| f"Image size {image.size[0]}x{image.size[1]} exceeds {BIGTIFF_PIXEL_THRESHOLD} pixels; " | |
| f"saving BigTIFF to {bigtiff_path} instead of {output_path}." | |
| ), | |
| stacklevel=2, | |
| ) | |
| return bigtiff_path | |
| return output_path | |
| def main(argv: Sequence[str] | None = None) -> None: | |
| args = apply_window_stride_defaults(parse_args(argv)) | |
| configure_torch(allow_tf32=args.allow_tf32) | |
| model_config = load_model_config( | |
| args.base_model, | |
| local_files_only=args.local_files_only, | |
| ) | |
| print("Initializing modular pipeline...") | |
| print(args) | |
| pipe = init_modular_pipeline( | |
| base_model=args.base_model, | |
| model_config=model_config, | |
| guidance_scale=args.guidance_scale, | |
| dtype=args.dtype, | |
| device=args.device, | |
| local_files_only=args.local_files_only, | |
| compile=args.compile, | |
| quantization=args.quantization, | |
| transformer_quantization=args.transformer_quantization, | |
| text_encoder_quantization=args.text_encoder_quantization, | |
| vae_quantization=args.vae_quantization, | |
| ) | |
| apply_vae_memory_options( | |
| pipe, | |
| enable_tiling=args.enable_tiling, | |
| enable_slicing=args.enable_slicing, | |
| ) | |
| load_lora_adapter( | |
| pipe, | |
| lora_path=args.lora_path, | |
| terra_scale=args.terra_scale, | |
| num_images_per_prompt=args.num_images_per_prompt, | |
| ) | |
| generator = torch.Generator().manual_seed(args.seed) | |
| call_kwargs = build_call_kwargs( | |
| args, | |
| pipe, | |
| generator, | |
| model_config, | |
| ) | |
| output = pipe(**call_kwargs) | |
| save_images( | |
| output.images, | |
| args.output, | |
| num_images_per_prompt=args.num_images_per_prompt, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |