English
File size: 46,030 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
import torch
import numpy as np
import os.path as osp
import plotly.graph_objects as go
from src.data import Data, NAG, Cluster
from src.transforms import GridSampling3D, SaveNodeIndex
from src.utils import fast_randperm, to_trimmed
from torch_scatter import scatter_mean
from src.utils.color import *


# TODO: To go further with ipwidgets :
#  - https://plotly.com/python/figurewidget-app/
#  - https://ipywidgets.readthedocs.io/en/stable/


def visualize_3d(
        input,
        keys=None,
        figsize=1000,
        width=None,
        height=None,
        class_names=None,
        class_colors=None,
        stuff_classes=None,
        num_classes=None,
        hide_void_pred=False,
        voxel=-1,
        max_points=50000,
        point_size=3,
        centroid_size=None,
        error_color=None,
        centroids=False,
        h_edge=False,
        h_edge_attr=False,
        h_edge_width=None,
        v_edge=False,
        v_edge_width=None,
        gap=None,
        radius=None,
        center=None,
        select=None,
        alpha=0.1,
        alpha_super=None,
        alpha_stuff=0.2,
        point_symbol='circle',
        centroid_symbol='circle',
        colorscale='Agsunset',
        **kwargs):
    """3D data interactive visualization.

    :param input: `Data` or `NAG` object
    :param keys: `List(str)` or `str`
        By default, the following attributes will be parsed in `input`
        for visualization {`pos`, `rgb`, `y`, `obj`, `semantic_pred`,
        `obj_pred`}. Yet, if `input` contains other attributes that you
        want to visualize, these can be passed as `keys`. This only
        supports point-wise attributes stored as 1D or 2D tensors.
        If the tensor contains only 1 channel, the attribute will be
        represented with a grayscale colormap. If the tensor contains
        2 or 3 channels, these will be represented as RGB, with
        an additional all-1 channel if need be. If the tensor contains
        more than 3 channels, a PCA projection to RGB will be shown. In
        any case, the attribute values will be rescaled with respect to
        their statistics before visualization, meaning that colors may
        not compare between two different plots
    :param figsize: `int`
        Figure dimensions will be `(figsize, figsize/2)` if `width` and
        `height` are not specified
    :param width: `int`
        Figure width
    :param height: `int`
        Figure height
    :param class_names: `List(str)`
        Names for point labels found in attributes `y` and
        `semantic_pred`
    :param class_colors: `List(List(int, int, int))`
        Colors palette for point labels found in attributes `y` and
        `semantic_pred`
    :param stuff_classes: `List(int)`
        Semantic labels of the classes considered as `stuff` for
        instance and panoptic segmentation. If `y` and `obj` are found
        in the point attributes, the stuff annotations will appear
        accordingly. Otherwise, stuff instance labeling will appear as
        any other object
    :param num_classes: `int`
        Number of valid classes. By convention, we assume
        `y ∈ [0, num_classes-1]` are VALID LABELS, while
        `y < 0` AND `y >= num_classes` ARE VOID LABELS
    :param hide_void_pred: `bool`
        Whether predictions on points labeled as VOID be visualized
    :param voxel: `float`
        Voxel size to subsample the point cloud to facilitate
        visualization
    :param max_points: `int`
        Maximum number of points displayed to facilitate visualization
    :param point_size: `int` or `float`
        Size of point markers
    :param centroid_size: `int` or `float`
        Size of superpoint markers
    :param error_color: `List(int, int, int)`
        Color used to identify mis-predicted points
    :param centroids: `bool`
        Whether superpoint centroids should be displayed
    :param h_edge: `bool`
        Whether horizontal edges should be displayed (only if
        `centroids=True`)
    :param h_edge_attr: `bool`
        Whether the edges should be colored by their features found in
        `edge_attr` (only if `h_edge=True`)
    :param h_edge_width: `float`
        Width of the horizontal edges, if `h_edge=True`. Defaults to
        `None`, in which case `point_size` will be used for the edge
        width
    :param v_edge: `bool`
        Whether vertical edges should be displayed (only if
        `centroids=True` and `gap` is not `None`)
    :param v_edge_width: `float`
        Width of the vertical edges, if `v_edge=True`. Defaults to
        `None`, in which case `point_size` will be used for the edge
        width
    :param gap: `List(float, float, float)`
        If `None`, the hierarchical graphs will be overlaid on the points.
        If not `None`, a 3D tensor indicating the offset by which the
        hierarchical graphs should be plotted
    :param radius: `float`
        If not `None`, only visualize a spherical sampling of the input
        data, centered on `center` and with size `radius`. This option
        is not compatible with `select`
    :param center: `List(float, float, float)`
        If `radius` is provided, only visualize a spherical sampling of
        the input data, centered on `center` and with size `radius`. If
        `None`, the center of the scene will be used
    :param select: `Tuple(int, Tensor)`
        If not `None`, will call `Data.select(*select)` or
        `NAG.select(*select)` on the input data (depending on its nature)
        and the coloring schemes will illustrate it. This option is not
        compatible with `radius`
    :param alpha: `float`
        Rules the whitening of selected points, nodes and edges (only if
         select is not `None`)
    :param alpha_super: `float`
        Rules the whitening of superpoints (only if select is not
        `None`). If `None`, alpha will be used as fallback
    :param alpha_stuff: `float`
        Rules the whitening of stuff points (only if the input
        points have `obj` and `semantic_pred` attributes, and
        `stuff_classes` or `num_classes` is specified). If `None`,
        `alpha` will be used as fallback
    :param point_symbol: `str`
        Marker symbol used for points. Must be one of
        `{'circle', 'circle-open', 'square', 'square-open', 'diamond',
        'diamond-open', 'cross', 'x'}`. Defaults to `'circle'`
    :param centroid_symbol: `str`
        Marker symbol used for centroids. Must be one of
        `{'circle', 'circle-open', 'square', 'square-open', 'diamond',
        'diamond-open', 'cross', 'x'}`. Defaults to `'circle'`
    :param colorscale: `str`
        Plotly colorscale used for coloring 1D continuous features. See
        https://plotly.com/python/builtin-colorscales for options
    :param kwargs

    :return:
    """
    # Data attributes plotted by default if found in the input
    _DEFAULT_KEYS = [
        'pos',
        'rgb',
        'y',
        'semantic_pred',
        'obj',
        'obj_pred',
        'x',
        'super_sampling',
        'super_index']

    # assert isinstance(input, (Data, NAG))
    gap = torch.tensor(gap) if gap is not None else gap
    assert gap is None or gap.shape == torch.Size([3])
    assert not (radius and (select is not None)), \
        "Cannot use both a `radius` and `select` at once"

    # We work on copies of the input data, to allow modified in this
    # scope
    input = input.clone().cpu()

    # If the input is a simple Data object, we convert it to a NAG
    input = NAG([input]) if isinstance(input, Data) else input

    # If the last level of the NAG has super_index, we manually
    # construct an additional Data level and append it to the NAG
    if input[input.num_levels - 1].is_sub:
        data_last = input[input.num_levels - 1]
        sub = Cluster(
            data_last.super_index, torch.arange(data_last.num_nodes),
            dense=True)
        obj = data_last.obj.merge(data_last.super_index) \
            if data_last.obj else None
        pos = scatter_mean(data_last.pos, data_last.super_index, dim=0)
        input = NAG(input.to_list() + [Data(pos=pos, sub=sub, obj=obj)])
    is_nag = isinstance(input, NAG)
    num_levels = input.num_levels if is_nag else 1

    # Make sure alpha is in [0, 1]
    alpha = max(0, min(alpha, 1))
    alpha_super = max(0, min(alpha_super, 1)) if alpha_super else alpha
    alpha_stuff = max(0, min(alpha_stuff, 1)) if alpha_stuff else alpha

    # If `radius` is provided, we only visualize a spherical selection
    # of size `radius` around the `center`
    if radius is not None:
        # If no `center` provided, pick the middle of the scene
        if center is None:
            hi = input[0].pos.max(dim=0).values
            lo = input[0].pos.min(dim=0).values
            center = (hi + lo) / 2

            # For Z, we center on the average Z, because the middle
            # value may cause empty samplings for outdoor scenes with
            # some very high objects and most of the interesting stuff
            # happening near the ground 
            center[2] = input[0].pos[:, 2].mean()
        else:
            center = torch.as_tensor(center).cpu()
        center = center.view(1, -1)

        # Create a mask on level-0 (ie points) to be used for indexing
        # the NAG structure
        mask = torch.where(
            torch.linalg.norm(input[0].pos - center, dim=1) < radius)[0]

        # Subselect the hierarchical partition based on the level-0 mask
        input = input.select(0, mask)

    # If `select` is provided, we will call NAG.select on the input data
    # and illustrate the selected/discarded pattern in the figure
    if select is not None and is_nag:

        # Add an ID to the points before applying NAG.select
        nag_temp = input.clone()
        for i in range(nag_temp.num_levels):
            nag_temp._list[i] = SaveNodeIndex()(nag_temp[i])

        # Apply the selection
        nag_temp = nag_temp.select(*select)

        # Indicate, for each node of the hierarchical graph, whether it
        # has been selected
        for i in range(num_levels):
            selected = torch.zeros(input[i].num_nodes, dtype=torch.bool)
            selected[nag_temp[i][SaveNodeIndex.DEFAULT_KEY]] = True
            input[i].selected = selected

        del nag_temp, selected

    elif select is not None and not is_nag:

        # Add an ID to the points before applying NAG.select
        data_temp = SaveNodeIndex()(Data(pos=input.pos.clone()))

        # Apply the selection
        data_temp = data_temp.select(select)[0]

        # Indicate, for each node of the hierarchical graph, whether it
        # has been selected
        selected = torch.zeros(input.num_nodes, dtype=torch.bool)
        selected[data_temp[SaveNodeIndex.DEFAULT_KEY]] = True
        input.selected = selected

        del data_temp, selected

    elif is_nag:
        for i in range(num_levels):
            input[i].selected = torch.ones(
                input[i].num_nodes, dtype=torch.bool)

    else:
        input.selected = torch.ones(input.num_nodes, dtype=torch.bool)

    # Data_0 accounts for the lowest level of hierarchy, the points
    # themselves
    data_0 = input[0] if is_nag else input

    # Subsample to limit the drawing time
    # If the level-0 cloud needs to be voxelized or sampled, a NAG
    # structure will be affected too. To maintain NAG consistency, we
    # only support 'GridSampling3D' with mode='last' and random sampling
    # without replacement. To keep track of the sampled points and index
    # the NAG accordingly, we use 'SaveNodeIndex'
    idx = torch.arange(data_0.num_points)

    # If a voxel size is specified, voxelize the level-0. We first
    # isolate the 'pos' and the input indices of data_0 and apply
    # voxelization on this. We then recover the original grid-sampled
    # points indices to be used with Data.select or NAG.select
    if voxel > 0:
        data_temp = SaveNodeIndex()(Data(pos=data_0.pos.clone()))
        data_temp = GridSampling3D(voxel, mode='last')(data_temp)
        idx = data_temp[SaveNodeIndex.DEFAULT_KEY]
        del data_temp

    # If the cloud is too large with respect to required 'max_points',
    # sample without replacement
    if idx.shape[0] > max_points:
        idx = idx[fast_randperm(idx.shape[0])[:max_points]]

    # If a sampling is needed, apply it to the input Data or NAG,
    # depending on the structure
    if idx.shape[0] < data_0.num_points:
        input = input.select(0, idx) if is_nag else input.select(idx)[0]
        data_0 = input[0] if is_nag else input

    # Round to the cm for cleaner hover info
    data_0.pos = (data_0.pos * 100).round() / 100

    # Class colors initialization
    if class_colors is not None and not isinstance(class_colors[0], str):
        class_colors = np.asarray(class_colors)
    else:
        class_colors = None

    # Prepare figure
    width = width if width and height else figsize
    height = height if width and height else int(figsize / 2)
    margin = int(0.02 * min(width, height))
    layout = go.Layout(
        width=width,
        height=height,
        scene=dict(aspectmode='data', ),  # preserve aspect ratio
        margin=dict(l=margin, r=margin, b=margin, t=margin),
        uirevision=True)
    fig = go.Figure(layout=layout)

    # To keep track of which trace should be seen under which mode
    # (i.e. button), we build trace_modes. This is a list of dictionaries
    # indicating, for each trace (list element), which mode (dict key)
    # it should appear in and with which attributes (values are dict of
    # parameters for plotly figure updates)
    trace_modes = []
    i_point_trace = 0
    i_unselected_point_trace = 1

    # Initialize `void_classes`
    void_classes = [num_classes] if num_classes else []

    # Draw a trace for position-colored 3D point cloud
    mini = data_0.pos.min(dim=0).values
    maxi = data_0.pos.max(dim=0).values
    colors = (data_0.pos - mini) / (maxi - mini + 1e-6)
    colors = rgb_to_plotly_rgb(colors)
    data_0.pos_colors = colors

    fig.add_trace(
        go.Scatter3d(
            x=data_0.pos[data_0.selected, 0],
            y=data_0.pos[data_0.selected, 1],
            z=data_0.pos[data_0.selected, 2],
            mode='markers',
            marker=dict(
                symbol=point_symbol,
                size=point_size,
                color=colors[data_0.selected]),
            hoverinfo='x+y+z+text',
            hovertext=None,
            showlegend=False,
            visible=True, ))
    trace_modes.append({
        'Position RGB': {
            'marker.color': colors[data_0.selected], 'hovertext': None}})

    fig.add_trace(
        go.Scatter3d(
            x=data_0.pos[~data_0.selected, 0],
            y=data_0.pos[~data_0.selected, 1],
            z=data_0.pos[~data_0.selected, 2],
            mode='markers',
            marker=dict(
                symbol=point_symbol,
                size=point_size,
                color=colors[~data_0.selected],
                opacity=alpha),
            hoverinfo='x+y+z+text',
            hovertext=None,
            showlegend=False,
            visible=True, ))
    trace_modes.append({
        'Position RGB': {
            'marker.color': colors[~data_0.selected], 'hovertext': None}})

    # Draw a trace for RGB 3D point cloud
    if data_0.rgb is not None:
        colors = data_0.rgb
        colors = rgb_to_plotly_rgb(colors)
        data_0.rgb_colors = colors
        trace_modes[i_point_trace]['RGB'] = {
            'marker.color': colors[data_0.selected], 'hovertext': None}
        trace_modes[i_unselected_point_trace]['RGB'] = {
            'marker.color': colors[~data_0.selected], 'hovertext': None}

    # Color the points with ground truth semantic labels. If labels are
    # expressed as histograms, keep the most frequent one
    if data_0.y is not None:
        y = data_0.y
        y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy()
        colors = class_colors[y] if class_colors is not None \
            else int_to_plotly_rgb(torch.LongTensor(y))
        data_0.y_colors = colors
        if class_names is None:
            text = np.array([f"Class {i}" for i in range(y.max() + 1)])
        else:
            text = np.array([str.title(c) for c in class_names])
        text = text[y]
        trace_modes[i_point_trace]['Semantic'] = {
            'marker.color': colors[data_0.selected],
            'hovertext': text[data_0.selected]}
        trace_modes[i_unselected_point_trace]['Semantic'] = {
            'marker.color': colors[~data_0.selected],
            'hovertext': text[~data_0.selected]}

    # Color the points with predicted semantic labels. If labels are
    # expressed as histograms, keep the most frequent one
    if data_0.semantic_pred is not None:
        pred = data_0.semantic_pred
        pred = pred.argmax(1).numpy() if pred.dim() == 2 else pred.numpy()

        # If the ground truth labels are available, we use them to
        # identify void points in the predictions
        if data_0.y is not None and hide_void_pred:
            # Get the target label
            y_gt = data_0.y
            y_gt = y_gt.argmax(1) if y_gt.dim() == 2 else y_gt

            # Create a mask over the points identifying those whose
            # ground truth label is void
            is_void = np.zeros(y_gt.max() + 1, dtype='bool')
            for i in void_classes:
                if i < is_void.shape[0]:
                    is_void[i] = True
            is_void = is_void[y_gt]

            # Set the predicted label to void if the ground truth is
            # void, this avoids visualizing predictions on void
            # labels
            pred[is_void] = y_gt[is_void]

        colors = class_colors[pred] if class_colors is not None \
            else int_to_plotly_rgb(torch.LongTensor(pred))
        data_0.pred_colors = colors
        if class_names is None:
            text = np.array([f"Class {i}" for i in range(pred.max() + 1)])
        else:
            text = np.array([str.title(c) for c in class_names])
        text = text[pred]
        trace_modes[i_point_trace]['Semantic Pred.'] = {
            'marker.color': colors[data_0.selected],
            'hovertext': text[data_0.selected]}
        trace_modes[i_unselected_point_trace]['Semantic Pred.'] = {
            'marker.color': colors[~data_0.selected],
            'hovertext': text[~data_0.selected]}

    # Color the points with ground truth instance labels. If semantic
    # labels and stuff_classes/void_classes also passed, the stuff/void
    # annotations will be treated accordingly
    if data_0.obj is not None and (class_names is None or data_0.y is None):
        obj = data_0.obj if isinstance(data_0.obj, torch.Tensor) \
            else data_0.obj.major(num_classes=num_classes)[0]
        colors = int_to_plotly_rgb(obj)
        data_0.obj_colors = colors
        text = np.array([f"Object {o}" for o in obj])
        trace_modes[i_point_trace]['Panoptic'] = {
            'marker.color': colors[data_0.selected],
            'hovertext': text[data_0.selected]}
        trace_modes[i_unselected_point_trace]['Panoptic'] = {
            'marker.color': colors[~data_0.selected],
            'hovertext': text[~data_0.selected]}
    elif data_0.obj is not None:
        # Colors and text for thing points
        obj = data_0.obj.major(num_classes=num_classes)[0]
        colors_thing = int_to_plotly_rgb(obj)
        text_thing = np.array([f"Object {o}" for o in obj])

        # For simplicity, we merge void_classes into the stuff_classes,
        # the expected behavior is the same, except that we will ensure
        # that the hover text distinguishes between stuff and void
        stuff_classes = stuff_classes if stuff_classes is not None else []
        stuff_classes = list(set(stuff_classes).union(set(void_classes)))

        # Colors and text for stuff points
        y = data_0.y
        y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy()
        colors_stuff = class_colors[y] if class_colors is not None \
            else int_to_plotly_rgb(torch.LongTensor(y))
        if class_names is None:
            text_stuff = np.array([
                f"{'Void' if i in void_classes else 'Stuff'} - Class {i}"
                for i in range(y.max() + 1)])
        else:
            text_stuff = np.array([
                f"{'Void' if i in void_classes else 'Stuff'} - {str.title(c)}"
                for i, c in enumerate(class_names)])
        text_stuff = text_stuff[y]

        # Apply alpha-whitening on stuff points
        colors_stuff = colors_stuff.astype('float')
        white = np.full((colors_stuff.shape[0], 3), 255, dtype='float')
        colors_stuff = colors_stuff * alpha_stuff + white * (1 - alpha_stuff)
        colors_stuff = colors_stuff.astype('int64')

        # Compute mask for stuff points
        stuff_classes = np.asarray([i for i in stuff_classes if i <= y.max()])
        is_stuff = np.zeros(y.max() + 1, dtype='bool')
        for i in stuff_classes:
            if i < is_stuff.shape[0]:
                is_stuff[i] = True
        is_stuff = is_stuff[y]

        # Merge thing and stuff colors and text
        colors = colors_thing
        text = text_thing
        colors[is_stuff] = colors_stuff[is_stuff]
        text[is_stuff] = text_stuff[is_stuff]
        data_0.obj_colors = colors

        # Create trace modes
        trace_modes[i_point_trace]['Panoptic'] = {
            'marker.color': colors[data_0.selected],
            'hovertext': text[data_0.selected]}
        trace_modes[i_unselected_point_trace]['Panoptic'] = {
            'marker.color': colors[~data_0.selected],
            'hovertext': text[~data_0.selected]}

    # Color the points with predicted instance labels. If semantic
    # labels and stuff_classes/void_classes also passed, the
    # stuff/void predictions will be treated accordingly. This
    # expects `data_0.obj_pred` to be an InstanceData object
    if getattr(data_0, 'obj_pred', None) is not None and class_names is None:
        obj, _, y = data_0.obj_pred.major(num_classes=num_classes)
        colors = int_to_plotly_rgb(obj)
        data_0.obj_pred_colors = colors
        text = np.array([f"Object {o}" for o in obj])
        trace_modes[i_point_trace]['Panoptic Pred.'] = {
            'marker.color': colors[data_0.selected],
            'hovertext': text[data_0.selected]}
        trace_modes[i_unselected_point_trace]['Panoptic Pred.'] = {
            'marker.color': colors[~data_0.selected],
            'hovertext': text[~data_0.selected]}
    elif getattr(data_0, 'obj_pred', None) is not None:
        # Colors and text for thing points
        obj, _, y = data_0.obj_pred.major(num_classes=num_classes)
        colors_thing = int_to_plotly_rgb(obj)
        text_thing = np.array([f"Object {o}" for o in obj])

        # For simplicity, we merge void_classes into the stuff_classes,
        # the expected behavior is the same, except that we will ensure
        # that the hover text distinguishes between stuff and void
        stuff_classes = stuff_classes if stuff_classes is not None else []
        stuff_classes = list(set(stuff_classes).union(set(void_classes)))

        # If the ground truth labels are available, we use them to
        # identify void points in the predictions
        if data_0.y is not None and hide_void_pred:
            # Get the target label
            y_gt = data_0.y
            y_gt = y_gt.argmax(1) if y_gt.dim() == 2 else y_gt

            # Create a mask over the points identifying those whose
            # ground truth label is void
            is_void = np.zeros(y_gt.max() + 1, dtype='bool')
            for i in void_classes:
                if i < is_void.shape[0]:
                    is_void[i] = True
            is_void = is_void[y_gt]

            # Set the predicted label to void if the ground truth is
            # void, this avoids visualizing predictions on void
            # labels
            y[is_void] = y_gt[is_void]

        # Colors and text for stuff points
        colors_stuff = class_colors[y] if class_colors is not None \
            else int_to_plotly_rgb(torch.LongTensor(y))
        if class_names is None:
            text_stuff = np.array([
                f"{'Void' if i in void_classes else 'Stuff'} - Class {i}"
                for i in range(y.max() + 1)])
        else:
            text_stuff = np.array([
                f"{'Void' if i in void_classes else 'Stuff'} - {str.title(c)}"
                for i, c in enumerate(class_names)])
        text_stuff = text_stuff[y]

        # Apply alpha-whitening on stuff points
        colors_stuff = colors_stuff.astype('float')
        white = np.full((colors_stuff.shape[0], 3), 255, dtype='float')
        colors_stuff = colors_stuff * alpha_stuff + white * (1 - alpha_stuff)
        colors_stuff = colors_stuff.astype('int64')

        # Compute mask for stuff points
        stuff_classes = np.asarray([i for i in stuff_classes if i <= y.max()])
        is_stuff = np.zeros(y.max() + 1, dtype='bool')
        for i in stuff_classes:
            if i < is_stuff.shape[0]:
                is_stuff[i] = True
        is_stuff = is_stuff[y]

        # Merge thing and stuff colors and text
        colors = colors_thing
        text = text_thing
        colors[is_stuff] = colors_stuff[is_stuff]
        text[is_stuff] = text_stuff[is_stuff]
        data_0.obj_pred_colors = colors

        # Create trace modes
        trace_modes[i_point_trace]['Panoptic Pred.'] = {
            'marker.color': colors[data_0.selected],
            'hovertext': text[data_0.selected]}
        trace_modes[i_unselected_point_trace]['Panoptic Pred.'] = {
            'marker.color': colors[~data_0.selected],
            'hovertext': text[~data_0.selected]}

    # Draw a trace for 3D point cloud features
    if data_0.x is not None:
        colors = feats_to_plotly_rgb(
            data_0.x, normalize=True, colorscale=colorscale)
        data_0.x_colors = colors
        trace_modes[i_point_trace]['Features 3D'] = {
            'marker.color': colors[data_0.selected], 'hovertext': None}
        trace_modes[i_unselected_point_trace]['Features 3D'] = {
            'marker.color': colors[~data_0.selected], 'hovertext': None}

    # Draw a trace for each key specified in keys. Only displays
    # point-wise tensor attributes that have not already been plotted
    # (ie not in `_DEFAULT_KEYS`)
    if keys is None:
        keys = []
    elif isinstance(keys, str):
        keys = [keys]
    keys = [k for k in keys if k not in _DEFAULT_KEYS]
    for key in keys:
        val = getattr(data_0, key, None)
        if (val is None or not torch.is_tensor(val)
                or val.shape[0] != data_0.num_points):
            continue
        colors = feats_to_plotly_rgb(
            val, normalize=True, colorscale=colorscale)
        data_0[f"{key}_colors"] = colors
        trace_modes[i_point_trace][str(key).title()] = {
            'marker.color': colors[data_0.selected], 'hovertext': None}
        trace_modes[i_unselected_point_trace][str(key).title()] = {
            'marker.color': colors[~data_0.selected], 'hovertext': None}

    # Draw a trace for 3D point cloud sampling (for sampling debugging)
    if 'super_sampling' in data_0.keys:
        colors = data_0.super_sampling
        colors = int_to_plotly_rgb(colors)
        colors[data_0.super_sampling == -1] = 230
        data_0.super_sampling_colors = colors
        trace_modes[i_point_trace]['Super sampling'] = {
            'marker.color': colors[data_0.selected], 'hovertext': None}
        trace_modes[i_unselected_point_trace]['Super sampling'] = {
            'marker.color': colors[~data_0.selected], 'hovertext': None}

    # Draw a trace for each cluster level
    for i_level, data_i in enumerate(input if is_nag else []):

        # Exit in case the Data has no 'super_index'
        if not data_i.is_sub:
            break

        # 'Data.super_index' are expressed between levels i and i+1, but
        # we need to recover the 'super_index' between level 0 and i+1,
        # to draw clusters on the level-0 points. To this end, we
        # compute the desired 'super_index' iteratively, with a
        # bottom-up approach
        if i_level == 0:
            super_index = data_i.super_index
        else:
            super_index = data_i.super_index[super_index]

        # Note that we update the 'trace_modes' 0th element here, this
        # assumes only it is the trace holding all level-0 points and on
        # which all other colors modes are defined
        colors = int_to_plotly_rgb(super_index)
        data_0[f"{i_level}_level_colors"] = colors
        text = np.array([f"↑: {i}" for i in super_index])
        trace_modes[i_point_trace][f"Level {i_level + 1}"] = {
            'marker.color': colors[data_0.selected],
            'hovertext': text[data_0.selected]}
        trace_modes[i_unselected_point_trace][f"Level {i_level + 1}"] = {
            'marker.color': colors[~data_0.selected],
            'hovertext': text[~data_0.selected]}

        # Skip to the next level if we do not need to draw the cluster
        # centroids
        if not centroids:
            continue

        # To recover centroids of the i+1 level superpoints, we either
        # read them from the next NAG level or compute them using the
        # level i 'super_index' indices
        num_levels = input.num_levels
        is_last_level = i_level == num_levels - 1
        if is_last_level or input[i_level + 1].pos is None:
            super_pos = scatter_mean(data_0.pos, super_index, dim=0)
        else:
            super_pos = input[i_level + 1].pos

        # Add the gap offset, if need be
        if gap is not None:

            super_pos += gap * (i_level + 1)

        # Round to the cm for cleaner hover info
        super_pos = (super_pos * 100).round() / 100

        # Save the drawing position of centroids to facilitate vertical
        # edges drawing later on
        input[i_level + 1].draw_pos = super_pos

        # Draw the level-i+1 cluster centroids
        idx_sp = torch.arange(data_i.super_index.max() + 1)
        colors = int_to_plotly_rgb(idx_sp)
        text = np.array([f"<b>#: {i}</b>" for i in idx_sp])
        ball_size = centroid_size if centroid_size else point_size * 3

        fig.add_trace(
            go.Scatter3d(
                x=super_pos[input[i_level + 1].selected, 0],
                y=super_pos[input[i_level + 1].selected, 1],
                z=super_pos[input[i_level + 1].selected, 2],
                mode='markers+text',
                marker=dict(
                    symbol=centroid_symbol,
                    size=ball_size,
                    color=colors[input[i_level + 1].selected.numpy()],
                    line_width=min(ball_size / 2, 2),
                    line_color='black'),
                textposition="bottom center",
                textfont=dict(size=16),
                hovertext=text,
                hoverinfo='x+y+z+text',
                showlegend=False,
                visible=gap is not None, ))

        fig.add_trace(
            go.Scatter3d(
                x=super_pos[~input[i_level + 1].selected, 0],
                y=super_pos[~input[i_level + 1].selected, 1],
                z=super_pos[~input[i_level + 1].selected, 2],
                mode='markers+text',
                marker=dict(
                    symbol=centroid_symbol,
                    size=ball_size,
                    color=colors[~input[i_level + 1].selected.numpy()],
                    line_width=min(ball_size / 2, 2),
                    line_color='black',
                    opacity=alpha_super),
                textposition="bottom center",
                textfont=dict(size=16),
                hovertext=text,
                hoverinfo='x+y+z+text',
                showlegend=False,
                visible=gap is not None, ))

        keys = [f"Level {i_level + 1}"] if gap is None \
            else trace_modes[i_point_trace].keys()
        trace_modes.append(
            {k: {
                'marker.color': colors[input[i_level + 1].selected.numpy()],
                'hovertext': text[input[i_level + 1].selected.numpy()]}
            for k in keys})
        trace_modes.append(
            {k: {
                'marker.color': colors[~input[i_level + 1].selected.numpy()],
                'hovertext': text[~input[i_level + 1].selected.numpy()]}
            for k in keys})

        if i_level > 0 and v_edge and gap is not None and is_nag:
            # Recover the source and target positions for vertical edges
            # between i_level -> i_level+1
            low_pos = data_i.draw_pos[data_i.selected]
            high_pos = super_pos[data_i.super_index[data_i.selected]]

            # Convert into a plotly-friendly format for 3D lines
            edges = np.full((low_pos.shape[0] * 3, 3), None)
            edges[::3] = low_pos
            edges[1::3] = high_pos

            # Color the vertical edges based on the parent cluster index
            # Plotly is a bit hacky with colors for 3D lines. We cannot
            # directly pass individual edge colors, we must instead give
            # edge color as an int corresponding to a colorscale list
            # holding plotly-friendly colors
            colors = data_i.super_index[data_i.selected]
            colors = np.repeat(colors, 3)
            n_colors = colors.max().item() + 1
            edge_colorscale = int_to_plotly_rgb(torch.arange(n_colors))
            edge_colorscale = [
                [i / (n_colors - 1), f"rgb({x[0]}, {x[1]}, {x[2]})"]
                for i, x in enumerate(edge_colorscale)]

            # Since plotly 3D lines do not support opacity, we draw
            # these edges as super thin to limit clutter
            edge_width = 0.5 if v_edge_width is None else v_edge_width

            # Draw the level i -> i+1 vertical edges. NB we only draw
            # edges that are selected and do not draw the unselected
            # edges. This is because plotly does not handle opacity
            # on lines (yet), which means the unselected edges will tend
            # to clutter the figure. For this reason we choose to simply
            # not show them
            fig.add_trace(
                go.Scatter3d(
                    x=edges[:, 0],
                    y=edges[:, 1],
                    z=edges[:, 2],
                    mode='lines',
                    line=dict(
                        width=edge_width,
                        color=colors,
                        colorscale=edge_colorscale),
                    hoverinfo='skip',
                    showlegend=False,
                    visible=gap is not None, ))

            # NB: at this point, trace_modes contains 'Level i+1' as its
            # last key, but we do not want vertical edges to be seen
            # when 'Level i+1' is selected, because it means 'Level i'
            # nodes are hidden
            keys = list(trace_modes[i_point_trace].keys())[:-1]
            trace_modes.append({k: {} for k in keys})

        # Do not draw superedges if not required or if the i+1 level
        # does not have any
        if not h_edge or is_last_level or not input[i_level + 1].has_edges:
            continue

        # Recover the superedge source and target positions
        se = input[i_level + 1].edge_index
        se_attr = input[i_level + 1].edge_attr

        # Since we can only draw one edge direction (they would overlap
        # otherwise), we can trim the graph to only keep one direction
        # for each undirected edge pair. However, this requires picking
        # one direction for the edge attributes to we ARBITRARILY TAKE
        # THE MAX EDGE FEATURE for each undirected edge
        input[i_level + 1].raise_if_edge_keys()
        if se_attr is not None:
            se, se_attr = to_trimmed(se, edge_attr=se_attr, reduce='max')
        else:
            se = to_trimmed(se)

        # Recover corresponding source and target coordinates using the
        # previously-computed 'super_pos' cluster centroid positions
        s_pos = super_pos[se[0]].numpy()
        t_pos = super_pos[se[1]].numpy()

        # Convert into a plotly-friendly format for 3D lines
        edges = np.full((se.shape[1] * 3, 3), None)
        edges[::3] = s_pos
        edges[1::3] = t_pos

        if h_edge_attr and se_attr is not None:

            # Recover edge features and convert them to RGB colors. NB:
            # edge features are assumed to be in [0, 1] or [-1, 1].
            # Since we only draw edges in one direction, we choose to
            # only represent the absolute value of the features. This
            # implies that features are either direction-independent or
            # that the edge direction only changes the sign of the
            # feature
            colors = feats_to_plotly_rgb(
                se_attr.abs(), normalize=True, colorscale=colorscale)
            colors = np.repeat(colors, 3, axis=0)
            edge_width = point_size if h_edge_width is None else h_edge_width

        else:
            colors = feats_to_plotly_rgb(
                torch.zeros(edges.shape[0]), normalize=True, colorscale=colorscale)
            edge_width = point_size if h_edge_width is None else h_edge_width

        selected_edge = input[i_level + 1].selected[se].all(axis=0)
        selected_edge = selected_edge.repeat_interleave(3).numpy()

        # Draw the level-i+1 superedges. NB we only draw edges that are
        # selected and do not draw the unselected edges. This is because
        # plotly does not handle opacity on lines (yet), which means the
        # unselected edges will tend to clutter the figure. For this
        # reason we choose to simply not show them
        fig.add_trace(
            go.Scatter3d(
                x=edges[selected_edge, 0],
                y=edges[selected_edge, 1],
                z=edges[selected_edge, 2],
                mode='lines',
                line=dict(
                    width=edge_width,
                    color=colors[selected_edge]),
                hoverinfo='skip',
                showlegend=False,
                visible=gap is not None, ))

        keys = [f"Level {i_level + 1}"] if gap is None \
            else trace_modes[i_point_trace].keys()
        trace_modes.append({k: {} for k in keys})

    # Add a trace for prediction errors. NB: it is important that this
    # trace is created last, as the button behavior for this one is
    # particular
    has_error = data_0.y is not None and data_0.semantic_pred is not None
    if has_error:

        # Recover prediction and ground truth and deal with potential
        # histograms
        y = data_0.y
        y = y.argmax(1).numpy() if y.dim() == 2 else y.numpy()
        pred = data_0.semantic_pred
        pred = pred.argmax(1).numpy() if pred.dim() == 2 else pred.numpy()

        # Identify erroneous point indices
        ignore = void_classes if void_classes else []
        ignore = ignore + [-1]
        indices = np.where((pred != y) & (~np.in1d(y, ignore)))[0]

        # Prepare the color for erroneous points
        error_color = 'red' if error_color is None \
            else np.asarray[error_color].squeeze()

        # Draw the erroneous points
        fig.add_trace(
            go.Scatter3d(
                x=data_0.pos[indices, 0],
                y=data_0.pos[indices, 1],
                z=data_0.pos[indices, 2],
                mode='markers',
                marker=dict(
                    symbol=point_symbol,
                    size=int(point_size * 1.5),
                    color=error_color, ),
                showlegend=False,
                visible=False, ))

    # Recover the keys for all visualization modes, as an ordered set,
    # with respect to their order of first appearance
    modes = list(dict.fromkeys([k for m in trace_modes for k in m.keys()]))

    # Traces color for interactive point cloud coloring
    def trace_update(mode):
        # Prepare the output args for the figure update attributes. By
        # default, all traces are non visible, with no color and no
        # hover text
        n_traces = len(trace_modes)
        out = {
            'visible': [False] * (n_traces + has_error),
            'marker.color': [None] * n_traces,
            'hovertext': [''] * n_traces}

        # For each trace in 'trace_modes' see if it contains 'mode' and
        # adapt out accordingly
        for i_trace, t_modes in enumerate(trace_modes):

            # The trace has no action for the mode, skip it and leave
            # the default args for the trace
            if mode not in t_modes:
                continue

            # Note that a trace will only be visible for its modes
            # declared in trace_modes
            out['visible'][i_trace] = True
            for key, val in t_modes[mode].items():
                out[key][i_trace] = val

        return [out, list(range(len(trace_modes)))]

    # Create the buttons that will serve for toggling trace visibility
    updatemenus = [
        dict(
            buttons=[dict(
                label=mode, method='update', args=trace_update(mode))
                for mode in modes if mode.lower() != 'errors'],
            pad={'r': 10, 't': 10},
            showactive=True,
            type='dropdown',
            direction='right',
            xanchor='left',
            x=0.02,
            yanchor='top',
            y=1.02, ),]

    if has_error:
        updatemenus.append(
            dict(
                buttons=[dict(
                    method='restyle',
                    label='Semantic Errors',
                    visible=True,
                    args=[
                        {'visible': True, 'marker.color': error_color},
                        [len(trace_modes)]],
                    args2=[
                        {'visible': False,},
                        [len(trace_modes)]],)],
                pad={'r': 10, 't': 10},
                showactive=False,
                type='buttons',
                xanchor='left',
                x=1.02,
                yanchor='top',
                y=1.02, ),)

    fig.update_layout(updatemenus=updatemenus)

    # Place the legend on the left
    fig.update_layout(
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="right",
            x=0.99))

    # Hide all axes and no background
    fig.update_layout(
        scene=dict(
            xaxis_title='',
            yaxis_title='',
            zaxis_title='',
            xaxis=dict(
                autorange=True,
                showgrid=False,
                ticks='',
                showticklabels=False,
                backgroundcolor="rgba(0, 0, 0, 0)"),
            yaxis=dict(
                autorange=True,
                showgrid=False,
                ticks='',
                showticklabels=False,
                backgroundcolor="rgba(0, 0, 0, 0)"),
            zaxis=dict(
                autorange=True,
                showgrid=False,
                ticks='',
                showticklabels=False,
                backgroundcolor="rgba(0, 0, 0, 0)")))

    output = {'figure': fig, 'data': data_0}

    return output


