File size: 5,627 Bytes
cd5aabe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import cv2
import numpy as np
import torch
from torchvision import transforms
import config
from config import logger
class RVMProcessor:
"""RVM (Robust Video Matting) 抠图处理器"""
def __init__(self):
self.model = None
self.available = False
self.device = "cpu" # 默认使用CPU,如果有GPU可以设置为"cuda"
try:
# 仅从本地加载,不使用网络
local_repo = getattr(config, 'RVM_LOCAL_REPO', '')
weights_path = getattr(config, 'RVM_WEIGHTS_PATH', '')
if not local_repo or not os.path.isdir(local_repo):
raise RuntimeError("RVM_LOCAL_REPO not set or invalid. Please set env RVM_LOCAL_REPO to local RobustVideoMatting repo path (with hubconf.py)")
if not weights_path or not os.path.isfile(weights_path):
raise RuntimeError("RVM_WEIGHTS_PATH not set or file not found. Please set env RVM_WEIGHTS_PATH to local RVM weights file path")
logger.info(f"Loading RVM model {config.RVM_MODEL} from local repo: {local_repo}")
# 使用本地仓库构建模型,禁用预训练以避免联网
self.model = torch.hub.load(local_repo, config.RVM_MODEL, source='local', pretrained=False)
# 加载本地权重
state = torch.load(weights_path, map_location=self.device)
if isinstance(state, dict) and 'state_dict' in state:
state = state['state_dict']
missing, unexpected = self.model.load_state_dict(state, strict=False)
# 迁移到设备并设置评估模式
self.model = self.model.to(self.device).eval()
self.available = True
logger.info("RVM background removal processor initialized successfully (local mode)")
if missing:
logger.warning(f"RVM weights missing keys: {list(missing)[:5]}... total={len(missing)}")
if unexpected:
logger.warning(f"RVM weights unexpected keys: {list(unexpected)[:5]}... total={len(unexpected)}")
except Exception as e:
logger.error(f"RVM background removal processor initialization failed: {e}")
self.available = False
def is_available(self) -> bool:
"""检查RVM处理器是否可用"""
return self.available and self.model is not None
def remove_background(self, image: np.ndarray, background_color: tuple = None) -> np.ndarray:
"""
使用RVM移除图片背景
:param image: 输入的OpenCV图像(BGR格式)
:param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
:return: 处理后的图像
"""
if not self.is_available():
raise Exception("RVM抠图处理器不可用")
try:
logger.info("Starting to remove background using RVM...")
# 保存原始图像尺寸
original_height, original_width = image.shape[:2]
# 将OpenCV图像(BGR)转换为RGB格式
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 转换为tensor
src = transforms.ToTensor()(image_rgb).unsqueeze(0).to(self.device)
# 推理
rec = [None] * 4
with torch.no_grad():
fgr, pha, *rec = self.model(src, *rec, downsample_ratio=0.25)
# 转换为numpy数组
fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3)
pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W)
# 检查尺寸是否匹配,如果不匹配则调整
if fgr.shape[:2] != (original_height, original_width):
fgr = cv2.resize(fgr, (original_width, original_height))
pha = cv2.resize(pha, (original_width, original_height))
if background_color is not None:
# 如果指定了背景颜色,创建纯色背景
# 将前景图像转换为BGR格式
fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
# 创建背景图像
background = np.full((original_height, original_width, 3), background_color, dtype=np.uint8)
# 使用alpha混合
alpha = pha.astype(np.float32) / 255.0
alpha = np.stack([alpha] * 3, axis=-1)
result = (fgr_bgr * alpha + background * (1 - alpha)).astype(np.uint8)
else:
# 保持透明背景,转换为BGRA格式
fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
rgba = np.dstack((fgr_bgr, pha)) # (H,W,4)
result = rgba
logger.info("RVM background removal completed")
return result
except Exception as e:
logger.error(f"RVM background removal failed: {e}")
raise Exception(f"背景移除失败: {str(e)}")
def create_id_photo(self, image: np.ndarray, background_color: tuple = (255, 255, 255)) -> np.ndarray:
"""
创建证件照(移除背景并添加纯色背景)
:param image: 输入的OpenCV图像
:param background_color: 背景颜色,默认白色(BGR格式)
:return: 处理后的证件照
"""
logger.info(f"Starting to create ID photo, background color: {background_color}")
# 移除背景并添加指定颜色背景
id_photo = self.remove_background(image, background_color)
logger.info("ID photo creation completed")
return id_photo
|