| """Turn raw DataFrames into the tensors the model + splitter + dataset consume. |
| |
| Pipeline (all pure functions, orchestrated by `preprocess`): |
| 1. Keep only positives (rating >= threshold). |
| 2. Iteratively drop cold users / items until both thresholds are satisfied |
| (dropping items can orphan users and vice versa — one pass isn't enough). |
| 3. Build the Vocab from the *filtered* interaction set. |
| 4. Encode interactions to (user_idx, item_idx, timestamp) int arrays. |
| 5. Encode item side features: multi-hot genres + normalized year. |
| 6. Encode user side features: gender (binary), age-bucket index, occupation |
| index. We DON'T min-max these — the bucketing is meaningful. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import re |
| from dataclasses import dataclass |
| from typing import Final |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from ..config import DataConfig |
| from ..logging_utils import get_logger |
| from .loader import RawFrames |
| from .vocab import Vocab |
|
|
| _logger = get_logger(__name__) |
|
|
| |
| _AGE_BUCKETS: Final[tuple[int, ...]] = (1, 18, 25, 35, 45, 50, 56) |
| _AGE_TO_IDX: Final[dict[int, int]] = {v: i for i, v in enumerate(_AGE_BUCKETS)} |
| _GENDER_TO_IDX: Final[dict[str, int]] = {"M": 0, "F": 1} |
| |
| _NUM_OCCUPATIONS: Final[int] = 21 |
|
|
| _YEAR_RE: Final[re.Pattern[str]] = re.compile(r"\((\d{4})\)\s*$") |
| _NO_GENRES_SENTINEL: Final[str] = "(no genres listed)" |
|
|
|
|
| @dataclass(frozen=True) |
| class ProcessedData: |
| """Everything the trainer / evaluator / recommender need from preprocessing.""" |
|
|
| vocab: Vocab |
|
|
| |
| interactions: np.ndarray |
|
|
| |
| user_features: np.ndarray |
| item_features: np.ndarray |
|
|
| |
| item_titles: np.ndarray |
|
|
| |
| user_feat_dim: int |
| item_feat_dim: int |
| genre_vocab: tuple[str, ...] |
|
|
|
|
| def preprocess(raw: RawFrames, data_cfg: DataConfig) -> ProcessedData: |
| """Run the full preprocessing pipeline.""" |
| positives = _filter_to_positives(raw.ratings, data_cfg.positive_rating_threshold) |
| positives = _iterative_min_interactions_filter( |
| positives, |
| min_user=data_cfg.min_user_interactions, |
| min_item=data_cfg.min_item_interactions, |
| ) |
|
|
| |
| |
| user_ids = sorted(positives["user_id"].unique().tolist()) |
| item_ids = sorted(positives["movie_id"].unique().tolist()) |
| vocab = Vocab.build(user_ids=user_ids, item_ids=item_ids) |
|
|
| interactions = _encode_interactions(positives, vocab) |
|
|
| item_features, genre_vocab, item_titles = _encode_item_features( |
| raw.movies, vocab |
| ) |
| |
| |
| if raw.users is not None: |
| user_features = _encode_user_features(raw.users, vocab) |
| else: |
| user_features = np.zeros((vocab.num_users, 0), dtype=np.float32) |
| _logger.info("No user demographics for this variant — user_feat_dim=0") |
|
|
| _logger.info( |
| "Preprocess complete: %d users, %d items, %d interactions, item_feat_dim=%d, user_feat_dim=%d", |
| vocab.num_users, |
| vocab.num_items, |
| len(interactions), |
| item_features.shape[1], |
| user_features.shape[1], |
| ) |
|
|
| return ProcessedData( |
| vocab=vocab, |
| interactions=interactions, |
| user_features=user_features, |
| item_features=item_features, |
| item_titles=item_titles, |
| user_feat_dim=int(user_features.shape[1]), |
| item_feat_dim=int(item_features.shape[1]), |
| genre_vocab=genre_vocab, |
| ) |
|
|
|
|
| |
|
|
|
|
| def _filter_to_positives(ratings: pd.DataFrame, threshold: float) -> pd.DataFrame: |
| out = ratings.loc[ratings["rating"] >= threshold, ["user_id", "movie_id", "timestamp"]] |
| _logger.info( |
| "Rating>=%g filter: %d -> %d interactions", threshold, len(ratings), len(out) |
| ) |
| return out.reset_index(drop=True) |
|
|
|
|
| def _iterative_min_interactions_filter( |
| df: pd.DataFrame, *, min_user: int, min_item: int |
| ) -> pd.DataFrame: |
| """Drop cold users and cold items repeatedly until both thresholds hold.""" |
| prev_len = -1 |
| out = df |
| while len(out) != prev_len: |
| prev_len = len(out) |
| u_counts = out.groupby("user_id").size() |
| i_counts = out.groupby("movie_id").size() |
| keep_users = set(u_counts[u_counts >= min_user].index) |
| keep_items = set(i_counts[i_counts >= min_item].index) |
| out = out[out["user_id"].isin(keep_users) & out["movie_id"].isin(keep_items)] |
| out = out.reset_index(drop=True) |
| _logger.info( |
| "Min-interactions filter (u>=%d, i>=%d): %d -> %d interactions", |
| min_user, |
| min_item, |
| len(df), |
| len(out), |
| ) |
| return out |
|
|
|
|
| def _encode_interactions(df: pd.DataFrame, vocab: Vocab) -> np.ndarray: |
| u = df["user_id"].map(vocab.user_to_idx).to_numpy(dtype=np.int64) |
| i = df["movie_id"].map(vocab.item_to_idx).to_numpy(dtype=np.int64) |
| t = df["timestamp"].to_numpy(dtype=np.int64) |
| return np.stack([u, i, t], axis=1) |
|
|
|
|
| def _encode_item_features( |
| movies: pd.DataFrame, vocab: Vocab |
| ) -> tuple[np.ndarray, tuple[str, ...], np.ndarray]: |
| """Multi-hot genres + min-max normalized release year. |
| |
| Year is normalized to [0, 1] based on the observed range; unparseable |
| titles get year=0 (and we log a warning count rather than crashing). |
| """ |
| movies = movies.copy() |
| movies["year_raw"] = movies["title"].map(_parse_year) |
| missing = int(movies["year_raw"].isna().sum()) |
| if missing > 0: |
| _logger.warning("Could not parse year from %d movie titles", missing) |
|
|
| |
| genres_per_movie = movies["genres"].fillna("").str.split("|") |
| all_genres: set[str] = set() |
| for genres in genres_per_movie: |
| for g in genres: |
| if g and g != _NO_GENRES_SENTINEL: |
| all_genres.add(g) |
| genre_vocab = tuple(sorted(all_genres)) |
| genre_to_idx = {g: i for i, g in enumerate(genre_vocab)} |
|
|
| num_items = vocab.num_items |
| item_feat_dim = len(genre_vocab) + 1 |
| feats = np.zeros((num_items, item_feat_dim), dtype=np.float32) |
| titles = np.empty(num_items, dtype=object) |
|
|
| |
| valid_years = [ |
| y for mid, y in zip(movies["movie_id"], movies["year_raw"]) |
| if mid in vocab.item_to_idx and pd.notna(y) |
| ] |
| if valid_years: |
| y_min, y_max = int(min(valid_years)), int(max(valid_years)) |
| else: |
| y_min, y_max = 0, 1 |
| y_range = max(y_max - y_min, 1) |
|
|
| for _, row in movies.iterrows(): |
| mid = int(row["movie_id"]) |
| if mid not in vocab.item_to_idx: |
| continue |
| idx = vocab.item_to_idx[mid] |
| titles[idx] = str(row["title"]) |
|
|
| |
| for g in str(row["genres"]).split("|"): |
| if g and g != _NO_GENRES_SENTINEL and g in genre_to_idx: |
| feats[idx, genre_to_idx[g]] = 1.0 |
|
|
| |
| year = row["year_raw"] |
| if pd.notna(year): |
| feats[idx, -1] = float((int(year) - y_min) / y_range) |
| else: |
| feats[idx, -1] = 0.0 |
|
|
| return feats, genre_vocab, titles |
|
|
|
|
| def _encode_user_features(users: pd.DataFrame, vocab: Vocab) -> np.ndarray: |
| """Gender one-hot (2) + age-bucket one-hot (7) + occupation one-hot (21).""" |
| num_users = vocab.num_users |
| gender_dim = len(_GENDER_TO_IDX) |
| age_dim = len(_AGE_BUCKETS) |
| occ_dim = _NUM_OCCUPATIONS |
| dim = gender_dim + age_dim + occ_dim |
| feats = np.zeros((num_users, dim), dtype=np.float32) |
|
|
| for _, row in users.iterrows(): |
| uid = int(row["user_id"]) |
| if uid not in vocab.user_to_idx: |
| continue |
| idx = vocab.user_to_idx[uid] |
|
|
| g = str(row["gender"]) |
| if g in _GENDER_TO_IDX: |
| feats[idx, _GENDER_TO_IDX[g]] = 1.0 |
|
|
| age = int(row["age"]) |
| if age in _AGE_TO_IDX: |
| feats[idx, gender_dim + _AGE_TO_IDX[age]] = 1.0 |
|
|
| occ = int(row["occupation"]) |
| if 0 <= occ < occ_dim: |
| feats[idx, gender_dim + age_dim + occ] = 1.0 |
|
|
| return feats |
|
|
|
|
| def _parse_year(title: object) -> float: |
| """Return the trailing (YYYY) from a title, or NaN if missing / malformed.""" |
| if not isinstance(title, str): |
| return float("nan") |
| m = _YEAR_RE.search(title) |
| return float(m.group(1)) if m else float("nan") |
|
|