|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights |
|
|
from huggingface_hub import hf_hub_download |
|
|
import cv2 |
|
|
|
|
|
|
|
|
repo_id = "JosephCatrambone/big-lama-torchscript" |
|
|
model_path = hf_hub_download(repo_id=repo_id, filename="lama.pt") |
|
|
lama_model = torch.jit.load(model_path, map_location="cpu") |
|
|
lama_model.eval() |
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
|
weights = DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1 |
|
|
model = deeplabv3_resnet50(weights=weights).to(device).eval() |
|
|
preprocess = weights.transforms() |
|
|
|
|
|
MAX_SIDE = 1024 |
|
|
|
|
|
def _resize_if_needed(pil_img: Image.Image, max_side=MAX_SIDE) -> Image.Image: |
|
|
w, h = pil_img.size |
|
|
if max(w, h) <= max_side: |
|
|
return pil_img |
|
|
r = max_side / float(max(w, h)) |
|
|
return pil_img.resize((int(w * r), int(h * r)), Image.BILINEAR) |
|
|
|
|
|
def segment(image: Image.Image): |
|
|
if not isinstance(image, Image.Image): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
image = image.convert("RGB") |
|
|
image = _resize_if_needed(image) |
|
|
|
|
|
|
|
|
x = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 |
|
|
x = x.unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model(x)["out"][0] |
|
|
pred = out.argmax(0).cpu().numpy() |
|
|
|
|
|
|
|
|
fg = (pred != 0).astype(np.uint8) |
|
|
|
|
|
|
|
|
kernel = np.ones((15,15), np.uint8) |
|
|
fg_dilated = cv2.dilate(fg, kernel, iterations=1) |
|
|
|
|
|
mask_img = Image.fromarray((fg_dilated * 255).astype(np.uint8), mode="L") |
|
|
|
|
|
|
|
|
base = image.convert("RGBA") |
|
|
overlay = Image.new("RGBA", base.size, (255, 0, 0, 0)) |
|
|
alpha = Image.fromarray((fg_dilated * 120).astype(np.uint8)) |
|
|
overlay.putalpha(alpha) |
|
|
blended = Image.alpha_composite(base, overlay).convert("RGB") |
|
|
|
|
|
|
|
|
img_np = np.array(image) |
|
|
mask_np = np.array(mask_img) |
|
|
img_t = torch.from_numpy(img_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0 |
|
|
mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).float() / 255.0 |
|
|
with torch.no_grad(): |
|
|
inpainted_t = lama_model(img_t, mask_t) |
|
|
inpainted_np = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
H, W = img_np.shape[:2] |
|
|
inpainted_np = inpainted_np[:H, :W, :] |
|
|
inpainted_img = Image.fromarray(inpainted_np) |
|
|
|
|
|
return blended, mask_img, inpainted_img |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=segment, |
|
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
|
outputs=[ |
|
|
gr.Image(type="pil", label="Overlay (foreground)"), |
|
|
gr.Image(type="pil", label="Binary Mask (foreground=white)"), |
|
|
gr.Image(type="pil", label="inpaint result"), |
|
|
], |
|
|
title="Semantic Segmentation + LaMa Inpainting", |
|
|
description="DeepLabV3 segmentation + LaMa inpainting。", |
|
|
examples=[ |
|
|
["./9F27E2C4-5662-4AA7-A14A-2DE6627EBE8E-14319-000010D528167C0B.PNG"] |
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |