English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
import torch
from torch_scatter import scatter_add
__all__ = ['histogram_to_atomic', 'atomic_to_histogram']
def histogram_to_atomic(gt, pred):
"""Convert ground truth and predictions at a segment level (i.e.
ground truth is 2D tensor carrying histogram of labels in each
segment), to pointwise 1D ground truth and predictions.
:param gt: 1D or 2D torch.Tensor
:param pred: 1D or 2D torch.Tensor
"""
assert gt.dim() <= 2
# Edge cases where nothing happens
if gt.dim() == 1:
return gt, pred
if gt.shape[1] == 1:
return gt.squeeze(1), pred
# Initialization
num_nodes, num_classes = gt.shape
device = pred.device
# Flatten the pointwise ground truth
point_gt = torch.arange(
num_classes, device=device).repeat(num_nodes).repeat_interleave(
gt.flatten())
# Expand the pointwise ground truth
point_pred = pred.repeat_interleave(gt.sum(dim=1), dim=0)
return point_gt, point_pred
def atomic_to_histogram(item, idx, n_bins=None):
"""Convert point-level positive integer data to histograms of
segment-level labels, based on idx.
:param item: 1D or 2D torch.Tensor
:param idx: 1D torch.Tensor
"""
assert item.ge(0).all(), \
"Mean aggregation only supports positive integers"
assert item.dtype in [torch.uint8, torch.int, torch.long], \
"Mean aggregation only supports positive integers"
assert item.ndim <= 2, \
"Voting and histograms are only supported for 1D and " \
"2D tensors"
# Initialization
n_bins = item.max() + 1 if n_bins is None else n_bins
# Temporarily convert input item to long
in_dtype = item.dtype
item = item.long()
# Important: if values are already 2D, we consider them to
# be histograms and will simply scatter_add them
if item.ndim == 2:
return scatter_add(item, idx, dim=0)
# Convert values to one-hot encoding. Values are temporarily offset
# to 0 to save some memory and compute in one-hot encoding and
# scatter_add
offset = item.min()
item = torch.nn.functional.one_hot(item - offset)
# Count number of occurrence of each value
hist = scatter_add(item, idx, dim=0)
N = hist.shape[0]
device = hist.device
# Prepend 0 columns to the histogram for bins removed due to
# offsetting
bins_before = torch.zeros(
N, offset, device=device, dtype=torch.long)
hist = torch.cat((bins_before, hist), dim=1)
# Append columns to the histogram for unobserved classes/bins
bins_after = torch.zeros(
N, n_bins - hist.shape[1], device=device,
dtype=torch.long)
hist = torch.cat((hist, bins_after), dim=1)
# Restore input dtype
hist = hist.to(in_dtype)
return hist