File size: 2,993 Bytes
4608d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import cv2
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from .briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
class BG:
def __init__(self):
self.net = BriaRMBG.from_pretrained("./models/RMBG-1.4")
self.device = "cpu"
def _resize_image(self,image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def _BG_mask(self, image_rgb):
orig_image = Image.fromarray(image_rgb)
w, h = orig_image.size
image_rgb = self._resize_image(orig_image)
im_np = np.array(image_rgb)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
im_tensor = im_tensor.to(self.device)
# inference
with torch.no_grad():
result = self.net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
threshold = 0.5
mask_np = torch.where(result > threshold, torch.tensor(1), torch.tensor(0))
mask_np = 1 - mask_np
mask_np = mask_np.squeeze(0).cpu().numpy().astype(np.uint8)
if np.count_nonzero(mask_np) == 0:
return None
# Set kernel size based on image size
kernel_size = max(w, h) // 30 # Adjust this factor according to your preference
# Morphological operations to remove gaps
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
processed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel)
processed_mask = cv2.morphologyEx(processed_mask, cv2.MORPH_CLOSE, kernel)
# Additional dilation and erosion to remove small gaps within the mask
processed_mask = cv2.dilate(processed_mask, kernel, iterations=2)
processed_mask = cv2.erode(processed_mask, kernel, iterations=1)
# Mask off the areas specified by the processed mask
new_mask = cv2.bitwise_and(mask_np, processed_mask)
return new_mask
def BG_remove(self,image_rgb,gamma=None):
mask = self._BG_mask(image_rgb)
if mask is None:
return image_rgb
binary_mask = np.uint8(mask) * 255
if gamma:
binary_mask = cv2.GaussianBlur(~binary_mask, (15, 15), gamma)
binary_mask = ~binary_mask
image_bgra = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2RGBA)
image_bgra[:, :, 3] = ~binary_mask
return image_bgra
def __del__(self):
del self.net
if __name__ == "__main__":
pass |