Spaces:
Running
Running
| import gradio as gr | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import torch | |
| from torchvision import transforms | |
| from diffusers import AutoencoderKL, LCMScheduler | |
| from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline | |
| from controlnet import ControlNetModel | |
| # Define helper functions | |
| def download_image(url): | |
| response = requests.get(url) | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |
| def load_model(): | |
| # Load model components | |
| controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16) | |
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
| pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae) | |
| pipe.to('cuda') | |
| return pipe | |
| pipe = load_model() | |
| # Define the inpainting function | |
| def inpaint(image, mask): | |
| # Process image and mask | |
| image = image.resize((1024, 1024)).convert("RGB") | |
| mask = mask.resize((1024, 1024)).convert("L") | |
| # Transform to tensor | |
| image_transform = transforms.ToTensor() | |
| image_tensor = image_transform(image).unsqueeze(0).to('cuda') | |
| mask_tensor = image_transform(mask).unsqueeze(0).to('cuda') | |
| mask_tensor = (mask_tensor > 0.5).float() # binarize mask | |
| # Generate image | |
| with torch.no_grad(): | |
| result = pipe(prompt="A park bench", init_image=image_tensor, mask_image=mask_tensor, num_inference_steps=50).images[0] | |
| return transforms.ToPILImage()(result.squeeze(0)) | |
| # Define the interface | |
| interface = gr.Interface(fn=inpaint, | |
| inputs=[gr.inputs.Image(type="pil", label="Original Image"), gr.inputs.Image(type="pil", label="Mask Image")], | |
| outputs=gr.outputs.Image(type="pil", label="Inpainted Image"), | |
| title="Stable Diffusion XL ControlNet Inpainting", | |
| description="Upload an image and its corresponding mask to inpaint the specified area.") | |
| if __name__ == "__main__": | |
| interface.launch() |