| """PyTorch Dataset for BPR training with on-the-fly negative sampling. |
| |
| Each item yields (user_idx, pos_item_idx, neg_item_idxs[K]) where `neg_item_idxs` |
| are uniformly sampled item indices that the user has NOT positively interacted |
| with in any split (train OR val OR test). Rejection sampling is fine here: |
| typical sparsity is |seen_u| / num_items ≪ 1%, so expected retries are ~0.01. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Final |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| _MAX_REJECTION_ATTEMPTS: Final[int] = 50 |
|
|
|
|
| class BPRDataset(Dataset): |
| """Yields (user, pos, negs) triples for pairwise ranking training.""" |
|
|
| def __init__( |
| self, |
| train_pairs: np.ndarray, |
| user_positives: list[set[int]], |
| num_items: int, |
| num_negatives: int, |
| seed: int, |
| ) -> None: |
| if train_pairs.ndim != 2 or train_pairs.shape[1] != 2: |
| raise ValueError( |
| f"train_pairs must have shape [N, 2], got {train_pairs.shape}" |
| ) |
| if num_negatives < 1: |
| raise ValueError(f"num_negatives must be >= 1, got {num_negatives}") |
|
|
| self._pairs = train_pairs.astype(np.int64, copy=False) |
| self._user_positives = user_positives |
| self._num_items = int(num_items) |
| self._num_negatives = int(num_negatives) |
| |
| |
| self._rng = np.random.default_rng(seed) |
|
|
| def __len__(self) -> int: |
| return self._pairs.shape[0] |
|
|
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| u, i_pos = self._pairs[idx] |
| seen = self._user_positives[int(u)] |
| negs = np.empty(self._num_negatives, dtype=np.int64) |
| for k in range(self._num_negatives): |
| negs[k] = self._sample_negative(seen) |
| return ( |
| torch.from_numpy(np.asarray(u, dtype=np.int64)), |
| torch.from_numpy(np.asarray(i_pos, dtype=np.int64)), |
| torch.from_numpy(negs), |
| ) |
|
|
| def _sample_negative(self, seen: set[int]) -> int: |
| """Rejection-sample a single item index not in `seen`.""" |
| for _ in range(_MAX_REJECTION_ATTEMPTS): |
| j = int(self._rng.integers(0, self._num_items)) |
| if j not in seen: |
| return j |
| |
| |
| return int(self._rng.integers(0, self._num_items)) |
|
|