|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from XMem2.util.palette import davis_palette
|
|
|
from XMem2.util.range_transform import im_normalization
|
|
|
|
|
|
|
|
|
def image_to_torch(frame: np.ndarray, device='cuda'):
|
|
|
|
|
|
frame = frame.transpose(2, 0, 1)
|
|
|
frame = torch.from_numpy(frame).float().to(device) / 255
|
|
|
frame_norm = im_normalization(frame)
|
|
|
return frame_norm, frame
|
|
|
|
|
|
|
|
|
def torch_prob_to_numpy_mask(prob):
|
|
|
mask = torch.argmax(prob, dim=0)
|
|
|
mask = mask.cpu().numpy().astype(np.uint8)
|
|
|
return mask
|
|
|
|
|
|
|
|
|
def index_numpy_to_one_hot_torch(mask, num_classes):
|
|
|
mask = torch.from_numpy(mask).long()
|
|
|
return F.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float()
|
|
|
|
|
|
|
|
|
"""
|
|
|
Some constants fro visualization
|
|
|
"""
|
|
|
color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
|
|
|
|
|
|
color_map_np = (color_map_np.astype(np.float32) * 1.5).clip(0, 255).astype(np.uint8)
|
|
|
color_map = color_map_np.tolist()
|
|
|
if torch.cuda.is_available():
|
|
|
color_map_torch = torch.from_numpy(color_map_np).cuda() / 255
|
|
|
|
|
|
grayscale_weights = np.array([[0.3, 0.59, 0.11]]).astype(np.float32)
|
|
|
if torch.cuda.is_available():
|
|
|
grayscale_weights_torch = torch.from_numpy(grayscale_weights).cuda().unsqueeze(0)
|
|
|
|
|
|
|
|
|
def get_visualization(mode, image, mask, layer, target_object):
|
|
|
if mode == 'fade':
|
|
|
return overlay_davis(image, mask, fade=True)
|
|
|
elif mode == 'davis':
|
|
|
return overlay_davis(image, mask)
|
|
|
elif mode == 'light':
|
|
|
return overlay_davis(image, mask, 0.9)
|
|
|
elif mode == 'popup':
|
|
|
return overlay_popup(image, mask, target_object)
|
|
|
elif mode == 'layered':
|
|
|
if layer is None:
|
|
|
print('Layer file not given. Defaulting to DAVIS.')
|
|
|
return overlay_davis(image, mask)
|
|
|
else:
|
|
|
return overlay_layer(image, mask, layer, target_object)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
def get_visualization_torch(mode, image, prob, layer, target_object):
|
|
|
if mode == 'fade':
|
|
|
return overlay_davis_torch(image, prob, fade=True)
|
|
|
elif mode == 'davis':
|
|
|
return overlay_davis_torch(image, prob)
|
|
|
elif mode == 'light':
|
|
|
return overlay_davis_torch(image, prob, 0.9)
|
|
|
elif mode == 'popup':
|
|
|
return overlay_popup_torch(image, prob, target_object)
|
|
|
elif mode == 'layered':
|
|
|
if layer is None:
|
|
|
print('Layer file not given. Defaulting to DAVIS.')
|
|
|
return overlay_davis_torch(image, prob)
|
|
|
else:
|
|
|
return overlay_layer_torch(image, prob, layer, target_object)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
def overlay_davis(image, mask, alpha=0.5, fade=False):
|
|
|
"""Overlay segmentation on top of RGB image. from davis official"""
|
|
|
im_overlay = image.copy()
|
|
|
|
|
|
colored_mask = color_map_np[mask]
|
|
|
foreground = image * alpha + (1 - alpha) * colored_mask
|
|
|
binary_mask = mask > 0
|
|
|
|
|
|
im_overlay[binary_mask] = foreground[binary_mask]
|
|
|
if fade:
|
|
|
im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6
|
|
|
return im_overlay.astype(image.dtype)
|
|
|
|
|
|
|
|
|
def overlay_popup(image, mask, target_object):
|
|
|
|
|
|
im_overlay = image.copy()
|
|
|
|
|
|
binary_mask = ~(np.isin(mask, target_object))
|
|
|
colored_region = (im_overlay[binary_mask] * grayscale_weights).sum(-1, keepdims=-1)
|
|
|
im_overlay[binary_mask] = colored_region
|
|
|
return im_overlay.astype(image.dtype)
|
|
|
|
|
|
|
|
|
def overlay_layer(image, mask, layer, target_object):
|
|
|
|
|
|
|
|
|
|
|
|
obj_mask = (np.isin(mask, target_object)).astype(np.float32)
|
|
|
layer_alpha = layer[:, :, 3].astype(np.float32) / 255
|
|
|
layer_rgb = layer[:, :, :3]
|
|
|
background_alpha = np.maximum(obj_mask, layer_alpha)[:, :, np.newaxis]
|
|
|
obj_mask = obj_mask[:, :, np.newaxis]
|
|
|
im_overlay = (
|
|
|
image * (1 - background_alpha) + layer_rgb * (1 - obj_mask) + image * obj_mask
|
|
|
).clip(0, 255)
|
|
|
return im_overlay.astype(image.dtype)
|
|
|
|
|
|
|
|
|
def overlay_davis_torch(image, mask, alpha=0.5, fade=False):
|
|
|
"""Overlay segmentation on top of RGB image. from davis official"""
|
|
|
|
|
|
image = image.permute(1, 2, 0)
|
|
|
im_overlay = image
|
|
|
mask = torch.argmax(mask, dim=0)
|
|
|
|
|
|
colored_mask = color_map_torch[mask]
|
|
|
foreground = image * alpha + (1 - alpha) * colored_mask
|
|
|
binary_mask = mask > 0
|
|
|
|
|
|
im_overlay[binary_mask] = foreground[binary_mask]
|
|
|
if fade:
|
|
|
im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6
|
|
|
|
|
|
im_overlay = (im_overlay * 255).cpu().numpy()
|
|
|
im_overlay = im_overlay.astype(np.uint8)
|
|
|
|
|
|
return im_overlay
|
|
|
|
|
|
|
|
|
def overlay_popup_torch(image, mask, target_object):
|
|
|
|
|
|
image = image.permute(1, 2, 0)
|
|
|
|
|
|
if len(target_object) == 0:
|
|
|
obj_mask = torch.zeros_like(mask[0]).unsqueeze(2)
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
obj_mask = mask[np.array(target_object, dtype=np.int32)].sum(0).unsqueeze(2)
|
|
|
gray_image = (image * grayscale_weights_torch).sum(-1, keepdim=True)
|
|
|
im_overlay = obj_mask * image + (1 - obj_mask) * gray_image
|
|
|
|
|
|
im_overlay = (im_overlay * 255).cpu().numpy()
|
|
|
im_overlay = im_overlay.astype(np.uint8)
|
|
|
|
|
|
return im_overlay
|
|
|
|
|
|
|
|
|
def overlay_layer_torch(image, mask, layer, target_object):
|
|
|
|
|
|
|
|
|
|
|
|
image = image.permute(1, 2, 0)
|
|
|
|
|
|
if len(target_object) == 0:
|
|
|
obj_mask = torch.zeros_like(mask[0])
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
obj_mask = mask[np.array(target_object, dtype=np.int32)].sum(0)
|
|
|
layer_alpha = layer[:, :, 3]
|
|
|
layer_rgb = layer[:, :, :3]
|
|
|
background_alpha = torch.maximum(obj_mask, layer_alpha).unsqueeze(2)
|
|
|
obj_mask = obj_mask.unsqueeze(2)
|
|
|
im_overlay = (
|
|
|
image * (1 - background_alpha) + layer_rgb * (1 - obj_mask) + image * obj_mask
|
|
|
).clip(0, 1)
|
|
|
|
|
|
im_overlay = (im_overlay * 255).cpu().numpy()
|
|
|
im_overlay = im_overlay.astype(np.uint8)
|
|
|
|
|
|
return im_overlay
|
|
|
|