coconutiscoding's picture
Update app.py
7bd64ab verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from huggingface_hub import hf_hub_download
import cv2
# ---------------- 下载并加载 LaMa 官方权重 ----------------
repo_id = "JosephCatrambone/big-lama-torchscript"
model_path = hf_hub_download(repo_id=repo_id, filename="lama.pt")
lama_model = torch.jit.load(model_path, map_location="cpu")
lama_model.eval()
# ---- 加载分割模型(CPU) ----
device = torch.device("cpu")
weights = DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1
model = deeplabv3_resnet50(weights=weights).to(device).eval()
preprocess = weights.transforms()
MAX_SIDE = 1024 # 为了速度与内存,限制输入最大边
def _resize_if_needed(pil_img: Image.Image, max_side=MAX_SIDE) -> Image.Image:
w, h = pil_img.size
if max(w, h) <= max_side:
return pil_img
r = max_side / float(max(w, h))
return pil_img.resize((int(w * r), int(h * r)), Image.BILINEAR)
def segment(image: Image.Image):
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
image = _resize_if_needed(image)
# 预处理并推理
x = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
x = x.unsqueeze(0).to(device) # [1,3,H,W]
with torch.no_grad():
out = model(x)["out"][0] # [C,H,W],C=21
pred = out.argmax(0).cpu().numpy() # [H,W]
# 前景 = 非背景(背景类在COCO VOC权重下是0)
fg = (pred != 0).astype(np.uint8)
# ---------------- mask ----------------
kernel = np.ones((15,15), np.uint8)
fg_dilated = cv2.dilate(fg, kernel, iterations=1)
mask_img = Image.fromarray((fg_dilated * 255).astype(np.uint8), mode="L")
# 叠加彩色遮罩(红色半透明)
base = image.convert("RGBA")
overlay = Image.new("RGBA", base.size, (255, 0, 0, 0))
alpha = Image.fromarray((fg_dilated * 120).astype(np.uint8))
overlay.putalpha(alpha)
blended = Image.alpha_composite(base, overlay).convert("RGB")
# ---- LaMa ----
img_np = np.array(image) # HWC, uint8
mask_np = np.array(mask_img) # H,W, 0/255
img_t = torch.from_numpy(img_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0
mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).float() / 255.0
with torch.no_grad():
inpainted_t = lama_model(img_t, mask_t) # [1,3,H,W]
inpainted_np = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
# ---- 裁剪回原图大小 ----
H, W = img_np.shape[:2]
inpainted_np = inpainted_np[:H, :W, :]
inpainted_img = Image.fromarray(inpainted_np)
return blended, mask_img, inpainted_img
# ---- Gradio 界面 ----
demo = gr.Interface(
fn=segment,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=[
gr.Image(type="pil", label="Overlay (foreground)"),
gr.Image(type="pil", label="Binary Mask (foreground=white)"),
gr.Image(type="pil", label="inpaint result"),
],
title="Semantic Segmentation + LaMa Inpainting",
description="DeepLabV3 segmentation + LaMa inpainting。",
examples=[
["./9F27E2C4-5662-4AA7-A14A-2DE6627EBE8E-14319-000010D528167C0B.PNG"]
]
)
if __name__ == "__main__":
demo.launch()