"""Chronological leave-one-out split. For each user, sort interactions by timestamp (ties broken by item index for determinism). The newest becomes the test positive, second-newest the val positive, the rest go to train. The `min_user_interactions` filter in preprocessing guarantees every user has at least 3 interactions. Random LOO leaks future info and is the #1 methodology error in rec-sys tutorials — see Rendle et al. 2020, "Neural Collaborative Filtering vs. Matrix Factorization Revisited." """ from __future__ import annotations from dataclasses import dataclass import numpy as np from ..logging_utils import get_logger _logger = get_logger(__name__) @dataclass(frozen=True) class Split: """A train/val/test split of user-item positives. `train_pairs` is an [N, 2] array of (user_idx, item_idx). `val_pairs` / `test_pairs` are [M, 2] arrays — exactly one row per user that has both a val and a test interaction (they should all, by the min-interactions filter). """ train_pairs: np.ndarray val_pairs: np.ndarray test_pairs: np.ndarray def leave_one_out_split(interactions: np.ndarray) -> Split: """Split the [N, 3] (user_idx, item_idx, timestamp) array chronologically.""" if interactions.ndim != 2 or interactions.shape[1] != 3: raise ValueError( f"expected interactions of shape [N, 3], got {interactions.shape}" ) # Sort by (user_idx asc, timestamp asc, item_idx asc). # lexsort keys are applied last-first, so order is deliberate. order = np.lexsort( (interactions[:, 1], interactions[:, 2], interactions[:, 0]) ) sorted_ia = interactions[order] users = sorted_ia[:, 0] # Index within each user's (already sorted) run, counting from the end. # We compute this by finding group boundaries. boundaries = np.flatnonzero(np.diff(users, prepend=users[0] - 1, append=users[-1] + 1)) # boundaries[k]..boundaries[k+1] is the run for user k. train_idx: list[np.ndarray] = [] val_idx: list[int] = [] test_idx: list[int] = [] for start, end in zip(boundaries[:-1], boundaries[1:]): run_len = end - start if run_len < 3: # Shouldn't happen given the preprocessing filter, but be defensive. _logger.warning( "Skipping user_idx=%d with only %d interactions", int(users[start]), run_len, ) continue # Newest -> test, second-newest -> val, rest -> train. test_idx.append(end - 1) val_idx.append(end - 2) train_idx.append(np.arange(start, end - 2)) train_rows = np.concatenate(train_idx) if train_idx else np.empty(0, dtype=np.int64) val_rows = np.asarray(val_idx, dtype=np.int64) test_rows = np.asarray(test_idx, dtype=np.int64) train_pairs = sorted_ia[train_rows, :2].astype(np.int64) val_pairs = sorted_ia[val_rows, :2].astype(np.int64) test_pairs = sorted_ia[test_rows, :2].astype(np.int64) _logger.info( "LOO split: train=%d, val=%d, test=%d pairs", len(train_pairs), len(val_pairs), len(test_pairs), ) return Split(train_pairs=train_pairs, val_pairs=val_pairs, test_pairs=test_pairs) def build_user_positives( train_pairs: np.ndarray, val_pairs: np.ndarray, test_pairs: np.ndarray, num_users: int, ) -> list[set[int]]: """Return a list indexed by user_idx -> set of item indices the user has interacted with across ALL splits. Used by the negative sampler: a sampled "negative" must not collide with any positive the user has, including held-out val/test positives. Otherwise you'd be training against the very items you're evaluating on. """ seen: list[set[int]] = [set() for _ in range(num_users)] for arr in (train_pairs, val_pairs, test_pairs): for u, i in arr: seen[int(u)].add(int(i)) return seen def build_user_train_positives( train_pairs: np.ndarray, num_users: int ) -> list[set[int]]: """Subset of `build_user_positives` that includes only training positives. Used by the evaluator to mask items the user has already seen in training (so they don't get recommended as if they were novel). """ seen: list[set[int]] = [set() for _ in range(num_users)] for u, i in train_pairs: seen[int(u)].add(int(i)) return seen