DanielKiani commited on
Commit
38ae75d
·
0 Parent(s):

Initial commit of recommender system project

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+
5
+ # Virtual Environment
6
+ venv/
7
+ .venv/
8
+
9
+ # Data and Logs
10
+ data/
11
+ logs/
12
+ notebooks/data/
13
+ notebooks/logs/
14
+
15
+ # IDE files
16
+ .vscode/
17
+ .idea/
18
+
README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Recomm](assets/banner.png)
2
+ [![Python](https://img.shields.io/badge/Python-3.10-blue?logo=python)](https://www.python.org/)[![PyTorch](https://img.shields.io/badge/PyTorch-2.7.1-EE4C2C?logo=pytorch)](https://pytorch.org/)![Made with ML](https://img.shields.io/badge/Made%20with-ML-blueviolet?logo=openai)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
3
+
4
+ # 🚀 End-to-End Sequential Recommender System
5
+
6
+ This project implements and evaluates a series of recommender system models, culminating in a state-of-the-art **SASRec (Self-Attentive Sequential Recommendation)** model for Top-N next-item prediction. The system is trained on the [RetailRocket e-commerce dataset](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset) and includes an interactive web demo built with Gradio.
7
+
8
+ ![Gradio app](assets/gradio.png)
9
+ You can find the Gradio app [Here](https://www.kaggle.com/datasets/kritanjalijain/amazon-reviews)
10
+
11
+ ---
12
+
13
+ ## 📑 Table of Contents
14
+
15
+ - [📖 Project Overview](#-project-overview)
16
+ - [✨ Key Features](#-key-features)
17
+ - [🧩 Models Implemented](#-models-implemented)
18
+ - [📊 Final Results](#-final-results)
19
+ - [🔍 Qualitative Analysis](#-qualitative-analysis)
20
+ - [🚧 Future Improvements](#-future-improvements)
21
+ - [📂 Project Structure](#-project-structure)
22
+ - [⚙️ Setup and Usage](#️-setup-and-usage)
23
+ - [🛠️ Technologies and Models Used](#️-technologies-and-models-used)
24
+
25
+ ---
26
+
27
+ ## 📖 Project Overview
28
+
29
+ The primary goal of this project is to predict the next item a user is likely to interact with based on their recent session history. This is a common and critical task in e-commerce known as Top-N sequential recommendation.
30
+
31
+ The project follows a structured approach:
32
+
33
+ 1. **Baseline Models**: Simple, non-sequential models to establish a performance baseline.
34
+ 2. **Hyperparameter Tuning**: Optuna is used to find the optimal configuration for ALS.
35
+ 3. **Advanced Sequential Model**: Implementation of **SASRec** with PyTorch Lightning.
36
+ 4. **Evaluation**: Offline evaluation using ranking metrics (Hit Rate, Precision, Recall @ 10).
37
+ 5. **Interactive Demo**: A Gradio web app for real-time personalized and cold-start recommendations.
38
+
39
+ ---
40
+
41
+ ## ✨ Key Features
42
+
43
+ - 🔹 **Comprehensive Model Comparison**: From popularity to Transformer-based SASRec.
44
+ - 🔹 **Robust Evaluation**: Time-based data split for realistic performance measurement.
45
+ - 🔹 **Hyperparameter Optimization**: Automated with Optuna for ALS.
46
+ - 🔹 **Deep Learning with Attention**: Full PyTorch Lightning implementation of SASRec.
47
+ - 🔹 **Interactive Web Demo**: Live Gradio app for recommendations.
48
+ - 🔹 **Modular Codebase**: Clean, organized structure.
49
+
50
+ ---
51
+
52
+ ## 🧩 Models Implemented
53
+
54
+ | Model | Methodology | Key Characteristics |
55
+ | :--- | :--- | :--- |
56
+ | **Popularity** | Non-personalized | Recommends the most frequently purchased items across all users. |
57
+ | **Item-Item CF** | Collaborative Filtering | Recommends items similar to a user’s past interactions. |
58
+ | **ALS** | Matrix Factorization | Learns latent embeddings from implicit feedback, tuned with Optuna. |
59
+ | **SASRec** | Transformer (Self-Attention) | Sequential model capturing contextual user-item interactions. |
60
+
61
+ ---
62
+
63
+ ## 📊 Final Results
64
+
65
+ SASRec significantly outperformed all baselines, with a **~4.7x improvement in Hit Rate**.
66
+
67
+ | Model | Test Hit Rate@10 | Test Precision@10 | Test Recall@10 |
68
+ | :--- | :---: | :---: | :---: |
69
+ | Popularity | 0.0651 | 0.0065 | 0.0324 |
70
+ | Item-Item CF | 0.0021 | 0.0002 | 0.0011 |
71
+ | Tuned ALS | 0.0063 | 0.0006 | 0.0042 |
72
+ | **SASRec** | **0.3069** | **0.0307** | **0.3069** |
73
+
74
+ ---
75
+
76
+ ## 🔍 Qualitative Analysis
77
+
78
+ The SASRec model not only recommends previously viewed items but also discovers **new, contextually relevant items**.
79
+ For example, for a user browsing **Category 1279**, SASRec suggested new items from the same category — showing strong personalization and discovery.
80
+
81
+ ---
82
+
83
+ ## 🚧 Future Improvements
84
+
85
+ - 📦 **Incorporate Item Features** (e.g., from `item_properties.csv`).
86
+ - 🤖 **Explore Advanced Models**:
87
+ - BERT4Rec (bidirectional Transformers).
88
+ - Graph-based recommender systems.
89
+ - 🧪 **Online A/B Testing** for business impact.
90
+ - ⚡ **Scalability Enhancements**: Feature stores, inference servers (Triton), quantization, distillation.
91
+
92
+ ---
93
+
94
+ ## 📂 Project Structure
95
+
96
+ ```bash
97
+ ├── checkpoints/ # Saved PyTorch Lightning checkpoints
98
+ ├── data/ # RetailRocket dataset
99
+ ├── notebooks/ # EDA notebooks
100
+ └── scripts/
101
+ ├── als_optuna_study.py # Optuna tuning for ALS
102
+ ├── app.py # Gradio web demo
103
+ ├── data_prepare.py # Data loading & preprocessing
104
+ ├── main.py # Entry point for demo
105
+ ├── models.py # Model definitions
106
+ ├── train_and_eval.py # Training & evaluation loop
107
+ └── utils.py # Helper functions
108
+ ├── README.md
109
+ └── requirements.txt
110
+ ```
111
+
112
+ ---
113
+
114
+ ## ⚙️ Setup and Usage
115
+
116
+ Follow these steps to set up and run the project locally.
117
+
118
+ ### 1. Prerequisites
119
+
120
+ - Python 3.10.6+
121
+ - An NVIDIA GPU is recommended for training the SASRec model.
122
+
123
+ ### 2. Clone the Repository
124
+
125
+ ```bash
126
+ git clone <your-repo-url>
127
+ cd <your-repo-name>
128
+ ```
129
+
130
+ ### 3. Install all required packages
131
+
132
+ ```bash
133
+ pip install -r requirements.txt
134
+ ```
135
+
136
+ ### 4. Download and Place Data
137
+
138
+ - Download the [RetailRocket e-commerce dataset](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset).
139
+
140
+ Then run this script:
141
+
142
+ ```bash
143
+ python data_prepare.py
144
+ ```
145
+
146
+ ### 5. Run the Full Evaluation
147
+
148
+ To train all models and see the final comparison table, run the main script:
149
+
150
+ ```bash
151
+ python train_and_eval.py
152
+ ```
153
+
154
+ ### 6. Run the main script
155
+
156
+ ```bash
157
+ python main.py
158
+ ```
159
+
160
+ ---
161
+
162
+ ## 🛠️ Technologies and Models Used
163
+
164
+ This project leverages a range of modern data science and machine learning technologies to build a robust recommender system from the ground up.
165
+
166
+ ### 🏭 Models
167
+
168
+ - **Popularity Model**: A non-personalized baseline that recommends the most frequently purchased items.
169
+ - **Item-Item Collaborative Filtering**: A classical neighborhood-based model that recommends items based on co-occurrence patterns with a user's interaction history.
170
+ - **Alternating Least Squares (ALS)**: A powerful matrix factorization technique for implicit feedback, optimized with hyperparameter tuning.
171
+ - **SASRec (Self-Attentive Sequential Recommendation)**: A state-of-the-art sequential model based on the Transformer architecture, designed to capture the order and context of user interactions.
172
+
173
+ ### 👩‍💻 Core Technologies & Libraries
174
+
175
+ - **Python 3.10**: The primary programming language for the project.
176
+ - **Pandas & NumPy**: For efficient data manipulation, preprocessing, and numerical operations.
177
+ - **Scikit-learn**: Used for calculating item similarity in the collaborative filtering model.
178
+ - **Implicit**: For the ALS model
179
+ - **PyTorch & PyTorch Lightning**: The deep learning framework used to build, train, and evaluate the SASRec model in a structured and scalable way.
180
+ - **Optuna**: A hyperparameter optimization framework used to automatically find the best parameters for the ALS model.
181
+ - **Gradio**: A fast and simple framework used to build and deploy the interactive web demo.
182
+ - **TensorBoard**: For logging and visualizing model training metrics.
assets/banner.png ADDED

Git LFS Details

  • SHA256: 5d374ef2d90cdf1206fb3adde8ad7d0e355dfcd1f9f73c803162658041f4d480
  • Pointer size: 132 Bytes
  • Size of remote file: 2.04 MB
checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af98f72c0c5f325aa0393753db2a7e5865336423b222ced1aeb7f164009d0e06
3
+ size 204423643
notebooks/lightning_logs/sasrec/version_0/events.out.tfevents.1757789891.DESKTOP-48K6QDS.24476.0 ADDED
Binary file (657 Bytes). View file
 
notebooks/lightning_logs/sasrec/version_0/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ vocab_size: 64705
2
+ max_len: 256
3
+ hidden_dim: 128
4
+ num_heads: 4
5
+ num_layers: 2
6
+ dropout: 0.1
7
+ learning_rate: 2.0e-05
notebooks/lightning_logs/sasrec/version_1/events.out.tfevents.1757790260.DESKTOP-48K6QDS.24476.1 ADDED
Binary file (657 Bytes). View file
 
notebooks/lightning_logs/sasrec/version_1/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ vocab_size: 64705
2
+ max_len: 256
3
+ hidden_dim: 128
4
+ num_heads: 4
5
+ num_layers: 2
6
+ dropout: 0.1
7
+ learning_rate: 2.0e-05
notebooks/lightning_logs/sasrec/version_2/events.out.tfevents.1757868923.DESKTOP-48K6QDS.22820.0 ADDED
Binary file (657 Bytes). View file
 
notebooks/lightning_logs/sasrec/version_2/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ vocab_size: 64705
2
+ max_len: 50
3
+ hidden_dim: 128
4
+ num_heads: 2
5
+ num_layers: 2
6
+ dropout: 0.2
7
+ learning_rate: 0.001
notebooks/lightning_logs/sasrec/version_3/events.out.tfevents.1757868963.DESKTOP-48K6QDS.22820.1 ADDED
Binary file (657 Bytes). View file
 
notebooks/lightning_logs/sasrec/version_3/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ vocab_size: 64705
2
+ max_len: 50
3
+ hidden_dim: 128
4
+ num_heads: 2
5
+ num_layers: 2
6
+ dropout: 0.2
7
+ learning_rate: 0.001
notebooks/reccomender.ipynb ADDED
@@ -0,0 +1,1419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "ae798d43",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 🚀 End-to-End Sequential Recommender System \n",
9
+ "\n",
10
+ "This project implements and evaluates a series of recommender system models, culminating in a state-of-the-art **SASRec (Self-Attentive Sequential Recommendation)** model for Top-N next-item prediction. The system is trained on the [RetailRocket e-commerce dataset](https://www.kaggle.com/datasets/retailrocket/ecommerce-dataset) and includes an interactive web demo built with Gradio. "
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "338759e6",
16
+ "metadata": {},
17
+ "source": [
18
+ "## EDA"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "dcc5a23b",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "Loading events.csv...\n",
32
+ "Data Head:\n",
33
+ " timestamp visitorid event itemid transactionid\n",
34
+ "0 1433221332117 257597 view 355908 NaN\n",
35
+ "1 1433224214164 992329 view 248676 NaN\n",
36
+ "2 1433221999827 111016 view 318965 NaN\n",
37
+ "3 1433221955914 483717 view 253185 NaN\n",
38
+ "4 1433221337106 951259 view 367447 NaN\n",
39
+ "\n",
40
+ "Data Info:\n",
41
+ "<class 'pandas.core.frame.DataFrame'>\n",
42
+ "RangeIndex: 2756101 entries, 0 to 2756100\n",
43
+ "Data columns (total 5 columns):\n",
44
+ " # Column Dtype \n",
45
+ "--- ------ ----- \n",
46
+ " 0 timestamp int64 \n",
47
+ " 1 visitorid int64 \n",
48
+ " 2 event object \n",
49
+ " 3 itemid int64 \n",
50
+ " 4 transactionid float64\n",
51
+ "dtypes: float64(1), int64(3), object(1)\n",
52
+ "memory usage: 105.1+ MB\n",
53
+ "\n",
54
+ "Missing Values:\n",
55
+ "timestamp 0\n",
56
+ "visitorid 0\n",
57
+ "event 0\n",
58
+ "itemid 0\n",
59
+ "transactionid 2733644\n",
60
+ "dtype: int64\n"
61
+ ]
62
+ }
63
+ ],
64
+ "source": [
65
+ "import pandas as pd\n",
66
+ "import time\n",
67
+ "from datetime import datetime\n",
68
+ "\n",
69
+ "# Define the path to your data folder\n",
70
+ "DATA_FOLDER = 'data/'\n",
71
+ "\n",
72
+ "# Load the events data\n",
73
+ "print(\"Loading events.csv...\")\n",
74
+ "events_df = pd.read_csv(DATA_FOLDER + 'events.csv')\n",
75
+ "\n",
76
+ "# --- Initial Inspection ---\n",
77
+ "\n",
78
+ "# See the first few rows\n",
79
+ "print(\"Data Head:\")\n",
80
+ "print(events_df.head())\n",
81
+ "\n",
82
+ "# Get a summary of the dataframe (columns, data types, memory usage)\n",
83
+ "print(\"\\nData Info:\")\n",
84
+ "events_df.info()\n",
85
+ "\n",
86
+ "# Check for any missing values\n",
87
+ "print(\"\\nMissing Values:\")\n",
88
+ "print(events_df.isnull().sum())"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 2,
94
+ "id": "dd89bf40",
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "name": "stdout",
99
+ "output_type": "stream",
100
+ "text": [
101
+ "\n",
102
+ "Data timeframe is from 2015-05-03 03:00:04.384000 to 2015-09-18 02:59:47.788000\n",
103
+ "\n",
104
+ "Event Counts:\n",
105
+ "event\n",
106
+ "view 2664312\n",
107
+ "addtocart 69332\n",
108
+ "transaction 22457\n",
109
+ "Name: count, dtype: int64\n",
110
+ "\n",
111
+ "Number of unique visitors: 1407580\n",
112
+ "Number of unique items: 235061\n"
113
+ ]
114
+ }
115
+ ],
116
+ "source": [
117
+ "# --- Data Cleaning and Understanding ---\n",
118
+ "\n",
119
+ "# 1. Convert timestamp to datetime\n",
120
+ "# The timestamp is in milliseconds, so we divide by 1000\n",
121
+ "events_df['timestamp_dt'] = pd.to_datetime(events_df['timestamp'], unit='ms')\n",
122
+ "print(f\"\\nData timeframe is from {events_df['timestamp_dt'].min()} to {events_df['timestamp_dt'].max()}\")\n",
123
+ "\n",
124
+ "\n",
125
+ "# 2. Analyze the distribution of event types\n",
126
+ "print(\"\\nEvent Counts:\")\n",
127
+ "event_counts = events_df['event'].value_counts()\n",
128
+ "print(event_counts)\n",
129
+ "\n",
130
+ "\n",
131
+ "# 3. Calculate number of unique users and items\n",
132
+ "n_users = events_df['visitorid'].nunique()\n",
133
+ "n_items = events_df['itemid'].nunique()\n",
134
+ "\n",
135
+ "print(f\"\\nNumber of unique visitors: {n_users}\")\n",
136
+ "print(f\"Number of unique items: {n_items}\")"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "id": "90fbfb19",
142
+ "metadata": {},
143
+ "source": [
144
+ "## Preparing the data"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "id": "8639638d",
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "import zipfile\n",
155
+ "import pandas as pd\n",
156
+ "from datetime import datetime, timedelta\n",
157
+ "import numpy as np\n",
158
+ "from scipy.sparse import csr_matrix\n",
159
+ "import math\n",
160
+ "import torch\n",
161
+ "import torch.nn as nn\n",
162
+ "from torch.utils.data import Dataset, DataLoader\n",
163
+ "import pytorch_lightning as pl\n",
164
+ "\n",
165
+ "def prepare_data(data_folder='data/', val_days=7, test_days=7):\n",
166
+ " \"\"\"\n",
167
+ " Loads, preprocesses, and splits the events data into train, validation, and test sets.\n",
168
+ " \n",
169
+ " args:\n",
170
+ " data_folder: str, path to the folder containing 'events.csv'\n",
171
+ " val_days: int, number of days for the validation set\n",
172
+ " test_days: int, number of days for the test set\n",
173
+ " \"\"\"\n",
174
+ " # --- Load Data ---\n",
175
+ " print(f\"Loading events.csv from folder: {data_folder}\")\n",
176
+ " try:\n",
177
+ " events_df = pd.read_csv(data_folder + 'events.csv')\n",
178
+ " print(\"Successfully loaded events.csv.\")\n",
179
+ " events_df['timestamp_dt'] = pd.to_datetime(events_df['timestamp'], unit='ms')\n",
180
+ " print(\"\\n--- Initial Data Summary ---\")\n",
181
+ " print(f\"Data shape: {events_df.shape}\")\n",
182
+ " print(f\"Full timeframe: {events_df['timestamp_dt'].min()} to {events_df['timestamp_dt'].max()}\")\n",
183
+ " print(\"----------------------------\\n\")\n",
184
+ " except FileNotFoundError:\n",
185
+ " print(f\"Error: 'events.csv' not found in '{data_folder}'. Please check the path.\")\n",
186
+ " return None, None, None\n",
187
+ "\n",
188
+ " # --- Split Data ---\n",
189
+ " sorted_df = events_df.sort_values('timestamp_dt').reset_index(drop=True)\n",
190
+ " print(f\"Splitting data: {test_days} days for test, {val_days} for validation.\")\n",
191
+ " end_time = sorted_df['timestamp_dt'].max()\n",
192
+ " test_start_time = end_time - timedelta(days=test_days)\n",
193
+ " val_start_time = test_start_time - timedelta(days=val_days)\n",
194
+ "\n",
195
+ " test_df = sorted_df[sorted_df['timestamp_dt'] >= test_start_time]\n",
196
+ " val_df = sorted_df[(sorted_df['timestamp_dt'] >= val_start_time) & (sorted_df['timestamp_dt'] < test_start_time)]\n",
197
+ " train_df = sorted_df[sorted_df['timestamp_dt'] < val_start_time]\n",
198
+ "\n",
199
+ " print(\"--- Data Splitting Summary ---\")\n",
200
+ " print(f\"Training set: {train_df.shape[0]:>8} records | from {train_df['timestamp_dt'].min()} to {train_df['timestamp_dt'].max()}\")\n",
201
+ " print(f\"Validation set: {val_df.shape[0]:>8} records | from {val_df['timestamp_dt'].min()} to {val_df['timestamp_dt'].max()}\")\n",
202
+ " print(f\"Test set: {test_df.shape[0]:>8} records | from {test_df['timestamp_dt'].min()} to {test_df['timestamp_dt'].max()}\")\n",
203
+ " print(\"------------------------------\")\n",
204
+ " \n",
205
+ " return train_df, val_df, test_df\n"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "id": "f99e4498",
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "DATA_PATH = \"data\"\n",
216
+ "\n",
217
+ "train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "id": "96b137cb",
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "class SASRecDataset(Dataset):\n",
228
+ " \"\"\"\n",
229
+ " SASRec Dataset.\n",
230
+ " - Precomputes (sequence_id, cutoff_idx) pairs for O(1) __getitem__.\n",
231
+ " - Supports 'last' or 'all' target modes.\n",
232
+ " \"\"\"\n",
233
+ " def __init__(self, sequences, max_len, target_mode=\"last\"):\n",
234
+ " \"\"\"\n",
235
+ " Args:\n",
236
+ " sequences: list of user sequences (list of item IDs).\n",
237
+ " max_len: maximum sequence length (padding applied).\n",
238
+ " target_mode: 'last' (only last prediction) or 'all' (predict at every step).\n",
239
+ " \"\"\"\n",
240
+ " self.sequences = sequences\n",
241
+ " self.max_len = max_len\n",
242
+ " self.target_mode = target_mode\n",
243
+ "\n",
244
+ " # Build index once\n",
245
+ " self.index = []\n",
246
+ " for seq_id, seq in enumerate(sequences):\n",
247
+ " for i in range(1, len(seq)):\n",
248
+ " self.index.append((seq_id, i))\n",
249
+ "\n",
250
+ " def __len__(self):\n",
251
+ " return len(self.index)\n",
252
+ "\n",
253
+ " def __getitem__(self, idx):\n",
254
+ " seq_id, cutoff = self.index[idx]\n",
255
+ " seq = self.sequences[seq_id][:cutoff]\n",
256
+ "\n",
257
+ " # Truncate & pad\n",
258
+ " seq = seq[-self.max_len:]\n",
259
+ " pad_len = self.max_len - len(seq)\n",
260
+ "\n",
261
+ " input_seq = np.zeros(self.max_len, dtype=np.int64)\n",
262
+ " input_seq[pad_len:] = seq\n",
263
+ "\n",
264
+ " if self.target_mode == \"last\":\n",
265
+ " target = self.sequences[seq_id][cutoff]\n",
266
+ " return torch.LongTensor(input_seq), torch.LongTensor([target])\n",
267
+ "\n",
268
+ " elif self.target_mode == \"all\":\n",
269
+ " # Predict next item at each step\n",
270
+ " target_seq = self.sequences[seq_id][1:cutoff+1]\n",
271
+ " target_seq = target_seq[-self.max_len:]\n",
272
+ " target = np.zeros(self.max_len, dtype=np.int64)\n",
273
+ " target[-len(target_seq):] = target_seq\n",
274
+ " return torch.LongTensor(input_seq), torch.LongTensor(target)\n",
275
+ "\n",
276
+ "class SASRecDataModule(pl.LightningDataModule):\n",
277
+ " \"\"\"\n",
278
+ " PyTorch Lightning DataModule for preparing the RetailRocket dataset for the SASRec model.\n",
279
+ "\n",
280
+ " This class handles all aspects of data preparation, including:\n",
281
+ " - Filtering out infrequent users and items to reduce noise.\n",
282
+ " - Building a consistent item vocabulary.\n",
283
+ " - Converting user event histories into sequential data.\n",
284
+ " - Creating and providing `DataLoader` instances for training, validation, and testing.\n",
285
+ " \"\"\"\n",
286
+ " def __init__(self, train_df, val_df, test_df, min_item_interactions=5, \n",
287
+ " min_user_interactions=5, max_len=50, batch_size=256):\n",
288
+ " \"\"\"\n",
289
+ " Initializes the DataModule.\n",
290
+ "\n",
291
+ " Args:\n",
292
+ " train_df (pd.DataFrame): DataFrame for training.\n",
293
+ " val_df (pd.DataFrame): DataFrame for validation.\n",
294
+ " test_df (pd.DataFrame): DataFrame for testing.\n",
295
+ " min_item_interactions (int): Minimum number of interactions for an item to be kept.\n",
296
+ " min_user_interactions (int): Minimum number of interactions for a user to be kept.\n",
297
+ " max_len (int): The maximum length of a user sequence fed to the model.\n",
298
+ " batch_size (int): The batch size for the DataLoaders.\n",
299
+ " \"\"\"\n",
300
+ " super().__init__()\n",
301
+ " self.train_df = train_df\n",
302
+ " self.val_df = val_df\n",
303
+ " self.test_df = test_df\n",
304
+ " self.min_item_interactions = min_item_interactions\n",
305
+ " self.min_user_interactions = min_user_interactions\n",
306
+ " self.max_len = max_len\n",
307
+ " self.batch_size = batch_size\n",
308
+ "\n",
309
+ " self.item_map = None\n",
310
+ " self.inverse_item_map = None\n",
311
+ " self.vocab_size = 0\n",
312
+ " self.user_history = None\n",
313
+ "\n",
314
+ " def setup(self, stage=None):\n",
315
+ " \"\"\"\n",
316
+ " Prepares the data for training, validation, and testing.\n",
317
+ "\n",
318
+ " This method is called automatically by PyTorch Lightning. It performs the following steps:\n",
319
+ " 1. Determines filtering criteria (which users and items to keep) based on the training set only\n",
320
+ " to prevent data leakage.\n",
321
+ " 2. Applies these filters to the train, validation, and test sets.\n",
322
+ " 3. Builds an item vocabulary (mapping item IDs to integer indices) from the combined\n",
323
+ " training and validation sets to ensure consistency for model checkpointing.\n",
324
+ " 4. Converts the event logs into sequences of item indices for each user in each data split.\n",
325
+ " \"\"\"\n",
326
+ " item_counts = self.train_df['itemid'].value_counts()\n",
327
+ " user_counts = self.train_df['visitorid'].value_counts()\n",
328
+ " items_to_keep = item_counts[item_counts >= self.min_item_interactions].index\n",
329
+ " users_to_keep = user_counts[user_counts >= self.min_user_interactions].index\n",
330
+ "\n",
331
+ " self.filtered_train_df = self.train_df[\n",
332
+ " (self.train_df['itemid'].isin(items_to_keep)) & \n",
333
+ " (self.train_df['visitorid'].isin(users_to_keep))\n",
334
+ " ].copy()\n",
335
+ " self.filtered_val_df = self.val_df[\n",
336
+ " (self.val_df['itemid'].isin(items_to_keep)) & \n",
337
+ " (self.val_df['visitorid'].isin(users_to_keep))\n",
338
+ " ].copy()\n",
339
+ " self.filtered_test_df = self.test_df[\n",
340
+ " (self.test_df['itemid'].isin(items_to_keep)) & \n",
341
+ " (self.test_df['visitorid'].isin(users_to_keep))\n",
342
+ " ].copy()\n",
343
+ "\n",
344
+ " all_known_items_df = pd.concat([self.filtered_train_df, self.filtered_val_df])\n",
345
+ " unique_items = all_known_items_df['itemid'].unique()\n",
346
+ " self.item_map = {item_id: i + 1 for i, item_id in enumerate(unique_items)}\n",
347
+ " self.inverse_item_map = {i: item_id for item_id, i in self.item_map.items()}\n",
348
+ " self.vocab_size = len(self.item_map) + 1 # +1 for padding token 0\n",
349
+ "\n",
350
+ " self.user_history = self.filtered_train_df.groupby('visitorid')['itemid'].apply(list)\n",
351
+ " \n",
352
+ " self.train_sequences = self._create_sequences(self.filtered_train_df)\n",
353
+ " self.val_sequences = self._create_sequences(self.filtered_val_df)\n",
354
+ " self.test_sequences = self._create_sequences(self.filtered_test_df)\n",
355
+ "\n",
356
+ " def _create_sequences(self, df):\n",
357
+ " \"\"\"\n",
358
+ " Helper function to convert a DataFrame of events into user interaction sequences.\n",
359
+ " \n",
360
+ " Args:\n",
361
+ " df (pd.DataFrame): The input DataFrame to process.\n",
362
+ "\n",
363
+ " Returns:\n",
364
+ " list[list[int]]: A list of user sequences, where each sequence is a list of item indices.\n",
365
+ " \"\"\"\n",
366
+ " df_sorted = df.sort_values(['visitorid', 'timestamp_dt'])\n",
367
+ " sequences = df_sorted.groupby('visitorid')['itemid'].apply(\n",
368
+ " lambda x: [self.item_map[i] for i in x if i in self.item_map]\n",
369
+ " ).tolist()\n",
370
+ " return [s for s in sequences if len(s) > 1]\n",
371
+ "\n",
372
+ " def train_dataloader(self):\n",
373
+ " \"\"\"Creates the DataLoader for the training set.\"\"\"\n",
374
+ " dataset = SASRecDataset(self.train_sequences, self.max_len)\n",
375
+ " return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)\n",
376
+ "\n",
377
+ " def val_dataloader(self):\n",
378
+ " \"\"\"Creates the DataLoader for the validation set.\"\"\"\n",
379
+ " dataset = SASRecDataset(self.val_sequences, self.max_len)\n",
380
+ " return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)\n",
381
+ " \n",
382
+ " def test_dataloader(self):\n",
383
+ " \"\"\"Creates the DataLoader for the test set.\"\"\"\n",
384
+ " dataset = SASRecDataset(self.test_sequences, self.max_len)\n",
385
+ " return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": null,
391
+ "id": "56bdc81c",
392
+ "metadata": {},
393
+ "outputs": [],
394
+ "source": [
395
+ "BATCH_SIZE = 256 \n",
396
+ "MAX_TOKEN_LEN = 50 # 50–100 is standard for SASRec\n",
397
+ "\n",
398
+ "# --- 1. Prepare the data into train, validation, and test sets ---\n",
399
+ "train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)\n",
400
+ "\n",
401
+ "# --- 2. Initialize DataModule ---\n",
402
+ "print(\"Initializing DataModule...\")\n",
403
+ "datamodule = SASRecDataModule(\n",
404
+ " train_df=train_set,\n",
405
+ " val_df=validation_set,\n",
406
+ " test_df=test_set,\n",
407
+ " batch_size=BATCH_SIZE,\n",
408
+ " max_len=MAX_TOKEN_LEN\n",
409
+ ")\n",
410
+ "datamodule.setup()"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "markdown",
415
+ "id": "0529207a",
416
+ "metadata": {},
417
+ "source": [
418
+ "## Define train and evaluate the base models "
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "id": "8d899a5a",
425
+ "metadata": {},
426
+ "outputs": [],
427
+ "source": [
428
+ "import pandas as pd\n",
429
+ "from datetime import datetime, timedelta\n",
430
+ "import numpy as np\n",
431
+ "from scipy.sparse import csr_matrix\n",
432
+ "from sklearn.metrics.pairwise import cosine_similarity\n",
433
+ "import implicit\n",
434
+ "import torch\n",
435
+ "import torch.nn as nn\n",
436
+ "from torch.utils.data import Dataset, DataLoader\n",
437
+ "import pytorch_lightning as pl\n",
438
+ "\n",
439
+ "# --- 1. Evaluation Helper Functions ---\n",
440
+ "\n",
441
+ "def prepare_ground_truth(df, mode=\"purchase\", event_weights=None):\n",
442
+ " \"\"\"\n",
443
+ " Prepares ground truth dictionaries for evaluation.\n",
444
+ "\n",
445
+ " Parameters\n",
446
+ " ----------\n",
447
+ " df : pd.DataFrame\n",
448
+ " Test dataframe containing at least ['visitorid', 'itemid', 'event'].\n",
449
+ " mode : str, default=\"purchase\"\n",
450
+ " - \"purchase\" : Only use transactions as ground truth.\n",
451
+ " - \"all\" : Use all events. Optionally weight them.\n",
452
+ " event_weights : dict, optional\n",
453
+ " Example: {\"view\": 1, \"addtocart\": 3, \"transaction\": 5}.\n",
454
+ " Used only if mode == \"all\".\n",
455
+ "\n",
456
+ " Returns\n",
457
+ " -------\n",
458
+ " dict : {user_id: set of item_ids}\n",
459
+ " \"\"\"\n",
460
+ " if mode == \"purchase\":\n",
461
+ " df_filtered = df[df[\"event\"] == \"transaction\"]\n",
462
+ " ground_truth = df_filtered.groupby(\"visitorid\")[\"itemid\"].apply(set).to_dict()\n",
463
+ "\n",
464
+ " elif mode == \"all\":\n",
465
+ " if event_weights is None:\n",
466
+ " # Default: treat all events equally\n",
467
+ " ground_truth = df.groupby(\"visitorid\")[\"itemid\"].apply(set).to_dict()\n",
468
+ " else:\n",
469
+ " # Weighted ground truth (for more advanced eval)\n",
470
+ " ground_truth = {}\n",
471
+ " for uid, user_df in df.groupby(\"visitorid\"):\n",
472
+ " weighted_items = []\n",
473
+ " for _, row in user_df.iterrows():\n",
474
+ " weight = event_weights.get(row[\"event\"], 1)\n",
475
+ " weighted_items.extend([row[\"itemid\"]] * weight)\n",
476
+ " ground_truth[uid] = set(weighted_items)\n",
477
+ " else:\n",
478
+ " raise ValueError(\"mode must be 'purchase' or 'all'\")\n",
479
+ "\n",
480
+ " return ground_truth\n",
481
+ "\n",
482
+ "def calculate_metrics(recommendations_dict, ground_truth_dict, k):\n",
483
+ " \"\"\"\n",
484
+ " Calculates Precision@k, Recall@k, and HitRate@k.\n",
485
+ "\n",
486
+ " args:\n",
487
+ " ----------\n",
488
+ " recommendations_dict : {user_id: [recommended_item_ids]}\n",
489
+ " ground_truth_dict : {user_id: set of ground truth item_ids}\n",
490
+ " k : int\n",
491
+ "\n",
492
+ " Returns\n",
493
+ " -------\n",
494
+ " dict with mean precision, recall, and hit rate\n",
495
+ " \"\"\"\n",
496
+ " all_precisions, all_recalls, all_hits = [], [], []\n",
497
+ "\n",
498
+ " for user_id, true_items in ground_truth_dict.items():\n",
499
+ " recs = recommendations_dict.get(user_id, [])[:k]\n",
500
+ " if not true_items:\n",
501
+ " continue\n",
502
+ " hits = len(set(recs) & true_items)\n",
503
+ "\n",
504
+ " precision = hits / k if k > 0 else 0\n",
505
+ " recall = hits / len(true_items)\n",
506
+ " hit_rate = 1.0 if hits > 0 else 0.0\n",
507
+ "\n",
508
+ " all_precisions.append(precision)\n",
509
+ " all_recalls.append(recall)\n",
510
+ " all_hits.append(hit_rate)\n",
511
+ "\n",
512
+ " if not all_precisions:\n",
513
+ " return {\"mean_precision@k\": 0, \"mean_recall@k\": 0, \"mean_hitrate@k\": 0}\n",
514
+ "\n",
515
+ " return {\n",
516
+ " \"mean_precision@k\": np.mean(all_precisions),\n",
517
+ " \"mean_recall@k\": np.mean(all_recalls),\n",
518
+ " \"mean_hitrate@k\": np.mean(all_hits)\n",
519
+ " }\n",
520
+ "\n",
521
+ "# --- 2. Model Functions (Popularity, Item-Item, ALS) ---\n",
522
+ "\n",
523
+ "def recommend_popular_items_and_evaluate(train_df, test_df, k=10, prepare_ground_truth=None, calculate_metrics=None):\n",
524
+ " \"\"\"\n",
525
+ " Trains a non-personalized Popularity model and evaluates its performance.\n",
526
+ "\n",
527
+ " This model recommends the top-k most frequently transacted items from the training\n",
528
+ " set to every user. It serves as a simple but strong baseline.\n",
529
+ "\n",
530
+ " Args:\n",
531
+ " train_df (pd.DataFrame): The training dataset.\n",
532
+ " test_df (pd.DataFrame): The test dataset for evaluation.\n",
533
+ " k (int): The number of items to recommend.\n",
534
+ " prepare_ground_truth (function): A function to process the test_df into a ground truth dict.\n",
535
+ " calculate_metrics (function): A function to compute ranking metrics.\n",
536
+ "\n",
537
+ " Returns:\n",
538
+ " dict: A dictionary containing the calculated evaluation metrics (e.g., precision, recall).\n",
539
+ " \"\"\"\n",
540
+ " print(f\"\\n--- Evaluating Popularity Model (Top {k} items) ---\")\n",
541
+ " \n",
542
+ " # 1. \"Train\" the model by finding the most popular items based on transactions\n",
543
+ " purchase_counts = train_df[train_df['event'] == 'transaction']['itemid'].value_counts()\n",
544
+ " popular_items = purchase_counts.head(k).index.tolist()\n",
545
+ " print(f\"Top {k} popular items identified from training data.\")\n",
546
+ "\n",
547
+ " # 2. Evaluate the model\n",
548
+ " ground_truth = prepare_ground_truth(test_df)\n",
549
+ " # Every user receives the same list of popular items\n",
550
+ " recommendations = {user_id: popular_items for user_id in ground_truth.keys()}\n",
551
+ " \n",
552
+ " metrics = calculate_metrics(recommendations, ground_truth, k)\n",
553
+ " print(\"Evaluation complete.\")\n",
554
+ " return metrics\n",
555
+ "\n",
556
+ "def recommend_item_item_and_evaluate(train_df, test_df, k=10, min_item_interactions=5, min_user_interactions=5, prepare_ground_truth=None, calculate_metrics=None):\n",
557
+ " \"\"\"\n",
558
+ " Trains an Item-Item Collaborative Filtering model and evaluates its performance.\n",
559
+ "\n",
560
+ " This model recommends items that are similar to items a user has interacted\n",
561
+ " with in the past, based on co-occurrence patterns in the training data.\n",
562
+ "\n",
563
+ " Args:\n",
564
+ " train_df (pd.DataFrame): The training dataset.\n",
565
+ " test_df (pd.DataFrame): The test dataset for evaluation.\n",
566
+ " k (int): The number of items to recommend.\n",
567
+ " min_item_interactions (int): Minimum number of interactions for an item to be kept.\n",
568
+ " min_user_interactions (int): Minimum number of interactions for a user to be kept.\n",
569
+ " prepare_ground_truth (function): A function to process the test_df into a ground truth dict.\n",
570
+ " calculate_metrics (function): A function to compute ranking metrics.\n",
571
+ "\n",
572
+ " Returns:\n",
573
+ " dict: A dictionary containing the calculated evaluation metrics.\n",
574
+ " \"\"\"\n",
575
+ " print(f\"\\n--- Evaluating Item-Item CF Model (Top {k} items) ---\")\n",
576
+ " \n",
577
+ " # 1. Filter out infrequent users and items to reduce noise and computation\n",
578
+ " item_counts = train_df['itemid'].value_counts()\n",
579
+ " user_counts = train_df['visitorid'].value_counts()\n",
580
+ " items_to_keep = item_counts[item_counts >= min_item_interactions].index\n",
581
+ " users_to_keep = user_counts[user_counts >= min_user_interactions].index\n",
582
+ " filtered_df = train_df[(train_df['itemid'].isin(items_to_keep)) & (train_df['visitorid'].isin(users_to_keep))].copy()\n",
583
+ " print(f\"Filtered training data from {len(train_df)} to {len(filtered_df)} records.\")\n",
584
+ "\n",
585
+ " # 2. Create user-item interaction matrix and vocabulary mappings\n",
586
+ " user_map = {uid: i for i, uid in enumerate(filtered_df['visitorid'].unique())}\n",
587
+ " item_map = {iid: i for i, iid in enumerate(filtered_df['itemid'].unique())}\n",
588
+ " inverse_item_map = {i: iid for iid, i in item_map.items()}\n",
589
+ " user_indices = filtered_df['visitorid'].map(user_map)\n",
590
+ " item_indices = filtered_df['itemid'].map(item_map)\n",
591
+ " user_item_matrix = csr_matrix((np.ones(len(filtered_df)), (user_indices, item_indices)))\n",
592
+ "\n",
593
+ " # 3. Calculate the cosine similarity matrix between all items\n",
594
+ " print(\"Calculating item similarity matrix...\")\n",
595
+ " item_similarity_matrix = cosine_similarity(user_item_matrix.T, dense_output=False)\n",
596
+ " print(\"Similarity matrix calculated.\")\n",
597
+ "\n",
598
+ " # 4. Generate recommendations and evaluate\n",
599
+ " ground_truth = prepare_ground_truth(test_df)\n",
600
+ " recommendations = {}\n",
601
+ " print(\"Generating recommendations for users in test set...\")\n",
602
+ " test_users = [u for u in ground_truth.keys() if u in user_map]\n",
603
+ " \n",
604
+ " for user_id in test_users:\n",
605
+ " user_index = user_map[user_id]\n",
606
+ " user_interactions_indices = user_item_matrix[user_index].indices\n",
607
+ " \n",
608
+ " if len(user_interactions_indices) > 0:\n",
609
+ " # Aggregate scores from items the user has interacted with\n",
610
+ " all_scores = np.asarray(item_similarity_matrix[user_interactions_indices].sum(axis=0)).flatten()\n",
611
+ " # Remove already interacted items from recommendations\n",
612
+ " all_scores[user_interactions_indices] = -1\n",
613
+ " top_indices = np.argsort(all_scores)[::-1][:k]\n",
614
+ " recs = [inverse_item_map[idx] for idx in top_indices if idx in inverse_item_map]\n",
615
+ " recommendations[user_id] = recs\n",
616
+ " \n",
617
+ " metrics = calculate_metrics(recommendations, ground_truth, k)\n",
618
+ " print(\"Evaluation complete.\")\n",
619
+ " return metrics\n",
620
+ "\n",
621
+ "def recommend_als_and_evaluate(train_df, test_df, k=10, min_item_interactions=5, min_user_interactions=5, \n",
622
+ " factors=25, regularization=0.02, iterations=48, prepare_ground_truth=None, calculate_metrics=None):\n",
623
+ " \"\"\"\n",
624
+ " Trains an Alternating Least Squares (ALS) model and evaluates its performance.\n",
625
+ "\n",
626
+ " This model uses matrix factorization to learn latent embeddings for users and\n",
627
+ " items from implicit feedback data. Default hyperparameters are set from a\n",
628
+ " previous Optuna tuning process.\n",
629
+ "\n",
630
+ " Args:\n",
631
+ " train_df (pd.DataFrame): The training dataset.\n",
632
+ " test_df (pd.DataFrame): The test dataset for evaluation.\n",
633
+ " k (int): The number of items to recommend.\n",
634
+ " min_item_interactions (int): Minimum number of interactions for an item to be kept.\n",
635
+ " min_user_interactions (int): Minimum number of interactions for a user to be kept.\n",
636
+ " factors (int): The number of latent factors to compute.\n",
637
+ " regularization (float): The regularization factor.\n",
638
+ " iterations (int): The number of ALS iterations to run.\n",
639
+ " prepare_ground_truth (function): A function to process the test_df into a ground truth dict.\n",
640
+ " calculate_metrics (function): A function to compute ranking metrics.\n",
641
+ "\n",
642
+ " Returns:\n",
643
+ " dict: A dictionary containing the calculated evaluation metrics.\n",
644
+ " \"\"\"\n",
645
+ " print(f\"\\n--- Evaluating ALS Model (Top {k} items) ---\")\n",
646
+ " \n",
647
+ " # 1. Filter data\n",
648
+ " item_counts = train_df['itemid'].value_counts()\n",
649
+ " user_counts = train_df['visitorid'].value_counts()\n",
650
+ " items_to_keep = item_counts[item_counts >= min_item_interactions].index\n",
651
+ " users_to_keep = user_counts[user_counts >= min_user_interactions].index\n",
652
+ " filtered_df = train_df[(train_df['itemid'].isin(items_to_keep)) & (train_df['visitorid'].isin(users_to_keep))].copy()\n",
653
+ " print(f\"Filtered training data from {len(train_df)} to {len(filtered_df)} records.\")\n",
654
+ "\n",
655
+ " # 2. Create mappings and confidence matrix\n",
656
+ " user_map = {uid: i for i, uid in enumerate(filtered_df['visitorid'].unique())}\n",
657
+ " item_map = {iid: i for i, iid in enumerate(filtered_df['itemid'].unique())}\n",
658
+ " inverse_item_map = {i: iid for iid, i in item_map.items()}\n",
659
+ " user_indices = filtered_df['visitorid'].map(user_map).astype(np.int32)\n",
660
+ " item_indices = filtered_df['itemid'].map(item_map).astype(np.int32)\n",
661
+ " \n",
662
+ " event_weights = {'view': 1, 'addtocart': 3, 'transaction': 5}\n",
663
+ " confidence = filtered_df['event'].map(event_weights).astype(np.float32)\n",
664
+ " user_item_matrix = csr_matrix((confidence, (user_indices, item_indices)))\n",
665
+ "\n",
666
+ " # 3. Train the ALS model\n",
667
+ " print(\"Training ALS model...\")\n",
668
+ " als_model = implicit.als.AlternatingLeastSquares(factors=factors, regularization=regularization, iterations=iterations)\n",
669
+ " als_model.fit(user_item_matrix)\n",
670
+ " print(\"ALS model trained.\")\n",
671
+ "\n",
672
+ " # 4. Generate recommendations and evaluate\n",
673
+ " ground_truth = prepare_ground_truth(test_df)\n",
674
+ " recommendations = {}\n",
675
+ " print(\"Generating recommendations for users in test set...\")\n",
676
+ " test_users_indices = [user_map[u] for u in ground_truth.keys() if u in user_map]\n",
677
+ " \n",
678
+ " if test_users_indices:\n",
679
+ " user_item_matrix_for_recs = user_item_matrix[test_users_indices]\n",
680
+ " ids, _ = als_model.recommend(test_users_indices, user_item_matrix_for_recs, N=k)\n",
681
+ " \n",
682
+ " for i, user_index in enumerate(test_users_indices):\n",
683
+ " original_user_id = list(user_map.keys())[list(user_map.values()).index(user_index)]\n",
684
+ " recs = [inverse_item_map[item_idx] for item_idx in ids[i] if item_idx in inverse_item_map]\n",
685
+ " recommendations[original_user_id] = recs\n",
686
+ " \n",
687
+ " metrics = calculate_metrics(recommendations, ground_truth, k)\n",
688
+ " print(\"Evaluation complete.\")\n",
689
+ " return metrics\n",
690
+ "\n",
691
+ "\n",
692
+ " train_set, validation_set, test_set = prepare_data(data_folder='C:/Users/dania/vsproject/projects/recommernder_system/data/')\n",
693
+ " if train_set is not None:\n",
694
+ " results = {}\n",
695
+ " full_train_set = pd.concat([train_set, validation_set])\n",
696
+ " \n",
697
+ "# # Evaluate classical models\n",
698
+ " print(\"\\n>>> Running evaluations on the VALIDATION set <<<\")\n",
699
+ " results['Popularity (Validation)'] = recommend_popular_items_and_evaluate(train_set, validation_set)\n",
700
+ " results['Item-Item CF (Validation)'] = recommend_item_item_and_evaluate(train_set, validation_set)\n",
701
+ " results['ALS (Validation)'] = recommend_als_and_evaluate(train_set, validation_set)\n",
702
+ " \n",
703
+ " print(\"\\n>>> Running final evaluations on the TEST set <<<\")\n",
704
+ " results['Popularity (Test)'] = recommend_popular_items_and_evaluate(full_train_set, test_set)\n",
705
+ " results['Item-Item CF (Test)'] = recommend_item_item_and_evaluate(full_train_set, test_set)\n",
706
+ " results['ALS (Test)'] = recommend_als_and_evaluate(full_train_set, test_set)\n",
707
+ " \n",
708
+ " print(\"\\n--- Final Evaluation Results ---\")\n",
709
+ " results_df = pd.DataFrame.from_dict(results, orient='index')\n",
710
+ " print(results_df)\n",
711
+ " print(\"--------------------------------\")\n"
712
+ ]
713
+ },
714
+ {
715
+ "cell_type": "code",
716
+ "execution_count": null,
717
+ "id": "b978c458",
718
+ "metadata": {},
719
+ "outputs": [],
720
+ "source": [
721
+ "train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)\n",
722
+ "if train_set is not None:\n",
723
+ " results = {}\n",
724
+ " full_train_set = pd.concat([train_set, validation_set])\n",
725
+ " \n",
726
+ " # Evaluate base models\n",
727
+ " print(\"\\n>>> Running evaluations on the VALIDATION set <<<\")\n",
728
+ " results['Popularity (Validation)'] = recommend_popular_items_and_evaluate(train_set, validation_set)\n",
729
+ " results['Item-Item CF (Validation)'] = recommend_item_item_and_evaluate(train_set, validation_set)\n",
730
+ " results['ALS (Validation)'] = recommend_als_and_evaluate(train_set, validation_set)\n",
731
+ " \n",
732
+ " print(\"\\n>>> Running final evaluations on the TEST set <<<\")\n",
733
+ " results['Popularity (Test)'] = recommend_popular_items_and_evaluate(full_train_set, test_set)\n",
734
+ " results['Item-Item CF (Test)'] = recommend_item_item_and_evaluate(full_train_set, test_set)\n",
735
+ " results['ALS (Test)'] = recommend_als_and_evaluate(full_train_set, test_set)\n",
736
+ " \n",
737
+ " print(\"\\n--- Final Evaluation Results ---\")\n",
738
+ " results_df = pd.DataFrame.from_dict(results, orient='index')\n",
739
+ " print(results_df)\n",
740
+ " print(\"--------------------------------\")"
741
+ ]
742
+ },
743
+ {
744
+ "cell_type": "markdown",
745
+ "id": "85d8f78c",
746
+ "metadata": {},
747
+ "source": [
748
+ "## Use Optuna to find the best Hyperparameters for the ALS model"
749
+ ]
750
+ },
751
+ {
752
+ "cell_type": "code",
753
+ "execution_count": null,
754
+ "id": "202be1f4",
755
+ "metadata": {},
756
+ "outputs": [],
757
+ "source": [
758
+ "import optuna\n",
759
+ "\n",
760
+ "def objective_als(trial, train_df, val_df):\n",
761
+ " \"\"\"\n",
762
+ " The objective function for Optuna to optimize.\n",
763
+ " \"\"\"\n",
764
+ " # 1. Define the hyperparameter search space\n",
765
+ " params = {\n",
766
+ " 'factors': trial.suggest_int('factors', 20, 200),\n",
767
+ " 'regularization': trial.suggest_float('regularization', 1e-3, 1e-1, log=True),\n",
768
+ " 'iterations': trial.suggest_int('iterations', 10, 50)\n",
769
+ " }\n",
770
+ " \n",
771
+ " # 2. Run an evaluation with the suggested parameters\n",
772
+ " metrics = recommend_als_and_evaluate(train_df, val_df, **params)\n",
773
+ " \n",
774
+ " # 3. Return the metric we want to maximize (precision)\n",
775
+ " return metrics['mean_precision@k']\n",
776
+ "\n",
777
+ "def tune_als_hyperparameters(train_df, val_df, n_trials=25):\n",
778
+ " \"\"\"\n",
779
+ " Orchestrates the Optuna study to find the best hyperparameters for ALS.\n",
780
+ " \"\"\"\n",
781
+ " study = optuna.create_study(direction='maximize')\n",
782
+ " study.optimize(lambda trial: objective_als(trial, train_df, val_df), n_trials=n_trials)\n",
783
+ " \n",
784
+ " print(\"\\n--- Optuna Study Complete ---\")\n",
785
+ " print(f\"Number of finished trials: {len(study.trials)}\")\n",
786
+ " print(\"Best trial:\")\n",
787
+ " trial = study.best_trial\n",
788
+ " print(f\" Value (Precision@10): {trial.value}\")\n",
789
+ " print(\" Params: \")\n",
790
+ " for key, value in trial.params.items():\n",
791
+ " print(f\" {key}: {value}\")\n",
792
+ " \n",
793
+ " return trial.params\n"
794
+ ]
795
+ },
796
+ {
797
+ "cell_type": "code",
798
+ "execution_count": null,
799
+ "id": "18d48b2e",
800
+ "metadata": {},
801
+ "outputs": [],
802
+ "source": [
803
+ "# 1. Prepare all data\n",
804
+ "train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)\n",
805
+ "\n",
806
+ "\n",
807
+ "# --- Hyperparameter Tuning Step ---\n",
808
+ "print(\"\\n>>> 1. TUNING ALS Hyperparameters on the VALIDATION set <<<\")\n",
809
+ "# You can increase n_trials for a more thorough search, e.g., to 50 or 100\n",
810
+ "best_als_params = tune_als_hyperparameters(train_set, validation_set, n_trials=25) \n"
811
+ ]
812
+ },
813
+ {
814
+ "cell_type": "markdown",
815
+ "id": "d9bc9ef8",
816
+ "metadata": {},
817
+ "source": [
818
+ "## Define train and evaluate the SASRec model"
819
+ ]
820
+ },
821
+ {
822
+ "cell_type": "code",
823
+ "execution_count": null,
824
+ "id": "4a90d635",
825
+ "metadata": {},
826
+ "outputs": [],
827
+ "source": [
828
+ "class SASRec(pl.LightningModule):\n",
829
+ " \"\"\"\n",
830
+ " A PyTorch Lightning implementation of the SASRec model for sequential recommendation.\n",
831
+ "\n",
832
+ " SASRec (Self-Attentive Sequential Recommendation) uses a Transformer-based\n",
833
+ " architecture to capture the sequential patterns in a user's interaction history\n",
834
+ " to predict the next item they are likely to interact with.\n",
835
+ "\n",
836
+ " Attributes:\n",
837
+ " save_hyperparameters: Automatically saves all constructor arguments as hyperparameters.\n",
838
+ " item_embedding (nn.Embedding): Embedding layer for item IDs.\n",
839
+ " positional_embedding (nn.Embedding): Embedding layer to encode the position of items in a sequence.\n",
840
+ " transformer_encoder (nn.TransformerEncoder): The core self-attention module.\n",
841
+ " fc (nn.Linear): Final fully connected layer to produce logits over the item vocabulary.\n",
842
+ " loss_fn (nn.CrossEntropyLoss): The loss function used for training.\n",
843
+ " \"\"\"\n",
844
+ " def __init__(self, vocab_size, max_len, hidden_dim, num_heads, num_layers,\n",
845
+ " dropout=0.2, learning_rate=1e-3, weight_decay=1e-6, warmup_steps=2000, max_steps=100000):\n",
846
+ " \"\"\"\n",
847
+ " Initializes the SASRec model layers and hyperparameters.\n",
848
+ "\n",
849
+ " Args:\n",
850
+ " vocab_size (int): The total number of unique items in the dataset (+1 for padding).\n",
851
+ " max_len (int): The maximum length of the input sequences.\n",
852
+ " hidden_dim (int): The dimensionality of the item and positional embeddings.\n",
853
+ " num_heads (int): The number of attention heads in the Transformer encoder.\n",
854
+ " num_layers (int): The number of layers in the Transformer encoder.\n",
855
+ " dropout (float): The dropout rate to be applied.\n",
856
+ " learning_rate (float): The learning rate for the optimizer.\n",
857
+ " weight_decay (float): The weight decay (L2 penalty) for the optimizer.\n",
858
+ " warmup_steps (int): The number of linear warmup steps for the learning rate scheduler.\n",
859
+ " max_steps (int): The total number of training steps for the learning rate scheduler's decay phase.\n",
860
+ " \"\"\"\n",
861
+ " super().__init__()\n",
862
+ " # This saves all hyperparameters to self.hparams, making them accessible later\n",
863
+ " self.save_hyperparameters()\n",
864
+ "\n",
865
+ " # Embedding layers\n",
866
+ " self.item_embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)\n",
867
+ " self.positional_embedding = nn.Embedding(max_len, hidden_dim)\n",
868
+ " self.dropout = nn.Dropout(dropout)\n",
869
+ "\n",
870
+ " # Transformer Encoder\n",
871
+ " encoder_layer = nn.TransformerEncoderLayer(\n",
872
+ " d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim * 4,\n",
873
+ " dropout=dropout, batch_first=True, activation='gelu'\n",
874
+ " )\n",
875
+ " self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n",
876
+ "\n",
877
+ " # Output layer\n",
878
+ " self.fc = nn.Linear(hidden_dim, vocab_size)\n",
879
+ "\n",
880
+ " # Loss function, ignoring the padding token\n",
881
+ " self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)\n",
882
+ " \n",
883
+ " # Lists to store outputs from validation and test steps\n",
884
+ " self.validation_step_outputs = []\n",
885
+ " self.test_step_outputs = []\n",
886
+ "\n",
887
+ " def forward(self, x):\n",
888
+ " \"\"\"\n",
889
+ " Defines the forward pass of the model.\n",
890
+ "\n",
891
+ " Args:\n",
892
+ " x (torch.Tensor): A batch of input sequences of shape (batch_size, seq_len).\n",
893
+ "\n",
894
+ " Returns:\n",
895
+ " torch.Tensor: The output logits of shape (batch_size, seq_len, vocab_size).\n",
896
+ " \"\"\"\n",
897
+ " seq_len = x.size(1)\n",
898
+ " # Create positional indices (0, 1, 2, ..., seq_len-1)\n",
899
+ " positions = torch.arange(seq_len, device=self.device).unsqueeze(0)\n",
900
+ "\n",
901
+ " # Create a causal mask to ensure the model doesn't look ahead in the sequence\n",
902
+ " causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=self.device)\n",
903
+ "\n",
904
+ " # Combine item and positional embeddings\n",
905
+ " x = self.item_embedding(x) + self.positional_embedding(positions)\n",
906
+ " x = self.dropout(x)\n",
907
+ " \n",
908
+ " # Pass through the Transformer encoder\n",
909
+ " x = self.transformer_encoder(x, mask=causal_mask)\n",
910
+ " \n",
911
+ " # Get final logits\n",
912
+ " logits = self.fc(x)\n",
913
+ " return logits\n",
914
+ "\n",
915
+ " def training_step(self, batch, batch_idx):\n",
916
+ " \"\"\"\n",
917
+ " Performs a single training step.\n",
918
+ "\n",
919
+ " Args:\n",
920
+ " batch (tuple): A tuple containing input sequences and target items.\n",
921
+ " batch_idx (int): The index of the current batch.\n",
922
+ "\n",
923
+ " Returns:\n",
924
+ " torch.Tensor: The calculated loss for the batch.\n",
925
+ " \"\"\"\n",
926
+ " inputs, targets = batch\n",
927
+ " logits = self.forward(inputs)\n",
928
+ "\n",
929
+ " # We only care about the prediction for the very last item in the input sequence\n",
930
+ " last_logits = logits[:, -1, :]\n",
931
+ " \n",
932
+ " # Calculate loss against the single target item\n",
933
+ " loss = self.loss_fn(last_logits, targets.squeeze())\n",
934
+ " \n",
935
+ " self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True)\n",
936
+ " return loss\n",
937
+ "\n",
938
+ " def validation_step(self, batch, batch_idx):\n",
939
+ " \"\"\"\n",
940
+ " Performs a single validation step.\n",
941
+ " Calculates loss and stores predictions for metric computation at the end of the epoch.\n",
942
+ " \"\"\"\n",
943
+ " inputs, targets = batch\n",
944
+ " logits = self.forward(inputs)\n",
945
+ " last_item_logits = logits[:, -1, :]\n",
946
+ " loss = self.loss_fn(last_item_logits, targets.squeeze())\n",
947
+ " self.log('val_loss', loss, prog_bar=True, on_epoch=True)\n",
948
+ "\n",
949
+ " # Get top-10 predictions for metric calculation\n",
950
+ " top_k_preds = torch.topk(last_item_logits, 10, dim=-1).indices\n",
951
+ " self.validation_step_outputs.append({'preds': top_k_preds, 'targets': targets})\n",
952
+ " return loss\n",
953
+ "\n",
954
+ " def on_validation_epoch_end(self):\n",
955
+ " \"\"\"\n",
956
+ " Calculates and logs ranking metrics at the end of the validation epoch.\n",
957
+ " \"\"\"\n",
958
+ " if not self.validation_step_outputs: return\n",
959
+ "\n",
960
+ " # Concatenate all predictions and targets from the epoch\n",
961
+ " preds = torch.cat([x['preds'] for x in self.validation_step_outputs], dim=0)\n",
962
+ " targets = torch.cat([x['targets'] for x in self.validation_step_outputs], dim=0)\n",
963
+ "\n",
964
+ " k = preds.size(1)\n",
965
+ " # Check if the target is in the top-k predictions for each example\n",
966
+ " hits_tensor = (preds == targets).any(dim=1)\n",
967
+ " num_hits = hits_tensor.sum().item()\n",
968
+ " num_targets = len(targets)\n",
969
+ "\n",
970
+ " if num_targets > 0:\n",
971
+ " hit_rate = num_hits / num_targets\n",
972
+ " recall = hit_rate # For next-item prediction, recall@k is the same as hit_rate@k\n",
973
+ " precision = num_hits / (k * num_targets)\n",
974
+ " else:\n",
975
+ " hit_rate, recall, precision = 0.0, 0.0, 0.0\n",
976
+ "\n",
977
+ " self.log('val_hitrate@10', hit_rate, prog_bar=True)\n",
978
+ " self.log('val_precision@10', precision, prog_bar=True)\n",
979
+ " self.log('val_recall@10', recall, prog_bar=True)\n",
980
+ "\n",
981
+ " self.validation_step_outputs.clear() # Free up memory\n",
982
+ "\n",
983
+ " def test_step(self, batch, batch_idx):\n",
984
+ " \"\"\"\n",
985
+ " Performs a single test step.\n",
986
+ " Mirrors the logic of the validation_step.\n",
987
+ " \"\"\"\n",
988
+ " inputs, targets = batch\n",
989
+ " logits = self.forward(inputs)\n",
990
+ " last_item_logits = logits[:, -1, :]\n",
991
+ " loss = self.loss_fn(last_item_logits, targets.squeeze())\n",
992
+ " self.log('test_loss', loss, prog_bar=True)\n",
993
+ "\n",
994
+ " top_k_preds = torch.topk(last_item_logits, 10, dim=-1).indices\n",
995
+ " self.test_step_outputs.append({'preds': top_k_preds, 'targets': targets})\n",
996
+ " return loss\n",
997
+ "\n",
998
+ " def on_test_epoch_end(self):\n",
999
+ " \"\"\"\n",
1000
+ " Calculates and logs ranking metrics at the end of the test epoch.\n",
1001
+ " \"\"\"\n",
1002
+ " if not self.test_step_outputs: return\n",
1003
+ "\n",
1004
+ " preds = torch.cat([x['preds'] for x in self.test_step_outputs], dim=0)\n",
1005
+ " targets = torch.cat([x['targets'] for x in self.test_step_outputs], dim=0)\n",
1006
+ "\n",
1007
+ " k = preds.size(1)\n",
1008
+ " hits_tensor = (preds == targets).any(dim=1)\n",
1009
+ " num_hits = hits_tensor.sum().item()\n",
1010
+ " num_targets = len(targets)\n",
1011
+ "\n",
1012
+ " if num_targets > 0:\n",
1013
+ " hit_rate = num_hits / num_targets\n",
1014
+ " recall = hit_rate\n",
1015
+ " precision = num_hits / (k * num_targets)\n",
1016
+ " else:\n",
1017
+ " hit_rate, recall, precision = 0.0, 0.0, 0.0\n",
1018
+ "\n",
1019
+ " self.log('test_hitrate@10', hit_rate, prog_bar=True)\n",
1020
+ " self.log('test_precision@10', precision, prog_bar=True)\n",
1021
+ " self.log('test_recall@10', recall, prog_bar=True)\n",
1022
+ "\n",
1023
+ " self.test_step_outputs.clear() # Free up memory\n",
1024
+ "\n",
1025
+ " def configure_optimizers(self):\n",
1026
+ " \"\"\"\n",
1027
+ " Configures the optimizer and learning rate scheduler.\n",
1028
+ " \n",
1029
+ " Uses AdamW optimizer and a linear warmup followed by a cosine decay schedule,\n",
1030
+ " which is a standard practice for training Transformer models.\n",
1031
+ " \"\"\"\n",
1032
+ " optimizer = torch.optim.AdamW(\n",
1033
+ " self.parameters(),\n",
1034
+ " lr=self.hparams.learning_rate,\n",
1035
+ " weight_decay=self.hparams.weight_decay\n",
1036
+ " )\n",
1037
+ " \n",
1038
+ " # Learning rate scheduler: linear warmup and cosine decay\n",
1039
+ " def lr_lambda(current_step: int):\n",
1040
+ " warmup_steps = self.hparams.warmup_steps\n",
1041
+ " max_steps = self.hparams.max_steps\n",
1042
+ " if current_step < warmup_steps:\n",
1043
+ " return float(current_step) / float(max(1, warmup_steps))\n",
1044
+ " progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))\n",
1045
+ " return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))\n",
1046
+ "\n",
1047
+ " scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n",
1048
+ "\n",
1049
+ " return {\n",
1050
+ " \"optimizer\": optimizer,\n",
1051
+ " \"lr_scheduler\": {\n",
1052
+ " \"scheduler\": scheduler,\n",
1053
+ " \"interval\": \"step\", # Update the scheduler at every training step\n",
1054
+ " \"frequency\": 1\n",
1055
+ " }\n",
1056
+ " }"
1057
+ ]
1058
+ },
1059
+ {
1060
+ "cell_type": "code",
1061
+ "execution_count": null,
1062
+ "id": "20bbc93a",
1063
+ "metadata": {},
1064
+ "outputs": [],
1065
+ "source": [
1066
+ "import pytorch_lightning as pl\n",
1067
+ "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
1068
+ "from pytorch_lightning.loggers import TensorBoardLogger\n",
1069
+ "import torch\n",
1070
+ "\n",
1071
+ "def train_and_eval_SASRec_model(train_set, validation_set, test_set, checkpoint_dir_path='checkpoints/',\n",
1072
+ " checkpoint_path=None, n_epochs=10, mode='train',\n",
1073
+ " batchsize=256, max_token_len=50, learning_rate=1e-3, hidden_dim=128,\n",
1074
+ " num_heads=2, num_layers=2, dropout=0.2, weight_decay=1e-6):\n",
1075
+ " \"\"\"\n",
1076
+ " Train or evaluate a SASRec sequential recommendation model using PyTorch Lightning.\n",
1077
+ "\n",
1078
+ " This function wraps the entire SASRec pipeline:\n",
1079
+ " - Initializes the SASRecDataModule (handles dataset preprocessing and dataloaders).\n",
1080
+ " - Builds the SASRec Transformer-based model.\n",
1081
+ " - Configures training callbacks (checkpointing, early stopping, LR monitoring).\n",
1082
+ " - Runs either training (`mode='train'`) or evaluation on the test set (`mode='test'`).\n",
1083
+ "\n",
1084
+ " Args\n",
1085
+ " ----------\n",
1086
+ " train_set : pd.DataFrame\n",
1087
+ " Training interactions dataset .\n",
1088
+ " validation_set : pd.DataFrame\n",
1089
+ " Validation dataset with the same structure as `train_set`.\n",
1090
+ " test_set : pd.DataFrame\n",
1091
+ " Test dataset with the same structure as `train_set`.\n",
1092
+ " checkpoint_dir_path : str, optional (default='checkpoints/')\n",
1093
+ " Directory to save model checkpoints.\n",
1094
+ " checkpoint_path : str or None, optional (default=None)\n",
1095
+ " Path to a checkpoint file for resuming training or loading a pretrained model for testing.\n",
1096
+ " n_epochs : int, optional (default=10)\n",
1097
+ " Number of training epochs.\n",
1098
+ " mode : {'train', 'test'}, optional (default='train')\n",
1099
+ " - `'train'`: trains the model on the training/validation data.\n",
1100
+ " - `'test'`: evaluates the model on the test set using a checkpoint.\n",
1101
+ " batchsize : int, optional (default=256)\n",
1102
+ " Batch size for training and evaluation.\n",
1103
+ " max_token_len : int, optional (default=50)\n",
1104
+ " Maximum sequence length per user (recent interactions kept).\n",
1105
+ " learning_rate : float, optional (default=1e-3)\n",
1106
+ " Learning rate for the AdamW optimizer.\n",
1107
+ " hidden_dim : int, optional (default=128)\n",
1108
+ " Dimensionality of item and positional embeddings.\n",
1109
+ " num_heads : int, optional (default=2)\n",
1110
+ " Number of attention heads in each Transformer encoder layer.\n",
1111
+ " num_layers : int, optional (default=2)\n",
1112
+ " Number of Transformer encoder layers.\n",
1113
+ " dropout : float, optional (default=0.2)\n",
1114
+ " Dropout probability applied in embeddings and Transformer layers.\n",
1115
+ " weight_decay : float, optional (default=1e-6)\n",
1116
+ " Weight decay regularization coefficient for AdamW.\n",
1117
+ " \"\"\"\n",
1118
+ " # --- 1. Initialize DataModule ---\n",
1119
+ " print(\"Initializing DataModule...\")\n",
1120
+ " datamodule = SASRecDataModule(\n",
1121
+ " train_df=train_set,\n",
1122
+ " val_df=validation_set,\n",
1123
+ " test_df=test_set,\n",
1124
+ " batch_size=batchsize,\n",
1125
+ " max_len=max_token_len\n",
1126
+ " )\n",
1127
+ " datamodule.setup()\n",
1128
+ "\n",
1129
+ " # --- 2. Initialize Model ---\n",
1130
+ " print(\"Initializing SASRec model...\")\n",
1131
+ " model = SASRec(\n",
1132
+ " vocab_size=datamodule.vocab_size,\n",
1133
+ " max_len=max_token_len,\n",
1134
+ " hidden_dim=hidden_dim,\n",
1135
+ " num_heads=num_heads,\n",
1136
+ " num_layers=num_layers,\n",
1137
+ " dropout=dropout,\n",
1138
+ " learning_rate=learning_rate,\n",
1139
+ " weight_decay=weight_decay\n",
1140
+ " )\n",
1141
+ "\n",
1142
+ " # --- 3. Configure Training Callbacks ---\n",
1143
+ " checkpoint_callback = ModelCheckpoint(\n",
1144
+ " dirpath=checkpoint_dir_path,\n",
1145
+ " filename=\"sasrec-{epoch:02d}-{val_hitrate@10:.4f}\",\n",
1146
+ " save_top_k=1,\n",
1147
+ " verbose=True,\n",
1148
+ " monitor=\"val_hitrate@10\",\n",
1149
+ " mode=\"max\"\n",
1150
+ " )\n",
1151
+ "\n",
1152
+ " early_stopping_callback = EarlyStopping(\n",
1153
+ " monitor=\"val_hitrate@10\", # stop if ranking metric stagnates\n",
1154
+ " patience=5,\n",
1155
+ " mode=\"max\"\n",
1156
+ " )\n",
1157
+ "\n",
1158
+ " lr_monitor = LearningRateMonitor(logging_interval=\"step\")\n",
1159
+ "\n",
1160
+ " logger = TensorBoardLogger(\"lightning_logs\", name=\"sasrec\")\n",
1161
+ "\n",
1162
+ " # --- 4. Initialize Trainer ---\n",
1163
+ " print(\"Initializing PyTorch Lightning Trainer...\")\n",
1164
+ " trainer = pl.Trainer(\n",
1165
+ " logger=logger,\n",
1166
+ " callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],\n",
1167
+ " max_epochs=n_epochs,\n",
1168
+ " accelerator='auto',\n",
1169
+ " devices=1,\n",
1170
+ " gradient_clip_val=1.0, # helps with exploding gradients\n",
1171
+ " )\n",
1172
+ "\n",
1173
+ " if mode == 'train' :\n",
1174
+ " # --- 5. Start Training ---\n",
1175
+ " print(f\"Starting training for up to {n_epochs} epochs...\")\n",
1176
+ " trainer.fit(model, datamodule,\n",
1177
+ " ckpt_path=checkpoint_path\n",
1178
+ " )\n",
1179
+ "\n",
1180
+ " elif mode == 'test':\n",
1181
+ " # --- 6. Test on best checkpoint ---\n",
1182
+ " print(\"Evaluating on test set...\")\n",
1183
+ " trainer.test(model, datamodule,\n",
1184
+ " ckpt_path=checkpoint_path\n",
1185
+ " )\n"
1186
+ ]
1187
+ },
1188
+ {
1189
+ "cell_type": "code",
1190
+ "execution_count": null,
1191
+ "id": "5d4a2a7b",
1192
+ "metadata": {},
1193
+ "outputs": [],
1194
+ "source": [
1195
+ "# --- Configuration ---\n",
1196
+ "BATCH_SIZE = 256\n",
1197
+ "MAX_TOKEN_LEN = 50 # 50–100 is standard\n",
1198
+ "LEARNING_RATE = 1e-3\n",
1199
+ "HIDDEN_DIM = 128\n",
1200
+ "NUM_HEADS = 2\n",
1201
+ "NUM_LAYERS = 2\n",
1202
+ "DROPOUT = 0.2\n",
1203
+ "WEIGHT_DECAY = 1e-6\n",
1204
+ "N_EPOCHS = 50\n",
1205
+ "MODE = 'train' # 'train' or 'test'\n",
1206
+ "\n",
1207
+ "# Train and evaluate SASRec model\n",
1208
+ "print(\"\\n>>> Training and evaluating SASRec model <<<\")\n",
1209
+ "train_and_eval_SASRec_model(train_set, validation_set, test_set, n_epochs=10, mode='train')\n",
1210
+ "\n",
1211
+ "print(\"\\n>>> Evaluating trained SASRec model on TEST set <<<\")\n",
1212
+ "train_and_eval_SASRec_model(train_set, validation_set, test_set, mode='test')"
1213
+ ]
1214
+ },
1215
+ {
1216
+ "cell_type": "markdown",
1217
+ "id": "468e0951",
1218
+ "metadata": {},
1219
+ "source": [
1220
+ "## Main function to run the complete Recommender System"
1221
+ ]
1222
+ },
1223
+ {
1224
+ "cell_type": "code",
1225
+ "execution_count": null,
1226
+ "id": "8f810e9a",
1227
+ "metadata": {},
1228
+ "outputs": [],
1229
+ "source": [
1230
+ "def load_item_properties(data_folder='data/'):\n",
1231
+ " \"\"\"\n",
1232
+ " Loads item properties and creates a mapping from item ID to its category ID.\n",
1233
+ " Handles both a single properties file or two split parts.\n",
1234
+ " \n",
1235
+ " Args:\n",
1236
+ " data_folder (str): The path to the folder containing item property files.\n",
1237
+ "\n",
1238
+ " Returns:\n",
1239
+ " dict: A dictionary mapping {itemid: categoryid}.\n",
1240
+ " \"\"\"\n",
1241
+ " print(\"Loading item properties...\")\n",
1242
+ " try:\n",
1243
+ " # First, try to load the two separate parts and combine them.\n",
1244
+ " props_df_part1 = pd.read_csv(data_folder + 'item_properties_part1.csv')\n",
1245
+ " props_df_part2 = pd.read_csv(data_folder + 'item_properties_part2.csv')\n",
1246
+ " props_df = pd.concat([props_df_part1, props_df_part2], ignore_index=True)\n",
1247
+ " print(\"Successfully loaded and combined item_properties_part1.csv and item_properties_part2.csv.\")\n",
1248
+ "\n",
1249
+ " except FileNotFoundError:\n",
1250
+ " try:\n",
1251
+ " # If the parts are not found, try to load a single combined file.\n",
1252
+ " props_df = pd.read_csv(data_folder + 'item_properties.csv')\n",
1253
+ " print(\"Successfully loaded a single item_properties.csv.\")\n",
1254
+ " except FileNotFoundError:\n",
1255
+ " print(f\"Warning: No item properties files found. Cannot display category information.\")\n",
1256
+ " return {}\n",
1257
+ "\n",
1258
+ " category_df = props_df[props_df['property'] == 'categoryid'].copy()\n",
1259
+ " category_df['value'] = pd.to_numeric(category_df['value'], errors='coerce').astype('Int64')\n",
1260
+ " item_to_category_map = category_df.set_index('itemid')['value'].to_dict()\n",
1261
+ " print(\"Item to category mapping created successfully.\")\n",
1262
+ " return item_to_category_map\n",
1263
+ "\n",
1264
+ "def load_category_tree(data_folder='data/'):\n",
1265
+ " \"\"\"\n",
1266
+ " Loads the category tree to map categories to their parent categories.\n",
1267
+ "\n",
1268
+ " Args:\n",
1269
+ " data_folder (str): The path to the folder containing category_tree.csv.\n",
1270
+ "\n",
1271
+ " Returns:\n",
1272
+ " dict: A dictionary mapping {categoryid: parentid}.\n",
1273
+ " \"\"\"\n",
1274
+ " print(\"Loading category tree...\")\n",
1275
+ " try:\n",
1276
+ " tree_df = pd.read_csv(data_folder + 'category_tree.csv')\n",
1277
+ " category_parent_map = tree_df.set_index('categoryid')['parentid'].to_dict()\n",
1278
+ " print(\"Category tree loaded successfully.\")\n",
1279
+ " return category_parent_map\n",
1280
+ " except FileNotFoundError:\n",
1281
+ " print(\"Warning: 'category_tree.csv' not found. Cannot display parent category information.\")\n",
1282
+ " return {}\n",
1283
+ "\n",
1284
+ "def get_popular_items(train_df, k=10):\n",
1285
+ " \"\"\"\n",
1286
+ " Calculates the top-k most popular items based on transaction count.\n",
1287
+ " \"\"\"\n",
1288
+ " purchase_counts = train_df[train_df['event'] == 'transaction']['itemid'].value_counts()\n",
1289
+ " return purchase_counts.head(k).index.tolist()\n",
1290
+ "\n",
1291
+ "def show_user_recommendations(visitor_id, model, datamodule, popular_items, item_category_map, category_parent_map, k=10):\n",
1292
+ " \"\"\"\n",
1293
+ " Displays recommendations for a user, including category and parent category information.\n",
1294
+ " \"\"\"\n",
1295
+ " print(f\"\\n--- Recommendations for Visitor ID: {visitor_id} ---\")\n",
1296
+ " model.eval()\n",
1297
+ "\n",
1298
+ " def format_item_with_category(item_id):\n",
1299
+ " category_id = item_category_map.get(item_id, 'N/A')\n",
1300
+ " parent_id = category_parent_map.get(category_id, 'N/A') if category_id != 'N/A' else 'N/A'\n",
1301
+ " return f\"Item: {item_id} (Category: {category_id}, Parent: {parent_id})\"\n",
1302
+ "\n",
1303
+ " user_history_ids = datamodule.user_history.get(visitor_id)\n",
1304
+ "\n",
1305
+ " if user_history_ids is None:\n",
1306
+ " print(f\"User {visitor_id} not found in training history. Providing popularity-based recommendations.\")\n",
1307
+ " print(f\"\\nTop {k} Popular Items (Fallback):\")\n",
1308
+ " recs_with_cats = [format_item_with_category(item_id) for item_id in popular_items]\n",
1309
+ " print(recs_with_cats)\n",
1310
+ " print(\"-------------------------------------------------\")\n",
1311
+ " return\n",
1312
+ "\n",
1313
+ " history_with_cats = [format_item_with_category(item_id) for item_id in user_history_ids]\n",
1314
+ " print(f\"User's Historical Interactions:\")\n",
1315
+ " print(history_with_cats)\n",
1316
+ "\n",
1317
+ " history_indices = [datamodule.item_map[i] for i in user_history_ids if i in datamodule.item_map]\n",
1318
+ " if not history_indices:\n",
1319
+ " print(\"None of the user's historical items are in the model's vocabulary.\")\n",
1320
+ " return\n",
1321
+ "\n",
1322
+ " max_len = datamodule.max_len\n",
1323
+ " input_seq = history_indices[-max_len:]\n",
1324
+ " padded_input = np.zeros(max_len, dtype=np.int64)\n",
1325
+ " padded_input[-len(input_seq):] = input_seq\n",
1326
+ " \n",
1327
+ " input_tensor = torch.LongTensor(np.array([padded_input]))\n",
1328
+ " input_tensor = input_tensor.to(model.device)\n",
1329
+ "\n",
1330
+ " with torch.no_grad():\n",
1331
+ " logits = model(input_tensor)\n",
1332
+ " last_item_logits = logits[0, -1, :]\n",
1333
+ " top_indices = torch.topk(last_item_logits, k).indices.tolist()\n",
1334
+ "\n",
1335
+ " recommended_item_ids = [datamodule.inverse_item_map[idx] for idx in top_indices if idx in datamodule.inverse_item_map]\n",
1336
+ "\n",
1337
+ " print(f\"\\nTop {k} Recommended Items:\")\n",
1338
+ " recs_with_cats = [format_item_with_category(item_id) for item_id in recommended_item_ids]\n",
1339
+ " print(recs_with_cats)\n",
1340
+ " print(\"-------------------------------------------------\")\n"
1341
+ ]
1342
+ },
1343
+ {
1344
+ "cell_type": "code",
1345
+ "execution_count": null,
1346
+ "id": "735d0f8d",
1347
+ "metadata": {},
1348
+ "outputs": [],
1349
+ "source": [
1350
+ "def main(checkpoint_path=\"checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt\", data_folder=\"data/\"):\n",
1351
+ " \"\"\"\n",
1352
+ " Main function to run the inference and qualitative analysis pipeline.\n",
1353
+ " \"\"\"\n",
1354
+ "\n",
1355
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1356
+ " print(f\"Using device: {device}\")\n",
1357
+ "\n",
1358
+ " print(\"Loading model from checkpoint...\")\n",
1359
+ " best_model = SASRec.load_from_checkpoint(checkpoint_path)\n",
1360
+ " best_model.to(device)\n",
1361
+ "\n",
1362
+ " print(\"Preparing data...\")\n",
1363
+ " train_set, validation_set, test_set = prepare_data(data_folder=data_folder)\n",
1364
+ " \n",
1365
+ " datamodule = SASRecDataModule(train_set, validation_set, test_set)\n",
1366
+ " datamodule.setup()\n",
1367
+ " \n",
1368
+ " item_category_map = load_item_properties(data_folder=data_folder)\n",
1369
+ " category_parent_map = load_category_tree(data_folder=data_folder)\n",
1370
+ " \n",
1371
+ " print(\"\\nCalculating popular items for cold-start users...\")\n",
1372
+ " popular_items_list = get_popular_items(train_set, k=10)\n",
1373
+ "\n",
1374
+ " users_in_train_history = set(datamodule.user_history.keys())\n",
1375
+ " users_in_test_set = set(datamodule.test_df['visitorid'].unique())\n",
1376
+ " valid_example_users = list(users_in_train_history.intersection(users_in_test_set))\n",
1377
+ "\n",
1378
+ " print(f\"\\nFound {len(valid_example_users)} users for qualitative analysis.\")\n",
1379
+ " \n",
1380
+ " for user_id in valid_example_users[:3]:\n",
1381
+ " show_user_recommendations(user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)\n",
1382
+ " \n",
1383
+ " new_user_id = -999\n",
1384
+ " show_user_recommendations(new_user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)\n"
1385
+ ]
1386
+ },
1387
+ {
1388
+ "cell_type": "code",
1389
+ "execution_count": null,
1390
+ "id": "0e7ba5f2",
1391
+ "metadata": {},
1392
+ "outputs": [],
1393
+ "source": [
1394
+ "main()"
1395
+ ]
1396
+ }
1397
+ ],
1398
+ "metadata": {
1399
+ "kernelspec": {
1400
+ "display_name": "Python 3",
1401
+ "language": "python",
1402
+ "name": "python3"
1403
+ },
1404
+ "language_info": {
1405
+ "codemirror_mode": {
1406
+ "name": "ipython",
1407
+ "version": 3
1408
+ },
1409
+ "file_extension": ".py",
1410
+ "mimetype": "text/x-python",
1411
+ "name": "python",
1412
+ "nbconvert_exporter": "python",
1413
+ "pygments_lexer": "ipython3",
1414
+ "version": "3.10.6"
1415
+ }
1416
+ },
1417
+ "nbformat": 4,
1418
+ "nbformat_minor": 5
1419
+ }
requirements.txt ADDED
Binary file (304 Bytes). View file
 
scripts/als_optuna_study.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ import numpy as np
3
+ import math
4
+ from scipy.sparse import csr_matrix
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import implicit
7
+ from utils import prepare_ground_truth, calculate_metrics
8
+ from models import recommend_als_and_evaluate
9
+ from data_prepare import prepare_data
10
+
11
+ def objective_als(trial, train_df, val_df):
12
+ """
13
+ The objective function for Optuna to optimize.
14
+ """
15
+ # 1. Define the hyperparameter search space
16
+ params = {
17
+ 'factors': trial.suggest_int('factors', 20, 200),
18
+ 'regularization': trial.suggest_float('regularization', 1e-3, 1e-1, log=True),
19
+ 'iterations': trial.suggest_int('iterations', 10, 50)
20
+ }
21
+
22
+ # 2. Run an evaluation with the suggested parameters
23
+ metrics = recommend_als_and_evaluate(train_df, val_df, **params)
24
+
25
+ # 3. Return the metric we want to maximize (precision)
26
+ return metrics['mean_precision@k']
27
+
28
+ def tune_als_hyperparameters(train_df, val_df, n_trials=25):
29
+ """
30
+ Orchestrates the Optuna study to find the best hyperparameters for ALS.
31
+ """
32
+ study = optuna.create_study(direction='maximize')
33
+ study.optimize(lambda trial: objective_als(trial, train_df, val_df), n_trials=n_trials)
34
+
35
+ print("\n--- Optuna Study Complete ---")
36
+ print(f"Number of finished trials: {len(study.trials)}")
37
+ print("Best trial:")
38
+ trial = study.best_trial
39
+ print(f" Value (Precision@10): {trial.value}")
40
+ print(" Params: ")
41
+ for key, value in trial.params.items():
42
+ print(f" {key}: {value}")
43
+
44
+ return trial.params
45
+
46
+ if __name__ == "__main__":
47
+ # 1. Prepare all data
48
+ train_set, validation_set, test_set = prepare_data()
49
+
50
+ # --- Hyperparameter Tuning Step ---
51
+ print("\n>>> 1. TUNING ALS Hyperparameters on the VALIDATION set <<<")
52
+
53
+ best_als_params = tune_als_hyperparameters(train_set, validation_set, n_trials=25)
scripts/app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from datetime import datetime, timedelta
6
+ from models import SASRec, SASRecDataModule
7
+ from data_prepare import SASRecDataset, SASRecDataModule, prepare_data
8
+ from utils import load_item_properties, load_category_tree, get_popular_items
9
+
10
+ # --- Global variables to hold loaded artifacts ---
11
+ # This prevents reloading the model and data on every prediction.
12
+ MODEL = None
13
+ DATAMODULE = None
14
+ ITEM_CATEGORY_MAP = None
15
+ CATEGORY_PARENT_MAP = None
16
+ POPULAR_ITEMS = None
17
+
18
+ # --- Data Loading and Preparation Functions ---
19
+
20
+ def load_artifacts():
21
+ """
22
+ Loads all necessary artifacts (model, data, mappings) into global variables.
23
+ This function is called only once when the app starts.
24
+ """
25
+ global MODEL, DATAMODULE, ITEM_CATEGORY_MAP, CATEGORY_PARENT_MAP, POPULAR_ITEMS
26
+
27
+ print("--- Loading all artifacts for the Gradio app ---")
28
+
29
+ # HF-FRIENDLY: Path is relative, assuming the checkpoint is in the root of the Space repo.
30
+ CHECKPOINT_PATH = "sasrec-epoch=05-val_hitrate@10=0.3614.ckpt"
31
+ DATA_FOLDER = "data/"
32
+
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ print(f"Using device: {device}")
35
+
36
+ print(f"Loading model from checkpoint: {CHECKPOINT_PATH}...")
37
+ MODEL = SASRec.load_from_checkpoint(CHECKPOINT_PATH)
38
+ MODEL.to(device)
39
+ MODEL.eval()
40
+
41
+ print("Preparing data...")
42
+ train_set, validation_set, test_set = prepare_data(data_folder=DATA_FOLDER)
43
+
44
+ DATAMODULE = SASRecDataModule(train_set, validation_set, test_set)
45
+ DATAMODULE.setup()
46
+
47
+ print("Loading item and category maps...")
48
+ ITEM_CATEGORY_MAP = load_item_properties(data_folder=DATA_FOLDER)
49
+ CATEGORY_PARENT_MAP = load_category_tree(data_folder=DATA_FOLDER)
50
+
51
+ print("Calculating popular items for cold-start users...")
52
+ POPULAR_ITEMS = get_popular_items(train_set, k=10)
53
+
54
+ print("--- Artifacts loaded successfully. Ready to serve recommendations. ---")
55
+
56
+ def get_recommendations(visitor_id_str):
57
+ """
58
+ The main prediction function for the Gradio interface.
59
+
60
+ Takes a visitor ID string, gets recommendations, and formats them for display.
61
+ """
62
+ try:
63
+ visitor_id = int(visitor_id_str)
64
+ except (ValueError, TypeError):
65
+ return pd.DataFrame(), pd.DataFrame(), "Please enter a valid numerical Visitor ID."
66
+
67
+ user_history_ids = DATAMODULE.user_history.get(visitor_id)
68
+
69
+ def format_to_df(item_list):
70
+ data = []
71
+ for rank, item_id in enumerate(item_list, 1):
72
+ category_id = ITEM_CATEGORY_MAP.get(item_id, 'N/A')
73
+ parent_id = CATEGORY_PARENT_MAP.get(category_id, 'N/A') if pd.notna(category_id) else 'N/A'
74
+ data.append([rank, item_id, category_id, parent_id])
75
+ return pd.DataFrame(data, columns=['Rank', 'Item ID', 'Category ID', 'Parent ID'])
76
+
77
+ # --- Cold-Start User (Fallback to Popularity) ---
78
+ if user_history_ids is None:
79
+ history_df = pd.DataFrame(columns=['Rank', 'Item ID', 'Category ID', 'Parent ID'])
80
+ recs_df = format_to_df(POPULAR_ITEMS)
81
+ message = f"User {visitor_id} is new. Showing Top 10 popular items as a fallback."
82
+ return history_df, recs_df, message
83
+
84
+ # --- Existing User (Use SASRec Model) ---
85
+ history_df = format_to_df(user_history_ids)
86
+
87
+ history_indices = [DATAMODULE.item_map[i] for i in user_history_ids if i in DATAMODULE.item_map]
88
+
89
+ if not history_indices:
90
+ message = "None of this user's historical items are in the model's vocabulary."
91
+ return history_df, pd.DataFrame(), message
92
+
93
+ max_len = DATAMODULE.max_len
94
+ input_seq = history_indices[-max_len:]
95
+ padded_input = np.zeros(max_len, dtype=np.int64)
96
+ padded_input[-len(input_seq):] = input_seq
97
+
98
+ input_tensor = torch.LongTensor(np.array([padded_input]))
99
+ input_tensor = input_tensor.to(MODEL.device)
100
+
101
+ with torch.no_grad():
102
+ logits = MODEL(input_tensor)
103
+ last_item_logits = logits[0, -1, :]
104
+ top_indices = torch.topk(last_item_logits, 10).indices.tolist()
105
+
106
+ recommended_item_ids = [DATAMODULE.inverse_item_map[idx] for idx in top_indices if idx in DATAMODULE.inverse_item_map]
107
+ recs_df = format_to_df(recommended_item_ids)
108
+ message = f"Showing personalized SASRec recommendations for user {visitor_id}."
109
+
110
+ return history_df, recs_df, message
111
+
112
+ # --- Main Execution Block ---
113
+ if __name__ == "__main__":
114
+ # Load all artifacts once at startup
115
+ load_artifacts()
116
+
117
+ # Find some valid example users to show in the UI
118
+ users_in_train_history = set(DATAMODULE.user_history.keys())
119
+ users_in_test_set = set(DATAMODULE.test_df['visitorid'].unique())
120
+ valid_example_users = list(users_in_train_history.intersection(users_in_test_set))
121
+
122
+ # Convert numpy types to standard Python int for Gradio compatibility
123
+ example_list = [int(u) for u in valid_example_users[:4]] + [-999]
124
+
125
+ # Create and launch the Gradio interface
126
+ with gr.Blocks(theme=gr.themes.Soft(), title="SASRec Recommender") as iface:
127
+ gr.Markdown(
128
+ """
129
+ # SASRec Sequential Recommender System
130
+ An interactive demo of a state-of-the-art recommender system trained on the RetailRocket dataset.
131
+ """
132
+ )
133
+ with gr.Row():
134
+ with gr.Column(scale=1):
135
+ visitor_id_input = gr.Number(
136
+ label="Enter Visitor ID",
137
+ info="Enter a user's numerical ID to get recommendations."
138
+ )
139
+ submit_button = gr.Button("Get Recommendations", variant="primary")
140
+ gr.Examples(
141
+ examples=example_list,
142
+ inputs=visitor_id_input,
143
+ label="Example User IDs (Click to try)"
144
+ )
145
+
146
+ with gr.Column(scale=3):
147
+ status_message = gr.Textbox(label="Status", interactive=False)
148
+ with gr.Tabs():
149
+ with gr.TabItem("Top 10 Recommendations"):
150
+ recs_output = gr.DataFrame(label="Recommended Items")
151
+ with gr.TabItem("User's Recent History"):
152
+ history_output = gr.DataFrame(label="Interaction History")
153
+
154
+ submit_button.click(
155
+ fn=get_recommendations,
156
+ inputs=visitor_id_input,
157
+ outputs=[history_output, recs_output, status_message]
158
+ )
159
+
160
+ # For local testing, this creates a shareable link.
161
+ # On Hugging Face Spaces, this is not strictly necessary but doesn't hurt.
162
+ iface.launch(share=True)
scripts/data_prepare.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zipfile
2
+ import pandas as pd
3
+ from datetime import datetime, timedelta
4
+ import numpy as np
5
+ from scipy.sparse import csr_matrix
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import pytorch_lightning as pl
11
+
12
+ def extract_ziped_data(ziped_data_path: str, extract_path : str):
13
+ """Extracts the contents of a zip file to a specified directory.
14
+
15
+ args:
16
+ ziped_data_path: str, path to the zip file
17
+ extract_path: str, path to the directory where contents will be extracted
18
+ """
19
+ # The directory where you want to extract the contents
20
+ extract_path = 'data'
21
+
22
+ # Open the zip file in read mode
23
+ with zipfile.ZipFile(ziped_data_path, 'r') as zip_ref:
24
+ # Extract all the contents into the specified directory
25
+ zip_ref.extractall(extract_path)
26
+
27
+ print(f"'{ziped_data_path}' has been extracted to '{extract_path}'")
28
+
29
+ def prepare_data(data_folder='data/', val_days=7, test_days=7):
30
+ """
31
+ Loads, preprocesses, and splits the events data into train, validation, and test sets.
32
+
33
+ args:
34
+ data_folder: str, path to the folder containing 'events.csv'
35
+ val_days: int, number of days for the validation set
36
+ test_days: int, number of days for the test set
37
+ """
38
+ # --- Load Data ---
39
+ print(f"Loading events.csv from folder: {data_folder}")
40
+ try:
41
+ events_df = pd.read_csv(data_folder + 'events.csv')
42
+ print("Successfully loaded events.csv.")
43
+ events_df['timestamp_dt'] = pd.to_datetime(events_df['timestamp'], unit='ms')
44
+ print("\n--- Initial Data Summary ---")
45
+ print(f"Data shape: {events_df.shape}")
46
+ print(f"Full timeframe: {events_df['timestamp_dt'].min()} to {events_df['timestamp_dt'].max()}")
47
+ print("----------------------------\n")
48
+ except FileNotFoundError:
49
+ print(f"Error: 'events.csv' not found in '{data_folder}'. Please check the path.")
50
+ return None, None, None
51
+
52
+ # --- Split Data ---
53
+ sorted_df = events_df.sort_values('timestamp_dt').reset_index(drop=True)
54
+ print(f"Splitting data: {test_days} days for test, {val_days} for validation.")
55
+ end_time = sorted_df['timestamp_dt'].max()
56
+ test_start_time = end_time - timedelta(days=test_days)
57
+ val_start_time = test_start_time - timedelta(days=val_days)
58
+
59
+ test_df = sorted_df[sorted_df['timestamp_dt'] >= test_start_time]
60
+ val_df = sorted_df[(sorted_df['timestamp_dt'] >= val_start_time) & (sorted_df['timestamp_dt'] < test_start_time)]
61
+ train_df = sorted_df[sorted_df['timestamp_dt'] < val_start_time]
62
+
63
+ print("--- Data Splitting Summary ---")
64
+ print(f"Training set: {train_df.shape[0]:>8} records | from {train_df['timestamp_dt'].min()} to {train_df['timestamp_dt'].max()}")
65
+ print(f"Validation set: {val_df.shape[0]:>8} records | from {val_df['timestamp_dt'].min()} to {val_df['timestamp_dt'].max()}")
66
+ print(f"Test set: {test_df.shape[0]:>8} records | from {test_df['timestamp_dt'].min()} to {test_df['timestamp_dt'].max()}")
67
+ print("------------------------------")
68
+
69
+ return train_df, val_df, test_df
70
+
71
+ class SASRecDataset(Dataset):
72
+ """
73
+ SASRec Dataset.
74
+ - Precomputes (sequence_id, cutoff_idx) pairs for O(1) __getitem__.
75
+ - Supports 'last' or 'all' target modes.
76
+ """
77
+ def __init__(self, sequences, max_len, target_mode="last"):
78
+ """
79
+ Args:
80
+ sequences: list of user sequences (list of item IDs).
81
+ max_len: maximum sequence length (padding applied).
82
+ target_mode: 'last' (only last prediction) or 'all' (predict at every step).
83
+ """
84
+ self.sequences = sequences
85
+ self.max_len = max_len
86
+ self.target_mode = target_mode
87
+
88
+ # Build index once
89
+ self.index = []
90
+ for seq_id, seq in enumerate(sequences):
91
+ for i in range(1, len(seq)):
92
+ self.index.append((seq_id, i))
93
+
94
+ def __len__(self):
95
+ return len(self.index)
96
+
97
+ def __getitem__(self, idx):
98
+ seq_id, cutoff = self.index[idx]
99
+ seq = self.sequences[seq_id][:cutoff]
100
+
101
+ # Truncate & pad
102
+ seq = seq[-self.max_len:]
103
+ pad_len = self.max_len - len(seq)
104
+
105
+ input_seq = np.zeros(self.max_len, dtype=np.int64)
106
+ input_seq[pad_len:] = seq
107
+
108
+ if self.target_mode == "last":
109
+ target = self.sequences[seq_id][cutoff]
110
+ return torch.LongTensor(input_seq), torch.LongTensor([target])
111
+
112
+ elif self.target_mode == "all":
113
+ # Predict next item at each step
114
+ target_seq = self.sequences[seq_id][1:cutoff+1]
115
+ target_seq = target_seq[-self.max_len:]
116
+ target = np.zeros(self.max_len, dtype=np.int64)
117
+ target[-len(target_seq):] = target_seq
118
+ return torch.LongTensor(input_seq), torch.LongTensor(target)
119
+
120
+ class SASRecDataModule(pl.LightningDataModule):
121
+ """
122
+ PyTorch Lightning DataModule for preparing the RetailRocket dataset for the SASRec model.
123
+
124
+ This class handles all aspects of data preparation, including:
125
+ - Filtering out infrequent users and items to reduce noise.
126
+ - Building a consistent item vocabulary.
127
+ - Converting user event histories into sequential data.
128
+ - Creating and providing `DataLoader` instances for training, validation, and testing.
129
+ """
130
+ def __init__(self, train_df, val_df, test_df, min_item_interactions=5,
131
+ min_user_interactions=5, max_len=50, batch_size=256):
132
+ """
133
+ Initializes the DataModule.
134
+
135
+ Args:
136
+ train_df (pd.DataFrame): DataFrame for training.
137
+ val_df (pd.DataFrame): DataFrame for validation.
138
+ test_df (pd.DataFrame): DataFrame for testing.
139
+ min_item_interactions (int): Minimum number of interactions for an item to be kept.
140
+ min_user_interactions (int): Minimum number of interactions for a user to be kept.
141
+ max_len (int): The maximum length of a user sequence fed to the model.
142
+ batch_size (int): The batch size for the DataLoaders.
143
+ """
144
+ super().__init__()
145
+ self.train_df = train_df
146
+ self.val_df = val_df
147
+ self.test_df = test_df
148
+ self.min_item_interactions = min_item_interactions
149
+ self.min_user_interactions = min_user_interactions
150
+ self.max_len = max_len
151
+ self.batch_size = batch_size
152
+
153
+ self.item_map = None
154
+ self.inverse_item_map = None
155
+ self.vocab_size = 0
156
+ self.user_history = None
157
+
158
+ def setup(self, stage=None):
159
+ """
160
+ Prepares the data for training, validation, and testing.
161
+
162
+ This method is called automatically by PyTorch Lightning. It performs the following steps:
163
+ 1. Determines filtering criteria (which users and items to keep) based on the training set only
164
+ to prevent data leakage.
165
+ 2. Applies these filters to the train, validation, and test sets.
166
+ 3. Builds an item vocabulary (mapping item IDs to integer indices) from the combined
167
+ training and validation sets to ensure consistency for model checkpointing.
168
+ 4. Converts the event logs into sequences of item indices for each user in each data split.
169
+ """
170
+ item_counts = self.train_df['itemid'].value_counts()
171
+ user_counts = self.train_df['visitorid'].value_counts()
172
+ items_to_keep = item_counts[item_counts >= self.min_item_interactions].index
173
+ users_to_keep = user_counts[user_counts >= self.min_user_interactions].index
174
+
175
+ self.filtered_train_df = self.train_df[
176
+ (self.train_df['itemid'].isin(items_to_keep)) &
177
+ (self.train_df['visitorid'].isin(users_to_keep))
178
+ ].copy()
179
+ self.filtered_val_df = self.val_df[
180
+ (self.val_df['itemid'].isin(items_to_keep)) &
181
+ (self.val_df['visitorid'].isin(users_to_keep))
182
+ ].copy()
183
+ self.filtered_test_df = self.test_df[
184
+ (self.test_df['itemid'].isin(items_to_keep)) &
185
+ (self.test_df['visitorid'].isin(users_to_keep))
186
+ ].copy()
187
+
188
+ all_known_items_df = pd.concat([self.filtered_train_df, self.filtered_val_df])
189
+ unique_items = all_known_items_df['itemid'].unique()
190
+ self.item_map = {item_id: i + 1 for i, item_id in enumerate(unique_items)}
191
+ self.inverse_item_map = {i: item_id for item_id, i in self.item_map.items()}
192
+ self.vocab_size = len(self.item_map) + 1 # +1 for padding token 0
193
+
194
+ self.user_history = self.filtered_train_df.groupby('visitorid')['itemid'].apply(list)
195
+
196
+ self.train_sequences = self._create_sequences(self.filtered_train_df)
197
+ self.val_sequences = self._create_sequences(self.filtered_val_df)
198
+ self.test_sequences = self._create_sequences(self.filtered_test_df)
199
+
200
+ def _create_sequences(self, df):
201
+ """
202
+ Helper function to convert a DataFrame of events into user interaction sequences.
203
+
204
+ Args:
205
+ df (pd.DataFrame): The input DataFrame to process.
206
+
207
+ Returns:
208
+ list[list[int]]: A list of user sequences, where each sequence is a list of item indices.
209
+ """
210
+ df_sorted = df.sort_values(['visitorid', 'timestamp_dt'])
211
+ sequences = df_sorted.groupby('visitorid')['itemid'].apply(
212
+ lambda x: [self.item_map[i] for i in x if i in self.item_map]
213
+ ).tolist()
214
+ return [s for s in sequences if len(s) > 1]
215
+
216
+ def train_dataloader(self):
217
+ """Creates the DataLoader for the training set."""
218
+ dataset = SASRecDataset(self.train_sequences, self.max_len)
219
+ return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)
220
+
221
+ def val_dataloader(self):
222
+ """Creates the DataLoader for the validation set."""
223
+ dataset = SASRecDataset(self.val_sequences, self.max_len)
224
+ return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
225
+
226
+ def test_dataloader(self):
227
+ """Creates the DataLoader for the test set."""
228
+ dataset = SASRecDataset(self.test_sequences, self.max_len)
229
+ return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
230
+
231
+ if __name__ == "__main__":
232
+
233
+ # --- Configuration ---
234
+ DATA_PATH = "data"
235
+ ZIPED_DATA_PATH = "data/archive.zip" # change to your zip file path
236
+ BATCH_SIZE = 256
237
+ MAX_TOKEN_LEN = 50 # 50–100 is standard for SASRec
238
+
239
+ # extract_ziped_data(ZIPED_DATA_PATH, DATA_PATH) # uncomment this line if you want to extract the data
240
+
241
+ # --- 1. Prepare the data into train, validation, and test sets ---
242
+ train_set, validation_set, test_set = prepare_data(data_folder=DATA_PATH)
243
+
244
+ # --- 2. Initialize DataModule ---
245
+ print("Initializing DataModule...")
246
+ datamodule = SASRecDataModule(
247
+ train_df=train_set,
248
+ val_df=validation_set,
249
+ test_df=test_set,
250
+ batch_size=BATCH_SIZE,
251
+ max_len=MAX_TOKEN_LEN
252
+ )
253
+ datamodule.setup()
scripts/main.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import datetime
5
+ from models import SASRec
6
+ from utils import prepare_ground_truth, calculate_metrics, load_item_properties, load_category_tree, get_popular_items, show_user_recommendations
7
+ from data_prepare import prepare_data, SASRecDataset, SASRecDataModule
8
+
9
+ def main(checkpoint_path="checkpoints/sasrec-epoch=06-val_hitrate@10=0.3629.ckpt", data_folder="data/"):
10
+ """
11
+ Main function to run the inference and qualitative analysis pipeline.
12
+ """
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"Using device: {device}")
16
+
17
+ print("Loading model from checkpoint...")
18
+ best_model = SASRec.load_from_checkpoint(checkpoint_path)
19
+ best_model.to(device)
20
+
21
+ print("Preparing data...")
22
+ train_set, validation_set, test_set = prepare_data(data_folder=data_folder)
23
+
24
+ datamodule = SASRecDataModule(train_set, validation_set, test_set)
25
+ datamodule.setup()
26
+
27
+ item_category_map = load_item_properties(data_folder=data_folder)
28
+ category_parent_map = load_category_tree(data_folder=data_folder)
29
+
30
+ print("\nCalculating popular items for cold-start users...")
31
+ popular_items_list = get_popular_items(train_set, k=10)
32
+
33
+ users_in_train_history = set(datamodule.user_history.keys())
34
+ users_in_test_set = set(datamodule.test_df['visitorid'].unique())
35
+ valid_example_users = list(users_in_train_history.intersection(users_in_test_set))
36
+
37
+ print(f"\nFound {len(valid_example_users)} users for qualitative analysis.")
38
+
39
+ for user_id in valid_example_users[:3]:
40
+ show_user_recommendations(user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)
41
+
42
+ new_user_id = -999
43
+ show_user_recommendations(new_user_id, best_model, datamodule, popular_items_list, item_category_map, category_parent_map)
44
+
45
+ if __name__ == "__main__":
46
+ main()
scripts/models.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import pytorch_lightning as pl
7
+ from scipy.sparse import csr_matrix
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import implicit
10
+ from utils import prepare_ground_truth, calculate_metrics
11
+
12
+ def recommend_popular_items_and_evaluate(train_df, test_df, k=10, prepare_ground_truth=None, calculate_metrics=None):
13
+ """
14
+ Trains a non-personalized Popularity model and evaluates its performance.
15
+
16
+ This model recommends the top-k most frequently transacted items from the training
17
+ set to every user. It serves as a simple but strong baseline.
18
+
19
+ Args:
20
+ train_df (pd.DataFrame): The training dataset.
21
+ test_df (pd.DataFrame): The test dataset for evaluation.
22
+ k (int): The number of items to recommend.
23
+ prepare_ground_truth (function): A function to process the test_df into a ground truth dict.
24
+ calculate_metrics (function): A function to compute ranking metrics.
25
+
26
+ Returns:
27
+ dict: A dictionary containing the calculated evaluation metrics (e.g., precision, recall).
28
+ """
29
+ print(f"\n--- Evaluating Popularity Model (Top {k} items) ---")
30
+
31
+ # 1. "Train" the model by finding the most popular items based on transactions
32
+ purchase_counts = train_df[train_df['event'] == 'transaction']['itemid'].value_counts()
33
+ popular_items = purchase_counts.head(k).index.tolist()
34
+ print(f"Top {k} popular items identified from training data.")
35
+
36
+ # 2. Evaluate the model
37
+ ground_truth = prepare_ground_truth(test_df)
38
+ # Every user receives the same list of popular items
39
+ recommendations = {user_id: popular_items for user_id in ground_truth.keys()}
40
+
41
+ metrics = calculate_metrics(recommendations, ground_truth, k)
42
+ print("Evaluation complete.")
43
+ return metrics
44
+
45
+ def recommend_item_item_and_evaluate(train_df, test_df, k=10, min_item_interactions=5, min_user_interactions=5, prepare_ground_truth=None, calculate_metrics=None):
46
+ """
47
+ Trains an Item-Item Collaborative Filtering model and evaluates its performance.
48
+
49
+ This model recommends items that are similar to items a user has interacted
50
+ with in the past, based on co-occurrence patterns in the training data.
51
+
52
+ Args:
53
+ train_df (pd.DataFrame): The training dataset.
54
+ test_df (pd.DataFrame): The test dataset for evaluation.
55
+ k (int): The number of items to recommend.
56
+ min_item_interactions (int): Minimum number of interactions for an item to be kept.
57
+ min_user_interactions (int): Minimum number of interactions for a user to be kept.
58
+ prepare_ground_truth (function): A function to process the test_df into a ground truth dict.
59
+ calculate_metrics (function): A function to compute ranking metrics.
60
+
61
+ Returns:
62
+ dict: A dictionary containing the calculated evaluation metrics.
63
+ """
64
+ print(f"\n--- Evaluating Item-Item CF Model (Top {k} items) ---")
65
+
66
+ # 1. Filter out infrequent users and items to reduce noise and computation
67
+ item_counts = train_df['itemid'].value_counts()
68
+ user_counts = train_df['visitorid'].value_counts()
69
+ items_to_keep = item_counts[item_counts >= min_item_interactions].index
70
+ users_to_keep = user_counts[user_counts >= min_user_interactions].index
71
+ filtered_df = train_df[(train_df['itemid'].isin(items_to_keep)) & (train_df['visitorid'].isin(users_to_keep))].copy()
72
+ print(f"Filtered training data from {len(train_df)} to {len(filtered_df)} records.")
73
+
74
+ # 2. Create user-item interaction matrix and vocabulary mappings
75
+ user_map = {uid: i for i, uid in enumerate(filtered_df['visitorid'].unique())}
76
+ item_map = {iid: i for i, iid in enumerate(filtered_df['itemid'].unique())}
77
+ inverse_item_map = {i: iid for iid, i in item_map.items()}
78
+ user_indices = filtered_df['visitorid'].map(user_map)
79
+ item_indices = filtered_df['itemid'].map(item_map)
80
+ user_item_matrix = csr_matrix((np.ones(len(filtered_df)), (user_indices, item_indices)))
81
+
82
+ # 3. Calculate the cosine similarity matrix between all items
83
+ print("Calculating item similarity matrix...")
84
+ item_similarity_matrix = cosine_similarity(user_item_matrix.T, dense_output=False)
85
+ print("Similarity matrix calculated.")
86
+
87
+ # 4. Generate recommendations and evaluate
88
+ ground_truth = prepare_ground_truth(test_df)
89
+ recommendations = {}
90
+ print("Generating recommendations for users in test set...")
91
+ test_users = [u for u in ground_truth.keys() if u in user_map]
92
+
93
+ for user_id in test_users:
94
+ user_index = user_map[user_id]
95
+ user_interactions_indices = user_item_matrix[user_index].indices
96
+
97
+ if len(user_interactions_indices) > 0:
98
+ # Aggregate scores from items the user has interacted with
99
+ all_scores = np.asarray(item_similarity_matrix[user_interactions_indices].sum(axis=0)).flatten()
100
+ # Remove already interacted items from recommendations
101
+ all_scores[user_interactions_indices] = -1
102
+ top_indices = np.argsort(all_scores)[::-1][:k]
103
+ recs = [inverse_item_map[idx] for idx in top_indices if idx in inverse_item_map]
104
+ recommendations[user_id] = recs
105
+
106
+ metrics = calculate_metrics(recommendations, ground_truth, k)
107
+ print("Evaluation complete.")
108
+ return metrics
109
+
110
+ def recommend_als_and_evaluate(train_df, test_df, k=10, min_item_interactions=5, min_user_interactions=5,
111
+ factors=25, regularization=0.02, iterations=48, prepare_ground_truth=None, calculate_metrics=None):
112
+ """
113
+ Trains an Alternating Least Squares (ALS) model and evaluates its performance.
114
+
115
+ This model uses matrix factorization to learn latent embeddings for users and
116
+ items from implicit feedback data. Default hyperparameters are set from a
117
+ previous Optuna tuning process.
118
+
119
+ Args:
120
+ train_df (pd.DataFrame): The training dataset.
121
+ test_df (pd.DataFrame): The test dataset for evaluation.
122
+ k (int): The number of items to recommend.
123
+ min_item_interactions (int): Minimum number of interactions for an item to be kept.
124
+ min_user_interactions (int): Minimum number of interactions for a user to be kept.
125
+ factors (int): The number of latent factors to compute.
126
+ regularization (float): The regularization factor.
127
+ iterations (int): The number of ALS iterations to run.
128
+ prepare_ground_truth (function): A function to process the test_df into a ground truth dict.
129
+ calculate_metrics (function): A function to compute ranking metrics.
130
+
131
+ Returns:
132
+ dict: A dictionary containing the calculated evaluation metrics.
133
+ """
134
+ print(f"\n--- Evaluating ALS Model (Top {k} items) ---")
135
+
136
+ # 1. Filter data
137
+ item_counts = train_df['itemid'].value_counts()
138
+ user_counts = train_df['visitorid'].value_counts()
139
+ items_to_keep = item_counts[item_counts >= min_item_interactions].index
140
+ users_to_keep = user_counts[user_counts >= min_user_interactions].index
141
+ filtered_df = train_df[(train_df['itemid'].isin(items_to_keep)) & (train_df['visitorid'].isin(users_to_keep))].copy()
142
+ print(f"Filtered training data from {len(train_df)} to {len(filtered_df)} records.")
143
+
144
+ # 2. Create mappings and confidence matrix
145
+ user_map = {uid: i for i, uid in enumerate(filtered_df['visitorid'].unique())}
146
+ item_map = {iid: i for i, iid in enumerate(filtered_df['itemid'].unique())}
147
+ inverse_item_map = {i: iid for iid, i in item_map.items()}
148
+ user_indices = filtered_df['visitorid'].map(user_map).astype(np.int32)
149
+ item_indices = filtered_df['itemid'].map(item_map).astype(np.int32)
150
+
151
+ event_weights = {'view': 1, 'addtocart': 3, 'transaction': 5}
152
+ confidence = filtered_df['event'].map(event_weights).astype(np.float32)
153
+ user_item_matrix = csr_matrix((confidence, (user_indices, item_indices)))
154
+
155
+ # 3. Train the ALS model
156
+ print("Training ALS model...")
157
+ als_model = implicit.als.AlternatingLeastSquares(factors=factors, regularization=regularization, iterations=iterations)
158
+ als_model.fit(user_item_matrix)
159
+ print("ALS model trained.")
160
+
161
+ # 4. Generate recommendations and evaluate
162
+ ground_truth = prepare_ground_truth(test_df)
163
+ recommendations = {}
164
+ print("Generating recommendations for users in test set...")
165
+ test_users_indices = [user_map[u] for u in ground_truth.keys() if u in user_map]
166
+
167
+ if test_users_indices:
168
+ user_item_matrix_for_recs = user_item_matrix[test_users_indices]
169
+ ids, _ = als_model.recommend(test_users_indices, user_item_matrix_for_recs, N=k)
170
+
171
+ for i, user_index in enumerate(test_users_indices):
172
+ original_user_id = list(user_map.keys())[list(user_map.values()).index(user_index)]
173
+ recs = [inverse_item_map[item_idx] for item_idx in ids[i] if item_idx in inverse_item_map]
174
+ recommendations[original_user_id] = recs
175
+
176
+ metrics = calculate_metrics(recommendations, ground_truth, k)
177
+ print("Evaluation complete.")
178
+ return metrics
179
+
180
+ class SASRec(pl.LightningModule):
181
+ """
182
+ A PyTorch Lightning implementation of the SASRec model for sequential recommendation.
183
+
184
+ SASRec (Self-Attentive Sequential Recommendation) uses a Transformer-based
185
+ architecture to capture the sequential patterns in a user's interaction history
186
+ to predict the next item they are likely to interact with.
187
+
188
+ Attributes:
189
+ save_hyperparameters: Automatically saves all constructor arguments as hyperparameters.
190
+ item_embedding (nn.Embedding): Embedding layer for item IDs.
191
+ positional_embedding (nn.Embedding): Embedding layer to encode the position of items in a sequence.
192
+ transformer_encoder (nn.TransformerEncoder): The core self-attention module.
193
+ fc (nn.Linear): Final fully connected layer to produce logits over the item vocabulary.
194
+ loss_fn (nn.CrossEntropyLoss): The loss function used for training.
195
+ """
196
+ def __init__(self, vocab_size, max_len, hidden_dim, num_heads, num_layers,
197
+ dropout=0.2, learning_rate=1e-3, weight_decay=1e-6, warmup_steps=2000, max_steps=100000):
198
+ """
199
+ Initializes the SASRec model layers and hyperparameters.
200
+
201
+ Args:
202
+ vocab_size (int): The total number of unique items in the dataset (+1 for padding).
203
+ max_len (int): The maximum length of the input sequences.
204
+ hidden_dim (int): The dimensionality of the item and positional embeddings.
205
+ num_heads (int): The number of attention heads in the Transformer encoder.
206
+ num_layers (int): The number of layers in the Transformer encoder.
207
+ dropout (float): The dropout rate to be applied.
208
+ learning_rate (float): The learning rate for the optimizer.
209
+ weight_decay (float): The weight decay (L2 penalty) for the optimizer.
210
+ warmup_steps (int): The number of linear warmup steps for the learning rate scheduler.
211
+ max_steps (int): The total number of training steps for the learning rate scheduler's decay phase.
212
+ """
213
+ super().__init__()
214
+ # This saves all hyperparameters to self.hparams, making them accessible later
215
+ self.save_hyperparameters()
216
+
217
+ # Embedding layers
218
+ self.item_embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
219
+ self.positional_embedding = nn.Embedding(max_len, hidden_dim)
220
+ self.dropout = nn.Dropout(dropout)
221
+
222
+ # Transformer Encoder
223
+ encoder_layer = nn.TransformerEncoderLayer(
224
+ d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim * 4,
225
+ dropout=dropout, batch_first=True, activation='gelu'
226
+ )
227
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
228
+
229
+ # Output layer
230
+ self.fc = nn.Linear(hidden_dim, vocab_size)
231
+
232
+ # Loss function, ignoring the padding token
233
+ self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)
234
+
235
+ # Lists to store outputs from validation and test steps
236
+ self.validation_step_outputs = []
237
+ self.test_step_outputs = []
238
+
239
+ def forward(self, x):
240
+ """
241
+ Defines the forward pass of the model.
242
+
243
+ Args:
244
+ x (torch.Tensor): A batch of input sequences of shape (batch_size, seq_len).
245
+
246
+ Returns:
247
+ torch.Tensor: The output logits of shape (batch_size, seq_len, vocab_size).
248
+ """
249
+ seq_len = x.size(1)
250
+ # Create positional indices (0, 1, 2, ..., seq_len-1)
251
+ positions = torch.arange(seq_len, device=self.device).unsqueeze(0)
252
+
253
+ # Create a causal mask to ensure the model doesn't look ahead in the sequence
254
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=self.device)
255
+
256
+ # Combine item and positional embeddings
257
+ x = self.item_embedding(x) + self.positional_embedding(positions)
258
+ x = self.dropout(x)
259
+
260
+ # Pass through the Transformer encoder
261
+ x = self.transformer_encoder(x, mask=causal_mask)
262
+
263
+ # Get final logits
264
+ logits = self.fc(x)
265
+ return logits
266
+
267
+ def training_step(self, batch, batch_idx):
268
+ """
269
+ Performs a single training step.
270
+
271
+ Args:
272
+ batch (tuple): A tuple containing input sequences and target items.
273
+ batch_idx (int): The index of the current batch.
274
+
275
+ Returns:
276
+ torch.Tensor: The calculated loss for the batch.
277
+ """
278
+ inputs, targets = batch
279
+ logits = self.forward(inputs)
280
+
281
+ # We only care about the prediction for the very last item in the input sequence
282
+ last_logits = logits[:, -1, :]
283
+
284
+ # Calculate loss against the single target item
285
+ loss = self.loss_fn(last_logits, targets.squeeze())
286
+
287
+ self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True)
288
+ return loss
289
+
290
+ def validation_step(self, batch, batch_idx):
291
+ """
292
+ Performs a single validation step.
293
+ Calculates loss and stores predictions for metric computation at the end of the epoch.
294
+ """
295
+ inputs, targets = batch
296
+ logits = self.forward(inputs)
297
+ last_item_logits = logits[:, -1, :]
298
+ loss = self.loss_fn(last_item_logits, targets.squeeze())
299
+ self.log('val_loss', loss, prog_bar=True, on_epoch=True)
300
+
301
+ # Get top-10 predictions for metric calculation
302
+ top_k_preds = torch.topk(last_item_logits, 10, dim=-1).indices
303
+ self.validation_step_outputs.append({'preds': top_k_preds, 'targets': targets})
304
+ return loss
305
+
306
+ def on_validation_epoch_end(self):
307
+ """
308
+ Calculates and logs ranking metrics at the end of the validation epoch.
309
+ """
310
+ if not self.validation_step_outputs: return
311
+
312
+ # Concatenate all predictions and targets from the epoch
313
+ preds = torch.cat([x['preds'] for x in self.validation_step_outputs], dim=0)
314
+ targets = torch.cat([x['targets'] for x in self.validation_step_outputs], dim=0)
315
+
316
+ k = preds.size(1)
317
+ # Check if the target is in the top-k predictions for each example
318
+ hits_tensor = (preds == targets).any(dim=1)
319
+ num_hits = hits_tensor.sum().item()
320
+ num_targets = len(targets)
321
+
322
+ if num_targets > 0:
323
+ hit_rate = num_hits / num_targets
324
+ recall = hit_rate # For next-item prediction, recall@k is the same as hit_rate@k
325
+ precision = num_hits / (k * num_targets)
326
+ else:
327
+ hit_rate, recall, precision = 0.0, 0.0, 0.0
328
+
329
+ self.log('val_hitrate@10', hit_rate, prog_bar=True)
330
+ self.log('val_precision@10', precision, prog_bar=True)
331
+ self.log('val_recall@10', recall, prog_bar=True)
332
+
333
+ self.validation_step_outputs.clear() # Free up memory
334
+
335
+ def test_step(self, batch, batch_idx):
336
+ """
337
+ Performs a single test step.
338
+ Mirrors the logic of the validation_step.
339
+ """
340
+ inputs, targets = batch
341
+ logits = self.forward(inputs)
342
+ last_item_logits = logits[:, -1, :]
343
+ loss = self.loss_fn(last_item_logits, targets.squeeze())
344
+ self.log('test_loss', loss, prog_bar=True)
345
+
346
+ top_k_preds = torch.topk(last_item_logits, 10, dim=-1).indices
347
+ self.test_step_outputs.append({'preds': top_k_preds, 'targets': targets})
348
+ return loss
349
+
350
+ def on_test_epoch_end(self):
351
+ """
352
+ Calculates and logs ranking metrics at the end of the test epoch.
353
+ """
354
+ if not self.test_step_outputs: return
355
+
356
+ preds = torch.cat([x['preds'] for x in self.test_step_outputs], dim=0)
357
+ targets = torch.cat([x['targets'] for x in self.test_step_outputs], dim=0)
358
+
359
+ k = preds.size(1)
360
+ hits_tensor = (preds == targets).any(dim=1)
361
+ num_hits = hits_tensor.sum().item()
362
+ num_targets = len(targets)
363
+
364
+ if num_targets > 0:
365
+ hit_rate = num_hits / num_targets
366
+ recall = hit_rate
367
+ precision = num_hits / (k * num_targets)
368
+ else:
369
+ hit_rate, recall, precision = 0.0, 0.0, 0.0
370
+
371
+ self.log('test_hitrate@10', hit_rate, prog_bar=True)
372
+ self.log('test_precision@10', precision, prog_bar=True)
373
+ self.log('test_recall@10', recall, prog_bar=True)
374
+
375
+ self.test_step_outputs.clear() # Free up memory
376
+
377
+ def configure_optimizers(self):
378
+ """
379
+ Configures the optimizer and learning rate scheduler.
380
+
381
+ Uses AdamW optimizer and a linear warmup followed by a cosine decay schedule,
382
+ which is a standard practice for training Transformer models.
383
+ """
384
+ optimizer = torch.optim.AdamW(
385
+ self.parameters(),
386
+ lr=self.hparams.learning_rate,
387
+ weight_decay=self.hparams.weight_decay
388
+ )
389
+
390
+ # Learning rate scheduler: linear warmup and cosine decay
391
+ def lr_lambda(current_step: int):
392
+ warmup_steps = self.hparams.warmup_steps
393
+ max_steps = self.hparams.max_steps
394
+ if current_step < warmup_steps:
395
+ return float(current_step) / float(max(1, warmup_steps))
396
+ progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))
397
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
398
+
399
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
400
+
401
+ return {
402
+ "optimizer": optimizer,
403
+ "lr_scheduler": {
404
+ "scheduler": scheduler,
405
+ "interval": "step", # Update the scheduler at every training step
406
+ "frequency": 1
407
+ }
408
+ }
scripts/train_and_eval.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import pytorch_lightning as pl
4
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
5
+ from pytorch_lightning.loggers import TensorBoardLogger
6
+ import torch
7
+
8
+ from utils import prepare_ground_truth, calculate_metrics
9
+ from data_prepare import prepare_data, SASRecDataset, SASRecDataModule
10
+ from models import recommend_popular_items_and_evaluate, recommend_item_item_and_evaluate, recommend_als_and_evaluate, SASRec
11
+
12
+
13
+ def train_and_eval_SASRec_model(train_set, validation_set, test_set, checkpoint_dir_path='checkpoints/',
14
+ checkpoint_path=None, n_epochs=10, mode='train',
15
+ batchsize=256, max_token_len=50, learning_rate=1e-3, hidden_dim=128,
16
+ num_heads=2, num_layers=2, dropout=0.2, weight_decay=1e-6):
17
+ """
18
+ Train or evaluate a SASRec sequential recommendation model using PyTorch Lightning.
19
+
20
+ This function wraps the entire SASRec pipeline:
21
+ - Initializes the SASRecDataModule (handles dataset preprocessing and dataloaders).
22
+ - Builds the SASRec Transformer-based model.
23
+ - Configures training callbacks (checkpointing, early stopping, LR monitoring).
24
+ - Runs either training (`mode='train'`) or evaluation on the test set (`mode='test'`).
25
+
26
+ Args
27
+ ----------
28
+ train_set : pd.DataFrame
29
+ Training interactions dataset .
30
+ validation_set : pd.DataFrame
31
+ Validation dataset with the same structure as `train_set`.
32
+ test_set : pd.DataFrame
33
+ Test dataset with the same structure as `train_set`.
34
+ checkpoint_dir_path : str, optional (default='checkpoints/')
35
+ Directory to save model checkpoints.
36
+ checkpoint_path : str or None, optional (default=None)
37
+ Path to a checkpoint file for resuming training or loading a pretrained model for testing.
38
+ n_epochs : int, optional (default=10)
39
+ Number of training epochs.
40
+ mode : {'train', 'test'}, optional (default='train')
41
+ - `'train'`: trains the model on the training/validation data.
42
+ - `'test'`: evaluates the model on the test set using a checkpoint.
43
+ batchsize : int, optional (default=256)
44
+ Batch size for training and evaluation.
45
+ max_token_len : int, optional (default=50)
46
+ Maximum sequence length per user (recent interactions kept).
47
+ learning_rate : float, optional (default=1e-3)
48
+ Learning rate for the AdamW optimizer.
49
+ hidden_dim : int, optional (default=128)
50
+ Dimensionality of item and positional embeddings.
51
+ num_heads : int, optional (default=2)
52
+ Number of attention heads in each Transformer encoder layer.
53
+ num_layers : int, optional (default=2)
54
+ Number of Transformer encoder layers.
55
+ dropout : float, optional (default=0.2)
56
+ Dropout probability applied in embeddings and Transformer layers.
57
+ weight_decay : float, optional (default=1e-6)
58
+ Weight decay regularization coefficient for AdamW.
59
+ """
60
+ # --- 1. Initialize DataModule ---
61
+ print("Initializing DataModule...")
62
+ datamodule = SASRecDataModule(
63
+ train_df=train_set,
64
+ val_df=validation_set,
65
+ test_df=test_set,
66
+ batch_size=batchsize,
67
+ max_len=max_token_len
68
+ )
69
+ datamodule.setup()
70
+
71
+ # --- 2. Initialize Model ---
72
+ print("Initializing SASRec model...")
73
+ model = SASRec(
74
+ vocab_size=datamodule.vocab_size,
75
+ max_len=max_token_len,
76
+ hidden_dim=hidden_dim,
77
+ num_heads=num_heads,
78
+ num_layers=num_layers,
79
+ dropout=dropout,
80
+ learning_rate=learning_rate,
81
+ weight_decay=weight_decay
82
+ )
83
+
84
+ # --- 3. Configure Training Callbacks ---
85
+ checkpoint_callback = ModelCheckpoint(
86
+ dirpath=checkpoint_dir_path,
87
+ filename="sasrec-{epoch:02d}-{val_hitrate@10:.4f}",
88
+ save_top_k=1,
89
+ verbose=True,
90
+ monitor="val_hitrate@10",
91
+ mode="max"
92
+ )
93
+
94
+ early_stopping_callback = EarlyStopping(
95
+ monitor="val_hitrate@10", # stop if ranking metric stagnates
96
+ patience=5,
97
+ mode="max"
98
+ )
99
+
100
+ lr_monitor = LearningRateMonitor(logging_interval="step")
101
+
102
+ logger = TensorBoardLogger("lightning_logs", name="sasrec")
103
+
104
+ # --- 4. Initialize Trainer ---
105
+ print("Initializing PyTorch Lightning Trainer...")
106
+ trainer = pl.Trainer(
107
+ logger=logger,
108
+ callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
109
+ max_epochs=n_epochs,
110
+ accelerator='auto',
111
+ devices=1,
112
+ gradient_clip_val=1.0, # helps with exploding gradients
113
+ )
114
+
115
+ if mode == 'train' :
116
+ # --- 5. Start Training ---
117
+ print(f"Starting training for up to {n_epochs} epochs...")
118
+ trainer.fit(model, datamodule,
119
+ ckpt_path=checkpoint_path
120
+ )
121
+
122
+ elif mode == 'test':
123
+ # --- 6. Test on best checkpoint ---
124
+ print("Evaluating on test set...")
125
+ trainer.test(model, datamodule,
126
+ ckpt_path=checkpoint_path
127
+ )
128
+
129
+ # --- Main Execution Block ---
130
+ if __name__ == "__main__":
131
+
132
+ # --- Configuration ---
133
+ BATCH_SIZE = 256
134
+ MAX_TOKEN_LEN = 50 # 50–100 is standard
135
+ LEARNING_RATE = 1e-3
136
+ HIDDEN_DIM = 128
137
+ NUM_HEADS = 2
138
+ NUM_LAYERS = 2
139
+ DROPOUT = 0.2
140
+ WEIGHT_DECAY = 1e-6
141
+ N_EPOCHS = 50
142
+ CHECKPOINT_SAVE_PATH = 'checkpoints/'
143
+ CHECKPOINT_LOAD_PATH = None # or specify a path to a checkpoint file
144
+ MODE = 'train' # 'train' or 'test'
145
+
146
+ train_set, validation_set, test_set = prepare_data(data_folder='data/')
147
+ if train_set is not None:
148
+ results = {}
149
+ full_train_set = pd.concat([train_set, validation_set])
150
+
151
+ # Evaluate classical models
152
+ print("\n>>> Running evaluations on the VALIDATION set <<<")
153
+ results['Popularity (Validation)'] = recommend_popular_items_and_evaluate(train_set, validation_set)
154
+ results['Item-Item CF (Validation)'] = recommend_item_item_and_evaluate(train_set, validation_set)
155
+ results['ALS (Validation)'] = recommend_als_and_evaluate(train_set, validation_set)
156
+
157
+ print("\n>>> Running final evaluations on the TEST set <<<")
158
+ results['Popularity (Test)'] = recommend_popular_items_and_evaluate(full_train_set, test_set)
159
+ results['Item-Item CF (Test)'] = recommend_item_item_and_evaluate(full_train_set, test_set)
160
+ results['ALS (Test)'] = recommend_als_and_evaluate(full_train_set, test_set)
161
+
162
+ print("\n--- Final Evaluation Results ---")
163
+ results_df = pd.DataFrame.from_dict(results, orient='index')
164
+ print(results_df)
165
+ print("--------------------------------")
166
+
167
+ # Train and evaluate SASRec model
168
+ print("\n>>> Training and evaluating SASRec model <<<")
169
+ train_and_eval_SASRec_model(train_set, validation_set, test_set, n_epochs=10, mode='train')
170
+
171
+ print("\n>>> Evaluating trained SASRec model on TEST set <<<")
172
+ train_and_eval_SASRec_model(train_set, validation_set, test_set, mode='test')
scripts/utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import datetime
5
+
6
+ def calculate_metrics(recommendations_dict, ground_truth_dict, k):
7
+ """
8
+ Calculates Precision@k, Recall@k, and HitRate@k.
9
+
10
+ args:
11
+ ----------
12
+ recommendations_dict : {user_id: [recommended_item_ids]}
13
+ ground_truth_dict : {user_id: set of ground truth item_ids}
14
+ k : int
15
+
16
+ Returns
17
+ -------
18
+ dict with mean precision, recall, and hit rate
19
+ """
20
+ all_precisions, all_recalls, all_hits = [], [], []
21
+
22
+ for user_id, true_items in ground_truth_dict.items():
23
+ recs = recommendations_dict.get(user_id, [])[:k]
24
+ if not true_items:
25
+ continue
26
+ hits = len(set(recs) & true_items)
27
+
28
+ precision = hits / k if k > 0 else 0
29
+ recall = hits / len(true_items)
30
+ hit_rate = 1.0 if hits > 0 else 0.0
31
+
32
+ all_precisions.append(precision)
33
+ all_recalls.append(recall)
34
+ all_hits.append(hit_rate)
35
+
36
+ if not all_precisions:
37
+ return {"mean_precision@k": 0, "mean_recall@k": 0, "mean_hitrate@k": 0}
38
+
39
+ return {
40
+ "mean_precision@k": np.mean(all_precisions),
41
+ "mean_recall@k": np.mean(all_recalls),
42
+ "mean_hitrate@k": np.mean(all_hits)
43
+ }
44
+
45
+ def prepare_ground_truth(df, mode="purchase", event_weights=None):
46
+ """
47
+ Prepares ground truth dictionaries for evaluation.
48
+
49
+ Parameters
50
+ ----------
51
+ df : pd.DataFrame
52
+ Test dataframe containing at least ['visitorid', 'itemid', 'event'].
53
+ mode : str, default="purchase"
54
+ - "purchase" : Only use transactions as ground truth.
55
+ - "all" : Use all events. Optionally weight them.
56
+ event_weights : dict, optional
57
+ Example: {"view": 1, "addtocart": 3, "transaction": 5}.
58
+ Used only if mode == "all".
59
+
60
+ Returns
61
+ -------
62
+ dict : {user_id: set of item_ids}
63
+ """
64
+ if mode == "purchase":
65
+ df_filtered = df[df["event"] == "transaction"]
66
+ ground_truth = df_filtered.groupby("visitorid")["itemid"].apply(set).to_dict()
67
+
68
+ elif mode == "all":
69
+ if event_weights is None:
70
+ # Default: treat all events equally
71
+ ground_truth = df.groupby("visitorid")["itemid"].apply(set).to_dict()
72
+ else:
73
+ # Weighted ground truth (for more advanced eval)
74
+ ground_truth = {}
75
+ for uid, user_df in df.groupby("visitorid"):
76
+ weighted_items = []
77
+ for _, row in user_df.iterrows():
78
+ weight = event_weights.get(row["event"], 1)
79
+ weighted_items.extend([row["itemid"]] * weight)
80
+ ground_truth[uid] = set(weighted_items)
81
+ else:
82
+ raise ValueError("mode must be 'purchase' or 'all'")
83
+
84
+ return ground_truth
85
+
86
+ def load_item_properties(data_folder='data/'):
87
+ """
88
+ Loads item properties and creates a mapping from item ID to its category ID.
89
+ Handles both a single properties file or two split parts.
90
+
91
+ Args:
92
+ data_folder (str): The path to the folder containing item property files.
93
+
94
+ Returns:
95
+ dict: A dictionary mapping {itemid: categoryid}.
96
+ """
97
+ print("Loading item properties...")
98
+ try:
99
+ # First, try to load the two separate parts and combine them.
100
+ props_df_part1 = pd.read_csv(data_folder + 'item_properties_part1.csv')
101
+ props_df_part2 = pd.read_csv(data_folder + 'item_properties_part2.csv')
102
+ props_df = pd.concat([props_df_part1, props_df_part2], ignore_index=True)
103
+ print("Successfully loaded and combined item_properties_part1.csv and item_properties_part2.csv.")
104
+
105
+ except FileNotFoundError:
106
+ try:
107
+ # If the parts are not found, try to load a single combined file.
108
+ props_df = pd.read_csv(data_folder + 'item_properties.csv')
109
+ print("Successfully loaded a single item_properties.csv.")
110
+ except FileNotFoundError:
111
+ print(f"Warning: No item properties files found. Cannot display category information.")
112
+ return {}
113
+
114
+ category_df = props_df[props_df['property'] == 'categoryid'].copy()
115
+ category_df['value'] = pd.to_numeric(category_df['value'], errors='coerce').astype('Int64')
116
+ item_to_category_map = category_df.set_index('itemid')['value'].to_dict()
117
+ print("Item to category mapping created successfully.")
118
+ return item_to_category_map
119
+
120
+ def load_category_tree(data_folder='data/'):
121
+ """
122
+ Loads the category tree to map categories to their parent categories.
123
+
124
+ Args:
125
+ data_folder (str): The path to the folder containing category_tree.csv.
126
+
127
+ Returns:
128
+ dict: A dictionary mapping {categoryid: parentid}.
129
+ """
130
+ print("Loading category tree...")
131
+ try:
132
+ tree_df = pd.read_csv(data_folder + 'category_tree.csv')
133
+ category_parent_map = tree_df.set_index('categoryid')['parentid'].to_dict()
134
+ print("Category tree loaded successfully.")
135
+ return category_parent_map
136
+ except FileNotFoundError:
137
+ print("Warning: 'category_tree.csv' not found. Cannot display parent category information.")
138
+ return {}
139
+
140
+ def get_popular_items(train_df, k=10):
141
+ """
142
+ Calculates the top-k most popular items based on transaction count.
143
+ """
144
+ purchase_counts = train_df[train_df['event'] == 'transaction']['itemid'].value_counts()
145
+ return purchase_counts.head(k).index.tolist()
146
+
147
+ def show_user_recommendations(visitor_id, model, datamodule, popular_items, item_category_map, category_parent_map, k=10):
148
+ """
149
+ Displays recommendations for a user, including category and parent category information.
150
+ """
151
+ print(f"\n--- Recommendations for Visitor ID: {visitor_id} ---")
152
+ model.eval()
153
+
154
+ def format_item_with_category(item_id):
155
+ category_id = item_category_map.get(item_id, 'N/A')
156
+ parent_id = category_parent_map.get(category_id, 'N/A') if category_id != 'N/A' else 'N/A'
157
+ return f"Item: {item_id} (Category: {category_id}, Parent: {parent_id})"
158
+
159
+ user_history_ids = datamodule.user_history.get(visitor_id)
160
+
161
+ if user_history_ids is None:
162
+ print(f"User {visitor_id} not found in training history. Providing popularity-based recommendations.")
163
+ print(f"\nTop {k} Popular Items (Fallback):")
164
+ recs_with_cats = [format_item_with_category(item_id) for item_id in popular_items]
165
+ print(recs_with_cats)
166
+ print("-------------------------------------------------")
167
+ return
168
+
169
+ history_with_cats = [format_item_with_category(item_id) for item_id in user_history_ids]
170
+ print(f"User's Historical Interactions:")
171
+ print(history_with_cats)
172
+
173
+ history_indices = [datamodule.item_map[i] for i in user_history_ids if i in datamodule.item_map]
174
+ if not history_indices:
175
+ print("None of the user's historical items are in the model's vocabulary.")
176
+ return
177
+
178
+ max_len = datamodule.max_len
179
+ input_seq = history_indices[-max_len:]
180
+ padded_input = np.zeros(max_len, dtype=np.int64)
181
+ padded_input[-len(input_seq):] = input_seq
182
+
183
+ input_tensor = torch.LongTensor(np.array([padded_input]))
184
+ input_tensor = input_tensor.to(model.device)
185
+
186
+ with torch.no_grad():
187
+ logits = model(input_tensor)
188
+ last_item_logits = logits[0, -1, :]
189
+ top_indices = torch.topk(last_item_logits, k).indices.tolist()
190
+
191
+ recommended_item_ids = [datamodule.inverse_item_map[idx] for idx in top_indices if idx in datamodule.inverse_item_map]
192
+
193
+ print(f"\nTop {k} Recommended Items:")
194
+ recs_with_cats = [format_item_with_category(item_id) for item_id in recommended_item_ids]
195
+ print(recs_with_cats)
196
+ print("-------------------------------------------------")