| | import torch |
| | from torch_scatter import scatter_add |
| |
|
| |
|
| | __all__ = ['histogram_to_atomic', 'atomic_to_histogram'] |
| |
|
| |
|
| | def histogram_to_atomic(gt, pred): |
| | """Convert ground truth and predictions at a segment level (i.e. |
| | ground truth is 2D tensor carrying histogram of labels in each |
| | segment), to pointwise 1D ground truth and predictions. |
| | |
| | :param gt: 1D or 2D torch.Tensor |
| | :param pred: 1D or 2D torch.Tensor |
| | """ |
| | assert gt.dim() <= 2 |
| |
|
| | |
| | if gt.dim() == 1: |
| | return gt, pred |
| | if gt.shape[1] == 1: |
| | return gt.squeeze(1), pred |
| |
|
| | |
| | num_nodes, num_classes = gt.shape |
| | device = pred.device |
| |
|
| | |
| | point_gt = torch.arange( |
| | num_classes, device=device).repeat(num_nodes).repeat_interleave( |
| | gt.flatten()) |
| |
|
| | |
| | point_pred = pred.repeat_interleave(gt.sum(dim=1), dim=0) |
| |
|
| | return point_gt, point_pred |
| |
|
| |
|
| | def atomic_to_histogram(item, idx, n_bins=None): |
| | """Convert point-level positive integer data to histograms of |
| | segment-level labels, based on idx. |
| | |
| | :param item: 1D or 2D torch.Tensor |
| | :param idx: 1D torch.Tensor |
| | """ |
| | assert item.ge(0).all(), \ |
| | "Mean aggregation only supports positive integers" |
| | assert item.dtype in [torch.uint8, torch.int, torch.long], \ |
| | "Mean aggregation only supports positive integers" |
| | assert item.ndim <= 2, \ |
| | "Voting and histograms are only supported for 1D and " \ |
| | "2D tensors" |
| |
|
| | |
| | n_bins = item.max() + 1 if n_bins is None else n_bins |
| |
|
| | |
| | in_dtype = item.dtype |
| | item = item.long() |
| |
|
| | |
| | |
| | if item.ndim == 2: |
| | return scatter_add(item, idx, dim=0) |
| |
|
| | |
| | |
| | |
| | offset = item.min() |
| | item = torch.nn.functional.one_hot(item - offset) |
| |
|
| | |
| | hist = scatter_add(item, idx, dim=0) |
| | N = hist.shape[0] |
| | device = hist.device |
| |
|
| | |
| | |
| | bins_before = torch.zeros( |
| | N, offset, device=device, dtype=torch.long) |
| | hist = torch.cat((bins_before, hist), dim=1) |
| |
|
| | |
| | bins_after = torch.zeros( |
| | N, n_bins - hist.shape[1], device=device, |
| | dtype=torch.long) |
| | hist = torch.cat((hist, bins_after), dim=1) |
| |
|
| | |
| | hist = hist.to(in_dtype) |
| |
|
| | return hist |
| |
|