import zipfile import pandas as pd from datetime import datetime, timedelta import numpy as np from scipy.sparse import csr_matrix import math import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import pytorch_lightning as pl def extract_ziped_data(ziped_data_path: str, extract_path : str): """Extracts the contents of a zip file to a specified directory. args: ziped_data_path: str, path to the zip file extract_path: str, path to the directory where contents will be extracted """ # The directory where you want to extract the contents extract_path = 'data' # Open the zip file in read mode with zipfile.ZipFile(ziped_data_path, 'r') as zip_ref: # Extract all the contents into the specified directory zip_ref.extractall(extract_path) print(f"'{ziped_data_path}' has been extracted to '{extract_path}'") def prepare_data(data_folder='data/', val_days=7, test_days=7): """ Loads, preprocesses, and splits the events data into train, validation, and test sets. args: data_folder: str, path to the folder containing 'events.csv' val_days: int, number of days for the validation set test_days: int, number of days for the test set """ # --- Load Data --- print(f"Loading events.csv from folder: {data_folder}") try: events_df = pd.read_csv(data_folder + 'events.csv') print("Successfully loaded events.csv.") events_df['timestamp_dt'] = pd.to_datetime(events_df['timestamp'], unit='ms') print("\n--- Initial Data Summary ---") print(f"Data shape: {events_df.shape}") print(f"Full timeframe: {events_df['timestamp_dt'].min()} to {events_df['timestamp_dt'].max()}") print("----------------------------\n") except FileNotFoundError: print(f"Error: 'events.csv' not found in '{data_folder}'. Please check the path.") return None, None, None # --- Split Data --- sorted_df = events_df.sort_values('timestamp_dt').reset_index(drop=True) print(f"Splitting data: {test_days} days for test, {val_days} for validation.") end_time = sorted_df['timestamp_dt'].max() test_start_time = end_time - timedelta(days=test_days) val_start_time = test_start_time - timedelta(days=val_days) test_df = sorted_df[sorted_df['timestamp_dt'] >= test_start_time] val_df = sorted_df[(sorted_df['timestamp_dt'] >= val_start_time) & (sorted_df['timestamp_dt'] < test_start_time)] train_df = sorted_df[sorted_df['timestamp_dt'] < val_start_time] print("--- Data Splitting Summary ---") print(f"Training set: {train_df.shape[0]:>8} records | from {train_df['timestamp_dt'].min()} to {train_df['timestamp_dt'].max()}") print(f"Validation set: {val_df.shape[0]:>8} records | from {val_df['timestamp_dt'].min()} to {val_df['timestamp_dt'].max()}") print(f"Test set: {test_df.shape[0]:>8} records | from {test_df['timestamp_dt'].min()} to {test_df['timestamp_dt'].max()}") print("------------------------------") return train_df, val_df, test_df class SASRecDataset(Dataset): """ SASRec Dataset. - Precomputes (sequence_id, cutoff_idx) pairs for O(1) __getitem__. - Supports 'last' or 'all' target modes. """ def __init__(self, sequences, max_len, target_mode="last"): """ Args: sequences: list of user sequences (list of item IDs). max_len: maximum sequence length (padding applied). target_mode: 'last' (only last prediction) or 'all' (predict at every step). """ self.sequences = sequences self.max_len = max_len self.target_mode = target_mode # Build index once self.index = [] for seq_id, seq in enumerate(sequences): for i in range(1, len(seq)): self.index.append((seq_id, i)) def __len__(self): return len(self.index) def __getitem__(self, idx): seq_id, cutoff = self.index[idx] seq = self.sequences[seq_id][:cutoff] # Truncate & pad seq = seq[-self.max_len:] pad_len = self.max_len - len(seq) input_seq = np.zeros(self.max_len, dtype=np.int64) input_seq[pad_len:] = seq if self.target_mode == "last": target = self.sequences[seq_id][cutoff] return torch.LongTensor(input_seq), torch.LongTensor([target]) elif self.target_mode == "all": # Predict next item at each step target_seq = self.sequences[seq_id][1:cutoff+1] target_seq = target_seq[-self.max_len:] target = np.zeros(self.max_len, dtype=np.int64) target[-len(target_seq):] = target_seq return torch.LongTensor(input_seq), torch.LongTensor(target) class SASRecDataModule(pl.LightningDataModule): """ PyTorch Lightning DataModule for preparing the RetailRocket dataset for the SASRec model. This class handles all aspects of data preparation, including: - Filtering out infrequent users and items to reduce noise. - Building a consistent item vocabulary. - Converting user event histories into sequential data. - Creating and providing `DataLoader` instances for training, validation, and testing. """ def __init__(self, train_df, val_df, test_df, min_item_interactions=5, min_user_interactions=5, max_len=50, batch_size=256): """ Initializes the DataModule. Args: train_df (pd.DataFrame): DataFrame for training. val_df (pd.DataFrame): DataFrame for validation. test_df (pd.DataFrame): DataFrame for testing. min_item_interactions (int): Minimum number of interactions for an item to be kept. min_user_interactions (int): Minimum number of interactions for a user to be kept. max_len (int): The maximum length of a user sequence fed to the model. batch_size (int): The batch size for the DataLoaders. """ super().__init__() self.train_df = train_df self.val_df = val_df self.test_df = test_df self.min_item_interactions = min_item_interactions self.min_user_interactions = min_user_interactions self.max_len = max_len self.batch_size = batch_size self.item_map = None self.inverse_item_map = None self.vocab_size = 0 self.user_history = None def setup(self, stage=None): """ Prepares the data for training, validation, and testing. This method is called automatically by PyTorch Lightning. It performs the following steps: 1. Determines filtering criteria (which users and items to keep) based on the training set only to prevent data leakage. 2. Applies these filters to the train, validation, and test sets. 3. Builds an item vocabulary (mapping item IDs to integer indices) from the combined training and validation sets to ensure consistency for model checkpointing. 4. Converts the event logs into sequences of item indices for each user in each data split. """ item_counts = self.train_df['itemid'].value_counts() user_counts = self.train_df['visitorid'].value_counts() items_to_keep = item_counts[item_counts >= self.min_item_interactions].index users_to_keep = user_counts[user_counts >= self.min_user_interactions].index self.filtered_train_df = self.train_df[ (self.train_df['itemid'].isin(items_to_keep)) & (self.train_df['visitorid'].isin(users_to_keep)) ].copy() self.filtered_val_df = self.val_df[ (self.val_df['itemid'].isin(items_to_keep)) & (self.val_df['visitorid'].isin(users_to_keep)) ].copy() self.filtered_test_df = self.test_df[ (self.test_df['itemid'].isin(items_to_keep)) & (self.test_df['visitorid'].isin(users_to_keep)) ].copy() all_known_items_df = pd.concat([self.filtered_train_df, self.filtered_val_df]) unique_items = all_known_items_df['itemid'].unique() self.item_map = {item_id: i + 1 for i, item_id in enumerate(unique_items)} self.inverse_item_map = {i: item_id for item_id, i in self.item_map.items()} self.vocab_size = len(self.item_map) + 1 # +1 for padding token 0 self.user_history = self.filtered_train_df.groupby('visitorid')['itemid'].apply(list) self.train_sequences = self._create_sequences(self.filtered_train_df) self.val_sequences = self._create_sequences(self.filtered_val_df) self.test_sequences = self._create_sequences(self.filtered_test_df) def _create_sequences(self, df): """ Helper function to convert a DataFrame of events into user interaction sequences. Args: df (pd.DataFrame): The input DataFrame to process. Returns: list[list[int]]: A list of user sequences, where each sequence is a list of item indices. """ df_sorted = df.sort_values(['visitorid', 'timestamp_dt']) sequences = df_sorted.groupby('visitorid')['itemid'].apply( lambda x: [self.item_map[i] for i in x if i in self.item_map] ).tolist() return [s for s in sequences if len(s) > 1] def train_dataloader(self): """Creates the DataLoader for the training set.""" dataset = SASRecDataset(self.train_sequences, self.max_len) return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0) def val_dataloader(self): """Creates the DataLoader for the validation set.""" dataset = SASRecDataset(self.val_sequences, self.max_len) return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0) def test_dataloader(self): """Creates the DataLoader for the test set.""" dataset = SASRecDataset(self.test_sequences, self.max_len) return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0) if __name__ == "__main__": # --- Configuration --- DATA_PATH = "data" ZIPED_DATA_PATH = "data/archive.zip" # change to your zip file path BATCH_SIZE = 256 MAX_TOKEN_LEN = 50 # 50–100 is standard for SASRec # extract_ziped_data(ZIPED_DATA_PATH, DATA_PATH) # uncomment this line if you want to extract the data # --- 1. Prepare the data into train, validation, and test sets --- train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH) # --- 2. Initialize DataModule --- print("Initializing DataModule...") datamodule = SASRecDataModule( train_df=train_set, val_df=validation_set, test_df=test_set, batch_size=BATCH_SIZE, max_len=MAX_TOKEN_LEN ) datamodule.setup()