MARS-SeqRec / data.py
CyberDancer's picture
MARS: Multi-scale Adaptive Recurrence with State compression
2319f81 verified
"""
Data pipeline for MARS sequential recommendation.
Supports:
1. Amazon Reviews 2023 (Books, Movies_and_TV, etc.) — filtered for power users
2. MovieLens-1M
3. Synthetic data for testing
All data is converted to a unified format:
- user_id: int
- item_ids: List[int] (chronologically sorted)
- timestamps: List[float] (in seconds)
"""
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import json
def download_movielens_1m(data_dir: str = './data/ml-1m') -> str:
"""Download and extract MovieLens-1M dataset."""
import urllib.request
import zipfile
os.makedirs(data_dir, exist_ok=True)
ratings_path = os.path.join(data_dir, 'ratings.dat')
if not os.path.exists(ratings_path):
url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip'
zip_path = os.path.join(data_dir, 'ml-1m.zip')
print(f"Downloading MovieLens-1M from {url}...")
urllib.request.urlretrieve(url, zip_path)
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(data_dir)
# Move files up one level
inner_dir = os.path.join(data_dir, 'ml-1m')
if os.path.exists(inner_dir):
for f in os.listdir(inner_dir):
os.rename(os.path.join(inner_dir, f), os.path.join(data_dir, f))
os.rmdir(inner_dir)
os.remove(zip_path)
return ratings_path
def load_movielens_1m(data_dir: str = './data/ml-1m', min_interactions: int = 5):
"""Load MovieLens-1M and return user sequences."""
ratings_path = download_movielens_1m(data_dir)
# Parse ratings.dat
user_interactions = defaultdict(list)
with open(ratings_path, 'r') as f:
for line in f:
parts = line.strip().split('::')
user_id = int(parts[0])
item_id = int(parts[1])
rating = float(parts[2])
timestamp = int(parts[3])
# Keep all ratings (implicit feedback style)
user_interactions[user_id].append((item_id, timestamp))
# Sort by timestamp, filter short sequences
sequences = {}
for uid, interactions in user_interactions.items():
interactions.sort(key=lambda x: x[1])
if len(interactions) >= min_interactions:
sequences[uid] = {
'item_ids': [x[0] for x in interactions],
'timestamps': [float(x[1]) for x in interactions]
}
return sequences
def load_amazon_reviews(
category: str = 'Movies_and_TV',
min_interactions: int = 20,
max_users: int = 50000,
data_dir: str = './data/amazon'
):
"""
Load Amazon Reviews 2023 dataset from HF Hub.
Filters to users with min_interactions+ for long-sequence modeling.
"""
try:
from datasets import load_dataset
print(f"Loading Amazon Reviews 2023 - {category}...")
# Try benchmark format first
try:
ds = load_dataset(
"McAuley-Lab/Amazon-Reviews-2023",
f"0core_rating_only_{category}",
trust_remote_code=True,
split="train"
)
except Exception:
# Fallback to raw reviews
ds = load_dataset(
"McAuley-Lab/Amazon-Reviews-2023",
f"raw_review_{category}",
trust_remote_code=True,
split="full"
)
# Build user sequences
user_interactions = defaultdict(list)
for row in ds:
uid = row.get('user_id', row.get('reviewerID'))
iid = row.get('parent_asin', row.get('asin'))
ts = row.get('timestamp', row.get('unixReviewTime', 0))
if uid and iid:
user_interactions[uid].append((iid, float(ts) / 1000 if ts > 1e12 else float(ts)))
# Filter and sort
sequences = {}
for uid, interactions in user_interactions.items():
interactions.sort(key=lambda x: x[1])
if len(interactions) >= min_interactions:
sequences[uid] = {
'item_ids': [x[0] for x in interactions],
'timestamps': [x[1] for x in interactions]
}
# Limit users
if len(sequences) > max_users:
keys = random.sample(list(sequences.keys()), max_users)
sequences = {k: sequences[k] for k in keys}
return sequences
except Exception as e:
print(f"Failed to load Amazon Reviews: {e}")
return {}
def generate_synthetic_data(
num_users: int = 5000,
num_items: int = 10000,
min_seq_len: int = 50,
max_seq_len: int = 1000,
seed: int = 42
) -> Dict:
"""
Generate synthetic sequential interaction data for testing.
Simulates realistic patterns:
- Power law item popularity
- Temporal patterns (daily/weekly)
- User interest drift over time
"""
rng = np.random.RandomState(seed)
# Power law item popularity
item_popularity = rng.power(0.8, num_items)
item_popularity /= item_popularity.sum()
sequences = {}
base_time = 1600000000 # ~Sep 2020
for uid in range(num_users):
seq_len = rng.randint(min_seq_len, max_seq_len + 1)
# User has a few interest clusters
num_clusters = rng.randint(2, 6)
cluster_centers = rng.choice(num_items, num_clusters, replace=False)
cluster_weights = rng.dirichlet(np.ones(num_clusters))
items = []
timestamps = []
current_time = base_time + rng.randint(0, 86400 * 365) # Random start
for t in range(seq_len):
# Interest drift: cluster weights shift over time
drift = rng.dirichlet(np.ones(num_clusters) * 5)
current_weights = 0.8 * cluster_weights + 0.2 * drift
# Select cluster, then item near cluster center
cluster = rng.choice(num_clusters, p=current_weights / current_weights.sum())
center = cluster_centers[cluster]
# Items near cluster center (with some randomness)
local_items = np.arange(
max(0, center - 50),
min(num_items, center + 50)
)
local_probs = item_popularity[local_items]
local_probs /= local_probs.sum()
item = local_items[rng.choice(len(local_items), p=local_probs)]
items.append(int(item) + 1) # 1-indexed (0 = padding)
# Time gap: exponential with daily/weekly patterns
gap = rng.exponential(3600) # avg 1 hour
# Add daily pattern
hour = (current_time % 86400) / 3600
if 2 < hour < 8: # Less activity at night
gap *= 3
current_time += gap
timestamps.append(current_time)
sequences[uid] = {
'item_ids': items,
'timestamps': timestamps
}
return sequences
class ReindexedData:
"""Reindex items to contiguous integers and provide train/val/test splits."""
def __init__(
self,
sequences: Dict,
max_seq_len: int = 512,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
):
self.max_seq_len = max_seq_len
# Collect all items and reindex
all_items = set()
for uid, data in sequences.items():
all_items.update(data['item_ids'])
self.item2idx = {item: idx + 1 for idx, item in enumerate(sorted(all_items))}
self.idx2item = {idx: item for item, idx in self.item2idx.items()}
self.num_items = len(self.item2idx)
print(f"Total users: {len(sequences)}, Total items: {self.num_items}")
# Reindex and split
self.train_data = []
self.val_data = []
self.test_data = []
seq_lens = []
for uid, data in sequences.items():
item_ids = [self.item2idx[i] for i in data['item_ids']]
timestamps = data['timestamps']
# Truncate to max_seq_len
if len(item_ids) > max_seq_len:
item_ids = item_ids[-max_seq_len:]
timestamps = timestamps[-max_seq_len:]
seq_lens.append(len(item_ids))
if len(item_ids) < 3:
continue
# Leave-one-out split
self.train_data.append({
'user_id': uid,
'item_ids': item_ids[:-2],
'timestamps': timestamps[:-2],
'next_item': item_ids[-2],
})
self.val_data.append({
'user_id': uid,
'item_ids': item_ids[:-1],
'timestamps': timestamps[:-1],
'next_item': item_ids[-1],
})
self.test_data.append({
'user_id': uid,
'item_ids': item_ids[:-1],
'timestamps': timestamps[:-1],
'next_item': item_ids[-1],
})
seq_lens = np.array(seq_lens)
print(f"Sequence length stats: mean={seq_lens.mean():.1f}, "
f"median={np.median(seq_lens):.1f}, "
f"max={seq_lens.max()}, min={seq_lens.min()}")
print(f"Users with 100+ interactions: {(seq_lens >= 100).sum()}")
print(f"Users with 200+ interactions: {(seq_lens >= 200).sum()}")
print(f"Train: {len(self.train_data)}, Val: {len(self.val_data)}, "
f"Test: {len(self.test_data)}")
class SeqRecDataset(Dataset):
"""Sequential recommendation dataset with negative sampling."""
def __init__(
self,
data: List[Dict],
num_items: int,
max_seq_len: int = 512,
num_negatives: int = 1,
is_training: bool = True,
):
self.data = data
self.num_items = num_items
self.max_seq_len = max_seq_len
self.num_negatives = num_negatives
self.is_training = is_training
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
item_ids = sample['item_ids'][-self.max_seq_len:]
timestamps = sample['timestamps'][-self.max_seq_len:]
next_item = sample['next_item']
# Padding
seq_len = len(item_ids)
pad_len = self.max_seq_len - seq_len
# Right-padding (needed for causal attention to work correctly)
padded_items = item_ids + [0] * pad_len
padded_timestamps = timestamps + [0.0] * pad_len
mask = [True] * seq_len + [False] * pad_len
# Negative sampling
item_set = set(item_ids)
negatives = []
for _ in range(self.num_negatives):
neg = random.randint(1, self.num_items)
while neg in item_set:
neg = random.randint(1, self.num_items)
negatives.append(neg)
return {
'item_ids': torch.tensor(padded_items, dtype=torch.long),
'timestamps': torch.tensor(padded_timestamps, dtype=torch.float32),
'mask': torch.tensor(mask, dtype=torch.bool),
'positive_ids': torch.tensor(next_item, dtype=torch.long),
'negative_ids': torch.tensor(negatives, dtype=torch.long),
}
def create_dataloaders(
data: ReindexedData,
max_seq_len: int = 512,
batch_size: int = 128,
num_negatives: int = 4,
eval_negatives: int = 999,
num_workers: int = 2,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""Create train/val/test dataloaders.
Uses 999 negatives for evaluation (standard SASRec protocol).
"""
train_dataset = SeqRecDataset(
data.train_data, data.num_items, max_seq_len,
num_negatives=num_negatives, is_training=True
)
val_dataset = SeqRecDataset(
data.val_data, data.num_items, max_seq_len,
num_negatives=eval_negatives, is_training=False
)
test_dataset = SeqRecDataset(
data.test_data, data.num_items, max_seq_len,
num_negatives=eval_negatives, is_training=False
)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
)
return train_loader, val_loader, test_loader
def save_data_config(data: ReindexedData, path: str):
"""Save data configuration for model loading."""
config = {
'num_items': data.num_items,
'num_train': len(data.train_data),
'num_val': len(data.val_data),
'num_test': len(data.test_data),
}
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
with open(path, 'w') as f:
json.dump(config, f, indent=2)
return config