| """Full-catalog (or sampled) ranking evaluator. |
| |
| Default strategy `full_catalog`: for each user, rank the held-out positive |
| against all items in the catalog. Items the user saw during training (and, |
| for the test split, during validation) are masked to -inf before ranking. |
| |
| We prefer full-catalog over the legacy 1+100 sampled-negatives protocol |
| because sampled metrics are inconsistent with full-catalog metrics — a |
| model that is worse under full-catalog can look better under sampled |
| (Krichene & Rendle, KDD 2020). Sampled is still exposed via config for |
| reproducing NCF-era benchmark numbers, but the evaluator logs a warning. |
| |
| Tie-breaking: if multiple items have the same score, we count items |
| STRICTLY greater than the positive's score. This gives the positive its |
| most favorable possible rank among ties — matches the convention used in |
| most papers. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| from torch import Tensor, nn |
|
|
| from ..config import Config |
| from ..logging_utils import get_logger |
| from ..models.base import BaseRecommender |
| from . import metrics |
|
|
| _logger = get_logger(__name__) |
|
|
|
|
| def _unwrap(model: nn.Module) -> nn.Module: |
| """DataParallel only forwards `forward()`. For methods like score_all_items |
| we need the underlying module.""" |
| return model.module if isinstance(model, nn.DataParallel) else model |
|
|
|
|
| class Evaluator: |
| def __init__( |
| self, |
| *, |
| cfg: Config, |
| val_pairs: np.ndarray, |
| test_pairs: np.ndarray, |
| user_train_positives: list[set[int]], |
| user_val_positives: list[set[int]] | None, |
| num_items: int, |
| device: torch.device, |
| ) -> None: |
| self.cfg = cfg |
| self.k = cfg.evaluation.k |
| self.strategy = cfg.evaluation.strategy |
| self.val_pairs = val_pairs.astype(np.int64, copy=False) |
| self.test_pairs = test_pairs.astype(np.int64, copy=False) |
| self.user_train_positives = user_train_positives |
| self.user_val_positives = user_val_positives |
| self.num_items = int(num_items) |
| self.device = device |
|
|
| if self.strategy == "sampled": |
| _logger.warning( |
| "evaluation.strategy='sampled' — results are NOT comparable to " |
| "full_catalog metrics. Use full_catalog for modern benchmarks." |
| ) |
|
|
| @torch.no_grad() |
| def evaluate(self, model: BaseRecommender, split: str) -> dict[str, float]: |
| """Compute HR/NDCG/Recall/MAP @ K on the given split.""" |
| if split == "val": |
| pairs = self.val_pairs |
| mask_sources: list[list[set[int]]] = [self.user_train_positives] |
| elif split == "test": |
| pairs = self.test_pairs |
| mask_sources = [self.user_train_positives] |
| if self.user_val_positives is not None: |
| mask_sources.append(self.user_val_positives) |
| else: |
| raise ValueError(f"split must be 'val' or 'test', got {split!r}") |
|
|
| was_training = model.training |
| model.eval() |
| try: |
| ranks = self._compute_ranks(model, pairs, mask_sources) |
| finally: |
| if was_training: |
| model.train() |
|
|
| k = self.k |
| return { |
| f"hr@{k}": metrics.hit_rate_at_k(ranks, k), |
| f"ndcg@{k}": metrics.ndcg_at_k(ranks, k), |
| f"recall@{k}": metrics.recall_at_k(ranks, k), |
| f"map@{k}": metrics.map_at_k(ranks, k), |
| } |
|
|
| |
|
|
| def _compute_ranks( |
| self, |
| model: BaseRecommender, |
| pairs: np.ndarray, |
| mask_sources: list[list[set[int]]], |
| ) -> np.ndarray: |
| batch_size = self.cfg.evaluation.eval_batch_size |
| num = pairs.shape[0] |
| ranks = np.empty(num, dtype=np.int64) |
|
|
| for start in range(0, num, batch_size): |
| end = min(start + batch_size, num) |
| batch = pairs[start:end] |
| users_t = torch.from_numpy(batch[:, 0]).to(self.device) |
| pos_items_t = torch.from_numpy(batch[:, 1]).to(self.device) |
|
|
| if self.strategy == "full_catalog": |
| batch_ranks = self._rank_full_catalog( |
| model, users_t, pos_items_t, batch[:, 0], mask_sources |
| ) |
| else: |
| batch_ranks = self._rank_sampled( |
| model, users_t, pos_items_t, batch[:, 0], mask_sources |
| ) |
|
|
| ranks[start:end] = batch_ranks |
|
|
| return ranks |
|
|
| def _rank_full_catalog( |
| self, |
| model: BaseRecommender, |
| users_t: Tensor, |
| pos_items_t: Tensor, |
| users_np: np.ndarray, |
| mask_sources: list[list[set[int]]], |
| ) -> np.ndarray: |
| scores = _unwrap(model).score_all_items(users_t) |
|
|
| |
| |
| |
| self._apply_seen_mask(scores, users_np, pos_items_t.cpu().numpy(), mask_sources) |
|
|
| pos_scores = scores.gather(1, pos_items_t.unsqueeze(1)) |
| |
| higher = (scores > pos_scores).sum(dim=1) + 1 |
| return higher.cpu().numpy().astype(np.int64) |
|
|
| def _rank_sampled( |
| self, |
| model: BaseRecommender, |
| users_t: Tensor, |
| pos_items_t: Tensor, |
| users_np: np.ndarray, |
| mask_sources: list[list[set[int]]], |
| ) -> np.ndarray: |
| num_neg = self.cfg.evaluation.num_sampled_negatives |
| rng = np.random.default_rng(self.cfg.seed) |
| B = users_np.shape[0] |
|
|
| |
| |
| neg_arr = np.empty((B, num_neg), dtype=np.int64) |
| for b in range(B): |
| u = int(users_np[b]) |
| forbidden = set() |
| for src in mask_sources: |
| forbidden |= src[u] |
| forbidden.add(int(pos_items_t[b].item())) |
| count = 0 |
| while count < num_neg: |
| j = int(rng.integers(0, self.num_items)) |
| if j in forbidden: |
| continue |
| neg_arr[b, count] = j |
| count += 1 |
|
|
| neg_t = torch.from_numpy(neg_arr).to(self.device) |
| m = _unwrap(model) |
| neg_scores = m.score(users_t.unsqueeze(-1).expand_as(neg_t), neg_t) |
| pos_scores = m.score(users_t, pos_items_t).unsqueeze(1) |
| higher = (neg_scores > pos_scores).sum(dim=1) + 1 |
| return higher.cpu().numpy().astype(np.int64) |
|
|
| def _apply_seen_mask( |
| self, |
| scores: Tensor, |
| users_np: np.ndarray, |
| pos_items_np: np.ndarray, |
| mask_sources: list[list[set[int]]], |
| ) -> None: |
| """In-place mask: set scores[b, i] = -inf for i in user_b's seen set |
| (excluding the positive item we're trying to rank).""" |
| neg_inf = float("-inf") |
| for b, u in enumerate(users_np): |
| u = int(u) |
| pos_i = int(pos_items_np[b]) |
| for src in mask_sources: |
| seen = src[u] |
| if not seen: |
| continue |
| |
| idxs = [i for i in seen if i != pos_i] |
| if idxs: |
| scores[b, idxs] = neg_inf |
|
|