File size: 3,726 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
This file contains utility functions for visualizing image observations in the training pipeline.
These functions can be a useful debugging tool.
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils


def image_tensor_to_numpy(image):
    """
    Converts processed image tensors to numpy so that they can be saved to disk or video.
    A useful utility function for visualizing images in the middle of training.

    Args:
        image (torch.Tensor): images of shape [..., C, H, W]

    Returns:
        image (np.array): converted images of shape [..., H, W, C] and type uint8
    """
    return TensorUtils.to_numpy(
            ObsUtils.unprocess_image(image)
        ).astype(np.uint8)


def image_to_disk(image, fname):
    """
    Writes an image to disk.

    Args:
        image (np.array): image of shape [H, W, 3]
        fname (str): path to save image to
    """
    image = Image.fromarray(image)
    image.save(fname)


def image_tensor_to_disk(image, fname):
    """
    Writes an image tensor to disk. Any leading batch dimensions are indexed out
    with the first element.

    Args:
        image (torch.Tensor): image of shape [..., C, H, W]. All leading dimensions
            will be indexed out with the first element
        fname (str): path to save image to
    """
    # index out all leading dimensions before [C, H, W]
    num_leading_dims = len(image.shape[:-3])
    for _ in range(num_leading_dims):
        image = image[0]
    image = image_tensor_to_numpy(image)
    image_to_disk(image, fname)


def visualize_image_randomizer(original_image, randomized_image, randomizer_name=None):
    """
    A function that visualizes the before and after of an image-based input randomizer
    Args:
        original_image: batch of original image shaped [B, H, W, 3]
        randomized_image: randomized image shaped [B, N, H, W, 3]. N is the number of randomization per input sample
        randomizer_name: (Optional) name of the randomizer
    Returns:
        None
    """

    B, N, H, W, C = randomized_image.shape

    # Create a grid of subplots with B rows and N+1 columns (1 for the original image, N for the randomized images)
    fig, axes = plt.subplots(B, N + 1, figsize=(4 * (N + 1), 4 * B))

    for i in range(B):
        # Display the original image in the first column of each row
        axes[i, 0].imshow(original_image[i])
        axes[i, 0].set_title("Original")
        axes[i, 0].axis("off")

        # Display the randomized images in the remaining columns of each row
        for j in range(N):
            axes[i, j + 1].imshow(randomized_image[i, j])
            axes[i, j + 1].axis("off")

    title = randomizer_name if randomizer_name is not None else "Randomized"
    fig.suptitle(title, fontsize=16)

    # Adjust the space between subplots for better visualization
    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    # Show the entire grid of subplots
    plt.show()


def depth_to_rgb(depth_map, depth_min=None, depth_max=None):
    """
    Convert depth map to rgb array by computing normalized depth values in [0, 1].
    """
    # normalize depth map into [0, 1]
    if depth_min is None:
        depth_min = depth_map.min()
    if depth_max is None:
        depth_max = depth_map.max()
    depth_map = (depth_map - depth_min) / (depth_max - depth_min)
    # depth_map = np.clip(depth_map / 3., 0., 1.)
    if len(depth_map.shape) == 3:
        assert depth_map.shape[-1] == 1
        depth_map = depth_map[..., 0]
    assert len(depth_map.shape) == 2 # [H, W]
    return (255. * cm.hot(depth_map, 3)).astype(np.uint8)[..., :3]