from __future__ import annotations import argparse from collections.abc import Sequence from pathlib import Path import torch from diffusers import ModularPipeline from examples.example import ( DTYPE_MAP, add_canvas_arguments, add_inference_arguments, add_input_arguments, add_runtime_arguments, apply_vae_memory_options, apply_window_stride_defaults, build_call_kwargs, build_components_manager, build_quantization_config, configure_torch, load_lora_adapter, save_images, ) def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Run Flux.2 MultiDiffusion from a remote Hugging Face Modular Diffusers repo." ) model = parser.add_argument_group("model") model.add_argument( "--repo-id", required=True, help="Hub repo containing the exported Modular Diffusers files and remote block.py code.", ) model.add_argument("--lora-path", "--lora_path", type=Path) add_input_arguments(parser) add_canvas_arguments(parser) add_inference_arguments(parser) add_runtime_arguments(parser) return parser def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: return build_parser().parse_args(argv) def init_remote_modular_pipeline( *, repo_id: str, guidance_scale=None, dtype: str, device: str, local_files_only: bool, compile: bool, quantization: str | None = None, transformer_quantization: str | None = None, text_encoder_quantization: str | None = None, vae_quantization: str | None = None, ) -> ModularPipeline: device_obj = torch.device(device) manager, use_auto_offload = build_components_manager(device_obj) pipe = ModularPipeline.from_pretrained( repo_id, trust_remote_code=True, components_manager=manager, local_files_only=local_files_only, ) pipe.load_components( torch_dtype=DTYPE_MAP[dtype], local_files_only=local_files_only, quantization_config=build_quantization_config( quantization, transformer_quantization=transformer_quantization, text_encoder_quantization=text_encoder_quantization, vae_quantization=vae_quantization, ), ) if not use_auto_offload and device_obj.type != "cpu": pipe.to(device_obj) if guidance_scale is not None and hasattr(pipe, "guider") and pipe.guider is not None: guider_spec = pipe.get_component_spec("guider") pipe.update_components( guider=guider_spec.create(guidance_scale=guidance_scale), ) if compile: pipe.transformer.compile_repeated_blocks( fullgraph=True, dynamic=True, ) return pipe def main(argv: Sequence[str] | None = None) -> None: args = apply_window_stride_defaults(parse_args(argv)) configure_torch(allow_tf32=args.allow_tf32) print("Initializing remote modular pipeline...") print(args) pipe = init_remote_modular_pipeline( repo_id=args.repo_id, guidance_scale=args.guidance_scale, dtype=args.dtype, device=args.device, local_files_only=args.local_files_only, compile=args.compile, quantization=args.quantization, transformer_quantization=args.transformer_quantization, text_encoder_quantization=args.text_encoder_quantization, vae_quantization=args.vae_quantization, ) apply_vae_memory_options( pipe, enable_tiling=args.enable_tiling, enable_slicing=args.enable_slicing, ) load_lora_adapter( pipe, lora_path=args.lora_path, terra_scale=args.terra_scale, num_images_per_prompt=args.num_images_per_prompt, ) generator = torch.Generator().manual_seed(args.seed) call_kwargs = build_call_kwargs( args, pipe, generator, getattr(pipe, "config", {}), ) output = pipe(**call_kwargs) save_images( output.images, args.output, num_images_per_prompt=args.num_images_per_prompt, ) if __name__ == "__main__": main()