Spaces:
Runtime error
Runtime error
File size: 4,192 Bytes
29fad46 7d263a3 29fad46 7d263a3 29fad46 4f49166 29fad46 4f49166 29fad46 7d263a3 c2d9880 7d263a3 c2d9880 7d263a3 c2d9880 7d263a3 c2d9880 7d263a3 29fad46 7d263a3 29fad46 7d263a3 29fad46 7d263a3 29fad46 7d263a3 29fad46 7d263a3 29fad46 c2d9880 7d263a3 29fad46 7d263a3 29fad46 7d263a3 29fad46 7d263a3 29fad46 7d263a3 29fad46 7d263a3 29fad46 989ac55 7d263a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | from diffusers.utils import load_image, make_image_grid
import gradio as gr
import cv2
import numpy as np
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
import torch
from PIL import Image
from Unet import UNet
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
human_segment = UNet(n_classes=2, img_channels=3)
human_segment.load_state_dict(torch.load("./unet_weights.pth", map_location=device))
human_segment.to(device)
human_segment.eval()
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", use_safetensors=True)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, use_safetensors=True
).to(device)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
def create_mask(img_path, invert=True):
"""Generate a binary mask using the custom segmentation model.
If invert=True, the mask will target the background instead of the human."""
img = Image.open(img_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
pred = human_segment(img_tensor)[0]
pred_class = torch.argmax(pred, dim=0).cpu().numpy()
if invert:
mask = (pred_class == 0).astype(np.uint8) * 255
else:
mask = (pred_class > 0).astype(np.uint8) * 255
mask = Image.fromarray(mask)
return mask
def load_and_resize_images(image, target_size=(512, 512)):
"""Load and resize images for inpainting."""
if isinstance(image, str):
init_image = Image.open(image).convert("RGB").resize(target_size)
else:
init_image = image.convert("RGB").resize(target_size)
return init_image
def make_inpaint_condition(image, image_mask):
"""Prepare the condition image for inpainting."""
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[:2] == image_mask.shape[:2]
image[image_mask > 0.5] = -1.0
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
return torch.from_numpy(image)
def generate_inpainting(init_image, mask_image, prompt, negative_prompt=None):
"""Generate the inpainted image."""
control_image = make_inpaint_condition(init_image, mask_image)
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=50,
eta=1.0,
image=init_image,
mask_image=mask_image,
control_image=control_image,
).images[0]
return output
def process_with_auto_mask(image_path, prompt, negative_prompt=None, invert_mask=True):
"""Process input image with automatic mask generation and inpainting.
If invert_mask=True, the background will be inpainted instead of the human."""
mask_image = create_mask(image_path, invert=invert_mask)
init_image = load_and_resize_images(image_path)
mask_image = mask_image.resize(init_image.size)
output_image = generate_inpainting(init_image, mask_image, prompt, negative_prompt)
return mask_image, output_image
demo = gr.Interface(
fn=process_with_auto_mask,
inputs=[
gr.Image(type='filepath', label="Original Image"),
gr.Textbox(label="Prompt", placeholder="Describe what should replace the masked area..."),
gr.Textbox(label="Negative Prompt", placeholder="Elements to avoid in the generated image...",
value="low quality, bad anatomy, blurry, pixelated")
],
outputs=[
gr.Image(label="Generated Mask"),
gr.Image(label="Inpainted Result")
],
title="Automatic Mask & Inpainting Tool",
description="Upload an image, and the system will automatically create a mask and perform inpainting based on your prompt.",
allow_flagging="never"
)
demo.launch(share=True)
|