| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torchsort |
| from scipy.stats import spearmanr, pearsonr |
| from sklearn.metrics import ndcg_score |
|
|
| RELEVANCE_SCALE = 5.0 |
|
|
|
|
| def _pearson_corrcoef_torch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
|
| x_centered = x - x.mean() |
| y_centered = y - y.mean() |
|
|
| cov = (x_centered * y_centered).mean() |
| std_x = x_centered.std(unbiased=False) |
| std_y = y_centered.std(unbiased=False) |
|
|
| return cov / (std_x * std_y + 1e-8) |
|
|
|
|
| def pearson_corrcoef(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
|
| x_np = x.detach().cpu().numpy() |
| y_np = y.detach().cpu().numpy() |
| corr, _ = pearsonr(x_np, y_np) |
| return torch.tensor(corr, dtype=x.dtype) |
|
|
|
|
| def spearman_corrcoef(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
|
| x_np = x.detach().cpu().numpy() |
| y_np = y.detach().cpu().numpy() |
| corr, _ = spearmanr(x_np, y_np) |
| return torch.tensor(corr, dtype=x.dtype) |
|
|
|
|
| def ndcg_at_k( |
| preds: torch.Tensor, |
| targets: torch.Tensor, |
| k: int = 10, |
| gain_type: str = 'exp2', |
| ) -> float: |
|
|
| preds_np = preds.detach().cpu().float().numpy() |
| targets_np = targets.detach().cpu().float().numpy() |
|
|
| n = len(preds_np) |
| k = min(k, n) |
|
|
| target_range = targets_np.max() - targets_np.min() |
| if target_range < 1e-8: |
| return 1.0 |
|
|
| targets_scaled = (targets_np - targets_np.min()) / target_range * RELEVANCE_SCALE |
|
|
| if gain_type == 'exp2': |
| return float(ndcg_score( |
| y_true=targets_scaled.reshape(1, -1), |
| y_score=preds_np.reshape(1, -1), |
| k=k, |
| )) |
| else: |
| sorted_by_pred = targets_scaled[np.argsort(-preds_np)][:k] |
| sorted_ideal = np.sort(targets_scaled)[::-1][:k] |
| positions = np.arange(k) + 2.0 |
| discounts = 1.0 / np.log2(positions) |
| dcg = (sorted_by_pred * discounts).sum() |
| idcg = (sorted_ideal * discounts).sum() |
| if idcg < 1e-8: |
| return 1.0 |
| return float(dcg / idcg) |
|
|
|
|
| def ndcg_at_k_per_prompt( |
| preds: torch.Tensor, |
| targets: torch.Tensor, |
| prompt_ids: torch.Tensor, |
| k: int = 5, |
| gain_type: str = 'exp2', |
| ) -> float: |
| preds_np = preds.detach().cpu().float().numpy() |
| targets_np = targets.detach().cpu().float().numpy() |
| pids_np = prompt_ids.detach().cpu().numpy() |
|
|
| unique_pids = np.unique(pids_np) |
| ndcg_scores = [] |
|
|
| for pid in unique_pids: |
| mask = (pids_np == pid) |
| p = preds_np[mask] |
| t = targets_np[mask] |
|
|
| n = len(p) |
| if n < 2: |
| continue |
|
|
| kk = min(k, n) |
|
|
| t_range = t.max() - t.min() |
| if t_range < 1e-8: |
| continue |
|
|
| t_scaled = (t - t.min()) / t_range * RELEVANCE_SCALE |
|
|
| if gain_type == 'exp2': |
| score = float(ndcg_score( |
| y_true=t_scaled.reshape(1, -1), |
| y_score=p.reshape(1, -1), |
| k=kk, |
| )) |
| else: |
| sorted_by_pred = t_scaled[np.argsort(-p)][:kk] |
| sorted_ideal = np.sort(t_scaled)[::-1][:kk] |
| positions = np.arange(kk) + 2.0 |
| discounts = 1.0 / np.log2(positions) |
| dcg = (sorted_by_pred * discounts).sum() |
| idcg = (sorted_ideal * discounts).sum() |
| score = float(dcg / idcg) if idcg > 1e-8 else 1.0 |
|
|
| ndcg_scores.append(score) |
|
|
| if len(ndcg_scores) == 0: |
| return 0.0 |
|
|
| return float(np.mean(ndcg_scores)) |
|
|
|
|
| class SRCCLoss(nn.Module): |
|
|
| def __init__(self, regularization_strength: float = 1e-2): |
| super().__init__() |
| self.regularization_strength = regularization_strength |
|
|
| def _srcc_for_group(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| p = preds.view(1, -1) |
| t = targets.view(1, -1) |
| pred_ranks = torchsort.soft_rank(p, regularization_strength=self.regularization_strength) |
| target_ranks = torchsort.soft_rank(t, regularization_strength=self.regularization_strength) |
| srcc = _pearson_corrcoef_torch(pred_ranks.squeeze(0), target_ranks.squeeze(0)) |
| return 1.0 - srcc |
|
|
| def forward( |
| self, |
| preds: torch.Tensor, |
| targets: torch.Tensor, |
| group_ids: torch.Tensor = None, |
| ) -> torch.Tensor: |
| if group_ids is None: |
| return self._srcc_for_group(preds, targets) |
|
|
| srcc_losses = [] |
| for gid in group_ids.unique(): |
| mask = (group_ids == gid) |
| if mask.sum() < 2: |
| continue |
| srcc_losses.append(self._srcc_for_group(preds[mask], targets[mask])) |
|
|
| if len(srcc_losses) == 0: |
| return self._srcc_for_group(preds, targets) |
|
|
| return torch.stack(srcc_losses).mean() |
|
|
|
|
| class LambdaRankLoss(nn.Module): |
|
|
| def __init__(self, sigma: float = 1.0, gain_type: str = 'exp2'): |
| super().__init__() |
| self.sigma = sigma |
| self.gain_type = gain_type |
|
|
| def _compute_max_dcg(self, targets: torch.Tensor) -> float: |
| t = targets.view(-1).float() |
| n = len(t) |
|
|
| sorted_targets, _ = torch.sort(t, descending=True) |
| shifted = sorted_targets - sorted_targets.min() |
|
|
| if self.gain_type == 'exp2': |
| gains = torch.pow(2.0, shifted) - 1.0 |
| else: |
| gains = shifted |
|
|
| positions = torch.arange(1, n + 1, device=t.device, dtype=torch.float32) |
| discounts = 1.0 / torch.log2(positions + 1.0) |
|
|
| return (gains * discounts).sum().item() |
|
|
| def _compute_lambdas_for_group(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| prediction = prediction.view(-1, 1) |
| target = target.view(-1, 1) |
| target_flat = target.view(-1) |
|
|
| n = prediction.size(0) |
| if n < 2: |
| return torch.zeros_like(prediction) |
|
|
| max_dcg = self._compute_max_dcg(target_flat) |
| if max_dcg < 1e-8: |
| return torch.zeros_like(prediction) |
| N = 1.0 / max_dcg |
|
|
| sorted_indices = torch.argsort(target_flat, descending=True) |
| rank_order = torch.zeros(n, dtype=torch.long, device=prediction.device) |
| rank_order[sorted_indices] = torch.arange(1, n + 1, device=prediction.device) |
| rank_order = rank_order.float().view(-1, 1) |
|
|
| pred32 = prediction.float() |
| tgt32 = target.float() |
| rank32 = rank_order.float() |
|
|
| p_ij = torch.sigmoid(-self.sigma * (pred32 - pred32.t())) |
|
|
| rel_diff = tgt32 - tgt32.t() |
| pos_pairs = (rel_diff > 0).float() |
| neg_pairs = (rel_diff < 0).float() |
| Sij = pos_pairs - neg_pairs |
|
|
| tgt_shifted = tgt32 - tgt32.min() |
| if self.gain_type == 'exp2': |
| gain_diff = torch.pow(2.0, tgt_shifted) - torch.pow(2.0, tgt_shifted.t()) |
| else: |
| gain_diff = tgt_shifted - tgt_shifted.t() |
|
|
| decay_diff = (1.0 / torch.log2(rank32 + 1.0) - |
| 1.0 / torch.log2(rank32.t() + 1.0)) |
|
|
| delta_ndcg = torch.abs(N * gain_diff * decay_diff) |
|
|
| lambda_update = self.sigma * (0.5 * (1 - Sij) - p_ij) * delta_ndcg |
| lambda_update = torch.sum(lambda_update, dim=1, keepdim=True) |
|
|
| return lambda_update |
|
|
| def forward( |
| self, |
| prediction: torch.Tensor, |
| target: torch.Tensor, |
| weight: float = 1.0, |
| retain_graph: bool = False, |
| group_ids: torch.Tensor = None, |
| ) -> None: |
| original_dtype = prediction.dtype |
| pred_flat = prediction.view(-1, 1) |
| tgt_flat = target.view(-1, 1) |
|
|
| n = pred_flat.size(0) |
| if n < 2: |
| return |
|
|
| with torch.no_grad(): |
| if group_ids is None: |
| lambda_update = self._compute_lambdas_for_group(pred_flat, tgt_flat) |
| else: |
| lambda_update = torch.zeros_like(pred_flat) |
| for gid in group_ids.unique(): |
| mask = (group_ids == gid) |
| if mask.sum() < 2: |
| continue |
| group_lambdas = self._compute_lambdas_for_group( |
| pred_flat[mask], tgt_flat[mask], |
| ) |
| lambda_update[mask] = group_lambdas |
|
|
| lambda_update = weight * lambda_update |
| lambda_update = lambda_update.to(original_dtype) |
|
|
| prediction.view(-1, 1).backward(lambda_update, retain_graph=retain_graph) |
|
|
|
|
| class MAESRCCLoss(nn.Module): |
|
|
| def __init__(self, srcc_weight: float = 1.0, regularization_strength: float = 1e-2): |
| super().__init__() |
| self.mae = nn.L1Loss() |
| self.srcc_loss = SRCCLoss(regularization_strength) |
| self.srcc_weight = srcc_weight |
|
|
| def forward( |
| self, |
| preds: torch.Tensor, |
| targets: torch.Tensor, |
| group_ids: torch.Tensor = None, |
| ) -> torch.Tensor: |
| mae_loss = self.mae(preds, targets) |
| srcc_loss = self.srcc_loss(preds, targets, group_ids=group_ids) |
| return mae_loss + self.srcc_weight * srcc_loss |
|
|
|
|
| class MAELambdaRankLoss(nn.Module): |
|
|
| def __init__(self, lambdarank_weight: float = 1.0, sigma: float = 1.0, gain_type: str = 'exp2'): |
| super().__init__() |
| self.mae = nn.L1Loss() |
| self.lambdarank = LambdaRankLoss(sigma=sigma, gain_type=gain_type) |
| self.lambdarank_weight = lambdarank_weight |
|
|
| def forward( |
| self, |
| preds: torch.Tensor, |
| targets: torch.Tensor, |
| group_ids: torch.Tensor = None, |
| ) -> torch.Tensor: |
| return self.mae(preds, targets) |
|
|
| def backward(self, preds: torch.Tensor, targets: torch.Tensor, mae_loss: torch.Tensor, group_ids: torch.Tensor = None) -> None: |
| mae_loss.backward(retain_graph=True) |
| if self.lambdarank_weight > 0: |
| self.lambdarank(preds, targets, weight=self.lambdarank_weight, group_ids=group_ids) |
|
|