import cv2 import os import numpy as np import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize from .briarmbg import BriaRMBG import PIL from PIL import Image from typing import Tuple class BG: def __init__(self): self.net = BriaRMBG.from_pretrained("./models/RMBG-1.4") self.device = "cpu" def _resize_image(self,image): image = image.convert('RGB') model_input_size = (1024, 1024) image = image.resize(model_input_size, Image.BILINEAR) return image def _BG_mask(self, image_rgb): orig_image = Image.fromarray(image_rgb) w, h = orig_image.size image_rgb = self._resize_image(orig_image) im_np = np.array(image_rgb) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) im_tensor = torch.unsqueeze(im_tensor, 0) im_tensor = torch.divide(im_tensor, 255.0) im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) im_tensor = im_tensor.to(self.device) # inference with torch.no_grad(): result = self.net(im_tensor) # post process result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi) threshold = 0.5 mask_np = torch.where(result > threshold, torch.tensor(1), torch.tensor(0)) mask_np = 1 - mask_np mask_np = mask_np.squeeze(0).cpu().numpy().astype(np.uint8) if np.count_nonzero(mask_np) == 0: return None # Set kernel size based on image size kernel_size = max(w, h) // 30 # Adjust this factor according to your preference # Morphological operations to remove gaps kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) processed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel) processed_mask = cv2.morphologyEx(processed_mask, cv2.MORPH_CLOSE, kernel) # Additional dilation and erosion to remove small gaps within the mask processed_mask = cv2.dilate(processed_mask, kernel, iterations=2) processed_mask = cv2.erode(processed_mask, kernel, iterations=1) # Mask off the areas specified by the processed mask new_mask = cv2.bitwise_and(mask_np, processed_mask) return new_mask def BG_remove(self,image_rgb,gamma=None): mask = self._BG_mask(image_rgb) if mask is None: return image_rgb binary_mask = np.uint8(mask) * 255 if gamma: binary_mask = cv2.GaussianBlur(~binary_mask, (15, 15), gamma) binary_mask = ~binary_mask image_bgra = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2RGBA) image_bgra[:, :, 3] = ~binary_mask return image_bgra def __del__(self): del self.net if __name__ == "__main__": pass