def figure_html(fig):
    # Save plotly figure to temp HTML
    fig.write_html(
        '/tmp/fig.html',
        config={'displayModeBar': False},
        include_plotlyjs='cdn',
        full_html=False)

    # Read the HTML
    with open("/tmp/fig.html", "r") as f:
        fig_html = f.read()

    # Center the figure div for cleaner display
    fig_html = fig_html.replace(
        'class="plotly-graph-div" style="',
        'class="plotly-graph-div" style="margin:0 auto;')

    return fig_html


def show(input, path=None, title=None, no_output=True, pt_path=None, **kwargs):
    """Interactive data visualization.

    :param input: Data or NAG object
    :param path: str
        Path to save the visualization into a sharable HTML
    :param title: str
        Figure title
    :param no_output: bool
        Set to True if you want to return the 3D Plotly figure objects
    :param pt_path:str
        Path to save the visualization-ready `Data` object as a `*.pt`.
        In this `Data` object, the `pos` and all `*color*` attributes
        will be saved, the rest is discarded. This is typically useful
        for exporting the visualization layers to another visualization
        tool
    :param kwargs:
    :return:
    """
    # Sanitize title and path
    if title is None:
        title = "Large-scale point cloud"
    if path is not None:
        if osp.isdir(path):
            path = osp.join(path, f"{title}.html")
        else:
            path = osp.splitext(path)[0] + '.html'
        fig_html = f'<h1 style="text-align: center;">{title}</h1>'

    # Draw a figure for 3D data visualization
    out_3d = visualize_3d(input, **kwargs)
    if no_output:
        if path is None:
            out_3d['figure'].show(config={'displayModeBar': False})
        else:
            fig_html += figure_html(out_3d['figure'])

    if path is not None:
        with open(path, "w") as f:
            f.write(fig_html)

    # Save to a .pt file for other downstream tasks.
    # NB: we only save the 'pos' and all data attributes containing
    # 'color', the rest is discarded
    # NB: we save a dictionary, to limit dependencies
    if pt_path is not None:
        if osp.isdir(pt_path):
            pt_path = osp.join(pt_path, f"viz_data.pt")
        else:
            pt_path = osp.splitext(pt_path)[0] + '.pt'

        data = {}
        for key in out_3d['data'].keys:
            if key == 'pos' or 'color' in key:
                data[key] = out_3d['data'][key]

        torch.save(data, pt_path)

    if not no_output:
        return out_3d

    return