Spaces:
Sleeping
Sleeping
| from typing import List, Optional, Tuple, Union | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| from PIL import Image | |
| from torch import Tensor, nn | |
| import torch | |
| from skimage.filters import threshold_otsu | |
| from s_multimae.da.base_da import BaseDataAugmentation | |
| from s_multimae.model_pl import ModelPL | |
| from s_multimae.visualizer import apply_vis_to_image | |
| from .base_model import BaseRGBDModel | |
| from .app_utils import get_size, normalize | |
| from .depth_model import BaseDepthModel | |
| # Environment | |
| torch.set_grad_enabled(False) | |
| from .device import device | |
| print(f"device: {device}") | |
| def post_processing_depth(depth: np.ndarray) -> np.ndarray: | |
| depth = (normalize(depth) * 255).astype(np.uint8) | |
| return cv2.applyColorMap(depth, cv2.COLORMAP_OCEAN) | |
| def base_inference( | |
| depth_model: BaseDepthModel, | |
| sod_model: BaseRGBDModel, | |
| da: BaseDataAugmentation, | |
| raw_image: Union[Image.Image, np.ndarray], | |
| raw_depth: Optional[Union[Image.Image, np.ndarray]] = None, | |
| color: np.ndarray = None, | |
| num_sets_of_salient_objects: int = 1, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Inference a pair of rgb image and depth image | |
| if depth image is not provided, the depth_model will predict a depth image based on image | |
| """ | |
| origin_size = get_size(raw_image) | |
| # Predict depth | |
| image = TF.to_tensor(raw_image) | |
| origin_shape = image.shape | |
| if raw_depth is None: | |
| depth: Tensor = depth_model.forward(image) | |
| else: | |
| depth = TF.to_tensor(raw_depth) | |
| # Preprocessing | |
| image, depth = da.forward( | |
| raw_image, depth.cpu().detach().squeeze(0).numpy(), is_transform=False | |
| ) | |
| # Inference | |
| sms = sod_model.inference(image, depth, origin_shape, num_sets_of_salient_objects) | |
| # Postprocessing | |
| sods = [] | |
| for sm in sms: | |
| binary_mask = np.array(sm) | |
| t = threshold_otsu(binary_mask) | |
| binary_mask[binary_mask < t] = 0.0 | |
| binary_mask[binary_mask >= t] = 1.0 | |
| sod = apply_vis_to_image(np.array(raw_image), binary_mask, color) | |
| sods.append(sod) | |
| depth = depth.permute(1, 2, 0).detach().cpu().numpy() | |
| depth = cv2.resize(depth, origin_size) | |
| depth = post_processing_depth(depth) | |
| return depth, sods, [e / 255.0 for e in sms] | |
| def transform_images(inputs: List[Image.Image], transform: nn.Module) -> Tensor: | |
| if len(inputs) == 1: | |
| return transform(inputs[0]).unsqueeze(0) | |
| return torch.cat([transform(input).unsqueeze(0) for input in inputs]) | |