import gradio as gr from PIL import Image from diffusers import AutoPipelineForInpainting, AutoencoderKL import torch from SegBody import segment_body # Import the segmentation function # Check if CUDA is available and set the device accordingly device = "cuda" if torch.cuda.is_available() else "cpu" # Load models with the correct precision based on the device if device == "cuda": vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) # Use fp16 for GPU pipeline = AutoPipelineForInpainting.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", vae=vae, torch_dtype=torch.float16, # Use fp16 for GPU variant="fp16", # Correct variant for GPU use_safetensors=True ).to(device) # Ensure it uses the GPU else: vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32) # Use fp32 for CPU pipeline = AutoPipelineForInpainting.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", vae=vae, torch_dtype=torch.float32, # Use fp32 for CPU variant="fp16", # Use fp32 for CPU use_safetensors=True ).to(device) # Ensure it uses the CPU if no GPU # Define the inference function def inpaint(person_image, garment_image, prompt): # Preprocess the images by resizing them to 512x512 person_image = person_image.convert("RGB").resize((512, 512)) garment_image = garment_image.convert("RGB").resize((512, 512)) # Use segment_body to generate the body mask for inpainting seg_image, mask_image = segment_body(person_image, face=False) # You can control face removal here (face=False) # Resize mask to 512x512 to match the inpainting requirements mask_image = mask_image.resize((512, 512)) # Perform inpainting using the pipeline results = pipeline( prompt=prompt, negative_prompt="ugly, bad quality, bad anatomy", image=person_image, mask_image=mask_image, # Use the mask from segmentation ip_adapter_image=garment_image, # Garment image as the IP Adapter image strength=0.99, guidance_scale=8.0, num_inference_steps=100 ) return results.images[0] # Return the generated image # Set up the Gradio interface demo = gr.Interface( fn=inpaint, inputs=[ gr.Image(type="pil", label="Person Image"), # Input for person image gr.Image(type="pil", label="Garment Image"), # Input for garment image gr.Textbox(label="Prompt", placeholder="Enter the prompt for the model") # Text prompt for inpainting ], outputs=gr.Image(type="pil"), title="Stable Diffusion Inpainting with Segmentation", description="Inpainting model for seamless garment transfer on segmented body image using Stable Diffusion XL.", server_timeout=100, # Increase timeout duration to prevent session errors ) demo.launch(share=True) # Enable share link for testing in a public domain