nas / PFMBench /src /utils /metrics.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
import torch
def spearman_correlation(y_true, y_pred):
# 排序并获取秩
rank_true = torch.argsort(torch.argsort(y_true))
rank_pred = torch.argsort(torch.argsort(y_pred))
# 计算秩差的平方
d_squared = torch.pow(rank_true - rank_pred, 2).float()
# 样本数量
n = y_true.size(0)
# 计算 Spearman 相关系数
spearman_corr = 1 - (6 * torch.sum(d_squared)) / (n * (n**2 - 1))
return spearman_corr.item()
def f1_score_max(pred, target):
"""
F1 score with the optimal threshold.
This function first enumerates all possible thresholds for deciding positive and negative
samples, and then pick the threshold with the maximal F1 score.
Parameters:
pred (Tensor): predictions of shape :math:`(B, N)`
target (Tensor): binary targets of shape :math:`(B, N)`
"""
order = pred.argsort(descending=True, dim=1)
target = target.gather(1, order)
precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
is_start = torch.zeros_like(target).bool()
is_start[:, 0] = 1
is_start = torch.scatter(is_start, 1, order, is_start)
all_order = pred.flatten().argsort(descending=True)
order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1]
order = order.flatten()
inv_order = torch.zeros_like(order)
inv_order[order] = torch.arange(order.shape[0], device=order.device)
is_start = is_start.flatten()[all_order]
all_order = inv_order[all_order]
precision = precision.flatten()
recall = recall.flatten()
all_precision = precision[all_order] - \
torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1])
all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
all_recall = recall[all_order] - \
torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1])
all_recall = all_recall.cumsum(0) / pred.shape[0]
all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
return all_f1.max()