File size: 5,627 Bytes
cd5aabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import cv2
import numpy as np
import torch
from torchvision import transforms

import config
from config import logger


class RVMProcessor:
    """RVM (Robust Video Matting) 抠图处理器"""

    def __init__(self):
        self.model = None
        self.available = False
        self.device = "cpu"  # 默认使用CPU,如果有GPU可以设置为"cuda"

        try:
            # 仅从本地加载,不使用网络
            local_repo = getattr(config, 'RVM_LOCAL_REPO', '')
            weights_path = getattr(config, 'RVM_WEIGHTS_PATH', '')

            if not local_repo or not os.path.isdir(local_repo):
                raise RuntimeError("RVM_LOCAL_REPO not set or invalid. Please set env RVM_LOCAL_REPO to local RobustVideoMatting repo path (with hubconf.py)")

            if not weights_path or not os.path.isfile(weights_path):
                raise RuntimeError("RVM_WEIGHTS_PATH not set or file not found. Please set env RVM_WEIGHTS_PATH to local RVM weights file path")

            logger.info(f"Loading RVM model {config.RVM_MODEL} from local repo: {local_repo}")
            # 使用本地仓库构建模型,禁用预训练以避免联网
            self.model = torch.hub.load(local_repo, config.RVM_MODEL, source='local', pretrained=False)

            # 加载本地权重
            state = torch.load(weights_path, map_location=self.device)
            if isinstance(state, dict) and 'state_dict' in state:
                state = state['state_dict']
            missing, unexpected = self.model.load_state_dict(state, strict=False)

            # 迁移到设备并设置评估模式
            self.model = self.model.to(self.device).eval()
            self.available = True
            logger.info("RVM background removal processor initialized successfully (local mode)")
            if missing:
                logger.warning(f"RVM weights missing keys: {list(missing)[:5]}... total={len(missing)}")
            if unexpected:
                logger.warning(f"RVM weights unexpected keys: {list(unexpected)[:5]}... total={len(unexpected)}")

        except Exception as e:
            logger.error(f"RVM background removal processor initialization failed: {e}")
            self.available = False

    def is_available(self) -> bool:
        """检查RVM处理器是否可用"""
        return self.available and self.model is not None

    def remove_background(self, image: np.ndarray, background_color: tuple = None) -> np.ndarray:
        """
        使用RVM移除图片背景
        :param image: 输入的OpenCV图像(BGR格式)
        :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
        :return: 处理后的图像
        """
        if not self.is_available():
            raise Exception("RVM抠图处理器不可用")

        try:
            logger.info("Starting to remove background using RVM...")

            # 保存原始图像尺寸
            original_height, original_width = image.shape[:2]

            # 将OpenCV图像(BGR)转换为RGB格式
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # 转换为tensor
            src = transforms.ToTensor()(image_rgb).unsqueeze(0).to(self.device)

            # 推理
            rec = [None] * 4
            with torch.no_grad():
                fgr, pha, *rec = self.model(src, *rec, downsample_ratio=0.25)

            # 转换为numpy数组
            fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)  # (H,W,3)
            pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8)                # (H,W)

            # 检查尺寸是否匹配,如果不匹配则调整
            if fgr.shape[:2] != (original_height, original_width):
                fgr = cv2.resize(fgr, (original_width, original_height))
                pha = cv2.resize(pha, (original_width, original_height))

            if background_color is not None:
                # 如果指定了背景颜色,创建纯色背景
                # 将前景图像转换为BGR格式
                fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)

                # 创建背景图像
                background = np.full((original_height, original_width, 3), background_color, dtype=np.uint8)

                # 使用alpha混合
                alpha = pha.astype(np.float32) / 255.0
                alpha = np.stack([alpha] * 3, axis=-1)

                result = (fgr_bgr * alpha + background * (1 - alpha)).astype(np.uint8)
            else:
                # 保持透明背景,转换为BGRA格式
                fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR)
                rgba = np.dstack((fgr_bgr, pha))  # (H,W,4)
                result = rgba

            logger.info("RVM background removal completed")
            return result

        except Exception as e:
            logger.error(f"RVM background removal failed: {e}")
            raise Exception(f"背景移除失败: {str(e)}")

    def create_id_photo(self, image: np.ndarray, background_color: tuple = (255, 255, 255)) -> np.ndarray:
        """
        创建证件照(移除背景并添加纯色背景)
        :param image: 输入的OpenCV图像
        :param background_color: 背景颜色,默认白色(BGR格式)
        :return: 处理后的证件照
        """
        logger.info(f"Starting to create ID photo, background color: {background_color}")

        # 移除背景并添加指定颜色背景
        id_photo = self.remove_background(image, background_color)

        logger.info("ID photo creation completed")
        return id_photo