Spaces:
Runtime error
Runtime error
| """ | |
| File copied from | |
| https://github.com/nicola-decao/diffmask/blob/master/diffmask/utils/util.py | |
| """ | |
| import torch | |
| from torch import Tensor | |
| def accuracy_precision_recall_f1( | |
| y_pred: Tensor, y_true: Tensor, average: bool = True | |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor]: | |
| """Calculates the accuracy, precision, recall and f1 score given the predicted and true labels. | |
| Args: | |
| y_pred (Tensor): predicted labels | |
| y_true (Tensor): true labels | |
| average (bool): whether to average the scores or not | |
| Returns: | |
| a tuple of the accuracy, precision, recall and f1 score | |
| """ | |
| M = confusion_matrix(y_pred, y_true) | |
| tp = M.diagonal(dim1=-2, dim2=-1).float() | |
| precision_den = M.sum(-2) | |
| precision = torch.where( | |
| precision_den == 0, torch.zeros_like(tp), tp / precision_den | |
| ) | |
| recall_den = M.sum(-1) | |
| recall = torch.where(recall_den == 0, torch.ones_like(tp), tp / recall_den) | |
| f1_den = precision + recall | |
| f1 = torch.where( | |
| f1_den == 0, torch.zeros_like(tp), 2 * (precision * recall) / f1_den | |
| ) | |
| # noinspection PyTypeChecker | |
| return ((y_pred == y_true).float().mean(-1),) + ( | |
| tuple(e.mean(-1) for e in (precision, recall, f1)) | |
| if average | |
| else (precision, recall, f1) | |
| ) | |
| def confusion_matrix(y_pred: Tensor, y_true: Tensor) -> Tensor: | |
| """Creates a confusion matrix given the predicted and true labels.""" | |
| device = y_pred.device | |
| labels = max(y_pred.max().item() + 1, y_true.max().item() + 1) | |
| return ( | |
| ( | |
| torch.stack((y_true, y_pred), -1).unsqueeze(-2).unsqueeze(-2) | |
| == torch.stack( | |
| ( | |
| torch.arange(labels, device=device).unsqueeze(-1).repeat(1, labels), | |
| torch.arange(labels, device=device).unsqueeze(-2).repeat(labels, 1), | |
| ), | |
| -1, | |
| ) | |
| ) | |
| .all(-1) | |
| .sum(-3) | |
| ) | |