import gradio as gr import torch import numpy as np from PIL import Image from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights from huggingface_hub import hf_hub_download import cv2 # ---------------- 下载并加载 LaMa 官方权重 ---------------- repo_id = "JosephCatrambone/big-lama-torchscript" model_path = hf_hub_download(repo_id=repo_id, filename="lama.pt") lama_model = torch.jit.load(model_path, map_location="cpu") lama_model.eval() # ---- 加载分割模型(CPU) ---- device = torch.device("cpu") weights = DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1 model = deeplabv3_resnet50(weights=weights).to(device).eval() preprocess = weights.transforms() MAX_SIDE = 1024 # 为了速度与内存,限制输入最大边 def _resize_if_needed(pil_img: Image.Image, max_side=MAX_SIDE) -> Image.Image: w, h = pil_img.size if max(w, h) <= max_side: return pil_img r = max_side / float(max(w, h)) return pil_img.resize((int(w * r), int(h * r)), Image.BILINEAR) def segment(image: Image.Image): if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert("RGB") image = _resize_if_needed(image) # 预处理并推理 x = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 x = x.unsqueeze(0).to(device) # [1,3,H,W] with torch.no_grad(): out = model(x)["out"][0] # [C,H,W],C=21 pred = out.argmax(0).cpu().numpy() # [H,W] # 前景 = 非背景(背景类在COCO VOC权重下是0) fg = (pred != 0).astype(np.uint8) # ---------------- mask ---------------- kernel = np.ones((15,15), np.uint8) fg_dilated = cv2.dilate(fg, kernel, iterations=1) mask_img = Image.fromarray((fg_dilated * 255).astype(np.uint8), mode="L") # 叠加彩色遮罩(红色半透明) base = image.convert("RGBA") overlay = Image.new("RGBA", base.size, (255, 0, 0, 0)) alpha = Image.fromarray((fg_dilated * 120).astype(np.uint8)) overlay.putalpha(alpha) blended = Image.alpha_composite(base, overlay).convert("RGB") # ---- LaMa ---- img_np = np.array(image) # HWC, uint8 mask_np = np.array(mask_img) # H,W, 0/255 img_t = torch.from_numpy(img_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0 mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).float() / 255.0 with torch.no_grad(): inpainted_t = lama_model(img_t, mask_t) # [1,3,H,W] inpainted_np = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8) # ---- 裁剪回原图大小 ---- H, W = img_np.shape[:2] inpainted_np = inpainted_np[:H, :W, :] inpainted_img = Image.fromarray(inpainted_np) return blended, mask_img, inpainted_img # ---- Gradio 界面 ---- demo = gr.Interface( fn=segment, inputs=gr.Image(type="pil", label="Upload Image"), outputs=[ gr.Image(type="pil", label="Overlay (foreground)"), gr.Image(type="pil", label="Binary Mask (foreground=white)"), gr.Image(type="pil", label="inpaint result"), ], title="Semantic Segmentation + LaMa Inpainting", description="DeepLabV3 segmentation + LaMa inpainting。", examples=[ ["./9F27E2C4-5662-4AA7-A14A-2DE6627EBE8E-14319-000010D528167C0B.PNG"] ] ) if __name__ == "__main__": demo.launch()