File size: 2,478 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
"""Utility functions for visualization.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import numpy as np
import torch
from matplotlib import pyplot as plt
METRIC_DEPTH_MAX_CLAMP_METER = 50.0
def colorize_depth(depth: torch.Tensor, val_max: float = 10.0) -> torch.Tensor:
"""Colorize depth map."""
depth_channels = depth.shape[-3]
# When we have a general depth/disparity map, output the color map as is.
if depth_channels == 1:
return colorize_scalar_map(
depth.squeeze(-3), val_min=0.0, val_max=val_max, color_map="turbo"
)
# When we have a multi-layered depth/disparity map,
# we concatenate the color maps horizontally and output it.
else:
colored_depths = []
for c in range(depth_channels):
colored_depths.append(
colorize_scalar_map(
depth[..., c, :, :], val_min=0.0, val_max=val_max, color_map="turbo"
)
)
return torch.cat(colored_depths, dim=-1)
def colorize_alpha(alpha: torch.Tensor) -> torch.Tensor:
"""Colorize alpha map."""
return colorize_scalar_map(alpha.squeeze(-3), val_min=0.0, val_max=1.0, color_map="coolwarm")
def colorize_scalar_map(
scalar_map: torch.Tensor, val_min=0.0, val_max=1.0, color_map: str = "jet"
) -> torch.Tensor:
"""Colorize a scalar map of.
Args:
scalar_map: Map of with format BHW.
val_min: Minimu value to display.
val_max: Maximum value to display.
color_map: Which color map to use. Will be passed to matplotlob.
Returns:
A colorized image with format BHWC.
"""
if scalar_map.ndim not in (2, 3, 4):
raise ValueError("Only scalar maps of 2 or 3 or 4 dimensions supported.")
cmap = plt.get_cmap(color_map)
scalar_map_np = scalar_map.detach().cpu().float().numpy()
scalar_map_np = (scalar_map_np - val_min) / (val_max - val_min)
scalar_map_np = np.clip(scalar_map_np, a_min=0.0, a_max=1.0)
color_map_np = cmap(scalar_map_np)[..., :3]
tensor = torch.as_tensor(color_map_np * 255.0, dtype=torch.uint8)
if tensor.ndim == 3:
return tensor.permute(2, 0, 1)
elif tensor.ndim == 4:
return tensor.permute(0, 3, 1, 2)
elif tensor.ndim == 5:
return tensor.permute(0, 1, 4, 2, 3)
else:
assert False, "Invalid tensor shape encountered."
|