|
|
import torch |
|
|
from torch_scatter import scatter_add |
|
|
|
|
|
|
|
|
__all__ = ['histogram_to_atomic', 'atomic_to_histogram'] |
|
|
|
|
|
|
|
|
def histogram_to_atomic(gt, pred): |
|
|
"""Convert ground truth and predictions at a segment level (i.e. |
|
|
ground truth is 2D tensor carrying histogram of labels in each |
|
|
segment), to pointwise 1D ground truth and predictions. |
|
|
|
|
|
:param gt: 1D or 2D torch.Tensor |
|
|
:param pred: 1D or 2D torch.Tensor |
|
|
""" |
|
|
assert gt.dim() <= 2 |
|
|
|
|
|
|
|
|
if gt.dim() == 1: |
|
|
return gt, pred |
|
|
if gt.shape[1] == 1: |
|
|
return gt.squeeze(1), pred |
|
|
|
|
|
|
|
|
num_nodes, num_classes = gt.shape |
|
|
device = pred.device |
|
|
|
|
|
|
|
|
point_gt = torch.arange( |
|
|
num_classes, device=device).repeat(num_nodes).repeat_interleave( |
|
|
gt.flatten()) |
|
|
|
|
|
|
|
|
point_pred = pred.repeat_interleave(gt.sum(dim=1), dim=0) |
|
|
|
|
|
return point_gt, point_pred |
|
|
|
|
|
|
|
|
def atomic_to_histogram(item, idx, n_bins=None): |
|
|
"""Convert point-level positive integer data to histograms of |
|
|
segment-level labels, based on idx. |
|
|
|
|
|
:param item: 1D or 2D torch.Tensor |
|
|
:param idx: 1D torch.Tensor |
|
|
""" |
|
|
assert item.ge(0).all(), \ |
|
|
"Mean aggregation only supports positive integers" |
|
|
assert item.dtype in [torch.uint8, torch.int, torch.long], \ |
|
|
"Mean aggregation only supports positive integers" |
|
|
assert item.ndim <= 2, \ |
|
|
"Voting and histograms are only supported for 1D and " \ |
|
|
"2D tensors" |
|
|
|
|
|
|
|
|
n_bins = item.max() + 1 if n_bins is None else n_bins |
|
|
|
|
|
|
|
|
in_dtype = item.dtype |
|
|
item = item.long() |
|
|
|
|
|
|
|
|
|
|
|
if item.ndim == 2: |
|
|
return scatter_add(item, idx, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset = item.min() |
|
|
item = torch.nn.functional.one_hot(item - offset) |
|
|
|
|
|
|
|
|
hist = scatter_add(item, idx, dim=0) |
|
|
N = hist.shape[0] |
|
|
device = hist.device |
|
|
|
|
|
|
|
|
|
|
|
bins_before = torch.zeros( |
|
|
N, offset, device=device, dtype=torch.long) |
|
|
hist = torch.cat((bins_before, hist), dim=1) |
|
|
|
|
|
|
|
|
bins_after = torch.zeros( |
|
|
N, n_bins - hist.shape[1], device=device, |
|
|
dtype=torch.long) |
|
|
hist = torch.cat((hist, bins_after), dim=1) |
|
|
|
|
|
|
|
|
hist = hist.to(in_dtype) |
|
|
|
|
|
return hist |
|
|
|