File size: 2,994 Bytes
93871a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import cv2
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from .briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

class BG:
    def __init__(self):
        self.net = BriaRMBG.from_pretrained("./models/RMBG-1.4")
        self.device = "cpu"

    def  _resize_image(self,image):
        image = image.convert('RGB')
        model_input_size = (1024, 1024)
        image = image.resize(model_input_size, Image.BILINEAR)
        return image

    def _BG_mask(self, image_rgb):
        orig_image = Image.fromarray(image_rgb)
        w, h = orig_image.size
        image_rgb = self._resize_image(orig_image)
        im_np = np.array(image_rgb)
        im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
        im_tensor = torch.unsqueeze(im_tensor, 0)
        im_tensor = torch.divide(im_tensor, 255.0)
        im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
        im_tensor = im_tensor.to(self.device)

        # inference
        with torch.no_grad():
            result = self.net(im_tensor)

        # post process
        result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
        ma = torch.max(result)
        mi = torch.min(result)
        result = (result - mi) / (ma - mi)

        threshold = 0.5
        mask_np = torch.where(result > threshold, torch.tensor(1), torch.tensor(0))
        mask_np = 1 - mask_np

        mask_np = mask_np.squeeze(0).cpu().numpy().astype(np.uint8)

        if np.count_nonzero(mask_np) == 0:
            return None

        # Set kernel size based on image size
        kernel_size = max(w, h) // 30  # Adjust this factor according to your preference

        # Morphological operations to remove gaps
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
        processed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel)
        processed_mask = cv2.morphologyEx(processed_mask, cv2.MORPH_CLOSE, kernel)

        # Additional dilation and erosion to remove small gaps within the mask
        processed_mask = cv2.dilate(processed_mask, kernel, iterations=2)
        processed_mask = cv2.erode(processed_mask, kernel, iterations=1)

        # Mask off the areas specified by the processed mask
        new_mask = cv2.bitwise_and(mask_np, processed_mask)

        return new_mask

    def BG_remove(self,image_rgb,gamma=None):


        mask = self._BG_mask(image_rgb)
        if mask is None:
            return image_rgb
        binary_mask = np.uint8(mask) * 255


        if gamma:
            binary_mask = cv2.GaussianBlur(~binary_mask, (15, 15), gamma)
            binary_mask = ~binary_mask

        image_bgra = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2RGBA)
        image_bgra[:, :, 3] = ~binary_mask

        return image_bgra

    def __del__(self):
        del self.net

if __name__ == "__main__":
    pass