"""Turn raw DataFrames into the tensors the model + splitter + dataset consume. Pipeline (all pure functions, orchestrated by `preprocess`): 1. Keep only positives (rating >= threshold). 2. Iteratively drop cold users / items until both thresholds are satisfied (dropping items can orphan users and vice versa — one pass isn't enough). 3. Build the Vocab from the *filtered* interaction set. 4. Encode interactions to (user_idx, item_idx, timestamp) int arrays. 5. Encode item side features: multi-hot genres + normalized year. 6. Encode user side features: gender (binary), age-bucket index, occupation index. We DON'T min-max these — the bucketing is meaningful. """ from __future__ import annotations import re from dataclasses import dataclass from typing import Final import numpy as np import pandas as pd from ..config import DataConfig from ..logging_utils import get_logger from .loader import RawFrames from .vocab import Vocab _logger = get_logger(__name__) # ml-1m age buckets per the dataset README. _AGE_BUCKETS: Final[tuple[int, ...]] = (1, 18, 25, 35, 45, 50, 56) _AGE_TO_IDX: Final[dict[int, int]] = {v: i for i, v in enumerate(_AGE_BUCKETS)} _GENDER_TO_IDX: Final[dict[str, int]] = {"M": 0, "F": 1} # ml-1m occupation codes are in [0, 20]. We keep the raw id as the index. _NUM_OCCUPATIONS: Final[int] = 21 _YEAR_RE: Final[re.Pattern[str]] = re.compile(r"\((\d{4})\)\s*$") _NO_GENRES_SENTINEL: Final[str] = "(no genres listed)" @dataclass(frozen=True) class ProcessedData: """Everything the trainer / evaluator / recommender need from preprocessing.""" vocab: Vocab # Positives — each row is (user_idx, item_idx, timestamp). interactions: np.ndarray # shape [N, 3], dtype int64 # Side-feature tables, indexed by user_idx / item_idx. user_features: np.ndarray # shape [num_users, user_feat_dim], float32 item_features: np.ndarray # shape [num_items, item_feat_dim], float32 # For display / inference — title per item index. item_titles: np.ndarray # shape [num_items], dtype object (str) # Metadata for side-feature encoding (dim breakdowns). user_feat_dim: int item_feat_dim: int genre_vocab: tuple[str, ...] def preprocess(raw: RawFrames, data_cfg: DataConfig) -> ProcessedData: """Run the full preprocessing pipeline.""" positives = _filter_to_positives(raw.ratings, data_cfg.positive_rating_threshold) positives = _iterative_min_interactions_filter( positives, min_user=data_cfg.min_user_interactions, min_item=data_cfg.min_item_interactions, ) # Build vocab in a stable order — sorted by raw id to make the mapping # deterministic across runs. user_ids = sorted(positives["user_id"].unique().tolist()) item_ids = sorted(positives["movie_id"].unique().tolist()) vocab = Vocab.build(user_ids=user_ids, item_ids=item_ids) interactions = _encode_interactions(positives, vocab) item_features, genre_vocab, item_titles = _encode_item_features( raw.movies, vocab ) # Some variants (ml-25m, ml-32m, ml-latest) don't ship user demographics. # In that case the user tower runs on its ID embedding only. if raw.users is not None: user_features = _encode_user_features(raw.users, vocab) else: user_features = np.zeros((vocab.num_users, 0), dtype=np.float32) _logger.info("No user demographics for this variant — user_feat_dim=0") _logger.info( "Preprocess complete: %d users, %d items, %d interactions, item_feat_dim=%d, user_feat_dim=%d", vocab.num_users, vocab.num_items, len(interactions), item_features.shape[1], user_features.shape[1], ) return ProcessedData( vocab=vocab, interactions=interactions, user_features=user_features, item_features=item_features, item_titles=item_titles, user_feat_dim=int(user_features.shape[1]), item_feat_dim=int(item_features.shape[1]), genre_vocab=genre_vocab, ) # ---------- internals ---------- def _filter_to_positives(ratings: pd.DataFrame, threshold: float) -> pd.DataFrame: out = ratings.loc[ratings["rating"] >= threshold, ["user_id", "movie_id", "timestamp"]] _logger.info( "Rating>=%g filter: %d -> %d interactions", threshold, len(ratings), len(out) ) return out.reset_index(drop=True) def _iterative_min_interactions_filter( df: pd.DataFrame, *, min_user: int, min_item: int ) -> pd.DataFrame: """Drop cold users and cold items repeatedly until both thresholds hold.""" prev_len = -1 out = df while len(out) != prev_len: prev_len = len(out) u_counts = out.groupby("user_id").size() i_counts = out.groupby("movie_id").size() keep_users = set(u_counts[u_counts >= min_user].index) keep_items = set(i_counts[i_counts >= min_item].index) out = out[out["user_id"].isin(keep_users) & out["movie_id"].isin(keep_items)] out = out.reset_index(drop=True) _logger.info( "Min-interactions filter (u>=%d, i>=%d): %d -> %d interactions", min_user, min_item, len(df), len(out), ) return out def _encode_interactions(df: pd.DataFrame, vocab: Vocab) -> np.ndarray: u = df["user_id"].map(vocab.user_to_idx).to_numpy(dtype=np.int64) i = df["movie_id"].map(vocab.item_to_idx).to_numpy(dtype=np.int64) t = df["timestamp"].to_numpy(dtype=np.int64) return np.stack([u, i, t], axis=1) def _encode_item_features( movies: pd.DataFrame, vocab: Vocab ) -> tuple[np.ndarray, tuple[str, ...], np.ndarray]: """Multi-hot genres + min-max normalized release year. Year is normalized to [0, 1] based on the observed range; unparseable titles get year=0 (and we log a warning count rather than crashing). """ movies = movies.copy() movies["year_raw"] = movies["title"].map(_parse_year) missing = int(movies["year_raw"].isna().sum()) if missing > 0: _logger.warning("Could not parse year from %d movie titles", missing) # Build the genre vocabulary deterministically from the dataset. genres_per_movie = movies["genres"].fillna("").str.split("|") all_genres: set[str] = set() for genres in genres_per_movie: for g in genres: if g and g != _NO_GENRES_SENTINEL: all_genres.add(g) genre_vocab = tuple(sorted(all_genres)) genre_to_idx = {g: i for i, g in enumerate(genre_vocab)} num_items = vocab.num_items item_feat_dim = len(genre_vocab) + 1 # +1 for year feats = np.zeros((num_items, item_feat_dim), dtype=np.float32) titles = np.empty(num_items, dtype=object) # Compute year normalization over the items we actually keep (those in vocab). valid_years = [ y for mid, y in zip(movies["movie_id"], movies["year_raw"]) if mid in vocab.item_to_idx and pd.notna(y) ] if valid_years: y_min, y_max = int(min(valid_years)), int(max(valid_years)) else: y_min, y_max = 0, 1 # degenerate; avoid divide-by-zero y_range = max(y_max - y_min, 1) for _, row in movies.iterrows(): mid = int(row["movie_id"]) if mid not in vocab.item_to_idx: continue idx = vocab.item_to_idx[mid] titles[idx] = str(row["title"]) # Genres -> multi-hot. for g in str(row["genres"]).split("|"): if g and g != _NO_GENRES_SENTINEL and g in genre_to_idx: feats[idx, genre_to_idx[g]] = 1.0 # Year -> normalized scalar. Missing -> 0 (a documented sentinel). year = row["year_raw"] if pd.notna(year): feats[idx, -1] = float((int(year) - y_min) / y_range) else: feats[idx, -1] = 0.0 return feats, genre_vocab, titles def _encode_user_features(users: pd.DataFrame, vocab: Vocab) -> np.ndarray: """Gender one-hot (2) + age-bucket one-hot (7) + occupation one-hot (21).""" num_users = vocab.num_users gender_dim = len(_GENDER_TO_IDX) age_dim = len(_AGE_BUCKETS) occ_dim = _NUM_OCCUPATIONS dim = gender_dim + age_dim + occ_dim feats = np.zeros((num_users, dim), dtype=np.float32) for _, row in users.iterrows(): uid = int(row["user_id"]) if uid not in vocab.user_to_idx: continue idx = vocab.user_to_idx[uid] g = str(row["gender"]) if g in _GENDER_TO_IDX: feats[idx, _GENDER_TO_IDX[g]] = 1.0 age = int(row["age"]) if age in _AGE_TO_IDX: feats[idx, gender_dim + _AGE_TO_IDX[age]] = 1.0 occ = int(row["occupation"]) if 0 <= occ < occ_dim: feats[idx, gender_dim + age_dim + occ] = 1.0 return feats def _parse_year(title: object) -> float: """Return the trailing (YYYY) from a title, or NaN if missing / malformed.""" if not isinstance(title, str): return float("nan") m = _YEAR_RE.search(title) return float(m.group(1)) if m else float("nan")