|
|
import torch |
|
|
import numpy as np |
|
|
import os.path as osp |
|
|
import plotly.graph_objects as go |
|
|
from src.data import Data, NAG, Cluster |
|
|
from src.transforms import GridSampling3D, SaveNodeIndex |
|
|
from src.utils import fast_randperm, to_trimmed |
|
|
from torch_scatter import scatter_mean |
|
|
from src.utils.color import * |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_3d( |
|
|
input, |
|
|
keys=None, |
|
|
figsize=1000, |
|
|
width=None, |
|
|
height=None, |
|
|
class_names=None, |
|
|
class_colors=None, |
|
|
stuff_classes=None, |
|
|
num_classes=None, |
|
|
hide_void_pred=False, |
|
|
voxel=-1, |
|
|
max_points=50000, |
|
|
point_size=3, |
|
|
centroid_size=None, |
|
|
error_color=None, |
|
|
centroids=False, |
|
|
h_edge=False, |
|
|
h_edge_attr=False, |
|
|
h_edge_width=None, |
|
|
v_edge=False, |
|
|
v_edge_width=None, |
|
|
gap=None, |
|
|
radius=None, |
|
|
center=None, |
|
|
select=None, |
|
|
alpha=0.1, |
|
|
alpha_super=None, |
|
|
alpha_stuff=0.2, |
|
|
point_symbol='circle', |
|
|
centroid_symbol='circle', |
|
|
colorscale='Agsunset', |
|
|
**kwargs): |
|
|
"""3D data interactive visualization. |
|
|
|
|
|
:param input: `Data` or `NAG` object |
|
|
:param keys: `List(str)` or `str` |
|
|
By default, the following attributes will be parsed in `input` |
|
|
for visualization {`pos`, `rgb`, `y`, `obj`, `semantic_pred`, |
|
|
`obj_pred`}. Yet, if `input` contains other attributes that you |
|
|
want to visualize, these can be passed as `keys`. This only |
|
|
supports point-wise attributes stored as 1D or 2D tensors. |
|
|
If the tensor contains only 1 channel, the attribute will be |
|
|
represented with a grayscale colormap. If the tensor contains |
|
|
2 or 3 channels, these will be represented as RGB, with |
|
|
an additional all-1 channel if need be. If the tensor contains |
|
|
more than 3 channels, a PCA projection to RGB will be shown. In |
|
|
any case, the attribute values will be rescaled with respect to |
|
|
their statistics before visualization, meaning that colors may |
|
|
not compare between two different plots |
|
|
:param figsize: `int` |
|
|
Figure dimensions will be `(figsize, figsize/2)` if `width` and |
|
|
`height` are not specified |
|
|
:param width: `int` |
|
|
Figure width |
|
|
:param height: `int` |
|
|
Figure height |
|
|
:param class_names: `List(str)` |
|
|
Names for point labels found in attributes `y` and |
|
|
`semantic_pred` |
|
|
:param class_colors: `List(List(int, int, int))` |
|
|
Colors palette for point labels found in attributes `y` and |
|
|
`semantic_pred` |
|
|
:param stuff_classes: `List(int)` |
|
|
Semantic labels of the classes considered as `stuff` for |
|
|
instance and panoptic segmentation. If `y` and `obj` are found |
|
|
in the point attributes, the stuff annotations will appear |
|
|
accordingly. Otherwise, stuff instance labeling will appear as |
|
|
any other object |
|
|
:param num_classes: `int` |
|
|
Number of valid classes. By convention, we assume |
|
|
`y ∈ [0, num_classes-1]` are VALID LABELS, while |
|
|
`y < 0` AND `y >= num_classes` ARE VOID LABELS |
|
|
:param hide_void_pred: `bool` |
|
|
Whether predictions on points labeled as VOID be visualized |
|
|
:param voxel: `float` |
|
|
Voxel size to subsample the point cloud to facilitate |
|
|
visualization |
|
|
:param max_points: `int` |
|
|
Maximum number of points displayed to facilitate visualization |
|
|
:param point_size: `int` or `float` |
|
|
Size of point markers |
|
|
:param centroid_size: `int` or `float` |
|
|
Size of superpoint markers |
|
|
:param error_color: `List(int, int, int)` |
|
|
Color used to identify mis-predicted points |
|
|
:param centroids: `bool` |
|
|
Whether superpoint centroids should be displayed |
|
|
:param h_edge: `bool` |
|
|
Whether horizontal edges should be displayed (only if |
|
|
`centroids=True`) |
|
|
:param h_edge_attr: `bool` |
|
|
Whether the edges should be colored by their features found in |
|
|
`edge_attr` (only if `h_edge=True`) |
|
|
:param h_edge_width: `float` |
|
|
Width of the horizontal edges, if `h_edge=True`. Defaults to |
|
|
`None`, in which case `point_size` will be used for the edge |
|
|
width |
|
|
:param v_edge: `bool` |
|
|
Whether vertical edges should be displayed (only if |
|
|
`centroids=True` and `gap` is not `None`) |
|
|
:param v_edge_width: `float` |
|
|
Width of the vertical edges, if `v_edge=True`. Defaults to |
|
|
`None`, in which case `point_size` will be used for the edge |
|
|
width |
|
|
:param gap: `List(float, float, float)` |
|
|
If `None`, the hierarchical graphs will be overlaid on the points. |
|
|
If not `None`, a 3D tensor indicating the offset by which the |
|
|
hierarchical graphs should be plotted |
|
|
:param radius: `float` |
|
|
If not `None`, only visualize a spherical sampling of the input |
|
|
data, centered on `center` and with size `radius`. This option |
|
|
is not compatible with `select` |
|
|
:param center: `List(float, float, float)` |
|
|
If `radius` is provided, only visualize a spherical sampling of |
|
|
the input data, centered on `center` and with size `radius`. If |
|
|
`None`, the center of the scene will be used |
|
|
:param select: `Tuple(int, Tensor)` |
|
|
If not `None`, will call `Data.select(*select)` or |
|
|
`NAG.select(*select)` on the input data (depending on its nature) |
|
|
and the coloring schemes will illustrate it. This option is not |
|
|
compatible with `radius` |
|
|
:param alpha: `float` |
|
|
Rules the whitening of selected points, nodes and edges (only if |
|
|
select is not `None`) |
|
|
:param alpha_super: `float` |
|
|
Rules the whitening of superpoints (only if select is not |
|
|
`None`). If `None`, alpha will be used as fallback |
|
|
:param alpha_stuff: `float` |
|
|
Rules the whitening of stuff points (only if the input |
|
|
points have `obj` and `semantic_pred` attributes, and |
|
|
`stuff_classes` or `num_classes` is specified). If `None`, |
|
|
`alpha` will be used as fallback |
|
|
:param point_symbol: `str` |
|
|
Marker symbol used for points. Must be one of |
|
|
`{'circle', 'circle-open', 'square', 'square-open', 'diamond', |
|
|
'diamond-open', 'cross', 'x'}`. Defaults to `'circle'` |
|
|
:param centroid_symbol: `str` |
|
|
Marker symbol used for centroids. Must be one of |
|
|
`{'circle', 'circle-open', 'square', 'square-open', 'diamond', |
|
|
'diamond-open', 'cross', 'x'}`. Defaults to `'circle'` |
|
|
:param colorscale: `str` |
|
|
Plotly colorscale used for coloring 1D continuous features. See |
|
|
https://plotly.com/python/builtin-colorscales for options |
|
|
:param kwargs |
|
|
|
|
|
:return: |
|
|
""" |
|
|
|
|
|
_DEFAULT_KEYS = [ |
|
|
'pos', |
|
|
'rgb', |
|
|
'y', |
|
|
'semantic_pred', |
|
|
'obj', |
|
|
'obj_pred', |
|
|
'x', |
|
|
'super_sampling', |
|
|
'super_index'] |
|
|
|
|
|
|
|
|
gap = torch.tensor(gap) if gap is not None else gap |
|
|
assert gap is None or gap.shape == torch.Size([3]) |
|
|
assert not (radius and (select is not None)), \ |
|
|
"Cannot use both a `radius` and `select` at once" |
|
|
|
|
|
|
|
|
|
|
|
input = input.clone().cpu() |
|
|
|
|
|
|
|
|
input = NAG([input]) if isinstance(input, Data) else input |
|
|
|
|
|
|
|
|
|
|
|
if input[input.num_levels - 1].is_sub: |
|
|
data_last = input[input.num_levels - 1] |
|
|
sub = Cluster( |
|
|
data_last.super_index, torch.arange(data_last.num_nodes), |
|
|
dense=True) |
|
|
obj = data_last.obj.merge(data_last.super_index) \ |
|
|
if data_last.obj else None |
|
|
pos = scatter_mean(data_last.pos, data_last.super_index, dim=0) |
|
|
input = NAG(input.to_list() + [Data(pos=pos, sub=sub, obj=obj)]) |
|
|
is_nag = isinstance(input, NAG) |
|
|
num_levels = input.num_levels if is_nag else 1 |
|
|
|
|
|
|
|
|
alpha = max(0, min(alpha, 1)) |
|
|
alpha_super = max(0, min(alpha_super, 1)) if alpha_super else alpha |
|
|
alpha_stuff = max(0, min(alpha_stuff, 1)) if alpha_stuff else alpha |
|
|
|
|
|
|
|
|
|
|
|
if radius is not None: |
|
|
|
|
|
if center is None: |
|
|
hi = input[0].pos.max(dim=0).values |
|
|
lo = input[0].pos.min(dim=0).values |
|
|
center = (hi + lo) / 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
center[2] = input[0].pos[:, 2].mean() |
|
|
else: |
|
|
center = torch.as_tensor(center).cpu() |
|
|
center = center.view(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
mask = torch.where( |
|
|
torch.linalg.norm(input[0].pos - center, dim=1) < radius)[0] |
|
|
|
|
|
|
|
|
input = input.select(0, mask) |
|
|
|
|
|
|
|
|
|
|
|
if select is not None and is_nag: |
|
|
|
|
|
|
|
|
nag_temp = input.clone() |
|
|
for i in range(nag_temp.num_levels): |
|
|
nag_temp._list[i] = SaveNodeIndex()(nag_temp[i]) |
|
|
|
|
|
|
|
|
nag_temp = nag_temp.select(*select) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(num_levels): |
|
|
selected = torch.zeros(input[i].num_nodes, dtype=torch.bool) |
|
|
selected[nag_temp[i][SaveNodeIndex.DEFAULT_KEY]] = True |
|
|
input[i].selected = selected |
|
|
|
|
|
del nag_temp, selected |
|
|
|
|
|
elif select is not None and not is_nag: |
|
|
|
|
|
|
|
|
data_temp = SaveNodeIndex()(Data(pos=input.pos.clone())) |
|
|
|
|
|
|
|
|
data_temp = data_temp.select(select)[0] |
|
|
|
|
|
|
|
|
|
|
|
selected = torch.zeros(input.num_nodes, dtype=torch.bool) |
|
|
selected[data_temp[SaveNodeIndex.DEFAULT_KEY]] = True |
|
|
input.selected = selected |
|
|
|
|
|
del data_temp, selected |
|
|
|
|
|
elif is_nag: |
|
|
for i in range(num_levels): |
|
|
input[i].selected = torch.ones( |
|
|
input[i].num_nodes, dtype=torch.bool) |
|
|
|
|
|
else: |
|
|
input.selected = torch.ones(input.num_nodes, dtype=torch.bool) |
|
|
|
|
|
|
|
|
|
|
|
data_0 = input[0] if is_nag else input |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idx = torch.arange(data_0.num_points) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if voxel > 0: |
|
|
data_temp = SaveNodeIndex()(Data(pos=data_0.pos.clone())) |
|
|
data_temp = GridSampling3D(voxel, mode='last')(data_temp) |
|
|
idx = data_temp[SaveNodeIndex.DEFAULT_KEY] |
|
|
del data_temp |
|
|
|
|
|
|
|
|
|
|
|
if idx.shape[0] > max_points: |
|
|
idx = idx[fast_randperm(idx.shape[0])[:max_points]] |
|
|
|
|
|
|
|
|
|
|
|
if idx.shape[0] < data_0.num_points: |
|
|
input = input.select(0, idx) if is_nag else input.select(idx)[0] |
|
|
data_0 = input[0] if is_nag else input |
|
|
|
|
|
|
|
|
data_0.pos = (data_0.pos * 100).round() / 100 |
|
|
|
|
|
|
|
|
if class_colors is not None and not isinstance(class_colors[0], str): |
|
|
class_colors = np.asarray(class_colors) |
|
|
else: |
|
|
class_colors = None |
|
|
|
|
|
|
|
|
width = width if width and height else figsize |
|
|
height = height if width and height else int(figsize / 2) |
|
|
margin = int(0.02 * min(width, height)) |
|
|
layout = go.Layout( |
|
|
width=width, |
|
|
height=height, |
|
|
scene=dict(aspectmode='data', ), |
|
|
margin=dict(l=margin, r=margin, b=margin, t=margin), |
|
|
uirevision=True) |
|
|
fig = go.Figure(layout=layout) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trace_modes = [] |
|
|
i_point_trace = 0 |
|
|
i_unselected_point_trace = 1 |
|
|
|
|
|
|
|
|
void_classes = [num_classes] if num_classes else [] |
|
|
|
|
|
|
|
|
mini = data_0.pos.min(dim=0).values |
|
|
maxi = data_0.pos.max(dim=0).values |
|
|
colors = (data_0.pos - mini) / (maxi - mini + 1e-6) |
|
|
colors = rgb_to_plotly_rgb(colors) |
|
|
data_0.pos_colors = colors |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=data_0.pos[data_0.selected, 0], |
|
|
y=data_0.pos[data_0.selected, 1], |
|
|
z=data_0.pos[data_0.selected, 2], |
|
|
mode='markers', |
|
|
marker=dict( |
|
|
symbol=point_symbol, |
|
|
size=point_size, |
|
|
color=colors[data_0.selected]), |
|
|
hoverinfo='x+y+z+text', |
|
|
hovertext=None, |
|
|
showlegend=False, |
|
|
visible=True, )) |
|
|
trace_modes.append({ |
|
|
'Position RGB': { |
|
|
'marker.color': colors[data_0.selected], 'hovertext': None}}) |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=data_0.pos[~data_0.selected, 0], |
|
|
y=data_0.pos[~data_0.selected, 1], |
|
|
z=data_0.pos[~data_0.selected, 2], |
|
|
mode='markers', |
|
|
marker=dict( |
|
|
symbol=point_symbol, |
|
|
size=point_size, |
|
|
color=colors[~data_0.selected], |
|
|
opacity=alpha), |
|
|
hoverinfo='x+y+z+text', |
|
|
hovertext=None, |
|
|
showlegend=False, |
|
|
visible=True, )) |
|
|
trace_modes.append({ |
|
|
'Position RGB': { |
|
|
'marker.color': colors[~data_0.selected], 'hovertext': None}}) |
|
|
|
|
|
|
|
|
if data_0.rgb is not None: |
|
|
colors = data_0.rgb |
|
|
colors = rgb_to_plotly_rgb(colors) |
|
|
data_0.rgb_colors = colors |
|
|
trace_modes[i_point_trace]['RGB'] = { |
|
|
'marker.color': colors[data_0.selected], 'hovertext': None} |
|
|
trace_modes[i_unselected_point_trace]['RGB'] = { |
|
|
'marker.color': colors[~data_0.selected], 'hovertext': None} |
|
|
|
|
|
|
|
|
|
|
|
if data_0.y is not None: |
|
|
y = data_0.y |
|
|
y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy() |
|
|
colors = class_colors[y] if class_colors is not None \ |
|
|
else int_to_plotly_rgb(torch.LongTensor(y)) |
|
|
data_0.y_colors = colors |
|
|
if class_names is None: |
|
|
text = np.array([f"Class {i}" for i in range(y.max() + 1)]) |
|
|
else: |
|
|
text = np.array([str.title(c) for c in class_names]) |
|
|
text = text[y] |
|
|
trace_modes[i_point_trace]['Semantic'] = { |
|
|
'marker.color': colors[data_0.selected], |
|
|
'hovertext': text[data_0.selected]} |
|
|
trace_modes[i_unselected_point_trace]['Semantic'] = { |
|
|
'marker.color': colors[~data_0.selected], |
|
|
'hovertext': text[~data_0.selected]} |
|
|
|
|
|
|
|
|
|
|
|
if data_0.semantic_pred is not None: |
|
|
pred = data_0.semantic_pred |
|
|
pred = pred.argmax(1).numpy() if pred.dim() == 2 else pred.numpy() |
|
|
|
|
|
|
|
|
|
|
|
if data_0.y is not None and hide_void_pred: |
|
|
|
|
|
y_gt = data_0.y |
|
|
y_gt = y_gt.argmax(1) if y_gt.dim() == 2 else y_gt |
|
|
|
|
|
|
|
|
|
|
|
is_void = np.zeros(y_gt.max() + 1, dtype='bool') |
|
|
for i in void_classes: |
|
|
if i < is_void.shape[0]: |
|
|
is_void[i] = True |
|
|
is_void = is_void[y_gt] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred[is_void] = y_gt[is_void] |
|
|
|
|
|
colors = class_colors[pred] if class_colors is not None \ |
|
|
else int_to_plotly_rgb(torch.LongTensor(pred)) |
|
|
data_0.pred_colors = colors |
|
|
if class_names is None: |
|
|
text = np.array([f"Class {i}" for i in range(pred.max() + 1)]) |
|
|
else: |
|
|
text = np.array([str.title(c) for c in class_names]) |
|
|
text = text[pred] |
|
|
trace_modes[i_point_trace]['Semantic Pred.'] = { |
|
|
'marker.color': colors[data_0.selected], |
|
|
'hovertext': text[data_0.selected]} |
|
|
trace_modes[i_unselected_point_trace]['Semantic Pred.'] = { |
|
|
'marker.color': colors[~data_0.selected], |
|
|
'hovertext': text[~data_0.selected]} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if data_0.obj is not None and (class_names is None or data_0.y is None): |
|
|
obj = data_0.obj if isinstance(data_0.obj, torch.Tensor) \ |
|
|
else data_0.obj.major(num_classes=num_classes)[0] |
|
|
colors = int_to_plotly_rgb(obj) |
|
|
data_0.obj_colors = colors |
|
|
text = np.array([f"Object {o}" for o in obj]) |
|
|
trace_modes[i_point_trace]['Panoptic'] = { |
|
|
'marker.color': colors[data_0.selected], |
|
|
'hovertext': text[data_0.selected]} |
|
|
trace_modes[i_unselected_point_trace]['Panoptic'] = { |
|
|
'marker.color': colors[~data_0.selected], |
|
|
'hovertext': text[~data_0.selected]} |
|
|
elif data_0.obj is not None: |
|
|
|
|
|
obj = data_0.obj.major(num_classes=num_classes)[0] |
|
|
colors_thing = int_to_plotly_rgb(obj) |
|
|
text_thing = np.array([f"Object {o}" for o in obj]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stuff_classes = stuff_classes if stuff_classes is not None else [] |
|
|
stuff_classes = list(set(stuff_classes).union(set(void_classes))) |
|
|
|
|
|
|
|
|
y = data_0.y |
|
|
y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy() |
|
|
colors_stuff = class_colors[y] if class_colors is not None \ |
|
|
else int_to_plotly_rgb(torch.LongTensor(y)) |
|
|
if class_names is None: |
|
|
text_stuff = np.array([ |
|
|
f"{'Void' if i in void_classes else 'Stuff'} - Class {i}" |
|
|
for i in range(y.max() + 1)]) |
|
|
else: |
|
|
text_stuff = np.array([ |
|
|
f"{'Void' if i in void_classes else 'Stuff'} - {str.title(c)}" |
|
|
for i, c in enumerate(class_names)]) |
|
|
text_stuff = text_stuff[y] |
|
|
|
|
|
|
|
|
colors_stuff = colors_stuff.astype('float') |
|
|
white = np.full((colors_stuff.shape[0], 3), 255, dtype='float') |
|
|
colors_stuff = colors_stuff * alpha_stuff + white * (1 - alpha_stuff) |
|
|
colors_stuff = colors_stuff.astype('int64') |
|
|
|
|
|
|
|
|
stuff_classes = np.asarray([i for i in stuff_classes if i <= y.max()]) |
|
|
is_stuff = np.zeros(y.max() + 1, dtype='bool') |
|
|
for i in stuff_classes: |
|
|
if i < is_stuff.shape[0]: |
|
|
is_stuff[i] = True |
|
|
is_stuff = is_stuff[y] |
|
|
|
|
|
|
|
|
colors = colors_thing |
|
|
text = text_thing |
|
|
colors[is_stuff] = colors_stuff[is_stuff] |
|
|
text[is_stuff] = text_stuff[is_stuff] |
|
|
data_0.obj_colors = colors |
|
|
|
|
|
|
|
|
trace_modes[i_point_trace]['Panoptic'] = { |
|
|
'marker.color': colors[data_0.selected], |
|
|
'hovertext': text[data_0.selected]} |
|
|
trace_modes[i_unselected_point_trace]['Panoptic'] = { |
|
|
'marker.color': colors[~data_0.selected], |
|
|
'hovertext': text[~data_0.selected]} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(data_0, 'obj_pred', None) is not None and class_names is None: |
|
|
obj, _, y = data_0.obj_pred.major(num_classes=num_classes) |
|
|
colors = int_to_plotly_rgb(obj) |
|
|
data_0.obj_pred_colors = colors |
|
|
text = np.array([f"Object {o}" for o in obj]) |
|
|
trace_modes[i_point_trace]['Panoptic Pred.'] = { |
|
|
'marker.color': colors[data_0.selected], |
|
|
'hovertext': text[data_0.selected]} |
|
|
trace_modes[i_unselected_point_trace]['Panoptic Pred.'] = { |
|
|
'marker.color': colors[~data_0.selected], |
|
|
'hovertext': text[~data_0.selected]} |
|
|
elif getattr(data_0, 'obj_pred', None) is not None: |
|
|
|
|
|
obj, _, y = data_0.obj_pred.major(num_classes=num_classes) |
|
|
colors_thing = int_to_plotly_rgb(obj) |
|
|
text_thing = np.array([f"Object {o}" for o in obj]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stuff_classes = stuff_classes if stuff_classes is not None else [] |
|
|
stuff_classes = list(set(stuff_classes).union(set(void_classes))) |
|
|
|
|
|
|
|
|
|
|
|
if data_0.y is not None and hide_void_pred: |
|
|
|
|
|
y_gt = data_0.y |
|
|
y_gt = y_gt.argmax(1) if y_gt.dim() == 2 else y_gt |
|
|
|
|
|
|
|
|
|
|
|
is_void = np.zeros(y_gt.max() + 1, dtype='bool') |
|
|
for i in void_classes: |
|
|
if i < is_void.shape[0]: |
|
|
is_void[i] = True |
|
|
is_void = is_void[y_gt] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y[is_void] = y_gt[is_void] |
|
|
|
|
|
|
|
|
colors_stuff = class_colors[y] if class_colors is not None \ |
|
|
else int_to_plotly_rgb(torch.LongTensor(y)) |
|
|
if class_names is None: |
|
|
text_stuff = np.array([ |
|
|
f"{'Void' if i in void_classes else 'Stuff'} - Class {i}" |
|
|
for i in range(y.max() + 1)]) |
|
|
else: |
|
|
text_stuff = np.array([ |
|
|
f"{'Void' if i in void_classes else 'Stuff'} - {str.title(c)}" |
|
|
for i, c in enumerate(class_names)]) |
|
|
text_stuff = text_stuff[y] |
|
|
|
|
|
|
|
|
colors_stuff = colors_stuff.astype('float') |
|
|
white = np.full((colors_stuff.shape[0], 3), 255, dtype='float') |
|
|
colors_stuff = colors_stuff * alpha_stuff + white * (1 - alpha_stuff) |
|
|
colors_stuff = colors_stuff.astype('int64') |
|
|
|
|
|
|
|
|
stuff_classes = np.asarray([i for i in stuff_classes if i <= y.max()]) |
|
|
is_stuff = np.zeros(y.max() + 1, dtype='bool') |
|
|
for i in stuff_classes: |
|
|
if i < is_stuff.shape[0]: |
|
|
is_stuff[i] = True |
|
|
is_stuff = is_stuff[y] |
|
|
|
|
|
|
|
|
colors = colors_thing |
|
|
text = text_thing |
|
|
colors[is_stuff] = colors_stuff[is_stuff] |
|
|
text[is_stuff] = text_stuff[is_stuff] |
|
|
data_0.obj_pred_colors = colors |
|
|
|
|
|
|
|
|
trace_modes[i_point_trace]['Panoptic Pred.'] = { |
|
|
'marker.color': colors[data_0.selected], |
|
|
'hovertext': text[data_0.selected]} |
|
|
trace_modes[i_unselected_point_trace]['Panoptic Pred.'] = { |
|
|
'marker.color': colors[~data_0.selected], |
|
|
'hovertext': text[~data_0.selected]} |
|
|
|
|
|
|
|
|
if data_0.x is not None: |
|
|
colors = feats_to_plotly_rgb( |
|
|
data_0.x, normalize=True, colorscale=colorscale) |
|
|
data_0.x_colors = colors |
|
|
trace_modes[i_point_trace]['Features 3D'] = { |
|
|
'marker.color': colors[data_0.selected], 'hovertext': None} |
|
|
trace_modes[i_unselected_point_trace]['Features 3D'] = { |
|
|
'marker.color': colors[~data_0.selected], 'hovertext': None} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if keys is None: |
|
|
keys = [] |
|
|
elif isinstance(keys, str): |
|
|
keys = [keys] |
|
|
keys = [k for k in keys if k not in _DEFAULT_KEYS] |
|
|
for key in keys: |
|
|
val = getattr(data_0, key, None) |
|
|
if (val is None or not torch.is_tensor(val) |
|
|
or val.shape[0] != data_0.num_points): |
|
|
continue |
|
|
colors = feats_to_plotly_rgb( |
|
|
val, normalize=True, colorscale=colorscale) |
|
|
data_0[f"{key}_colors"] = colors |
|
|
trace_modes[i_point_trace][str(key).title()] = { |
|
|
'marker.color': colors[data_0.selected], 'hovertext': None} |
|
|
trace_modes[i_unselected_point_trace][str(key).title()] = { |
|
|
'marker.color': colors[~data_0.selected], 'hovertext': None} |
|
|
|
|
|
|
|
|
if 'super_sampling' in data_0.keys: |
|
|
colors = data_0.super_sampling |
|
|
colors = int_to_plotly_rgb(colors) |
|
|
colors[data_0.super_sampling == -1] = 230 |
|
|
data_0.super_sampling_colors = colors |
|
|
trace_modes[i_point_trace]['Super sampling'] = { |
|
|
'marker.color': colors[data_0.selected], 'hovertext': None} |
|
|
trace_modes[i_unselected_point_trace]['Super sampling'] = { |
|
|
'marker.color': colors[~data_0.selected], 'hovertext': None} |
|
|
|
|
|
|
|
|
for i_level, data_i in enumerate(input if is_nag else []): |
|
|
|
|
|
|
|
|
if not data_i.is_sub: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if i_level == 0: |
|
|
super_index = data_i.super_index |
|
|
else: |
|
|
super_index = data_i.super_index[super_index] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colors = int_to_plotly_rgb(super_index) |
|
|
data_0[f"{i_level}_level_colors"] = colors |
|
|
text = np.array([f"↑: {i}" for i in super_index]) |
|
|
trace_modes[i_point_trace][f"Level {i_level + 1}"] = { |
|
|
'marker.color': colors[data_0.selected], |
|
|
'hovertext': text[data_0.selected]} |
|
|
trace_modes[i_unselected_point_trace][f"Level {i_level + 1}"] = { |
|
|
'marker.color': colors[~data_0.selected], |
|
|
'hovertext': text[~data_0.selected]} |
|
|
|
|
|
|
|
|
|
|
|
if not centroids: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_levels = input.num_levels |
|
|
is_last_level = i_level == num_levels - 1 |
|
|
if is_last_level or input[i_level + 1].pos is None: |
|
|
super_pos = scatter_mean(data_0.pos, super_index, dim=0) |
|
|
else: |
|
|
super_pos = input[i_level + 1].pos |
|
|
|
|
|
|
|
|
if gap is not None: |
|
|
|
|
|
super_pos += gap * (i_level + 1) |
|
|
|
|
|
|
|
|
super_pos = (super_pos * 100).round() / 100 |
|
|
|
|
|
|
|
|
|
|
|
input[i_level + 1].draw_pos = super_pos |
|
|
|
|
|
|
|
|
idx_sp = torch.arange(data_i.super_index.max() + 1) |
|
|
colors = int_to_plotly_rgb(idx_sp) |
|
|
text = np.array([f"<b>#: {i}</b>" for i in idx_sp]) |
|
|
ball_size = centroid_size if centroid_size else point_size * 3 |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=super_pos[input[i_level + 1].selected, 0], |
|
|
y=super_pos[input[i_level + 1].selected, 1], |
|
|
z=super_pos[input[i_level + 1].selected, 2], |
|
|
mode='markers+text', |
|
|
marker=dict( |
|
|
symbol=centroid_symbol, |
|
|
size=ball_size, |
|
|
color=colors[input[i_level + 1].selected.numpy()], |
|
|
line_width=min(ball_size / 2, 2), |
|
|
line_color='black'), |
|
|
textposition="bottom center", |
|
|
textfont=dict(size=16), |
|
|
hovertext=text, |
|
|
hoverinfo='x+y+z+text', |
|
|
showlegend=False, |
|
|
visible=gap is not None, )) |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=super_pos[~input[i_level + 1].selected, 0], |
|
|
y=super_pos[~input[i_level + 1].selected, 1], |
|
|
z=super_pos[~input[i_level + 1].selected, 2], |
|
|
mode='markers+text', |
|
|
marker=dict( |
|
|
symbol=centroid_symbol, |
|
|
size=ball_size, |
|
|
color=colors[~input[i_level + 1].selected.numpy()], |
|
|
line_width=min(ball_size / 2, 2), |
|
|
line_color='black', |
|
|
opacity=alpha_super), |
|
|
textposition="bottom center", |
|
|
textfont=dict(size=16), |
|
|
hovertext=text, |
|
|
hoverinfo='x+y+z+text', |
|
|
showlegend=False, |
|
|
visible=gap is not None, )) |
|
|
|
|
|
keys = [f"Level {i_level + 1}"] if gap is None \ |
|
|
else trace_modes[i_point_trace].keys() |
|
|
trace_modes.append( |
|
|
{k: { |
|
|
'marker.color': colors[input[i_level + 1].selected.numpy()], |
|
|
'hovertext': text[input[i_level + 1].selected.numpy()]} |
|
|
for k in keys}) |
|
|
trace_modes.append( |
|
|
{k: { |
|
|
'marker.color': colors[~input[i_level + 1].selected.numpy()], |
|
|
'hovertext': text[~input[i_level + 1].selected.numpy()]} |
|
|
for k in keys}) |
|
|
|
|
|
if i_level > 0 and v_edge and gap is not None and is_nag: |
|
|
|
|
|
|
|
|
low_pos = data_i.draw_pos[data_i.selected] |
|
|
high_pos = super_pos[data_i.super_index[data_i.selected]] |
|
|
|
|
|
|
|
|
edges = np.full((low_pos.shape[0] * 3, 3), None) |
|
|
edges[::3] = low_pos |
|
|
edges[1::3] = high_pos |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colors = data_i.super_index[data_i.selected] |
|
|
colors = np.repeat(colors, 3) |
|
|
n_colors = colors.max().item() + 1 |
|
|
edge_colorscale = int_to_plotly_rgb(torch.arange(n_colors)) |
|
|
edge_colorscale = [ |
|
|
[i / (n_colors - 1), f"rgb({x[0]}, {x[1]}, {x[2]})"] |
|
|
for i, x in enumerate(edge_colorscale)] |
|
|
|
|
|
|
|
|
|
|
|
edge_width = 0.5 if v_edge_width is None else v_edge_width |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=edges[:, 0], |
|
|
y=edges[:, 1], |
|
|
z=edges[:, 2], |
|
|
mode='lines', |
|
|
line=dict( |
|
|
width=edge_width, |
|
|
color=colors, |
|
|
colorscale=edge_colorscale), |
|
|
hoverinfo='skip', |
|
|
showlegend=False, |
|
|
visible=gap is not None, )) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
keys = list(trace_modes[i_point_trace].keys())[:-1] |
|
|
trace_modes.append({k: {} for k in keys}) |
|
|
|
|
|
|
|
|
|
|
|
if not h_edge or is_last_level or not input[i_level + 1].has_edges: |
|
|
continue |
|
|
|
|
|
|
|
|
se = input[i_level + 1].edge_index |
|
|
se_attr = input[i_level + 1].edge_attr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input[i_level + 1].raise_if_edge_keys() |
|
|
if se_attr is not None: |
|
|
se, se_attr = to_trimmed(se, edge_attr=se_attr, reduce='max') |
|
|
else: |
|
|
se = to_trimmed(se) |
|
|
|
|
|
|
|
|
|
|
|
s_pos = super_pos[se[0]].numpy() |
|
|
t_pos = super_pos[se[1]].numpy() |
|
|
|
|
|
|
|
|
edges = np.full((se.shape[1] * 3, 3), None) |
|
|
edges[::3] = s_pos |
|
|
edges[1::3] = t_pos |
|
|
|
|
|
if h_edge_attr and se_attr is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colors = feats_to_plotly_rgb( |
|
|
se_attr.abs(), normalize=True, colorscale=colorscale) |
|
|
colors = np.repeat(colors, 3, axis=0) |
|
|
edge_width = point_size if h_edge_width is None else h_edge_width |
|
|
|
|
|
else: |
|
|
colors = feats_to_plotly_rgb( |
|
|
torch.zeros(edges.shape[0]), normalize=True, colorscale=colorscale) |
|
|
edge_width = point_size if h_edge_width is None else h_edge_width |
|
|
|
|
|
selected_edge = input[i_level + 1].selected[se].all(axis=0) |
|
|
selected_edge = selected_edge.repeat_interleave(3).numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=edges[selected_edge, 0], |
|
|
y=edges[selected_edge, 1], |
|
|
z=edges[selected_edge, 2], |
|
|
mode='lines', |
|
|
line=dict( |
|
|
width=edge_width, |
|
|
color=colors[selected_edge]), |
|
|
hoverinfo='skip', |
|
|
showlegend=False, |
|
|
visible=gap is not None, )) |
|
|
|
|
|
keys = [f"Level {i_level + 1}"] if gap is None \ |
|
|
else trace_modes[i_point_trace].keys() |
|
|
trace_modes.append({k: {} for k in keys}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
has_error = data_0.y is not None and data_0.semantic_pred is not None |
|
|
if has_error: |
|
|
|
|
|
|
|
|
|
|
|
y = data_0.y |
|
|
y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy() |
|
|
pred = data_0.semantic_pred |
|
|
pred = pred.argmax(1).numpy() if pred.dim() == 2 else pred.numpy() |
|
|
|
|
|
|
|
|
ignore = void_classes if void_classes else [] |
|
|
ignore = ignore + [-1] |
|
|
indices = np.where((pred != y) & (~np.in1d(y, ignore)))[0] |
|
|
|
|
|
|
|
|
error_color = 'red' if error_color is None \ |
|
|
else np.asarray[error_color].squeeze() |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=data_0.pos[indices, 0], |
|
|
y=data_0.pos[indices, 1], |
|
|
z=data_0.pos[indices, 2], |
|
|
mode='markers', |
|
|
marker=dict( |
|
|
symbol=point_symbol, |
|
|
size=int(point_size * 1.5), |
|
|
color=error_color, ), |
|
|
showlegend=False, |
|
|
visible=False, )) |
|
|
|
|
|
|
|
|
|
|
|
modes = list(dict.fromkeys([k for m in trace_modes for k in m.keys()])) |
|
|
|
|
|
|
|
|
def trace_update(mode): |
|
|
|
|
|
|
|
|
|
|
|
n_traces = len(trace_modes) |
|
|
out = { |
|
|
'visible': [False] * (n_traces + has_error), |
|
|
'marker.color': [None] * n_traces, |
|
|
'hovertext': [''] * n_traces} |
|
|
|
|
|
|
|
|
|
|
|
for i_trace, t_modes in enumerate(trace_modes): |
|
|
|
|
|
|
|
|
|
|
|
if mode not in t_modes: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
out['visible'][i_trace] = True |
|
|
for key, val in t_modes[mode].items(): |
|
|
out[key][i_trace] = val |
|
|
|
|
|
return [out, list(range(len(trace_modes)))] |
|
|
|
|
|
|
|
|
updatemenus = [ |
|
|
dict( |
|
|
buttons=[dict( |
|
|
label=mode, method='update', args=trace_update(mode)) |
|
|
for mode in modes if mode.lower() != 'errors'], |
|
|
pad={'r': 10, 't': 10}, |
|
|
showactive=True, |
|
|
type='dropdown', |
|
|
direction='right', |
|
|
xanchor='left', |
|
|
x=0.02, |
|
|
yanchor='top', |
|
|
y=1.02, ),] |
|
|
|
|
|
if has_error: |
|
|
updatemenus.append( |
|
|
dict( |
|
|
buttons=[dict( |
|
|
method='restyle', |
|
|
label='Semantic Errors', |
|
|
visible=True, |
|
|
args=[ |
|
|
{'visible': True, 'marker.color': error_color}, |
|
|
[len(trace_modes)]], |
|
|
args2=[ |
|
|
{'visible': False,}, |
|
|
[len(trace_modes)]],)], |
|
|
pad={'r': 10, 't': 10}, |
|
|
showactive=False, |
|
|
type='buttons', |
|
|
xanchor='left', |
|
|
x=1.02, |
|
|
yanchor='top', |
|
|
y=1.02, ),) |
|
|
|
|
|
fig.update_layout(updatemenus=updatemenus) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
legend=dict( |
|
|
yanchor="middle", |
|
|
y=0.5, |
|
|
xanchor="right", |
|
|
x=0.99)) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
scene=dict( |
|
|
xaxis_title='', |
|
|
yaxis_title='', |
|
|
zaxis_title='', |
|
|
xaxis=dict( |
|
|
autorange=True, |
|
|
showgrid=False, |
|
|
ticks='', |
|
|
showticklabels=False, |
|
|
backgroundcolor="rgba(0, 0, 0, 0)"), |
|
|
yaxis=dict( |
|
|
autorange=True, |
|
|
showgrid=False, |
|
|
ticks='', |
|
|
showticklabels=False, |
|
|
backgroundcolor="rgba(0, 0, 0, 0)"), |
|
|
zaxis=dict( |
|
|
autorange=True, |
|
|
showgrid=False, |
|
|
ticks='', |
|
|
showticklabels=False, |
|
|
backgroundcolor="rgba(0, 0, 0, 0)"))) |
|
|
|
|
|
output = {'figure': fig, 'data': data_0} |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def figure_html(fig): |
|
|
|
|
|
fig.write_html( |
|
|
'/tmp/fig.html', |
|
|
config={'displayModeBar': False}, |
|
|
include_plotlyjs='cdn', |
|
|
full_html=False) |
|
|
|
|
|
|
|
|
with open("/tmp/fig.html", "r") as f: |
|
|
fig_html = f.read() |
|
|
|
|
|
|
|
|
fig_html = fig_html.replace( |
|
|
'class="plotly-graph-div" style="', |
|
|
'class="plotly-graph-div" style="margin:0 auto;') |
|
|
|
|
|
return fig_html |
|
|
|
|
|
|
|
|
def show(input, path=None, title=None, no_output=True, pt_path=None, **kwargs): |
|
|
"""Interactive data visualization. |
|
|
|
|
|
:param input: Data or NAG object |
|
|
:param path: str |
|
|
Path to save the visualization into a sharable HTML |
|
|
:param title: str |
|
|
Figure title |
|
|
:param no_output: bool |
|
|
Set to True if you want to return the 3D Plotly figure objects |
|
|
:param pt_path:str |
|
|
Path to save the visualization-ready `Data` object as a `*.pt`. |
|
|
In this `Data` object, the `pos` and all `*color*` attributes |
|
|
will be saved, the rest is discarded. This is typically useful |
|
|
for exporting the visualization layers to another visualization |
|
|
tool |
|
|
:param kwargs: |
|
|
:return: |
|
|
""" |
|
|
|
|
|
if title is None: |
|
|
title = "Large-scale point cloud" |
|
|
if path is not None: |
|
|
if osp.isdir(path): |
|
|
path = osp.join(path, f"{title}.html") |
|
|
else: |
|
|
path = osp.splitext(path)[0] + '.html' |
|
|
fig_html = f'<h1 style="text-align: center;">{title}</h1>' |
|
|
|
|
|
|
|
|
out_3d = visualize_3d(input, **kwargs) |
|
|
if no_output: |
|
|
if path is None: |
|
|
out_3d['figure'].show(config={'displayModeBar': False}) |
|
|
else: |
|
|
fig_html += figure_html(out_3d['figure']) |
|
|
|
|
|
if path is not None: |
|
|
with open(path, "w") as f: |
|
|
f.write(fig_html) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pt_path is not None: |
|
|
if osp.isdir(pt_path): |
|
|
pt_path = osp.join(pt_path, f"viz_data.pt") |
|
|
else: |
|
|
pt_path = osp.splitext(pt_path)[0] + '.pt' |
|
|
|
|
|
data = {} |
|
|
for key in out_3d['data'].keys: |
|
|
if key == 'pos' or 'color' in key: |
|
|
data[key] = out_3d['data'][key] |
|
|
|
|
|
torch.save(data, pt_path) |
|
|
|
|
|
if not no_output: |
|
|
return out_3d |
|
|
|
|
|
return |
|
|
|