Spaces:
Sleeping
Sleeping
Commit
·
38ae75d
0
Parent(s):
Initial commit of recommender system project
Browse files- .gitattributes +2 -0
- .gitignore +18 -0
- README.md +182 -0
- assets/banner.png +3 -0
- checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt +3 -0
- notebooks/lightning_logs/sasrec/version_0/events.out.tfevents.1757789891.DESKTOP-48K6QDS.24476.0 +0 -0
- notebooks/lightning_logs/sasrec/version_0/hparams.yaml +7 -0
- notebooks/lightning_logs/sasrec/version_1/events.out.tfevents.1757790260.DESKTOP-48K6QDS.24476.1 +0 -0
- notebooks/lightning_logs/sasrec/version_1/hparams.yaml +7 -0
- notebooks/lightning_logs/sasrec/version_2/events.out.tfevents.1757868923.DESKTOP-48K6QDS.22820.0 +0 -0
- notebooks/lightning_logs/sasrec/version_2/hparams.yaml +7 -0
- notebooks/lightning_logs/sasrec/version_3/events.out.tfevents.1757868963.DESKTOP-48K6QDS.22820.1 +0 -0
- notebooks/lightning_logs/sasrec/version_3/hparams.yaml +7 -0
- notebooks/reccomender.ipynb +1419 -0
- requirements.txt +0 -0
- scripts/als_optuna_study.py +53 -0
- scripts/app.py +162 -0
- scripts/data_prepare.py +253 -0
- scripts/main.py +46 -0
- scripts/models.py +408 -0
- scripts/train_and_eval.py +172 -0
- scripts/utils.py +196 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
|
| 5 |
+
# Virtual Environment
|
| 6 |
+
venv/
|
| 7 |
+
.venv/
|
| 8 |
+
|
| 9 |
+
# Data and Logs
|
| 10 |
+
data/
|
| 11 |
+
logs/
|
| 12 |
+
notebooks/data/
|
| 13 |
+
notebooks/logs/
|
| 14 |
+
|
| 15 |
+
# IDE files
|
| 16 |
+
.vscode/
|
| 17 |
+
.idea/
|
| 18 |
+
|
README.md
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
[](https://www.python.org/)[](https://pytorch.org/)[](LICENSE)
|
| 3 |
+
|
| 4 |
+
# 🚀 End-to-End Sequential Recommender System
|
| 5 |
+
|
| 6 |
+
This project implements and evaluates a series of recommender system models, culminating in a state-of-the-art **SASRec (Self-Attentive Sequential Recommendation)** model for Top-N next-item prediction. The system is trained on the [RetailRocket e-commerce dataset](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset) and includes an interactive web demo built with Gradio.
|
| 7 |
+
|
| 8 |
+

|
| 9 |
+
You can find the Gradio app [Here](https://www.kaggle.com/datasets/kritanjalijain/amazon-reviews)
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 📑 Table of Contents
|
| 14 |
+
|
| 15 |
+
- [📖 Project Overview](#-project-overview)
|
| 16 |
+
- [✨ Key Features](#-key-features)
|
| 17 |
+
- [🧩 Models Implemented](#-models-implemented)
|
| 18 |
+
- [📊 Final Results](#-final-results)
|
| 19 |
+
- [🔍 Qualitative Analysis](#-qualitative-analysis)
|
| 20 |
+
- [🚧 Future Improvements](#-future-improvements)
|
| 21 |
+
- [📂 Project Structure](#-project-structure)
|
| 22 |
+
- [⚙️ Setup and Usage](#️-setup-and-usage)
|
| 23 |
+
- [🛠️ Technologies and Models Used](#️-technologies-and-models-used)
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## 📖 Project Overview
|
| 28 |
+
|
| 29 |
+
The primary goal of this project is to predict the next item a user is likely to interact with based on their recent session history. This is a common and critical task in e-commerce known as Top-N sequential recommendation.
|
| 30 |
+
|
| 31 |
+
The project follows a structured approach:
|
| 32 |
+
|
| 33 |
+
1. **Baseline Models**: Simple, non-sequential models to establish a performance baseline.
|
| 34 |
+
2. **Hyperparameter Tuning**: Optuna is used to find the optimal configuration for ALS.
|
| 35 |
+
3. **Advanced Sequential Model**: Implementation of **SASRec** with PyTorch Lightning.
|
| 36 |
+
4. **Evaluation**: Offline evaluation using ranking metrics (Hit Rate, Precision, Recall @ 10).
|
| 37 |
+
5. **Interactive Demo**: A Gradio web app for real-time personalized and cold-start recommendations.
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## ✨ Key Features
|
| 42 |
+
|
| 43 |
+
- 🔹 **Comprehensive Model Comparison**: From popularity to Transformer-based SASRec.
|
| 44 |
+
- 🔹 **Robust Evaluation**: Time-based data split for realistic performance measurement.
|
| 45 |
+
- 🔹 **Hyperparameter Optimization**: Automated with Optuna for ALS.
|
| 46 |
+
- 🔹 **Deep Learning with Attention**: Full PyTorch Lightning implementation of SASRec.
|
| 47 |
+
- 🔹 **Interactive Web Demo**: Live Gradio app for recommendations.
|
| 48 |
+
- 🔹 **Modular Codebase**: Clean, organized structure.
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## 🧩 Models Implemented
|
| 53 |
+
|
| 54 |
+
| Model | Methodology | Key Characteristics |
|
| 55 |
+
| :--- | :--- | :--- |
|
| 56 |
+
| **Popularity** | Non-personalized | Recommends the most frequently purchased items across all users. |
|
| 57 |
+
| **Item-Item CF** | Collaborative Filtering | Recommends items similar to a user’s past interactions. |
|
| 58 |
+
| **ALS** | Matrix Factorization | Learns latent embeddings from implicit feedback, tuned with Optuna. |
|
| 59 |
+
| **SASRec** | Transformer (Self-Attention) | Sequential model capturing contextual user-item interactions. |
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## 📊 Final Results
|
| 64 |
+
|
| 65 |
+
SASRec significantly outperformed all baselines, with a **~4.7x improvement in Hit Rate**.
|
| 66 |
+
|
| 67 |
+
| Model | Test Hit Rate@10 | Test Precision@10 | Test Recall@10 |
|
| 68 |
+
| :--- | :---: | :---: | :---: |
|
| 69 |
+
| Popularity | 0.0651 | 0.0065 | 0.0324 |
|
| 70 |
+
| Item-Item CF | 0.0021 | 0.0002 | 0.0011 |
|
| 71 |
+
| Tuned ALS | 0.0063 | 0.0006 | 0.0042 |
|
| 72 |
+
| **SASRec** | **0.3069** | **0.0307** | **0.3069** |
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## 🔍 Qualitative Analysis
|
| 77 |
+
|
| 78 |
+
The SASRec model not only recommends previously viewed items but also discovers **new, contextually relevant items**.
|
| 79 |
+
For example, for a user browsing **Category 1279**, SASRec suggested new items from the same category — showing strong personalization and discovery.
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## 🚧 Future Improvements
|
| 84 |
+
|
| 85 |
+
- 📦 **Incorporate Item Features** (e.g., from `item_properties.csv`).
|
| 86 |
+
- 🤖 **Explore Advanced Models**:
|
| 87 |
+
- BERT4Rec (bidirectional Transformers).
|
| 88 |
+
- Graph-based recommender systems.
|
| 89 |
+
- 🧪 **Online A/B Testing** for business impact.
|
| 90 |
+
- ⚡ **Scalability Enhancements**: Feature stores, inference servers (Triton), quantization, distillation.
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
## 📂 Project Structure
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
├── checkpoints/ # Saved PyTorch Lightning checkpoints
|
| 98 |
+
├── data/ # RetailRocket dataset
|
| 99 |
+
├── notebooks/ # EDA notebooks
|
| 100 |
+
└── scripts/
|
| 101 |
+
├── als_optuna_study.py # Optuna tuning for ALS
|
| 102 |
+
├── app.py # Gradio web demo
|
| 103 |
+
├── data_prepare.py # Data loading & preprocessing
|
| 104 |
+
├── main.py # Entry point for demo
|
| 105 |
+
├── models.py # Model definitions
|
| 106 |
+
├── train_and_eval.py # Training & evaluation loop
|
| 107 |
+
└── utils.py # Helper functions
|
| 108 |
+
├── README.md
|
| 109 |
+
└── requirements.txt
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
## ⚙️ Setup and Usage
|
| 115 |
+
|
| 116 |
+
Follow these steps to set up and run the project locally.
|
| 117 |
+
|
| 118 |
+
### 1. Prerequisites
|
| 119 |
+
|
| 120 |
+
- Python 3.10.6+
|
| 121 |
+
- An NVIDIA GPU is recommended for training the SASRec model.
|
| 122 |
+
|
| 123 |
+
### 2. Clone the Repository
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
git clone <your-repo-url>
|
| 127 |
+
cd <your-repo-name>
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
### 3. Install all required packages
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
pip install -r requirements.txt
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### 4. Download and Place Data
|
| 137 |
+
|
| 138 |
+
- Download the [RetailRocket e-commerce dataset](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset).
|
| 139 |
+
|
| 140 |
+
Then run this script:
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
python data_prepare.py
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
### 5. Run the Full Evaluation
|
| 147 |
+
|
| 148 |
+
To train all models and see the final comparison table, run the main script:
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
python train_and_eval.py
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### 6. Run the main script
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
python main.py
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
---
|
| 161 |
+
|
| 162 |
+
## 🛠️ Technologies and Models Used
|
| 163 |
+
|
| 164 |
+
This project leverages a range of modern data science and machine learning technologies to build a robust recommender system from the ground up.
|
| 165 |
+
|
| 166 |
+
### 🏭 Models
|
| 167 |
+
|
| 168 |
+
- **Popularity Model**: A non-personalized baseline that recommends the most frequently purchased items.
|
| 169 |
+
- **Item-Item Collaborative Filtering**: A classical neighborhood-based model that recommends items based on co-occurrence patterns with a user's interaction history.
|
| 170 |
+
- **Alternating Least Squares (ALS)**: A powerful matrix factorization technique for implicit feedback, optimized with hyperparameter tuning.
|
| 171 |
+
- **SASRec (Self-Attentive Sequential Recommendation)**: A state-of-the-art sequential model based on the Transformer architecture, designed to capture the order and context of user interactions.
|
| 172 |
+
|
| 173 |
+
### 👩💻 Core Technologies & Libraries
|
| 174 |
+
|
| 175 |
+
- **Python 3.10**: The primary programming language for the project.
|
| 176 |
+
- **Pandas & NumPy**: For efficient data manipulation, preprocessing, and numerical operations.
|
| 177 |
+
- **Scikit-learn**: Used for calculating item similarity in the collaborative filtering model.
|
| 178 |
+
- **Implicit**: For the ALS model
|
| 179 |
+
- **PyTorch & PyTorch Lightning**: The deep learning framework used to build, train, and evaluate the SASRec model in a structured and scalable way.
|
| 180 |
+
- **Optuna**: A hyperparameter optimization framework used to automatically find the best parameters for the ALS model.
|
| 181 |
+
- **Gradio**: A fast and simple framework used to build and deploy the interactive web demo.
|
| 182 |
+
- **TensorBoard**: For logging and visualizing model training metrics.
|
assets/banner.png
ADDED
|
Git LFS Details
|
checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af98f72c0c5f325aa0393753db2a7e5865336423b222ced1aeb7f164009d0e06
|
| 3 |
+
size 204423643
|
notebooks/lightning_logs/sasrec/version_0/events.out.tfevents.1757789891.DESKTOP-48K6QDS.24476.0
ADDED
|
Binary file (657 Bytes). View file
|
|
|
notebooks/lightning_logs/sasrec/version_0/hparams.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 64705
|
| 2 |
+
max_len: 256
|
| 3 |
+
hidden_dim: 128
|
| 4 |
+
num_heads: 4
|
| 5 |
+
num_layers: 2
|
| 6 |
+
dropout: 0.1
|
| 7 |
+
learning_rate: 2.0e-05
|
notebooks/lightning_logs/sasrec/version_1/events.out.tfevents.1757790260.DESKTOP-48K6QDS.24476.1
ADDED
|
Binary file (657 Bytes). View file
|
|
|
notebooks/lightning_logs/sasrec/version_1/hparams.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 64705
|
| 2 |
+
max_len: 256
|
| 3 |
+
hidden_dim: 128
|
| 4 |
+
num_heads: 4
|
| 5 |
+
num_layers: 2
|
| 6 |
+
dropout: 0.1
|
| 7 |
+
learning_rate: 2.0e-05
|
notebooks/lightning_logs/sasrec/version_2/events.out.tfevents.1757868923.DESKTOP-48K6QDS.22820.0
ADDED
|
Binary file (657 Bytes). View file
|
|
|
notebooks/lightning_logs/sasrec/version_2/hparams.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 64705
|
| 2 |
+
max_len: 50
|
| 3 |
+
hidden_dim: 128
|
| 4 |
+
num_heads: 2
|
| 5 |
+
num_layers: 2
|
| 6 |
+
dropout: 0.2
|
| 7 |
+
learning_rate: 0.001
|
notebooks/lightning_logs/sasrec/version_3/events.out.tfevents.1757868963.DESKTOP-48K6QDS.22820.1
ADDED
|
Binary file (657 Bytes). View file
|
|
|
notebooks/lightning_logs/sasrec/version_3/hparams.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 64705
|
| 2 |
+
max_len: 50
|
| 3 |
+
hidden_dim: 128
|
| 4 |
+
num_heads: 2
|
| 5 |
+
num_layers: 2
|
| 6 |
+
dropout: 0.2
|
| 7 |
+
learning_rate: 0.001
|
notebooks/reccomender.ipynb
ADDED
|
@@ -0,0 +1,1419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "ae798d43",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# 🚀 End-to-End Sequential Recommender System \n",
|
| 9 |
+
"\n",
|
| 10 |
+
"This project implements and evaluates a series of recommender system models, culminating in a state-of-the-art **SASRec (Self-Attentive Sequential Recommendation)** model for Top-N next-item prediction. The system is trained on the [RetailRocket e-commerce dataset](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset) and includes an interactive web demo built with Gradio. "
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "markdown",
|
| 15 |
+
"id": "338759e6",
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"source": [
|
| 18 |
+
"## EDA"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": null,
|
| 24 |
+
"id": "dcc5a23b",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [
|
| 27 |
+
{
|
| 28 |
+
"name": "stdout",
|
| 29 |
+
"output_type": "stream",
|
| 30 |
+
"text": [
|
| 31 |
+
"Loading events.csv...\n",
|
| 32 |
+
"Data Head:\n",
|
| 33 |
+
" timestamp visitorid event itemid transactionid\n",
|
| 34 |
+
"0 1433221332117 257597 view 355908 NaN\n",
|
| 35 |
+
"1 1433224214164 992329 view 248676 NaN\n",
|
| 36 |
+
"2 1433221999827 111016 view 318965 NaN\n",
|
| 37 |
+
"3 1433221955914 483717 view 253185 NaN\n",
|
| 38 |
+
"4 1433221337106 951259 view 367447 NaN\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"Data Info:\n",
|
| 41 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
| 42 |
+
"RangeIndex: 2756101 entries, 0 to 2756100\n",
|
| 43 |
+
"Data columns (total 5 columns):\n",
|
| 44 |
+
" # Column Dtype \n",
|
| 45 |
+
"--- ------ ----- \n",
|
| 46 |
+
" 0 timestamp int64 \n",
|
| 47 |
+
" 1 visitorid int64 \n",
|
| 48 |
+
" 2 event object \n",
|
| 49 |
+
" 3 itemid int64 \n",
|
| 50 |
+
" 4 transactionid float64\n",
|
| 51 |
+
"dtypes: float64(1), int64(3), object(1)\n",
|
| 52 |
+
"memory usage: 105.1+ MB\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"Missing Values:\n",
|
| 55 |
+
"timestamp 0\n",
|
| 56 |
+
"visitorid 0\n",
|
| 57 |
+
"event 0\n",
|
| 58 |
+
"itemid 0\n",
|
| 59 |
+
"transactionid 2733644\n",
|
| 60 |
+
"dtype: int64\n"
|
| 61 |
+
]
|
| 62 |
+
}
|
| 63 |
+
],
|
| 64 |
+
"source": [
|
| 65 |
+
"import pandas as pd\n",
|
| 66 |
+
"import time\n",
|
| 67 |
+
"from datetime import datetime\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# Define the path to your data folder\n",
|
| 70 |
+
"DATA_FOLDER = 'data/'\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"# Load the events data\n",
|
| 73 |
+
"print(\"Loading events.csv...\")\n",
|
| 74 |
+
"events_df = pd.read_csv(DATA_FOLDER + 'events.csv')\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# --- Initial Inspection ---\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"# See the first few rows\n",
|
| 79 |
+
"print(\"Data Head:\")\n",
|
| 80 |
+
"print(events_df.head())\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"# Get a summary of the dataframe (columns, data types, memory usage)\n",
|
| 83 |
+
"print(\"\\nData Info:\")\n",
|
| 84 |
+
"events_df.info()\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"# Check for any missing values\n",
|
| 87 |
+
"print(\"\\nMissing Values:\")\n",
|
| 88 |
+
"print(events_df.isnull().sum())"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": 2,
|
| 94 |
+
"id": "dd89bf40",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [
|
| 97 |
+
{
|
| 98 |
+
"name": "stdout",
|
| 99 |
+
"output_type": "stream",
|
| 100 |
+
"text": [
|
| 101 |
+
"\n",
|
| 102 |
+
"Data timeframe is from 2015-05-03 03:00:04.384000 to 2015-09-18 02:59:47.788000\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"Event Counts:\n",
|
| 105 |
+
"event\n",
|
| 106 |
+
"view 2664312\n",
|
| 107 |
+
"addtocart 69332\n",
|
| 108 |
+
"transaction 22457\n",
|
| 109 |
+
"Name: count, dtype: int64\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"Number of unique visitors: 1407580\n",
|
| 112 |
+
"Number of unique items: 235061\n"
|
| 113 |
+
]
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
"source": [
|
| 117 |
+
"# --- Data Cleaning and Understanding ---\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"# 1. Convert timestamp to datetime\n",
|
| 120 |
+
"# The timestamp is in milliseconds, so we divide by 1000\n",
|
| 121 |
+
"events_df['timestamp_dt'] = pd.to_datetime(events_df['timestamp'], unit='ms')\n",
|
| 122 |
+
"print(f\"\\nData timeframe is from {events_df['timestamp_dt'].min()} to {events_df['timestamp_dt'].max()}\")\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"# 2. Analyze the distribution of event types\n",
|
| 126 |
+
"print(\"\\nEvent Counts:\")\n",
|
| 127 |
+
"event_counts = events_df['event'].value_counts()\n",
|
| 128 |
+
"print(event_counts)\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"# 3. Calculate number of unique users and items\n",
|
| 132 |
+
"n_users = events_df['visitorid'].nunique()\n",
|
| 133 |
+
"n_items = events_df['itemid'].nunique()\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"print(f\"\\nNumber of unique visitors: {n_users}\")\n",
|
| 136 |
+
"print(f\"Number of unique items: {n_items}\")"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "markdown",
|
| 141 |
+
"id": "90fbfb19",
|
| 142 |
+
"metadata": {},
|
| 143 |
+
"source": [
|
| 144 |
+
"## Preparing the data"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "code",
|
| 149 |
+
"execution_count": null,
|
| 150 |
+
"id": "8639638d",
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"outputs": [],
|
| 153 |
+
"source": [
|
| 154 |
+
"import zipfile\n",
|
| 155 |
+
"import pandas as pd\n",
|
| 156 |
+
"from datetime import datetime, timedelta\n",
|
| 157 |
+
"import numpy as np\n",
|
| 158 |
+
"from scipy.sparse import csr_matrix\n",
|
| 159 |
+
"import math\n",
|
| 160 |
+
"import torch\n",
|
| 161 |
+
"import torch.nn as nn\n",
|
| 162 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 163 |
+
"import pytorch_lightning as pl\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"def prepare_data(data_folder='data/', val_days=7, test_days=7):\n",
|
| 166 |
+
" \"\"\"\n",
|
| 167 |
+
" Loads, preprocesses, and splits the events data into train, validation, and test sets.\n",
|
| 168 |
+
" \n",
|
| 169 |
+
" args:\n",
|
| 170 |
+
" data_folder: str, path to the folder containing 'events.csv'\n",
|
| 171 |
+
" val_days: int, number of days for the validation set\n",
|
| 172 |
+
" test_days: int, number of days for the test set\n",
|
| 173 |
+
" \"\"\"\n",
|
| 174 |
+
" # --- Load Data ---\n",
|
| 175 |
+
" print(f\"Loading events.csv from folder: {data_folder}\")\n",
|
| 176 |
+
" try:\n",
|
| 177 |
+
" events_df = pd.read_csv(data_folder + 'events.csv')\n",
|
| 178 |
+
" print(\"Successfully loaded events.csv.\")\n",
|
| 179 |
+
" events_df['timestamp_dt'] = pd.to_datetime(events_df['timestamp'], unit='ms')\n",
|
| 180 |
+
" print(\"\\n--- Initial Data Summary ---\")\n",
|
| 181 |
+
" print(f\"Data shape: {events_df.shape}\")\n",
|
| 182 |
+
" print(f\"Full timeframe: {events_df['timestamp_dt'].min()} to {events_df['timestamp_dt'].max()}\")\n",
|
| 183 |
+
" print(\"----------------------------\\n\")\n",
|
| 184 |
+
" except FileNotFoundError:\n",
|
| 185 |
+
" print(f\"Error: 'events.csv' not found in '{data_folder}'. Please check the path.\")\n",
|
| 186 |
+
" return None, None, None\n",
|
| 187 |
+
"\n",
|
| 188 |
+
" # --- Split Data ---\n",
|
| 189 |
+
" sorted_df = events_df.sort_values('timestamp_dt').reset_index(drop=True)\n",
|
| 190 |
+
" print(f\"Splitting data: {test_days} days for test, {val_days} for validation.\")\n",
|
| 191 |
+
" end_time = sorted_df['timestamp_dt'].max()\n",
|
| 192 |
+
" test_start_time = end_time - timedelta(days=test_days)\n",
|
| 193 |
+
" val_start_time = test_start_time - timedelta(days=val_days)\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" test_df = sorted_df[sorted_df['timestamp_dt'] >= test_start_time]\n",
|
| 196 |
+
" val_df = sorted_df[(sorted_df['timestamp_dt'] >= val_start_time) & (sorted_df['timestamp_dt'] < test_start_time)]\n",
|
| 197 |
+
" train_df = sorted_df[sorted_df['timestamp_dt'] < val_start_time]\n",
|
| 198 |
+
"\n",
|
| 199 |
+
" print(\"--- Data Splitting Summary ---\")\n",
|
| 200 |
+
" print(f\"Training set: {train_df.shape[0]:>8} records | from {train_df['timestamp_dt'].min()} to {train_df['timestamp_dt'].max()}\")\n",
|
| 201 |
+
" print(f\"Validation set: {val_df.shape[0]:>8} records | from {val_df['timestamp_dt'].min()} to {val_df['timestamp_dt'].max()}\")\n",
|
| 202 |
+
" print(f\"Test set: {test_df.shape[0]:>8} records | from {test_df['timestamp_dt'].min()} to {test_df['timestamp_dt'].max()}\")\n",
|
| 203 |
+
" print(\"------------------------------\")\n",
|
| 204 |
+
" \n",
|
| 205 |
+
" return train_df, val_df, test_df\n"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": null,
|
| 211 |
+
"id": "f99e4498",
|
| 212 |
+
"metadata": {},
|
| 213 |
+
"outputs": [],
|
| 214 |
+
"source": [
|
| 215 |
+
"DATA_PATH = \"data\"\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)"
|
| 218 |
+
]
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"cell_type": "code",
|
| 222 |
+
"execution_count": null,
|
| 223 |
+
"id": "96b137cb",
|
| 224 |
+
"metadata": {},
|
| 225 |
+
"outputs": [],
|
| 226 |
+
"source": [
|
| 227 |
+
"class SASRecDataset(Dataset):\n",
|
| 228 |
+
" \"\"\"\n",
|
| 229 |
+
" SASRec Dataset.\n",
|
| 230 |
+
" - Precomputes (sequence_id, cutoff_idx) pairs for O(1) __getitem__.\n",
|
| 231 |
+
" - Supports 'last' or 'all' target modes.\n",
|
| 232 |
+
" \"\"\"\n",
|
| 233 |
+
" def __init__(self, sequences, max_len, target_mode=\"last\"):\n",
|
| 234 |
+
" \"\"\"\n",
|
| 235 |
+
" Args:\n",
|
| 236 |
+
" sequences: list of user sequences (list of item IDs).\n",
|
| 237 |
+
" max_len: maximum sequence length (padding applied).\n",
|
| 238 |
+
" target_mode: 'last' (only last prediction) or 'all' (predict at every step).\n",
|
| 239 |
+
" \"\"\"\n",
|
| 240 |
+
" self.sequences = sequences\n",
|
| 241 |
+
" self.max_len = max_len\n",
|
| 242 |
+
" self.target_mode = target_mode\n",
|
| 243 |
+
"\n",
|
| 244 |
+
" # Build index once\n",
|
| 245 |
+
" self.index = []\n",
|
| 246 |
+
" for seq_id, seq in enumerate(sequences):\n",
|
| 247 |
+
" for i in range(1, len(seq)):\n",
|
| 248 |
+
" self.index.append((seq_id, i))\n",
|
| 249 |
+
"\n",
|
| 250 |
+
" def __len__(self):\n",
|
| 251 |
+
" return len(self.index)\n",
|
| 252 |
+
"\n",
|
| 253 |
+
" def __getitem__(self, idx):\n",
|
| 254 |
+
" seq_id, cutoff = self.index[idx]\n",
|
| 255 |
+
" seq = self.sequences[seq_id][:cutoff]\n",
|
| 256 |
+
"\n",
|
| 257 |
+
" # Truncate & pad\n",
|
| 258 |
+
" seq = seq[-self.max_len:]\n",
|
| 259 |
+
" pad_len = self.max_len - len(seq)\n",
|
| 260 |
+
"\n",
|
| 261 |
+
" input_seq = np.zeros(self.max_len, dtype=np.int64)\n",
|
| 262 |
+
" input_seq[pad_len:] = seq\n",
|
| 263 |
+
"\n",
|
| 264 |
+
" if self.target_mode == \"last\":\n",
|
| 265 |
+
" target = self.sequences[seq_id][cutoff]\n",
|
| 266 |
+
" return torch.LongTensor(input_seq), torch.LongTensor([target])\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" elif self.target_mode == \"all\":\n",
|
| 269 |
+
" # Predict next item at each step\n",
|
| 270 |
+
" target_seq = self.sequences[seq_id][1:cutoff+1]\n",
|
| 271 |
+
" target_seq = target_seq[-self.max_len:]\n",
|
| 272 |
+
" target = np.zeros(self.max_len, dtype=np.int64)\n",
|
| 273 |
+
" target[-len(target_seq):] = target_seq\n",
|
| 274 |
+
" return torch.LongTensor(input_seq), torch.LongTensor(target)\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"class SASRecDataModule(pl.LightningDataModule):\n",
|
| 277 |
+
" \"\"\"\n",
|
| 278 |
+
" PyTorch Lightning DataModule for preparing the RetailRocket dataset for the SASRec model.\n",
|
| 279 |
+
"\n",
|
| 280 |
+
" This class handles all aspects of data preparation, including:\n",
|
| 281 |
+
" - Filtering out infrequent users and items to reduce noise.\n",
|
| 282 |
+
" - Building a consistent item vocabulary.\n",
|
| 283 |
+
" - Converting user event histories into sequential data.\n",
|
| 284 |
+
" - Creating and providing `DataLoader` instances for training, validation, and testing.\n",
|
| 285 |
+
" \"\"\"\n",
|
| 286 |
+
" def __init__(self, train_df, val_df, test_df, min_item_interactions=5, \n",
|
| 287 |
+
" min_user_interactions=5, max_len=50, batch_size=256):\n",
|
| 288 |
+
" \"\"\"\n",
|
| 289 |
+
" Initializes the DataModule.\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" Args:\n",
|
| 292 |
+
" train_df (pd.DataFrame): DataFrame for training.\n",
|
| 293 |
+
" val_df (pd.DataFrame): DataFrame for validation.\n",
|
| 294 |
+
" test_df (pd.DataFrame): DataFrame for testing.\n",
|
| 295 |
+
" min_item_interactions (int): Minimum number of interactions for an item to be kept.\n",
|
| 296 |
+
" min_user_interactions (int): Minimum number of interactions for a user to be kept.\n",
|
| 297 |
+
" max_len (int): The maximum length of a user sequence fed to the model.\n",
|
| 298 |
+
" batch_size (int): The batch size for the DataLoaders.\n",
|
| 299 |
+
" \"\"\"\n",
|
| 300 |
+
" super().__init__()\n",
|
| 301 |
+
" self.train_df = train_df\n",
|
| 302 |
+
" self.val_df = val_df\n",
|
| 303 |
+
" self.test_df = test_df\n",
|
| 304 |
+
" self.min_item_interactions = min_item_interactions\n",
|
| 305 |
+
" self.min_user_interactions = min_user_interactions\n",
|
| 306 |
+
" self.max_len = max_len\n",
|
| 307 |
+
" self.batch_size = batch_size\n",
|
| 308 |
+
"\n",
|
| 309 |
+
" self.item_map = None\n",
|
| 310 |
+
" self.inverse_item_map = None\n",
|
| 311 |
+
" self.vocab_size = 0\n",
|
| 312 |
+
" self.user_history = None\n",
|
| 313 |
+
"\n",
|
| 314 |
+
" def setup(self, stage=None):\n",
|
| 315 |
+
" \"\"\"\n",
|
| 316 |
+
" Prepares the data for training, validation, and testing.\n",
|
| 317 |
+
"\n",
|
| 318 |
+
" This method is called automatically by PyTorch Lightning. It performs the following steps:\n",
|
| 319 |
+
" 1. Determines filtering criteria (which users and items to keep) based on the training set only\n",
|
| 320 |
+
" to prevent data leakage.\n",
|
| 321 |
+
" 2. Applies these filters to the train, validation, and test sets.\n",
|
| 322 |
+
" 3. Builds an item vocabulary (mapping item IDs to integer indices) from the combined\n",
|
| 323 |
+
" training and validation sets to ensure consistency for model checkpointing.\n",
|
| 324 |
+
" 4. Converts the event logs into sequences of item indices for each user in each data split.\n",
|
| 325 |
+
" \"\"\"\n",
|
| 326 |
+
" item_counts = self.train_df['itemid'].value_counts()\n",
|
| 327 |
+
" user_counts = self.train_df['visitorid'].value_counts()\n",
|
| 328 |
+
" items_to_keep = item_counts[item_counts >= self.min_item_interactions].index\n",
|
| 329 |
+
" users_to_keep = user_counts[user_counts >= self.min_user_interactions].index\n",
|
| 330 |
+
"\n",
|
| 331 |
+
" self.filtered_train_df = self.train_df[\n",
|
| 332 |
+
" (self.train_df['itemid'].isin(items_to_keep)) & \n",
|
| 333 |
+
" (self.train_df['visitorid'].isin(users_to_keep))\n",
|
| 334 |
+
" ].copy()\n",
|
| 335 |
+
" self.filtered_val_df = self.val_df[\n",
|
| 336 |
+
" (self.val_df['itemid'].isin(items_to_keep)) & \n",
|
| 337 |
+
" (self.val_df['visitorid'].isin(users_to_keep))\n",
|
| 338 |
+
" ].copy()\n",
|
| 339 |
+
" self.filtered_test_df = self.test_df[\n",
|
| 340 |
+
" (self.test_df['itemid'].isin(items_to_keep)) & \n",
|
| 341 |
+
" (self.test_df['visitorid'].isin(users_to_keep))\n",
|
| 342 |
+
" ].copy()\n",
|
| 343 |
+
"\n",
|
| 344 |
+
" all_known_items_df = pd.concat([self.filtered_train_df, self.filtered_val_df])\n",
|
| 345 |
+
" unique_items = all_known_items_df['itemid'].unique()\n",
|
| 346 |
+
" self.item_map = {item_id: i + 1 for i, item_id in enumerate(unique_items)}\n",
|
| 347 |
+
" self.inverse_item_map = {i: item_id for item_id, i in self.item_map.items()}\n",
|
| 348 |
+
" self.vocab_size = len(self.item_map) + 1 # +1 for padding token 0\n",
|
| 349 |
+
"\n",
|
| 350 |
+
" self.user_history = self.filtered_train_df.groupby('visitorid')['itemid'].apply(list)\n",
|
| 351 |
+
" \n",
|
| 352 |
+
" self.train_sequences = self._create_sequences(self.filtered_train_df)\n",
|
| 353 |
+
" self.val_sequences = self._create_sequences(self.filtered_val_df)\n",
|
| 354 |
+
" self.test_sequences = self._create_sequences(self.filtered_test_df)\n",
|
| 355 |
+
"\n",
|
| 356 |
+
" def _create_sequences(self, df):\n",
|
| 357 |
+
" \"\"\"\n",
|
| 358 |
+
" Helper function to convert a DataFrame of events into user interaction sequences.\n",
|
| 359 |
+
" \n",
|
| 360 |
+
" Args:\n",
|
| 361 |
+
" df (pd.DataFrame): The input DataFrame to process.\n",
|
| 362 |
+
"\n",
|
| 363 |
+
" Returns:\n",
|
| 364 |
+
" list[list[int]]: A list of user sequences, where each sequence is a list of item indices.\n",
|
| 365 |
+
" \"\"\"\n",
|
| 366 |
+
" df_sorted = df.sort_values(['visitorid', 'timestamp_dt'])\n",
|
| 367 |
+
" sequences = df_sorted.groupby('visitorid')['itemid'].apply(\n",
|
| 368 |
+
" lambda x: [self.item_map[i] for i in x if i in self.item_map]\n",
|
| 369 |
+
" ).tolist()\n",
|
| 370 |
+
" return [s for s in sequences if len(s) > 1]\n",
|
| 371 |
+
"\n",
|
| 372 |
+
" def train_dataloader(self):\n",
|
| 373 |
+
" \"\"\"Creates the DataLoader for the training set.\"\"\"\n",
|
| 374 |
+
" dataset = SASRecDataset(self.train_sequences, self.max_len)\n",
|
| 375 |
+
" return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)\n",
|
| 376 |
+
"\n",
|
| 377 |
+
" def val_dataloader(self):\n",
|
| 378 |
+
" \"\"\"Creates the DataLoader for the validation set.\"\"\"\n",
|
| 379 |
+
" dataset = SASRecDataset(self.val_sequences, self.max_len)\n",
|
| 380 |
+
" return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)\n",
|
| 381 |
+
" \n",
|
| 382 |
+
" def test_dataloader(self):\n",
|
| 383 |
+
" \"\"\"Creates the DataLoader for the test set.\"\"\"\n",
|
| 384 |
+
" dataset = SASRecDataset(self.test_sequences, self.max_len)\n",
|
| 385 |
+
" return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"cell_type": "code",
|
| 390 |
+
"execution_count": null,
|
| 391 |
+
"id": "56bdc81c",
|
| 392 |
+
"metadata": {},
|
| 393 |
+
"outputs": [],
|
| 394 |
+
"source": [
|
| 395 |
+
"BATCH_SIZE = 256 \n",
|
| 396 |
+
"MAX_TOKEN_LEN = 50 # 50–100 is standard for SASRec\n",
|
| 397 |
+
"\n",
|
| 398 |
+
"# --- 1. Prepare the data into train, validation, and test sets ---\n",
|
| 399 |
+
"train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"# --- 2. Initialize DataModule ---\n",
|
| 402 |
+
"print(\"Initializing DataModule...\")\n",
|
| 403 |
+
"datamodule = SASRecDataModule(\n",
|
| 404 |
+
" train_df=train_set,\n",
|
| 405 |
+
" val_df=validation_set,\n",
|
| 406 |
+
" test_df=test_set,\n",
|
| 407 |
+
" batch_size=BATCH_SIZE,\n",
|
| 408 |
+
" max_len=MAX_TOKEN_LEN\n",
|
| 409 |
+
")\n",
|
| 410 |
+
"datamodule.setup()"
|
| 411 |
+
]
|
| 412 |
+
},
|
| 413 |
+
{
|
| 414 |
+
"cell_type": "markdown",
|
| 415 |
+
"id": "0529207a",
|
| 416 |
+
"metadata": {},
|
| 417 |
+
"source": [
|
| 418 |
+
"## Define train and evaluate the base models "
|
| 419 |
+
]
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"cell_type": "code",
|
| 423 |
+
"execution_count": null,
|
| 424 |
+
"id": "8d899a5a",
|
| 425 |
+
"metadata": {},
|
| 426 |
+
"outputs": [],
|
| 427 |
+
"source": [
|
| 428 |
+
"import pandas as pd\n",
|
| 429 |
+
"from datetime import datetime, timedelta\n",
|
| 430 |
+
"import numpy as np\n",
|
| 431 |
+
"from scipy.sparse import csr_matrix\n",
|
| 432 |
+
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
| 433 |
+
"import implicit\n",
|
| 434 |
+
"import torch\n",
|
| 435 |
+
"import torch.nn as nn\n",
|
| 436 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 437 |
+
"import pytorch_lightning as pl\n",
|
| 438 |
+
"\n",
|
| 439 |
+
"# --- 1. Evaluation Helper Functions ---\n",
|
| 440 |
+
"\n",
|
| 441 |
+
"def prepare_ground_truth(df, mode=\"purchase\", event_weights=None):\n",
|
| 442 |
+
" \"\"\"\n",
|
| 443 |
+
" Prepares ground truth dictionaries for evaluation.\n",
|
| 444 |
+
"\n",
|
| 445 |
+
" Parameters\n",
|
| 446 |
+
" ----------\n",
|
| 447 |
+
" df : pd.DataFrame\n",
|
| 448 |
+
" Test dataframe containing at least ['visitorid', 'itemid', 'event'].\n",
|
| 449 |
+
" mode : str, default=\"purchase\"\n",
|
| 450 |
+
" - \"purchase\" : Only use transactions as ground truth.\n",
|
| 451 |
+
" - \"all\" : Use all events. Optionally weight them.\n",
|
| 452 |
+
" event_weights : dict, optional\n",
|
| 453 |
+
" Example: {\"view\": 1, \"addtocart\": 3, \"transaction\": 5}.\n",
|
| 454 |
+
" Used only if mode == \"all\".\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" Returns\n",
|
| 457 |
+
" -------\n",
|
| 458 |
+
" dict : {user_id: set of item_ids}\n",
|
| 459 |
+
" \"\"\"\n",
|
| 460 |
+
" if mode == \"purchase\":\n",
|
| 461 |
+
" df_filtered = df[df[\"event\"] == \"transaction\"]\n",
|
| 462 |
+
" ground_truth = df_filtered.groupby(\"visitorid\")[\"itemid\"].apply(set).to_dict()\n",
|
| 463 |
+
"\n",
|
| 464 |
+
" elif mode == \"all\":\n",
|
| 465 |
+
" if event_weights is None:\n",
|
| 466 |
+
" # Default: treat all events equally\n",
|
| 467 |
+
" ground_truth = df.groupby(\"visitorid\")[\"itemid\"].apply(set).to_dict()\n",
|
| 468 |
+
" else:\n",
|
| 469 |
+
" # Weighted ground truth (for more advanced eval)\n",
|
| 470 |
+
" ground_truth = {}\n",
|
| 471 |
+
" for uid, user_df in df.groupby(\"visitorid\"):\n",
|
| 472 |
+
" weighted_items = []\n",
|
| 473 |
+
" for _, row in user_df.iterrows():\n",
|
| 474 |
+
" weight = event_weights.get(row[\"event\"], 1)\n",
|
| 475 |
+
" weighted_items.extend([row[\"itemid\"]] * weight)\n",
|
| 476 |
+
" ground_truth[uid] = set(weighted_items)\n",
|
| 477 |
+
" else:\n",
|
| 478 |
+
" raise ValueError(\"mode must be 'purchase' or 'all'\")\n",
|
| 479 |
+
"\n",
|
| 480 |
+
" return ground_truth\n",
|
| 481 |
+
"\n",
|
| 482 |
+
"def calculate_metrics(recommendations_dict, ground_truth_dict, k):\n",
|
| 483 |
+
" \"\"\"\n",
|
| 484 |
+
" Calculates Precision@k, Recall@k, and HitRate@k.\n",
|
| 485 |
+
"\n",
|
| 486 |
+
" args:\n",
|
| 487 |
+
" ----------\n",
|
| 488 |
+
" recommendations_dict : {user_id: [recommended_item_ids]}\n",
|
| 489 |
+
" ground_truth_dict : {user_id: set of ground truth item_ids}\n",
|
| 490 |
+
" k : int\n",
|
| 491 |
+
"\n",
|
| 492 |
+
" Returns\n",
|
| 493 |
+
" -------\n",
|
| 494 |
+
" dict with mean precision, recall, and hit rate\n",
|
| 495 |
+
" \"\"\"\n",
|
| 496 |
+
" all_precisions, all_recalls, all_hits = [], [], []\n",
|
| 497 |
+
"\n",
|
| 498 |
+
" for user_id, true_items in ground_truth_dict.items():\n",
|
| 499 |
+
" recs = recommendations_dict.get(user_id, [])[:k]\n",
|
| 500 |
+
" if not true_items:\n",
|
| 501 |
+
" continue\n",
|
| 502 |
+
" hits = len(set(recs) & true_items)\n",
|
| 503 |
+
"\n",
|
| 504 |
+
" precision = hits / k if k > 0 else 0\n",
|
| 505 |
+
" recall = hits / len(true_items)\n",
|
| 506 |
+
" hit_rate = 1.0 if hits > 0 else 0.0\n",
|
| 507 |
+
"\n",
|
| 508 |
+
" all_precisions.append(precision)\n",
|
| 509 |
+
" all_recalls.append(recall)\n",
|
| 510 |
+
" all_hits.append(hit_rate)\n",
|
| 511 |
+
"\n",
|
| 512 |
+
" if not all_precisions:\n",
|
| 513 |
+
" return {\"mean_precision@k\": 0, \"mean_recall@k\": 0, \"mean_hitrate@k\": 0}\n",
|
| 514 |
+
"\n",
|
| 515 |
+
" return {\n",
|
| 516 |
+
" \"mean_precision@k\": np.mean(all_precisions),\n",
|
| 517 |
+
" \"mean_recall@k\": np.mean(all_recalls),\n",
|
| 518 |
+
" \"mean_hitrate@k\": np.mean(all_hits)\n",
|
| 519 |
+
" }\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"# --- 2. Model Functions (Popularity, Item-Item, ALS) ---\n",
|
| 522 |
+
"\n",
|
| 523 |
+
"def recommend_popular_items_and_evaluate(train_df, test_df, k=10, prepare_ground_truth=None, calculate_metrics=None):\n",
|
| 524 |
+
" \"\"\"\n",
|
| 525 |
+
" Trains a non-personalized Popularity model and evaluates its performance.\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" This model recommends the top-k most frequently transacted items from the training\n",
|
| 528 |
+
" set to every user. It serves as a simple but strong baseline.\n",
|
| 529 |
+
"\n",
|
| 530 |
+
" Args:\n",
|
| 531 |
+
" train_df (pd.DataFrame): The training dataset.\n",
|
| 532 |
+
" test_df (pd.DataFrame): The test dataset for evaluation.\n",
|
| 533 |
+
" k (int): The number of items to recommend.\n",
|
| 534 |
+
" prepare_ground_truth (function): A function to process the test_df into a ground truth dict.\n",
|
| 535 |
+
" calculate_metrics (function): A function to compute ranking metrics.\n",
|
| 536 |
+
"\n",
|
| 537 |
+
" Returns:\n",
|
| 538 |
+
" dict: A dictionary containing the calculated evaluation metrics (e.g., precision, recall).\n",
|
| 539 |
+
" \"\"\"\n",
|
| 540 |
+
" print(f\"\\n--- Evaluating Popularity Model (Top {k} items) ---\")\n",
|
| 541 |
+
" \n",
|
| 542 |
+
" # 1. \"Train\" the model by finding the most popular items based on transactions\n",
|
| 543 |
+
" purchase_counts = train_df[train_df['event'] == 'transaction']['itemid'].value_counts()\n",
|
| 544 |
+
" popular_items = purchase_counts.head(k).index.tolist()\n",
|
| 545 |
+
" print(f\"Top {k} popular items identified from training data.\")\n",
|
| 546 |
+
"\n",
|
| 547 |
+
" # 2. Evaluate the model\n",
|
| 548 |
+
" ground_truth = prepare_ground_truth(test_df)\n",
|
| 549 |
+
" # Every user receives the same list of popular items\n",
|
| 550 |
+
" recommendations = {user_id: popular_items for user_id in ground_truth.keys()}\n",
|
| 551 |
+
" \n",
|
| 552 |
+
" metrics = calculate_metrics(recommendations, ground_truth, k)\n",
|
| 553 |
+
" print(\"Evaluation complete.\")\n",
|
| 554 |
+
" return metrics\n",
|
| 555 |
+
"\n",
|
| 556 |
+
"def recommend_item_item_and_evaluate(train_df, test_df, k=10, min_item_interactions=5, min_user_interactions=5, prepare_ground_truth=None, calculate_metrics=None):\n",
|
| 557 |
+
" \"\"\"\n",
|
| 558 |
+
" Trains an Item-Item Collaborative Filtering model and evaluates its performance.\n",
|
| 559 |
+
"\n",
|
| 560 |
+
" This model recommends items that are similar to items a user has interacted\n",
|
| 561 |
+
" with in the past, based on co-occurrence patterns in the training data.\n",
|
| 562 |
+
"\n",
|
| 563 |
+
" Args:\n",
|
| 564 |
+
" train_df (pd.DataFrame): The training dataset.\n",
|
| 565 |
+
" test_df (pd.DataFrame): The test dataset for evaluation.\n",
|
| 566 |
+
" k (int): The number of items to recommend.\n",
|
| 567 |
+
" min_item_interactions (int): Minimum number of interactions for an item to be kept.\n",
|
| 568 |
+
" min_user_interactions (int): Minimum number of interactions for a user to be kept.\n",
|
| 569 |
+
" prepare_ground_truth (function): A function to process the test_df into a ground truth dict.\n",
|
| 570 |
+
" calculate_metrics (function): A function to compute ranking metrics.\n",
|
| 571 |
+
"\n",
|
| 572 |
+
" Returns:\n",
|
| 573 |
+
" dict: A dictionary containing the calculated evaluation metrics.\n",
|
| 574 |
+
" \"\"\"\n",
|
| 575 |
+
" print(f\"\\n--- Evaluating Item-Item CF Model (Top {k} items) ---\")\n",
|
| 576 |
+
" \n",
|
| 577 |
+
" # 1. Filter out infrequent users and items to reduce noise and computation\n",
|
| 578 |
+
" item_counts = train_df['itemid'].value_counts()\n",
|
| 579 |
+
" user_counts = train_df['visitorid'].value_counts()\n",
|
| 580 |
+
" items_to_keep = item_counts[item_counts >= min_item_interactions].index\n",
|
| 581 |
+
" users_to_keep = user_counts[user_counts >= min_user_interactions].index\n",
|
| 582 |
+
" filtered_df = train_df[(train_df['itemid'].isin(items_to_keep)) & (train_df['visitorid'].isin(users_to_keep))].copy()\n",
|
| 583 |
+
" print(f\"Filtered training data from {len(train_df)} to {len(filtered_df)} records.\")\n",
|
| 584 |
+
"\n",
|
| 585 |
+
" # 2. Create user-item interaction matrix and vocabulary mappings\n",
|
| 586 |
+
" user_map = {uid: i for i, uid in enumerate(filtered_df['visitorid'].unique())}\n",
|
| 587 |
+
" item_map = {iid: i for i, iid in enumerate(filtered_df['itemid'].unique())}\n",
|
| 588 |
+
" inverse_item_map = {i: iid for iid, i in item_map.items()}\n",
|
| 589 |
+
" user_indices = filtered_df['visitorid'].map(user_map)\n",
|
| 590 |
+
" item_indices = filtered_df['itemid'].map(item_map)\n",
|
| 591 |
+
" user_item_matrix = csr_matrix((np.ones(len(filtered_df)), (user_indices, item_indices)))\n",
|
| 592 |
+
"\n",
|
| 593 |
+
" # 3. Calculate the cosine similarity matrix between all items\n",
|
| 594 |
+
" print(\"Calculating item similarity matrix...\")\n",
|
| 595 |
+
" item_similarity_matrix = cosine_similarity(user_item_matrix.T, dense_output=False)\n",
|
| 596 |
+
" print(\"Similarity matrix calculated.\")\n",
|
| 597 |
+
"\n",
|
| 598 |
+
" # 4. Generate recommendations and evaluate\n",
|
| 599 |
+
" ground_truth = prepare_ground_truth(test_df)\n",
|
| 600 |
+
" recommendations = {}\n",
|
| 601 |
+
" print(\"Generating recommendations for users in test set...\")\n",
|
| 602 |
+
" test_users = [u for u in ground_truth.keys() if u in user_map]\n",
|
| 603 |
+
" \n",
|
| 604 |
+
" for user_id in test_users:\n",
|
| 605 |
+
" user_index = user_map[user_id]\n",
|
| 606 |
+
" user_interactions_indices = user_item_matrix[user_index].indices\n",
|
| 607 |
+
" \n",
|
| 608 |
+
" if len(user_interactions_indices) > 0:\n",
|
| 609 |
+
" # Aggregate scores from items the user has interacted with\n",
|
| 610 |
+
" all_scores = np.asarray(item_similarity_matrix[user_interactions_indices].sum(axis=0)).flatten()\n",
|
| 611 |
+
" # Remove already interacted items from recommendations\n",
|
| 612 |
+
" all_scores[user_interactions_indices] = -1\n",
|
| 613 |
+
" top_indices = np.argsort(all_scores)[::-1][:k]\n",
|
| 614 |
+
" recs = [inverse_item_map[idx] for idx in top_indices if idx in inverse_item_map]\n",
|
| 615 |
+
" recommendations[user_id] = recs\n",
|
| 616 |
+
" \n",
|
| 617 |
+
" metrics = calculate_metrics(recommendations, ground_truth, k)\n",
|
| 618 |
+
" print(\"Evaluation complete.\")\n",
|
| 619 |
+
" return metrics\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"def recommend_als_and_evaluate(train_df, test_df, k=10, min_item_interactions=5, min_user_interactions=5, \n",
|
| 622 |
+
" factors=25, regularization=0.02, iterations=48, prepare_ground_truth=None, calculate_metrics=None):\n",
|
| 623 |
+
" \"\"\"\n",
|
| 624 |
+
" Trains an Alternating Least Squares (ALS) model and evaluates its performance.\n",
|
| 625 |
+
"\n",
|
| 626 |
+
" This model uses matrix factorization to learn latent embeddings for users and\n",
|
| 627 |
+
" items from implicit feedback data. Default hyperparameters are set from a\n",
|
| 628 |
+
" previous Optuna tuning process.\n",
|
| 629 |
+
"\n",
|
| 630 |
+
" Args:\n",
|
| 631 |
+
" train_df (pd.DataFrame): The training dataset.\n",
|
| 632 |
+
" test_df (pd.DataFrame): The test dataset for evaluation.\n",
|
| 633 |
+
" k (int): The number of items to recommend.\n",
|
| 634 |
+
" min_item_interactions (int): Minimum number of interactions for an item to be kept.\n",
|
| 635 |
+
" min_user_interactions (int): Minimum number of interactions for a user to be kept.\n",
|
| 636 |
+
" factors (int): The number of latent factors to compute.\n",
|
| 637 |
+
" regularization (float): The regularization factor.\n",
|
| 638 |
+
" iterations (int): The number of ALS iterations to run.\n",
|
| 639 |
+
" prepare_ground_truth (function): A function to process the test_df into a ground truth dict.\n",
|
| 640 |
+
" calculate_metrics (function): A function to compute ranking metrics.\n",
|
| 641 |
+
"\n",
|
| 642 |
+
" Returns:\n",
|
| 643 |
+
" dict: A dictionary containing the calculated evaluation metrics.\n",
|
| 644 |
+
" \"\"\"\n",
|
| 645 |
+
" print(f\"\\n--- Evaluating ALS Model (Top {k} items) ---\")\n",
|
| 646 |
+
" \n",
|
| 647 |
+
" # 1. Filter data\n",
|
| 648 |
+
" item_counts = train_df['itemid'].value_counts()\n",
|
| 649 |
+
" user_counts = train_df['visitorid'].value_counts()\n",
|
| 650 |
+
" items_to_keep = item_counts[item_counts >= min_item_interactions].index\n",
|
| 651 |
+
" users_to_keep = user_counts[user_counts >= min_user_interactions].index\n",
|
| 652 |
+
" filtered_df = train_df[(train_df['itemid'].isin(items_to_keep)) & (train_df['visitorid'].isin(users_to_keep))].copy()\n",
|
| 653 |
+
" print(f\"Filtered training data from {len(train_df)} to {len(filtered_df)} records.\")\n",
|
| 654 |
+
"\n",
|
| 655 |
+
" # 2. Create mappings and confidence matrix\n",
|
| 656 |
+
" user_map = {uid: i for i, uid in enumerate(filtered_df['visitorid'].unique())}\n",
|
| 657 |
+
" item_map = {iid: i for i, iid in enumerate(filtered_df['itemid'].unique())}\n",
|
| 658 |
+
" inverse_item_map = {i: iid for iid, i in item_map.items()}\n",
|
| 659 |
+
" user_indices = filtered_df['visitorid'].map(user_map).astype(np.int32)\n",
|
| 660 |
+
" item_indices = filtered_df['itemid'].map(item_map).astype(np.int32)\n",
|
| 661 |
+
" \n",
|
| 662 |
+
" event_weights = {'view': 1, 'addtocart': 3, 'transaction': 5}\n",
|
| 663 |
+
" confidence = filtered_df['event'].map(event_weights).astype(np.float32)\n",
|
| 664 |
+
" user_item_matrix = csr_matrix((confidence, (user_indices, item_indices)))\n",
|
| 665 |
+
"\n",
|
| 666 |
+
" # 3. Train the ALS model\n",
|
| 667 |
+
" print(\"Training ALS model...\")\n",
|
| 668 |
+
" als_model = implicit.als.AlternatingLeastSquares(factors=factors, regularization=regularization, iterations=iterations)\n",
|
| 669 |
+
" als_model.fit(user_item_matrix)\n",
|
| 670 |
+
" print(\"ALS model trained.\")\n",
|
| 671 |
+
"\n",
|
| 672 |
+
" # 4. Generate recommendations and evaluate\n",
|
| 673 |
+
" ground_truth = prepare_ground_truth(test_df)\n",
|
| 674 |
+
" recommendations = {}\n",
|
| 675 |
+
" print(\"Generating recommendations for users in test set...\")\n",
|
| 676 |
+
" test_users_indices = [user_map[u] for u in ground_truth.keys() if u in user_map]\n",
|
| 677 |
+
" \n",
|
| 678 |
+
" if test_users_indices:\n",
|
| 679 |
+
" user_item_matrix_for_recs = user_item_matrix[test_users_indices]\n",
|
| 680 |
+
" ids, _ = als_model.recommend(test_users_indices, user_item_matrix_for_recs, N=k)\n",
|
| 681 |
+
" \n",
|
| 682 |
+
" for i, user_index in enumerate(test_users_indices):\n",
|
| 683 |
+
" original_user_id = list(user_map.keys())[list(user_map.values()).index(user_index)]\n",
|
| 684 |
+
" recs = [inverse_item_map[item_idx] for item_idx in ids[i] if item_idx in inverse_item_map]\n",
|
| 685 |
+
" recommendations[original_user_id] = recs\n",
|
| 686 |
+
" \n",
|
| 687 |
+
" metrics = calculate_metrics(recommendations, ground_truth, k)\n",
|
| 688 |
+
" print(\"Evaluation complete.\")\n",
|
| 689 |
+
" return metrics\n",
|
| 690 |
+
"\n",
|
| 691 |
+
"\n",
|
| 692 |
+
" train_set, validation_set, test_set = prepare_data(data_folder='C:/Users/dania/vsproject/projects/recommernder_system/data/')\n",
|
| 693 |
+
" if train_set is not None:\n",
|
| 694 |
+
" results = {}\n",
|
| 695 |
+
" full_train_set = pd.concat([train_set, validation_set])\n",
|
| 696 |
+
" \n",
|
| 697 |
+
"# # Evaluate classical models\n",
|
| 698 |
+
" print(\"\\n>>> Running evaluations on the VALIDATION set <<<\")\n",
|
| 699 |
+
" results['Popularity (Validation)'] = recommend_popular_items_and_evaluate(train_set, validation_set)\n",
|
| 700 |
+
" results['Item-Item CF (Validation)'] = recommend_item_item_and_evaluate(train_set, validation_set)\n",
|
| 701 |
+
" results['ALS (Validation)'] = recommend_als_and_evaluate(train_set, validation_set)\n",
|
| 702 |
+
" \n",
|
| 703 |
+
" print(\"\\n>>> Running final evaluations on the TEST set <<<\")\n",
|
| 704 |
+
" results['Popularity (Test)'] = recommend_popular_items_and_evaluate(full_train_set, test_set)\n",
|
| 705 |
+
" results['Item-Item CF (Test)'] = recommend_item_item_and_evaluate(full_train_set, test_set)\n",
|
| 706 |
+
" results['ALS (Test)'] = recommend_als_and_evaluate(full_train_set, test_set)\n",
|
| 707 |
+
" \n",
|
| 708 |
+
" print(\"\\n--- Final Evaluation Results ---\")\n",
|
| 709 |
+
" results_df = pd.DataFrame.from_dict(results, orient='index')\n",
|
| 710 |
+
" print(results_df)\n",
|
| 711 |
+
" print(\"--------------------------------\")\n"
|
| 712 |
+
]
|
| 713 |
+
},
|
| 714 |
+
{
|
| 715 |
+
"cell_type": "code",
|
| 716 |
+
"execution_count": null,
|
| 717 |
+
"id": "b978c458",
|
| 718 |
+
"metadata": {},
|
| 719 |
+
"outputs": [],
|
| 720 |
+
"source": [
|
| 721 |
+
"train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)\n",
|
| 722 |
+
"if train_set is not None:\n",
|
| 723 |
+
" results = {}\n",
|
| 724 |
+
" full_train_set = pd.concat([train_set, validation_set])\n",
|
| 725 |
+
" \n",
|
| 726 |
+
" # Evaluate base models\n",
|
| 727 |
+
" print(\"\\n>>> Running evaluations on the VALIDATION set <<<\")\n",
|
| 728 |
+
" results['Popularity (Validation)'] = recommend_popular_items_and_evaluate(train_set, validation_set)\n",
|
| 729 |
+
" results['Item-Item CF (Validation)'] = recommend_item_item_and_evaluate(train_set, validation_set)\n",
|
| 730 |
+
" results['ALS (Validation)'] = recommend_als_and_evaluate(train_set, validation_set)\n",
|
| 731 |
+
" \n",
|
| 732 |
+
" print(\"\\n>>> Running final evaluations on the TEST set <<<\")\n",
|
| 733 |
+
" results['Popularity (Test)'] = recommend_popular_items_and_evaluate(full_train_set, test_set)\n",
|
| 734 |
+
" results['Item-Item CF (Test)'] = recommend_item_item_and_evaluate(full_train_set, test_set)\n",
|
| 735 |
+
" results['ALS (Test)'] = recommend_als_and_evaluate(full_train_set, test_set)\n",
|
| 736 |
+
" \n",
|
| 737 |
+
" print(\"\\n--- Final Evaluation Results ---\")\n",
|
| 738 |
+
" results_df = pd.DataFrame.from_dict(results, orient='index')\n",
|
| 739 |
+
" print(results_df)\n",
|
| 740 |
+
" print(\"--------------------------------\")"
|
| 741 |
+
]
|
| 742 |
+
},
|
| 743 |
+
{
|
| 744 |
+
"cell_type": "markdown",
|
| 745 |
+
"id": "85d8f78c",
|
| 746 |
+
"metadata": {},
|
| 747 |
+
"source": [
|
| 748 |
+
"## Use Optuna to find the best Hyperparameters for the ALS model"
|
| 749 |
+
]
|
| 750 |
+
},
|
| 751 |
+
{
|
| 752 |
+
"cell_type": "code",
|
| 753 |
+
"execution_count": null,
|
| 754 |
+
"id": "202be1f4",
|
| 755 |
+
"metadata": {},
|
| 756 |
+
"outputs": [],
|
| 757 |
+
"source": [
|
| 758 |
+
"import optuna\n",
|
| 759 |
+
"\n",
|
| 760 |
+
"def objective_als(trial, train_df, val_df):\n",
|
| 761 |
+
" \"\"\"\n",
|
| 762 |
+
" The objective function for Optuna to optimize.\n",
|
| 763 |
+
" \"\"\"\n",
|
| 764 |
+
" # 1. Define the hyperparameter search space\n",
|
| 765 |
+
" params = {\n",
|
| 766 |
+
" 'factors': trial.suggest_int('factors', 20, 200),\n",
|
| 767 |
+
" 'regularization': trial.suggest_float('regularization', 1e-3, 1e-1, log=True),\n",
|
| 768 |
+
" 'iterations': trial.suggest_int('iterations', 10, 50)\n",
|
| 769 |
+
" }\n",
|
| 770 |
+
" \n",
|
| 771 |
+
" # 2. Run an evaluation with the suggested parameters\n",
|
| 772 |
+
" metrics = recommend_als_and_evaluate(train_df, val_df, **params)\n",
|
| 773 |
+
" \n",
|
| 774 |
+
" # 3. Return the metric we want to maximize (precision)\n",
|
| 775 |
+
" return metrics['mean_precision@k']\n",
|
| 776 |
+
"\n",
|
| 777 |
+
"def tune_als_hyperparameters(train_df, val_df, n_trials=25):\n",
|
| 778 |
+
" \"\"\"\n",
|
| 779 |
+
" Orchestrates the Optuna study to find the best hyperparameters for ALS.\n",
|
| 780 |
+
" \"\"\"\n",
|
| 781 |
+
" study = optuna.create_study(direction='maximize')\n",
|
| 782 |
+
" study.optimize(lambda trial: objective_als(trial, train_df, val_df), n_trials=n_trials)\n",
|
| 783 |
+
" \n",
|
| 784 |
+
" print(\"\\n--- Optuna Study Complete ---\")\n",
|
| 785 |
+
" print(f\"Number of finished trials: {len(study.trials)}\")\n",
|
| 786 |
+
" print(\"Best trial:\")\n",
|
| 787 |
+
" trial = study.best_trial\n",
|
| 788 |
+
" print(f\" Value (Precision@10): {trial.value}\")\n",
|
| 789 |
+
" print(\" Params: \")\n",
|
| 790 |
+
" for key, value in trial.params.items():\n",
|
| 791 |
+
" print(f\" {key}: {value}\")\n",
|
| 792 |
+
" \n",
|
| 793 |
+
" return trial.params\n"
|
| 794 |
+
]
|
| 795 |
+
},
|
| 796 |
+
{
|
| 797 |
+
"cell_type": "code",
|
| 798 |
+
"execution_count": null,
|
| 799 |
+
"id": "18d48b2e",
|
| 800 |
+
"metadata": {},
|
| 801 |
+
"outputs": [],
|
| 802 |
+
"source": [
|
| 803 |
+
"# 1. Prepare all data\n",
|
| 804 |
+
"train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)\n",
|
| 805 |
+
"\n",
|
| 806 |
+
"\n",
|
| 807 |
+
"# --- Hyperparameter Tuning Step ---\n",
|
| 808 |
+
"print(\"\\n>>> 1. TUNING ALS Hyperparameters on the VALIDATION set <<<\")\n",
|
| 809 |
+
"# You can increase n_trials for a more thorough search, e.g., to 50 or 100\n",
|
| 810 |
+
"best_als_params = tune_als_hyperparameters(train_set, validation_set, n_trials=25) \n"
|
| 811 |
+
]
|
| 812 |
+
},
|
| 813 |
+
{
|
| 814 |
+
"cell_type": "markdown",
|
| 815 |
+
"id": "d9bc9ef8",
|
| 816 |
+
"metadata": {},
|
| 817 |
+
"source": [
|
| 818 |
+
"## Define train and evaluate the SASRec model"
|
| 819 |
+
]
|
| 820 |
+
},
|
| 821 |
+
{
|
| 822 |
+
"cell_type": "code",
|
| 823 |
+
"execution_count": null,
|
| 824 |
+
"id": "4a90d635",
|
| 825 |
+
"metadata": {},
|
| 826 |
+
"outputs": [],
|
| 827 |
+
"source": [
|
| 828 |
+
"class SASRec(pl.LightningModule):\n",
|
| 829 |
+
" \"\"\"\n",
|
| 830 |
+
" A PyTorch Lightning implementation of the SASRec model for sequential recommendation.\n",
|
| 831 |
+
"\n",
|
| 832 |
+
" SASRec (Self-Attentive Sequential Recommendation) uses a Transformer-based\n",
|
| 833 |
+
" architecture to capture the sequential patterns in a user's interaction history\n",
|
| 834 |
+
" to predict the next item they are likely to interact with.\n",
|
| 835 |
+
"\n",
|
| 836 |
+
" Attributes:\n",
|
| 837 |
+
" save_hyperparameters: Automatically saves all constructor arguments as hyperparameters.\n",
|
| 838 |
+
" item_embedding (nn.Embedding): Embedding layer for item IDs.\n",
|
| 839 |
+
" positional_embedding (nn.Embedding): Embedding layer to encode the position of items in a sequence.\n",
|
| 840 |
+
" transformer_encoder (nn.TransformerEncoder): The core self-attention module.\n",
|
| 841 |
+
" fc (nn.Linear): Final fully connected layer to produce logits over the item vocabulary.\n",
|
| 842 |
+
" loss_fn (nn.CrossEntropyLoss): The loss function used for training.\n",
|
| 843 |
+
" \"\"\"\n",
|
| 844 |
+
" def __init__(self, vocab_size, max_len, hidden_dim, num_heads, num_layers,\n",
|
| 845 |
+
" dropout=0.2, learning_rate=1e-3, weight_decay=1e-6, warmup_steps=2000, max_steps=100000):\n",
|
| 846 |
+
" \"\"\"\n",
|
| 847 |
+
" Initializes the SASRec model layers and hyperparameters.\n",
|
| 848 |
+
"\n",
|
| 849 |
+
" Args:\n",
|
| 850 |
+
" vocab_size (int): The total number of unique items in the dataset (+1 for padding).\n",
|
| 851 |
+
" max_len (int): The maximum length of the input sequences.\n",
|
| 852 |
+
" hidden_dim (int): The dimensionality of the item and positional embeddings.\n",
|
| 853 |
+
" num_heads (int): The number of attention heads in the Transformer encoder.\n",
|
| 854 |
+
" num_layers (int): The number of layers in the Transformer encoder.\n",
|
| 855 |
+
" dropout (float): The dropout rate to be applied.\n",
|
| 856 |
+
" learning_rate (float): The learning rate for the optimizer.\n",
|
| 857 |
+
" weight_decay (float): The weight decay (L2 penalty) for the optimizer.\n",
|
| 858 |
+
" warmup_steps (int): The number of linear warmup steps for the learning rate scheduler.\n",
|
| 859 |
+
" max_steps (int): The total number of training steps for the learning rate scheduler's decay phase.\n",
|
| 860 |
+
" \"\"\"\n",
|
| 861 |
+
" super().__init__()\n",
|
| 862 |
+
" # This saves all hyperparameters to self.hparams, making them accessible later\n",
|
| 863 |
+
" self.save_hyperparameters()\n",
|
| 864 |
+
"\n",
|
| 865 |
+
" # Embedding layers\n",
|
| 866 |
+
" self.item_embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)\n",
|
| 867 |
+
" self.positional_embedding = nn.Embedding(max_len, hidden_dim)\n",
|
| 868 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
| 869 |
+
"\n",
|
| 870 |
+
" # Transformer Encoder\n",
|
| 871 |
+
" encoder_layer = nn.TransformerEncoderLayer(\n",
|
| 872 |
+
" d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim * 4,\n",
|
| 873 |
+
" dropout=dropout, batch_first=True, activation='gelu'\n",
|
| 874 |
+
" )\n",
|
| 875 |
+
" self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n",
|
| 876 |
+
"\n",
|
| 877 |
+
" # Output layer\n",
|
| 878 |
+
" self.fc = nn.Linear(hidden_dim, vocab_size)\n",
|
| 879 |
+
"\n",
|
| 880 |
+
" # Loss function, ignoring the padding token\n",
|
| 881 |
+
" self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)\n",
|
| 882 |
+
" \n",
|
| 883 |
+
" # Lists to store outputs from validation and test steps\n",
|
| 884 |
+
" self.validation_step_outputs = []\n",
|
| 885 |
+
" self.test_step_outputs = []\n",
|
| 886 |
+
"\n",
|
| 887 |
+
" def forward(self, x):\n",
|
| 888 |
+
" \"\"\"\n",
|
| 889 |
+
" Defines the forward pass of the model.\n",
|
| 890 |
+
"\n",
|
| 891 |
+
" Args:\n",
|
| 892 |
+
" x (torch.Tensor): A batch of input sequences of shape (batch_size, seq_len).\n",
|
| 893 |
+
"\n",
|
| 894 |
+
" Returns:\n",
|
| 895 |
+
" torch.Tensor: The output logits of shape (batch_size, seq_len, vocab_size).\n",
|
| 896 |
+
" \"\"\"\n",
|
| 897 |
+
" seq_len = x.size(1)\n",
|
| 898 |
+
" # Create positional indices (0, 1, 2, ..., seq_len-1)\n",
|
| 899 |
+
" positions = torch.arange(seq_len, device=self.device).unsqueeze(0)\n",
|
| 900 |
+
"\n",
|
| 901 |
+
" # Create a causal mask to ensure the model doesn't look ahead in the sequence\n",
|
| 902 |
+
" causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=self.device)\n",
|
| 903 |
+
"\n",
|
| 904 |
+
" # Combine item and positional embeddings\n",
|
| 905 |
+
" x = self.item_embedding(x) + self.positional_embedding(positions)\n",
|
| 906 |
+
" x = self.dropout(x)\n",
|
| 907 |
+
" \n",
|
| 908 |
+
" # Pass through the Transformer encoder\n",
|
| 909 |
+
" x = self.transformer_encoder(x, mask=causal_mask)\n",
|
| 910 |
+
" \n",
|
| 911 |
+
" # Get final logits\n",
|
| 912 |
+
" logits = self.fc(x)\n",
|
| 913 |
+
" return logits\n",
|
| 914 |
+
"\n",
|
| 915 |
+
" def training_step(self, batch, batch_idx):\n",
|
| 916 |
+
" \"\"\"\n",
|
| 917 |
+
" Performs a single training step.\n",
|
| 918 |
+
"\n",
|
| 919 |
+
" Args:\n",
|
| 920 |
+
" batch (tuple): A tuple containing input sequences and target items.\n",
|
| 921 |
+
" batch_idx (int): The index of the current batch.\n",
|
| 922 |
+
"\n",
|
| 923 |
+
" Returns:\n",
|
| 924 |
+
" torch.Tensor: The calculated loss for the batch.\n",
|
| 925 |
+
" \"\"\"\n",
|
| 926 |
+
" inputs, targets = batch\n",
|
| 927 |
+
" logits = self.forward(inputs)\n",
|
| 928 |
+
"\n",
|
| 929 |
+
" # We only care about the prediction for the very last item in the input sequence\n",
|
| 930 |
+
" last_logits = logits[:, -1, :]\n",
|
| 931 |
+
" \n",
|
| 932 |
+
" # Calculate loss against the single target item\n",
|
| 933 |
+
" loss = self.loss_fn(last_logits, targets.squeeze())\n",
|
| 934 |
+
" \n",
|
| 935 |
+
" self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True)\n",
|
| 936 |
+
" return loss\n",
|
| 937 |
+
"\n",
|
| 938 |
+
" def validation_step(self, batch, batch_idx):\n",
|
| 939 |
+
" \"\"\"\n",
|
| 940 |
+
" Performs a single validation step.\n",
|
| 941 |
+
" Calculates loss and stores predictions for metric computation at the end of the epoch.\n",
|
| 942 |
+
" \"\"\"\n",
|
| 943 |
+
" inputs, targets = batch\n",
|
| 944 |
+
" logits = self.forward(inputs)\n",
|
| 945 |
+
" last_item_logits = logits[:, -1, :]\n",
|
| 946 |
+
" loss = self.loss_fn(last_item_logits, targets.squeeze())\n",
|
| 947 |
+
" self.log('val_loss', loss, prog_bar=True, on_epoch=True)\n",
|
| 948 |
+
"\n",
|
| 949 |
+
" # Get top-10 predictions for metric calculation\n",
|
| 950 |
+
" top_k_preds = torch.topk(last_item_logits, 10, dim=-1).indices\n",
|
| 951 |
+
" self.validation_step_outputs.append({'preds': top_k_preds, 'targets': targets})\n",
|
| 952 |
+
" return loss\n",
|
| 953 |
+
"\n",
|
| 954 |
+
" def on_validation_epoch_end(self):\n",
|
| 955 |
+
" \"\"\"\n",
|
| 956 |
+
" Calculates and logs ranking metrics at the end of the validation epoch.\n",
|
| 957 |
+
" \"\"\"\n",
|
| 958 |
+
" if not self.validation_step_outputs: return\n",
|
| 959 |
+
"\n",
|
| 960 |
+
" # Concatenate all predictions and targets from the epoch\n",
|
| 961 |
+
" preds = torch.cat([x['preds'] for x in self.validation_step_outputs], dim=0)\n",
|
| 962 |
+
" targets = torch.cat([x['targets'] for x in self.validation_step_outputs], dim=0)\n",
|
| 963 |
+
"\n",
|
| 964 |
+
" k = preds.size(1)\n",
|
| 965 |
+
" # Check if the target is in the top-k predictions for each example\n",
|
| 966 |
+
" hits_tensor = (preds == targets).any(dim=1)\n",
|
| 967 |
+
" num_hits = hits_tensor.sum().item()\n",
|
| 968 |
+
" num_targets = len(targets)\n",
|
| 969 |
+
"\n",
|
| 970 |
+
" if num_targets > 0:\n",
|
| 971 |
+
" hit_rate = num_hits / num_targets\n",
|
| 972 |
+
" recall = hit_rate # For next-item prediction, recall@k is the same as hit_rate@k\n",
|
| 973 |
+
" precision = num_hits / (k * num_targets)\n",
|
| 974 |
+
" else:\n",
|
| 975 |
+
" hit_rate, recall, precision = 0.0, 0.0, 0.0\n",
|
| 976 |
+
"\n",
|
| 977 |
+
" self.log('val_hitrate@10', hit_rate, prog_bar=True)\n",
|
| 978 |
+
" self.log('val_precision@10', precision, prog_bar=True)\n",
|
| 979 |
+
" self.log('val_recall@10', recall, prog_bar=True)\n",
|
| 980 |
+
"\n",
|
| 981 |
+
" self.validation_step_outputs.clear() # Free up memory\n",
|
| 982 |
+
"\n",
|
| 983 |
+
" def test_step(self, batch, batch_idx):\n",
|
| 984 |
+
" \"\"\"\n",
|
| 985 |
+
" Performs a single test step.\n",
|
| 986 |
+
" Mirrors the logic of the validation_step.\n",
|
| 987 |
+
" \"\"\"\n",
|
| 988 |
+
" inputs, targets = batch\n",
|
| 989 |
+
" logits = self.forward(inputs)\n",
|
| 990 |
+
" last_item_logits = logits[:, -1, :]\n",
|
| 991 |
+
" loss = self.loss_fn(last_item_logits, targets.squeeze())\n",
|
| 992 |
+
" self.log('test_loss', loss, prog_bar=True)\n",
|
| 993 |
+
"\n",
|
| 994 |
+
" top_k_preds = torch.topk(last_item_logits, 10, dim=-1).indices\n",
|
| 995 |
+
" self.test_step_outputs.append({'preds': top_k_preds, 'targets': targets})\n",
|
| 996 |
+
" return loss\n",
|
| 997 |
+
"\n",
|
| 998 |
+
" def on_test_epoch_end(self):\n",
|
| 999 |
+
" \"\"\"\n",
|
| 1000 |
+
" Calculates and logs ranking metrics at the end of the test epoch.\n",
|
| 1001 |
+
" \"\"\"\n",
|
| 1002 |
+
" if not self.test_step_outputs: return\n",
|
| 1003 |
+
"\n",
|
| 1004 |
+
" preds = torch.cat([x['preds'] for x in self.test_step_outputs], dim=0)\n",
|
| 1005 |
+
" targets = torch.cat([x['targets'] for x in self.test_step_outputs], dim=0)\n",
|
| 1006 |
+
"\n",
|
| 1007 |
+
" k = preds.size(1)\n",
|
| 1008 |
+
" hits_tensor = (preds == targets).any(dim=1)\n",
|
| 1009 |
+
" num_hits = hits_tensor.sum().item()\n",
|
| 1010 |
+
" num_targets = len(targets)\n",
|
| 1011 |
+
"\n",
|
| 1012 |
+
" if num_targets > 0:\n",
|
| 1013 |
+
" hit_rate = num_hits / num_targets\n",
|
| 1014 |
+
" recall = hit_rate\n",
|
| 1015 |
+
" precision = num_hits / (k * num_targets)\n",
|
| 1016 |
+
" else:\n",
|
| 1017 |
+
" hit_rate, recall, precision = 0.0, 0.0, 0.0\n",
|
| 1018 |
+
"\n",
|
| 1019 |
+
" self.log('test_hitrate@10', hit_rate, prog_bar=True)\n",
|
| 1020 |
+
" self.log('test_precision@10', precision, prog_bar=True)\n",
|
| 1021 |
+
" self.log('test_recall@10', recall, prog_bar=True)\n",
|
| 1022 |
+
"\n",
|
| 1023 |
+
" self.test_step_outputs.clear() # Free up memory\n",
|
| 1024 |
+
"\n",
|
| 1025 |
+
" def configure_optimizers(self):\n",
|
| 1026 |
+
" \"\"\"\n",
|
| 1027 |
+
" Configures the optimizer and learning rate scheduler.\n",
|
| 1028 |
+
" \n",
|
| 1029 |
+
" Uses AdamW optimizer and a linear warmup followed by a cosine decay schedule,\n",
|
| 1030 |
+
" which is a standard practice for training Transformer models.\n",
|
| 1031 |
+
" \"\"\"\n",
|
| 1032 |
+
" optimizer = torch.optim.AdamW(\n",
|
| 1033 |
+
" self.parameters(),\n",
|
| 1034 |
+
" lr=self.hparams.learning_rate,\n",
|
| 1035 |
+
" weight_decay=self.hparams.weight_decay\n",
|
| 1036 |
+
" )\n",
|
| 1037 |
+
" \n",
|
| 1038 |
+
" # Learning rate scheduler: linear warmup and cosine decay\n",
|
| 1039 |
+
" def lr_lambda(current_step: int):\n",
|
| 1040 |
+
" warmup_steps = self.hparams.warmup_steps\n",
|
| 1041 |
+
" max_steps = self.hparams.max_steps\n",
|
| 1042 |
+
" if current_step < warmup_steps:\n",
|
| 1043 |
+
" return float(current_step) / float(max(1, warmup_steps))\n",
|
| 1044 |
+
" progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))\n",
|
| 1045 |
+
" return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))\n",
|
| 1046 |
+
"\n",
|
| 1047 |
+
" scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n",
|
| 1048 |
+
"\n",
|
| 1049 |
+
" return {\n",
|
| 1050 |
+
" \"optimizer\": optimizer,\n",
|
| 1051 |
+
" \"lr_scheduler\": {\n",
|
| 1052 |
+
" \"scheduler\": scheduler,\n",
|
| 1053 |
+
" \"interval\": \"step\", # Update the scheduler at every training step\n",
|
| 1054 |
+
" \"frequency\": 1\n",
|
| 1055 |
+
" }\n",
|
| 1056 |
+
" }"
|
| 1057 |
+
]
|
| 1058 |
+
},
|
| 1059 |
+
{
|
| 1060 |
+
"cell_type": "code",
|
| 1061 |
+
"execution_count": null,
|
| 1062 |
+
"id": "20bbc93a",
|
| 1063 |
+
"metadata": {},
|
| 1064 |
+
"outputs": [],
|
| 1065 |
+
"source": [
|
| 1066 |
+
"import pytorch_lightning as pl\n",
|
| 1067 |
+
"from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
|
| 1068 |
+
"from pytorch_lightning.loggers import TensorBoardLogger\n",
|
| 1069 |
+
"import torch\n",
|
| 1070 |
+
"\n",
|
| 1071 |
+
"def train_and_eval_SASRec_model(train_set, validation_set, test_set, checkpoint_dir_path='checkpoints/',\n",
|
| 1072 |
+
" checkpoint_path=None, n_epochs=10, mode='train',\n",
|
| 1073 |
+
" batchsize=256, max_token_len=50, learning_rate=1e-3, hidden_dim=128,\n",
|
| 1074 |
+
" num_heads=2, num_layers=2, dropout=0.2, weight_decay=1e-6):\n",
|
| 1075 |
+
" \"\"\"\n",
|
| 1076 |
+
" Train or evaluate a SASRec sequential recommendation model using PyTorch Lightning.\n",
|
| 1077 |
+
"\n",
|
| 1078 |
+
" This function wraps the entire SASRec pipeline:\n",
|
| 1079 |
+
" - Initializes the SASRecDataModule (handles dataset preprocessing and dataloaders).\n",
|
| 1080 |
+
" - Builds the SASRec Transformer-based model.\n",
|
| 1081 |
+
" - Configures training callbacks (checkpointing, early stopping, LR monitoring).\n",
|
| 1082 |
+
" - Runs either training (`mode='train'`) or evaluation on the test set (`mode='test'`).\n",
|
| 1083 |
+
"\n",
|
| 1084 |
+
" Args\n",
|
| 1085 |
+
" ----------\n",
|
| 1086 |
+
" train_set : pd.DataFrame\n",
|
| 1087 |
+
" Training interactions dataset .\n",
|
| 1088 |
+
" validation_set : pd.DataFrame\n",
|
| 1089 |
+
" Validation dataset with the same structure as `train_set`.\n",
|
| 1090 |
+
" test_set : pd.DataFrame\n",
|
| 1091 |
+
" Test dataset with the same structure as `train_set`.\n",
|
| 1092 |
+
" checkpoint_dir_path : str, optional (default='checkpoints/')\n",
|
| 1093 |
+
" Directory to save model checkpoints.\n",
|
| 1094 |
+
" checkpoint_path : str or None, optional (default=None)\n",
|
| 1095 |
+
" Path to a checkpoint file for resuming training or loading a pretrained model for testing.\n",
|
| 1096 |
+
" n_epochs : int, optional (default=10)\n",
|
| 1097 |
+
" Number of training epochs.\n",
|
| 1098 |
+
" mode : {'train', 'test'}, optional (default='train')\n",
|
| 1099 |
+
" - `'train'`: trains the model on the training/validation data.\n",
|
| 1100 |
+
" - `'test'`: evaluates the model on the test set using a checkpoint.\n",
|
| 1101 |
+
" batchsize : int, optional (default=256)\n",
|
| 1102 |
+
" Batch size for training and evaluation.\n",
|
| 1103 |
+
" max_token_len : int, optional (default=50)\n",
|
| 1104 |
+
" Maximum sequence length per user (recent interactions kept).\n",
|
| 1105 |
+
" learning_rate : float, optional (default=1e-3)\n",
|
| 1106 |
+
" Learning rate for the AdamW optimizer.\n",
|
| 1107 |
+
" hidden_dim : int, optional (default=128)\n",
|
| 1108 |
+
" Dimensionality of item and positional embeddings.\n",
|
| 1109 |
+
" num_heads : int, optional (default=2)\n",
|
| 1110 |
+
" Number of attention heads in each Transformer encoder layer.\n",
|
| 1111 |
+
" num_layers : int, optional (default=2)\n",
|
| 1112 |
+
" Number of Transformer encoder layers.\n",
|
| 1113 |
+
" dropout : float, optional (default=0.2)\n",
|
| 1114 |
+
" Dropout probability applied in embeddings and Transformer layers.\n",
|
| 1115 |
+
" weight_decay : float, optional (default=1e-6)\n",
|
| 1116 |
+
" Weight decay regularization coefficient for AdamW.\n",
|
| 1117 |
+
" \"\"\"\n",
|
| 1118 |
+
" # --- 1. Initialize DataModule ---\n",
|
| 1119 |
+
" print(\"Initializing DataModule...\")\n",
|
| 1120 |
+
" datamodule = SASRecDataModule(\n",
|
| 1121 |
+
" train_df=train_set,\n",
|
| 1122 |
+
" val_df=validation_set,\n",
|
| 1123 |
+
" test_df=test_set,\n",
|
| 1124 |
+
" batch_size=batchsize,\n",
|
| 1125 |
+
" max_len=max_token_len\n",
|
| 1126 |
+
" )\n",
|
| 1127 |
+
" datamodule.setup()\n",
|
| 1128 |
+
"\n",
|
| 1129 |
+
" # --- 2. Initialize Model ---\n",
|
| 1130 |
+
" print(\"Initializing SASRec model...\")\n",
|
| 1131 |
+
" model = SASRec(\n",
|
| 1132 |
+
" vocab_size=datamodule.vocab_size,\n",
|
| 1133 |
+
" max_len=max_token_len,\n",
|
| 1134 |
+
" hidden_dim=hidden_dim,\n",
|
| 1135 |
+
" num_heads=num_heads,\n",
|
| 1136 |
+
" num_layers=num_layers,\n",
|
| 1137 |
+
" dropout=dropout,\n",
|
| 1138 |
+
" learning_rate=learning_rate,\n",
|
| 1139 |
+
" weight_decay=weight_decay\n",
|
| 1140 |
+
" )\n",
|
| 1141 |
+
"\n",
|
| 1142 |
+
" # --- 3. Configure Training Callbacks ---\n",
|
| 1143 |
+
" checkpoint_callback = ModelCheckpoint(\n",
|
| 1144 |
+
" dirpath=checkpoint_dir_path,\n",
|
| 1145 |
+
" filename=\"sasrec-{epoch:02d}-{val_hitrate@10:.4f}\",\n",
|
| 1146 |
+
" save_top_k=1,\n",
|
| 1147 |
+
" verbose=True,\n",
|
| 1148 |
+
" monitor=\"val_hitrate@10\",\n",
|
| 1149 |
+
" mode=\"max\"\n",
|
| 1150 |
+
" )\n",
|
| 1151 |
+
"\n",
|
| 1152 |
+
" early_stopping_callback = EarlyStopping(\n",
|
| 1153 |
+
" monitor=\"val_hitrate@10\", # stop if ranking metric stagnates\n",
|
| 1154 |
+
" patience=5,\n",
|
| 1155 |
+
" mode=\"max\"\n",
|
| 1156 |
+
" )\n",
|
| 1157 |
+
"\n",
|
| 1158 |
+
" lr_monitor = LearningRateMonitor(logging_interval=\"step\")\n",
|
| 1159 |
+
"\n",
|
| 1160 |
+
" logger = TensorBoardLogger(\"lightning_logs\", name=\"sasrec\")\n",
|
| 1161 |
+
"\n",
|
| 1162 |
+
" # --- 4. Initialize Trainer ---\n",
|
| 1163 |
+
" print(\"Initializing PyTorch Lightning Trainer...\")\n",
|
| 1164 |
+
" trainer = pl.Trainer(\n",
|
| 1165 |
+
" logger=logger,\n",
|
| 1166 |
+
" callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],\n",
|
| 1167 |
+
" max_epochs=n_epochs,\n",
|
| 1168 |
+
" accelerator='auto',\n",
|
| 1169 |
+
" devices=1,\n",
|
| 1170 |
+
" gradient_clip_val=1.0, # helps with exploding gradients\n",
|
| 1171 |
+
" )\n",
|
| 1172 |
+
"\n",
|
| 1173 |
+
" if mode == 'train' :\n",
|
| 1174 |
+
" # --- 5. Start Training ---\n",
|
| 1175 |
+
" print(f\"Starting training for up to {n_epochs} epochs...\")\n",
|
| 1176 |
+
" trainer.fit(model, datamodule,\n",
|
| 1177 |
+
" ckpt_path=checkpoint_path\n",
|
| 1178 |
+
" )\n",
|
| 1179 |
+
"\n",
|
| 1180 |
+
" elif mode == 'test':\n",
|
| 1181 |
+
" # --- 6. Test on best checkpoint ---\n",
|
| 1182 |
+
" print(\"Evaluating on test set...\")\n",
|
| 1183 |
+
" trainer.test(model, datamodule,\n",
|
| 1184 |
+
" ckpt_path=checkpoint_path\n",
|
| 1185 |
+
" )\n"
|
| 1186 |
+
]
|
| 1187 |
+
},
|
| 1188 |
+
{
|
| 1189 |
+
"cell_type": "code",
|
| 1190 |
+
"execution_count": null,
|
| 1191 |
+
"id": "5d4a2a7b",
|
| 1192 |
+
"metadata": {},
|
| 1193 |
+
"outputs": [],
|
| 1194 |
+
"source": [
|
| 1195 |
+
"# --- Configuration ---\n",
|
| 1196 |
+
"BATCH_SIZE = 256\n",
|
| 1197 |
+
"MAX_TOKEN_LEN = 50 # 50–100 is standard\n",
|
| 1198 |
+
"LEARNING_RATE = 1e-3\n",
|
| 1199 |
+
"HIDDEN_DIM = 128\n",
|
| 1200 |
+
"NUM_HEADS = 2\n",
|
| 1201 |
+
"NUM_LAYERS = 2\n",
|
| 1202 |
+
"DROPOUT = 0.2\n",
|
| 1203 |
+
"WEIGHT_DECAY = 1e-6\n",
|
| 1204 |
+
"N_EPOCHS = 50\n",
|
| 1205 |
+
"MODE = 'train' # 'train' or 'test'\n",
|
| 1206 |
+
"\n",
|
| 1207 |
+
"# Train and evaluate SASRec model\n",
|
| 1208 |
+
"print(\"\\n>>> Training and evaluating SASRec model <<<\")\n",
|
| 1209 |
+
"train_and_eval_SASRec_model(train_set, validation_set, test_set, n_epochs=10, mode='train')\n",
|
| 1210 |
+
"\n",
|
| 1211 |
+
"print(\"\\n>>> Evaluating trained SASRec model on TEST set <<<\")\n",
|
| 1212 |
+
"train_and_eval_SASRec_model(train_set, validation_set, test_set, mode='test')"
|
| 1213 |
+
]
|
| 1214 |
+
},
|
| 1215 |
+
{
|
| 1216 |
+
"cell_type": "markdown",
|
| 1217 |
+
"id": "468e0951",
|
| 1218 |
+
"metadata": {},
|
| 1219 |
+
"source": [
|
| 1220 |
+
"## Main function to run the complete Recommender System"
|
| 1221 |
+
]
|
| 1222 |
+
},
|
| 1223 |
+
{
|
| 1224 |
+
"cell_type": "code",
|
| 1225 |
+
"execution_count": null,
|
| 1226 |
+
"id": "8f810e9a",
|
| 1227 |
+
"metadata": {},
|
| 1228 |
+
"outputs": [],
|
| 1229 |
+
"source": [
|
| 1230 |
+
"def load_item_properties(data_folder='data/'):\n",
|
| 1231 |
+
" \"\"\"\n",
|
| 1232 |
+
" Loads item properties and creates a mapping from item ID to its category ID.\n",
|
| 1233 |
+
" Handles both a single properties file or two split parts.\n",
|
| 1234 |
+
" \n",
|
| 1235 |
+
" Args:\n",
|
| 1236 |
+
" data_folder (str): The path to the folder containing item property files.\n",
|
| 1237 |
+
"\n",
|
| 1238 |
+
" Returns:\n",
|
| 1239 |
+
" dict: A dictionary mapping {itemid: categoryid}.\n",
|
| 1240 |
+
" \"\"\"\n",
|
| 1241 |
+
" print(\"Loading item properties...\")\n",
|
| 1242 |
+
" try:\n",
|
| 1243 |
+
" # First, try to load the two separate parts and combine them.\n",
|
| 1244 |
+
" props_df_part1 = pd.read_csv(data_folder + 'item_properties_part1.csv')\n",
|
| 1245 |
+
" props_df_part2 = pd.read_csv(data_folder + 'item_properties_part2.csv')\n",
|
| 1246 |
+
" props_df = pd.concat([props_df_part1, props_df_part2], ignore_index=True)\n",
|
| 1247 |
+
" print(\"Successfully loaded and combined item_properties_part1.csv and item_properties_part2.csv.\")\n",
|
| 1248 |
+
"\n",
|
| 1249 |
+
" except FileNotFoundError:\n",
|
| 1250 |
+
" try:\n",
|
| 1251 |
+
" # If the parts are not found, try to load a single combined file.\n",
|
| 1252 |
+
" props_df = pd.read_csv(data_folder + 'item_properties.csv')\n",
|
| 1253 |
+
" print(\"Successfully loaded a single item_properties.csv.\")\n",
|
| 1254 |
+
" except FileNotFoundError:\n",
|
| 1255 |
+
" print(f\"Warning: No item properties files found. Cannot display category information.\")\n",
|
| 1256 |
+
" return {}\n",
|
| 1257 |
+
"\n",
|
| 1258 |
+
" category_df = props_df[props_df['property'] == 'categoryid'].copy()\n",
|
| 1259 |
+
" category_df['value'] = pd.to_numeric(category_df['value'], errors='coerce').astype('Int64')\n",
|
| 1260 |
+
" item_to_category_map = category_df.set_index('itemid')['value'].to_dict()\n",
|
| 1261 |
+
" print(\"Item to category mapping created successfully.\")\n",
|
| 1262 |
+
" return item_to_category_map\n",
|
| 1263 |
+
"\n",
|
| 1264 |
+
"def load_category_tree(data_folder='data/'):\n",
|
| 1265 |
+
" \"\"\"\n",
|
| 1266 |
+
" Loads the category tree to map categories to their parent categories.\n",
|
| 1267 |
+
"\n",
|
| 1268 |
+
" Args:\n",
|
| 1269 |
+
" data_folder (str): The path to the folder containing category_tree.csv.\n",
|
| 1270 |
+
"\n",
|
| 1271 |
+
" Returns:\n",
|
| 1272 |
+
" dict: A dictionary mapping {categoryid: parentid}.\n",
|
| 1273 |
+
" \"\"\"\n",
|
| 1274 |
+
" print(\"Loading category tree...\")\n",
|
| 1275 |
+
" try:\n",
|
| 1276 |
+
" tree_df = pd.read_csv(data_folder + 'category_tree.csv')\n",
|
| 1277 |
+
" category_parent_map = tree_df.set_index('categoryid')['parentid'].to_dict()\n",
|
| 1278 |
+
" print(\"Category tree loaded successfully.\")\n",
|
| 1279 |
+
" return category_parent_map\n",
|
| 1280 |
+
" except FileNotFoundError:\n",
|
| 1281 |
+
" print(\"Warning: 'category_tree.csv' not found. Cannot display parent category information.\")\n",
|
| 1282 |
+
" return {}\n",
|
| 1283 |
+
"\n",
|
| 1284 |
+
"def get_popular_items(train_df, k=10):\n",
|
| 1285 |
+
" \"\"\"\n",
|
| 1286 |
+
" Calculates the top-k most popular items based on transaction count.\n",
|
| 1287 |
+
" \"\"\"\n",
|
| 1288 |
+
" purchase_counts = train_df[train_df['event'] == 'transaction']['itemid'].value_counts()\n",
|
| 1289 |
+
" return purchase_counts.head(k).index.tolist()\n",
|
| 1290 |
+
"\n",
|
| 1291 |
+
"def show_user_recommendations(visitor_id, model, datamodule, popular_items, item_category_map, category_parent_map, k=10):\n",
|
| 1292 |
+
" \"\"\"\n",
|
| 1293 |
+
" Displays recommendations for a user, including category and parent category information.\n",
|
| 1294 |
+
" \"\"\"\n",
|
| 1295 |
+
" print(f\"\\n--- Recommendations for Visitor ID: {visitor_id} ---\")\n",
|
| 1296 |
+
" model.eval()\n",
|
| 1297 |
+
"\n",
|
| 1298 |
+
" def format_item_with_category(item_id):\n",
|
| 1299 |
+
" category_id = item_category_map.get(item_id, 'N/A')\n",
|
| 1300 |
+
" parent_id = category_parent_map.get(category_id, 'N/A') if category_id != 'N/A' else 'N/A'\n",
|
| 1301 |
+
" return f\"Item: {item_id} (Category: {category_id}, Parent: {parent_id})\"\n",
|
| 1302 |
+
"\n",
|
| 1303 |
+
" user_history_ids = datamodule.user_history.get(visitor_id)\n",
|
| 1304 |
+
"\n",
|
| 1305 |
+
" if user_history_ids is None:\n",
|
| 1306 |
+
" print(f\"User {visitor_id} not found in training history. Providing popularity-based recommendations.\")\n",
|
| 1307 |
+
" print(f\"\\nTop {k} Popular Items (Fallback):\")\n",
|
| 1308 |
+
" recs_with_cats = [format_item_with_category(item_id) for item_id in popular_items]\n",
|
| 1309 |
+
" print(recs_with_cats)\n",
|
| 1310 |
+
" print(\"-------------------------------------------------\")\n",
|
| 1311 |
+
" return\n",
|
| 1312 |
+
"\n",
|
| 1313 |
+
" history_with_cats = [format_item_with_category(item_id) for item_id in user_history_ids]\n",
|
| 1314 |
+
" print(f\"User's Historical Interactions:\")\n",
|
| 1315 |
+
" print(history_with_cats)\n",
|
| 1316 |
+
"\n",
|
| 1317 |
+
" history_indices = [datamodule.item_map[i] for i in user_history_ids if i in datamodule.item_map]\n",
|
| 1318 |
+
" if not history_indices:\n",
|
| 1319 |
+
" print(\"None of the user's historical items are in the model's vocabulary.\")\n",
|
| 1320 |
+
" return\n",
|
| 1321 |
+
"\n",
|
| 1322 |
+
" max_len = datamodule.max_len\n",
|
| 1323 |
+
" input_seq = history_indices[-max_len:]\n",
|
| 1324 |
+
" padded_input = np.zeros(max_len, dtype=np.int64)\n",
|
| 1325 |
+
" padded_input[-len(input_seq):] = input_seq\n",
|
| 1326 |
+
" \n",
|
| 1327 |
+
" input_tensor = torch.LongTensor(np.array([padded_input]))\n",
|
| 1328 |
+
" input_tensor = input_tensor.to(model.device)\n",
|
| 1329 |
+
"\n",
|
| 1330 |
+
" with torch.no_grad():\n",
|
| 1331 |
+
" logits = model(input_tensor)\n",
|
| 1332 |
+
" last_item_logits = logits[0, -1, :]\n",
|
| 1333 |
+
" top_indices = torch.topk(last_item_logits, k).indices.tolist()\n",
|
| 1334 |
+
"\n",
|
| 1335 |
+
" recommended_item_ids = [datamodule.inverse_item_map[idx] for idx in top_indices if idx in datamodule.inverse_item_map]\n",
|
| 1336 |
+
"\n",
|
| 1337 |
+
" print(f\"\\nTop {k} Recommended Items:\")\n",
|
| 1338 |
+
" recs_with_cats = [format_item_with_category(item_id) for item_id in recommended_item_ids]\n",
|
| 1339 |
+
" print(recs_with_cats)\n",
|
| 1340 |
+
" print(\"-------------------------------------------------\")\n"
|
| 1341 |
+
]
|
| 1342 |
+
},
|
| 1343 |
+
{
|
| 1344 |
+
"cell_type": "code",
|
| 1345 |
+
"execution_count": null,
|
| 1346 |
+
"id": "735d0f8d",
|
| 1347 |
+
"metadata": {},
|
| 1348 |
+
"outputs": [],
|
| 1349 |
+
"source": [
|
| 1350 |
+
"def main(checkpoint_path=\"checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt\", data_folder=\"data/\"):\n",
|
| 1351 |
+
" \"\"\"\n",
|
| 1352 |
+
" Main function to run the inference and qualitative analysis pipeline.\n",
|
| 1353 |
+
" \"\"\"\n",
|
| 1354 |
+
"\n",
|
| 1355 |
+
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 1356 |
+
" print(f\"Using device: {device}\")\n",
|
| 1357 |
+
"\n",
|
| 1358 |
+
" print(\"Loading model from checkpoint...\")\n",
|
| 1359 |
+
" best_model = SASRec.load_from_checkpoint(checkpoint_path)\n",
|
| 1360 |
+
" best_model.to(device)\n",
|
| 1361 |
+
"\n",
|
| 1362 |
+
" print(\"Preparing data...\")\n",
|
| 1363 |
+
" train_set, validation_set, test_set = prepare_data(data_folder=data_folder)\n",
|
| 1364 |
+
" \n",
|
| 1365 |
+
" datamodule = SASRecDataModule(train_set, validation_set, test_set)\n",
|
| 1366 |
+
" datamodule.setup()\n",
|
| 1367 |
+
" \n",
|
| 1368 |
+
" item_category_map = load_item_properties(data_folder=data_folder)\n",
|
| 1369 |
+
" category_parent_map = load_category_tree(data_folder=data_folder)\n",
|
| 1370 |
+
" \n",
|
| 1371 |
+
" print(\"\\nCalculating popular items for cold-start users...\")\n",
|
| 1372 |
+
" popular_items_list = get_popular_items(train_set, k=10)\n",
|
| 1373 |
+
"\n",
|
| 1374 |
+
" users_in_train_history = set(datamodule.user_history.keys())\n",
|
| 1375 |
+
" users_in_test_set = set(datamodule.test_df['visitorid'].unique())\n",
|
| 1376 |
+
" valid_example_users = list(users_in_train_history.intersection(users_in_test_set))\n",
|
| 1377 |
+
"\n",
|
| 1378 |
+
" print(f\"\\nFound {len(valid_example_users)} users for qualitative analysis.\")\n",
|
| 1379 |
+
" \n",
|
| 1380 |
+
" for user_id in valid_example_users[:3]:\n",
|
| 1381 |
+
" show_user_recommendations(user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)\n",
|
| 1382 |
+
" \n",
|
| 1383 |
+
" new_user_id = -999\n",
|
| 1384 |
+
" show_user_recommendations(new_user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)\n"
|
| 1385 |
+
]
|
| 1386 |
+
},
|
| 1387 |
+
{
|
| 1388 |
+
"cell_type": "code",
|
| 1389 |
+
"execution_count": null,
|
| 1390 |
+
"id": "0e7ba5f2",
|
| 1391 |
+
"metadata": {},
|
| 1392 |
+
"outputs": [],
|
| 1393 |
+
"source": [
|
| 1394 |
+
"main()"
|
| 1395 |
+
]
|
| 1396 |
+
}
|
| 1397 |
+
],
|
| 1398 |
+
"metadata": {
|
| 1399 |
+
"kernelspec": {
|
| 1400 |
+
"display_name": "Python 3",
|
| 1401 |
+
"language": "python",
|
| 1402 |
+
"name": "python3"
|
| 1403 |
+
},
|
| 1404 |
+
"language_info": {
|
| 1405 |
+
"codemirror_mode": {
|
| 1406 |
+
"name": "ipython",
|
| 1407 |
+
"version": 3
|
| 1408 |
+
},
|
| 1409 |
+
"file_extension": ".py",
|
| 1410 |
+
"mimetype": "text/x-python",
|
| 1411 |
+
"name": "python",
|
| 1412 |
+
"nbconvert_exporter": "python",
|
| 1413 |
+
"pygments_lexer": "ipython3",
|
| 1414 |
+
"version": "3.10.6"
|
| 1415 |
+
}
|
| 1416 |
+
},
|
| 1417 |
+
"nbformat": 4,
|
| 1418 |
+
"nbformat_minor": 5
|
| 1419 |
+
}
|
requirements.txt
ADDED
|
Binary file (304 Bytes). View file
|
|
|
scripts/als_optuna_study.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
from scipy.sparse import csr_matrix
|
| 5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 6 |
+
import implicit
|
| 7 |
+
from utils import prepare_ground_truth, calculate_metrics
|
| 8 |
+
from models import recommend_als_and_evaluate
|
| 9 |
+
from data_prepare import prepare_data
|
| 10 |
+
|
| 11 |
+
def objective_als(trial, train_df, val_df):
|
| 12 |
+
"""
|
| 13 |
+
The objective function for Optuna to optimize.
|
| 14 |
+
"""
|
| 15 |
+
# 1. Define the hyperparameter search space
|
| 16 |
+
params = {
|
| 17 |
+
'factors': trial.suggest_int('factors', 20, 200),
|
| 18 |
+
'regularization': trial.suggest_float('regularization', 1e-3, 1e-1, log=True),
|
| 19 |
+
'iterations': trial.suggest_int('iterations', 10, 50)
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
# 2. Run an evaluation with the suggested parameters
|
| 23 |
+
metrics = recommend_als_and_evaluate(train_df, val_df, **params)
|
| 24 |
+
|
| 25 |
+
# 3. Return the metric we want to maximize (precision)
|
| 26 |
+
return metrics['mean_precision@k']
|
| 27 |
+
|
| 28 |
+
def tune_als_hyperparameters(train_df, val_df, n_trials=25):
|
| 29 |
+
"""
|
| 30 |
+
Orchestrates the Optuna study to find the best hyperparameters for ALS.
|
| 31 |
+
"""
|
| 32 |
+
study = optuna.create_study(direction='maximize')
|
| 33 |
+
study.optimize(lambda trial: objective_als(trial, train_df, val_df), n_trials=n_trials)
|
| 34 |
+
|
| 35 |
+
print("\n--- Optuna Study Complete ---")
|
| 36 |
+
print(f"Number of finished trials: {len(study.trials)}")
|
| 37 |
+
print("Best trial:")
|
| 38 |
+
trial = study.best_trial
|
| 39 |
+
print(f" Value (Precision@10): {trial.value}")
|
| 40 |
+
print(" Params: ")
|
| 41 |
+
for key, value in trial.params.items():
|
| 42 |
+
print(f" {key}: {value}")
|
| 43 |
+
|
| 44 |
+
return trial.params
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
# 1. Prepare all data
|
| 48 |
+
train_set, validation_set, test_set = prepare_data()
|
| 49 |
+
|
| 50 |
+
# --- Hyperparameter Tuning Step ---
|
| 51 |
+
print("\n>>> 1. TUNING ALS Hyperparameters on the VALIDATION set <<<")
|
| 52 |
+
|
| 53 |
+
best_als_params = tune_als_hyperparameters(train_set, validation_set, n_trials=25)
|
scripts/app.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
from models import SASRec, SASRecDataModule
|
| 7 |
+
from data_prepare import SASRecDataset, SASRecDataModule, prepare_data
|
| 8 |
+
from utils import load_item_properties, load_category_tree, get_popular_items
|
| 9 |
+
|
| 10 |
+
# --- Global variables to hold loaded artifacts ---
|
| 11 |
+
# This prevents reloading the model and data on every prediction.
|
| 12 |
+
MODEL = None
|
| 13 |
+
DATAMODULE = None
|
| 14 |
+
ITEM_CATEGORY_MAP = None
|
| 15 |
+
CATEGORY_PARENT_MAP = None
|
| 16 |
+
POPULAR_ITEMS = None
|
| 17 |
+
|
| 18 |
+
# --- Data Loading and Preparation Functions ---
|
| 19 |
+
|
| 20 |
+
def load_artifacts():
|
| 21 |
+
"""
|
| 22 |
+
Loads all necessary artifacts (model, data, mappings) into global variables.
|
| 23 |
+
This function is called only once when the app starts.
|
| 24 |
+
"""
|
| 25 |
+
global MODEL, DATAMODULE, ITEM_CATEGORY_MAP, CATEGORY_PARENT_MAP, POPULAR_ITEMS
|
| 26 |
+
|
| 27 |
+
print("--- Loading all artifacts for the Gradio app ---")
|
| 28 |
+
|
| 29 |
+
# HF-FRIENDLY: Path is relative, assuming the checkpoint is in the root of the Space repo.
|
| 30 |
+
CHECKPOINT_PATH = "sasrec-epoch=05-val_hitrate@10=0.3614.ckpt"
|
| 31 |
+
DATA_FOLDER = "data/"
|
| 32 |
+
|
| 33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
+
print(f"Using device: {device}")
|
| 35 |
+
|
| 36 |
+
print(f"Loading model from checkpoint: {CHECKPOINT_PATH}...")
|
| 37 |
+
MODEL = SASRec.load_from_checkpoint(CHECKPOINT_PATH)
|
| 38 |
+
MODEL.to(device)
|
| 39 |
+
MODEL.eval()
|
| 40 |
+
|
| 41 |
+
print("Preparing data...")
|
| 42 |
+
train_set, validation_set, test_set = prepare_data(data_folder=DATA_FOLDER)
|
| 43 |
+
|
| 44 |
+
DATAMODULE = SASRecDataModule(train_set, validation_set, test_set)
|
| 45 |
+
DATAMODULE.setup()
|
| 46 |
+
|
| 47 |
+
print("Loading item and category maps...")
|
| 48 |
+
ITEM_CATEGORY_MAP = load_item_properties(data_folder=DATA_FOLDER)
|
| 49 |
+
CATEGORY_PARENT_MAP = load_category_tree(data_folder=DATA_FOLDER)
|
| 50 |
+
|
| 51 |
+
print("Calculating popular items for cold-start users...")
|
| 52 |
+
POPULAR_ITEMS = get_popular_items(train_set, k=10)
|
| 53 |
+
|
| 54 |
+
print("--- Artifacts loaded successfully. Ready to serve recommendations. ---")
|
| 55 |
+
|
| 56 |
+
def get_recommendations(visitor_id_str):
|
| 57 |
+
"""
|
| 58 |
+
The main prediction function for the Gradio interface.
|
| 59 |
+
|
| 60 |
+
Takes a visitor ID string, gets recommendations, and formats them for display.
|
| 61 |
+
"""
|
| 62 |
+
try:
|
| 63 |
+
visitor_id = int(visitor_id_str)
|
| 64 |
+
except (ValueError, TypeError):
|
| 65 |
+
return pd.DataFrame(), pd.DataFrame(), "Please enter a valid numerical Visitor ID."
|
| 66 |
+
|
| 67 |
+
user_history_ids = DATAMODULE.user_history.get(visitor_id)
|
| 68 |
+
|
| 69 |
+
def format_to_df(item_list):
|
| 70 |
+
data = []
|
| 71 |
+
for rank, item_id in enumerate(item_list, 1):
|
| 72 |
+
category_id = ITEM_CATEGORY_MAP.get(item_id, 'N/A')
|
| 73 |
+
parent_id = CATEGORY_PARENT_MAP.get(category_id, 'N/A') if pd.notna(category_id) else 'N/A'
|
| 74 |
+
data.append([rank, item_id, category_id, parent_id])
|
| 75 |
+
return pd.DataFrame(data, columns=['Rank', 'Item ID', 'Category ID', 'Parent ID'])
|
| 76 |
+
|
| 77 |
+
# --- Cold-Start User (Fallback to Popularity) ---
|
| 78 |
+
if user_history_ids is None:
|
| 79 |
+
history_df = pd.DataFrame(columns=['Rank', 'Item ID', 'Category ID', 'Parent ID'])
|
| 80 |
+
recs_df = format_to_df(POPULAR_ITEMS)
|
| 81 |
+
message = f"User {visitor_id} is new. Showing Top 10 popular items as a fallback."
|
| 82 |
+
return history_df, recs_df, message
|
| 83 |
+
|
| 84 |
+
# --- Existing User (Use SASRec Model) ---
|
| 85 |
+
history_df = format_to_df(user_history_ids)
|
| 86 |
+
|
| 87 |
+
history_indices = [DATAMODULE.item_map[i] for i in user_history_ids if i in DATAMODULE.item_map]
|
| 88 |
+
|
| 89 |
+
if not history_indices:
|
| 90 |
+
message = "None of this user's historical items are in the model's vocabulary."
|
| 91 |
+
return history_df, pd.DataFrame(), message
|
| 92 |
+
|
| 93 |
+
max_len = DATAMODULE.max_len
|
| 94 |
+
input_seq = history_indices[-max_len:]
|
| 95 |
+
padded_input = np.zeros(max_len, dtype=np.int64)
|
| 96 |
+
padded_input[-len(input_seq):] = input_seq
|
| 97 |
+
|
| 98 |
+
input_tensor = torch.LongTensor(np.array([padded_input]))
|
| 99 |
+
input_tensor = input_tensor.to(MODEL.device)
|
| 100 |
+
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
logits = MODEL(input_tensor)
|
| 103 |
+
last_item_logits = logits[0, -1, :]
|
| 104 |
+
top_indices = torch.topk(last_item_logits, 10).indices.tolist()
|
| 105 |
+
|
| 106 |
+
recommended_item_ids = [DATAMODULE.inverse_item_map[idx] for idx in top_indices if idx in DATAMODULE.inverse_item_map]
|
| 107 |
+
recs_df = format_to_df(recommended_item_ids)
|
| 108 |
+
message = f"Showing personalized SASRec recommendations for user {visitor_id}."
|
| 109 |
+
|
| 110 |
+
return history_df, recs_df, message
|
| 111 |
+
|
| 112 |
+
# --- Main Execution Block ---
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
# Load all artifacts once at startup
|
| 115 |
+
load_artifacts()
|
| 116 |
+
|
| 117 |
+
# Find some valid example users to show in the UI
|
| 118 |
+
users_in_train_history = set(DATAMODULE.user_history.keys())
|
| 119 |
+
users_in_test_set = set(DATAMODULE.test_df['visitorid'].unique())
|
| 120 |
+
valid_example_users = list(users_in_train_history.intersection(users_in_test_set))
|
| 121 |
+
|
| 122 |
+
# Convert numpy types to standard Python int for Gradio compatibility
|
| 123 |
+
example_list = [int(u) for u in valid_example_users[:4]] + [-999]
|
| 124 |
+
|
| 125 |
+
# Create and launch the Gradio interface
|
| 126 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="SASRec Recommender") as iface:
|
| 127 |
+
gr.Markdown(
|
| 128 |
+
"""
|
| 129 |
+
# SASRec Sequential Recommender System
|
| 130 |
+
An interactive demo of a state-of-the-art recommender system trained on the RetailRocket dataset.
|
| 131 |
+
"""
|
| 132 |
+
)
|
| 133 |
+
with gr.Row():
|
| 134 |
+
with gr.Column(scale=1):
|
| 135 |
+
visitor_id_input = gr.Number(
|
| 136 |
+
label="Enter Visitor ID",
|
| 137 |
+
info="Enter a user's numerical ID to get recommendations."
|
| 138 |
+
)
|
| 139 |
+
submit_button = gr.Button("Get Recommendations", variant="primary")
|
| 140 |
+
gr.Examples(
|
| 141 |
+
examples=example_list,
|
| 142 |
+
inputs=visitor_id_input,
|
| 143 |
+
label="Example User IDs (Click to try)"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
with gr.Column(scale=3):
|
| 147 |
+
status_message = gr.Textbox(label="Status", interactive=False)
|
| 148 |
+
with gr.Tabs():
|
| 149 |
+
with gr.TabItem("Top 10 Recommendations"):
|
| 150 |
+
recs_output = gr.DataFrame(label="Recommended Items")
|
| 151 |
+
with gr.TabItem("User's Recent History"):
|
| 152 |
+
history_output = gr.DataFrame(label="Interaction History")
|
| 153 |
+
|
| 154 |
+
submit_button.click(
|
| 155 |
+
fn=get_recommendations,
|
| 156 |
+
inputs=visitor_id_input,
|
| 157 |
+
outputs=[history_output, recs_output, status_message]
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# For local testing, this creates a shareable link.
|
| 161 |
+
# On Hugging Face Spaces, this is not strictly necessary but doesn't hurt.
|
| 162 |
+
iface.launch(share=True)
|
scripts/data_prepare.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import zipfile
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy.sparse import csr_matrix
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
import pytorch_lightning as pl
|
| 11 |
+
|
| 12 |
+
def extract_ziped_data(ziped_data_path: str, extract_path : str):
|
| 13 |
+
"""Extracts the contents of a zip file to a specified directory.
|
| 14 |
+
|
| 15 |
+
args:
|
| 16 |
+
ziped_data_path: str, path to the zip file
|
| 17 |
+
extract_path: str, path to the directory where contents will be extracted
|
| 18 |
+
"""
|
| 19 |
+
# The directory where you want to extract the contents
|
| 20 |
+
extract_path = 'data'
|
| 21 |
+
|
| 22 |
+
# Open the zip file in read mode
|
| 23 |
+
with zipfile.ZipFile(ziped_data_path, 'r') as zip_ref:
|
| 24 |
+
# Extract all the contents into the specified directory
|
| 25 |
+
zip_ref.extractall(extract_path)
|
| 26 |
+
|
| 27 |
+
print(f"'{ziped_data_path}' has been extracted to '{extract_path}'")
|
| 28 |
+
|
| 29 |
+
def prepare_data(data_folder='data/', val_days=7, test_days=7):
|
| 30 |
+
"""
|
| 31 |
+
Loads, preprocesses, and splits the events data into train, validation, and test sets.
|
| 32 |
+
|
| 33 |
+
args:
|
| 34 |
+
data_folder: str, path to the folder containing 'events.csv'
|
| 35 |
+
val_days: int, number of days for the validation set
|
| 36 |
+
test_days: int, number of days for the test set
|
| 37 |
+
"""
|
| 38 |
+
# --- Load Data ---
|
| 39 |
+
print(f"Loading events.csv from folder: {data_folder}")
|
| 40 |
+
try:
|
| 41 |
+
events_df = pd.read_csv(data_folder + 'events.csv')
|
| 42 |
+
print("Successfully loaded events.csv.")
|
| 43 |
+
events_df['timestamp_dt'] = pd.to_datetime(events_df['timestamp'], unit='ms')
|
| 44 |
+
print("\n--- Initial Data Summary ---")
|
| 45 |
+
print(f"Data shape: {events_df.shape}")
|
| 46 |
+
print(f"Full timeframe: {events_df['timestamp_dt'].min()} to {events_df['timestamp_dt'].max()}")
|
| 47 |
+
print("----------------------------\n")
|
| 48 |
+
except FileNotFoundError:
|
| 49 |
+
print(f"Error: 'events.csv' not found in '{data_folder}'. Please check the path.")
|
| 50 |
+
return None, None, None
|
| 51 |
+
|
| 52 |
+
# --- Split Data ---
|
| 53 |
+
sorted_df = events_df.sort_values('timestamp_dt').reset_index(drop=True)
|
| 54 |
+
print(f"Splitting data: {test_days} days for test, {val_days} for validation.")
|
| 55 |
+
end_time = sorted_df['timestamp_dt'].max()
|
| 56 |
+
test_start_time = end_time - timedelta(days=test_days)
|
| 57 |
+
val_start_time = test_start_time - timedelta(days=val_days)
|
| 58 |
+
|
| 59 |
+
test_df = sorted_df[sorted_df['timestamp_dt'] >= test_start_time]
|
| 60 |
+
val_df = sorted_df[(sorted_df['timestamp_dt'] >= val_start_time) & (sorted_df['timestamp_dt'] < test_start_time)]
|
| 61 |
+
train_df = sorted_df[sorted_df['timestamp_dt'] < val_start_time]
|
| 62 |
+
|
| 63 |
+
print("--- Data Splitting Summary ---")
|
| 64 |
+
print(f"Training set: {train_df.shape[0]:>8} records | from {train_df['timestamp_dt'].min()} to {train_df['timestamp_dt'].max()}")
|
| 65 |
+
print(f"Validation set: {val_df.shape[0]:>8} records | from {val_df['timestamp_dt'].min()} to {val_df['timestamp_dt'].max()}")
|
| 66 |
+
print(f"Test set: {test_df.shape[0]:>8} records | from {test_df['timestamp_dt'].min()} to {test_df['timestamp_dt'].max()}")
|
| 67 |
+
print("------------------------------")
|
| 68 |
+
|
| 69 |
+
return train_df, val_df, test_df
|
| 70 |
+
|
| 71 |
+
class SASRecDataset(Dataset):
|
| 72 |
+
"""
|
| 73 |
+
SASRec Dataset.
|
| 74 |
+
- Precomputes (sequence_id, cutoff_idx) pairs for O(1) __getitem__.
|
| 75 |
+
- Supports 'last' or 'all' target modes.
|
| 76 |
+
"""
|
| 77 |
+
def __init__(self, sequences, max_len, target_mode="last"):
|
| 78 |
+
"""
|
| 79 |
+
Args:
|
| 80 |
+
sequences: list of user sequences (list of item IDs).
|
| 81 |
+
max_len: maximum sequence length (padding applied).
|
| 82 |
+
target_mode: 'last' (only last prediction) or 'all' (predict at every step).
|
| 83 |
+
"""
|
| 84 |
+
self.sequences = sequences
|
| 85 |
+
self.max_len = max_len
|
| 86 |
+
self.target_mode = target_mode
|
| 87 |
+
|
| 88 |
+
# Build index once
|
| 89 |
+
self.index = []
|
| 90 |
+
for seq_id, seq in enumerate(sequences):
|
| 91 |
+
for i in range(1, len(seq)):
|
| 92 |
+
self.index.append((seq_id, i))
|
| 93 |
+
|
| 94 |
+
def __len__(self):
|
| 95 |
+
return len(self.index)
|
| 96 |
+
|
| 97 |
+
def __getitem__(self, idx):
|
| 98 |
+
seq_id, cutoff = self.index[idx]
|
| 99 |
+
seq = self.sequences[seq_id][:cutoff]
|
| 100 |
+
|
| 101 |
+
# Truncate & pad
|
| 102 |
+
seq = seq[-self.max_len:]
|
| 103 |
+
pad_len = self.max_len - len(seq)
|
| 104 |
+
|
| 105 |
+
input_seq = np.zeros(self.max_len, dtype=np.int64)
|
| 106 |
+
input_seq[pad_len:] = seq
|
| 107 |
+
|
| 108 |
+
if self.target_mode == "last":
|
| 109 |
+
target = self.sequences[seq_id][cutoff]
|
| 110 |
+
return torch.LongTensor(input_seq), torch.LongTensor([target])
|
| 111 |
+
|
| 112 |
+
elif self.target_mode == "all":
|
| 113 |
+
# Predict next item at each step
|
| 114 |
+
target_seq = self.sequences[seq_id][1:cutoff+1]
|
| 115 |
+
target_seq = target_seq[-self.max_len:]
|
| 116 |
+
target = np.zeros(self.max_len, dtype=np.int64)
|
| 117 |
+
target[-len(target_seq):] = target_seq
|
| 118 |
+
return torch.LongTensor(input_seq), torch.LongTensor(target)
|
| 119 |
+
|
| 120 |
+
class SASRecDataModule(pl.LightningDataModule):
|
| 121 |
+
"""
|
| 122 |
+
PyTorch Lightning DataModule for preparing the RetailRocket dataset for the SASRec model.
|
| 123 |
+
|
| 124 |
+
This class handles all aspects of data preparation, including:
|
| 125 |
+
- Filtering out infrequent users and items to reduce noise.
|
| 126 |
+
- Building a consistent item vocabulary.
|
| 127 |
+
- Converting user event histories into sequential data.
|
| 128 |
+
- Creating and providing `DataLoader` instances for training, validation, and testing.
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self, train_df, val_df, test_df, min_item_interactions=5,
|
| 131 |
+
min_user_interactions=5, max_len=50, batch_size=256):
|
| 132 |
+
"""
|
| 133 |
+
Initializes the DataModule.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
train_df (pd.DataFrame): DataFrame for training.
|
| 137 |
+
val_df (pd.DataFrame): DataFrame for validation.
|
| 138 |
+
test_df (pd.DataFrame): DataFrame for testing.
|
| 139 |
+
min_item_interactions (int): Minimum number of interactions for an item to be kept.
|
| 140 |
+
min_user_interactions (int): Minimum number of interactions for a user to be kept.
|
| 141 |
+
max_len (int): The maximum length of a user sequence fed to the model.
|
| 142 |
+
batch_size (int): The batch size for the DataLoaders.
|
| 143 |
+
"""
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.train_df = train_df
|
| 146 |
+
self.val_df = val_df
|
| 147 |
+
self.test_df = test_df
|
| 148 |
+
self.min_item_interactions = min_item_interactions
|
| 149 |
+
self.min_user_interactions = min_user_interactions
|
| 150 |
+
self.max_len = max_len
|
| 151 |
+
self.batch_size = batch_size
|
| 152 |
+
|
| 153 |
+
self.item_map = None
|
| 154 |
+
self.inverse_item_map = None
|
| 155 |
+
self.vocab_size = 0
|
| 156 |
+
self.user_history = None
|
| 157 |
+
|
| 158 |
+
def setup(self, stage=None):
|
| 159 |
+
"""
|
| 160 |
+
Prepares the data for training, validation, and testing.
|
| 161 |
+
|
| 162 |
+
This method is called automatically by PyTorch Lightning. It performs the following steps:
|
| 163 |
+
1. Determines filtering criteria (which users and items to keep) based on the training set only
|
| 164 |
+
to prevent data leakage.
|
| 165 |
+
2. Applies these filters to the train, validation, and test sets.
|
| 166 |
+
3. Builds an item vocabulary (mapping item IDs to integer indices) from the combined
|
| 167 |
+
training and validation sets to ensure consistency for model checkpointing.
|
| 168 |
+
4. Converts the event logs into sequences of item indices for each user in each data split.
|
| 169 |
+
"""
|
| 170 |
+
item_counts = self.train_df['itemid'].value_counts()
|
| 171 |
+
user_counts = self.train_df['visitorid'].value_counts()
|
| 172 |
+
items_to_keep = item_counts[item_counts >= self.min_item_interactions].index
|
| 173 |
+
users_to_keep = user_counts[user_counts >= self.min_user_interactions].index
|
| 174 |
+
|
| 175 |
+
self.filtered_train_df = self.train_df[
|
| 176 |
+
(self.train_df['itemid'].isin(items_to_keep)) &
|
| 177 |
+
(self.train_df['visitorid'].isin(users_to_keep))
|
| 178 |
+
].copy()
|
| 179 |
+
self.filtered_val_df = self.val_df[
|
| 180 |
+
(self.val_df['itemid'].isin(items_to_keep)) &
|
| 181 |
+
(self.val_df['visitorid'].isin(users_to_keep))
|
| 182 |
+
].copy()
|
| 183 |
+
self.filtered_test_df = self.test_df[
|
| 184 |
+
(self.test_df['itemid'].isin(items_to_keep)) &
|
| 185 |
+
(self.test_df['visitorid'].isin(users_to_keep))
|
| 186 |
+
].copy()
|
| 187 |
+
|
| 188 |
+
all_known_items_df = pd.concat([self.filtered_train_df, self.filtered_val_df])
|
| 189 |
+
unique_items = all_known_items_df['itemid'].unique()
|
| 190 |
+
self.item_map = {item_id: i + 1 for i, item_id in enumerate(unique_items)}
|
| 191 |
+
self.inverse_item_map = {i: item_id for item_id, i in self.item_map.items()}
|
| 192 |
+
self.vocab_size = len(self.item_map) + 1 # +1 for padding token 0
|
| 193 |
+
|
| 194 |
+
self.user_history = self.filtered_train_df.groupby('visitorid')['itemid'].apply(list)
|
| 195 |
+
|
| 196 |
+
self.train_sequences = self._create_sequences(self.filtered_train_df)
|
| 197 |
+
self.val_sequences = self._create_sequences(self.filtered_val_df)
|
| 198 |
+
self.test_sequences = self._create_sequences(self.filtered_test_df)
|
| 199 |
+
|
| 200 |
+
def _create_sequences(self, df):
|
| 201 |
+
"""
|
| 202 |
+
Helper function to convert a DataFrame of events into user interaction sequences.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
df (pd.DataFrame): The input DataFrame to process.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
list[list[int]]: A list of user sequences, where each sequence is a list of item indices.
|
| 209 |
+
"""
|
| 210 |
+
df_sorted = df.sort_values(['visitorid', 'timestamp_dt'])
|
| 211 |
+
sequences = df_sorted.groupby('visitorid')['itemid'].apply(
|
| 212 |
+
lambda x: [self.item_map[i] for i in x if i in self.item_map]
|
| 213 |
+
).tolist()
|
| 214 |
+
return [s for s in sequences if len(s) > 1]
|
| 215 |
+
|
| 216 |
+
def train_dataloader(self):
|
| 217 |
+
"""Creates the DataLoader for the training set."""
|
| 218 |
+
dataset = SASRecDataset(self.train_sequences, self.max_len)
|
| 219 |
+
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)
|
| 220 |
+
|
| 221 |
+
def val_dataloader(self):
|
| 222 |
+
"""Creates the DataLoader for the validation set."""
|
| 223 |
+
dataset = SASRecDataset(self.val_sequences, self.max_len)
|
| 224 |
+
return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
|
| 225 |
+
|
| 226 |
+
def test_dataloader(self):
|
| 227 |
+
"""Creates the DataLoader for the test set."""
|
| 228 |
+
dataset = SASRecDataset(self.test_sequences, self.max_len)
|
| 229 |
+
return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
|
| 233 |
+
# --- Configuration ---
|
| 234 |
+
DATA_PATH = "data"
|
| 235 |
+
ZIPED_DATA_PATH = "data/archive.zip" # change to your zip file path
|
| 236 |
+
BATCH_SIZE = 256
|
| 237 |
+
MAX_TOKEN_LEN = 50 # 50–100 is standard for SASRec
|
| 238 |
+
|
| 239 |
+
# extract_ziped_data(ZIPED_DATA_PATH, DATA_PATH) # uncomment this line if you want to extract the data
|
| 240 |
+
|
| 241 |
+
# --- 1. Prepare the data into train, validation, and test sets ---
|
| 242 |
+
train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)
|
| 243 |
+
|
| 244 |
+
# --- 2. Initialize DataModule ---
|
| 245 |
+
print("Initializing DataModule...")
|
| 246 |
+
datamodule = SASRecDataModule(
|
| 247 |
+
train_df=train_set,
|
| 248 |
+
val_df=validation_set,
|
| 249 |
+
test_df=test_set,
|
| 250 |
+
batch_size=BATCH_SIZE,
|
| 251 |
+
max_len=MAX_TOKEN_LEN
|
| 252 |
+
)
|
| 253 |
+
datamodule.setup()
|
scripts/main.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import datetime
|
| 5 |
+
from models import SASRec
|
| 6 |
+
from utils import prepare_ground_truth, calculate_metrics, load_item_properties, load_category_tree, get_popular_items, show_user_recommendations
|
| 7 |
+
from data_prepare import prepare_data, SASRecDataset, SASRecDataModule
|
| 8 |
+
|
| 9 |
+
def main(checkpoint_path="checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt", data_folder="data/"):
|
| 10 |
+
"""
|
| 11 |
+
Main function to run the inference and qualitative analysis pipeline.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
print(f"Using device: {device}")
|
| 16 |
+
|
| 17 |
+
print("Loading model from checkpoint...")
|
| 18 |
+
best_model = SASRec.load_from_checkpoint(checkpoint_path)
|
| 19 |
+
best_model.to(device)
|
| 20 |
+
|
| 21 |
+
print("Preparing data...")
|
| 22 |
+
train_set, validation_set, test_set = prepare_data(data_folder=data_folder)
|
| 23 |
+
|
| 24 |
+
datamodule = SASRecDataModule(train_set, validation_set, test_set)
|
| 25 |
+
datamodule.setup()
|
| 26 |
+
|
| 27 |
+
item_category_map = load_item_properties(data_folder=data_folder)
|
| 28 |
+
category_parent_map = load_category_tree(data_folder=data_folder)
|
| 29 |
+
|
| 30 |
+
print("\nCalculating popular items for cold-start users...")
|
| 31 |
+
popular_items_list = get_popular_items(train_set, k=10)
|
| 32 |
+
|
| 33 |
+
users_in_train_history = set(datamodule.user_history.keys())
|
| 34 |
+
users_in_test_set = set(datamodule.test_df['visitorid'].unique())
|
| 35 |
+
valid_example_users = list(users_in_train_history.intersection(users_in_test_set))
|
| 36 |
+
|
| 37 |
+
print(f"\nFound {len(valid_example_users)} users for qualitative analysis.")
|
| 38 |
+
|
| 39 |
+
for user_id in valid_example_users[:3]:
|
| 40 |
+
show_user_recommendations(user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)
|
| 41 |
+
|
| 42 |
+
new_user_id = -999
|
| 43 |
+
show_user_recommendations(new_user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
main()
|
scripts/models.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
from scipy.sparse import csr_matrix
|
| 8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 9 |
+
import implicit
|
| 10 |
+
from utils import prepare_ground_truth, calculate_metrics
|
| 11 |
+
|
| 12 |
+
def recommend_popular_items_and_evaluate(train_df, test_df, k=10, prepare_ground_truth=None, calculate_metrics=None):
|
| 13 |
+
"""
|
| 14 |
+
Trains a non-personalized Popularity model and evaluates its performance.
|
| 15 |
+
|
| 16 |
+
This model recommends the top-k most frequently transacted items from the training
|
| 17 |
+
set to every user. It serves as a simple but strong baseline.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
train_df (pd.DataFrame): The training dataset.
|
| 21 |
+
test_df (pd.DataFrame): The test dataset for evaluation.
|
| 22 |
+
k (int): The number of items to recommend.
|
| 23 |
+
prepare_ground_truth (function): A function to process the test_df into a ground truth dict.
|
| 24 |
+
calculate_metrics (function): A function to compute ranking metrics.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
dict: A dictionary containing the calculated evaluation metrics (e.g., precision, recall).
|
| 28 |
+
"""
|
| 29 |
+
print(f"\n--- Evaluating Popularity Model (Top {k} items) ---")
|
| 30 |
+
|
| 31 |
+
# 1. "Train" the model by finding the most popular items based on transactions
|
| 32 |
+
purchase_counts = train_df[train_df['event'] == 'transaction']['itemid'].value_counts()
|
| 33 |
+
popular_items = purchase_counts.head(k).index.tolist()
|
| 34 |
+
print(f"Top {k} popular items identified from training data.")
|
| 35 |
+
|
| 36 |
+
# 2. Evaluate the model
|
| 37 |
+
ground_truth = prepare_ground_truth(test_df)
|
| 38 |
+
# Every user receives the same list of popular items
|
| 39 |
+
recommendations = {user_id: popular_items for user_id in ground_truth.keys()}
|
| 40 |
+
|
| 41 |
+
metrics = calculate_metrics(recommendations, ground_truth, k)
|
| 42 |
+
print("Evaluation complete.")
|
| 43 |
+
return metrics
|
| 44 |
+
|
| 45 |
+
def recommend_item_item_and_evaluate(train_df, test_df, k=10, min_item_interactions=5, min_user_interactions=5, prepare_ground_truth=None, calculate_metrics=None):
|
| 46 |
+
"""
|
| 47 |
+
Trains an Item-Item Collaborative Filtering model and evaluates its performance.
|
| 48 |
+
|
| 49 |
+
This model recommends items that are similar to items a user has interacted
|
| 50 |
+
with in the past, based on co-occurrence patterns in the training data.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
train_df (pd.DataFrame): The training dataset.
|
| 54 |
+
test_df (pd.DataFrame): The test dataset for evaluation.
|
| 55 |
+
k (int): The number of items to recommend.
|
| 56 |
+
min_item_interactions (int): Minimum number of interactions for an item to be kept.
|
| 57 |
+
min_user_interactions (int): Minimum number of interactions for a user to be kept.
|
| 58 |
+
prepare_ground_truth (function): A function to process the test_df into a ground truth dict.
|
| 59 |
+
calculate_metrics (function): A function to compute ranking metrics.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
dict: A dictionary containing the calculated evaluation metrics.
|
| 63 |
+
"""
|
| 64 |
+
print(f"\n--- Evaluating Item-Item CF Model (Top {k} items) ---")
|
| 65 |
+
|
| 66 |
+
# 1. Filter out infrequent users and items to reduce noise and computation
|
| 67 |
+
item_counts = train_df['itemid'].value_counts()
|
| 68 |
+
user_counts = train_df['visitorid'].value_counts()
|
| 69 |
+
items_to_keep = item_counts[item_counts >= min_item_interactions].index
|
| 70 |
+
users_to_keep = user_counts[user_counts >= min_user_interactions].index
|
| 71 |
+
filtered_df = train_df[(train_df['itemid'].isin(items_to_keep)) & (train_df['visitorid'].isin(users_to_keep))].copy()
|
| 72 |
+
print(f"Filtered training data from {len(train_df)} to {len(filtered_df)} records.")
|
| 73 |
+
|
| 74 |
+
# 2. Create user-item interaction matrix and vocabulary mappings
|
| 75 |
+
user_map = {uid: i for i, uid in enumerate(filtered_df['visitorid'].unique())}
|
| 76 |
+
item_map = {iid: i for i, iid in enumerate(filtered_df['itemid'].unique())}
|
| 77 |
+
inverse_item_map = {i: iid for iid, i in item_map.items()}
|
| 78 |
+
user_indices = filtered_df['visitorid'].map(user_map)
|
| 79 |
+
item_indices = filtered_df['itemid'].map(item_map)
|
| 80 |
+
user_item_matrix = csr_matrix((np.ones(len(filtered_df)), (user_indices, item_indices)))
|
| 81 |
+
|
| 82 |
+
# 3. Calculate the cosine similarity matrix between all items
|
| 83 |
+
print("Calculating item similarity matrix...")
|
| 84 |
+
item_similarity_matrix = cosine_similarity(user_item_matrix.T, dense_output=False)
|
| 85 |
+
print("Similarity matrix calculated.")
|
| 86 |
+
|
| 87 |
+
# 4. Generate recommendations and evaluate
|
| 88 |
+
ground_truth = prepare_ground_truth(test_df)
|
| 89 |
+
recommendations = {}
|
| 90 |
+
print("Generating recommendations for users in test set...")
|
| 91 |
+
test_users = [u for u in ground_truth.keys() if u in user_map]
|
| 92 |
+
|
| 93 |
+
for user_id in test_users:
|
| 94 |
+
user_index = user_map[user_id]
|
| 95 |
+
user_interactions_indices = user_item_matrix[user_index].indices
|
| 96 |
+
|
| 97 |
+
if len(user_interactions_indices) > 0:
|
| 98 |
+
# Aggregate scores from items the user has interacted with
|
| 99 |
+
all_scores = np.asarray(item_similarity_matrix[user_interactions_indices].sum(axis=0)).flatten()
|
| 100 |
+
# Remove already interacted items from recommendations
|
| 101 |
+
all_scores[user_interactions_indices] = -1
|
| 102 |
+
top_indices = np.argsort(all_scores)[::-1][:k]
|
| 103 |
+
recs = [inverse_item_map[idx] for idx in top_indices if idx in inverse_item_map]
|
| 104 |
+
recommendations[user_id] = recs
|
| 105 |
+
|
| 106 |
+
metrics = calculate_metrics(recommendations, ground_truth, k)
|
| 107 |
+
print("Evaluation complete.")
|
| 108 |
+
return metrics
|
| 109 |
+
|
| 110 |
+
def recommend_als_and_evaluate(train_df, test_df, k=10, min_item_interactions=5, min_user_interactions=5,
|
| 111 |
+
factors=25, regularization=0.02, iterations=48, prepare_ground_truth=None, calculate_metrics=None):
|
| 112 |
+
"""
|
| 113 |
+
Trains an Alternating Least Squares (ALS) model and evaluates its performance.
|
| 114 |
+
|
| 115 |
+
This model uses matrix factorization to learn latent embeddings for users and
|
| 116 |
+
items from implicit feedback data. Default hyperparameters are set from a
|
| 117 |
+
previous Optuna tuning process.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
train_df (pd.DataFrame): The training dataset.
|
| 121 |
+
test_df (pd.DataFrame): The test dataset for evaluation.
|
| 122 |
+
k (int): The number of items to recommend.
|
| 123 |
+
min_item_interactions (int): Minimum number of interactions for an item to be kept.
|
| 124 |
+
min_user_interactions (int): Minimum number of interactions for a user to be kept.
|
| 125 |
+
factors (int): The number of latent factors to compute.
|
| 126 |
+
regularization (float): The regularization factor.
|
| 127 |
+
iterations (int): The number of ALS iterations to run.
|
| 128 |
+
prepare_ground_truth (function): A function to process the test_df into a ground truth dict.
|
| 129 |
+
calculate_metrics (function): A function to compute ranking metrics.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
dict: A dictionary containing the calculated evaluation metrics.
|
| 133 |
+
"""
|
| 134 |
+
print(f"\n--- Evaluating ALS Model (Top {k} items) ---")
|
| 135 |
+
|
| 136 |
+
# 1. Filter data
|
| 137 |
+
item_counts = train_df['itemid'].value_counts()
|
| 138 |
+
user_counts = train_df['visitorid'].value_counts()
|
| 139 |
+
items_to_keep = item_counts[item_counts >= min_item_interactions].index
|
| 140 |
+
users_to_keep = user_counts[user_counts >= min_user_interactions].index
|
| 141 |
+
filtered_df = train_df[(train_df['itemid'].isin(items_to_keep)) & (train_df['visitorid'].isin(users_to_keep))].copy()
|
| 142 |
+
print(f"Filtered training data from {len(train_df)} to {len(filtered_df)} records.")
|
| 143 |
+
|
| 144 |
+
# 2. Create mappings and confidence matrix
|
| 145 |
+
user_map = {uid: i for i, uid in enumerate(filtered_df['visitorid'].unique())}
|
| 146 |
+
item_map = {iid: i for i, iid in enumerate(filtered_df['itemid'].unique())}
|
| 147 |
+
inverse_item_map = {i: iid for iid, i in item_map.items()}
|
| 148 |
+
user_indices = filtered_df['visitorid'].map(user_map).astype(np.int32)
|
| 149 |
+
item_indices = filtered_df['itemid'].map(item_map).astype(np.int32)
|
| 150 |
+
|
| 151 |
+
event_weights = {'view': 1, 'addtocart': 3, 'transaction': 5}
|
| 152 |
+
confidence = filtered_df['event'].map(event_weights).astype(np.float32)
|
| 153 |
+
user_item_matrix = csr_matrix((confidence, (user_indices, item_indices)))
|
| 154 |
+
|
| 155 |
+
# 3. Train the ALS model
|
| 156 |
+
print("Training ALS model...")
|
| 157 |
+
als_model = implicit.als.AlternatingLeastSquares(factors=factors, regularization=regularization, iterations=iterations)
|
| 158 |
+
als_model.fit(user_item_matrix)
|
| 159 |
+
print("ALS model trained.")
|
| 160 |
+
|
| 161 |
+
# 4. Generate recommendations and evaluate
|
| 162 |
+
ground_truth = prepare_ground_truth(test_df)
|
| 163 |
+
recommendations = {}
|
| 164 |
+
print("Generating recommendations for users in test set...")
|
| 165 |
+
test_users_indices = [user_map[u] for u in ground_truth.keys() if u in user_map]
|
| 166 |
+
|
| 167 |
+
if test_users_indices:
|
| 168 |
+
user_item_matrix_for_recs = user_item_matrix[test_users_indices]
|
| 169 |
+
ids, _ = als_model.recommend(test_users_indices, user_item_matrix_for_recs, N=k)
|
| 170 |
+
|
| 171 |
+
for i, user_index in enumerate(test_users_indices):
|
| 172 |
+
original_user_id = list(user_map.keys())[list(user_map.values()).index(user_index)]
|
| 173 |
+
recs = [inverse_item_map[item_idx] for item_idx in ids[i] if item_idx in inverse_item_map]
|
| 174 |
+
recommendations[original_user_id] = recs
|
| 175 |
+
|
| 176 |
+
metrics = calculate_metrics(recommendations, ground_truth, k)
|
| 177 |
+
print("Evaluation complete.")
|
| 178 |
+
return metrics
|
| 179 |
+
|
| 180 |
+
class SASRec(pl.LightningModule):
|
| 181 |
+
"""
|
| 182 |
+
A PyTorch Lightning implementation of the SASRec model for sequential recommendation.
|
| 183 |
+
|
| 184 |
+
SASRec (Self-Attentive Sequential Recommendation) uses a Transformer-based
|
| 185 |
+
architecture to capture the sequential patterns in a user's interaction history
|
| 186 |
+
to predict the next item they are likely to interact with.
|
| 187 |
+
|
| 188 |
+
Attributes:
|
| 189 |
+
save_hyperparameters: Automatically saves all constructor arguments as hyperparameters.
|
| 190 |
+
item_embedding (nn.Embedding): Embedding layer for item IDs.
|
| 191 |
+
positional_embedding (nn.Embedding): Embedding layer to encode the position of items in a sequence.
|
| 192 |
+
transformer_encoder (nn.TransformerEncoder): The core self-attention module.
|
| 193 |
+
fc (nn.Linear): Final fully connected layer to produce logits over the item vocabulary.
|
| 194 |
+
loss_fn (nn.CrossEntropyLoss): The loss function used for training.
|
| 195 |
+
"""
|
| 196 |
+
def __init__(self, vocab_size, max_len, hidden_dim, num_heads, num_layers,
|
| 197 |
+
dropout=0.2, learning_rate=1e-3, weight_decay=1e-6, warmup_steps=2000, max_steps=100000):
|
| 198 |
+
"""
|
| 199 |
+
Initializes the SASRec model layers and hyperparameters.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
vocab_size (int): The total number of unique items in the dataset (+1 for padding).
|
| 203 |
+
max_len (int): The maximum length of the input sequences.
|
| 204 |
+
hidden_dim (int): The dimensionality of the item and positional embeddings.
|
| 205 |
+
num_heads (int): The number of attention heads in the Transformer encoder.
|
| 206 |
+
num_layers (int): The number of layers in the Transformer encoder.
|
| 207 |
+
dropout (float): The dropout rate to be applied.
|
| 208 |
+
learning_rate (float): The learning rate for the optimizer.
|
| 209 |
+
weight_decay (float): The weight decay (L2 penalty) for the optimizer.
|
| 210 |
+
warmup_steps (int): The number of linear warmup steps for the learning rate scheduler.
|
| 211 |
+
max_steps (int): The total number of training steps for the learning rate scheduler's decay phase.
|
| 212 |
+
"""
|
| 213 |
+
super().__init__()
|
| 214 |
+
# This saves all hyperparameters to self.hparams, making them accessible later
|
| 215 |
+
self.save_hyperparameters()
|
| 216 |
+
|
| 217 |
+
# Embedding layers
|
| 218 |
+
self.item_embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
|
| 219 |
+
self.positional_embedding = nn.Embedding(max_len, hidden_dim)
|
| 220 |
+
self.dropout = nn.Dropout(dropout)
|
| 221 |
+
|
| 222 |
+
# Transformer Encoder
|
| 223 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 224 |
+
d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim * 4,
|
| 225 |
+
dropout=dropout, batch_first=True, activation='gelu'
|
| 226 |
+
)
|
| 227 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 228 |
+
|
| 229 |
+
# Output layer
|
| 230 |
+
self.fc = nn.Linear(hidden_dim, vocab_size)
|
| 231 |
+
|
| 232 |
+
# Loss function, ignoring the padding token
|
| 233 |
+
self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)
|
| 234 |
+
|
| 235 |
+
# Lists to store outputs from validation and test steps
|
| 236 |
+
self.validation_step_outputs = []
|
| 237 |
+
self.test_step_outputs = []
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
"""
|
| 241 |
+
Defines the forward pass of the model.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
x (torch.Tensor): A batch of input sequences of shape (batch_size, seq_len).
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
torch.Tensor: The output logits of shape (batch_size, seq_len, vocab_size).
|
| 248 |
+
"""
|
| 249 |
+
seq_len = x.size(1)
|
| 250 |
+
# Create positional indices (0, 1, 2, ..., seq_len-1)
|
| 251 |
+
positions = torch.arange(seq_len, device=self.device).unsqueeze(0)
|
| 252 |
+
|
| 253 |
+
# Create a causal mask to ensure the model doesn't look ahead in the sequence
|
| 254 |
+
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=self.device)
|
| 255 |
+
|
| 256 |
+
# Combine item and positional embeddings
|
| 257 |
+
x = self.item_embedding(x) + self.positional_embedding(positions)
|
| 258 |
+
x = self.dropout(x)
|
| 259 |
+
|
| 260 |
+
# Pass through the Transformer encoder
|
| 261 |
+
x = self.transformer_encoder(x, mask=causal_mask)
|
| 262 |
+
|
| 263 |
+
# Get final logits
|
| 264 |
+
logits = self.fc(x)
|
| 265 |
+
return logits
|
| 266 |
+
|
| 267 |
+
def training_step(self, batch, batch_idx):
|
| 268 |
+
"""
|
| 269 |
+
Performs a single training step.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
batch (tuple): A tuple containing input sequences and target items.
|
| 273 |
+
batch_idx (int): The index of the current batch.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
torch.Tensor: The calculated loss for the batch.
|
| 277 |
+
"""
|
| 278 |
+
inputs, targets = batch
|
| 279 |
+
logits = self.forward(inputs)
|
| 280 |
+
|
| 281 |
+
# We only care about the prediction for the very last item in the input sequence
|
| 282 |
+
last_logits = logits[:, -1, :]
|
| 283 |
+
|
| 284 |
+
# Calculate loss against the single target item
|
| 285 |
+
loss = self.loss_fn(last_logits, targets.squeeze())
|
| 286 |
+
|
| 287 |
+
self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True)
|
| 288 |
+
return loss
|
| 289 |
+
|
| 290 |
+
def validation_step(self, batch, batch_idx):
|
| 291 |
+
"""
|
| 292 |
+
Performs a single validation step.
|
| 293 |
+
Calculates loss and stores predictions for metric computation at the end of the epoch.
|
| 294 |
+
"""
|
| 295 |
+
inputs, targets = batch
|
| 296 |
+
logits = self.forward(inputs)
|
| 297 |
+
last_item_logits = logits[:, -1, :]
|
| 298 |
+
loss = self.loss_fn(last_item_logits, targets.squeeze())
|
| 299 |
+
self.log('val_loss', loss, prog_bar=True, on_epoch=True)
|
| 300 |
+
|
| 301 |
+
# Get top-10 predictions for metric calculation
|
| 302 |
+
top_k_preds = torch.topk(last_item_logits, 10, dim=-1).indices
|
| 303 |
+
self.validation_step_outputs.append({'preds': top_k_preds, 'targets': targets})
|
| 304 |
+
return loss
|
| 305 |
+
|
| 306 |
+
def on_validation_epoch_end(self):
|
| 307 |
+
"""
|
| 308 |
+
Calculates and logs ranking metrics at the end of the validation epoch.
|
| 309 |
+
"""
|
| 310 |
+
if not self.validation_step_outputs: return
|
| 311 |
+
|
| 312 |
+
# Concatenate all predictions and targets from the epoch
|
| 313 |
+
preds = torch.cat([x['preds'] for x in self.validation_step_outputs], dim=0)
|
| 314 |
+
targets = torch.cat([x['targets'] for x in self.validation_step_outputs], dim=0)
|
| 315 |
+
|
| 316 |
+
k = preds.size(1)
|
| 317 |
+
# Check if the target is in the top-k predictions for each example
|
| 318 |
+
hits_tensor = (preds == targets).any(dim=1)
|
| 319 |
+
num_hits = hits_tensor.sum().item()
|
| 320 |
+
num_targets = len(targets)
|
| 321 |
+
|
| 322 |
+
if num_targets > 0:
|
| 323 |
+
hit_rate = num_hits / num_targets
|
| 324 |
+
recall = hit_rate # For next-item prediction, recall@k is the same as hit_rate@k
|
| 325 |
+
precision = num_hits / (k * num_targets)
|
| 326 |
+
else:
|
| 327 |
+
hit_rate, recall, precision = 0.0, 0.0, 0.0
|
| 328 |
+
|
| 329 |
+
self.log('val_hitrate@10', hit_rate, prog_bar=True)
|
| 330 |
+
self.log('val_precision@10', precision, prog_bar=True)
|
| 331 |
+
self.log('val_recall@10', recall, prog_bar=True)
|
| 332 |
+
|
| 333 |
+
self.validation_step_outputs.clear() # Free up memory
|
| 334 |
+
|
| 335 |
+
def test_step(self, batch, batch_idx):
|
| 336 |
+
"""
|
| 337 |
+
Performs a single test step.
|
| 338 |
+
Mirrors the logic of the validation_step.
|
| 339 |
+
"""
|
| 340 |
+
inputs, targets = batch
|
| 341 |
+
logits = self.forward(inputs)
|
| 342 |
+
last_item_logits = logits[:, -1, :]
|
| 343 |
+
loss = self.loss_fn(last_item_logits, targets.squeeze())
|
| 344 |
+
self.log('test_loss', loss, prog_bar=True)
|
| 345 |
+
|
| 346 |
+
top_k_preds = torch.topk(last_item_logits, 10, dim=-1).indices
|
| 347 |
+
self.test_step_outputs.append({'preds': top_k_preds, 'targets': targets})
|
| 348 |
+
return loss
|
| 349 |
+
|
| 350 |
+
def on_test_epoch_end(self):
|
| 351 |
+
"""
|
| 352 |
+
Calculates and logs ranking metrics at the end of the test epoch.
|
| 353 |
+
"""
|
| 354 |
+
if not self.test_step_outputs: return
|
| 355 |
+
|
| 356 |
+
preds = torch.cat([x['preds'] for x in self.test_step_outputs], dim=0)
|
| 357 |
+
targets = torch.cat([x['targets'] for x in self.test_step_outputs], dim=0)
|
| 358 |
+
|
| 359 |
+
k = preds.size(1)
|
| 360 |
+
hits_tensor = (preds == targets).any(dim=1)
|
| 361 |
+
num_hits = hits_tensor.sum().item()
|
| 362 |
+
num_targets = len(targets)
|
| 363 |
+
|
| 364 |
+
if num_targets > 0:
|
| 365 |
+
hit_rate = num_hits / num_targets
|
| 366 |
+
recall = hit_rate
|
| 367 |
+
precision = num_hits / (k * num_targets)
|
| 368 |
+
else:
|
| 369 |
+
hit_rate, recall, precision = 0.0, 0.0, 0.0
|
| 370 |
+
|
| 371 |
+
self.log('test_hitrate@10', hit_rate, prog_bar=True)
|
| 372 |
+
self.log('test_precision@10', precision, prog_bar=True)
|
| 373 |
+
self.log('test_recall@10', recall, prog_bar=True)
|
| 374 |
+
|
| 375 |
+
self.test_step_outputs.clear() # Free up memory
|
| 376 |
+
|
| 377 |
+
def configure_optimizers(self):
|
| 378 |
+
"""
|
| 379 |
+
Configures the optimizer and learning rate scheduler.
|
| 380 |
+
|
| 381 |
+
Uses AdamW optimizer and a linear warmup followed by a cosine decay schedule,
|
| 382 |
+
which is a standard practice for training Transformer models.
|
| 383 |
+
"""
|
| 384 |
+
optimizer = torch.optim.AdamW(
|
| 385 |
+
self.parameters(),
|
| 386 |
+
lr=self.hparams.learning_rate,
|
| 387 |
+
weight_decay=self.hparams.weight_decay
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Learning rate scheduler: linear warmup and cosine decay
|
| 391 |
+
def lr_lambda(current_step: int):
|
| 392 |
+
warmup_steps = self.hparams.warmup_steps
|
| 393 |
+
max_steps = self.hparams.max_steps
|
| 394 |
+
if current_step < warmup_steps:
|
| 395 |
+
return float(current_step) / float(max(1, warmup_steps))
|
| 396 |
+
progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))
|
| 397 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 398 |
+
|
| 399 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 400 |
+
|
| 401 |
+
return {
|
| 402 |
+
"optimizer": optimizer,
|
| 403 |
+
"lr_scheduler": {
|
| 404 |
+
"scheduler": scheduler,
|
| 405 |
+
"interval": "step", # Update the scheduler at every training step
|
| 406 |
+
"frequency": 1
|
| 407 |
+
}
|
| 408 |
+
}
|
scripts/train_and_eval.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
| 5 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from utils import prepare_ground_truth, calculate_metrics
|
| 9 |
+
from data_prepare import prepare_data, SASRecDataset, SASRecDataModule
|
| 10 |
+
from models import recommend_popular_items_and_evaluate, recommend_item_item_and_evaluate, recommend_als_and_evaluate, SASRec
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def train_and_eval_SASRec_model(train_set, validation_set, test_set, checkpoint_dir_path='checkpoints/',
|
| 14 |
+
checkpoint_path=None, n_epochs=10, mode='train',
|
| 15 |
+
batchsize=256, max_token_len=50, learning_rate=1e-3, hidden_dim=128,
|
| 16 |
+
num_heads=2, num_layers=2, dropout=0.2, weight_decay=1e-6):
|
| 17 |
+
"""
|
| 18 |
+
Train or evaluate a SASRec sequential recommendation model using PyTorch Lightning.
|
| 19 |
+
|
| 20 |
+
This function wraps the entire SASRec pipeline:
|
| 21 |
+
- Initializes the SASRecDataModule (handles dataset preprocessing and dataloaders).
|
| 22 |
+
- Builds the SASRec Transformer-based model.
|
| 23 |
+
- Configures training callbacks (checkpointing, early stopping, LR monitoring).
|
| 24 |
+
- Runs either training (`mode='train'`) or evaluation on the test set (`mode='test'`).
|
| 25 |
+
|
| 26 |
+
Args
|
| 27 |
+
----------
|
| 28 |
+
train_set : pd.DataFrame
|
| 29 |
+
Training interactions dataset .
|
| 30 |
+
validation_set : pd.DataFrame
|
| 31 |
+
Validation dataset with the same structure as `train_set`.
|
| 32 |
+
test_set : pd.DataFrame
|
| 33 |
+
Test dataset with the same structure as `train_set`.
|
| 34 |
+
checkpoint_dir_path : str, optional (default='checkpoints/')
|
| 35 |
+
Directory to save model checkpoints.
|
| 36 |
+
checkpoint_path : str or None, optional (default=None)
|
| 37 |
+
Path to a checkpoint file for resuming training or loading a pretrained model for testing.
|
| 38 |
+
n_epochs : int, optional (default=10)
|
| 39 |
+
Number of training epochs.
|
| 40 |
+
mode : {'train', 'test'}, optional (default='train')
|
| 41 |
+
- `'train'`: trains the model on the training/validation data.
|
| 42 |
+
- `'test'`: evaluates the model on the test set using a checkpoint.
|
| 43 |
+
batchsize : int, optional (default=256)
|
| 44 |
+
Batch size for training and evaluation.
|
| 45 |
+
max_token_len : int, optional (default=50)
|
| 46 |
+
Maximum sequence length per user (recent interactions kept).
|
| 47 |
+
learning_rate : float, optional (default=1e-3)
|
| 48 |
+
Learning rate for the AdamW optimizer.
|
| 49 |
+
hidden_dim : int, optional (default=128)
|
| 50 |
+
Dimensionality of item and positional embeddings.
|
| 51 |
+
num_heads : int, optional (default=2)
|
| 52 |
+
Number of attention heads in each Transformer encoder layer.
|
| 53 |
+
num_layers : int, optional (default=2)
|
| 54 |
+
Number of Transformer encoder layers.
|
| 55 |
+
dropout : float, optional (default=0.2)
|
| 56 |
+
Dropout probability applied in embeddings and Transformer layers.
|
| 57 |
+
weight_decay : float, optional (default=1e-6)
|
| 58 |
+
Weight decay regularization coefficient for AdamW.
|
| 59 |
+
"""
|
| 60 |
+
# --- 1. Initialize DataModule ---
|
| 61 |
+
print("Initializing DataModule...")
|
| 62 |
+
datamodule = SASRecDataModule(
|
| 63 |
+
train_df=train_set,
|
| 64 |
+
val_df=validation_set,
|
| 65 |
+
test_df=test_set,
|
| 66 |
+
batch_size=batchsize,
|
| 67 |
+
max_len=max_token_len
|
| 68 |
+
)
|
| 69 |
+
datamodule.setup()
|
| 70 |
+
|
| 71 |
+
# --- 2. Initialize Model ---
|
| 72 |
+
print("Initializing SASRec model...")
|
| 73 |
+
model = SASRec(
|
| 74 |
+
vocab_size=datamodule.vocab_size,
|
| 75 |
+
max_len=max_token_len,
|
| 76 |
+
hidden_dim=hidden_dim,
|
| 77 |
+
num_heads=num_heads,
|
| 78 |
+
num_layers=num_layers,
|
| 79 |
+
dropout=dropout,
|
| 80 |
+
learning_rate=learning_rate,
|
| 81 |
+
weight_decay=weight_decay
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# --- 3. Configure Training Callbacks ---
|
| 85 |
+
checkpoint_callback = ModelCheckpoint(
|
| 86 |
+
dirpath=checkpoint_dir_path,
|
| 87 |
+
filename="sasrec-{epoch:02d}-{val_hitrate@10:.4f}",
|
| 88 |
+
save_top_k=1,
|
| 89 |
+
verbose=True,
|
| 90 |
+
monitor="val_hitrate@10",
|
| 91 |
+
mode="max"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
early_stopping_callback = EarlyStopping(
|
| 95 |
+
monitor="val_hitrate@10", # stop if ranking metric stagnates
|
| 96 |
+
patience=5,
|
| 97 |
+
mode="max"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
lr_monitor = LearningRateMonitor(logging_interval="step")
|
| 101 |
+
|
| 102 |
+
logger = TensorBoardLogger("lightning_logs", name="sasrec")
|
| 103 |
+
|
| 104 |
+
# --- 4. Initialize Trainer ---
|
| 105 |
+
print("Initializing PyTorch Lightning Trainer...")
|
| 106 |
+
trainer = pl.Trainer(
|
| 107 |
+
logger=logger,
|
| 108 |
+
callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
|
| 109 |
+
max_epochs=n_epochs,
|
| 110 |
+
accelerator='auto',
|
| 111 |
+
devices=1,
|
| 112 |
+
gradient_clip_val=1.0, # helps with exploding gradients
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if mode == 'train' :
|
| 116 |
+
# --- 5. Start Training ---
|
| 117 |
+
print(f"Starting training for up to {n_epochs} epochs...")
|
| 118 |
+
trainer.fit(model, datamodule,
|
| 119 |
+
ckpt_path=checkpoint_path
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
elif mode == 'test':
|
| 123 |
+
# --- 6. Test on best checkpoint ---
|
| 124 |
+
print("Evaluating on test set...")
|
| 125 |
+
trainer.test(model, datamodule,
|
| 126 |
+
ckpt_path=checkpoint_path
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# --- Main Execution Block ---
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
|
| 132 |
+
# --- Configuration ---
|
| 133 |
+
BATCH_SIZE = 256
|
| 134 |
+
MAX_TOKEN_LEN = 50 # 50–100 is standard
|
| 135 |
+
LEARNING_RATE = 1e-3
|
| 136 |
+
HIDDEN_DIM = 128
|
| 137 |
+
NUM_HEADS = 2
|
| 138 |
+
NUM_LAYERS = 2
|
| 139 |
+
DROPOUT = 0.2
|
| 140 |
+
WEIGHT_DECAY = 1e-6
|
| 141 |
+
N_EPOCHS = 50
|
| 142 |
+
CHECKPOINT_SAVE_PATH = 'checkpoints/'
|
| 143 |
+
CHECKPOINT_LOAD_PATH = None # or specify a path to a checkpoint file
|
| 144 |
+
MODE = 'train' # 'train' or 'test'
|
| 145 |
+
|
| 146 |
+
train_set, validation_set, test_set = prepare_data(data_folder='data/')
|
| 147 |
+
if train_set is not None:
|
| 148 |
+
results = {}
|
| 149 |
+
full_train_set = pd.concat([train_set, validation_set])
|
| 150 |
+
|
| 151 |
+
# Evaluate classical models
|
| 152 |
+
print("\n>>> Running evaluations on the VALIDATION set <<<")
|
| 153 |
+
results['Popularity (Validation)'] = recommend_popular_items_and_evaluate(train_set, validation_set)
|
| 154 |
+
results['Item-Item CF (Validation)'] = recommend_item_item_and_evaluate(train_set, validation_set)
|
| 155 |
+
results['ALS (Validation)'] = recommend_als_and_evaluate(train_set, validation_set)
|
| 156 |
+
|
| 157 |
+
print("\n>>> Running final evaluations on the TEST set <<<")
|
| 158 |
+
results['Popularity (Test)'] = recommend_popular_items_and_evaluate(full_train_set, test_set)
|
| 159 |
+
results['Item-Item CF (Test)'] = recommend_item_item_and_evaluate(full_train_set, test_set)
|
| 160 |
+
results['ALS (Test)'] = recommend_als_and_evaluate(full_train_set, test_set)
|
| 161 |
+
|
| 162 |
+
print("\n--- Final Evaluation Results ---")
|
| 163 |
+
results_df = pd.DataFrame.from_dict(results, orient='index')
|
| 164 |
+
print(results_df)
|
| 165 |
+
print("--------------------------------")
|
| 166 |
+
|
| 167 |
+
# Train and evaluate SASRec model
|
| 168 |
+
print("\n>>> Training and evaluating SASRec model <<<")
|
| 169 |
+
train_and_eval_SASRec_model(train_set, validation_set, test_set, n_epochs=10, mode='train')
|
| 170 |
+
|
| 171 |
+
print("\n>>> Evaluating trained SASRec model on TEST set <<<")
|
| 172 |
+
train_and_eval_SASRec_model(train_set, validation_set, test_set, mode='test')
|
scripts/utils.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import datetime
|
| 5 |
+
|
| 6 |
+
def calculate_metrics(recommendations_dict, ground_truth_dict, k):
|
| 7 |
+
"""
|
| 8 |
+
Calculates Precision@k, Recall@k, and HitRate@k.
|
| 9 |
+
|
| 10 |
+
args:
|
| 11 |
+
----------
|
| 12 |
+
recommendations_dict : {user_id: [recommended_item_ids]}
|
| 13 |
+
ground_truth_dict : {user_id: set of ground truth item_ids}
|
| 14 |
+
k : int
|
| 15 |
+
|
| 16 |
+
Returns
|
| 17 |
+
-------
|
| 18 |
+
dict with mean precision, recall, and hit rate
|
| 19 |
+
"""
|
| 20 |
+
all_precisions, all_recalls, all_hits = [], [], []
|
| 21 |
+
|
| 22 |
+
for user_id, true_items in ground_truth_dict.items():
|
| 23 |
+
recs = recommendations_dict.get(user_id, [])[:k]
|
| 24 |
+
if not true_items:
|
| 25 |
+
continue
|
| 26 |
+
hits = len(set(recs) & true_items)
|
| 27 |
+
|
| 28 |
+
precision = hits / k if k > 0 else 0
|
| 29 |
+
recall = hits / len(true_items)
|
| 30 |
+
hit_rate = 1.0 if hits > 0 else 0.0
|
| 31 |
+
|
| 32 |
+
all_precisions.append(precision)
|
| 33 |
+
all_recalls.append(recall)
|
| 34 |
+
all_hits.append(hit_rate)
|
| 35 |
+
|
| 36 |
+
if not all_precisions:
|
| 37 |
+
return {"mean_precision@k": 0, "mean_recall@k": 0, "mean_hitrate@k": 0}
|
| 38 |
+
|
| 39 |
+
return {
|
| 40 |
+
"mean_precision@k": np.mean(all_precisions),
|
| 41 |
+
"mean_recall@k": np.mean(all_recalls),
|
| 42 |
+
"mean_hitrate@k": np.mean(all_hits)
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def prepare_ground_truth(df, mode="purchase", event_weights=None):
|
| 46 |
+
"""
|
| 47 |
+
Prepares ground truth dictionaries for evaluation.
|
| 48 |
+
|
| 49 |
+
Parameters
|
| 50 |
+
----------
|
| 51 |
+
df : pd.DataFrame
|
| 52 |
+
Test dataframe containing at least ['visitorid', 'itemid', 'event'].
|
| 53 |
+
mode : str, default="purchase"
|
| 54 |
+
- "purchase" : Only use transactions as ground truth.
|
| 55 |
+
- "all" : Use all events. Optionally weight them.
|
| 56 |
+
event_weights : dict, optional
|
| 57 |
+
Example: {"view": 1, "addtocart": 3, "transaction": 5}.
|
| 58 |
+
Used only if mode == "all".
|
| 59 |
+
|
| 60 |
+
Returns
|
| 61 |
+
-------
|
| 62 |
+
dict : {user_id: set of item_ids}
|
| 63 |
+
"""
|
| 64 |
+
if mode == "purchase":
|
| 65 |
+
df_filtered = df[df["event"] == "transaction"]
|
| 66 |
+
ground_truth = df_filtered.groupby("visitorid")["itemid"].apply(set).to_dict()
|
| 67 |
+
|
| 68 |
+
elif mode == "all":
|
| 69 |
+
if event_weights is None:
|
| 70 |
+
# Default: treat all events equally
|
| 71 |
+
ground_truth = df.groupby("visitorid")["itemid"].apply(set).to_dict()
|
| 72 |
+
else:
|
| 73 |
+
# Weighted ground truth (for more advanced eval)
|
| 74 |
+
ground_truth = {}
|
| 75 |
+
for uid, user_df in df.groupby("visitorid"):
|
| 76 |
+
weighted_items = []
|
| 77 |
+
for _, row in user_df.iterrows():
|
| 78 |
+
weight = event_weights.get(row["event"], 1)
|
| 79 |
+
weighted_items.extend([row["itemid"]] * weight)
|
| 80 |
+
ground_truth[uid] = set(weighted_items)
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError("mode must be 'purchase' or 'all'")
|
| 83 |
+
|
| 84 |
+
return ground_truth
|
| 85 |
+
|
| 86 |
+
def load_item_properties(data_folder='data/'):
|
| 87 |
+
"""
|
| 88 |
+
Loads item properties and creates a mapping from item ID to its category ID.
|
| 89 |
+
Handles both a single properties file or two split parts.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
data_folder (str): The path to the folder containing item property files.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
dict: A dictionary mapping {itemid: categoryid}.
|
| 96 |
+
"""
|
| 97 |
+
print("Loading item properties...")
|
| 98 |
+
try:
|
| 99 |
+
# First, try to load the two separate parts and combine them.
|
| 100 |
+
props_df_part1 = pd.read_csv(data_folder + 'item_properties_part1.csv')
|
| 101 |
+
props_df_part2 = pd.read_csv(data_folder + 'item_properties_part2.csv')
|
| 102 |
+
props_df = pd.concat([props_df_part1, props_df_part2], ignore_index=True)
|
| 103 |
+
print("Successfully loaded and combined item_properties_part1.csv and item_properties_part2.csv.")
|
| 104 |
+
|
| 105 |
+
except FileNotFoundError:
|
| 106 |
+
try:
|
| 107 |
+
# If the parts are not found, try to load a single combined file.
|
| 108 |
+
props_df = pd.read_csv(data_folder + 'item_properties.csv')
|
| 109 |
+
print("Successfully loaded a single item_properties.csv.")
|
| 110 |
+
except FileNotFoundError:
|
| 111 |
+
print(f"Warning: No item properties files found. Cannot display category information.")
|
| 112 |
+
return {}
|
| 113 |
+
|
| 114 |
+
category_df = props_df[props_df['property'] == 'categoryid'].copy()
|
| 115 |
+
category_df['value'] = pd.to_numeric(category_df['value'], errors='coerce').astype('Int64')
|
| 116 |
+
item_to_category_map = category_df.set_index('itemid')['value'].to_dict()
|
| 117 |
+
print("Item to category mapping created successfully.")
|
| 118 |
+
return item_to_category_map
|
| 119 |
+
|
| 120 |
+
def load_category_tree(data_folder='data/'):
|
| 121 |
+
"""
|
| 122 |
+
Loads the category tree to map categories to their parent categories.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
data_folder (str): The path to the folder containing category_tree.csv.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
dict: A dictionary mapping {categoryid: parentid}.
|
| 129 |
+
"""
|
| 130 |
+
print("Loading category tree...")
|
| 131 |
+
try:
|
| 132 |
+
tree_df = pd.read_csv(data_folder + 'category_tree.csv')
|
| 133 |
+
category_parent_map = tree_df.set_index('categoryid')['parentid'].to_dict()
|
| 134 |
+
print("Category tree loaded successfully.")
|
| 135 |
+
return category_parent_map
|
| 136 |
+
except FileNotFoundError:
|
| 137 |
+
print("Warning: 'category_tree.csv' not found. Cannot display parent category information.")
|
| 138 |
+
return {}
|
| 139 |
+
|
| 140 |
+
def get_popular_items(train_df, k=10):
|
| 141 |
+
"""
|
| 142 |
+
Calculates the top-k most popular items based on transaction count.
|
| 143 |
+
"""
|
| 144 |
+
purchase_counts = train_df[train_df['event'] == 'transaction']['itemid'].value_counts()
|
| 145 |
+
return purchase_counts.head(k).index.tolist()
|
| 146 |
+
|
| 147 |
+
def show_user_recommendations(visitor_id, model, datamodule, popular_items, item_category_map, category_parent_map, k=10):
|
| 148 |
+
"""
|
| 149 |
+
Displays recommendations for a user, including category and parent category information.
|
| 150 |
+
"""
|
| 151 |
+
print(f"\n--- Recommendations for Visitor ID: {visitor_id} ---")
|
| 152 |
+
model.eval()
|
| 153 |
+
|
| 154 |
+
def format_item_with_category(item_id):
|
| 155 |
+
category_id = item_category_map.get(item_id, 'N/A')
|
| 156 |
+
parent_id = category_parent_map.get(category_id, 'N/A') if category_id != 'N/A' else 'N/A'
|
| 157 |
+
return f"Item: {item_id} (Category: {category_id}, Parent: {parent_id})"
|
| 158 |
+
|
| 159 |
+
user_history_ids = datamodule.user_history.get(visitor_id)
|
| 160 |
+
|
| 161 |
+
if user_history_ids is None:
|
| 162 |
+
print(f"User {visitor_id} not found in training history. Providing popularity-based recommendations.")
|
| 163 |
+
print(f"\nTop {k} Popular Items (Fallback):")
|
| 164 |
+
recs_with_cats = [format_item_with_category(item_id) for item_id in popular_items]
|
| 165 |
+
print(recs_with_cats)
|
| 166 |
+
print("-------------------------------------------------")
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
history_with_cats = [format_item_with_category(item_id) for item_id in user_history_ids]
|
| 170 |
+
print(f"User's Historical Interactions:")
|
| 171 |
+
print(history_with_cats)
|
| 172 |
+
|
| 173 |
+
history_indices = [datamodule.item_map[i] for i in user_history_ids if i in datamodule.item_map]
|
| 174 |
+
if not history_indices:
|
| 175 |
+
print("None of the user's historical items are in the model's vocabulary.")
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
max_len = datamodule.max_len
|
| 179 |
+
input_seq = history_indices[-max_len:]
|
| 180 |
+
padded_input = np.zeros(max_len, dtype=np.int64)
|
| 181 |
+
padded_input[-len(input_seq):] = input_seq
|
| 182 |
+
|
| 183 |
+
input_tensor = torch.LongTensor(np.array([padded_input]))
|
| 184 |
+
input_tensor = input_tensor.to(model.device)
|
| 185 |
+
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
logits = model(input_tensor)
|
| 188 |
+
last_item_logits = logits[0, -1, :]
|
| 189 |
+
top_indices = torch.topk(last_item_logits, k).indices.tolist()
|
| 190 |
+
|
| 191 |
+
recommended_item_ids = [datamodule.inverse_item_map[idx] for idx in top_indices if idx in datamodule.inverse_item_map]
|
| 192 |
+
|
| 193 |
+
print(f"\nTop {k} Recommended Items:")
|
| 194 |
+
recs_with_cats = [format_item_with_category(item_id) for item_id in recommended_item_ids]
|
| 195 |
+
print(recs_with_cats)
|
| 196 |
+
print("-------------------------------------------------")
|