import time import cv2 import numpy as np import torch from torchvision import transforms device = "cpu" # 输入输出路径 input_path = "/opt/data/face/yang.webp" output_path = "/opt/data/face/output_alpha.webp" # ✅ 加载预训练模型 (resnet50) model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50").to(device).eval() # 开始计时 start = time.time() # 读图 (BGR->RGB) img = cv2.imread(input_path)[:, :, ::-1].copy() src = transforms.ToTensor()(img).unsqueeze(0).to(device) # 推理 rec = [None] * 4 with torch.no_grad(): fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25) # 转 numpy fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3) pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W) # 拼接 RGBA rgba = np.dstack((fgr, pha)) # (H,W,4) # 保存 WebP (带透明度) cv2.imwrite(output_path, rgba[:, :, [2,1,0,3]], [cv2.IMWRITE_WEBP_QUALITY, 100]) # 转成 BGRA 顺序 # 结束计时 elapsed = time.time() - start # 控制台日志输出 print(f"✅ RVM 抠图完成 (透明背景)") print(f" 输入文件: {input_path}") print(f" 输出文件: {output_path}") print(f" 耗时: {elapsed:.3f} 秒 (设备: {device})")