Spaces:
Running
Running
| """ | |
| Download and sample Amazon Reviews 2023 β Books, Movies_and_TV, Kindle_Store. | |
| Source: McAuley-Lab/Amazon-Reviews-2023 on HuggingFace. | |
| The dataset is stored as single, large JSONL files (one per category) under: | |
| raw/review_categories/<Category>.jsonl | |
| raw/meta_categories/meta_<Category>.jsonl | |
| We stream these files over HTTP, line by line, and cache each phase to disk | |
| so a network hiccup doesn't force us to re-download everything. After a | |
| crash, re-run the script β completed phases are skipped automatically. | |
| Disk cache layout (under data/raw/): | |
| review_<Category>.jsonl β raw streamed reviews (per category) | |
| meta_<Category>.jsonl β filtered metadata (per category) | |
| Output: data/processed/{reviews,items,users}.parquet | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import socket | |
| import time | |
| import urllib.error | |
| import urllib.request | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| log = logging.getLogger(__name__) | |
| CATEGORIES = ["Books", "Movies_and_TV", "Kindle_Store"] | |
| BASE_URL = "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/resolve/main" | |
| DATA_DIR = Path(os.environ.get("DATA_DIR", "./data")) | |
| RAW_DIR = DATA_DIR / "raw" | |
| PROCESSED_DIR = DATA_DIR / "processed" | |
| REVIEW_KEEP_KEYS = ("user_id", "parent_asin", "rating", "title", "text", | |
| "helpful_vote", "verified_purchase", "timestamp") | |
| META_KEEP_KEYS = ("parent_asin", "title", "description", "features", | |
| "categories", "average_rating", "rating_number", "price") | |
| DEFAULT_USERS_PER_CATEGORY = 3000 | |
| DEFAULT_MIN_REVIEWS = 5 | |
| DEFAULT_MAX_REVIEWS = 50 | |
| DEFAULT_TEST_HOLDOUT = 2 | |
| # Network tolerance | |
| NETWORK_TIMEOUT = 300 # 5 minutes per read | |
| RETRY_ATTEMPTS = 4 | |
| RETRY_BACKOFF_BASE = 5 # seconds; doubles each retry | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Network primitives with retry | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _open_url(url: str, byte_offset: int = 0): | |
| """Open a URL with a long timeout and optional Range header for resumes.""" | |
| headers = {"User-Agent": "naijataste-ai/1.0"} | |
| if byte_offset > 0: | |
| headers["Range"] = f"bytes={byte_offset}-" | |
| req = urllib.request.Request(url, headers=headers) | |
| return urllib.request.urlopen(req, timeout=NETWORK_TIMEOUT) | |
| def _is_transient(exc: BaseException) -> bool: | |
| """Network errors we should retry on.""" | |
| if isinstance(exc, (socket.timeout, TimeoutError)): | |
| return True | |
| if isinstance(exc, urllib.error.URLError): | |
| # Most URLErrors wrap transient issues (DNS, conn reset) | |
| return True | |
| if isinstance(exc, (ConnectionResetError, ConnectionAbortedError, | |
| ConnectionRefusedError)): | |
| return True | |
| return False | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Streaming with disk cache | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def stream_to_cache(url: str, cache_path: Path, max_rows: int, | |
| progress_every: int = 25_000) -> int: | |
| """Stream a JSONL URL to a local cache file. Returns rows written. | |
| If cache_path already exists and contains >= max_rows lines, this is a | |
| no-op. Otherwise writes line-by-line; on network failure, retries with | |
| exponential backoff and resumes from where we left off. | |
| """ | |
| if cache_path.exists(): | |
| existing = sum(1 for _ in cache_path.open("r", encoding="utf-8")) | |
| if existing >= max_rows: | |
| log.info(f" cache hit: {cache_path.name} has {existing:,} rows β₯ target {max_rows:,}; skipping download") | |
| return existing | |
| log.info(f" partial cache: {cache_path.name} has {existing:,} rows; resuming") | |
| rows_so_far = existing | |
| mode = "a" | |
| else: | |
| rows_so_far = 0 | |
| mode = "w" | |
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| for attempt in range(1, RETRY_ATTEMPTS + 1): | |
| try: | |
| with _open_url(url) as resp, cache_path.open(mode, encoding="utf-8") as fout: | |
| # If resuming, we need to skip lines we already have. | |
| # Simpler approach: server doesn't honor byte ranges reliably | |
| # on HF for line semantics, so we re-stream from start and | |
| # skip the first `rows_so_far` lines. | |
| skipped = 0 | |
| for raw in resp: | |
| if not raw or raw.isspace(): | |
| continue | |
| if skipped < rows_so_far: | |
| skipped += 1 | |
| continue | |
| # Write line as-is (already valid JSONL line ending with \n) | |
| text = raw.decode("utf-8", errors="replace") | |
| if not text.endswith("\n"): | |
| text += "\n" | |
| fout.write(text) | |
| rows_so_far += 1 | |
| if rows_so_far % progress_every == 0: | |
| log.info(f" cached {rows_so_far:,} rowsβ¦") | |
| if rows_so_far >= max_rows: | |
| break | |
| log.info(f" β cached {rows_so_far:,} rows to {cache_path.name}") | |
| return rows_so_far | |
| except Exception as e: | |
| if not _is_transient(e) or attempt == RETRY_ATTEMPTS: | |
| raise | |
| backoff = RETRY_BACKOFF_BASE * (2 ** (attempt - 1)) | |
| log.warning(f" network error ({type(e).__name__}: {e}); retry {attempt}/{RETRY_ATTEMPTS - 1} in {backoff}s") | |
| time.sleep(backoff) | |
| # Recount how much we have on disk before next attempt | |
| if cache_path.exists(): | |
| rows_so_far = sum(1 for _ in cache_path.open("r", encoding="utf-8")) | |
| mode = "a" | |
| else: | |
| rows_so_far = 0 | |
| mode = "w" | |
| raise RuntimeError("unreachable") | |
| def stream_filter_to_cache(url: str, cache_path: Path, target_asins: set[str], | |
| max_scan: int, progress_every: int = 100_000) -> int: | |
| """Stream metadata, keep only rows whose parent_asin is in target, cache. | |
| Same retry+resume semantics as stream_to_cache. Returns rows written. | |
| """ | |
| if cache_path.exists(): | |
| kept_existing = sum(1 for _ in cache_path.open("r", encoding="utf-8")) | |
| log.info(f" cache hit: {cache_path.name} has {kept_existing:,} rows; using as-is") | |
| return kept_existing | |
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| kept = 0 | |
| scanned = 0 | |
| found_asins: set[str] = set() | |
| for attempt in range(1, RETRY_ATTEMPTS + 1): | |
| try: | |
| # Truncate cache on retry β we restart scanning from the top | |
| # (deduplication happens at parquet stage via drop_duplicates) | |
| with _open_url(url) as resp, cache_path.open("w", encoding="utf-8") as fout: | |
| kept = 0 | |
| scanned = 0 | |
| found_asins = set() | |
| for raw in resp: | |
| if not raw or raw.isspace(): | |
| continue | |
| scanned += 1 | |
| try: | |
| row = json.loads(raw) | |
| except json.JSONDecodeError: | |
| continue | |
| asin = row.get("parent_asin") | |
| if asin in target_asins and asin not in found_asins: | |
| text = raw.decode("utf-8", errors="replace") \ | |
| if isinstance(raw, bytes) else raw | |
| if not text.endswith("\n"): | |
| text += "\n" | |
| fout.write(text) | |
| kept += 1 | |
| found_asins.add(asin) | |
| if kept >= len(target_asins): | |
| break | |
| if scanned % progress_every == 0: | |
| log.info(f" scanned {scanned:,}, kept {kept:,}") | |
| if scanned >= max_scan: | |
| break | |
| log.info(f" β scanned {scanned:,}, cached {kept:,} matching rows to {cache_path.name}") | |
| return kept | |
| except Exception as e: | |
| if not _is_transient(e) or attempt == RETRY_ATTEMPTS: | |
| raise | |
| backoff = RETRY_BACKOFF_BASE * (2 ** (attempt - 1)) | |
| log.warning(f" network error ({type(e).__name__}: {e}); retry {attempt}/{RETRY_ATTEMPTS - 1} in {backoff}s") | |
| time.sleep(backoff) | |
| raise RuntimeError("unreachable") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Cache β DataFrame loaders | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_reviews_from_cache(cache_path: Path, category: str) -> pd.DataFrame: | |
| rows = [] | |
| with cache_path.open("r", encoding="utf-8") as f: | |
| for raw in f: | |
| try: | |
| r = json.loads(raw) | |
| except json.JSONDecodeError: | |
| continue | |
| rows.append({k: r.get(k) for k in REVIEW_KEEP_KEYS}) | |
| df = pd.DataFrame(rows) | |
| df["domain"] = category | |
| return df | |
| def load_meta_from_cache(cache_path: Path, category: str) -> pd.DataFrame: | |
| rows = [] | |
| with cache_path.open("r", encoding="utf-8") as f: | |
| for raw in f: | |
| try: | |
| r = json.loads(raw) | |
| except json.JSONDecodeError: | |
| continue | |
| row = {} | |
| for k in META_KEEP_KEYS: | |
| v = r.get(k) | |
| if isinstance(v, list): | |
| v = " ".join(str(x) for x in v if x is not None) | |
| row[k] = v | |
| row["domain"] = category | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| if not df.empty: | |
| for col in ("description", "features"): | |
| if col in df.columns: | |
| df[col] = df[col].astype(str).str.slice(0, 2000) | |
| return df | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Sampling, splits, normalization (unchanged from v3) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def sample_users(reviews: pd.DataFrame, min_reviews: int, max_reviews: int, | |
| target_users: int) -> pd.DataFrame: | |
| counts = reviews.groupby("user_id").agg( | |
| n_reviews=("rating", "size"), | |
| n_domains=("domain", "nunique"), | |
| ).reset_index() | |
| eligible = counts[(counts["n_reviews"] >= min_reviews) | |
| & (counts["n_reviews"] <= max_reviews)] | |
| log.info(f"{len(eligible):,} users in [{min_reviews},{max_reviews}] reviews") | |
| cross = eligible[eligible["n_domains"] >= 2] | |
| single = eligible[eligible["n_domains"] == 1] | |
| n_cross = min(len(cross), target_users // 3) | |
| n_single = min(len(single), target_users - n_cross) | |
| log.info(f"Sampling {n_cross:,} cross-domain + {n_single:,} single-domain users") | |
| rng = np.random.default_rng(42) | |
| cross_s = cross.sample(n=n_cross, random_state=rng.integers(1e9)) if n_cross else cross.head(0) | |
| single_s = single.sample(n=n_single, random_state=rng.integers(1e9)) if n_single else single.head(0) | |
| return pd.concat([cross_s, single_s], ignore_index=True) | |
| def build_train_test_splits(reviews: pd.DataFrame, holdout: int) -> pd.DataFrame: | |
| reviews = reviews.sort_values(["user_id", "timestamp"], ascending=[True, True]) | |
| reviews["rank_within_user"] = reviews.groupby("user_id").cumcount(ascending=False) | |
| reviews["split"] = np.where(reviews["rank_within_user"] < holdout, "test", "train") | |
| return reviews.drop(columns=["rank_within_user"]) | |
| def normalize_items_for_parquet(items: pd.DataFrame) -> pd.DataFrame: | |
| """Coerce messy item-metadata columns to clean dtypes.""" | |
| if items.empty: | |
| return items | |
| out = items.copy() | |
| for col in ("price", "average_rating", "rating_number"): | |
| if col in out.columns: | |
| s = out[col].astype(str).str.replace(r"^\$", "", regex=True) | |
| out[col] = pd.to_numeric(s, errors="coerce") | |
| for col in ("parent_asin", "title", "description", "features", | |
| "categories", "domain"): | |
| if col in out.columns: | |
| out[col] = out[col].astype(str).replace({"None": "", "nan": ""}) | |
| return out | |
| def build_user_stats(reviews_train: pd.DataFrame) -> pd.DataFrame: | |
| def lens(s): | |
| return s.fillna("").astype(str).str.split().str.len() | |
| stats = reviews_train.groupby("user_id").agg( | |
| n_reviews=("rating", "size"), | |
| avg_rating=("rating", "mean"), | |
| std_rating=("rating", "std"), | |
| avg_review_length=("text", lambda s: lens(s).mean()), | |
| std_review_length=("text", lambda s: lens(s).std()), | |
| verified_rate=("verified_purchase", "mean"), | |
| domains=("domain", lambda s: list(s.unique())), | |
| n_domains=("domain", "nunique"), | |
| ).reset_index() | |
| stats["std_rating"] = stats["std_rating"].fillna(0) | |
| stats["std_review_length"] = stats["std_review_length"].fillna(0) | |
| return stats | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--rows-per-category", type=int, default=150_000) | |
| ap.add_argument("--meta-scan-cap", type=int, default=600_000, | |
| help="Max metadata rows to scan per category (smaller=faster)") | |
| ap.add_argument("--target-users", type=int, default=DEFAULT_USERS_PER_CATEGORY * 3) | |
| ap.add_argument("--min-reviews", type=int, default=DEFAULT_MIN_REVIEWS) | |
| ap.add_argument("--max-reviews", type=int, default=DEFAULT_MAX_REVIEWS) | |
| ap.add_argument("--test-holdout", type=int, default=DEFAULT_TEST_HOLDOUT) | |
| ap.add_argument("--skip-meta", action="store_true", | |
| help="Skip metadata download; use review titles as item info") | |
| args = ap.parse_args() | |
| PROCESSED_DIR.mkdir(parents=True, exist_ok=True) | |
| RAW_DIR.mkdir(parents=True, exist_ok=True) | |
| # ββ Phase 1: reviews (cached per category) ββββββββββββββββββββββββββββββ | |
| log.info("=" * 70) | |
| log.info("PHASE 1: downloading review files (resumable)") | |
| log.info("=" * 70) | |
| for cat in CATEGORIES: | |
| cache_path = RAW_DIR / f"review_{cat}.jsonl" | |
| url = f"{BASE_URL}/raw/review_categories/{cat}.jsonl" | |
| log.info(f"[{cat}] reviews β {cache_path.name}") | |
| stream_to_cache(url, cache_path, max_rows=args.rows_per_category) | |
| all_reviews = [] | |
| for cat in CATEGORIES: | |
| cache_path = RAW_DIR / f"review_{cat}.jsonl" | |
| df = load_reviews_from_cache(cache_path, cat) | |
| log.info(f"[{cat}] loaded {len(df):,} reviews from cache") | |
| all_reviews.append(df) | |
| reviews = pd.concat(all_reviews, ignore_index=True) | |
| log.info(f"Total raw reviews: {len(reviews):,}") | |
| # ββ Phase 2: clean + sample users + splits βββββββββββββββββββββββββββββ | |
| log.info("=" * 70) | |
| log.info("PHASE 2: filtering, sampling, splits") | |
| log.info("=" * 70) | |
| reviews = reviews.dropna(subset=["user_id", "parent_asin", "rating", "text"]) | |
| reviews = reviews[reviews["text"].astype(str).str.len() > 20] | |
| log.info(f"After cleaning: {len(reviews):,} reviews") | |
| user_sample = sample_users(reviews, args.min_reviews, args.max_reviews, | |
| args.target_users) | |
| keep_users = set(user_sample["user_id"]) | |
| reviews = reviews[reviews["user_id"].isin(keep_users)].reset_index(drop=True) | |
| log.info(f"After user filter: {len(reviews):,} reviews / {len(keep_users):,} users") | |
| reviews = build_train_test_splits(reviews, holdout=args.test_holdout) | |
| n_train = (reviews["split"] == "train").sum() | |
| n_test = (reviews["split"] == "test").sum() | |
| log.info(f"Train: {n_train:,} | Test: {n_test:,}") | |
| # ββ Phase 3: metadata (cached per category) βββββββββββββββββββββββββββββ | |
| log.info("=" * 70) | |
| log.info("PHASE 3: downloading item metadata (resumable)") | |
| log.info("=" * 70) | |
| if args.skip_meta: | |
| log.info("--skip-meta set; building minimal catalog from review titles") | |
| items = (reviews.groupby(["parent_asin", "domain"]) | |
| .agg(title=("title", "first")) | |
| .reset_index()) | |
| items["description"] = "" | |
| items["features"] = "" | |
| items["categories"] = "" | |
| items["average_rating"] = None | |
| items["rating_number"] = None | |
| items["price"] = None | |
| else: | |
| for cat in CATEGORIES: | |
| cache_path = RAW_DIR / f"meta_{cat}.jsonl" | |
| url = f"{BASE_URL}/raw/meta_categories/meta_{cat}.jsonl" | |
| cat_asins = set(reviews.loc[reviews["domain"] == cat, "parent_asin"]) | |
| log.info(f"[{cat}] metadata β {cache_path.name} (target {len(cat_asins):,} items)") | |
| stream_filter_to_cache(url, cache_path, cat_asins, | |
| max_scan=args.meta_scan_cap) | |
| all_items = [] | |
| for cat in CATEGORIES: | |
| cache_path = RAW_DIR / f"meta_{cat}.jsonl" | |
| df = load_meta_from_cache(cache_path, cat) | |
| log.info(f"[{cat}] loaded {len(df):,} metadata rows from cache") | |
| all_items.append(df) | |
| items = pd.concat(all_items, ignore_index=True) | |
| if not items.empty: | |
| items = items.drop_duplicates(subset=["parent_asin"]) | |
| # Fallback for items without metadata: use review title | |
| found = set(items["parent_asin"]) if not items.empty else set() | |
| missing = (reviews[~reviews["parent_asin"].isin(found)] | |
| .groupby(["parent_asin", "domain"]) | |
| .agg(title=("title", "first")) | |
| .reset_index()) | |
| if not missing.empty: | |
| for col in ("description", "features", "categories"): | |
| missing[col] = "" | |
| for col in ("average_rating", "rating_number", "price"): | |
| missing[col] = None | |
| items = pd.concat([items, missing], ignore_index=True) | |
| log.info(f"Added {len(missing):,} items from review-title fallback") | |
| # ββ Phase 4: write parquet outputs ββββββββββββββββββββββββββββββββββββββ | |
| log.info("=" * 70) | |
| log.info("PHASE 4: writing processed parquet files") | |
| log.info("=" * 70) | |
| user_stats = build_user_stats(reviews[reviews["split"] == "train"]) | |
| items = normalize_items_for_parquet(items) | |
| reviews.to_parquet(PROCESSED_DIR / "reviews.parquet", index=False) | |
| items.to_parquet(PROCESSED_DIR / "items.parquet", index=False) | |
| user_stats.to_parquet(PROCESSED_DIR / "users.parquet", index=False) | |
| log.info(f"Wrote processed files to {PROCESSED_DIR}/") | |
| log.info(f" reviews.parquet: {len(reviews):,} rows") | |
| log.info(f" items.parquet: {len(items):,} rows") | |
| log.info(f" users.parquet: {len(user_stats):,} rows") | |
| log.info("Done.") | |
| if __name__ == "__main__": | |
| main() | |