Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| import torch | |
| from utils import prepare_ground_truth, calculate_metrics | |
| from data_prepare import prepare_data, SASRecDataset, SASRecDataModule | |
| from models import recommend_popular_items_and_evaluate, recommend_item_item_and_evaluate, recommend_als_and_evaluate, SASRec | |
| def train_and_eval_SASRec_model(train_set, validation_set, test_set, checkpoint_dir_path='checkpoints/', | |
| checkpoint_path=None, n_epochs=10, mode='train', | |
| batchsize=256, max_token_len=50, learning_rate=1e-3, hidden_dim=128, | |
| num_heads=2, num_layers=2, dropout=0.2, weight_decay=1e-6): | |
| """ | |
| Train or evaluate a SASRec sequential recommendation model using PyTorch Lightning. | |
| This function wraps the entire SASRec pipeline: | |
| - Initializes the SASRecDataModule (handles dataset preprocessing and dataloaders). | |
| - Builds the SASRec Transformer-based model. | |
| - Configures training callbacks (checkpointing, early stopping, LR monitoring). | |
| - Runs either training (`mode='train'`) or evaluation on the test set (`mode='test'`). | |
| Args | |
| ---------- | |
| train_set : pd.DataFrame | |
| Training interactions dataset . | |
| validation_set : pd.DataFrame | |
| Validation dataset with the same structure as `train_set`. | |
| test_set : pd.DataFrame | |
| Test dataset with the same structure as `train_set`. | |
| checkpoint_dir_path : str, optional (default='checkpoints/') | |
| Directory to save model checkpoints. | |
| checkpoint_path : str or None, optional (default=None) | |
| Path to a checkpoint file for resuming training or loading a pretrained model for testing. | |
| n_epochs : int, optional (default=10) | |
| Number of training epochs. | |
| mode : {'train', 'test'}, optional (default='train') | |
| - `'train'`: trains the model on the training/validation data. | |
| - `'test'`: evaluates the model on the test set using a checkpoint. | |
| batchsize : int, optional (default=256) | |
| Batch size for training and evaluation. | |
| max_token_len : int, optional (default=50) | |
| Maximum sequence length per user (recent interactions kept). | |
| learning_rate : float, optional (default=1e-3) | |
| Learning rate for the AdamW optimizer. | |
| hidden_dim : int, optional (default=128) | |
| Dimensionality of item and positional embeddings. | |
| num_heads : int, optional (default=2) | |
| Number of attention heads in each Transformer encoder layer. | |
| num_layers : int, optional (default=2) | |
| Number of Transformer encoder layers. | |
| dropout : float, optional (default=0.2) | |
| Dropout probability applied in embeddings and Transformer layers. | |
| weight_decay : float, optional (default=1e-6) | |
| Weight decay regularization coefficient for AdamW. | |
| """ | |
| # --- 1. Initialize DataModule --- | |
| print("Initializing DataModule...") | |
| datamodule = SASRecDataModule( | |
| train_df=train_set, | |
| val_df=validation_set, | |
| test_df=test_set, | |
| batch_size=batchsize, | |
| max_len=max_token_len | |
| ) | |
| datamodule.setup() | |
| # --- 2. Initialize Model --- | |
| print("Initializing SASRec model...") | |
| model = SASRec( | |
| vocab_size=datamodule.vocab_size, | |
| max_len=max_token_len, | |
| hidden_dim=hidden_dim, | |
| num_heads=num_heads, | |
| num_layers=num_layers, | |
| dropout=dropout, | |
| learning_rate=learning_rate, | |
| weight_decay=weight_decay | |
| ) | |
| # --- 3. Configure Training Callbacks --- | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=checkpoint_dir_path, | |
| filename="sasrec-{epoch:02d}-{val_hitrate@10:.4f}", | |
| save_top_k=1, | |
| verbose=True, | |
| monitor="val_hitrate@10", | |
| mode="max" | |
| ) | |
| early_stopping_callback = EarlyStopping( | |
| monitor="val_hitrate@10", # stop if ranking metric stagnates | |
| patience=5, | |
| mode="max" | |
| ) | |
| lr_monitor = LearningRateMonitor(logging_interval="step") | |
| logger = TensorBoardLogger("lightning_logs", name="sasrec") | |
| # --- 4. Initialize Trainer --- | |
| print("Initializing PyTorch Lightning Trainer...") | |
| trainer = pl.Trainer( | |
| logger=logger, | |
| callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor], | |
| max_epochs=n_epochs, | |
| accelerator='auto', | |
| devices=1, | |
| gradient_clip_val=1.0, # helps with exploding gradients | |
| ) | |
| if mode == 'train' : | |
| # --- 5. Start Training --- | |
| print(f"Starting training for up to {n_epochs} epochs...") | |
| trainer.fit(model, datamodule, | |
| ckpt_path=checkpoint_path | |
| ) | |
| elif mode == 'test': | |
| # --- 6. Test on best checkpoint --- | |
| print("Evaluating on test set...") | |
| trainer.test(model, datamodule, | |
| ckpt_path=checkpoint_path | |
| ) | |
| # --- Main Execution Block --- | |
| if __name__ == "__main__": | |
| # --- Configuration --- | |
| BATCH_SIZE = 256 | |
| MAX_TOKEN_LEN = 50 # 50–100 is standard | |
| LEARNING_RATE = 1e-3 | |
| HIDDEN_DIM = 128 | |
| NUM_HEADS = 2 | |
| NUM_LAYERS = 2 | |
| DROPOUT = 0.2 | |
| WEIGHT_DECAY = 1e-6 | |
| N_EPOCHS = 50 | |
| CHECKPOINT_SAVE_PATH = 'checkpoints/' | |
| CHECKPOINT_LOAD_PATH = None # or specify a path to a checkpoint file | |
| MODE = 'train' # 'train' or 'test' | |
| train_set, validation_set, test_set = prepare_data(data_folder='data/') | |
| if train_set is not None: | |
| results = {} | |
| full_train_set = pd.concat([train_set, validation_set]) | |
| # Evaluate classical models | |
| print("\n>>> Running evaluations on the VALIDATION set <<<") | |
| results['Popularity (Validation)'] = recommend_popular_items_and_evaluate(train_set, validation_set) | |
| results['Item-Item CF (Validation)'] = recommend_item_item_and_evaluate(train_set, validation_set) | |
| results['ALS (Validation)'] = recommend_als_and_evaluate(train_set, validation_set) | |
| print("\n>>> Running final evaluations on the TEST set <<<") | |
| results['Popularity (Test)'] = recommend_popular_items_and_evaluate(full_train_set, test_set) | |
| results['Item-Item CF (Test)'] = recommend_item_item_and_evaluate(full_train_set, test_set) | |
| results['ALS (Test)'] = recommend_als_and_evaluate(full_train_set, test_set) | |
| print("\n--- Final Evaluation Results ---") | |
| results_df = pd.DataFrame.from_dict(results, orient='index') | |
| print(results_df) | |
| print("--------------------------------") | |
| # Train and evaluate SASRec model | |
| print("\n>>> Training and evaluating SASRec model <<<") | |
| train_and_eval_SASRec_model(train_set, validation_set, test_set, n_epochs=10, mode='train') | |
| print("\n>>> Evaluating trained SASRec model on TEST set <<<") | |
| train_and_eval_SASRec_model(train_set, validation_set, test_set, mode='test') |