File size: 1,825 Bytes
93871a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
import numpy as np
import cv2
import os
import torch
from skimage.metrics import structural_similarity as ssim

class SKY:
    def __init__(self):
        self.device = ("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = AutoImageProcessor.from_pretrained("models/mask2former-swin-large-ade-semantic")
        self.model = Mask2FormerForUniversalSegmentation.from_pretrained("models/mask2former-swin-large-ade-semantic").to(self.device)

    def _SKY_mask(self,image_rgb):

        inputs = self.processor(images=image_rgb, return_tensors="pt").to(self.device)

        outputs = self.model(**inputs)
        inputs.to("cpu")
        del inputs
        predicted_semantic_map = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[image_rgb.shape[:2]])[0]
        mask = predicted_semantic_map.cpu().numpy()
        predicted_semantic_map.to("cpu")
        del predicted_semantic_map

        mask_np = (mask == 2)

        if np.count_nonzero(mask_np)==0:
            return None

        return mask_np

    def SKY_remove(self,image_rgb,gamma=None):


        mask = self._SKY_mask(image_rgb)
        if mask is None:
            return image_rgb
        binary_mask = np.uint8(mask) * 255


        if gamma:
            binary_mask = cv2.GaussianBlur(~binary_mask, (15, 15), gamma)
            binary_mask = ~binary_mask

        image_bgra = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2RGBA)
        image_bgra[:, :, 3] = ~binary_mask

        return image_bgra


    def __del__(self):
        self.model = None
        self.processor = None
        del self.model
        del self.processor
        torch.cuda.empty_cache()
        import gc
        gc.collect()

if __name__ == "__main__":
    pass