File size: 5,340 Bytes
cd5aabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import time
from typing import Optional, Tuple

import cv2
import numpy as np

from config import logger, REMBG_AVAILABLE

if REMBG_AVAILABLE:
    import rembg
    from rembg import new_session
    from PIL import Image


class RembgProcessor:
    """rembg抠图处理器"""

    def __init__(self):
        start_time = time.perf_counter()
        self.session = None
        self.available = False
        self.model_name = "u2net"  # 默认使用u2net模型,适合人像抠图

        if REMBG_AVAILABLE:
            try:
                # 初始化rembg会话
                self.session = new_session(self.model_name)
                self.available = True
                logger.info(f"rembg background removal processor initialized successfully, using model: {self.model_name}")
            except Exception as e:
                logger.error(f"rembg background removal processor initialization failed: {e}")
                self.available = False
        else:
            logger.warning("rembg is not available, background removal function will be disabled")
        init_time = time.perf_counter() - start_time
        if self.available:
            logger.info(f"RembgProcessor initialized successfully, time: {init_time:.3f}s")
        else:
            logger.info(f"RembgProcessor initialization completed but not available, time: {init_time:.3f}s")

    def is_available(self) -> bool:
        """检查抠图处理器是否可用"""
        return self.available and self.session is not None

    def remove_background(self, image: np.ndarray, background_color: Optional[Tuple[int, int, int]] = None) -> np.ndarray:
        """
        移除图片背景
        :param image: 输入的OpenCV图像(BGR格式)
        :param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
        :return: 处理后的图像
        """
        if not self.is_available():
            raise Exception("rembg抠图处理器不可用")

        try:
            # 将OpenCV图像(BGR)转换为PIL图像(RGB)
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image_rgb)

            # 使用rembg移除背景
            logger.info("Starting to remove background using rembg...")
            output_image = rembg.remove(pil_image, session=self.session)

            # 转换回OpenCV格式
            if background_color is not None:
                # 如果指定了背景颜色,创建纯色背景
                background = Image.new('RGB', output_image.size, background_color[::-1])  # BGR转RGB
                # 将透明图像粘贴到背景上
                background.paste(output_image, mask=output_image)
                result_array = np.array(background)
                result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
            else:
                # 保持透明背景,转换为BGRA格式
                result_array = np.array(output_image)
                if result_array.shape[2] == 4:  # RGBA格式
                    # 转换RGBA到BGRA
                    result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA)
                else:  # RGB格式
                    result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)

            logger.info("rembg background removal completed")
            return result_bgr

        except Exception as e:
            logger.error(f"rembg background removal failed: {e}")
            raise Exception(f"背景移除失败: {str(e)}")

    def create_id_photo(self, image: np.ndarray, background_color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray:
        """
        创建证件照(移除背景并添加纯色背景)
        :param image: 输入的OpenCV图像
        :param background_color: 背景颜色,默认白色(BGR格式)
        :return: 处理后的证件照
        """
        logger.info(f"Starting to create ID photo, background color: {background_color}")

        # 移除背景并添加指定颜色背景
        id_photo = self.remove_background(image, background_color)

        logger.info("ID photo creation completed")
        return id_photo

    def get_supported_models(self) -> list:
        """获取支持的模型列表"""
        if not REMBG_AVAILABLE:
            return []

        # rembg支持的模型列表
        return [
            "u2net",          # 通用模型,适合人像
            "u2net_human_seg", # 专门针对人像的模型
            "silueta",        # 适合物体抠图
            "isnet-general-use" # 更精确的通用模型
        ]

    def switch_model(self, model_name: str) -> bool:
        """
        切换rembg模型
        :param model_name: 模型名称
        :return: 是否切换成功
        """
        if not REMBG_AVAILABLE:
            return False

        try:
            if model_name in self.get_supported_models():
                self.session = new_session(model_name)
                self.model_name = model_name
                logger.info(f"rembg model switched to: {model_name}")
                return True
            else:
                logger.error(f"Unsupported model: {model_name}")
                return False
        except Exception as e:
            logger.error(f"Failed to switch model: {e}")
            return False