|
|
import time |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
from config import logger, REMBG_AVAILABLE |
|
|
|
|
|
if REMBG_AVAILABLE: |
|
|
import rembg |
|
|
from rembg import new_session |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class RembgProcessor: |
|
|
"""rembg抠图处理器""" |
|
|
|
|
|
def __init__(self): |
|
|
start_time = time.perf_counter() |
|
|
self.session = None |
|
|
self.available = False |
|
|
self.model_name = "u2net" |
|
|
|
|
|
if REMBG_AVAILABLE: |
|
|
try: |
|
|
|
|
|
self.session = new_session(self.model_name) |
|
|
self.available = True |
|
|
logger.info(f"rembg background removal processor initialized successfully, using model: {self.model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"rembg background removal processor initialization failed: {e}") |
|
|
self.available = False |
|
|
else: |
|
|
logger.warning("rembg is not available, background removal function will be disabled") |
|
|
init_time = time.perf_counter() - start_time |
|
|
if self.available: |
|
|
logger.info(f"RembgProcessor initialized successfully, time: {init_time:.3f}s") |
|
|
else: |
|
|
logger.info(f"RembgProcessor initialization completed but not available, time: {init_time:.3f}s") |
|
|
|
|
|
def is_available(self) -> bool: |
|
|
"""检查抠图处理器是否可用""" |
|
|
return self.available and self.session is not None |
|
|
|
|
|
def remove_background(self, image: np.ndarray, background_color: Optional[Tuple[int, int, int]] = None) -> np.ndarray: |
|
|
""" |
|
|
移除图片背景 |
|
|
:param image: 输入的OpenCV图像(BGR格式) |
|
|
:param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景 |
|
|
:return: 处理后的图像 |
|
|
""" |
|
|
if not self.is_available(): |
|
|
raise Exception("rembg抠图处理器不可用") |
|
|
|
|
|
try: |
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(image_rgb) |
|
|
|
|
|
|
|
|
logger.info("Starting to remove background using rembg...") |
|
|
output_image = rembg.remove(pil_image, session=self.session) |
|
|
|
|
|
|
|
|
if background_color is not None: |
|
|
|
|
|
background = Image.new('RGB', output_image.size, background_color[::-1]) |
|
|
|
|
|
background.paste(output_image, mask=output_image) |
|
|
result_array = np.array(background) |
|
|
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR) |
|
|
else: |
|
|
|
|
|
result_array = np.array(output_image) |
|
|
if result_array.shape[2] == 4: |
|
|
|
|
|
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA) |
|
|
else: |
|
|
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
logger.info("rembg background removal completed") |
|
|
return result_bgr |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"rembg background removal failed: {e}") |
|
|
raise Exception(f"背景移除失败: {str(e)}") |
|
|
|
|
|
def create_id_photo(self, image: np.ndarray, background_color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray: |
|
|
""" |
|
|
创建证件照(移除背景并添加纯色背景) |
|
|
:param image: 输入的OpenCV图像 |
|
|
:param background_color: 背景颜色,默认白色(BGR格式) |
|
|
:return: 处理后的证件照 |
|
|
""" |
|
|
logger.info(f"Starting to create ID photo, background color: {background_color}") |
|
|
|
|
|
|
|
|
id_photo = self.remove_background(image, background_color) |
|
|
|
|
|
logger.info("ID photo creation completed") |
|
|
return id_photo |
|
|
|
|
|
def get_supported_models(self) -> list: |
|
|
"""获取支持的模型列表""" |
|
|
if not REMBG_AVAILABLE: |
|
|
return [] |
|
|
|
|
|
|
|
|
return [ |
|
|
"u2net", |
|
|
"u2net_human_seg", |
|
|
"silueta", |
|
|
"isnet-general-use" |
|
|
] |
|
|
|
|
|
def switch_model(self, model_name: str) -> bool: |
|
|
""" |
|
|
切换rembg模型 |
|
|
:param model_name: 模型名称 |
|
|
:return: 是否切换成功 |
|
|
""" |
|
|
if not REMBG_AVAILABLE: |
|
|
return False |
|
|
|
|
|
try: |
|
|
if model_name in self.get_supported_models(): |
|
|
self.session = new_session(model_name) |
|
|
self.model_name = model_name |
|
|
logger.info(f"rembg model switched to: {model_name}") |
|
|
return True |
|
|
else: |
|
|
logger.error(f"Unsupported model: {model_name}") |
|
|
return False |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to switch model: {e}") |
|
|
return False |
|
|
|