Spaces:
Paused
Paused
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| import config | |
| from config import logger | |
| class RVMProcessor: | |
| """RVM (Robust Video Matting) 抠图处理器""" | |
| def __init__(self): | |
| self.model = None | |
| self.available = False | |
| self.device = "cpu" # 默认使用CPU,如果有GPU可以设置为"cuda" | |
| try: | |
| # 仅从本地加载,不使用网络 | |
| local_repo = getattr(config, 'RVM_LOCAL_REPO', '') | |
| weights_path = getattr(config, 'RVM_WEIGHTS_PATH', '') | |
| if not local_repo or not os.path.isdir(local_repo): | |
| raise RuntimeError("RVM_LOCAL_REPO not set or invalid. Please set env RVM_LOCAL_REPO to local RobustVideoMatting repo path (with hubconf.py)") | |
| if not weights_path or not os.path.isfile(weights_path): | |
| raise RuntimeError("RVM_WEIGHTS_PATH not set or file not found. Please set env RVM_WEIGHTS_PATH to local RVM weights file path") | |
| logger.info(f"Loading RVM model {config.RVM_MODEL} from local repo: {local_repo}") | |
| # 使用本地仓库构建模型,禁用预训练以避免联网 | |
| self.model = torch.hub.load(local_repo, config.RVM_MODEL, source='local', pretrained=False) | |
| # 加载本地权重 | |
| state = torch.load(weights_path, map_location=self.device) | |
| if isinstance(state, dict) and 'state_dict' in state: | |
| state = state['state_dict'] | |
| missing, unexpected = self.model.load_state_dict(state, strict=False) | |
| # 迁移到设备并设置评估模式 | |
| self.model = self.model.to(self.device).eval() | |
| self.available = True | |
| logger.info("RVM background removal processor initialized successfully (local mode)") | |
| if missing: | |
| logger.warning(f"RVM weights missing keys: {list(missing)[:5]}... total={len(missing)}") | |
| if unexpected: | |
| logger.warning(f"RVM weights unexpected keys: {list(unexpected)[:5]}... total={len(unexpected)}") | |
| except Exception as e: | |
| logger.error(f"RVM background removal processor initialization failed: {e}") | |
| self.available = False | |
| def is_available(self) -> bool: | |
| """检查RVM处理器是否可用""" | |
| return self.available and self.model is not None | |
| def remove_background(self, image: np.ndarray, background_color: tuple = None) -> np.ndarray: | |
| """ | |
| 使用RVM移除图片背景 | |
| :param image: 输入的OpenCV图像(BGR格式) | |
| :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景 | |
| :return: 处理后的图像 | |
| """ | |
| if not self.is_available(): | |
| raise Exception("RVM抠图处理器不可用") | |
| try: | |
| logger.info("Starting to remove background using RVM...") | |
| # 保存原始图像尺寸 | |
| original_height, original_width = image.shape[:2] | |
| # 将OpenCV图像(BGR)转换为RGB格式 | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # 转换为tensor | |
| src = transforms.ToTensor()(image_rgb).unsqueeze(0).to(self.device) | |
| # 推理 | |
| rec = [None] * 4 | |
| with torch.no_grad(): | |
| fgr, pha, *rec = self.model(src, *rec, downsample_ratio=0.25) | |
| # 转换为numpy数组 | |
| fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3) | |
| pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W) | |
| # 检查尺寸是否匹配,如果不匹配则调整 | |
| if fgr.shape[:2] != (original_height, original_width): | |
| fgr = cv2.resize(fgr, (original_width, original_height)) | |
| pha = cv2.resize(pha, (original_width, original_height)) | |
| if background_color is not None: | |
| # 如果指定了背景颜色,创建纯色背景 | |
| # 将前景图像转换为BGR格式 | |
| fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR) | |
| # 创建背景图像 | |
| background = np.full((original_height, original_width, 3), background_color, dtype=np.uint8) | |
| # 使用alpha混合 | |
| alpha = pha.astype(np.float32) / 255.0 | |
| alpha = np.stack([alpha] * 3, axis=-1) | |
| result = (fgr_bgr * alpha + background * (1 - alpha)).astype(np.uint8) | |
| else: | |
| # 保持透明背景,转换为BGRA格式 | |
| fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR) | |
| rgba = np.dstack((fgr_bgr, pha)) # (H,W,4) | |
| result = rgba | |
| logger.info("RVM background removal completed") | |
| return result | |
| except Exception as e: | |
| logger.error(f"RVM background removal failed: {e}") | |
| raise Exception(f"背景移除失败: {str(e)}") | |
| def create_id_photo(self, image: np.ndarray, background_color: tuple = (255, 255, 255)) -> np.ndarray: | |
| """ | |
| 创建证件照(移除背景并添加纯色背景) | |
| :param image: 输入的OpenCV图像 | |
| :param background_color: 背景颜色,默认白色(BGR格式) | |
| :return: 处理后的证件照 | |
| """ | |
| logger.info(f"Starting to create ID photo, background color: {background_color}") | |
| # 移除背景并添加指定颜色背景 | |
| id_photo = self.remove_background(image, background_color) | |
| logger.info("ID photo creation completed") | |
| return id_photo | |