"""PyTorch Dataset for BPR training with on-the-fly negative sampling. Each item yields (user_idx, pos_item_idx, neg_item_idxs[K]) where `neg_item_idxs` are uniformly sampled item indices that the user has NOT positively interacted with in any split (train OR val OR test). Rejection sampling is fine here: typical sparsity is |seen_u| / num_items ≪ 1%, so expected retries are ~0.01. """ from __future__ import annotations from typing import Final import numpy as np import torch from torch.utils.data import Dataset _MAX_REJECTION_ATTEMPTS: Final[int] = 50 # safety bound; sparsity makes this fine class BPRDataset(Dataset): """Yields (user, pos, negs) triples for pairwise ranking training.""" def __init__( self, train_pairs: np.ndarray, user_positives: list[set[int]], num_items: int, num_negatives: int, seed: int, ) -> None: if train_pairs.ndim != 2 or train_pairs.shape[1] != 2: raise ValueError( f"train_pairs must have shape [N, 2], got {train_pairs.shape}" ) if num_negatives < 1: raise ValueError(f"num_negatives must be >= 1, got {num_negatives}") self._pairs = train_pairs.astype(np.int64, copy=False) self._user_positives = user_positives self._num_items = int(num_items) self._num_negatives = int(num_negatives) # Per-worker RNGs are re-seeded in `worker_init_fn`. This is only used # in single-worker mode; multi-worker overrides via worker_init_fn. self._rng = np.random.default_rng(seed) def __len__(self) -> int: return self._pairs.shape[0] def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: u, i_pos = self._pairs[idx] seen = self._user_positives[int(u)] negs = np.empty(self._num_negatives, dtype=np.int64) for k in range(self._num_negatives): negs[k] = self._sample_negative(seen) return ( torch.from_numpy(np.asarray(u, dtype=np.int64)), torch.from_numpy(np.asarray(i_pos, dtype=np.int64)), torch.from_numpy(negs), ) def _sample_negative(self, seen: set[int]) -> int: """Rejection-sample a single item index not in `seen`.""" for _ in range(_MAX_REJECTION_ATTEMPTS): j = int(self._rng.integers(0, self._num_items)) if j not in seen: return j # Extremely rare fallback for pathological users who've seen nearly # everything; just return any index and let training absorb the noise. return int(self._rng.integers(0, self._num_items))