import torch import numpy as np import os.path as osp import plotly.graph_objects as go from src.data import Data, NAG, Cluster from src.transforms import GridSampling3D, SaveNodeIndex from src.utils import fast_randperm, to_trimmed from torch_scatter import scatter_mean from src.utils.color import * # TODO: To go further with ipwidgets : # - https://plotly.com/python/figurewidget-app/ # - https://ipywidgets.readthedocs.io/en/stable/ def visualize_3d( input, keys=None, figsize=1000, width=None, height=None, class_names=None, class_colors=None, stuff_classes=None, num_classes=None, hide_void_pred=False, voxel=-1, max_points=50000, point_size=3, centroid_size=None, error_color=None, centroids=False, h_edge=False, h_edge_attr=False, h_edge_width=None, v_edge=False, v_edge_width=None, gap=None, radius=None, center=None, select=None, alpha=0.1, alpha_super=None, alpha_stuff=0.2, point_symbol='circle', centroid_symbol='circle', colorscale='Agsunset', **kwargs): """3D data interactive visualization. :param input: `Data` or `NAG` object :param keys: `List(str)` or `str` By default, the following attributes will be parsed in `input` for visualization {`pos`, `rgb`, `y`, `obj`, `semantic_pred`, `obj_pred`}. Yet, if `input` contains other attributes that you want to visualize, these can be passed as `keys`. This only supports point-wise attributes stored as 1D or 2D tensors. If the tensor contains only 1 channel, the attribute will be represented with a grayscale colormap. If the tensor contains 2 or 3 channels, these will be represented as RGB, with an additional all-1 channel if need be. If the tensor contains more than 3 channels, a PCA projection to RGB will be shown. In any case, the attribute values will be rescaled with respect to their statistics before visualization, meaning that colors may not compare between two different plots :param figsize: `int` Figure dimensions will be `(figsize, figsize/2)` if `width` and `height` are not specified :param width: `int` Figure width :param height: `int` Figure height :param class_names: `List(str)` Names for point labels found in attributes `y` and `semantic_pred` :param class_colors: `List(List(int, int, int))` Colors palette for point labels found in attributes `y` and `semantic_pred` :param stuff_classes: `List(int)` Semantic labels of the classes considered as `stuff` for instance and panoptic segmentation. If `y` and `obj` are found in the point attributes, the stuff annotations will appear accordingly. Otherwise, stuff instance labeling will appear as any other object :param num_classes: `int` Number of valid classes. By convention, we assume `y ∈ [0, num_classes-1]` are VALID LABELS, while `y < 0` AND `y >= num_classes` ARE VOID LABELS :param hide_void_pred: `bool` Whether predictions on points labeled as VOID be visualized :param voxel: `float` Voxel size to subsample the point cloud to facilitate visualization :param max_points: `int` Maximum number of points displayed to facilitate visualization :param point_size: `int` or `float` Size of point markers :param centroid_size: `int` or `float` Size of superpoint markers :param error_color: `List(int, int, int)` Color used to identify mis-predicted points :param centroids: `bool` Whether superpoint centroids should be displayed :param h_edge: `bool` Whether horizontal edges should be displayed (only if `centroids=True`) :param h_edge_attr: `bool` Whether the edges should be colored by their features found in `edge_attr` (only if `h_edge=True`) :param h_edge_width: `float` Width of the horizontal edges, if `h_edge=True`. Defaults to `None`, in which case `point_size` will be used for the edge width :param v_edge: `bool` Whether vertical edges should be displayed (only if `centroids=True` and `gap` is not `None`) :param v_edge_width: `float` Width of the vertical edges, if `v_edge=True`. Defaults to `None`, in which case `point_size` will be used for the edge width :param gap: `List(float, float, float)` If `None`, the hierarchical graphs will be overlaid on the points. If not `None`, a 3D tensor indicating the offset by which the hierarchical graphs should be plotted :param radius: `float` If not `None`, only visualize a spherical sampling of the input data, centered on `center` and with size `radius`. This option is not compatible with `select` :param center: `List(float, float, float)` If `radius` is provided, only visualize a spherical sampling of the input data, centered on `center` and with size `radius`. If `None`, the center of the scene will be used :param select: `Tuple(int, Tensor)` If not `None`, will call `Data.select(*select)` or `NAG.select(*select)` on the input data (depending on its nature) and the coloring schemes will illustrate it. This option is not compatible with `radius` :param alpha: `float` Rules the whitening of selected points, nodes and edges (only if select is not `None`) :param alpha_super: `float` Rules the whitening of superpoints (only if select is not `None`). If `None`, alpha will be used as fallback :param alpha_stuff: `float` Rules the whitening of stuff points (only if the input points have `obj` and `semantic_pred` attributes, and `stuff_classes` or `num_classes` is specified). If `None`, `alpha` will be used as fallback :param point_symbol: `str` Marker symbol used for points. Must be one of `{'circle', 'circle-open', 'square', 'square-open', 'diamond', 'diamond-open', 'cross', 'x'}`. Defaults to `'circle'` :param centroid_symbol: `str` Marker symbol used for centroids. Must be one of `{'circle', 'circle-open', 'square', 'square-open', 'diamond', 'diamond-open', 'cross', 'x'}`. Defaults to `'circle'` :param colorscale: `str` Plotly colorscale used for coloring 1D continuous features. See https://plotly.com/python/builtin-colorscales for options :param kwargs :return: """ # Data attributes plotted by default if found in the input _DEFAULT_KEYS = [ 'pos', 'rgb', 'y', 'semantic_pred', 'obj', 'obj_pred', 'x', 'super_sampling', 'super_index'] # assert isinstance(input, (Data, NAG)) gap = torch.tensor(gap) if gap is not None else gap assert gap is None or gap.shape == torch.Size([3]) assert not (radius and (select is not None)), \ "Cannot use both a `radius` and `select` at once" # We work on copies of the input data, to allow modified in this # scope input = input.clone().cpu() # If the input is a simple Data object, we convert it to a NAG input = NAG([input]) if isinstance(input, Data) else input # If the last level of the NAG has super_index, we manually # construct an additional Data level and append it to the NAG if input[input.num_levels - 1].is_sub: data_last = input[input.num_levels - 1] sub = Cluster( data_last.super_index, torch.arange(data_last.num_nodes), dense=True) obj = data_last.obj.merge(data_last.super_index) \ if data_last.obj else None pos = scatter_mean(data_last.pos, data_last.super_index, dim=0) input = NAG(input.to_list() + [Data(pos=pos, sub=sub, obj=obj)]) is_nag = isinstance(input, NAG) num_levels = input.num_levels if is_nag else 1 # Make sure alpha is in [0, 1] alpha = max(0, min(alpha, 1)) alpha_super = max(0, min(alpha_super, 1)) if alpha_super else alpha alpha_stuff = max(0, min(alpha_stuff, 1)) if alpha_stuff else alpha # If `radius` is provided, we only visualize a spherical selection # of size `radius` around the `center` if radius is not None: # If no `center` provided, pick the middle of the scene if center is None: hi = input[0].pos.max(dim=0).values lo = input[0].pos.min(dim=0).values center = (hi + lo) / 2 # For Z, we center on the average Z, because the middle # value may cause empty samplings for outdoor scenes with # some very high objects and most of the interesting stuff # happening near the ground center[2] = input[0].pos[:, 2].mean() else: center = torch.as_tensor(center).cpu() center = center.view(1, -1) # Create a mask on level-0 (ie points) to be used for indexing # the NAG structure mask = torch.where( torch.linalg.norm(input[0].pos - center, dim=1) < radius)[0] # Subselect the hierarchical partition based on the level-0 mask input = input.select(0, mask) # If `select` is provided, we will call NAG.select on the input data # and illustrate the selected/discarded pattern in the figure if select is not None and is_nag: # Add an ID to the points before applying NAG.select nag_temp = input.clone() for i in range(nag_temp.num_levels): nag_temp._list[i] = SaveNodeIndex()(nag_temp[i]) # Apply the selection nag_temp = nag_temp.select(*select) # Indicate, for each node of the hierarchical graph, whether it # has been selected for i in range(num_levels): selected = torch.zeros(input[i].num_nodes, dtype=torch.bool) selected[nag_temp[i][SaveNodeIndex.DEFAULT_KEY]] = True input[i].selected = selected del nag_temp, selected elif select is not None and not is_nag: # Add an ID to the points before applying NAG.select data_temp = SaveNodeIndex()(Data(pos=input.pos.clone())) # Apply the selection data_temp = data_temp.select(select)[0] # Indicate, for each node of the hierarchical graph, whether it # has been selected selected = torch.zeros(input.num_nodes, dtype=torch.bool) selected[data_temp[SaveNodeIndex.DEFAULT_KEY]] = True input.selected = selected del data_temp, selected elif is_nag: for i in range(num_levels): input[i].selected = torch.ones( input[i].num_nodes, dtype=torch.bool) else: input.selected = torch.ones(input.num_nodes, dtype=torch.bool) # Data_0 accounts for the lowest level of hierarchy, the points # themselves data_0 = input[0] if is_nag else input # Subsample to limit the drawing time # If the level-0 cloud needs to be voxelized or sampled, a NAG # structure will be affected too. To maintain NAG consistency, we # only support 'GridSampling3D' with mode='last' and random sampling # without replacement. To keep track of the sampled points and index # the NAG accordingly, we use 'SaveNodeIndex' idx = torch.arange(data_0.num_points) # If a voxel size is specified, voxelize the level-0. We first # isolate the 'pos' and the input indices of data_0 and apply # voxelization on this. We then recover the original grid-sampled # points indices to be used with Data.select or NAG.select if voxel > 0: data_temp = SaveNodeIndex()(Data(pos=data_0.pos.clone())) data_temp = GridSampling3D(voxel, mode='last')(data_temp) idx = data_temp[SaveNodeIndex.DEFAULT_KEY] del data_temp # If the cloud is too large with respect to required 'max_points', # sample without replacement if idx.shape[0] > max_points: idx = idx[fast_randperm(idx.shape[0])[:max_points]] # If a sampling is needed, apply it to the input Data or NAG, # depending on the structure if idx.shape[0] < data_0.num_points: input = input.select(0, idx) if is_nag else input.select(idx)[0] data_0 = input[0] if is_nag else input # Round to the cm for cleaner hover info data_0.pos = (data_0.pos * 100).round() / 100 # Class colors initialization if class_colors is not None and not isinstance(class_colors[0], str): class_colors = np.asarray(class_colors) else: class_colors = None # Prepare figure width = width if width and height else figsize height = height if width and height else int(figsize / 2) margin = int(0.02 * min(width, height)) layout = go.Layout( width=width, height=height, scene=dict(aspectmode='data', ), # preserve aspect ratio margin=dict(l=margin, r=margin, b=margin, t=margin), uirevision=True) fig = go.Figure(layout=layout) # To keep track of which trace should be seen under which mode # (i.e. button), we build trace_modes. This is a list of dictionaries # indicating, for each trace (list element), which mode (dict key) # it should appear in and with which attributes (values are dict of # parameters for plotly figure updates) trace_modes = [] i_point_trace = 0 i_unselected_point_trace = 1 # Initialize `void_classes` void_classes = [num_classes] if num_classes else [] # Draw a trace for position-colored 3D point cloud mini = data_0.pos.min(dim=0).values maxi = data_0.pos.max(dim=0).values colors = (data_0.pos - mini) / (maxi - mini + 1e-6) colors = rgb_to_plotly_rgb(colors) data_0.pos_colors = colors fig.add_trace( go.Scatter3d( x=data_0.pos[data_0.selected, 0], y=data_0.pos[data_0.selected, 1], z=data_0.pos[data_0.selected, 2], mode='markers', marker=dict( symbol=point_symbol, size=point_size, color=colors[data_0.selected]), hoverinfo='x+y+z+text', hovertext=None, showlegend=False, visible=True, )) trace_modes.append({ 'Position RGB': { 'marker.color': colors[data_0.selected], 'hovertext': None}}) fig.add_trace( go.Scatter3d( x=data_0.pos[~data_0.selected, 0], y=data_0.pos[~data_0.selected, 1], z=data_0.pos[~data_0.selected, 2], mode='markers', marker=dict( symbol=point_symbol, size=point_size, color=colors[~data_0.selected], opacity=alpha), hoverinfo='x+y+z+text', hovertext=None, showlegend=False, visible=True, )) trace_modes.append({ 'Position RGB': { 'marker.color': colors[~data_0.selected], 'hovertext': None}}) # Draw a trace for RGB 3D point cloud if data_0.rgb is not None: colors = data_0.rgb colors = rgb_to_plotly_rgb(colors) data_0.rgb_colors = colors trace_modes[i_point_trace]['RGB'] = { 'marker.color': colors[data_0.selected], 'hovertext': None} trace_modes[i_unselected_point_trace]['RGB'] = { 'marker.color': colors[~data_0.selected], 'hovertext': None} # Color the points with ground truth semantic labels. If labels are # expressed as histograms, keep the most frequent one if data_0.y is not None: y = data_0.y y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy() colors = class_colors[y] if class_colors is not None \ else int_to_plotly_rgb(torch.LongTensor(y)) data_0.y_colors = colors if class_names is None: text = np.array([f"Class {i}" for i in range(y.max() + 1)]) else: text = np.array([str.title(c) for c in class_names]) text = text[y] trace_modes[i_point_trace]['Semantic'] = { 'marker.color': colors[data_0.selected], 'hovertext': text[data_0.selected]} trace_modes[i_unselected_point_trace]['Semantic'] = { 'marker.color': colors[~data_0.selected], 'hovertext': text[~data_0.selected]} # Color the points with predicted semantic labels. If labels are # expressed as histograms, keep the most frequent one if data_0.semantic_pred is not None: pred = data_0.semantic_pred pred = pred.argmax(1).numpy() if pred.dim() == 2 else pred.numpy() # If the ground truth labels are available, we use them to # identify void points in the predictions if data_0.y is not None and hide_void_pred: # Get the target label y_gt = data_0.y y_gt = y_gt.argmax(1) if y_gt.dim() == 2 else y_gt # Create a mask over the points identifying those whose # ground truth label is void is_void = np.zeros(y_gt.max() + 1, dtype='bool') for i in void_classes: if i < is_void.shape[0]: is_void[i] = True is_void = is_void[y_gt] # Set the predicted label to void if the ground truth is # void, this avoids visualizing predictions on void # labels pred[is_void] = y_gt[is_void] colors = class_colors[pred] if class_colors is not None \ else int_to_plotly_rgb(torch.LongTensor(pred)) data_0.pred_colors = colors if class_names is None: text = np.array([f"Class {i}" for i in range(pred.max() + 1)]) else: text = np.array([str.title(c) for c in class_names]) text = text[pred] trace_modes[i_point_trace]['Semantic Pred.'] = { 'marker.color': colors[data_0.selected], 'hovertext': text[data_0.selected]} trace_modes[i_unselected_point_trace]['Semantic Pred.'] = { 'marker.color': colors[~data_0.selected], 'hovertext': text[~data_0.selected]} # Color the points with ground truth instance labels. If semantic # labels and stuff_classes/void_classes also passed, the stuff/void # annotations will be treated accordingly if data_0.obj is not None and (class_names is None or data_0.y is None): obj = data_0.obj if isinstance(data_0.obj, torch.Tensor) \ else data_0.obj.major(num_classes=num_classes)[0] colors = int_to_plotly_rgb(obj) data_0.obj_colors = colors text = np.array([f"Object {o}" for o in obj]) trace_modes[i_point_trace]['Panoptic'] = { 'marker.color': colors[data_0.selected], 'hovertext': text[data_0.selected]} trace_modes[i_unselected_point_trace]['Panoptic'] = { 'marker.color': colors[~data_0.selected], 'hovertext': text[~data_0.selected]} elif data_0.obj is not None: # Colors and text for thing points obj = data_0.obj.major(num_classes=num_classes)[0] colors_thing = int_to_plotly_rgb(obj) text_thing = np.array([f"Object {o}" for o in obj]) # For simplicity, we merge void_classes into the stuff_classes, # the expected behavior is the same, except that we will ensure # that the hover text distinguishes between stuff and void stuff_classes = stuff_classes if stuff_classes is not None else [] stuff_classes = list(set(stuff_classes).union(set(void_classes))) # Colors and text for stuff points y = data_0.y y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy() colors_stuff = class_colors[y] if class_colors is not None \ else int_to_plotly_rgb(torch.LongTensor(y)) if class_names is None: text_stuff = np.array([ f"{'Void' if i in void_classes else 'Stuff'} - Class {i}" for i in range(y.max() + 1)]) else: text_stuff = np.array([ f"{'Void' if i in void_classes else 'Stuff'} - {str.title(c)}" for i, c in enumerate(class_names)]) text_stuff = text_stuff[y] # Apply alpha-whitening on stuff points colors_stuff = colors_stuff.astype('float') white = np.full((colors_stuff.shape[0], 3), 255, dtype='float') colors_stuff = colors_stuff * alpha_stuff + white * (1 - alpha_stuff) colors_stuff = colors_stuff.astype('int64') # Compute mask for stuff points stuff_classes = np.asarray([i for i in stuff_classes if i <= y.max()]) is_stuff = np.zeros(y.max() + 1, dtype='bool') for i in stuff_classes: if i < is_stuff.shape[0]: is_stuff[i] = True is_stuff = is_stuff[y] # Merge thing and stuff colors and text colors = colors_thing text = text_thing colors[is_stuff] = colors_stuff[is_stuff] text[is_stuff] = text_stuff[is_stuff] data_0.obj_colors = colors # Create trace modes trace_modes[i_point_trace]['Panoptic'] = { 'marker.color': colors[data_0.selected], 'hovertext': text[data_0.selected]} trace_modes[i_unselected_point_trace]['Panoptic'] = { 'marker.color': colors[~data_0.selected], 'hovertext': text[~data_0.selected]} # Color the points with predicted instance labels. If semantic # labels and stuff_classes/void_classes also passed, the # stuff/void predictions will be treated accordingly. This # expects `data_0.obj_pred` to be an InstanceData object if getattr(data_0, 'obj_pred', None) is not None and class_names is None: obj, _, y = data_0.obj_pred.major(num_classes=num_classes) colors = int_to_plotly_rgb(obj) data_0.obj_pred_colors = colors text = np.array([f"Object {o}" for o in obj]) trace_modes[i_point_trace]['Panoptic Pred.'] = { 'marker.color': colors[data_0.selected], 'hovertext': text[data_0.selected]} trace_modes[i_unselected_point_trace]['Panoptic Pred.'] = { 'marker.color': colors[~data_0.selected], 'hovertext': text[~data_0.selected]} elif getattr(data_0, 'obj_pred', None) is not None: # Colors and text for thing points obj, _, y = data_0.obj_pred.major(num_classes=num_classes) colors_thing = int_to_plotly_rgb(obj) text_thing = np.array([f"Object {o}" for o in obj]) # For simplicity, we merge void_classes into the stuff_classes, # the expected behavior is the same, except that we will ensure # that the hover text distinguishes between stuff and void stuff_classes = stuff_classes if stuff_classes is not None else [] stuff_classes = list(set(stuff_classes).union(set(void_classes))) # If the ground truth labels are available, we use them to # identify void points in the predictions if data_0.y is not None and hide_void_pred: # Get the target label y_gt = data_0.y y_gt = y_gt.argmax(1) if y_gt.dim() == 2 else y_gt # Create a mask over the points identifying those whose # ground truth label is void is_void = np.zeros(y_gt.max() + 1, dtype='bool') for i in void_classes: if i < is_void.shape[0]: is_void[i] = True is_void = is_void[y_gt] # Set the predicted label to void if the ground truth is # void, this avoids visualizing predictions on void # labels y[is_void] = y_gt[is_void] # Colors and text for stuff points colors_stuff = class_colors[y] if class_colors is not None \ else int_to_plotly_rgb(torch.LongTensor(y)) if class_names is None: text_stuff = np.array([ f"{'Void' if i in void_classes else 'Stuff'} - Class {i}" for i in range(y.max() + 1)]) else: text_stuff = np.array([ f"{'Void' if i in void_classes else 'Stuff'} - {str.title(c)}" for i, c in enumerate(class_names)]) text_stuff = text_stuff[y] # Apply alpha-whitening on stuff points colors_stuff = colors_stuff.astype('float') white = np.full((colors_stuff.shape[0], 3), 255, dtype='float') colors_stuff = colors_stuff * alpha_stuff + white * (1 - alpha_stuff) colors_stuff = colors_stuff.astype('int64') # Compute mask for stuff points stuff_classes = np.asarray([i for i in stuff_classes if i <= y.max()]) is_stuff = np.zeros(y.max() + 1, dtype='bool') for i in stuff_classes: if i < is_stuff.shape[0]: is_stuff[i] = True is_stuff = is_stuff[y] # Merge thing and stuff colors and text colors = colors_thing text = text_thing colors[is_stuff] = colors_stuff[is_stuff] text[is_stuff] = text_stuff[is_stuff] data_0.obj_pred_colors = colors # Create trace modes trace_modes[i_point_trace]['Panoptic Pred.'] = { 'marker.color': colors[data_0.selected], 'hovertext': text[data_0.selected]} trace_modes[i_unselected_point_trace]['Panoptic Pred.'] = { 'marker.color': colors[~data_0.selected], 'hovertext': text[~data_0.selected]} # Draw a trace for 3D point cloud features if data_0.x is not None: colors = feats_to_plotly_rgb( data_0.x, normalize=True, colorscale=colorscale) data_0.x_colors = colors trace_modes[i_point_trace]['Features 3D'] = { 'marker.color': colors[data_0.selected], 'hovertext': None} trace_modes[i_unselected_point_trace]['Features 3D'] = { 'marker.color': colors[~data_0.selected], 'hovertext': None} # Draw a trace for each key specified in keys. Only displays # point-wise tensor attributes that have not already been plotted # (ie not in `_DEFAULT_KEYS`) if keys is None: keys = [] elif isinstance(keys, str): keys = [keys] keys = [k for k in keys if k not in _DEFAULT_KEYS] for key in keys: val = getattr(data_0, key, None) if (val is None or not torch.is_tensor(val) or val.shape[0] != data_0.num_points): continue colors = feats_to_plotly_rgb( val, normalize=True, colorscale=colorscale) data_0[f"{key}_colors"] = colors trace_modes[i_point_trace][str(key).title()] = { 'marker.color': colors[data_0.selected], 'hovertext': None} trace_modes[i_unselected_point_trace][str(key).title()] = { 'marker.color': colors[~data_0.selected], 'hovertext': None} # Draw a trace for 3D point cloud sampling (for sampling debugging) if 'super_sampling' in data_0.keys: colors = data_0.super_sampling colors = int_to_plotly_rgb(colors) colors[data_0.super_sampling == -1] = 230 data_0.super_sampling_colors = colors trace_modes[i_point_trace]['Super sampling'] = { 'marker.color': colors[data_0.selected], 'hovertext': None} trace_modes[i_unselected_point_trace]['Super sampling'] = { 'marker.color': colors[~data_0.selected], 'hovertext': None} # Draw a trace for each cluster level for i_level, data_i in enumerate(input if is_nag else []): # Exit in case the Data has no 'super_index' if not data_i.is_sub: break # 'Data.super_index' are expressed between levels i and i+1, but # we need to recover the 'super_index' between level 0 and i+1, # to draw clusters on the level-0 points. To this end, we # compute the desired 'super_index' iteratively, with a # bottom-up approach if i_level == 0: super_index = data_i.super_index else: super_index = data_i.super_index[super_index] # Note that we update the 'trace_modes' 0th element here, this # assumes only it is the trace holding all level-0 points and on # which all other colors modes are defined colors = int_to_plotly_rgb(super_index) data_0[f"{i_level}_level_colors"] = colors text = np.array([f"↑: {i}" for i in super_index]) trace_modes[i_point_trace][f"Level {i_level + 1}"] = { 'marker.color': colors[data_0.selected], 'hovertext': text[data_0.selected]} trace_modes[i_unselected_point_trace][f"Level {i_level + 1}"] = { 'marker.color': colors[~data_0.selected], 'hovertext': text[~data_0.selected]} # Skip to the next level if we do not need to draw the cluster # centroids if not centroids: continue # To recover centroids of the i+1 level superpoints, we either # read them from the next NAG level or compute them using the # level i 'super_index' indices num_levels = input.num_levels is_last_level = i_level == num_levels - 1 if is_last_level or input[i_level + 1].pos is None: super_pos = scatter_mean(data_0.pos, super_index, dim=0) else: super_pos = input[i_level + 1].pos # Add the gap offset, if need be if gap is not None: super_pos += gap * (i_level + 1) # Round to the cm for cleaner hover info super_pos = (super_pos * 100).round() / 100 # Save the drawing position of centroids to facilitate vertical # edges drawing later on input[i_level + 1].draw_pos = super_pos # Draw the level-i+1 cluster centroids idx_sp = torch.arange(data_i.super_index.max() + 1) colors = int_to_plotly_rgb(idx_sp) text = np.array([f"#: {i}" for i in idx_sp]) ball_size = centroid_size if centroid_size else point_size * 3 fig.add_trace( go.Scatter3d( x=super_pos[input[i_level + 1].selected, 0], y=super_pos[input[i_level + 1].selected, 1], z=super_pos[input[i_level + 1].selected, 2], mode='markers+text', marker=dict( symbol=centroid_symbol, size=ball_size, color=colors[input[i_level + 1].selected.numpy()], line_width=min(ball_size / 2, 2), line_color='black'), textposition="bottom center", textfont=dict(size=16), hovertext=text, hoverinfo='x+y+z+text', showlegend=False, visible=gap is not None, )) fig.add_trace( go.Scatter3d( x=super_pos[~input[i_level + 1].selected, 0], y=super_pos[~input[i_level + 1].selected, 1], z=super_pos[~input[i_level + 1].selected, 2], mode='markers+text', marker=dict( symbol=centroid_symbol, size=ball_size, color=colors[~input[i_level + 1].selected.numpy()], line_width=min(ball_size / 2, 2), line_color='black', opacity=alpha_super), textposition="bottom center", textfont=dict(size=16), hovertext=text, hoverinfo='x+y+z+text', showlegend=False, visible=gap is not None, )) keys = [f"Level {i_level + 1}"] if gap is None \ else trace_modes[i_point_trace].keys() trace_modes.append( {k: { 'marker.color': colors[input[i_level + 1].selected.numpy()], 'hovertext': text[input[i_level + 1].selected.numpy()]} for k in keys}) trace_modes.append( {k: { 'marker.color': colors[~input[i_level + 1].selected.numpy()], 'hovertext': text[~input[i_level + 1].selected.numpy()]} for k in keys}) if i_level > 0 and v_edge and gap is not None and is_nag: # Recover the source and target positions for vertical edges # between i_level -> i_level+1 low_pos = data_i.draw_pos[data_i.selected] high_pos = super_pos[data_i.super_index[data_i.selected]] # Convert into a plotly-friendly format for 3D lines edges = np.full((low_pos.shape[0] * 3, 3), None) edges[::3] = low_pos edges[1::3] = high_pos # Color the vertical edges based on the parent cluster index # Plotly is a bit hacky with colors for 3D lines. We cannot # directly pass individual edge colors, we must instead give # edge color as an int corresponding to a colorscale list # holding plotly-friendly colors colors = data_i.super_index[data_i.selected] colors = np.repeat(colors, 3) n_colors = colors.max().item() + 1 edge_colorscale = int_to_plotly_rgb(torch.arange(n_colors)) edge_colorscale = [ [i / (n_colors - 1), f"rgb({x[0]}, {x[1]}, {x[2]})"] for i, x in enumerate(edge_colorscale)] # Since plotly 3D lines do not support opacity, we draw # these edges as super thin to limit clutter edge_width = 0.5 if v_edge_width is None else v_edge_width # Draw the level i -> i+1 vertical edges. NB we only draw # edges that are selected and do not draw the unselected # edges. This is because plotly does not handle opacity # on lines (yet), which means the unselected edges will tend # to clutter the figure. For this reason we choose to simply # not show them fig.add_trace( go.Scatter3d( x=edges[:, 0], y=edges[:, 1], z=edges[:, 2], mode='lines', line=dict( width=edge_width, color=colors, colorscale=edge_colorscale), hoverinfo='skip', showlegend=False, visible=gap is not None, )) # NB: at this point, trace_modes contains 'Level i+1' as its # last key, but we do not want vertical edges to be seen # when 'Level i+1' is selected, because it means 'Level i' # nodes are hidden keys = list(trace_modes[i_point_trace].keys())[:-1] trace_modes.append({k: {} for k in keys}) # Do not draw superedges if not required or if the i+1 level # does not have any if not h_edge or is_last_level or not input[i_level + 1].has_edges: continue # Recover the superedge source and target positions se = input[i_level + 1].edge_index se_attr = input[i_level + 1].edge_attr # Since we can only draw one edge direction (they would overlap # otherwise), we can trim the graph to only keep one direction # for each undirected edge pair. However, this requires picking # one direction for the edge attributes to we ARBITRARILY TAKE # THE MAX EDGE FEATURE for each undirected edge input[i_level + 1].raise_if_edge_keys() if se_attr is not None: se, se_attr = to_trimmed(se, edge_attr=se_attr, reduce='max') else: se = to_trimmed(se) # Recover corresponding source and target coordinates using the # previously-computed 'super_pos' cluster centroid positions s_pos = super_pos[se[0]].numpy() t_pos = super_pos[se[1]].numpy() # Convert into a plotly-friendly format for 3D lines edges = np.full((se.shape[1] * 3, 3), None) edges[::3] = s_pos edges[1::3] = t_pos if h_edge_attr and se_attr is not None: # Recover edge features and convert them to RGB colors. NB: # edge features are assumed to be in [0, 1] or [-1, 1]. # Since we only draw edges in one direction, we choose to # only represent the absolute value of the features. This # implies that features are either direction-independent or # that the edge direction only changes the sign of the # feature colors = feats_to_plotly_rgb( se_attr.abs(), normalize=True, colorscale=colorscale) colors = np.repeat(colors, 3, axis=0) edge_width = point_size if h_edge_width is None else h_edge_width else: colors = feats_to_plotly_rgb( torch.zeros(edges.shape[0]), normalize=True, colorscale=colorscale) edge_width = point_size if h_edge_width is None else h_edge_width selected_edge = input[i_level + 1].selected[se].all(axis=0) selected_edge = selected_edge.repeat_interleave(3).numpy() # Draw the level-i+1 superedges. NB we only draw edges that are # selected and do not draw the unselected edges. This is because # plotly does not handle opacity on lines (yet), which means the # unselected edges will tend to clutter the figure. For this # reason we choose to simply not show them fig.add_trace( go.Scatter3d( x=edges[selected_edge, 0], y=edges[selected_edge, 1], z=edges[selected_edge, 2], mode='lines', line=dict( width=edge_width, color=colors[selected_edge]), hoverinfo='skip', showlegend=False, visible=gap is not None, )) keys = [f"Level {i_level + 1}"] if gap is None \ else trace_modes[i_point_trace].keys() trace_modes.append({k: {} for k in keys}) # Add a trace for prediction errors. NB: it is important that this # trace is created last, as the button behavior for this one is # particular has_error = data_0.y is not None and data_0.semantic_pred is not None if has_error: # Recover prediction and ground truth and deal with potential # histograms y = data_0.y y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy() pred = data_0.semantic_pred pred = pred.argmax(1).numpy() if pred.dim() == 2 else pred.numpy() # Identify erroneous point indices ignore = void_classes if void_classes else [] ignore = ignore + [-1] indices = np.where((pred != y) & (~np.in1d(y, ignore)))[0] # Prepare the color for erroneous points error_color = 'red' if error_color is None \ else np.asarray[error_color].squeeze() # Draw the erroneous points fig.add_trace( go.Scatter3d( x=data_0.pos[indices, 0], y=data_0.pos[indices, 1], z=data_0.pos[indices, 2], mode='markers', marker=dict( symbol=point_symbol, size=int(point_size * 1.5), color=error_color, ), showlegend=False, visible=False, )) # Recover the keys for all visualization modes, as an ordered set, # with respect to their order of first appearance modes = list(dict.fromkeys([k for m in trace_modes for k in m.keys()])) # Traces color for interactive point cloud coloring def trace_update(mode): # Prepare the output args for the figure update attributes. By # default, all traces are non visible, with no color and no # hover text n_traces = len(trace_modes) out = { 'visible': [False] * (n_traces + has_error), 'marker.color': [None] * n_traces, 'hovertext': [''] * n_traces} # For each trace in 'trace_modes' see if it contains 'mode' and # adapt out accordingly for i_trace, t_modes in enumerate(trace_modes): # The trace has no action for the mode, skip it and leave # the default args for the trace if mode not in t_modes: continue # Note that a trace will only be visible for its modes # declared in trace_modes out['visible'][i_trace] = True for key, val in t_modes[mode].items(): out[key][i_trace] = val return [out, list(range(len(trace_modes)))] # Create the buttons that will serve for toggling trace visibility updatemenus = [ dict( buttons=[dict( label=mode, method='update', args=trace_update(mode)) for mode in modes if mode.lower() != 'errors'], pad={'r': 10, 't': 10}, showactive=True, type='dropdown', direction='right', xanchor='left', x=0.02, yanchor='top', y=1.02, ),] if has_error: updatemenus.append( dict( buttons=[dict( method='restyle', label='Semantic Errors', visible=True, args=[ {'visible': True, 'marker.color': error_color}, [len(trace_modes)]], args2=[ {'visible': False,}, [len(trace_modes)]],)], pad={'r': 10, 't': 10}, showactive=False, type='buttons', xanchor='left', x=1.02, yanchor='top', y=1.02, ),) fig.update_layout(updatemenus=updatemenus) # Place the legend on the left fig.update_layout( legend=dict( yanchor="middle", y=0.5, xanchor="right", x=0.99)) # Hide all axes and no background fig.update_layout( scene=dict( xaxis_title='', yaxis_title='', zaxis_title='', xaxis=dict( autorange=True, showgrid=False, ticks='', showticklabels=False, backgroundcolor="rgba(0, 0, 0, 0)"), yaxis=dict( autorange=True, showgrid=False, ticks='', showticklabels=False, backgroundcolor="rgba(0, 0, 0, 0)"), zaxis=dict( autorange=True, showgrid=False, ticks='', showticklabels=False, backgroundcolor="rgba(0, 0, 0, 0)"))) output = {'figure': fig, 'data': data_0} return output def figure_html(fig): # Save plotly figure to temp HTML fig.write_html( '/tmp/fig.html', config={'displayModeBar': False}, include_plotlyjs='cdn', full_html=False) # Read the HTML with open("/tmp/fig.html", "r") as f: fig_html = f.read() # Center the figure div for cleaner display fig_html = fig_html.replace( 'class="plotly-graph-div" style="', 'class="plotly-graph-div" style="margin:0 auto;') return fig_html def show(input, path=None, title=None, no_output=True, pt_path=None, **kwargs): """Interactive data visualization. :param input: Data or NAG object :param path: str Path to save the visualization into a sharable HTML :param title: str Figure title :param no_output: bool Set to True if you want to return the 3D Plotly figure objects :param pt_path:str Path to save the visualization-ready `Data` object as a `*.pt`. In this `Data` object, the `pos` and all `*color*` attributes will be saved, the rest is discarded. This is typically useful for exporting the visualization layers to another visualization tool :param kwargs: :return: """ # Sanitize title and path if title is None: title = "Large-scale point cloud" if path is not None: if osp.isdir(path): path = osp.join(path, f"{title}.html") else: path = osp.splitext(path)[0] + '.html' fig_html = f'