import torch def spearman_correlation(y_true, y_pred): # 排序并获取秩 rank_true = torch.argsort(torch.argsort(y_true)) rank_pred = torch.argsort(torch.argsort(y_pred)) # 计算秩差的平方 d_squared = torch.pow(rank_true - rank_pred, 2).float() # 样本数量 n = y_true.size(0) # 计算 Spearman 相关系数 spearman_corr = 1 - (6 * torch.sum(d_squared)) / (n * (n**2 - 1)) return spearman_corr.item() def f1_score_max(pred, target): """ F1 score with the optimal threshold. This function first enumerates all possible thresholds for deciding positive and negative samples, and then pick the threshold with the maximal F1 score. Parameters: pred (Tensor): predictions of shape :math:`(B, N)` target (Tensor): binary targets of shape :math:`(B, N)` """ order = pred.argsort(descending=True, dim=1) target = target.gather(1, order) precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) is_start = torch.zeros_like(target).bool() is_start[:, 0] = 1 is_start = torch.scatter(is_start, 1, order, is_start) all_order = pred.flatten().argsort(descending=True) order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1] order = order.flatten() inv_order = torch.zeros_like(order) inv_order[order] = torch.arange(order.shape[0], device=order.device) is_start = is_start.flatten()[all_order] all_order = inv_order[all_order] precision = precision.flatten() recall = recall.flatten() all_precision = precision[all_order] - \ torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1]) all_precision = all_precision.cumsum(0) / is_start.cumsum(0) all_recall = recall[all_order] - \ torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1]) all_recall = all_recall.cumsum(0) / pred.shape[0] all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10) return all_f1.max()