""" Data pipeline for MARS sequential recommendation. Supports: 1. Amazon Reviews 2023 (Books, Movies_and_TV, etc.) — filtered for power users 2. MovieLens-1M 3. Synthetic data for testing All data is converted to a unified format: - user_id: int - item_ids: List[int] (chronologically sorted) - timestamps: List[float] (in seconds) """ import os import random import numpy as np import torch from torch.utils.data import Dataset, DataLoader from typing import Dict, List, Tuple, Optional from collections import defaultdict import json def download_movielens_1m(data_dir: str = './data/ml-1m') -> str: """Download and extract MovieLens-1M dataset.""" import urllib.request import zipfile os.makedirs(data_dir, exist_ok=True) ratings_path = os.path.join(data_dir, 'ratings.dat') if not os.path.exists(ratings_path): url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip' zip_path = os.path.join(data_dir, 'ml-1m.zip') print(f"Downloading MovieLens-1M from {url}...") urllib.request.urlretrieve(url, zip_path) with zipfile.ZipFile(zip_path, 'r') as z: z.extractall(data_dir) # Move files up one level inner_dir = os.path.join(data_dir, 'ml-1m') if os.path.exists(inner_dir): for f in os.listdir(inner_dir): os.rename(os.path.join(inner_dir, f), os.path.join(data_dir, f)) os.rmdir(inner_dir) os.remove(zip_path) return ratings_path def load_movielens_1m(data_dir: str = './data/ml-1m', min_interactions: int = 5): """Load MovieLens-1M and return user sequences.""" ratings_path = download_movielens_1m(data_dir) # Parse ratings.dat user_interactions = defaultdict(list) with open(ratings_path, 'r') as f: for line in f: parts = line.strip().split('::') user_id = int(parts[0]) item_id = int(parts[1]) rating = float(parts[2]) timestamp = int(parts[3]) # Keep all ratings (implicit feedback style) user_interactions[user_id].append((item_id, timestamp)) # Sort by timestamp, filter short sequences sequences = {} for uid, interactions in user_interactions.items(): interactions.sort(key=lambda x: x[1]) if len(interactions) >= min_interactions: sequences[uid] = { 'item_ids': [x[0] for x in interactions], 'timestamps': [float(x[1]) for x in interactions] } return sequences def load_amazon_reviews( category: str = 'Movies_and_TV', min_interactions: int = 20, max_users: int = 50000, data_dir: str = './data/amazon' ): """ Load Amazon Reviews 2023 dataset from HF Hub. Filters to users with min_interactions+ for long-sequence modeling. """ try: from datasets import load_dataset print(f"Loading Amazon Reviews 2023 - {category}...") # Try benchmark format first try: ds = load_dataset( "McAuley-Lab/Amazon-Reviews-2023", f"0core_rating_only_{category}", trust_remote_code=True, split="train" ) except Exception: # Fallback to raw reviews ds = load_dataset( "McAuley-Lab/Amazon-Reviews-2023", f"raw_review_{category}", trust_remote_code=True, split="full" ) # Build user sequences user_interactions = defaultdict(list) for row in ds: uid = row.get('user_id', row.get('reviewerID')) iid = row.get('parent_asin', row.get('asin')) ts = row.get('timestamp', row.get('unixReviewTime', 0)) if uid and iid: user_interactions[uid].append((iid, float(ts) / 1000 if ts > 1e12 else float(ts))) # Filter and sort sequences = {} for uid, interactions in user_interactions.items(): interactions.sort(key=lambda x: x[1]) if len(interactions) >= min_interactions: sequences[uid] = { 'item_ids': [x[0] for x in interactions], 'timestamps': [x[1] for x in interactions] } # Limit users if len(sequences) > max_users: keys = random.sample(list(sequences.keys()), max_users) sequences = {k: sequences[k] for k in keys} return sequences except Exception as e: print(f"Failed to load Amazon Reviews: {e}") return {} def generate_synthetic_data( num_users: int = 5000, num_items: int = 10000, min_seq_len: int = 50, max_seq_len: int = 1000, seed: int = 42 ) -> Dict: """ Generate synthetic sequential interaction data for testing. Simulates realistic patterns: - Power law item popularity - Temporal patterns (daily/weekly) - User interest drift over time """ rng = np.random.RandomState(seed) # Power law item popularity item_popularity = rng.power(0.8, num_items) item_popularity /= item_popularity.sum() sequences = {} base_time = 1600000000 # ~Sep 2020 for uid in range(num_users): seq_len = rng.randint(min_seq_len, max_seq_len + 1) # User has a few interest clusters num_clusters = rng.randint(2, 6) cluster_centers = rng.choice(num_items, num_clusters, replace=False) cluster_weights = rng.dirichlet(np.ones(num_clusters)) items = [] timestamps = [] current_time = base_time + rng.randint(0, 86400 * 365) # Random start for t in range(seq_len): # Interest drift: cluster weights shift over time drift = rng.dirichlet(np.ones(num_clusters) * 5) current_weights = 0.8 * cluster_weights + 0.2 * drift # Select cluster, then item near cluster center cluster = rng.choice(num_clusters, p=current_weights / current_weights.sum()) center = cluster_centers[cluster] # Items near cluster center (with some randomness) local_items = np.arange( max(0, center - 50), min(num_items, center + 50) ) local_probs = item_popularity[local_items] local_probs /= local_probs.sum() item = local_items[rng.choice(len(local_items), p=local_probs)] items.append(int(item) + 1) # 1-indexed (0 = padding) # Time gap: exponential with daily/weekly patterns gap = rng.exponential(3600) # avg 1 hour # Add daily pattern hour = (current_time % 86400) / 3600 if 2 < hour < 8: # Less activity at night gap *= 3 current_time += gap timestamps.append(current_time) sequences[uid] = { 'item_ids': items, 'timestamps': timestamps } return sequences class ReindexedData: """Reindex items to contiguous integers and provide train/val/test splits.""" def __init__( self, sequences: Dict, max_seq_len: int = 512, val_ratio: float = 0.1, test_ratio: float = 0.1, ): self.max_seq_len = max_seq_len # Collect all items and reindex all_items = set() for uid, data in sequences.items(): all_items.update(data['item_ids']) self.item2idx = {item: idx + 1 for idx, item in enumerate(sorted(all_items))} self.idx2item = {idx: item for item, idx in self.item2idx.items()} self.num_items = len(self.item2idx) print(f"Total users: {len(sequences)}, Total items: {self.num_items}") # Reindex and split self.train_data = [] self.val_data = [] self.test_data = [] seq_lens = [] for uid, data in sequences.items(): item_ids = [self.item2idx[i] for i in data['item_ids']] timestamps = data['timestamps'] # Truncate to max_seq_len if len(item_ids) > max_seq_len: item_ids = item_ids[-max_seq_len:] timestamps = timestamps[-max_seq_len:] seq_lens.append(len(item_ids)) if len(item_ids) < 3: continue # Leave-one-out split self.train_data.append({ 'user_id': uid, 'item_ids': item_ids[:-2], 'timestamps': timestamps[:-2], 'next_item': item_ids[-2], }) self.val_data.append({ 'user_id': uid, 'item_ids': item_ids[:-1], 'timestamps': timestamps[:-1], 'next_item': item_ids[-1], }) self.test_data.append({ 'user_id': uid, 'item_ids': item_ids[:-1], 'timestamps': timestamps[:-1], 'next_item': item_ids[-1], }) seq_lens = np.array(seq_lens) print(f"Sequence length stats: mean={seq_lens.mean():.1f}, " f"median={np.median(seq_lens):.1f}, " f"max={seq_lens.max()}, min={seq_lens.min()}") print(f"Users with 100+ interactions: {(seq_lens >= 100).sum()}") print(f"Users with 200+ interactions: {(seq_lens >= 200).sum()}") print(f"Train: {len(self.train_data)}, Val: {len(self.val_data)}, " f"Test: {len(self.test_data)}") class SeqRecDataset(Dataset): """Sequential recommendation dataset with negative sampling.""" def __init__( self, data: List[Dict], num_items: int, max_seq_len: int = 512, num_negatives: int = 1, is_training: bool = True, ): self.data = data self.num_items = num_items self.max_seq_len = max_seq_len self.num_negatives = num_negatives self.is_training = is_training def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] item_ids = sample['item_ids'][-self.max_seq_len:] timestamps = sample['timestamps'][-self.max_seq_len:] next_item = sample['next_item'] # Padding seq_len = len(item_ids) pad_len = self.max_seq_len - seq_len # Right-padding (needed for causal attention to work correctly) padded_items = item_ids + [0] * pad_len padded_timestamps = timestamps + [0.0] * pad_len mask = [True] * seq_len + [False] * pad_len # Negative sampling item_set = set(item_ids) negatives = [] for _ in range(self.num_negatives): neg = random.randint(1, self.num_items) while neg in item_set: neg = random.randint(1, self.num_items) negatives.append(neg) return { 'item_ids': torch.tensor(padded_items, dtype=torch.long), 'timestamps': torch.tensor(padded_timestamps, dtype=torch.float32), 'mask': torch.tensor(mask, dtype=torch.bool), 'positive_ids': torch.tensor(next_item, dtype=torch.long), 'negative_ids': torch.tensor(negatives, dtype=torch.long), } def create_dataloaders( data: ReindexedData, max_seq_len: int = 512, batch_size: int = 128, num_negatives: int = 4, eval_negatives: int = 999, num_workers: int = 2, ) -> Tuple[DataLoader, DataLoader, DataLoader]: """Create train/val/test dataloaders. Uses 999 negatives for evaluation (standard SASRec protocol). """ train_dataset = SeqRecDataset( data.train_data, data.num_items, max_seq_len, num_negatives=num_negatives, is_training=True ) val_dataset = SeqRecDataset( data.val_data, data.num_items, max_seq_len, num_negatives=eval_negatives, is_training=False ) test_dataset = SeqRecDataset( data.test_data, data.num_items, max_seq_len, num_negatives=eval_negatives, is_training=False ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) return train_loader, val_loader, test_loader def save_data_config(data: ReindexedData, path: str): """Save data configuration for model loading.""" config = { 'num_items': data.num_items, 'num_train': len(data.train_data), 'num_val': len(data.val_data), 'num_test': len(data.test_data), } os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True) with open(path, 'w') as f: json.dump(config, f, indent=2) return config