Spaces:
Runtime error
Runtime error
| from diffusers.utils import load_image, make_image_grid | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler | |
| import torch | |
| from PIL import Image | |
| from Unet import UNet | |
| from torchvision import transforms | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| human_segment = UNet(n_classes=2, img_channels=3) | |
| human_segment.load_state_dict(torch.load("./unet_weights.pth", map_location=device)) | |
| human_segment.to(device) | |
| human_segment.eval() | |
| controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", use_safetensors=True) | |
| pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", controlnet=controlnet, use_safetensors=True | |
| ).to(device) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| ]) | |
| def create_mask(img_path, invert=True): | |
| """Generate a binary mask using the custom segmentation model. | |
| If invert=True, the mask will target the background instead of the human.""" | |
| img = Image.open(img_path).convert("RGB") | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| pred = human_segment(img_tensor)[0] | |
| pred_class = torch.argmax(pred, dim=0).cpu().numpy() | |
| if invert: | |
| mask = (pred_class == 0).astype(np.uint8) * 255 | |
| else: | |
| mask = (pred_class > 0).astype(np.uint8) * 255 | |
| mask = Image.fromarray(mask) | |
| return mask | |
| def load_and_resize_images(image, target_size=(512, 512)): | |
| """Load and resize images for inpainting.""" | |
| if isinstance(image, str): | |
| init_image = Image.open(image).convert("RGB").resize(target_size) | |
| else: | |
| init_image = image.convert("RGB").resize(target_size) | |
| return init_image | |
| def make_inpaint_condition(image, image_mask): | |
| """Prepare the condition image for inpainting.""" | |
| image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 | |
| image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 | |
| assert image.shape[:2] == image_mask.shape[:2] | |
| image[image_mask > 0.5] = -1.0 | |
| image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) | |
| return torch.from_numpy(image) | |
| def generate_inpainting(init_image, mask_image, prompt, negative_prompt=None): | |
| """Generate the inpainted image.""" | |
| control_image = make_inpaint_condition(init_image, mask_image) | |
| output = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=50, | |
| eta=1.0, | |
| image=init_image, | |
| mask_image=mask_image, | |
| control_image=control_image, | |
| ).images[0] | |
| return output | |
| def process_with_auto_mask(image_path, prompt, negative_prompt=None, invert_mask=True): | |
| """Process input image with automatic mask generation and inpainting. | |
| If invert_mask=True, the background will be inpainted instead of the human.""" | |
| mask_image = create_mask(image_path, invert=invert_mask) | |
| init_image = load_and_resize_images(image_path) | |
| mask_image = mask_image.resize(init_image.size) | |
| output_image = generate_inpainting(init_image, mask_image, prompt, negative_prompt) | |
| return mask_image, output_image | |
| demo = gr.Interface( | |
| fn=process_with_auto_mask, | |
| inputs=[ | |
| gr.Image(type='filepath', label="Original Image"), | |
| gr.Textbox(label="Prompt", placeholder="Describe what should replace the masked area..."), | |
| gr.Textbox(label="Negative Prompt", placeholder="Elements to avoid in the generated image...", | |
| value="low quality, bad anatomy, blurry, pixelated") | |
| ], | |
| outputs=[ | |
| gr.Image(label="Generated Mask"), | |
| gr.Image(label="Inpainted Result") | |
| ], | |
| title="Automatic Mask & Inpainting Tool", | |
| description="Upload an image, and the system will automatically create a mask and perform inpainting based on your prompt.", | |
| allow_flagging="never" | |
| ) | |
| demo.launch(share=True) | |