arlaz's picture
Initial Multidiff Modular export
d8bf41d
from __future__ import annotations
import argparse
import csv
import re
import warnings
from collections.abc import Mapping, Sequence
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import ComponentsManager, DiffusionPipeline, ModularPipeline
from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE
from diffusers.modular_pipelines.flux2.before_denoise import Flux2PrepareImageLatentsStep
from diffusers.utils import load_image
from PIL import Image
DTYPE_MAP = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
QUANTIZATION_CHOICES = ("none", "float8wo", "int8wo", "int4wo", "float8dyn")
DEFAULT_QUANTIZATION = "none"
DEFAULT_REQUESTED_QUANTIZATION = "float8wo"
WEIGHTING_TYPES = ("none", "linear", "cosine")
BIGTIFF_PIXEL_THRESHOLD = 4096 * 4096
TIFF_SUFFIXES = {".tif", ".tiff"}
def configure_torch(*, allow_tf32: bool) -> None:
if allow_tf32 and torch.cuda.is_available():
torch.backends.fp32_precision = "tf32"
torch.set_float32_matmul_precision("high")
def build_multidiffusion_blocks(model_config: Mapping):
from block import (
Flux2KleinBaseMultiDiffusionAutoBlocks,
Flux2KleinMultiDiffusionAutoBlocks,
Flux2MultiDiffusionAutoBlocks,
)
class_name = model_config.get("_class_name")
if class_name == "Flux2Pipeline":
return Flux2MultiDiffusionAutoBlocks()
if class_name == "Flux2KleinPipeline":
if model_config.get("is_distilled"):
return Flux2KleinMultiDiffusionAutoBlocks()
return Flux2KleinBaseMultiDiffusionAutoBlocks()
raise ValueError(f"Cannot select MultiDiffusion blocks from model class {class_name!r}.")
def add_input_arguments(parser: argparse.ArgumentParser) -> None:
inputs = parser.add_argument_group("inputs")
inputs.add_argument("--prompt", default="A dense renaissance fresco.")
inputs.add_argument("--masks", type=Path, default=None)
inputs.add_argument("--image-img2img", default=None)
inputs.add_argument("--image-conditioning", default=None)
inputs.add_argument("--strength", type=float, default=1.0, help="The strength of renoising in the img2img setting.")
def add_canvas_arguments(parser: argparse.ArgumentParser) -> None:
canvas = parser.add_argument_group("canvas")
canvas.add_argument("--height", type=int, default=4096)
canvas.add_argument("--width", type=int, default=4096)
canvas.add_argument("--height-generation", type=int, default=None)
canvas.add_argument("--width-generation", type=int, default=None)
canvas.add_argument("--window-stride-height", type=int, default=None)
canvas.add_argument("--window-stride-width", type=int, default=None)
canvas.add_argument("--window-stride-height-offset", type=int, default=0)
canvas.add_argument("--window-stride-width-offset", type=int, default=0)
canvas.add_argument("--panorama-width", action="store_true")
canvas.add_argument("--panorama-height", action="store_true")
canvas.add_argument("--weighting-type", default="cosine", choices=WEIGHTING_TYPES)
canvas.add_argument("--weighting-range", type=float, default=None)
def add_inference_arguments(parser: argparse.ArgumentParser) -> None:
inference = parser.add_argument_group("inference")
inference.add_argument("--num-inference-steps", type=int, default=None)
inference.add_argument("--guidance-scale", type=float, default=None)
inference.add_argument("--terra-scale", type=float, default=None)
inference.add_argument("--seed", type=int, default=42)
inference.add_argument("--num-images-per-prompt", type=int, default=1)
inference.add_argument("--batch-size", type=int, default=1)
def add_runtime_arguments(parser: argparse.ArgumentParser) -> None:
runtime = parser.add_argument_group("runtime")
runtime.add_argument("--dtype", default="bfloat16", choices=tuple(DTYPE_MAP))
runtime.add_argument("--device", default="cuda")
runtime.add_argument("--output", default="output.png")
runtime.add_argument("--local-files-only", action="store_true")
runtime.add_argument("--allow-tf32", action="store_true")
runtime.add_argument("--compile", action="store_true")
runtime.add_argument(
"--quantize",
dest="quantization",
action="store_const",
const=DEFAULT_REQUESTED_QUANTIZATION,
default=argparse.SUPPRESS,
help=argparse.SUPPRESS,
)
runtime.add_argument(
"--quantization",
nargs="?",
const=DEFAULT_REQUESTED_QUANTIZATION,
default=DEFAULT_QUANTIZATION,
choices=QUANTIZATION_CHOICES,
help=(
"TorchAO quantization strategy for transformer, text_encoder, and vae. "
f"Passing the flag without a value uses {DEFAULT_REQUESTED_QUANTIZATION}."
),
)
runtime.add_argument(
"--transformer-quantization",
default=None,
choices=QUANTIZATION_CHOICES,
help="Override the quantization strategy for the image transformer component.",
)
runtime.add_argument(
"--text-encoder-quantization",
default=None,
choices=QUANTIZATION_CHOICES,
help="Override the quantization strategy for the text_encoder component.",
)
runtime.add_argument(
"--vae-quantization",
default=None,
choices=QUANTIZATION_CHOICES,
help="Override the quantization strategy for the VAE component.",
)
runtime.add_argument("--enable-tiling", action="store_true")
runtime.add_argument("--enable-slicing", action="store_true")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run Flux2-Klein MultiDiffusion with local modular blocks.")
model = parser.add_argument_group("model")
model.add_argument("--base-model", required=True)
model.add_argument("--lora-path", "--lora_path", type=Path)
add_input_arguments(parser)
add_canvas_arguments(parser)
add_inference_arguments(parser)
add_runtime_arguments(parser)
return parser
def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
return build_parser().parse_args(argv)
def load_model_config(base_model: str, *, local_files_only: bool) -> dict:
return dict(
DiffusionPipeline.load_config(
base_model,
local_files_only=local_files_only,
)
)
def apply_window_stride_defaults(args: argparse.Namespace) -> argparse.Namespace:
if args.height_generation is not None and args.window_stride_height is None:
args.window_stride_height = args.height_generation // 2
if args.width_generation is not None and args.window_stride_width is None:
args.window_stride_width = args.width_generation // 2
return args
def default_num_inference_steps(model_config: Mapping) -> int | None:
if model_config.get("_class_name") in {"Flux2KleinPipeline", "Flux2KleinModularPipeline"} and model_config.get(
"is_distilled"
):
return 4
return None
def apply_vae_memory_options(pipe, *, enable_tiling: bool, enable_slicing: bool) -> None:
if enable_tiling:
pipe.vae.enable_tiling()
if enable_slicing:
pipe.vae.enable_slicing()
def _validate_quantization_strategy(strategy: str, *, name: str) -> str:
if strategy not in QUANTIZATION_CHOICES:
raise ValueError(f"`{name}` must be one of {QUANTIZATION_CHOICES}, got {strategy!r}.")
return strategy
def resolve_quantization_mapping(
quantization: str | None = DEFAULT_QUANTIZATION,
*,
transformer_quantization: str | None = None,
text_encoder_quantization: str | None = None,
vae_quantization: str | None = None,
) -> dict[str, str]:
base_quantization = _validate_quantization_strategy(
quantization or DEFAULT_QUANTIZATION,
name="quantization",
)
component_strategies = {
"transformer": transformer_quantization if transformer_quantization is not None else base_quantization,
"text_encoder": text_encoder_quantization if text_encoder_quantization is not None else base_quantization,
"vae": vae_quantization if vae_quantization is not None else base_quantization,
}
return {
component: _validate_quantization_strategy(strategy, name=f"{component}_quantization")
for component, strategy in component_strategies.items()
if strategy != "none"
}
def _build_torchao_quant_type(strategy: str):
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8WeightOnlyConfig,
)
if strategy == "float8wo":
return Float8WeightOnlyConfig()
if strategy == "int8wo":
return Int8WeightOnlyConfig()
if strategy == "int4wo":
return Int4WeightOnlyConfig(group_size=128)
if strategy == "float8dyn":
return Float8DynamicActivationFloat8WeightConfig()
raise ValueError(f"Cannot build TorchAO quantization config for {strategy!r}.")
def build_quantization_config(
quantization: str | None = DEFAULT_QUANTIZATION,
*,
transformer_quantization: str | None = None,
text_encoder_quantization: str | None = None,
vae_quantization: str | None = None,
):
quant_mapping = resolve_quantization_mapping(
quantization,
transformer_quantization=transformer_quantization,
text_encoder_quantization=text_encoder_quantization,
vae_quantization=vae_quantization,
)
if not quant_mapping:
return None
from diffusers import PipelineQuantizationConfig
from diffusers import TorchAoConfig as DiffusersTorchAoConfig
from transformers import TorchAoConfig as TransformersTorchAoConfig
component_config_classes = {
"transformer": DiffusersTorchAoConfig,
"text_encoder": TransformersTorchAoConfig,
"vae": DiffusersTorchAoConfig,
}
return PipelineQuantizationConfig(
quant_mapping={
component: component_config_classes[component](_build_torchao_quant_type(strategy))
for component, strategy in quant_mapping.items()
}
)
def build_components_manager(device: torch.device) -> tuple[ComponentsManager, bool]:
manager = ComponentsManager()
use_auto_offload = device.type == "cuda"
if use_auto_offload:
manager.enable_auto_cpu_offload(device=device)
elif device.type == "mps":
print("MPS does not support ComponentsManager auto CPU offload; loading components directly on MPS.")
return manager, use_auto_offload
def init_modular_pipeline(
*,
base_model: str,
model_config: Mapping | None = None,
guidance_scale=None,
dtype: str,
device: str,
local_files_only: bool,
compile: bool,
quantization: str | None = DEFAULT_QUANTIZATION,
transformer_quantization: str | None = None,
text_encoder_quantization: str | None = None,
vae_quantization: str | None = None,
) -> ModularPipeline:
device_obj = torch.device(device)
manager, use_auto_offload = build_components_manager(device_obj)
model_config = model_config or load_model_config(
base_model,
local_files_only=local_files_only,
)
pipe = build_multidiffusion_blocks(model_config).init_pipeline(
base_model,
components_manager=manager,
)
pipe.load_components(
torch_dtype=DTYPE_MAP[dtype],
local_files_only=local_files_only,
quantization_config=build_quantization_config(
quantization,
transformer_quantization=transformer_quantization,
text_encoder_quantization=text_encoder_quantization,
vae_quantization=vae_quantization,
),
)
if not use_auto_offload and device_obj.type != "cpu":
pipe.to(device_obj)
if guidance_scale is not None and hasattr(pipe, "guider") and pipe.guider is not None:
guider_spec = pipe.get_component_spec("guider")
pipe.update_components(
guider=guider_spec.create(guidance_scale=guidance_scale),
)
if compile:
pipe.transformer.compile_repeated_blocks(
fullgraph=True,
dynamic=True,
)
return pipe
def prepare_img2img_latents(pipe, image, *, height: int, width: int, generator):
multiple_of = pipe.vae_scale_factor * 2
vae_encoder = pipe.blocks.sub_blocks["vae_encoder"]
img_conditioning = vae_encoder.sub_blocks["img_conditioning"]
encode_block = img_conditioning.sub_blocks["encode"]
vae_encoder_pipe = encode_block.init_pipeline()
vae_encoder_pipe.update_components(vae=pipe.vae)
target_height = (height // multiple_of) * multiple_of
target_width = (width // multiple_of) * multiple_of
# Do not use the stock vae_encoder block here: its preprocess step clamps to ~1024x1024.
image_tensor = pipe.image_processor.preprocess(
image,
height=target_height,
width=target_width,
resize_mode="default",
)
image_latents = vae_encoder_pipe(
condition_images=[image_tensor],
generator=generator,
).image_latents[0]
return Flux2PrepareImageLatentsStep._pack_latents(image_latents)
def _natural_sort_key(path: Path):
return [int(part) if part.isdigit() else part.lower() for part in re.split(r"(\d+)", path.name)]
def _is_csv_prompt(prompt: str) -> bool:
return Path(prompt).suffix.lower() == ".csv"
def _read_regional_prompt_csv(prompt_csv: Path) -> list[dict[str, str]]:
if not prompt_csv.is_file():
raise FileNotFoundError(f"Prompt CSV does not exist: {prompt_csv}")
with prompt_csv.open(newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
if reader.fieldnames != ["mask", "prompt"]:
raise ValueError(f"Prompt CSV must have exact headers 'mask,prompt', got {reader.fieldnames!r}.")
rows = [{"mask": row["mask"], "prompt": row["prompt"]} for row in reader]
if not rows:
raise ValueError(f"Prompt CSV must contain at least one regional prompt row: {prompt_csv}")
return rows
def _load_grayscale_mask(path: Path, *, width: int, height: int) -> torch.Tensor:
with Image.open(path) as image:
if image.size != (width, height):
raise ValueError(f"Mask {path.name!r} must have size {(width, height)}, got {image.size}.")
mask = torch.from_numpy(np.asarray(image.convert("L"), dtype=np.uint8).copy()).to(torch.float32)
if not torch.all((mask == 0) | (mask == 255)):
warnings.warn(f"Mask {path.name!r} is not binary; values will be used as fractional weights.", stacklevel=2)
return mask / 255.0
def _load_regional_masks(
*,
rows: list[dict[str, str]],
masks_dir: Path,
height: int,
width: int,
vae_scale_factor: int,
) -> torch.Tensor:
if not masks_dir.is_dir():
raise FileNotFoundError(f"Mask folder does not exist: {masks_dir}")
available_masks = {path.name: path for path in sorted(masks_dir.iterdir(), key=_natural_sort_key) if path.is_file()}
masks = []
for row in rows:
mask_name = row["mask"]
if mask_name not in available_masks:
raise FileNotFoundError(f"Mask {mask_name!r} from prompt CSV was not found in {masks_dir}.")
masks.append(_load_grayscale_mask(available_masks[mask_name], width=width, height=height))
latent_height = height // (vae_scale_factor * 2)
latent_width = width // (vae_scale_factor * 2)
regional_masks = F.interpolate(
torch.stack(masks).unsqueeze(1),
size=(latent_height, latent_width),
mode="area",
).squeeze(1)
if torch.any(regional_masks.sum(dim=0) <= 0):
raise ValueError("Regional masks must cover every packed latent cell.")
return regional_masks
def prepare_regional_prompt_inputs(args: argparse.Namespace, pipe) -> tuple[str | list[str], torch.Tensor | None]:
if not _is_csv_prompt(args.prompt):
if args.masks is not None:
raise ValueError("`--masks` is only valid when `--prompt` points to a CSV file.")
return args.prompt, None
if args.masks is None:
raise ValueError("`--masks` is required when `--prompt` points to a CSV file.")
rows = _read_regional_prompt_csv(Path(args.prompt))
regional_masks = _load_regional_masks(
rows=rows,
masks_dir=args.masks,
height=args.height,
width=args.width,
vae_scale_factor=pipe.vae_scale_factor,
)
return ["", *[row["prompt"] for row in rows]], regional_masks
def build_call_kwargs(args: argparse.Namespace, pipe, generator, model_config: Mapping | None = None) -> dict:
model_config = model_config or {}
if args.batch_size <= 0:
raise ValueError(f"`--batch-size` must be a positive integer, got {args.batch_size}.")
prompt, regional_masks = prepare_regional_prompt_inputs(args, pipe)
call_kwargs = {
"prompt": prompt,
"height": args.height,
"width": args.width,
"height_generation": args.height_generation,
"width_generation": args.width_generation,
"window_stride_height": args.window_stride_height,
"window_stride_width": args.window_stride_width,
"window_stride_height_offset": args.window_stride_height_offset,
"window_stride_width_offset": args.window_stride_width_offset,
"panorama_width": args.panorama_width,
"panorama_height": args.panorama_height,
"weighting_type": args.weighting_type,
"weighting_range": args.weighting_range,
"generator": generator,
"num_images_per_prompt": args.num_images_per_prompt,
"window_batch_size": args.batch_size,
}
if regional_masks is not None:
call_kwargs["regional_masks"] = regional_masks
if args.num_inference_steps is not None:
call_kwargs["num_inference_steps"] = args.num_inference_steps
elif (steps := default_num_inference_steps(model_config)) is not None:
call_kwargs["num_inference_steps"] = steps
if args.image_img2img is not None:
call_kwargs["image_img2img"] = prepare_img2img_latents(
pipe,
load_image(args.image_img2img),
height=args.height,
width=args.width,
generator=generator,
)
call_kwargs["strength"] = args.strength
if args.image_conditioning is not None:
call_kwargs["image"] = load_image(args.image_conditioning)
return call_kwargs
def load_lora_adapter(pipe, *, lora_path: Path | None, terra_scale: float | None, num_images_per_prompt: int) -> None:
if lora_path is None:
return
pipe.transformer.load_lora_adapter(
lora_path,
weight_name=LORA_WEIGHT_NAME_SAFE,
)
print(f"Loaded LoRA weights from {lora_path}")
print(pipe.transformer)
if terra_scale is not None:
pipe.transformer.set_terra_t(
[terra_scale] * num_images_per_prompt,
adapter=None,
)
def save_images(images, output: str | Path, *, num_images_per_prompt: int) -> None:
output_path = Path(output)
for i, image in enumerate(images):
use_bigtiff = image.size[0] * image.size[1] > BIGTIFF_PIXEL_THRESHOLD
save_path = output_path
if num_images_per_prompt > 1:
save_path = output_path.with_stem(f"{output_path.stem}_{i}")
save_path = _resolve_output_path_for_image(save_path, image=image, use_bigtiff=use_bigtiff)
if use_bigtiff:
image.save(save_path, format="TIFF", big_tiff=True)
else:
image.save(save_path)
def _resolve_output_path_for_image(output_path: Path, *, image, use_bigtiff: bool) -> Path:
if output_path.suffix == "":
return output_path.with_suffix(".tif" if use_bigtiff else ".png")
if use_bigtiff and output_path.suffix.lower() not in TIFF_SUFFIXES:
bigtiff_path = output_path.with_suffix(".tif")
warnings.warn(
(
f"Image size {image.size[0]}x{image.size[1]} exceeds {BIGTIFF_PIXEL_THRESHOLD} pixels; "
f"saving BigTIFF to {bigtiff_path} instead of {output_path}."
),
stacklevel=2,
)
return bigtiff_path
return output_path
def main(argv: Sequence[str] | None = None) -> None:
args = apply_window_stride_defaults(parse_args(argv))
configure_torch(allow_tf32=args.allow_tf32)
model_config = load_model_config(
args.base_model,
local_files_only=args.local_files_only,
)
print("Initializing modular pipeline...")
print(args)
pipe = init_modular_pipeline(
base_model=args.base_model,
model_config=model_config,
guidance_scale=args.guidance_scale,
dtype=args.dtype,
device=args.device,
local_files_only=args.local_files_only,
compile=args.compile,
quantization=args.quantization,
transformer_quantization=args.transformer_quantization,
text_encoder_quantization=args.text_encoder_quantization,
vae_quantization=args.vae_quantization,
)
apply_vae_memory_options(
pipe,
enable_tiling=args.enable_tiling,
enable_slicing=args.enable_slicing,
)
load_lora_adapter(
pipe,
lora_path=args.lora_path,
terra_scale=args.terra_scale,
num_images_per_prompt=args.num_images_per_prompt,
)
generator = torch.Generator().manual_seed(args.seed)
call_kwargs = build_call_kwargs(
args,
pipe,
generator,
model_config,
)
output = pipe(**call_kwargs)
save_images(
output.images,
args.output,
num_images_per_prompt=args.num_images_per_prompt,
)
if __name__ == "__main__":
main()