backup / scripts /SKY.py
killbill007's picture
Upload 754 files
93871a1 verified
raw
history blame
1.83 kB
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
import numpy as np
import cv2
import os
import torch
from skimage.metrics import structural_similarity as ssim
class SKY:
def __init__(self):
self.device = ("cuda" if torch.cuda.is_available() else "cpu")
self.processor = AutoImageProcessor.from_pretrained("models/mask2former-swin-large-ade-semantic")
self.model = Mask2FormerForUniversalSegmentation.from_pretrained("models/mask2former-swin-large-ade-semantic").to(self.device)
def _SKY_mask(self,image_rgb):
inputs = self.processor(images=image_rgb, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
inputs.to("cpu")
del inputs
predicted_semantic_map = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[image_rgb.shape[:2]])[0]
mask = predicted_semantic_map.cpu().numpy()
predicted_semantic_map.to("cpu")
del predicted_semantic_map
mask_np = (mask == 2)
if np.count_nonzero(mask_np)==0:
return None
return mask_np
def SKY_remove(self,image_rgb,gamma=None):
mask = self._SKY_mask(image_rgb)
if mask is None:
return image_rgb
binary_mask = np.uint8(mask) * 255
if gamma:
binary_mask = cv2.GaussianBlur(~binary_mask, (15, 15), gamma)
binary_mask = ~binary_mask
image_bgra = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2RGBA)
image_bgra[:, :, 3] = ~binary_mask
return image_bgra
def __del__(self):
self.model = None
self.processor = None
del self.model
del self.processor
torch.cuda.empty_cache()
import gc
gc.collect()
if __name__ == "__main__":
pass