Spaces:
Sleeping
Sleeping
File size: 10,983 Bytes
38ae75d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import zipfile
import pandas as pd
from datetime import datetime, timedelta
import numpy as np
from scipy.sparse import csr_matrix
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
def extract_ziped_data(ziped_data_path: str, extract_path : str):
"""Extracts the contents of a zip file to a specified directory.
args:
ziped_data_path: str, path to the zip file
extract_path: str, path to the directory where contents will be extracted
"""
# The directory where you want to extract the contents
extract_path = 'data'
# Open the zip file in read mode
with zipfile.ZipFile(ziped_data_path, 'r') as zip_ref:
# Extract all the contents into the specified directory
zip_ref.extractall(extract_path)
print(f"'{ziped_data_path}' has been extracted to '{extract_path}'")
def prepare_data(data_folder='data/', val_days=7, test_days=7):
"""
Loads, preprocesses, and splits the events data into train, validation, and test sets.
args:
data_folder: str, path to the folder containing 'events.csv'
val_days: int, number of days for the validation set
test_days: int, number of days for the test set
"""
# --- Load Data ---
print(f"Loading events.csv from folder: {data_folder}")
try:
events_df = pd.read_csv(data_folder + 'events.csv')
print("Successfully loaded events.csv.")
events_df['timestamp_dt'] = pd.to_datetime(events_df['timestamp'], unit='ms')
print("\n--- Initial Data Summary ---")
print(f"Data shape: {events_df.shape}")
print(f"Full timeframe: {events_df['timestamp_dt'].min()} to {events_df['timestamp_dt'].max()}")
print("----------------------------\n")
except FileNotFoundError:
print(f"Error: 'events.csv' not found in '{data_folder}'. Please check the path.")
return None, None, None
# --- Split Data ---
sorted_df = events_df.sort_values('timestamp_dt').reset_index(drop=True)
print(f"Splitting data: {test_days} days for test, {val_days} for validation.")
end_time = sorted_df['timestamp_dt'].max()
test_start_time = end_time - timedelta(days=test_days)
val_start_time = test_start_time - timedelta(days=val_days)
test_df = sorted_df[sorted_df['timestamp_dt'] >= test_start_time]
val_df = sorted_df[(sorted_df['timestamp_dt'] >= val_start_time) & (sorted_df['timestamp_dt'] < test_start_time)]
train_df = sorted_df[sorted_df['timestamp_dt'] < val_start_time]
print("--- Data Splitting Summary ---")
print(f"Training set: {train_df.shape[0]:>8} records | from {train_df['timestamp_dt'].min()} to {train_df['timestamp_dt'].max()}")
print(f"Validation set: {val_df.shape[0]:>8} records | from {val_df['timestamp_dt'].min()} to {val_df['timestamp_dt'].max()}")
print(f"Test set: {test_df.shape[0]:>8} records | from {test_df['timestamp_dt'].min()} to {test_df['timestamp_dt'].max()}")
print("------------------------------")
return train_df, val_df, test_df
class SASRecDataset(Dataset):
"""
SASRec Dataset.
- Precomputes (sequence_id, cutoff_idx) pairs for O(1) __getitem__.
- Supports 'last' or 'all' target modes.
"""
def __init__(self, sequences, max_len, target_mode="last"):
"""
Args:
sequences: list of user sequences (list of item IDs).
max_len: maximum sequence length (padding applied).
target_mode: 'last' (only last prediction) or 'all' (predict at every step).
"""
self.sequences = sequences
self.max_len = max_len
self.target_mode = target_mode
# Build index once
self.index = []
for seq_id, seq in enumerate(sequences):
for i in range(1, len(seq)):
self.index.append((seq_id, i))
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
seq_id, cutoff = self.index[idx]
seq = self.sequences[seq_id][:cutoff]
# Truncate & pad
seq = seq[-self.max_len:]
pad_len = self.max_len - len(seq)
input_seq = np.zeros(self.max_len, dtype=np.int64)
input_seq[pad_len:] = seq
if self.target_mode == "last":
target = self.sequences[seq_id][cutoff]
return torch.LongTensor(input_seq), torch.LongTensor([target])
elif self.target_mode == "all":
# Predict next item at each step
target_seq = self.sequences[seq_id][1:cutoff+1]
target_seq = target_seq[-self.max_len:]
target = np.zeros(self.max_len, dtype=np.int64)
target[-len(target_seq):] = target_seq
return torch.LongTensor(input_seq), torch.LongTensor(target)
class SASRecDataModule(pl.LightningDataModule):
"""
PyTorch Lightning DataModule for preparing the RetailRocket dataset for the SASRec model.
This class handles all aspects of data preparation, including:
- Filtering out infrequent users and items to reduce noise.
- Building a consistent item vocabulary.
- Converting user event histories into sequential data.
- Creating and providing `DataLoader` instances for training, validation, and testing.
"""
def __init__(self, train_df, val_df, test_df, min_item_interactions=5,
min_user_interactions=5, max_len=50, batch_size=256):
"""
Initializes the DataModule.
Args:
train_df (pd.DataFrame): DataFrame for training.
val_df (pd.DataFrame): DataFrame for validation.
test_df (pd.DataFrame): DataFrame for testing.
min_item_interactions (int): Minimum number of interactions for an item to be kept.
min_user_interactions (int): Minimum number of interactions for a user to be kept.
max_len (int): The maximum length of a user sequence fed to the model.
batch_size (int): The batch size for the DataLoaders.
"""
super().__init__()
self.train_df = train_df
self.val_df = val_df
self.test_df = test_df
self.min_item_interactions = min_item_interactions
self.min_user_interactions = min_user_interactions
self.max_len = max_len
self.batch_size = batch_size
self.item_map = None
self.inverse_item_map = None
self.vocab_size = 0
self.user_history = None
def setup(self, stage=None):
"""
Prepares the data for training, validation, and testing.
This method is called automatically by PyTorch Lightning. It performs the following steps:
1. Determines filtering criteria (which users and items to keep) based on the training set only
to prevent data leakage.
2. Applies these filters to the train, validation, and test sets.
3. Builds an item vocabulary (mapping item IDs to integer indices) from the combined
training and validation sets to ensure consistency for model checkpointing.
4. Converts the event logs into sequences of item indices for each user in each data split.
"""
item_counts = self.train_df['itemid'].value_counts()
user_counts = self.train_df['visitorid'].value_counts()
items_to_keep = item_counts[item_counts >= self.min_item_interactions].index
users_to_keep = user_counts[user_counts >= self.min_user_interactions].index
self.filtered_train_df = self.train_df[
(self.train_df['itemid'].isin(items_to_keep)) &
(self.train_df['visitorid'].isin(users_to_keep))
].copy()
self.filtered_val_df = self.val_df[
(self.val_df['itemid'].isin(items_to_keep)) &
(self.val_df['visitorid'].isin(users_to_keep))
].copy()
self.filtered_test_df = self.test_df[
(self.test_df['itemid'].isin(items_to_keep)) &
(self.test_df['visitorid'].isin(users_to_keep))
].copy()
all_known_items_df = pd.concat([self.filtered_train_df, self.filtered_val_df])
unique_items = all_known_items_df['itemid'].unique()
self.item_map = {item_id: i + 1 for i, item_id in enumerate(unique_items)}
self.inverse_item_map = {i: item_id for item_id, i in self.item_map.items()}
self.vocab_size = len(self.item_map) + 1 # +1 for padding token 0
self.user_history = self.filtered_train_df.groupby('visitorid')['itemid'].apply(list)
self.train_sequences = self._create_sequences(self.filtered_train_df)
self.val_sequences = self._create_sequences(self.filtered_val_df)
self.test_sequences = self._create_sequences(self.filtered_test_df)
def _create_sequences(self, df):
"""
Helper function to convert a DataFrame of events into user interaction sequences.
Args:
df (pd.DataFrame): The input DataFrame to process.
Returns:
list[list[int]]: A list of user sequences, where each sequence is a list of item indices.
"""
df_sorted = df.sort_values(['visitorid', 'timestamp_dt'])
sequences = df_sorted.groupby('visitorid')['itemid'].apply(
lambda x: [self.item_map[i] for i in x if i in self.item_map]
).tolist()
return [s for s in sequences if len(s) > 1]
def train_dataloader(self):
"""Creates the DataLoader for the training set."""
dataset = SASRecDataset(self.train_sequences, self.max_len)
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)
def val_dataloader(self):
"""Creates the DataLoader for the validation set."""
dataset = SASRecDataset(self.val_sequences, self.max_len)
return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
def test_dataloader(self):
"""Creates the DataLoader for the test set."""
dataset = SASRecDataset(self.test_sequences, self.max_len)
return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
if __name__ == "__main__":
# --- Configuration ---
DATA_PATH = "data"
ZIPED_DATA_PATH = "data/archive.zip" # change to your zip file path
BATCH_SIZE = 256
MAX_TOKEN_LEN = 50 # 50–100 is standard for SASRec
# extract_ziped_data(ZIPED_DATA_PATH, DATA_PATH) # uncomment this line if you want to extract the data
# --- 1. Prepare the data into train, validation, and test sets ---
train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)
# --- 2. Initialize DataModule ---
print("Initializing DataModule...")
datamodule = SASRecDataModule(
train_df=train_set,
val_df=validation_set,
test_df=test_set,
batch_size=BATCH_SIZE,
max_len=MAX_TOKEN_LEN
)
datamodule.setup() |