picpocket2 / test /test_rvm_infer.py
chawin.chen
init
7a6cb13
import time
import cv2
import numpy as np
import torch
from torchvision import transforms
device = "cpu"
# 输入输出路径
input_path = "/opt/data/face/yang.webp"
output_path = "/opt/data/face/output_alpha.webp"
# ✅ 加载预训练模型 (resnet50)
model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50").to(device).eval()
# 开始计时
start = time.time()
# 读图 (BGR->RGB)
img = cv2.imread(input_path)[:, :, ::-1].copy()
src = transforms.ToTensor()(img).unsqueeze(0).to(device)
# 推理
rec = [None] * 4
with torch.no_grad():
fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25)
# 转 numpy
fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3)
pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W)
# 拼接 RGBA
rgba = np.dstack((fgr, pha)) # (H,W,4)
# 保存 WebP (带透明度)
cv2.imwrite(output_path, rgba[:, :, [2,1,0,3]], [cv2.IMWRITE_WEBP_QUALITY, 100]) # 转成 BGRA 顺序
# 结束计时
elapsed = time.time() - start
# 控制台日志输出
print(f"✅ RVM 抠图完成 (透明背景)")
print(f" 输入文件: {input_path}")
print(f" 输出文件: {output_path}")
print(f" 耗时: {elapsed:.3f} 秒 (设备: {device})")