picpocket / rembg_processor.py
chawin.chen
init
cd5aabe
import time
from typing import Optional, Tuple
import cv2
import numpy as np
from config import logger, REMBG_AVAILABLE
if REMBG_AVAILABLE:
import rembg
from rembg import new_session
from PIL import Image
class RembgProcessor:
"""rembg抠图处理器"""
def __init__(self):
start_time = time.perf_counter()
self.session = None
self.available = False
self.model_name = "u2net" # 默认使用u2net模型,适合人像抠图
if REMBG_AVAILABLE:
try:
# 初始化rembg会话
self.session = new_session(self.model_name)
self.available = True
logger.info(f"rembg background removal processor initialized successfully, using model: {self.model_name}")
except Exception as e:
logger.error(f"rembg background removal processor initialization failed: {e}")
self.available = False
else:
logger.warning("rembg is not available, background removal function will be disabled")
init_time = time.perf_counter() - start_time
if self.available:
logger.info(f"RembgProcessor initialized successfully, time: {init_time:.3f}s")
else:
logger.info(f"RembgProcessor initialization completed but not available, time: {init_time:.3f}s")
def is_available(self) -> bool:
"""检查抠图处理器是否可用"""
return self.available and self.session is not None
def remove_background(self, image: np.ndarray, background_color: Optional[Tuple[int, int, int]] = None) -> np.ndarray:
"""
移除图片背景
:param image: 输入的OpenCV图像(BGR格式)
:param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景
:return: 处理后的图像
"""
if not self.is_available():
raise Exception("rembg抠图处理器不可用")
try:
# 将OpenCV图像(BGR)转换为PIL图像(RGB)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image_rgb)
# 使用rembg移除背景
logger.info("Starting to remove background using rembg...")
output_image = rembg.remove(pil_image, session=self.session)
# 转换回OpenCV格式
if background_color is not None:
# 如果指定了背景颜色,创建纯色背景
background = Image.new('RGB', output_image.size, background_color[::-1]) # BGR转RGB
# 将透明图像粘贴到背景上
background.paste(output_image, mask=output_image)
result_array = np.array(background)
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
else:
# 保持透明背景,转换为BGRA格式
result_array = np.array(output_image)
if result_array.shape[2] == 4: # RGBA格式
# 转换RGBA到BGRA
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA)
else: # RGB格式
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
logger.info("rembg background removal completed")
return result_bgr
except Exception as e:
logger.error(f"rembg background removal failed: {e}")
raise Exception(f"背景移除失败: {str(e)}")
def create_id_photo(self, image: np.ndarray, background_color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray:
"""
创建证件照(移除背景并添加纯色背景)
:param image: 输入的OpenCV图像
:param background_color: 背景颜色,默认白色(BGR格式)
:return: 处理后的证件照
"""
logger.info(f"Starting to create ID photo, background color: {background_color}")
# 移除背景并添加指定颜色背景
id_photo = self.remove_background(image, background_color)
logger.info("ID photo creation completed")
return id_photo
def get_supported_models(self) -> list:
"""获取支持的模型列表"""
if not REMBG_AVAILABLE:
return []
# rembg支持的模型列表
return [
"u2net", # 通用模型,适合人像
"u2net_human_seg", # 专门针对人像的模型
"silueta", # 适合物体抠图
"isnet-general-use" # 更精确的通用模型
]
def switch_model(self, model_name: str) -> bool:
"""
切换rembg模型
:param model_name: 模型名称
:return: 是否切换成功
"""
if not REMBG_AVAILABLE:
return False
try:
if model_name in self.get_supported_models():
self.session = new_session(model_name)
self.model_name = model_name
logger.info(f"rembg model switched to: {model_name}")
return True
else:
logger.error(f"Unsupported model: {model_name}")
return False
except Exception as e:
logger.error(f"Failed to switch model: {e}")
return False