from __future__ import annotations import argparse import csv import re import warnings from collections.abc import Mapping, Sequence from pathlib import Path import numpy as np import torch import torch.nn.functional as F from diffusers import ComponentsManager, DiffusionPipeline, ModularPipeline from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE from diffusers.modular_pipelines.flux2.before_denoise import Flux2PrepareImageLatentsStep from diffusers.utils import load_image from PIL import Image DTYPE_MAP = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } QUANTIZATION_CHOICES = ("none", "float8wo", "int8wo", "int4wo", "float8dyn") DEFAULT_QUANTIZATION = "none" DEFAULT_REQUESTED_QUANTIZATION = "float8wo" WEIGHTING_TYPES = ("none", "linear", "cosine") BIGTIFF_PIXEL_THRESHOLD = 4096 * 4096 TIFF_SUFFIXES = {".tif", ".tiff"} def configure_torch(*, allow_tf32: bool) -> None: if allow_tf32 and torch.cuda.is_available(): torch.backends.fp32_precision = "tf32" torch.set_float32_matmul_precision("high") def build_multidiffusion_blocks(model_config: Mapping): from block import ( Flux2KleinBaseMultiDiffusionAutoBlocks, Flux2KleinMultiDiffusionAutoBlocks, Flux2MultiDiffusionAutoBlocks, ) class_name = model_config.get("_class_name") if class_name == "Flux2Pipeline": return Flux2MultiDiffusionAutoBlocks() if class_name == "Flux2KleinPipeline": if model_config.get("is_distilled"): return Flux2KleinMultiDiffusionAutoBlocks() return Flux2KleinBaseMultiDiffusionAutoBlocks() raise ValueError(f"Cannot select MultiDiffusion blocks from model class {class_name!r}.") def add_input_arguments(parser: argparse.ArgumentParser) -> None: inputs = parser.add_argument_group("inputs") inputs.add_argument("--prompt", default="A dense renaissance fresco.") inputs.add_argument("--masks", type=Path, default=None) inputs.add_argument("--image-img2img", default=None) inputs.add_argument("--image-conditioning", default=None) inputs.add_argument("--strength", type=float, default=1.0, help="The strength of renoising in the img2img setting.") def add_canvas_arguments(parser: argparse.ArgumentParser) -> None: canvas = parser.add_argument_group("canvas") canvas.add_argument("--height", type=int, default=4096) canvas.add_argument("--width", type=int, default=4096) canvas.add_argument("--height-generation", type=int, default=None) canvas.add_argument("--width-generation", type=int, default=None) canvas.add_argument("--window-stride-height", type=int, default=None) canvas.add_argument("--window-stride-width", type=int, default=None) canvas.add_argument("--window-stride-height-offset", type=int, default=0) canvas.add_argument("--window-stride-width-offset", type=int, default=0) canvas.add_argument("--panorama-width", action="store_true") canvas.add_argument("--panorama-height", action="store_true") canvas.add_argument("--weighting-type", default="cosine", choices=WEIGHTING_TYPES) canvas.add_argument("--weighting-range", type=float, default=None) def add_inference_arguments(parser: argparse.ArgumentParser) -> None: inference = parser.add_argument_group("inference") inference.add_argument("--num-inference-steps", type=int, default=None) inference.add_argument("--guidance-scale", type=float, default=None) inference.add_argument("--terra-scale", type=float, default=None) inference.add_argument("--seed", type=int, default=42) inference.add_argument("--num-images-per-prompt", type=int, default=1) inference.add_argument("--batch-size", type=int, default=1) def add_runtime_arguments(parser: argparse.ArgumentParser) -> None: runtime = parser.add_argument_group("runtime") runtime.add_argument("--dtype", default="bfloat16", choices=tuple(DTYPE_MAP)) runtime.add_argument("--device", default="cuda") runtime.add_argument("--output", default="output.png") runtime.add_argument("--local-files-only", action="store_true") runtime.add_argument("--allow-tf32", action="store_true") runtime.add_argument("--compile", action="store_true") runtime.add_argument( "--quantize", dest="quantization", action="store_const", const=DEFAULT_REQUESTED_QUANTIZATION, default=argparse.SUPPRESS, help=argparse.SUPPRESS, ) runtime.add_argument( "--quantization", nargs="?", const=DEFAULT_REQUESTED_QUANTIZATION, default=DEFAULT_QUANTIZATION, choices=QUANTIZATION_CHOICES, help=( "TorchAO quantization strategy for transformer, text_encoder, and vae. " f"Passing the flag without a value uses {DEFAULT_REQUESTED_QUANTIZATION}." ), ) runtime.add_argument( "--transformer-quantization", default=None, choices=QUANTIZATION_CHOICES, help="Override the quantization strategy for the image transformer component.", ) runtime.add_argument( "--text-encoder-quantization", default=None, choices=QUANTIZATION_CHOICES, help="Override the quantization strategy for the text_encoder component.", ) runtime.add_argument( "--vae-quantization", default=None, choices=QUANTIZATION_CHOICES, help="Override the quantization strategy for the VAE component.", ) runtime.add_argument("--enable-tiling", action="store_true") runtime.add_argument("--enable-slicing", action="store_true") def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run Flux2-Klein MultiDiffusion with local modular blocks.") model = parser.add_argument_group("model") model.add_argument("--base-model", required=True) model.add_argument("--lora-path", "--lora_path", type=Path) add_input_arguments(parser) add_canvas_arguments(parser) add_inference_arguments(parser) add_runtime_arguments(parser) return parser def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: return build_parser().parse_args(argv) def load_model_config(base_model: str, *, local_files_only: bool) -> dict: return dict( DiffusionPipeline.load_config( base_model, local_files_only=local_files_only, ) ) def apply_window_stride_defaults(args: argparse.Namespace) -> argparse.Namespace: if args.height_generation is not None and args.window_stride_height is None: args.window_stride_height = args.height_generation // 2 if args.width_generation is not None and args.window_stride_width is None: args.window_stride_width = args.width_generation // 2 return args def default_num_inference_steps(model_config: Mapping) -> int | None: if model_config.get("_class_name") in {"Flux2KleinPipeline", "Flux2KleinModularPipeline"} and model_config.get( "is_distilled" ): return 4 return None def apply_vae_memory_options(pipe, *, enable_tiling: bool, enable_slicing: bool) -> None: if enable_tiling: pipe.vae.enable_tiling() if enable_slicing: pipe.vae.enable_slicing() def _validate_quantization_strategy(strategy: str, *, name: str) -> str: if strategy not in QUANTIZATION_CHOICES: raise ValueError(f"`{name}` must be one of {QUANTIZATION_CHOICES}, got {strategy!r}.") return strategy def resolve_quantization_mapping( quantization: str | None = DEFAULT_QUANTIZATION, *, transformer_quantization: str | None = None, text_encoder_quantization: str | None = None, vae_quantization: str | None = None, ) -> dict[str, str]: base_quantization = _validate_quantization_strategy( quantization or DEFAULT_QUANTIZATION, name="quantization", ) component_strategies = { "transformer": transformer_quantization if transformer_quantization is not None else base_quantization, "text_encoder": text_encoder_quantization if text_encoder_quantization is not None else base_quantization, "vae": vae_quantization if vae_quantization is not None else base_quantization, } return { component: _validate_quantization_strategy(strategy, name=f"{component}_quantization") for component, strategy in component_strategies.items() if strategy != "none" } def _build_torchao_quant_type(strategy: str): from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig, ) if strategy == "float8wo": return Float8WeightOnlyConfig() if strategy == "int8wo": return Int8WeightOnlyConfig() if strategy == "int4wo": return Int4WeightOnlyConfig(group_size=128) if strategy == "float8dyn": return Float8DynamicActivationFloat8WeightConfig() raise ValueError(f"Cannot build TorchAO quantization config for {strategy!r}.") def build_quantization_config( quantization: str | None = DEFAULT_QUANTIZATION, *, transformer_quantization: str | None = None, text_encoder_quantization: str | None = None, vae_quantization: str | None = None, ): quant_mapping = resolve_quantization_mapping( quantization, transformer_quantization=transformer_quantization, text_encoder_quantization=text_encoder_quantization, vae_quantization=vae_quantization, ) if not quant_mapping: return None from diffusers import PipelineQuantizationConfig from diffusers import TorchAoConfig as DiffusersTorchAoConfig from transformers import TorchAoConfig as TransformersTorchAoConfig component_config_classes = { "transformer": DiffusersTorchAoConfig, "text_encoder": TransformersTorchAoConfig, "vae": DiffusersTorchAoConfig, } return PipelineQuantizationConfig( quant_mapping={ component: component_config_classes[component](_build_torchao_quant_type(strategy)) for component, strategy in quant_mapping.items() } ) def build_components_manager(device: torch.device) -> tuple[ComponentsManager, bool]: manager = ComponentsManager() use_auto_offload = device.type == "cuda" if use_auto_offload: manager.enable_auto_cpu_offload(device=device) elif device.type == "mps": print("MPS does not support ComponentsManager auto CPU offload; loading components directly on MPS.") return manager, use_auto_offload def init_modular_pipeline( *, base_model: str, model_config: Mapping | None = None, guidance_scale=None, dtype: str, device: str, local_files_only: bool, compile: bool, quantization: str | None = DEFAULT_QUANTIZATION, transformer_quantization: str | None = None, text_encoder_quantization: str | None = None, vae_quantization: str | None = None, ) -> ModularPipeline: device_obj = torch.device(device) manager, use_auto_offload = build_components_manager(device_obj) model_config = model_config or load_model_config( base_model, local_files_only=local_files_only, ) pipe = build_multidiffusion_blocks(model_config).init_pipeline( base_model, components_manager=manager, ) pipe.load_components( torch_dtype=DTYPE_MAP[dtype], local_files_only=local_files_only, quantization_config=build_quantization_config( quantization, transformer_quantization=transformer_quantization, text_encoder_quantization=text_encoder_quantization, vae_quantization=vae_quantization, ), ) if not use_auto_offload and device_obj.type != "cpu": pipe.to(device_obj) if guidance_scale is not None and hasattr(pipe, "guider") and pipe.guider is not None: guider_spec = pipe.get_component_spec("guider") pipe.update_components( guider=guider_spec.create(guidance_scale=guidance_scale), ) if compile: pipe.transformer.compile_repeated_blocks( fullgraph=True, dynamic=True, ) return pipe def prepare_img2img_latents(pipe, image, *, height: int, width: int, generator): multiple_of = pipe.vae_scale_factor * 2 vae_encoder = pipe.blocks.sub_blocks["vae_encoder"] img_conditioning = vae_encoder.sub_blocks["img_conditioning"] encode_block = img_conditioning.sub_blocks["encode"] vae_encoder_pipe = encode_block.init_pipeline() vae_encoder_pipe.update_components(vae=pipe.vae) target_height = (height // multiple_of) * multiple_of target_width = (width // multiple_of) * multiple_of # Do not use the stock vae_encoder block here: its preprocess step clamps to ~1024x1024. image_tensor = pipe.image_processor.preprocess( image, height=target_height, width=target_width, resize_mode="default", ) image_latents = vae_encoder_pipe( condition_images=[image_tensor], generator=generator, ).image_latents[0] return Flux2PrepareImageLatentsStep._pack_latents(image_latents) def _natural_sort_key(path: Path): return [int(part) if part.isdigit() else part.lower() for part in re.split(r"(\d+)", path.name)] def _is_csv_prompt(prompt: str) -> bool: return Path(prompt).suffix.lower() == ".csv" def _read_regional_prompt_csv(prompt_csv: Path) -> list[dict[str, str]]: if not prompt_csv.is_file(): raise FileNotFoundError(f"Prompt CSV does not exist: {prompt_csv}") with prompt_csv.open(newline="", encoding="utf-8") as handle: reader = csv.DictReader(handle) if reader.fieldnames != ["mask", "prompt"]: raise ValueError(f"Prompt CSV must have exact headers 'mask,prompt', got {reader.fieldnames!r}.") rows = [{"mask": row["mask"], "prompt": row["prompt"]} for row in reader] if not rows: raise ValueError(f"Prompt CSV must contain at least one regional prompt row: {prompt_csv}") return rows def _load_grayscale_mask(path: Path, *, width: int, height: int) -> torch.Tensor: with Image.open(path) as image: if image.size != (width, height): raise ValueError(f"Mask {path.name!r} must have size {(width, height)}, got {image.size}.") mask = torch.from_numpy(np.asarray(image.convert("L"), dtype=np.uint8).copy()).to(torch.float32) if not torch.all((mask == 0) | (mask == 255)): warnings.warn(f"Mask {path.name!r} is not binary; values will be used as fractional weights.", stacklevel=2) return mask / 255.0 def _load_regional_masks( *, rows: list[dict[str, str]], masks_dir: Path, height: int, width: int, vae_scale_factor: int, ) -> torch.Tensor: if not masks_dir.is_dir(): raise FileNotFoundError(f"Mask folder does not exist: {masks_dir}") available_masks = {path.name: path for path in sorted(masks_dir.iterdir(), key=_natural_sort_key) if path.is_file()} masks = [] for row in rows: mask_name = row["mask"] if mask_name not in available_masks: raise FileNotFoundError(f"Mask {mask_name!r} from prompt CSV was not found in {masks_dir}.") masks.append(_load_grayscale_mask(available_masks[mask_name], width=width, height=height)) latent_height = height // (vae_scale_factor * 2) latent_width = width // (vae_scale_factor * 2) regional_masks = F.interpolate( torch.stack(masks).unsqueeze(1), size=(latent_height, latent_width), mode="area", ).squeeze(1) if torch.any(regional_masks.sum(dim=0) <= 0): raise ValueError("Regional masks must cover every packed latent cell.") return regional_masks def prepare_regional_prompt_inputs(args: argparse.Namespace, pipe) -> tuple[str | list[str], torch.Tensor | None]: if not _is_csv_prompt(args.prompt): if args.masks is not None: raise ValueError("`--masks` is only valid when `--prompt` points to a CSV file.") return args.prompt, None if args.masks is None: raise ValueError("`--masks` is required when `--prompt` points to a CSV file.") rows = _read_regional_prompt_csv(Path(args.prompt)) regional_masks = _load_regional_masks( rows=rows, masks_dir=args.masks, height=args.height, width=args.width, vae_scale_factor=pipe.vae_scale_factor, ) return ["", *[row["prompt"] for row in rows]], regional_masks def build_call_kwargs(args: argparse.Namespace, pipe, generator, model_config: Mapping | None = None) -> dict: model_config = model_config or {} if args.batch_size <= 0: raise ValueError(f"`--batch-size` must be a positive integer, got {args.batch_size}.") prompt, regional_masks = prepare_regional_prompt_inputs(args, pipe) call_kwargs = { "prompt": prompt, "height": args.height, "width": args.width, "height_generation": args.height_generation, "width_generation": args.width_generation, "window_stride_height": args.window_stride_height, "window_stride_width": args.window_stride_width, "window_stride_height_offset": args.window_stride_height_offset, "window_stride_width_offset": args.window_stride_width_offset, "panorama_width": args.panorama_width, "panorama_height": args.panorama_height, "weighting_type": args.weighting_type, "weighting_range": args.weighting_range, "generator": generator, "num_images_per_prompt": args.num_images_per_prompt, "window_batch_size": args.batch_size, } if regional_masks is not None: call_kwargs["regional_masks"] = regional_masks if args.num_inference_steps is not None: call_kwargs["num_inference_steps"] = args.num_inference_steps elif (steps := default_num_inference_steps(model_config)) is not None: call_kwargs["num_inference_steps"] = steps if args.image_img2img is not None: call_kwargs["image_img2img"] = prepare_img2img_latents( pipe, load_image(args.image_img2img), height=args.height, width=args.width, generator=generator, ) call_kwargs["strength"] = args.strength if args.image_conditioning is not None: call_kwargs["image"] = load_image(args.image_conditioning) return call_kwargs def load_lora_adapter(pipe, *, lora_path: Path | None, terra_scale: float | None, num_images_per_prompt: int) -> None: if lora_path is None: return pipe.transformer.load_lora_adapter( lora_path, weight_name=LORA_WEIGHT_NAME_SAFE, ) print(f"Loaded LoRA weights from {lora_path}") print(pipe.transformer) if terra_scale is not None: pipe.transformer.set_terra_t( [terra_scale] * num_images_per_prompt, adapter=None, ) def save_images(images, output: str | Path, *, num_images_per_prompt: int) -> None: output_path = Path(output) for i, image in enumerate(images): use_bigtiff = image.size[0] * image.size[1] > BIGTIFF_PIXEL_THRESHOLD save_path = output_path if num_images_per_prompt > 1: save_path = output_path.with_stem(f"{output_path.stem}_{i}") save_path = _resolve_output_path_for_image(save_path, image=image, use_bigtiff=use_bigtiff) if use_bigtiff: image.save(save_path, format="TIFF", big_tiff=True) else: image.save(save_path) def _resolve_output_path_for_image(output_path: Path, *, image, use_bigtiff: bool) -> Path: if output_path.suffix == "": return output_path.with_suffix(".tif" if use_bigtiff else ".png") if use_bigtiff and output_path.suffix.lower() not in TIFF_SUFFIXES: bigtiff_path = output_path.with_suffix(".tif") warnings.warn( ( f"Image size {image.size[0]}x{image.size[1]} exceeds {BIGTIFF_PIXEL_THRESHOLD} pixels; " f"saving BigTIFF to {bigtiff_path} instead of {output_path}." ), stacklevel=2, ) return bigtiff_path return output_path def main(argv: Sequence[str] | None = None) -> None: args = apply_window_stride_defaults(parse_args(argv)) configure_torch(allow_tf32=args.allow_tf32) model_config = load_model_config( args.base_model, local_files_only=args.local_files_only, ) print("Initializing modular pipeline...") print(args) pipe = init_modular_pipeline( base_model=args.base_model, model_config=model_config, guidance_scale=args.guidance_scale, dtype=args.dtype, device=args.device, local_files_only=args.local_files_only, compile=args.compile, quantization=args.quantization, transformer_quantization=args.transformer_quantization, text_encoder_quantization=args.text_encoder_quantization, vae_quantization=args.vae_quantization, ) apply_vae_memory_options( pipe, enable_tiling=args.enable_tiling, enable_slicing=args.enable_slicing, ) load_lora_adapter( pipe, lora_path=args.lora_path, terra_scale=args.terra_scale, num_images_per_prompt=args.num_images_per_prompt, ) generator = torch.Generator().manual_seed(args.seed) call_kwargs = build_call_kwargs( args, pipe, generator, model_config, ) output = pipe(**call_kwargs) save_images( output.images, args.output, num_images_per_prompt=args.num_images_per_prompt, ) if __name__ == "__main__": main()