| | import torch |
| |
|
| | def spearman_correlation(y_true, y_pred): |
| | |
| | rank_true = torch.argsort(torch.argsort(y_true)) |
| | rank_pred = torch.argsort(torch.argsort(y_pred)) |
| | |
| | |
| | d_squared = torch.pow(rank_true - rank_pred, 2).float() |
| | |
| | |
| | n = y_true.size(0) |
| | |
| | |
| | spearman_corr = 1 - (6 * torch.sum(d_squared)) / (n * (n**2 - 1)) |
| | return spearman_corr.item() |
| |
|
| |
|
| | def f1_score_max(pred, target): |
| | """ |
| | F1 score with the optimal threshold. |
| | |
| | This function first enumerates all possible thresholds for deciding positive and negative |
| | samples, and then pick the threshold with the maximal F1 score. |
| | |
| | Parameters: |
| | pred (Tensor): predictions of shape :math:`(B, N)` |
| | target (Tensor): binary targets of shape :math:`(B, N)` |
| | """ |
| | order = pred.argsort(descending=True, dim=1) |
| | target = target.gather(1, order) |
| | precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) |
| | recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) |
| | is_start = torch.zeros_like(target).bool() |
| | is_start[:, 0] = 1 |
| | is_start = torch.scatter(is_start, 1, order, is_start) |
| |
|
| | all_order = pred.flatten().argsort(descending=True) |
| | order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1] |
| | order = order.flatten() |
| | inv_order = torch.zeros_like(order) |
| | inv_order[order] = torch.arange(order.shape[0], device=order.device) |
| | is_start = is_start.flatten()[all_order] |
| | all_order = inv_order[all_order] |
| | precision = precision.flatten() |
| | recall = recall.flatten() |
| | all_precision = precision[all_order] - \ |
| | torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1]) |
| | all_precision = all_precision.cumsum(0) / is_start.cumsum(0) |
| | all_recall = recall[all_order] - \ |
| | torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1]) |
| | all_recall = all_recall.cumsum(0) / pred.shape[0] |
| | all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10) |
| | return all_f1.max() |
| |
|