RecommenderSystem / README.md
Daniel kiani
Update README.md
2cb039f verified

A newer version of the Gradio SDK is available: 6.2.0

Upgrade
metadata
title: SASRec Sequential Recommender
emoji: πŸ›οΈ
colorFrom: blue
colorTo: indigo
sdk: gradio
app_file: scripts/app.py

Recomm PythonPyTorchMade with MLLicense: MIT

πŸš€ End-to-End Sequential Recommender System

This project implements and evaluates a series of recommender system models, culminating in a state-of-the-art SASRec (Self-Attentive Sequential Recommendation) model for Top-N next-item prediction. The system is trained on the RetailRocket e-commerce dataset and includes an interactive web demo built with Gradio.

Gradio app You can find the Gradio app Here


πŸ“‘ Table of Contents


πŸ“– Project Overview

The primary goal of this project is to predict the next item a user is likely to interact with based on their recent session history. This is a common and critical task in e-commerce known as Top-N sequential recommendation.

The project follows a structured approach:

  1. Baseline Models: Simple, non-sequential models to establish a performance baseline.
  2. Hyperparameter Tuning: Optuna is used to find the optimal configuration for ALS.
  3. Advanced Sequential Model: Implementation of SASRec with PyTorch Lightning.
  4. Evaluation: Offline evaluation using ranking metrics (Hit Rate, Precision, Recall @ 10).
  5. Interactive Demo: A Gradio web app for real-time personalized and cold-start recommendations.

✨ Key Features

  • πŸ”Ή Comprehensive Model Comparison: From popularity to Transformer-based SASRec.
  • πŸ”Ή Robust Evaluation: Time-based data split for realistic performance measurement.
  • πŸ”Ή Hyperparameter Optimization: Automated with Optuna for ALS.
  • πŸ”Ή Deep Learning with Attention: Full PyTorch Lightning implementation of SASRec.
  • πŸ”Ή Interactive Web Demo: Live Gradio app for recommendations.
  • πŸ”Ή Modular Codebase: Clean, organized structure.

🧩 Models Implemented

Model Methodology Key Characteristics
Popularity Non-personalized Recommends the most frequently purchased items across all users.
Item-Item CF Collaborative Filtering Recommends items similar to a user’s past interactions.
ALS Matrix Factorization Learns latent embeddings from implicit feedback, tuned with Optuna.
SASRec Transformer (Self-Attention) Sequential model capturing contextual user-item interactions.

πŸ“Š Final Results

SASRec significantly outperformed all baselines, with a ~4.7x improvement in Hit Rate.

Model Test Hit Rate@10 Test Precision@10 Test Recall@10
Popularity 0.0651 0.0065 0.0324
Item-Item CF 0.0021 0.0002 0.0011
Tuned ALS 0.0063 0.0006 0.0042
SASRec 0.3069 0.0307 0.3069

πŸ” Qualitative Analysis

The SASRec model not only recommends previously viewed items but also discovers new, contextually relevant items.
For example, for a user browsing Category 1279, SASRec suggested new items from the same category β€” showing strong personalization and discovery.


🚧 Future Improvements

  • πŸ“¦ Incorporate Item Features (e.g., from item_properties.csv).
  • πŸ€– Explore Advanced Models:
    • BERT4Rec (bidirectional Transformers).
    • Graph-based recommender systems.
  • πŸ§ͺ Online A/B Testing for business impact.
  • ⚑ Scalability Enhancements: Feature stores, inference servers (Triton), quantization, distillation.

πŸ“‚ Project Structure

β”œβ”€β”€ checkpoints/              # Saved PyTorch Lightning checkpoints
β”œβ”€β”€ data/                     # RetailRocket dataset
β”œβ”€β”€ notebooks/                # EDA notebooks
└── scripts/                  
    β”œβ”€β”€ als_optuna_study.py   # Optuna tuning for ALS
    β”œβ”€β”€ app.py                # Gradio web demo
    β”œβ”€β”€ data_prepare.py       # Data loading & preprocessing
    β”œβ”€β”€ main.py               # Entry point for demo
    β”œβ”€β”€ models.py             # Model definitions
    β”œβ”€β”€ train_and_eval.py     # Training & evaluation loop
    └── utils.py              # Helper functions
β”œβ”€β”€ README.md  
└── requirements.txt  

βš™οΈ Setup and Usage

Follow these steps to set up and run the project locally.

1. Prerequisites

  • Python 3.10.6+
  • An NVIDIA GPU is recommended for training the SASRec model.

2. Clone the Repository

git clone <your-repo-url>
cd <your-repo-name>

3. Install all required packages

pip install -r requirements.txt

4. Download and Place Data

Then run this script:

python data_prepare.py

5. Run the Full Evaluation

To train all models and see the final comparison table, run the main script:

python train_and_eval.py

6. Run the main script

python main.py

πŸ› οΈ Technologies and Models Used

This project leverages a range of modern data science and machine learning technologies to build a robust recommender system from the ground up.

🏭 Models

  • Popularity Model: A non-personalized baseline that recommends the most frequently purchased items.
  • Item-Item Collaborative Filtering: A classical neighborhood-based model that recommends items based on co-occurrence patterns with a user's interaction history.
  • Alternating Least Squares (ALS): A powerful matrix factorization technique for implicit feedback, optimized with hyperparameter tuning.
  • SASRec (Self-Attentive Sequential Recommendation): A state-of-the-art sequential model based on the Transformer architecture, designed to capture the order and context of user interactions.

πŸ‘©β€πŸ’» Core Technologies & Libraries

  • Python 3.10: The primary programming language for the project.
  • Pandas & NumPy: For efficient data manipulation, preprocessing, and numerical operations.
  • Scikit-learn: Used for calculating item similarity in the collaborative filtering model.
  • Implicit: For the ALS model
  • PyTorch & PyTorch Lightning: The deep learning framework used to build, train, and evaluate the SASRec model in a structured and scalable way.
  • Optuna: A hyperparameter optimization framework used to automatically find the best parameters for the ALS model.
  • Gradio: A fast and simple framework used to build and deploy the interactive web demo.
  • TensorBoard: For logging and visualizing model training metrics.