|
|
import cv2 |
|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torchvision.transforms.functional import normalize |
|
|
from .briarmbg import BriaRMBG |
|
|
import PIL |
|
|
from PIL import Image |
|
|
from typing import Tuple |
|
|
|
|
|
class BG: |
|
|
def __init__(self): |
|
|
self.net = BriaRMBG.from_pretrained("./models/RMBG-1.4") |
|
|
self.device = "cpu" |
|
|
|
|
|
def _resize_image(self,image): |
|
|
image = image.convert('RGB') |
|
|
model_input_size = (1024, 1024) |
|
|
image = image.resize(model_input_size, Image.BILINEAR) |
|
|
return image |
|
|
|
|
|
def _BG_mask(self, image_rgb): |
|
|
orig_image = Image.fromarray(image_rgb) |
|
|
w, h = orig_image.size |
|
|
image_rgb = self._resize_image(orig_image) |
|
|
im_np = np.array(image_rgb) |
|
|
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) |
|
|
im_tensor = torch.unsqueeze(im_tensor, 0) |
|
|
im_tensor = torch.divide(im_tensor, 255.0) |
|
|
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) |
|
|
im_tensor = im_tensor.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
result = self.net(im_tensor) |
|
|
|
|
|
|
|
|
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0) |
|
|
ma = torch.max(result) |
|
|
mi = torch.min(result) |
|
|
result = (result - mi) / (ma - mi) |
|
|
|
|
|
threshold = 0.5 |
|
|
mask_np = torch.where(result > threshold, torch.tensor(1), torch.tensor(0)) |
|
|
mask_np = 1 - mask_np |
|
|
|
|
|
mask_np = mask_np.squeeze(0).cpu().numpy().astype(np.uint8) |
|
|
|
|
|
if np.count_nonzero(mask_np) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
kernel_size = max(w, h) // 30 |
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) |
|
|
processed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel) |
|
|
processed_mask = cv2.morphologyEx(processed_mask, cv2.MORPH_CLOSE, kernel) |
|
|
|
|
|
|
|
|
processed_mask = cv2.dilate(processed_mask, kernel, iterations=2) |
|
|
processed_mask = cv2.erode(processed_mask, kernel, iterations=1) |
|
|
|
|
|
|
|
|
new_mask = cv2.bitwise_and(mask_np, processed_mask) |
|
|
|
|
|
return new_mask |
|
|
|
|
|
def BG_remove(self,image_rgb,gamma=None): |
|
|
|
|
|
|
|
|
mask = self._BG_mask(image_rgb) |
|
|
if mask is None: |
|
|
return image_rgb |
|
|
binary_mask = np.uint8(mask) * 255 |
|
|
|
|
|
|
|
|
if gamma: |
|
|
binary_mask = cv2.GaussianBlur(~binary_mask, (15, 15), gamma) |
|
|
binary_mask = ~binary_mask |
|
|
|
|
|
image_bgra = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2RGBA) |
|
|
image_bgra[:, :, 3] = ~binary_mask |
|
|
|
|
|
return image_bgra |
|
|
|
|
|
def __del__(self): |
|
|
del self.net |
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |