File size: 4,419 Bytes
188f0cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Chronological leave-one-out split.

For each user, sort interactions by timestamp (ties broken by item index for
determinism). The newest becomes the test positive, second-newest the val
positive, the rest go to train. The `min_user_interactions` filter in
preprocessing guarantees every user has at least 3 interactions.

Random LOO leaks future info and is the #1 methodology error in rec-sys
tutorials — see Rendle et al. 2020, "Neural Collaborative Filtering vs.
Matrix Factorization Revisited."
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from ..logging_utils import get_logger

_logger = get_logger(__name__)


@dataclass(frozen=True)
class Split:
    """A train/val/test split of user-item positives.

    `train_pairs` is an [N, 2] array of (user_idx, item_idx).
    `val_pairs` / `test_pairs` are [M, 2] arrays — exactly one row per user
    that has both a val and a test interaction (they should all, by the
    min-interactions filter).
    """

    train_pairs: np.ndarray
    val_pairs: np.ndarray
    test_pairs: np.ndarray


def leave_one_out_split(interactions: np.ndarray) -> Split:
    """Split the [N, 3] (user_idx, item_idx, timestamp) array chronologically."""
    if interactions.ndim != 2 or interactions.shape[1] != 3:
        raise ValueError(
            f"expected interactions of shape [N, 3], got {interactions.shape}"
        )

    # Sort by (user_idx asc, timestamp asc, item_idx asc).
    # lexsort keys are applied last-first, so order is deliberate.
    order = np.lexsort(
        (interactions[:, 1], interactions[:, 2], interactions[:, 0])
    )
    sorted_ia = interactions[order]

    users = sorted_ia[:, 0]
    # Index within each user's (already sorted) run, counting from the end.
    # We compute this by finding group boundaries.
    boundaries = np.flatnonzero(np.diff(users, prepend=users[0] - 1, append=users[-1] + 1))
    # boundaries[k]..boundaries[k+1] is the run for user k.

    train_idx: list[np.ndarray] = []
    val_idx: list[int] = []
    test_idx: list[int] = []

    for start, end in zip(boundaries[:-1], boundaries[1:]):
        run_len = end - start
        if run_len < 3:
            # Shouldn't happen given the preprocessing filter, but be defensive.
            _logger.warning(
                "Skipping user_idx=%d with only %d interactions",
                int(users[start]),
                run_len,
            )
            continue
        # Newest -> test, second-newest -> val, rest -> train.
        test_idx.append(end - 1)
        val_idx.append(end - 2)
        train_idx.append(np.arange(start, end - 2))

    train_rows = np.concatenate(train_idx) if train_idx else np.empty(0, dtype=np.int64)
    val_rows = np.asarray(val_idx, dtype=np.int64)
    test_rows = np.asarray(test_idx, dtype=np.int64)

    train_pairs = sorted_ia[train_rows, :2].astype(np.int64)
    val_pairs = sorted_ia[val_rows, :2].astype(np.int64)
    test_pairs = sorted_ia[test_rows, :2].astype(np.int64)

    _logger.info(
        "LOO split: train=%d, val=%d, test=%d pairs",
        len(train_pairs),
        len(val_pairs),
        len(test_pairs),
    )

    return Split(train_pairs=train_pairs, val_pairs=val_pairs, test_pairs=test_pairs)


def build_user_positives(
    train_pairs: np.ndarray,
    val_pairs: np.ndarray,
    test_pairs: np.ndarray,
    num_users: int,
) -> list[set[int]]:
    """Return a list indexed by user_idx -> set of item indices the user has
    interacted with across ALL splits.

    Used by the negative sampler: a sampled "negative" must not collide with
    any positive the user has, including held-out val/test positives.
    Otherwise you'd be training against the very items you're evaluating on.
    """
    seen: list[set[int]] = [set() for _ in range(num_users)]
    for arr in (train_pairs, val_pairs, test_pairs):
        for u, i in arr:
            seen[int(u)].add(int(i))
    return seen


def build_user_train_positives(
    train_pairs: np.ndarray, num_users: int
) -> list[set[int]]:
    """Subset of `build_user_positives` that includes only training positives.

    Used by the evaluator to mask items the user has already seen in training
    (so they don't get recommended as if they were novel).
    """
    seen: list[set[int]] = [set() for _ in range(num_users)]
    for u, i in train_pairs:
        seen[int(u)].add(int(i))
    return seen