| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ Helper functions for visualizing outputs """ |
|
|
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import matplotlib |
| import torch |
| from jaxtyping import Bool, Float |
| from torch import Tensor |
| import colors |
|
|
| |
| Colormaps = "turbo" |
|
|
| @dataclass(frozen=True) |
| class ColormapOptions: |
| """Options for colormap""" |
|
|
| colormap: Colormaps = "default" |
| """ The colormap to use """ |
| normalize: bool = False |
| """ Whether to normalize the input tensor image """ |
| colormap_min: float = 0 |
| """ Minimum value for the output colormap """ |
| colormap_max: float = 1 |
| """ Maximum value for the output colormap """ |
| invert: bool = False |
| """ Whether to invert the output colormap """ |
|
|
|
|
| def apply_colormap( |
| image: Float[Tensor, "*bs channels"], |
| colormap_options: ColormapOptions = ColormapOptions(), |
| eps: float = 1e-9, |
| ): |
| """ |
| Applies a colormap to a tensor image. |
| If single channel, applies a colormap to the image. |
| If 3 channel, treats the channels as RGB. |
| If more than 3 channel, applies a PCA reduction on the dimensions to 3 channels |
| |
| Args: |
| image: Input tensor image. |
| eps: Epsilon value for numerical stability. |
| |
| Returns: |
| Tensor with the colormap applied. |
| """ |
|
|
| |
| if image.shape[-1] == 3: |
| return image |
|
|
| |
| if image.shape[-1] == 1 and torch.is_floating_point(image): |
| output = image |
| if colormap_options.normalize: |
| output = output - torch.min(output) |
| output = output / (torch.max(output) + eps) |
| output = ( |
| output * (colormap_options.colormap_max - colormap_options.colormap_min) + colormap_options.colormap_min |
| ) |
| output = torch.clip(output, 0, 1) |
| if colormap_options.invert: |
| output = 1 - output |
| return apply_float_colormap(output, colormap=colormap_options.colormap) |
|
|
| |
| if image.dtype == torch.bool: |
| return apply_boolean_colormap(image) |
|
|
| if image.shape[-1] > 3: |
| return apply_pca_colormap(image) |
|
|
| raise NotImplementedError |
|
|
|
|
| def apply_float_colormap(image: Float[Tensor, "*bs 1"], colormap: Colormaps = "viridis"): |
| """Convert single channel to a color image. |
| |
| Args: |
| image: Single channel image. |
| colormap: Colormap for image. |
| |
| Returns: |
| Tensor: Colored image with colors in [0, 1] |
| """ |
| if colormap == "default": |
| colormap = "turbo" |
|
|
| image = torch.nan_to_num(image, 0) |
| if colormap == "gray": |
| return image.repeat(1, 1, 3) |
| image_long = (image * 255).long() |
| image_long_min = torch.min(image_long) |
| image_long_max = torch.max(image_long) |
| assert image_long_min >= 0, f"the min value is {image_long_min}" |
| assert image_long_max <= 255, f"the max value is {image_long_max}" |
| return torch.tensor(matplotlib.colormaps[colormap].colors, device=image.device)[image_long[..., 0]] |
|
|
|
|
| def apply_depth_colormap( |
| depth: Float[Tensor, "*bs 1"], |
| accumulation: Optional[Float[Tensor, "*bs 1"]] = None, |
| near_plane: Optional[float] = None, |
| far_plane: Optional[float] = None, |
| colormap_options: ColormapOptions = ColormapOptions(), |
| ): |
| """Converts a depth image to color for easier analysis. |
| |
| Args: |
| depth: Depth image. |
| accumulation: Ray accumulation used for masking vis. |
| near_plane: Closest depth to consider. If None, use min image value. |
| far_plane: Furthest depth to consider. If None, use max image value. |
| colormap: Colormap to apply. |
| |
| Returns: |
| Colored depth image with colors in [0, 1] |
| """ |
|
|
| near_plane = near_plane or float(torch.min(depth)) |
| far_plane = far_plane or float(torch.max(depth)) |
|
|
| depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
| depth = torch.clip(depth, 0, 1) |
| |
|
|
| colored_image = apply_colormap(depth, colormap_options=colormap_options) |
|
|
| if accumulation is not None: |
| colored_image = colored_image * accumulation + (1 - accumulation) |
|
|
| return colored_image |
|
|
|
|
| def apply_boolean_colormap( |
| image: Bool[Tensor, "*bs 1"], |
| true_color = colors.WHITE, |
| false_color = colors.BLACK, |
| ): |
| """Converts a depth image to color for easier analysis. |
| |
| Args: |
| image: Boolean image. |
| true_color: Color to use for True. |
| false_color: Color to use for False. |
| |
| Returns: |
| Colored boolean image |
| """ |
|
|
| colored_image = torch.ones(image.shape[:-1] + (3,)) |
| colored_image[image[..., 0], :] = true_color |
| colored_image[~image[..., 0], :] = false_color |
| return colored_image |
|
|
|
|
| def apply_pca_colormap(image: Float[Tensor, "*bs dim"]): |
| """Convert feature image to 3-channel RGB via PCA. The first three principle |
| components are used for the color channels, with outlier rejection per-channel |
| |
| Args: |
| image: image of arbitrary vectors |
| |
| Returns: |
| Tensor: Colored image |
| """ |
| original_shape = image.shape |
| image = image.view(-1, image.shape[-1]) |
| _, _, v = torch.pca_lowrank(image) |
| image = torch.matmul(image, v[..., :3]) |
| d = torch.abs(image - torch.median(image, dim=0).values) |
| mdev = torch.median(d, dim=0).values |
| s = d / mdev |
| m = 3.0 |
| rins = image[s[:, 0] < m, 0] |
| gins = image[s[:, 1] < m, 1] |
| bins = image[s[:, 2] < m, 2] |
|
|
| if len(rins) == 0 or len(gins) == 0 or len(bins) == 0: |
| return image.new_zeros(*original_shape[:-1], 3) |
|
|
| image[:, 0] -= rins.min() |
| image[:, 1] -= gins.min() |
| image[:, 2] -= bins.min() |
|
|
| image[:, 0] /= rins.max() - rins.min() |
| image[:, 1] /= gins.max() - gins.min() |
| image[:, 2] /= bins.max() - bins.min() |
|
|
| image = torch.clamp(image, 0, 1) |
| image_long = (image * 255).long() |
| image_long_min = torch.min(image_long) |
| image_long_max = torch.max(image_long) |
| assert image_long_min >= 0, f"the min value is {image_long_min}" |
| assert image_long_max <= 255, f"the max value is {image_long_max}" |
| return image.view(*original_shape[:-1], 3) |
|
|