|
|
import torch |
|
|
import numpy as np |
|
|
from torch.nn.functional import one_hot |
|
|
from typing import List, Tuple, Union |
|
|
from torch_geometric.nn.pool.consecutive import consecutive_cluster |
|
|
from torch_scatter import scatter_max, scatter_sum |
|
|
|
|
|
from src.data.csr import CSRData, CSRBatch |
|
|
from src.utils import tensor_idx, is_dense, has_duplicates, to_trimmed |
|
|
|
|
|
|
|
|
__all__ = ['InstanceData', 'InstanceBatch'] |
|
|
|
|
|
|
|
|
class InstanceData(CSRData): |
|
|
"""Child class of CSRData to simplify some common operations |
|
|
dedicated to instance labels clustering. In particular, this data |
|
|
structure stores the cluster-object overlaps: for each cluster (i.e. |
|
|
segment, superpoint, node in the superpoint graph, etc), we store |
|
|
all the object instances with which it overlaps. Concretely, for |
|
|
each cluster-object pair, we store: |
|
|
- `obj`: the object's index |
|
|
- `count`: the number of points in the cluster-object overlap |
|
|
- `y`: the object's semantic label |
|
|
|
|
|
Importantly, each object in the InstanceData is expected to be |
|
|
described by a unique index in `obj', regardless of its actual |
|
|
semantic class. It is not required for the object instances to be |
|
|
contiguous in `[0, obj_max]`, although enforcing it may have |
|
|
beneficial downstream effects on memory and I/O times. Finally, |
|
|
when two InstanceData are batched in an InstanceBatch, the `obj' |
|
|
indices will be updated to avoid collision between the batch items. |
|
|
|
|
|
:param pointers: torch.LongTensor |
|
|
Pointers to address the data in the associated value tensors. |
|
|
`values[Pointers[i]:Pointers[i+1]]` hold the values for the ith |
|
|
cluster. If `dense=True`, the `pointers` are actually the dense |
|
|
indices to be converted to pointer format. |
|
|
:param obj: torch.LongTensor |
|
|
Object index for each cluster-object pair. Assumes there are |
|
|
NO DUPLICATE CLUSTER-OBJECT pairs in the input data, unless |
|
|
'dense=True'. |
|
|
:param count: torch.LongTensor |
|
|
Number of points in the overlap for each cluster-object pair. |
|
|
:param y: torch.LongTensor |
|
|
Semantic label the object for each cluster-object pair. By |
|
|
definition, we assume the objects to be SEMANTICALLY PURE. For |
|
|
that reason, we only store a single semantic label for objects, |
|
|
as opposed to superpoints, for which we want to maintain a |
|
|
histogram of labels. |
|
|
:param dense: bool |
|
|
If `dense=True`, the `pointers` are actually the dense indices |
|
|
to be converted to pointer format. Besides, any duplicate |
|
|
cluster-obj pairs will be merged and the corresponding `count` |
|
|
will be updated. |
|
|
:param kwargs: |
|
|
Other kwargs will be ignored. |
|
|
""" |
|
|
|
|
|
__value_keys__ = ['obj', 'count', 'y'] |
|
|
__is_index_value_serialization_key__ = None |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pointers: torch.Tensor, |
|
|
obj: torch.Tensor, |
|
|
count: torch.Tensor, |
|
|
y: torch.Tensor, |
|
|
dense: bool = False, |
|
|
**kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dense: |
|
|
|
|
|
cluster_obj_idx = pointers * (obj.max() + 1) + obj |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cluster_obj_idx, perm = consecutive_cluster(cluster_obj_idx) |
|
|
pointers = pointers[perm] |
|
|
obj = obj[perm] |
|
|
y = y[perm] |
|
|
|
|
|
|
|
|
|
|
|
count = scatter_sum(count, cluster_obj_idx) |
|
|
|
|
|
super().__init__( |
|
|
pointers, obj, count, y, dense=dense, |
|
|
is_index_value=[True, False, False]) |
|
|
|
|
|
@classmethod |
|
|
def get_base_class(cls) -> type: |
|
|
"""Helps `self.from_list()` and `self.to_list()` identify which |
|
|
classes to use for batch collation and un-collation. |
|
|
""" |
|
|
return InstanceData |
|
|
|
|
|
@classmethod |
|
|
def get_batch_class(cls) -> type: |
|
|
"""Helps `self.from_list()` and `self.to_list()` identify which |
|
|
classes to use for batch collation and un-collation. |
|
|
""" |
|
|
return InstanceBatch |
|
|
|
|
|
@property |
|
|
def obj(self) -> torch.Tensor: |
|
|
return self.values[0] |
|
|
|
|
|
@obj.setter |
|
|
def obj(self, obj: torch.Tensor): |
|
|
assert obj.device == self.device, \ |
|
|
f"obj is on {obj.device} while self is on {self.device}" |
|
|
self.values[0] = obj |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def count(self) -> torch.Tensor: |
|
|
return self.values[1] |
|
|
|
|
|
@count.setter |
|
|
def count(self, count: torch.Tensor): |
|
|
assert count.device == self.device, \ |
|
|
f"count is on {count.device} while self is on {self.device}" |
|
|
self.values[1] = count |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def y(self) -> torch.Tensor: |
|
|
return self.values[2] |
|
|
|
|
|
@y.setter |
|
|
def y(self, y: torch.Tensor): |
|
|
assert y.device == self.device, \ |
|
|
f"y is on {y.device} while self is on {self.device}" |
|
|
self.values[2] = y |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def num_clusters(self): |
|
|
return self.num_groups |
|
|
|
|
|
@property |
|
|
def num_overlaps(self): |
|
|
return self.num_items |
|
|
|
|
|
@property |
|
|
def num_obj(self): |
|
|
return self.obj.unique().numel() |
|
|
|
|
|
def major( |
|
|
self, |
|
|
num_classes: int = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Return the obj, count, and y of the majority instance in each |
|
|
cluster (i.e. the object with which it has the highest overlap). |
|
|
|
|
|
:param num_classes: int |
|
|
Number of classes in the dataset. Specifying `num_classes` |
|
|
allows identifying 'void' labels. By convention, we assume |
|
|
`y ∈ [0, self.num_classes-1]` ARE ALL VALID LABELS (i.e. not |
|
|
'ignored', 'void', 'unknown', etc), while `y < 0` AND |
|
|
`y >= self.num_classes` ARE VOID LABELS. Void data is dealt |
|
|
with following https://arxiv.org/abs/1801.00868 and |
|
|
https://arxiv.org/abs/1905.01220 |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
num_classes = num_classes if num_classes else self.y.max() + 1 |
|
|
|
|
|
|
|
|
|
|
|
cluster_idx = self.indices |
|
|
|
|
|
|
|
|
pair_is_void = (self.y < 0) | (self.y >= num_classes) |
|
|
|
|
|
|
|
|
x = torch.stack((self.count, self.count * ~pair_is_void)).T |
|
|
res = scatter_max(x, cluster_idx, dim=0) |
|
|
count = res[0][:, 0] |
|
|
argmax = res[1][:, 0] |
|
|
obj = self.obj[argmax] |
|
|
y = self.y[argmax] |
|
|
|
|
|
|
|
|
is_major_void = (y < 0) | (y >= num_classes) |
|
|
if (~is_major_void).all(): |
|
|
return obj, count, y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_count = scatter_sum(self.count, cluster_idx, dim=0) |
|
|
major_50_plus = (count / total_count) > 0.5 |
|
|
if major_50_plus[is_major_void].all(): |
|
|
return obj, count, y |
|
|
|
|
|
|
|
|
|
|
|
count_no_void = res[0][:, 1] |
|
|
argmax_no_void = res[1][:, 1] |
|
|
count[is_major_void] = count_no_void[is_major_void] |
|
|
obj[is_major_void] = self.obj[argmax_no_void][is_major_void] |
|
|
y[is_major_void] = self.y[argmax_no_void][is_major_void] |
|
|
|
|
|
return obj, count, y |
|
|
|
|
|
def merge( |
|
|
self, |
|
|
idx: Union[int, List[int], torch.Tensor, np.ndarray] |
|
|
) -> 'InstanceData': |
|
|
"""Merge clusters based on `idx` and return the result in a new |
|
|
InstanceData object. |
|
|
|
|
|
:param idx: 1D torch.LongTensor or numpy.NDArray |
|
|
Indices of the parent cluster each cluster should be merged |
|
|
into. Must have the same size as `self.num_clusters` and |
|
|
indices must start at 0 and be contiguous. |
|
|
""" |
|
|
|
|
|
|
|
|
idx = tensor_idx(idx) |
|
|
assert idx.shape == torch.Size([self.num_clusters]), \ |
|
|
f"Expected indices of shape {torch.Size([self.num_clusters])}, " \ |
|
|
f"but received shape {idx.shape} instead" |
|
|
assert is_dense(idx), f"Expected contiguous indices in [0, max]" |
|
|
|
|
|
|
|
|
merged_idx = idx[self.indices].long() |
|
|
|
|
|
|
|
|
|
|
|
return self.__class__( |
|
|
merged_idx, self.obj, self.count, self.y, dense=True) |
|
|
|
|
|
def iou_and_size(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Compute the Intersection over Union (IoU) and the individual |
|
|
size for each cluster-object pair in the data. This is typically |
|
|
needed for computing the Average Precision. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a_idx = self.indices |
|
|
b_idx = consecutive_cluster(self.obj)[0] |
|
|
|
|
|
|
|
|
a_size = scatter_sum(self.count, a_idx)[a_idx] |
|
|
b_size = scatter_sum(self.count, b_idx)[b_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(self, 'pair_cropped_count', None) is not None: |
|
|
b_size += self.pair_cropped_count |
|
|
|
|
|
|
|
|
iou = self.count / (a_size + b_size - self.count) |
|
|
|
|
|
return iou, a_size, b_size |
|
|
|
|
|
def estimate_centroid( |
|
|
self, |
|
|
cluster_pos: torch.Tensor, |
|
|
mode: str = 'iou' |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Estimate the centroid position of each object, based on the |
|
|
position of the clusters. |
|
|
|
|
|
Based on the hypothesis that clusters are relatively |
|
|
instance-pure, we can approximate the centroid of each object by |
|
|
taking the barycenter of the centroids of the clusters |
|
|
overlapping with each object, weighed down by their respective |
|
|
IoUs. |
|
|
|
|
|
NB: This is a proxy and one could design failure cases, when |
|
|
clusters are not pure enough. |
|
|
|
|
|
:param cluster_pos: Tensor of size [num_clusters, D] |
|
|
Centroid position of each cluster |
|
|
:param mode: str |
|
|
Method used to estimate the centroids. 'iou' will weigh down |
|
|
the centroids of the clusters overlapping each instance by |
|
|
their IoU. 'ratio-product' will use the product of the size |
|
|
ratios of the overlap wrt the cluster and wrt the instance. |
|
|
'overlap' will use the size of the overlap between the |
|
|
cluster and the instance. |
|
|
|
|
|
:return obj_pos, obj_idx |
|
|
obj_pos: Tensor |
|
|
Estimated position for each object |
|
|
obj_idx: Tensor |
|
|
Corresponding object indices |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a_idx = self.indices |
|
|
b_idx, perm = consecutive_cluster(self.obj) |
|
|
obj_idx = self.obj[perm] |
|
|
|
|
|
|
|
|
a_pos = cluster_pos[a_idx] |
|
|
|
|
|
|
|
|
mode = mode.lower() |
|
|
if mode == 'iou': |
|
|
iou, _, _ = self.iou_and_size() |
|
|
w = iou |
|
|
elif mode == 'product-iou': |
|
|
_, a_size, b_size = self.iou_and_size() |
|
|
w = self.count**2 / (a_size * b_size) |
|
|
elif mode == 'overlap': |
|
|
w = self.count |
|
|
else: |
|
|
raise NotImplementedError |
|
|
w = w.view(-1, 1) |
|
|
|
|
|
|
|
|
|
|
|
a_wpos = torch.cat((a_pos * w, w), dim=1) |
|
|
res = scatter_sum(a_wpos, b_idx, dim=0) |
|
|
obj_pos = res[:, :-1] / res[:, -1].view(-1, 1) |
|
|
|
|
|
return obj_pos, obj_idx |
|
|
|
|
|
def instance_graph( |
|
|
self, |
|
|
edge_index: torch.Tensor, |
|
|
num_classes: int = None, |
|
|
smooth_affinity: bool = True |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute instance graph and per-edge affinity scores. |
|
|
|
|
|
:param edge_index: Tensor of size [2, num_edges] |
|
|
Edges connecting the clusters in of the instance graph. The |
|
|
output instance graph will be a trimmed version of this |
|
|
graph, where only (i, j) edges with (i < j) are preserved. |
|
|
:param num_classes: int |
|
|
Number of classes in the dataset. Specifying `num_classes` |
|
|
allows identifying 'void' labels. By convention, we assume |
|
|
`y ∈ [0, self.num_classes-1]` ARE ALL VALID LABELS (i.e. not |
|
|
'ignored', 'void', 'unknown', etc), while `y < 0` AND |
|
|
`y >= self.num_classes` ARE VOID LABELS. Void data is dealt |
|
|
with following https://arxiv.org/abs/1801.00868 and |
|
|
https://arxiv.org/abs/1905.01220 |
|
|
:param smooth_affinity: bool |
|
|
If True, the affinity score computed for each edge will |
|
|
follow the 'smooth' formulation: |
|
|
`(overlap_i_obj_j / size_i + overlap_j_obj_i / size_j) / 2` |
|
|
for the edge `(i, j)`, where `obj_i` designates the target |
|
|
instance of `i`. If False, the affinity will be computed |
|
|
with the simpler formulation: `obj_i == obj_j` |
|
|
|
|
|
:return obj_edge_index, obj_edge_affinity |
|
|
obj_edge_index: Tensor of size [2, num_trimmed_edges] |
|
|
Edges of the trimmed instance graph |
|
|
obj_edge_affinity: Tensor |
|
|
Affinity for each edge |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
obj_edge_index = to_trimmed(edge_index.to(self.device)) |
|
|
|
|
|
|
|
|
if obj_edge_index.numel() == 0: |
|
|
return obj_edge_index, torch.zeros(0, device=self.device) |
|
|
|
|
|
|
|
|
|
|
|
sp_obj_idx = self.major(num_classes=num_classes)[0] |
|
|
|
|
|
|
|
|
|
|
|
i_obj_idx = sp_obj_idx[obj_edge_index[0]] |
|
|
j_obj_idx = sp_obj_idx[obj_edge_index[1]] |
|
|
|
|
|
|
|
|
|
|
|
if not smooth_affinity: |
|
|
return obj_edge_index, (i_obj_idx == j_obj_idx).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base = self.obj.max() + 1 |
|
|
A = self.indices * base + self.obj |
|
|
B = obj_edge_index[0] * base + j_obj_idx |
|
|
C = obj_edge_index[1] * base + i_obj_idx |
|
|
|
|
|
|
|
|
all_uid_raw = torch.cat((A, B, C)) |
|
|
uid, perm = consecutive_cluster(all_uid_raw) |
|
|
uid_raw = all_uid_raw[perm] |
|
|
num_uid = uid.max() + 1 |
|
|
A_uid = uid[:A.shape[0]] |
|
|
B_uid = uid[A.shape[0]:A.shape[0] + B.shape[0]] |
|
|
C_uid = uid[-C.shape[0]:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlaps = torch.zeros(num_uid, device=self.device) |
|
|
overlaps[A_uid] = self.count.float() |
|
|
overlap_i_obj_j = overlaps[B_uid] |
|
|
overlap_j_obj_i = overlaps[C_uid] |
|
|
|
|
|
|
|
|
sp_size = scatter_sum(self.count, self.indices) |
|
|
size_i = sp_size[obj_edge_index[0]].float() |
|
|
size_j = sp_size[obj_edge_index[1]].float() |
|
|
|
|
|
|
|
|
affinity = (overlap_i_obj_j / size_i + overlap_j_obj_i / size_j) / 2 |
|
|
|
|
|
return obj_edge_index, affinity |
|
|
|
|
|
def search_void( |
|
|
self, |
|
|
num_classes: int |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Search for clusters and objects with 'void' semantic labels. |
|
|
|
|
|
IMPORTANT: |
|
|
By convention, we assume `y ∈ [0, num_classes-1]` ARE ALL |
|
|
VALID LABELS (i.e. not 'void', 'ignored', 'unknown', etc), |
|
|
while `y < 0` AND `y >= num_classes` ARE VOID LABELS. |
|
|
This applies to both `Data.y` and `Data.obj.y`. |
|
|
|
|
|
Points with 'void' labels are handled following the procedure |
|
|
proposed in: |
|
|
- https://arxiv.org/abs/1801.00868 |
|
|
- https://arxiv.org/abs/1905.01220 |
|
|
|
|
|
More precisely, we remove from IoU and metrics computation: |
|
|
- predictions (i.e. clusters here) containing more than 50% of |
|
|
'void' points |
|
|
- targets (i.e. objects here) containing more than 50% of |
|
|
'void' points. In our case, we assume targets to be |
|
|
SEMANTICALLY PURE, so we remove a target even if it contains |
|
|
a single 'void' point |
|
|
|
|
|
To this end, the present function returns: |
|
|
- `cluster_mask`: boolean mask of the clusters containing more |
|
|
than 50% points with `void` labels |
|
|
- `pair_mask`: boolean mask of the cluster-object pairs whose |
|
|
object (i.e. target) has an `void` label |
|
|
- `pair_cropped_count`: tensor of cropped target size, for |
|
|
each pair. Indeed, blindly removing the predictions with 50% |
|
|
or more void points will affect downstream IoU computation. |
|
|
To account for this, this, `pair_cropped_count` is intended |
|
|
to be used at IoU computation time, when assessing the |
|
|
prediction and target sizes |
|
|
|
|
|
NB: by construction, removing pairs in `pair_mask` from the |
|
|
InstanceData will also remove all target objects containing |
|
|
'void' points. Importantly, this assumes, however, that the |
|
|
raw instance annotations in the datasets are semantically |
|
|
pure: all annotated instances contain points of the same |
|
|
class. Said otherwise: IF AN INSTANCE CONTAINS A SINGLE |
|
|
'VOID' POINT, THEN ALL OF ITS POINTS ARE 'VOID'. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
is_pair_b_void = (self.y < 0) | (self.y >= num_classes) |
|
|
|
|
|
|
|
|
pair_a_idx = self.indices |
|
|
|
|
|
|
|
|
a_size = scatter_sum(self.count, pair_a_idx) |
|
|
|
|
|
|
|
|
void_a_idx = pair_a_idx[is_pair_b_void].unique() |
|
|
|
|
|
|
|
|
|
|
|
void_a_total_size = a_size[void_a_idx] |
|
|
void_a_void_size = scatter_sum( |
|
|
self.count[is_pair_b_void], pair_a_idx[is_pair_b_void])[void_a_idx] |
|
|
void_a_50_plus = (void_a_void_size / void_a_total_size.float()) > 0.5 |
|
|
void_a_50_plus_idx = void_a_idx[void_a_50_plus] |
|
|
|
|
|
|
|
|
is_a_void = torch.zeros( |
|
|
self.num_clusters, dtype=torch.bool, device=self.device) |
|
|
is_a_void[void_a_50_plus_idx] = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b_idx = consecutive_cluster(self.obj)[0] |
|
|
pair_cropped_count = scatter_sum( |
|
|
self.count * is_a_void[pair_a_idx], b_idx)[b_idx] |
|
|
|
|
|
|
|
|
|
|
|
is_pair_void = is_pair_b_void | is_a_void[pair_a_idx] |
|
|
|
|
|
return is_a_void, is_pair_void, pair_cropped_count |
|
|
|
|
|
def remove_void( |
|
|
self, |
|
|
num_classes: int |
|
|
) -> Tuple['InstanceData', torch.Tensor]: |
|
|
"""Return a new InstanceData with void clusters, objects and |
|
|
pairs removed. |
|
|
|
|
|
IMPORTANT: |
|
|
By convention, we assume `y ∈ [0, num_classes-1]` ARE ALL |
|
|
VALID LABELS (i.e. not 'void', 'ignored', 'unknown', etc), |
|
|
while `y < 0` AND `y >= num_classes` ARE VOID LABELS. |
|
|
This applies to both `Data.y` and `Data.obj.y`. |
|
|
|
|
|
Points with 'void' labels are handled following the procedure |
|
|
proposed in: |
|
|
- https://arxiv.org/abs/1801.00868 |
|
|
- https://arxiv.org/abs/1905.01220 |
|
|
|
|
|
More precisely: |
|
|
- predictions (i.e. clusters here) containing more than 50% of |
|
|
'void' points are removed from the metrics computation |
|
|
- targets (i.e. objects here) containing more than 50% of |
|
|
'void' points are removed from the metrics computation |
|
|
- the remaining 'void' points are ignored when computing the |
|
|
prediction-target (i.e. cluster-object here) IoUs |
|
|
|
|
|
To this end, the present function returns: |
|
|
- `instance_data`: a new InstanceData object with all void |
|
|
clusters, objects, and pairs removed |
|
|
- `non_void_mask`: boolean mask spanning the clusters, |
|
|
indicating the clusters that were preserved in the |
|
|
`instance_data`. This mask can be used outside of this |
|
|
function to subsample cluster-wise information after |
|
|
void-removal |
|
|
|
|
|
NB: by construction, removing pairs in `pair_mask` from the |
|
|
InstanceData will also remove all target objects containing |
|
|
'void' points. Importantly, this assumes, however, that the |
|
|
raw instance annotations in the datasets are semantically |
|
|
pure: all annotated instances contain points of the same |
|
|
class. Said otherwise: IF AN INSTANCE CONTAINS A SINGLE |
|
|
'VOID' POINT, THEN ALL OF ITS POINTS ARE 'VOID'. |
|
|
""" |
|
|
|
|
|
is_cluster_void, is_pair_void, pair_cropped_count = \ |
|
|
self.search_void(num_classes) |
|
|
|
|
|
|
|
|
idx = self.indices |
|
|
idx = idx[~is_pair_void] |
|
|
idx = consecutive_cluster(idx)[0] |
|
|
obj = self.obj[~is_pair_void] |
|
|
count = self.count[~is_pair_void] |
|
|
y = self.y[~is_pair_void] |
|
|
pair_cropped_count = pair_cropped_count[~is_pair_void] |
|
|
instance_data = self.__class__(idx, obj, count, y, dense=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
instance_data.pair_cropped_count = pair_cropped_count |
|
|
|
|
|
return instance_data, ~is_cluster_void |
|
|
|
|
|
def debug(self): |
|
|
super().debug() |
|
|
|
|
|
|
|
|
cluster_obj_idx = self.indices * (self.obj.max() + 1) + self.obj |
|
|
assert not has_duplicates(cluster_obj_idx) |
|
|
|
|
|
def __repr__(self): |
|
|
info = [ |
|
|
f"{key}={getattr(self, key)}" |
|
|
for key in ['num_clusters', 'num_overlaps', 'num_obj', 'device']] |
|
|
return f"{self.__class__.__name__}({', '.join(info)})" |
|
|
|
|
|
def target_label_histogram(self, num_classes: int) -> torch.Tensor: |
|
|
"""Compute the target histogram for semantic segmentation. That |
|
|
is, for each cluster, the histogram of pointwise labels of its |
|
|
overlaps. When joined with cluster-wise semantic predictions, |
|
|
this histogram can be passed to a ConfusionMatrix metric. |
|
|
|
|
|
:param num_classes: int |
|
|
Number of valid classes. By convention, we assume |
|
|
`y ∈ [0, num_classes-1]` are VALID LABELS, while |
|
|
`y < 0` AND `y >= num_classes` ARE VOID LABELS |
|
|
|
|
|
:return: Tensor of shape [num_clusters, num_classes + 1] |
|
|
""" |
|
|
|
|
|
y = self.y.clone() |
|
|
y[(y < 0) | (y > num_classes)] = num_classes |
|
|
|
|
|
|
|
|
y_hist = one_hot(y, num_classes=num_classes + 1) * self.count.view(-1, 1) |
|
|
return scatter_sum(y_hist, self.indices, dim=0) |
|
|
|
|
|
def semantic_segmentation_oracle( |
|
|
self, |
|
|
num_classes: int, |
|
|
*metric_args, |
|
|
**metric_kwargs |
|
|
) -> 'SemanticMetricResults': |
|
|
"""Compute the oracle performance for semantic segmentation, |
|
|
when all clusters predict the dominant label among their points. |
|
|
This corresponds to the highest achievable performance with the |
|
|
partition at hand. |
|
|
|
|
|
:param num_classes: int |
|
|
Number of valid classes. By convention, we assume |
|
|
`y ∈ [0, num_classes-1]` are VALID LABELS, while |
|
|
`y < 0` AND `y >= num_classes` ARE VOID LABELS |
|
|
:param metric_args: |
|
|
Args for the metrics computation |
|
|
:param metric_kwargs: |
|
|
Kwargs for the metrics computation |
|
|
|
|
|
:return: SemanticMetricResults |
|
|
""" |
|
|
|
|
|
y_hist = self.target_label_histogram(num_classes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = y_hist[:, :num_classes].argmax(dim=1) |
|
|
target = y_hist |
|
|
|
|
|
|
|
|
from src.metrics import ConfusionMatrix |
|
|
cm = ConfusionMatrix(num_classes, *metric_args, **metric_kwargs) |
|
|
cm(pred.cpu(), target.cpu()) |
|
|
metrics = cm.all_metrics() |
|
|
|
|
|
return metrics |
|
|
|
|
|
def oracle( |
|
|
self, |
|
|
num_classes: int |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Compute the oracle predictions for instance and panoptic |
|
|
segmentation. This is a proxy for the highest achievable |
|
|
performance with the cluster partition at hand. The output data |
|
|
can be passed to the relevant metrics in `src.metrics` for |
|
|
performance computation. |
|
|
|
|
|
More precisely, for the oracle prediction: |
|
|
- each cluster is assigned to the instance it shares the most |
|
|
points with |
|
|
- clusters assigned to the same instance are merged into a |
|
|
single prediction |
|
|
- each predicted instance has a score equal to its IoU with |
|
|
the assigned target instance |
|
|
|
|
|
:param num_classes: int |
|
|
Number of valid classes. By convention, we assume |
|
|
`y ∈ [0, num_classes-1]` are VALID LABELS, while |
|
|
`y < 0` AND `y >= num_classes` ARE VOID LABELS |
|
|
:return: oracle_scores, oracle_y, oracle_instance_data |
|
|
""" |
|
|
|
|
|
|
|
|
obj, count, y = self.major(num_classes=num_classes) |
|
|
idx, perm = consecutive_cluster(obj) |
|
|
|
|
|
|
|
|
|
|
|
oracle = self.merge(idx) |
|
|
|
|
|
|
|
|
|
|
|
oracle_y = y[perm] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iou = oracle.iou_and_size()[0] |
|
|
argmax = scatter_max(oracle.count, oracle.indices)[1] |
|
|
oracle_scores = iou[argmax] |
|
|
|
|
|
return oracle_scores, oracle_y, oracle |
|
|
|
|
|
def instance_segmentation_oracle( |
|
|
self, |
|
|
num_classes: int, |
|
|
**metric_kwargs |
|
|
) -> 'InstanceMetricResults': |
|
|
"""Compute the oracle performance for instance segmentation. |
|
|
This is a proxy for the highest achievable performance with the |
|
|
cluster partition at hand. |
|
|
|
|
|
More precisely, for the oracle prediction: |
|
|
- each cluster is assigned to the instance it shares the most |
|
|
points with |
|
|
- clusters assigned to the same instance are merged into a |
|
|
single prediction |
|
|
- each predicted instance has a score equal to its IoU with |
|
|
the assigned target instance |
|
|
|
|
|
:param num_classes: int |
|
|
Number of valid classes. By convention, we assume |
|
|
`y ∈ [0, num_classes-1]` are VALID LABELS, while |
|
|
`y < 0` AND `y >= num_classes` ARE VOID LABELS |
|
|
:param metric_kwargs: |
|
|
Kwargs for the metrics computation |
|
|
|
|
|
:return: InstanceMetricResults |
|
|
""" |
|
|
|
|
|
oracle_scores, oracle_y, oracle = self.oracle(num_classes) |
|
|
|
|
|
|
|
|
from src.metrics import MeanAveragePrecision3D |
|
|
metric = MeanAveragePrecision3D(num_classes, **metric_kwargs) |
|
|
metric.update(oracle_scores, oracle_y, oracle) |
|
|
results = metric.compute() |
|
|
|
|
|
return results |
|
|
|
|
|
def panoptic_segmentation_oracle( |
|
|
self, |
|
|
num_classes: int, |
|
|
**metric_kwargs |
|
|
) -> 'PanopticMetricResults': |
|
|
"""Compute the oracle performance for panoptic segmentation. |
|
|
This is a proxy for the highest achievable performance with the |
|
|
cluster partition at hand. |
|
|
|
|
|
More precisely, for the oracle prediction: |
|
|
- each cluster is assigned to the instance it shares the most |
|
|
points with |
|
|
- clusters assigned to the same instance are merged into a |
|
|
single prediction |
|
|
|
|
|
:param num_classes: int |
|
|
Number of valid classes. By convention, we assume |
|
|
`y ∈ [0, num_classes-1]` are VALID LABELS, while |
|
|
`y < 0` AND `y >= num_classes` ARE VOID LABELS |
|
|
:param metric_kwargs: |
|
|
Kwargs for the metrics computation |
|
|
|
|
|
:return: PanopticMetricResults |
|
|
""" |
|
|
|
|
|
oracle_scores, oracle_y, oracle = self.oracle(num_classes) |
|
|
|
|
|
|
|
|
from src.metrics import PanopticQuality3D |
|
|
metric = PanopticQuality3D(num_classes, **metric_kwargs) |
|
|
metric.update(oracle_y, oracle) |
|
|
results = metric.compute() |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
class InstanceBatch(InstanceData, CSRBatch): |
|
|
"""Wrapper for InstanceData batching. Importantly, although |
|
|
instance labels in 'obj' will be updated to avoid collisions between |
|
|
the different batch items. |
|
|
""" |
|
|
|