arlaz's picture
Initial Multidiff Modular export
d8bf41d
from __future__ import annotations
import argparse
from collections.abc import Sequence
from pathlib import Path
import torch
from diffusers import ModularPipeline
from examples.example import (
DTYPE_MAP,
add_canvas_arguments,
add_inference_arguments,
add_input_arguments,
add_runtime_arguments,
apply_vae_memory_options,
apply_window_stride_defaults,
build_call_kwargs,
build_components_manager,
build_quantization_config,
configure_torch,
load_lora_adapter,
save_images,
)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Run Flux.2 MultiDiffusion from a remote Hugging Face Modular Diffusers repo."
)
model = parser.add_argument_group("model")
model.add_argument(
"--repo-id",
required=True,
help="Hub repo containing the exported Modular Diffusers files and remote block.py code.",
)
model.add_argument("--lora-path", "--lora_path", type=Path)
add_input_arguments(parser)
add_canvas_arguments(parser)
add_inference_arguments(parser)
add_runtime_arguments(parser)
return parser
def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
return build_parser().parse_args(argv)
def init_remote_modular_pipeline(
*,
repo_id: str,
guidance_scale=None,
dtype: str,
device: str,
local_files_only: bool,
compile: bool,
quantization: str | None = None,
transformer_quantization: str | None = None,
text_encoder_quantization: str | None = None,
vae_quantization: str | None = None,
) -> ModularPipeline:
device_obj = torch.device(device)
manager, use_auto_offload = build_components_manager(device_obj)
pipe = ModularPipeline.from_pretrained(
repo_id,
trust_remote_code=True,
components_manager=manager,
local_files_only=local_files_only,
)
pipe.load_components(
torch_dtype=DTYPE_MAP[dtype],
local_files_only=local_files_only,
quantization_config=build_quantization_config(
quantization,
transformer_quantization=transformer_quantization,
text_encoder_quantization=text_encoder_quantization,
vae_quantization=vae_quantization,
),
)
if not use_auto_offload and device_obj.type != "cpu":
pipe.to(device_obj)
if guidance_scale is not None and hasattr(pipe, "guider") and pipe.guider is not None:
guider_spec = pipe.get_component_spec("guider")
pipe.update_components(
guider=guider_spec.create(guidance_scale=guidance_scale),
)
if compile:
pipe.transformer.compile_repeated_blocks(
fullgraph=True,
dynamic=True,
)
return pipe
def main(argv: Sequence[str] | None = None) -> None:
args = apply_window_stride_defaults(parse_args(argv))
configure_torch(allow_tf32=args.allow_tf32)
print("Initializing remote modular pipeline...")
print(args)
pipe = init_remote_modular_pipeline(
repo_id=args.repo_id,
guidance_scale=args.guidance_scale,
dtype=args.dtype,
device=args.device,
local_files_only=args.local_files_only,
compile=args.compile,
quantization=args.quantization,
transformer_quantization=args.transformer_quantization,
text_encoder_quantization=args.text_encoder_quantization,
vae_quantization=args.vae_quantization,
)
apply_vae_memory_options(
pipe,
enable_tiling=args.enable_tiling,
enable_slicing=args.enable_slicing,
)
load_lora_adapter(
pipe,
lora_path=args.lora_path,
terra_scale=args.terra_scale,
num_images_per_prompt=args.num_images_per_prompt,
)
generator = torch.Generator().manual_seed(args.seed)
call_kwargs = build_call_kwargs(
args,
pipe,
generator,
getattr(pipe, "config", {}),
)
output = pipe(**call_kwargs)
save_images(
output.images,
args.output,
num_images_per_prompt=args.num_images_per_prompt,
)
if __name__ == "__main__":
main()