RecommenderSystem / scripts /train_and_eval.py
DanielKiani's picture
Initial commit of recommender system project
38ae75d
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from utils import prepare_ground_truth, calculate_metrics
from data_prepare import prepare_data, SASRecDataset, SASRecDataModule
from models import recommend_popular_items_and_evaluate, recommend_item_item_and_evaluate, recommend_als_and_evaluate, SASRec
def train_and_eval_SASRec_model(train_set, validation_set, test_set, checkpoint_dir_path='checkpoints/',
checkpoint_path=None, n_epochs=10, mode='train',
batchsize=256, max_token_len=50, learning_rate=1e-3, hidden_dim=128,
num_heads=2, num_layers=2, dropout=0.2, weight_decay=1e-6):
"""
Train or evaluate a SASRec sequential recommendation model using PyTorch Lightning.
This function wraps the entire SASRec pipeline:
- Initializes the SASRecDataModule (handles dataset preprocessing and dataloaders).
- Builds the SASRec Transformer-based model.
- Configures training callbacks (checkpointing, early stopping, LR monitoring).
- Runs either training (`mode='train'`) or evaluation on the test set (`mode='test'`).
Args
----------
train_set : pd.DataFrame
Training interactions dataset .
validation_set : pd.DataFrame
Validation dataset with the same structure as `train_set`.
test_set : pd.DataFrame
Test dataset with the same structure as `train_set`.
checkpoint_dir_path : str, optional (default='checkpoints/')
Directory to save model checkpoints.
checkpoint_path : str or None, optional (default=None)
Path to a checkpoint file for resuming training or loading a pretrained model for testing.
n_epochs : int, optional (default=10)
Number of training epochs.
mode : {'train', 'test'}, optional (default='train')
- `'train'`: trains the model on the training/validation data.
- `'test'`: evaluates the model on the test set using a checkpoint.
batchsize : int, optional (default=256)
Batch size for training and evaluation.
max_token_len : int, optional (default=50)
Maximum sequence length per user (recent interactions kept).
learning_rate : float, optional (default=1e-3)
Learning rate for the AdamW optimizer.
hidden_dim : int, optional (default=128)
Dimensionality of item and positional embeddings.
num_heads : int, optional (default=2)
Number of attention heads in each Transformer encoder layer.
num_layers : int, optional (default=2)
Number of Transformer encoder layers.
dropout : float, optional (default=0.2)
Dropout probability applied in embeddings and Transformer layers.
weight_decay : float, optional (default=1e-6)
Weight decay regularization coefficient for AdamW.
"""
# --- 1. Initialize DataModule ---
print("Initializing DataModule...")
datamodule = SASRecDataModule(
train_df=train_set,
val_df=validation_set,
test_df=test_set,
batch_size=batchsize,
max_len=max_token_len
)
datamodule.setup()
# --- 2. Initialize Model ---
print("Initializing SASRec model...")
model = SASRec(
vocab_size=datamodule.vocab_size,
max_len=max_token_len,
hidden_dim=hidden_dim,
num_heads=num_heads,
num_layers=num_layers,
dropout=dropout,
learning_rate=learning_rate,
weight_decay=weight_decay
)
# --- 3. Configure Training Callbacks ---
checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir_path,
filename="sasrec-{epoch:02d}-{val_hitrate@10:.4f}",
save_top_k=1,
verbose=True,
monitor="val_hitrate@10",
mode="max"
)
early_stopping_callback = EarlyStopping(
monitor="val_hitrate@10", # stop if ranking metric stagnates
patience=5,
mode="max"
)
lr_monitor = LearningRateMonitor(logging_interval="step")
logger = TensorBoardLogger("lightning_logs", name="sasrec")
# --- 4. Initialize Trainer ---
print("Initializing PyTorch Lightning Trainer...")
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
max_epochs=n_epochs,
accelerator='auto',
devices=1,
gradient_clip_val=1.0, # helps with exploding gradients
)
if mode == 'train' :
# --- 5. Start Training ---
print(f"Starting training for up to {n_epochs} epochs...")
trainer.fit(model, datamodule,
ckpt_path=checkpoint_path
)
elif mode == 'test':
# --- 6. Test on best checkpoint ---
print("Evaluating on test set...")
trainer.test(model, datamodule,
ckpt_path=checkpoint_path
)
# --- Main Execution Block ---
if __name__ == "__main__":
# --- Configuration ---
BATCH_SIZE = 256
MAX_TOKEN_LEN = 50 # 50–100 is standard
LEARNING_RATE = 1e-3
HIDDEN_DIM = 128
NUM_HEADS = 2
NUM_LAYERS = 2
DROPOUT = 0.2
WEIGHT_DECAY = 1e-6
N_EPOCHS = 50
CHECKPOINT_SAVE_PATH = 'checkpoints/'
CHECKPOINT_LOAD_PATH = None # or specify a path to a checkpoint file
MODE = 'train' # 'train' or 'test'
train_set, validation_set, test_set = prepare_data(data_folder='data/')
if train_set is not None:
results = {}
full_train_set = pd.concat([train_set, validation_set])
# Evaluate classical models
print("\n>>> Running evaluations on the VALIDATION set <<<")
results['Popularity (Validation)'] = recommend_popular_items_and_evaluate(train_set, validation_set)
results['Item-Item CF (Validation)'] = recommend_item_item_and_evaluate(train_set, validation_set)
results['ALS (Validation)'] = recommend_als_and_evaluate(train_set, validation_set)
print("\n>>> Running final evaluations on the TEST set <<<")
results['Popularity (Test)'] = recommend_popular_items_and_evaluate(full_train_set, test_set)
results['Item-Item CF (Test)'] = recommend_item_item_and_evaluate(full_train_set, test_set)
results['ALS (Test)'] = recommend_als_and_evaluate(full_train_set, test_set)
print("\n--- Final Evaluation Results ---")
results_df = pd.DataFrame.from_dict(results, orient='index')
print(results_df)
print("--------------------------------")
# Train and evaluate SASRec model
print("\n>>> Training and evaluating SASRec model <<<")
train_and_eval_SASRec_model(train_set, validation_set, test_set, n_epochs=10, mode='train')
print("\n>>> Evaluating trained SASRec model on TEST set <<<")
train_and_eval_SASRec_model(train_set, validation_set, test_set, mode='test')