Spaces:
Sleeping
A newer version of the Gradio SDK is available:
6.2.0
title: SASRec Sequential Recommender
emoji: ποΈ
colorFrom: blue
colorTo: indigo
sdk: gradio
app_file: scripts/app.py
π End-to-End Sequential Recommender System
This project implements and evaluates a series of recommender system models, culminating in a state-of-the-art SASRec (Self-Attentive Sequential Recommendation) model for Top-N next-item prediction. The system is trained on the RetailRocket e-commerce dataset and includes an interactive web demo built with Gradio.
You can find the Gradio app Here
π Table of Contents
- π Project Overview
- β¨ Key Features
- π§© Models Implemented
- π Final Results
- π Qualitative Analysis
- π§ Future Improvements
- π Project Structure
- βοΈ Setup and Usage
- π οΈ Technologies and Models Used
π Project Overview
The primary goal of this project is to predict the next item a user is likely to interact with based on their recent session history. This is a common and critical task in e-commerce known as Top-N sequential recommendation.
The project follows a structured approach:
- Baseline Models: Simple, non-sequential models to establish a performance baseline.
- Hyperparameter Tuning: Optuna is used to find the optimal configuration for ALS.
- Advanced Sequential Model: Implementation of SASRec with PyTorch Lightning.
- Evaluation: Offline evaluation using ranking metrics (Hit Rate, Precision, Recall @ 10).
- Interactive Demo: A Gradio web app for real-time personalized and cold-start recommendations.
β¨ Key Features
- πΉ Comprehensive Model Comparison: From popularity to Transformer-based SASRec.
- πΉ Robust Evaluation: Time-based data split for realistic performance measurement.
- πΉ Hyperparameter Optimization: Automated with Optuna for ALS.
- πΉ Deep Learning with Attention: Full PyTorch Lightning implementation of SASRec.
- πΉ Interactive Web Demo: Live Gradio app for recommendations.
- πΉ Modular Codebase: Clean, organized structure.
π§© Models Implemented
| Model | Methodology | Key Characteristics |
|---|---|---|
| Popularity | Non-personalized | Recommends the most frequently purchased items across all users. |
| Item-Item CF | Collaborative Filtering | Recommends items similar to a userβs past interactions. |
| ALS | Matrix Factorization | Learns latent embeddings from implicit feedback, tuned with Optuna. |
| SASRec | Transformer (Self-Attention) | Sequential model capturing contextual user-item interactions. |
π Final Results
SASRec significantly outperformed all baselines, with a ~4.7x improvement in Hit Rate.
| Model | Test Hit Rate@10 | Test Precision@10 | Test Recall@10 |
|---|---|---|---|
| Popularity | 0.0651 | 0.0065 | 0.0324 |
| Item-Item CF | 0.0021 | 0.0002 | 0.0011 |
| Tuned ALS | 0.0063 | 0.0006 | 0.0042 |
| SASRec | 0.3069 | 0.0307 | 0.3069 |
π Qualitative Analysis
The SASRec model not only recommends previously viewed items but also discovers new, contextually relevant items.
For example, for a user browsing Category 1279, SASRec suggested new items from the same category β showing strong personalization and discovery.
π§ Future Improvements
- π¦ Incorporate Item Features (e.g., from
item_properties.csv). - π€ Explore Advanced Models:
- BERT4Rec (bidirectional Transformers).
- Graph-based recommender systems.
- π§ͺ Online A/B Testing for business impact.
- β‘ Scalability Enhancements: Feature stores, inference servers (Triton), quantization, distillation.
π Project Structure
βββ checkpoints/ # Saved PyTorch Lightning checkpoints
βββ data/ # RetailRocket dataset
βββ notebooks/ # EDA notebooks
βββ scripts/
βββ als_optuna_study.py # Optuna tuning for ALS
βββ app.py # Gradio web demo
βββ data_prepare.py # Data loading & preprocessing
βββ main.py # Entry point for demo
βββ models.py # Model definitions
βββ train_and_eval.py # Training & evaluation loop
βββ utils.py # Helper functions
βββ README.md
βββ requirements.txt
βοΈ Setup and Usage
Follow these steps to set up and run the project locally.
1. Prerequisites
- Python 3.10.6+
- An NVIDIA GPU is recommended for training the SASRec model.
2. Clone the Repository
git clone <your-repo-url>
cd <your-repo-name>
3. Install all required packages
pip install -r requirements.txt
4. Download and Place Data
- Download the RetailRocket e-commerce dataset.
Then run this script:
python data_prepare.py
5. Run the Full Evaluation
To train all models and see the final comparison table, run the main script:
python train_and_eval.py
6. Run the main script
python main.py
π οΈ Technologies and Models Used
This project leverages a range of modern data science and machine learning technologies to build a robust recommender system from the ground up.
π Models
- Popularity Model: A non-personalized baseline that recommends the most frequently purchased items.
- Item-Item Collaborative Filtering: A classical neighborhood-based model that recommends items based on co-occurrence patterns with a user's interaction history.
- Alternating Least Squares (ALS): A powerful matrix factorization technique for implicit feedback, optimized with hyperparameter tuning.
- SASRec (Self-Attentive Sequential Recommendation): A state-of-the-art sequential model based on the Transformer architecture, designed to capture the order and context of user interactions.
π©βπ» Core Technologies & Libraries
- Python 3.10: The primary programming language for the project.
- Pandas & NumPy: For efficient data manipulation, preprocessing, and numerical operations.
- Scikit-learn: Used for calculating item similarity in the collaborative filtering model.
- Implicit: For the ALS model
- PyTorch & PyTorch Lightning: The deep learning framework used to build, train, and evaluate the SASRec model in a structured and scalable way.
- Optuna: A hyperparameter optimization framework used to automatically find the best parameters for the ALS model.
- Gradio: A fast and simple framework used to build and deploy the interactive web demo.
- TensorBoard: For logging and visualizing model training metrics.
