English
File size: 9,878 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import torch
import logging
import src

log = logging.getLogger(__name__)


__all__ = ['SemanticSegmentationOutput']


class SemanticSegmentationOutput:
    """A simple holder for semantic segmentation model output, with a
    few helper methods for manipulating the predictions and targets
    (if any).
    """

    def __init__(self, logits, y_hist=None):
        self.logits = logits
        self.y_hist = y_hist
        if src.is_debug_enabled():
            self.debug()

    def debug(self):
        """Runs a series of sanity checks on the attributes of self.
        """
        assert isinstance(self.logits, torch.Tensor) \
               or all(isinstance(l, torch.Tensor) for l in self.logits)
        if self.has_target:
            if self.multi_stage:
                assert len(self.y_hist) == len(self.logits)
                assert all(
                    y.shape[0] == l.shape[0]
                    for y, l in zip(self.y_hist, self.logits))
            else:
                assert self.y_hist.shape[0] == self.logits.shape[0]

    @property
    def device(self):
        """Returns the device on which the logits are stored, assuming
        all other output variables held by the object are also on the
        same device.
        """
        logits = self.logits[0] if self.multi_stage else self.logits
        return logits.device

    @property
    def has_target(self):
        """Check whether `self` contains target data for semantic
        segmentation.
        """
        return self.y_hist is not None

    @property
    def multi_stage(self):
        """If the semantic segmentation `logits` are stored in an
        enumerable, then the model output is multi-stage.
        """
        return not isinstance(self.logits, torch.Tensor)

    @property
    def num_classes(self):
        """Number for semantic classes in the output predictions.
        """
        logits = self.logits[0] if self.multi_stage else self.logits
        return logits.shape[1]

    @property
    def num_nodes(self):
        """Number for nodes/superpoints in the output predictions. By
        default, for a hierarchical partition, this means counting the
        number of level-1 nodes/superpoints.
        """
        logits = self.logits[0] if self.multi_stage else self.logits
        return logits.shape[0]

    def semantic_pred(self):
        """Semantic predictions on the level-1 superpoint.

        Final semantic segmentation predictions are the argmax of the
        first-level partition logits.
        """
        logits = self.logits[0] if self.multi_stage else self.logits
        return torch.argmax(logits, dim=1)

    @property
    def semantic_target(self):
        """Semantic target on the level-1 superpoint.

        Final semantic segmentation target are the label histogram
        of the first-level partition logits.
        """
        return self.y_hist[0] if self.multi_stage else self.y_hist

    @property
    def void_mask(self):
        """Returns a mask on the level-1 nodes indicating which is void.
        By convention, nodes/superpoints are void if they contain
        more than 50% void points. By convention in this project, void
        points have the label `num_classes`. In label histograms, void
        points are counted in the last column.
        """
        if not self.has_target:
            return

        # For simplicity, we only return the mask for the level-1
        y_hist = self.semantic_target
        total_count = y_hist.sum(dim=1)
        void_count = y_hist[:, -1]
        return void_count / total_count > 0.5

    def __repr__(self):
        return f"{self.__class__.__name__}()"
    
    def voxel_semantic_pred(self, super_index=None, sub=None):
        """Semantic predictions on the level-0 voxels.

        Final semantic segmentation predictions are the argmax of the
        first-level partition logits. This function then distributes 
        these predictions to each level-0 point (ie voxel in our 
        framework).
        
        :param super_index: LongTensor
            Tensor holding, for each level-0 point (ie voxel), the index
            of the level-1 superpoint it belongs to
        :param sub: Cluster
            Cluster object indicating, for each level-1 superpoint, 
            the indices of the level-0 points (ie voxels) it contains    
        """
        assert super_index is not None or sub is not None, \
            "Must provide either `super_index` or `sub`"
        
        # If super_index is not provided, build it from sub
        if super_index is None:
            super_index = sub.to_super_index()
        
        # Distribute the level-1 superpoint predictions to the voxels
        return self.semantic_pred()[super_index]
        
    def voxel_logits_pred(self, super_index=None, sub=None):
        """Semantic predictions on the level-0 voxels.

        Final semantic segmentation predictions are the argmax of the
        first-level partition logits. This function then distributes 
        these predictions to each level-0 point (ie voxel in our 
        framework).
        
        :param super_index: LongTensor
            Tensor holding, for each level-0 point (ie voxel), the index
            of the level-1 superpoint it belongs to
        :param sub: Cluster
            Cluster object indicating, for each level-1 superpoint, 
            the indices of the level-0 points (ie voxels) it contains    
        """
        assert super_index is not None or sub is not None, \
            "Must provide either `super_index` or `sub`"
        
        # If super_index is not provided, build it from sub
        if super_index is None:
            super_index = sub.to_super_index()
        # Distribute the level-1 superpoint logits to the voxels
        return self.logits[0][super_index]

    def full_res_semantic_pred(
            self, 
            super_index_level0_to_level1=None, 
            super_index_raw_to_level0=None, 
            sub_level1_to_level0=None, 
            sub_level0_to_raw=None):
        """Semantic predictions on the full-resolution input point
        cloud.

        Final semantic segmentation predictions are the argmax of the
        first-level partition logits. This function then distributes 
        these predictions to each raw point (ie full-resolution point 
        cloud before voxelization in our framework).
        
        :param super_index_level0_to_level1: LongTensor
            Tensor holding, for each level-0 point (ie voxel), the index
            of the level-1 superpoint it belongs to
        :param super_index_raw_to_level0: LongTensor
            Tensor holding, for each raw full-resolution point, the 
            index of the level-0 point (ie voxel) it belongs to
        :param sub_level1_to_level0: Cluster
            Cluster object indicating, for each level-1 superpoint, 
            the indices of the level-0 points (ie voxels) it contains  
        :param sub_level0_to_raw: Cluster
            Cluster object indicating, for each level-0 point (ie 
            voxel), the indices of the raw full-resolution points it 
            contains    
        """
        assert super_index_level0_to_level1 is not None or sub_level1_to_level0 is not None, \
            "Must provide either `super_index_level0_to_level1` or `sub_level1_to_level0`"
    
        assert super_index_raw_to_level0 is not None or sub_level0_to_raw is not None, \
            "Must provide either `super_index_raw_to_level0` or `sub_level0_to_raw`"
        
        # If super_index are not provided, build them from sub
        if super_index_level0_to_level1 is None:
            super_index_level0_to_level1 = sub_level1_to_level0.to_super_index()
        if super_index_raw_to_level0 is None:
            super_index_raw_to_level0 = sub_level0_to_raw.to_super_index()
        
        # Distribute the level-1 superpoint predictions to the 
        # full-resolution points
        return self.semantic_pred()[super_index_level0_to_level1][super_index_raw_to_level0]
    
    def full_res_logits_pred(
        self,
        super_index_level0_to_level1=None,
        super_index_raw_to_level0=None,
        sub_level1_to_level0=None,
        sub_level0_to_raw=None):
        """Logits on the full-resolution input point cloud.

        This function propagates the level-1 superpoint logits to each
        raw point (ie full-resolution point cloud before voxelization).
        
        :param super_index_level0_to_level1: LongTensor
            For each level-0 point (voxel), the index of the level-1 superpoint it belongs to.
        :param super_index_raw_to_level0: LongTensor
            For each raw point, the index of the level-0 point it belongs to.
        :param sub_level1_to_level0: Cluster
            Optional. Used to build `super_index_level0_to_level1` if not given.
        :param sub_level0_to_raw: Cluster
            Optional. Used to build `super_index_raw_to_level0` if not given.
        :return: Tensor of shape (N_raw, C), where N_raw is the number of raw points,
                 and C is the number of classes.
        """
        assert super_index_level0_to_level1 is not None or sub_level1_to_level0 is not None, \
            "Must provide either `super_index_level0_to_level1` or `sub_level1_to_level0`"

        assert super_index_raw_to_level0 is not None or sub_level0_to_raw is not None, \
            "Must provide either `super_index_raw_to_level0` or `sub_level0_to_raw`"

        if super_index_level0_to_level1 is None:
            super_index_level0_to_level1 = sub_level1_to_level0.to_super_index()
        if super_index_raw_to_level0 is None:
            super_index_raw_to_level0 = sub_level0_to_raw.to_super_index()
        return self.logits[0][super_index_level0_to_level1][super_index_raw_to_level0]