Spaces:
Sleeping
Sleeping
| import torch | |
| import argparse | |
| from diffusers.utils import load_image, check_min_version | |
| from controlnet_flux import FluxControlNetModel | |
| from transformer_flux import FluxTransformer2DModel | |
| from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline | |
| def main(image, mask, prompt): | |
| check_min_version("0.30.2") | |
| # Enable memory optimizations | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.cuda.empty_cache() | |
| torch.backends.cudnn.benchmark = True | |
| # Set environment variable for memory allocation | |
| import os | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" | |
| # Build pipeline components | |
| controlnet = FluxControlNetModel.from_pretrained( | |
| "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda") | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda") | |
| pipe = FluxControlNetInpaintingPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| controlnet=controlnet, | |
| transformer=transformer, | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda") | |
| # Enable memory efficient attention | |
| pipe.enable_attention_slicing(1) | |
| # Load and process images | |
| size = (384, 384) # or even (256, 256) | |
| image = image.convert("RGB").resize(size) | |
| mask = mask.convert("RGB").resize(size) | |
| # Set generator | |
| generator = torch.Generator(device="cuda").manual_seed(24) | |
| # Run inference with memory optimizations | |
| with torch.cuda.amp.autocast(): # Enable automatic mixed precision | |
| result = pipe( | |
| prompt=prompt, | |
| height=size[1], | |
| width=size[0], | |
| control_image=image, | |
| control_mask=mask, | |
| num_inference_steps=28, | |
| generator=generator, | |
| controlnet_conditioning_scale=0.9, | |
| guidance_scale=3.5, | |
| negative_prompt="", | |
| true_guidance_scale=1.0, | |
| ).images[0] | |
| # Clear cache after generation | |
| torch.cuda.empty_cache() | |
| print("Successfully inpaint image") | |
| return result | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Inpaint an image using FluxControlNetInpaintingPipeline." | |
| ) | |
| parser.add_argument( | |
| "--image_path", type=str, required=True, help="Path to the input image." | |
| ) | |
| parser.add_argument( | |
| "--mask_path", type=str, required=True, help="Path to the mask image." | |
| ) | |
| parser.add_argument( | |
| "--prompt", type=str, required=True, help="Prompt for the inpainting process." | |
| ) | |
| args = parser.parse_args() | |
| result = main(args.image_path, args.mask_path, args.prompt) | |
| result.save("output.png") | |