|
|
import torch |
|
|
import logging |
|
|
import src |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
__all__ = ['SemanticSegmentationOutput'] |
|
|
|
|
|
|
|
|
class SemanticSegmentationOutput: |
|
|
"""A simple holder for semantic segmentation model output, with a |
|
|
few helper methods for manipulating the predictions and targets |
|
|
(if any). |
|
|
""" |
|
|
|
|
|
def __init__(self, logits, y_hist=None): |
|
|
self.logits = logits |
|
|
self.y_hist = y_hist |
|
|
if src.is_debug_enabled(): |
|
|
self.debug() |
|
|
|
|
|
def debug(self): |
|
|
"""Runs a series of sanity checks on the attributes of self. |
|
|
""" |
|
|
assert isinstance(self.logits, torch.Tensor) \ |
|
|
or all(isinstance(l, torch.Tensor) for l in self.logits) |
|
|
if self.has_target: |
|
|
if self.multi_stage: |
|
|
assert len(self.y_hist) == len(self.logits) |
|
|
assert all( |
|
|
y.shape[0] == l.shape[0] |
|
|
for y, l in zip(self.y_hist, self.logits)) |
|
|
else: |
|
|
assert self.y_hist.shape[0] == self.logits.shape[0] |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
"""Returns the device on which the logits are stored, assuming |
|
|
all other output variables held by the object are also on the |
|
|
same device. |
|
|
""" |
|
|
logits = self.logits[0] if self.multi_stage else self.logits |
|
|
return logits.device |
|
|
|
|
|
@property |
|
|
def has_target(self): |
|
|
"""Check whether `self` contains target data for semantic |
|
|
segmentation. |
|
|
""" |
|
|
return self.y_hist is not None |
|
|
|
|
|
@property |
|
|
def multi_stage(self): |
|
|
"""If the semantic segmentation `logits` are stored in an |
|
|
enumerable, then the model output is multi-stage. |
|
|
""" |
|
|
return not isinstance(self.logits, torch.Tensor) |
|
|
|
|
|
@property |
|
|
def num_classes(self): |
|
|
"""Number for semantic classes in the output predictions. |
|
|
""" |
|
|
logits = self.logits[0] if self.multi_stage else self.logits |
|
|
return logits.shape[1] |
|
|
|
|
|
@property |
|
|
def num_nodes(self): |
|
|
"""Number for nodes/superpoints in the output predictions. By |
|
|
default, for a hierarchical partition, this means counting the |
|
|
number of level-1 nodes/superpoints. |
|
|
""" |
|
|
logits = self.logits[0] if self.multi_stage else self.logits |
|
|
return logits.shape[0] |
|
|
|
|
|
def semantic_pred(self): |
|
|
"""Semantic predictions on the level-1 superpoint. |
|
|
|
|
|
Final semantic segmentation predictions are the argmax of the |
|
|
first-level partition logits. |
|
|
""" |
|
|
logits = self.logits[0] if self.multi_stage else self.logits |
|
|
return torch.argmax(logits, dim=1) |
|
|
|
|
|
@property |
|
|
def semantic_target(self): |
|
|
"""Semantic target on the level-1 superpoint. |
|
|
|
|
|
Final semantic segmentation target are the label histogram |
|
|
of the first-level partition logits. |
|
|
""" |
|
|
return self.y_hist[0] if self.multi_stage else self.y_hist |
|
|
|
|
|
@property |
|
|
def void_mask(self): |
|
|
"""Returns a mask on the level-1 nodes indicating which is void. |
|
|
By convention, nodes/superpoints are void if they contain |
|
|
more than 50% void points. By convention in this project, void |
|
|
points have the label `num_classes`. In label histograms, void |
|
|
points are counted in the last column. |
|
|
""" |
|
|
if not self.has_target: |
|
|
return |
|
|
|
|
|
|
|
|
y_hist = self.semantic_target |
|
|
total_count = y_hist.sum(dim=1) |
|
|
void_count = y_hist[:, -1] |
|
|
return void_count / total_count > 0.5 |
|
|
|
|
|
def __repr__(self): |
|
|
return f"{self.__class__.__name__}()" |
|
|
|
|
|
def voxel_semantic_pred(self, super_index=None, sub=None): |
|
|
"""Semantic predictions on the level-0 voxels. |
|
|
|
|
|
Final semantic segmentation predictions are the argmax of the |
|
|
first-level partition logits. This function then distributes |
|
|
these predictions to each level-0 point (ie voxel in our |
|
|
framework). |
|
|
|
|
|
:param super_index: LongTensor |
|
|
Tensor holding, for each level-0 point (ie voxel), the index |
|
|
of the level-1 superpoint it belongs to |
|
|
:param sub: Cluster |
|
|
Cluster object indicating, for each level-1 superpoint, |
|
|
the indices of the level-0 points (ie voxels) it contains |
|
|
""" |
|
|
assert super_index is not None or sub is not None, \ |
|
|
"Must provide either `super_index` or `sub`" |
|
|
|
|
|
|
|
|
if super_index is None: |
|
|
super_index = sub.to_super_index() |
|
|
|
|
|
|
|
|
return self.semantic_pred()[super_index] |
|
|
|
|
|
def voxel_logits_pred(self, super_index=None, sub=None): |
|
|
"""Semantic predictions on the level-0 voxels. |
|
|
|
|
|
Final semantic segmentation predictions are the argmax of the |
|
|
first-level partition logits. This function then distributes |
|
|
these predictions to each level-0 point (ie voxel in our |
|
|
framework). |
|
|
|
|
|
:param super_index: LongTensor |
|
|
Tensor holding, for each level-0 point (ie voxel), the index |
|
|
of the level-1 superpoint it belongs to |
|
|
:param sub: Cluster |
|
|
Cluster object indicating, for each level-1 superpoint, |
|
|
the indices of the level-0 points (ie voxels) it contains |
|
|
""" |
|
|
assert super_index is not None or sub is not None, \ |
|
|
"Must provide either `super_index` or `sub`" |
|
|
|
|
|
|
|
|
if super_index is None: |
|
|
super_index = sub.to_super_index() |
|
|
|
|
|
return self.logits[0][super_index] |
|
|
|
|
|
def full_res_semantic_pred( |
|
|
self, |
|
|
super_index_level0_to_level1=None, |
|
|
super_index_raw_to_level0=None, |
|
|
sub_level1_to_level0=None, |
|
|
sub_level0_to_raw=None): |
|
|
"""Semantic predictions on the full-resolution input point |
|
|
cloud. |
|
|
|
|
|
Final semantic segmentation predictions are the argmax of the |
|
|
first-level partition logits. This function then distributes |
|
|
these predictions to each raw point (ie full-resolution point |
|
|
cloud before voxelization in our framework). |
|
|
|
|
|
:param super_index_level0_to_level1: LongTensor |
|
|
Tensor holding, for each level-0 point (ie voxel), the index |
|
|
of the level-1 superpoint it belongs to |
|
|
:param super_index_raw_to_level0: LongTensor |
|
|
Tensor holding, for each raw full-resolution point, the |
|
|
index of the level-0 point (ie voxel) it belongs to |
|
|
:param sub_level1_to_level0: Cluster |
|
|
Cluster object indicating, for each level-1 superpoint, |
|
|
the indices of the level-0 points (ie voxels) it contains |
|
|
:param sub_level0_to_raw: Cluster |
|
|
Cluster object indicating, for each level-0 point (ie |
|
|
voxel), the indices of the raw full-resolution points it |
|
|
contains |
|
|
""" |
|
|
assert super_index_level0_to_level1 is not None or sub_level1_to_level0 is not None, \ |
|
|
"Must provide either `super_index_level0_to_level1` or `sub_level1_to_level0`" |
|
|
|
|
|
assert super_index_raw_to_level0 is not None or sub_level0_to_raw is not None, \ |
|
|
"Must provide either `super_index_raw_to_level0` or `sub_level0_to_raw`" |
|
|
|
|
|
|
|
|
if super_index_level0_to_level1 is None: |
|
|
super_index_level0_to_level1 = sub_level1_to_level0.to_super_index() |
|
|
if super_index_raw_to_level0 is None: |
|
|
super_index_raw_to_level0 = sub_level0_to_raw.to_super_index() |
|
|
|
|
|
|
|
|
|
|
|
return self.semantic_pred()[super_index_level0_to_level1][super_index_raw_to_level0] |
|
|
|
|
|
def full_res_logits_pred( |
|
|
self, |
|
|
super_index_level0_to_level1=None, |
|
|
super_index_raw_to_level0=None, |
|
|
sub_level1_to_level0=None, |
|
|
sub_level0_to_raw=None): |
|
|
"""Logits on the full-resolution input point cloud. |
|
|
|
|
|
This function propagates the level-1 superpoint logits to each |
|
|
raw point (ie full-resolution point cloud before voxelization). |
|
|
|
|
|
:param super_index_level0_to_level1: LongTensor |
|
|
For each level-0 point (voxel), the index of the level-1 superpoint it belongs to. |
|
|
:param super_index_raw_to_level0: LongTensor |
|
|
For each raw point, the index of the level-0 point it belongs to. |
|
|
:param sub_level1_to_level0: Cluster |
|
|
Optional. Used to build `super_index_level0_to_level1` if not given. |
|
|
:param sub_level0_to_raw: Cluster |
|
|
Optional. Used to build `super_index_raw_to_level0` if not given. |
|
|
:return: Tensor of shape (N_raw, C), where N_raw is the number of raw points, |
|
|
and C is the number of classes. |
|
|
""" |
|
|
assert super_index_level0_to_level1 is not None or sub_level1_to_level0 is not None, \ |
|
|
"Must provide either `super_index_level0_to_level1` or `sub_level1_to_level0`" |
|
|
|
|
|
assert super_index_raw_to_level0 is not None or sub_level0_to_raw is not None, \ |
|
|
"Must provide either `super_index_raw_to_level0` or `sub_level0_to_raw`" |
|
|
|
|
|
if super_index_level0_to_level1 is None: |
|
|
super_index_level0_to_level1 = sub_level1_to_level0.to_super_index() |
|
|
if super_index_raw_to_level0 is None: |
|
|
super_index_raw_to_level0 = sub_level0_to_raw.to_super_index() |
|
|
return self.logits[0][super_index_level0_to_level1][super_index_raw_to_level0] |
|
|
|