|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
from .base_data_element import BaseDataElement |
|
|
|
|
|
|
|
|
class LabelData(BaseDataElement): |
|
|
"""Data structure for label-level annotations or predictions.""" |
|
|
|
|
|
@staticmethod |
|
|
def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert the one-hot input to label. |
|
|
|
|
|
Args: |
|
|
onehot (torch.Tensor, optional): The one-hot input. The format |
|
|
of input must be one-hot. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The converted results. |
|
|
""" |
|
|
assert isinstance(onehot, torch.Tensor) |
|
|
if (onehot.ndim == 1 and onehot.max().item() <= 1 |
|
|
and onehot.min().item() >= 0): |
|
|
return onehot.nonzero().squeeze(-1) |
|
|
else: |
|
|
raise ValueError( |
|
|
'input is not one-hot and can not convert to label') |
|
|
|
|
|
@staticmethod |
|
|
def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor: |
|
|
"""Convert the label-format input to one-hot. |
|
|
|
|
|
Args: |
|
|
label (torch.Tensor): The label-format input. The format |
|
|
of item must be label-format. |
|
|
num_classes (int): The number of classes. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The converted results. |
|
|
""" |
|
|
assert isinstance(label, torch.Tensor) |
|
|
onehot = label.new_zeros((num_classes, )) |
|
|
assert max(label, default=torch.tensor(0)).item() < num_classes |
|
|
onehot[label] = 1 |
|
|
return onehot |
|
|
|