File size: 8,394 Bytes
7734c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
from typing import Optional

from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures import Pointclouds
from pytorch3d.transforms import quaternion_to_matrix

from sam3d_objects.data.dataset.tdfy.transforms_3d import compose_transform
from sam3d_objects.utils.visualization.plotly.plot_scene import plot_tdfy_scene
from sam3d_objects.utils.visualization.image_mesh import (
    mesh_from_pointmap,
    create_textured_mesh,
)

from sam3d_objects.utils.visualization.plotly.plot_scene import NO_BACKGROUND, default_axisargs
from sam3d_objects.utils.visualization.plotly.save_scene import make_video as make_scene_video
import seaborn as sns
import copy


class SceneVisualizer:
    make_video_from_fig = make_scene_video

    @staticmethod
    def plot_scene(
        points_local: torch.Tensor,
        instance_quaternions_l2c: torch.Tensor,
        instance_positions_l2c: torch.Tensor,
        instance_scales_l2c: torch.Tensor,
        pointmap: Optional[torch.Tensor] = None,
        image: Optional[torch.Tensor] = None,
        title: str = "Tdfy Scene",
        height: int = 1000,
        show_pointmap_as_mesh: bool = True,
        clip_pointmap_colors_for_vis: bool = False,
        filter_pointmap_edges: bool = True,
    ):
        cam = SceneVisualizer.camera()

        object_points = SceneVisualizer.object_pointcloud(
            points_local=points_local.unsqueeze(0),
            quat_l2c=instance_quaternions_l2c,
            trans_l2c=instance_positions_l2c,
            scale_l2c=instance_scales_l2c,
            # colors=torch.ones_like(sample["instance_points_local"]) * torch.tensor([1, 0, 0]),
        )

        pointmap_struct_dict = SceneVisualizer._create_pointmap_structure(
            pointmap=pointmap,
            image=image,
            show_pointmap_as_mesh=show_pointmap_as_mesh,
            clip_pointmap_colors_for_vis=clip_pointmap_colors_for_vis,
            filter_pointmap_edges=filter_pointmap_edges,
        )
        return plot_tdfy_scene(
            {
                title: {
                    "camera": cam,
                    "object_points": object_points,
                    **pointmap_struct_dict,
                }
            },
            height=height,
        )

    @staticmethod
    def plot_multi_objects(
        pose_targets,
        mask_names=None,
        pointmap=None,
        pointmap_colors=None,
        mask_colors=None,
        plot_tdfy_kwargs=None,
        title="Tdfy Scene",
    ):
        if mask_colors is None:
            mask_colors = sns.color_palette("husl", len(mask_names))
        if mask_names is None:
            mask_names = [str(i) for i in range(len(pose_targets))]

        cam = SceneVisualizer.camera()
        objects = {}
        for i, mask_name in enumerate(mask_names):
            if mask_name == None:
                continue

            objects[mask_name] = SceneVisualizer.object_pointcloud(
                points_local=pose_targets[i]["xyz_local"].unsqueeze(0),
                quat_l2c=pose_targets[i]["rotation"],
                trans_l2c=pose_targets[i]["translation"],
                scale_l2c=pose_targets[i]["scale"],
                colors=mask_colors[i],
            )

        pointmap_dict = {}
        if pointmap is not None:
            pointmap[pointmap.isnan()] = 0
            pointmap_dict = SceneVisualizer._create_pointmap_structure(
                pointmap=pointmap,
                image=pointmap_colors,
                filter_pointmap_edges=True,
            )

        if plot_tdfy_kwargs is None:
            plot_tdfy_kwargs = copy.deepcopy(NO_BACKGROUND)
        if "height" not in plot_tdfy_kwargs:
            plot_tdfy_kwargs["height"] = 1000
        if "width" not in plot_tdfy_kwargs:
            plot_tdfy_kwargs["width"] = 1000

        fig = plot_tdfy_scene(
            {
                title: {
                    "camera": cam,
                    **objects,
                    **pointmap_dict,
                }
            },
            **plot_tdfy_kwargs,
        )
        return fig

    @staticmethod
    def _create_pointmap_structure(
        pointmap: torch.Tensor,
        image: torch.Tensor,
        show_pointmap_as_mesh: bool = True,
        clip_pointmap_colors_for_vis: bool = True,
        filter_pointmap_edges: bool = True,
    ):
        if pointmap is None:
            return {}

        if show_pointmap_as_mesh:
            if image is None:
                image = torch.zeros_like(pointmap)
            struct = SceneVisualizer.pointmap_to_mesh(
                pointmap=pointmap,
                image=image,
                clip_pointmap_colors_for_vis=clip_pointmap_colors_for_vis,
                filter_edges=filter_pointmap_edges,
            )
            return {"Pointmap mesh": struct}
        else:
            struct = SceneVisualizer.pointmap_to_pointcloud(
                pointmap=pointmap, image=image
            )
            return {"Pointmap pointcloud": struct}

    @staticmethod
    def camera(
        quaternion: Optional[torch.Tensor] = None,
        translation: Optional[torch.Tensor] = None,
    ):
        """
        Args:
            quaternion: (4,) tensor of quaternion
            translation: (3,) tensor of translation
        """
        if quaternion is None:
            quaternion = torch.tensor([1, 0, 0, 0]).unsqueeze(0)
        if translation is None:
            translation = torch.tensor([0, 0, 0]).unsqueeze(0)
        R = quaternion_to_matrix(quaternion)
        return PerspectiveCameras(R=R, T=translation)

    @staticmethod
    def object_pointcloud(
        points_local: torch.Tensor,
        quat_l2c: torch.Tensor,
        trans_l2c: torch.Tensor,
        scale_l2c: torch.Tensor,
        colors: Optional[torch.Tensor] = None,
    ):
        """
        Args:
            points_local: (N, 3) tensor of point coordinates
            colors: (N, 3) tensor of colors
        """
        if colors is None:
            colors = torch.ones_like(points_local) * torch.tensor(
                (1.0, 0.0, 0.0), device=points_local.device
            )
        elif isinstance(colors, tuple):
            colors = torch.ones_like(points_local) * torch.tensor(
                colors, device=points_local.device
            )

        R_l2c = quaternion_to_matrix(quat_l2c)
        l2c_transform = compose_transform(
            scale=scale_l2c, rotation=R_l2c, translation=trans_l2c
        )
        points_world = l2c_transform.transform_points(points_local)
        return Pointclouds(points=points_world, features=colors)

    @staticmethod
    def pointmap_to_pointcloud(pointmap: torch.Tensor, image: torch.Tensor):
        """
        Args:
            pointmap: (H, W, 3) tensor of point coordinates
            image: (H, W, 3) tensor of image
        """
        if image is not None:
            if image.shape[0] == 3:
                image = image.permute(1, 2, 0)
            image = image.reshape(-1, 3).unsqueeze(0).float()

        return Pointclouds(
            points=pointmap.reshape(-1, 3).unsqueeze(0),
            features=image,
        )

    @staticmethod
    def pointmap_to_mesh(
        pointmap: torch.Tensor,
        image: torch.Tensor,
        clip_pointmap_colors_for_vis: bool = True,
        filter_edges: bool = True,
        clamp_eps: float = 1 / 254,
    ):
        """
        Args:
            pointmap: (H, W, 3) tensor of point coordinates
            image: (H, W, 3) tensor of image
        """
        pointmap = pointmap.cpu().numpy()
        if image is None:
            image = torch.zeros_like(pointmap)
        if image.shape[0] == 3:
            image = image.permute(1, 2, 0)

        if clip_pointmap_colors_for_vis:
            # Not sure why, but this is needed to avoid underflow in the visualization
            # We also clip to prevent overflow, just in case and since this is just for visualization
            image = image.clamp(clamp_eps, 1 - clamp_eps)
        image = image.cpu().numpy()
        mesh = mesh_from_pointmap(pointmap, image, filter_edges=filter_edges)
        vertices = torch.from_numpy(mesh.vertices)
        faces = torch.from_numpy(mesh.faces)
        vertex_colors = torch.from_numpy(mesh.vertex_colors)
        return create_textured_mesh(vertices, faces, vertex_colors)