File size: 4,369 Bytes
188f0cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Parse MovieLens dataset files into raw DataFrames. Variant-aware.

ml-1m: legacy `.dat` files with `::` separator, latin-1 encoding, has user demographics.
ml-32m (and other modern releases): `.csv` files with headers, comma separator,
UTF-8 encoding, no user demographics.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Final

import pandas as pd

from ..logging_utils import get_logger

_logger = get_logger(__name__)

# ---------- ml-1m schema (legacy `.dat`) ----------
_ML1M_SEP: Final[str] = "::"
_ML1M_ENCODING: Final[str] = "latin-1"
_ML1M_RATINGS_COLS: Final[tuple[str, ...]] = ("user_id", "movie_id", "rating", "timestamp")
_ML1M_USERS_COLS: Final[tuple[str, ...]] = ("user_id", "gender", "age", "occupation", "zip")
_ML1M_MOVIES_COLS: Final[tuple[str, ...]] = ("movie_id", "title", "genres")

# ---------- ml-25m / ml-32m schema (modern `.csv`) ----------
_CSV_RATINGS_RENAME: Final[dict[str, str]] = {"userId": "user_id", "movieId": "movie_id"}
_CSV_MOVIES_RENAME: Final[dict[str, str]] = {"movieId": "movie_id"}


@dataclass(frozen=True)
class RawFrames:
    """Raw MovieLens tables as parsed from disk, with no transformations applied.

    `users` is None on variants that don't ship user demographics (ml-25m, ml-32m).
    """

    ratings: pd.DataFrame
    movies: pd.DataFrame
    users: pd.DataFrame | None = None


def load_raw(dataset_dir: Path | str, variant: str) -> RawFrames:
    """Dispatch on `variant`. Returns frames in a uniform internal schema."""
    dataset_dir = Path(dataset_dir)
    if variant == "ml-1m":
        return _load_ml1m(dataset_dir)
    if variant in {"ml-25m", "ml-32m", "ml-latest", "ml-latest-small"}:
        return _load_csv_variant(dataset_dir)
    raise ValueError(f"unsupported dataset variant: {variant!r}")


# ---------- ml-1m loader ----------

def _load_ml1m(dataset_dir: Path) -> RawFrames:
    ratings = _read_dat(dataset_dir / "ratings.dat", _ML1M_RATINGS_COLS)
    users = _read_dat(dataset_dir / "users.dat", _ML1M_USERS_COLS)
    movies = _read_dat(dataset_dir / "movies.dat", _ML1M_MOVIES_COLS)

    ratings = ratings.astype(
        {"user_id": "int64", "movie_id": "int64", "rating": "int64", "timestamp": "int64"}
    )
    users = users.astype(
        {"user_id": "int64", "age": "int64", "occupation": "int64", "zip": "string"}
    )
    movies = movies.astype({"movie_id": "int64", "title": "string", "genres": "string"})

    _logger.info(
        "Loaded ml-1m: %d ratings, %d users, %d movies",
        len(ratings), len(users), len(movies),
    )
    return RawFrames(ratings=ratings, users=users, movies=movies)


def _read_dat(path: Path, columns: tuple[str, ...]) -> pd.DataFrame:
    if not path.is_file():
        raise FileNotFoundError(f"expected dataset file missing: {path}")
    return pd.read_csv(
        path,
        sep=_ML1M_SEP,
        names=list(columns),
        engine="python",
        encoding=_ML1M_ENCODING,
        header=None,
    )


# ---------- ml-25m / ml-32m / ml-latest loader ----------

def _load_csv_variant(dataset_dir: Path) -> RawFrames:
    """ml-25m, ml-32m and ml-latest all share the same CSV schema."""
    ratings_path = dataset_dir / "ratings.csv"
    movies_path = dataset_dir / "movies.csv"
    if not ratings_path.is_file():
        raise FileNotFoundError(f"expected dataset file missing: {ratings_path}")
    if not movies_path.is_file():
        raise FileNotFoundError(f"expected dataset file missing: {movies_path}")

    # Modern csv variants are UTF-8 with header rows. Titles can contain commas
    # (escaped via double-quotes) — the default csv parser handles that.
    ratings = pd.read_csv(ratings_path).rename(columns=_CSV_RATINGS_RENAME)
    movies = pd.read_csv(movies_path).rename(columns=_CSV_MOVIES_RENAME)

    # Pin dtypes. Note: ratings can be half-stars (0.5–5.0) in modern variants,
    # so `rating` is float, not int.
    ratings = ratings.astype(
        {"user_id": "int64", "movie_id": "int64", "rating": "float32", "timestamp": "int64"}
    )
    movies = movies.astype({"movie_id": "int64", "title": "string", "genres": "string"})

    _logger.info(
        "Loaded csv variant: %d ratings, %d movies (no user demographics)",
        len(ratings), len(movies),
    )
    return RawFrames(ratings=ratings, movies=movies, users=None)