Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ding.torch_utils.network import one_hot | |
| class MultiLogitsLoss(nn.Module): | |
| """ | |
| Overview: | |
| Base class for supervised learning on linklink, including basic processes. | |
| Interfaces: | |
| ``__init__``, ``forward``. | |
| """ | |
| def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None: | |
| """ | |
| Overview: | |
| Initialization method, use cross_entropy as default criterion. | |
| Arguments: | |
| - criterion (:obj:`str`): Criterion type, supports ['cross_entropy', 'label_smooth_ce']. | |
| - smooth_ratio (:obs:`float`): Smoothing ratio for label smoothing. | |
| """ | |
| super(MultiLogitsLoss, self).__init__() | |
| if criterion is None: | |
| criterion = 'cross_entropy' | |
| assert (criterion in ['cross_entropy', 'label_smooth_ce']) | |
| self.criterion = criterion | |
| if self.criterion == 'label_smooth_ce': | |
| self.ratio = smooth_ratio | |
| def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.LongTensor: | |
| """ | |
| Overview: | |
| Process the label according to the criterion. | |
| Arguments: | |
| - logits (:obj:`torch.Tensor`): Predicted logits. | |
| - labels (:obj:`torch.LongTensor`): Ground truth. | |
| Returns: | |
| - ret (:obj:`torch.LongTensor`): Processed label. | |
| """ | |
| N = logits.shape[1] | |
| if self.criterion == 'cross_entropy': | |
| return one_hot(labels, num=N) | |
| elif self.criterion == 'label_smooth_ce': | |
| val = float(self.ratio) / (N - 1) | |
| ret = torch.full_like(logits, val) | |
| ret.scatter_(1, labels.unsqueeze(1), 1 - val) | |
| return ret | |
| def _nll_loss(self, nlls: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate the negative log likelihood loss. | |
| Arguments: | |
| - nlls (:obj:`torch.Tensor`): Negative log likelihood loss. | |
| - labels (:obj:`torch.LongTensor`): Ground truth. | |
| Returns: | |
| - ret (:obj:`torch.Tensor`): Calculated loss. | |
| """ | |
| ret = (-nlls * (labels.detach())) | |
| return ret.sum(dim=1) | |
| def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate the metric matrix. | |
| Arguments: | |
| - logits (:obj:`torch.Tensor`): Predicted logits. | |
| - labels (:obj:`torch.LongTensor`): Ground truth. | |
| Returns: | |
| - metric (:obj:`torch.Tensor`): Calculated metric matrix. | |
| """ | |
| M, N = logits.shape | |
| labels = self._label_process(logits, labels) | |
| logits = F.log_softmax(logits, dim=1) | |
| metric = [] | |
| for i in range(M): | |
| logit = logits[i] | |
| logit = logit.repeat(M).reshape(M, N) | |
| metric.append(self._nll_loss(logit, labels)) | |
| return torch.stack(metric, dim=0) | |
| def _match(self, matrix: torch.Tensor): | |
| """ | |
| Overview: | |
| Match the metric matrix. | |
| Arguments: | |
| - matrix (:obj:`torch.Tensor`): Metric matrix. | |
| Returns: | |
| - index (:obj:`np.ndarray`): Matched index. | |
| """ | |
| mat = matrix.clone().detach().to('cpu').numpy() | |
| mat = -mat # maximize | |
| M = mat.shape[0] | |
| index = np.full(M, -1, dtype=np.int32) # -1 note not find link | |
| lx = mat.max(axis=1) | |
| ly = np.zeros(M, dtype=np.float32) | |
| visx = np.zeros(M, dtype=np.bool_) | |
| visy = np.zeros(M, dtype=np.bool_) | |
| def has_augmented_path(t, binary_distance_matrix): | |
| # What is changed? visx, visy, distance_matrix, index | |
| # What is changed within this function? visx, visy, index | |
| visx[t] = True | |
| for i in range(M): | |
| if not visy[i] and binary_distance_matrix[t, i]: | |
| visy[i] = True | |
| if index[i] == -1 or has_augmented_path(index[i], binary_distance_matrix): | |
| index[i] = t | |
| return True | |
| return False | |
| for i in range(M): | |
| while True: | |
| visx.fill(False) | |
| visy.fill(False) | |
| distance_matrix = self._get_distance_matrix(lx, ly, mat, M) | |
| binary_distance_matrix = np.abs(distance_matrix) < 1e-4 | |
| if has_augmented_path(i, binary_distance_matrix): | |
| break | |
| masked_distance_matrix = distance_matrix[:, ~visy][visx] | |
| if 0 in masked_distance_matrix.shape: # empty matrix | |
| raise RuntimeError("match error, matrix: {}".format(matrix)) | |
| else: | |
| d = masked_distance_matrix.min() | |
| lx[visx] -= d | |
| ly[visy] += d | |
| return index | |
| def _get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray: | |
| """ | |
| Overview: | |
| Get distance matrix. | |
| Arguments: | |
| - lx (:obj:`np.ndarray`): lx. | |
| - ly (:obj:`np.ndarray`): ly. | |
| - mat (:obj:`np.ndarray`): mat. | |
| - M (:obj:`int`): M. | |
| """ | |
| nlx = np.broadcast_to(lx, [M, M]).T | |
| nly = np.broadcast_to(ly, [M, M]) | |
| nret = nlx + nly - mat | |
| return nret | |
| def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate multiple logits loss. | |
| Arguments: | |
| - logits (:obj:`torch.Tensor`): Predicted logits, whose shape must be 2-dim, like (B, N). | |
| - labels (:obj:`torch.LongTensor`): Ground truth. | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): Calculated loss. | |
| """ | |
| assert (len(logits.shape) == 2) | |
| metric_matrix = self._get_metric_matrix(logits, labels) | |
| index = self._match(metric_matrix) | |
| loss = [] | |
| for i in range(metric_matrix.shape[0]): | |
| loss.append(metric_matrix[index[i], i]) | |
| return sum(loss) / len(loss) | |