|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from diffusers import AutoPipelineForInpainting, AutoencoderKL |
|
|
import torch |
|
|
from SegBody import segment_body |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) |
|
|
pipeline = AutoPipelineForInpainting.from_pretrained( |
|
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
|
vae=vae, |
|
|
torch_dtype=torch.float16, |
|
|
variant="fp16", |
|
|
use_safetensors=True |
|
|
).to(device) |
|
|
else: |
|
|
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32) |
|
|
pipeline = AutoPipelineForInpainting.from_pretrained( |
|
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
|
vae=vae, |
|
|
torch_dtype=torch.float32, |
|
|
variant="fp16", |
|
|
use_safetensors=True |
|
|
).to(device) |
|
|
|
|
|
|
|
|
def inpaint(person_image, garment_image, prompt): |
|
|
|
|
|
person_image = person_image.convert("RGB").resize((512, 512)) |
|
|
garment_image = garment_image.convert("RGB").resize((512, 512)) |
|
|
|
|
|
|
|
|
seg_image, mask_image = segment_body(person_image, face=False) |
|
|
|
|
|
|
|
|
mask_image = mask_image.resize((512, 512)) |
|
|
|
|
|
|
|
|
results = pipeline( |
|
|
prompt=prompt, |
|
|
negative_prompt="ugly, bad quality, bad anatomy", |
|
|
image=person_image, |
|
|
mask_image=mask_image, |
|
|
ip_adapter_image=garment_image, |
|
|
strength=0.99, |
|
|
guidance_scale=8.0, |
|
|
num_inference_steps=100 |
|
|
) |
|
|
|
|
|
return results.images[0] |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=inpaint, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Person Image"), |
|
|
gr.Image(type="pil", label="Garment Image"), |
|
|
gr.Textbox(label="Prompt", placeholder="Enter the prompt for the model") |
|
|
], |
|
|
outputs=gr.Image(type="pil"), |
|
|
title="Stable Diffusion Inpainting with Segmentation", |
|
|
description="Inpainting model for seamless garment transfer on segmented body image using Stable Diffusion XL.", |
|
|
server_timeout=100, |
|
|
) |
|
|
|
|
|
demo.launch(share=True) |
|
|
|