File size: 7,632 Bytes
2cb039f
 
 
 
 
 
 
 
 
38ae75d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
---
title: SASRec Sequential Recommender
emoji: πŸ›οΈ
colorFrom: blue
colorTo: indigo
sdk: gradio
app_file: scripts/app.py
---

![Recomm](assets/banner.png)
[![Python](https://img.shields.io/badge/Python-3.10-blue?logo=python)](https://www.python.org/)[![PyTorch](https://img.shields.io/badge/PyTorch-2.7.1-EE4C2C?logo=pytorch)](https://pytorch.org/)![Made with ML](https://img.shields.io/badge/Made%20with-ML-blueviolet?logo=openai)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)

# πŸš€ End-to-End Sequential Recommender System  

This project implements and evaluates a series of recommender system models, culminating in a state-of-the-art **SASRec (Self-Attentive Sequential Recommendation)** model for Top-N next-item prediction. The system is trained on the [RetailRocket e-commerce dataset](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset) and includes an interactive web demo built with Gradio.  

![Gradio app](assets/gradio.png)
You can find the Gradio app [Here](https://www.kaggle.com/datasets/kritanjalijain/amazon-reviews)

---

## πŸ“‘ Table of Contents  

- [πŸ“– Project Overview](#-project-overview)  
- [✨ Key Features](#-key-features)  
- [🧩 Models Implemented](#-models-implemented)  
- [πŸ“Š Final Results](#-final-results)  
- [πŸ” Qualitative Analysis](#-qualitative-analysis)  
- [🚧 Future Improvements](#-future-improvements)  
- [πŸ“‚ Project Structure](#-project-structure)  
- [βš™οΈ Setup and Usage](#️-setup-and-usage)  
- [πŸ› οΈ Technologies and Models Used](#️-technologies-and-models-used)  

---

## πŸ“– Project Overview  

The primary goal of this project is to predict the next item a user is likely to interact with based on their recent session history. This is a common and critical task in e-commerce known as Top-N sequential recommendation.  

The project follows a structured approach:  

1. **Baseline Models**: Simple, non-sequential models to establish a performance baseline.  
2. **Hyperparameter Tuning**: Optuna is used to find the optimal configuration for ALS.  
3. **Advanced Sequential Model**: Implementation of **SASRec** with PyTorch Lightning.  
4. **Evaluation**: Offline evaluation using ranking metrics (Hit Rate, Precision, Recall @ 10).  
5. **Interactive Demo**: A Gradio web app for real-time personalized and cold-start recommendations.  

---

## ✨ Key Features  

- πŸ”Ή **Comprehensive Model Comparison**: From popularity to Transformer-based SASRec.  
- πŸ”Ή **Robust Evaluation**: Time-based data split for realistic performance measurement.  
- πŸ”Ή **Hyperparameter Optimization**: Automated with Optuna for ALS.  
- πŸ”Ή **Deep Learning with Attention**: Full PyTorch Lightning implementation of SASRec.  
- πŸ”Ή **Interactive Web Demo**: Live Gradio app for recommendations.  
- πŸ”Ή **Modular Codebase**: Clean, organized structure.  

---

## 🧩 Models Implemented  

| Model | Methodology | Key Characteristics |
| :--- | :--- | :--- |
| **Popularity** | Non-personalized | Recommends the most frequently purchased items across all users. |
| **Item-Item CF** | Collaborative Filtering | Recommends items similar to a user’s past interactions. |
| **ALS** | Matrix Factorization | Learns latent embeddings from implicit feedback, tuned with Optuna. |
| **SASRec** | Transformer (Self-Attention) | Sequential model capturing contextual user-item interactions. |

---

## πŸ“Š Final Results  

SASRec significantly outperformed all baselines, with a **~4.7x improvement in Hit Rate**.  

| Model | Test Hit Rate@10 | Test Precision@10 | Test Recall@10 |
| :--- | :---: | :---: | :---: |
| Popularity | 0.0651 | 0.0065 | 0.0324 |
| Item-Item CF | 0.0021 | 0.0002 | 0.0011 |
| Tuned ALS | 0.0063 | 0.0006 | 0.0042 |
| **SASRec** | **0.3069** | **0.0307** | **0.3069** |

---

## πŸ” Qualitative Analysis  

The SASRec model not only recommends previously viewed items but also discovers **new, contextually relevant items**.  
For example, for a user browsing **Category 1279**, SASRec suggested new items from the same category β€” showing strong personalization and discovery.  

---

## 🚧 Future Improvements  

- πŸ“¦ **Incorporate Item Features** (e.g., from `item_properties.csv`).  
- πŸ€– **Explore Advanced Models**:  
  - BERT4Rec (bidirectional Transformers).  
  - Graph-based recommender systems.  
- πŸ§ͺ **Online A/B Testing** for business impact.  
- ⚑ **Scalability Enhancements**: Feature stores, inference servers (Triton), quantization, distillation.  

---

## πŸ“‚ Project Structure  

```bash
β”œβ”€β”€ checkpoints/              # Saved PyTorch Lightning checkpoints
β”œβ”€β”€ data/                     # RetailRocket dataset
β”œβ”€β”€ notebooks/                # EDA notebooks
└── scripts/                  
    β”œβ”€β”€ als_optuna_study.py   # Optuna tuning for ALS
    β”œβ”€β”€ app.py                # Gradio web demo
    β”œβ”€β”€ data_prepare.py       # Data loading & preprocessing
    β”œβ”€β”€ main.py               # Entry point for demo
    β”œβ”€β”€ models.py             # Model definitions
    β”œβ”€β”€ train_and_eval.py     # Training & evaluation loop
    └── utils.py              # Helper functions
β”œβ”€β”€ README.md  
└── requirements.txt  
```

---

## βš™οΈ Setup and Usage

Follow these steps to set up and run the project locally.

### 1. Prerequisites

- Python 3.10.6+
- An NVIDIA GPU is recommended for training the SASRec model.

### 2. Clone the Repository

```bash
git clone <your-repo-url>
cd <your-repo-name>
```

### 3. Install all required packages

```bash
pip install -r requirements.txt
```

### 4. Download and Place Data

- Download the [RetailRocket e-commerce dataset](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset).

Then run this script:

```bash
python data_prepare.py
```

### 5. Run the Full Evaluation

To train all models and see the final comparison table, run the main script:

```bash
python train_and_eval.py
```

### 6. Run the main script

```bash
python main.py
```

---

## πŸ› οΈ Technologies and Models Used

This project leverages a range of modern data science and machine learning technologies to build a robust recommender system from the ground up.

### 🏭 Models

- **Popularity Model**: A non-personalized baseline that recommends the most frequently purchased items.
- **Item-Item Collaborative Filtering**: A classical neighborhood-based model that recommends items based on co-occurrence patterns with a user's interaction history.
- **Alternating Least Squares (ALS)**: A powerful matrix factorization technique for implicit feedback, optimized with hyperparameter tuning.
- **SASRec (Self-Attentive Sequential Recommendation)**: A state-of-the-art sequential model based on the Transformer architecture, designed to capture the order and context of user interactions.

### πŸ‘©β€πŸ’» Core Technologies & Libraries

- **Python 3.10**: The primary programming language for the project.
- **Pandas & NumPy**: For efficient data manipulation, preprocessing, and numerical operations.
- **Scikit-learn**: Used for calculating item similarity in the collaborative filtering model.
- **Implicit**: For the ALS model
- **PyTorch & PyTorch Lightning**: The deep learning framework used to build, train, and evaluate the SASRec model in a structured and scalable way.
- **Optuna**: A hyperparameter optimization framework used to automatically find the best parameters for the ALS model.
- **Gradio**: A fast and simple framework used to build and deploy the interactive web demo.
- **TensorBoard**: For logging and visualizing model training metrics.