Upload 17 files
Browse files- .gitattributes +4 -0
- README.md +564 -0
- app.py +628 -0
- data/sales.csv +0 -0
- generate_dataset.py +128 -0
- models/all_models_metadata.json +26 -0
- models/best_model.joblib +3 -0
- models/model_metadata.json +21 -0
- models/preprocessing.joblib +3 -0
- plots/demand_trends.png +3 -0
- plots/feature_importance.png +3 -0
- plots/model_comparison.png +3 -0
- plots/monthly_demand.png +3 -0
- predict.py +403 -0
- requirements.txt +10 -0
- setup_env.bat +26 -0
- setup_env.sh +25 -0
- train_model.py +877 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
plots/demand_trends.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
plots/feature_importance.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
plots/model_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
plots/monthly_demand.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demand Prediction System for E-commerce
|
| 2 |
+
|
| 3 |
+
A complete machine learning and time-series forecasting system for predicting product demand (sales quantity) in e-commerce. Compares both supervised learning regression models and time-series models (ARIMA, Prophet) to find the best approach.
|
| 4 |
+
|
| 5 |
+
## 📋 Table of Contents
|
| 6 |
+
|
| 7 |
+
- [Overview](#overview)
|
| 8 |
+
- [Features](#features)
|
| 9 |
+
- [Project Structure](#project-structure)
|
| 10 |
+
- [Installation](#installation)
|
| 11 |
+
- [Dataset](#dataset)
|
| 12 |
+
- [Usage](#usage)
|
| 13 |
+
- [Model Details](#model-details)
|
| 14 |
+
- [Evaluation Metrics](#evaluation-metrics)
|
| 15 |
+
- [Visualizations](#visualizations)
|
| 16 |
+
- [Example Predictions](#example-predictions)
|
| 17 |
+
- [Technical Details](#technical-details)
|
| 18 |
+
|
| 19 |
+
## 🎯 Overview
|
| 20 |
+
|
| 21 |
+
This project implements a demand prediction system that uses historical sales data to forecast future product demand. The system compares two approaches:
|
| 22 |
+
|
| 23 |
+
1. **Machine Learning Models**: Treat demand prediction as a regression problem using product features (price, discount, category, date features)
|
| 24 |
+
2. **Time-Series Models**: Treat demand prediction as a time-series problem using historical patterns (ARIMA, Prophet)
|
| 25 |
+
|
| 26 |
+
The system automatically selects the best performing model across both approaches.
|
| 27 |
+
|
| 28 |
+
**Key Capabilities:**
|
| 29 |
+
- Predicts sales quantity for products on future dates (ML models)
|
| 30 |
+
- Predicts overall daily demand (Time-series models)
|
| 31 |
+
- Handles temporal patterns and seasonality
|
| 32 |
+
- Considers price, discount, category, and date features (ML models)
|
| 33 |
+
- Captures time-series patterns and trends (Time-series models)
|
| 34 |
+
- Automatically selects the best model from multiple candidates
|
| 35 |
+
- Provides comprehensive evaluation metrics
|
| 36 |
+
- Compares ML vs Time-Series approaches
|
| 37 |
+
|
| 38 |
+
## ✨ Features
|
| 39 |
+
|
| 40 |
+
- **Data Preprocessing**: Automatic handling of missing values, date feature extraction
|
| 41 |
+
- **Feature Engineering**:
|
| 42 |
+
- Date features (day, month, day_of_week, weekend, year, quarter)
|
| 43 |
+
- Categorical encoding (product_id, category)
|
| 44 |
+
- Feature scaling
|
| 45 |
+
- **Multiple Models**:
|
| 46 |
+
- **Machine Learning Models:**
|
| 47 |
+
- Linear Regression
|
| 48 |
+
- Random Forest Regressor
|
| 49 |
+
- XGBoost Regressor (optional)
|
| 50 |
+
- **Time-Series Models:**
|
| 51 |
+
- ARIMA (AutoRegressive Integrated Moving Average)
|
| 52 |
+
- Prophet (Facebook's time-series forecasting tool)
|
| 53 |
+
- **Model Selection**: Automatic best model selection based on R2 score
|
| 54 |
+
- **Evaluation Metrics**: MAE, RMSE, and R2 Score
|
| 55 |
+
- **Visualizations**:
|
| 56 |
+
- Demand trends over time
|
| 57 |
+
- Monthly average demand
|
| 58 |
+
- Feature importance
|
| 59 |
+
- Model comparison
|
| 60 |
+
- **Model Persistence**: Save and load trained models using joblib
|
| 61 |
+
- **Future Predictions**: Predict demand for any product on any future date
|
| 62 |
+
|
| 63 |
+
## 📁 Project Structure
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
demand_prediction/
|
| 67 |
+
│
|
| 68 |
+
├── data/
|
| 69 |
+
│ └── sales.csv # Sales dataset
|
| 70 |
+
│
|
| 71 |
+
├── models/ # Generated during training
|
| 72 |
+
│ ├── best_model.joblib # Best ML model (if ML is best)
|
| 73 |
+
│ ├── best_timeseries_model.joblib # Best time-series model (if TS is best)
|
| 74 |
+
│ ├── preprocessing.joblib # Encoders and scaler (for ML models)
|
| 75 |
+
│ ├── model_metadata.json # Model metadata (legacy)
|
| 76 |
+
│ └── all_models_metadata.json # All models comparison metadata
|
| 77 |
+
│
|
| 78 |
+
├── plots/ # Generated during training
|
| 79 |
+
│ ├── demand_trends.png # Time series plot
|
| 80 |
+
│ ├── monthly_demand.png # Monthly averages
|
| 81 |
+
│ ├── feature_importance.png # Feature importance (ML models)
|
| 82 |
+
│ ├── model_comparison.png # Model metrics comparison (all models)
|
| 83 |
+
│ └── timeseries_predictions.png # Time-series model predictions
|
| 84 |
+
│
|
| 85 |
+
├── generate_dataset.py # Script to generate synthetic dataset
|
| 86 |
+
├── train_model.py # Main training script
|
| 87 |
+
├── predict.py # Prediction script
|
| 88 |
+
├── app.py # Streamlit dashboard (interactive web app)
|
| 89 |
+
├── requirements.txt # Python dependencies
|
| 90 |
+
└── README.md # This file
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## 🚀 Installation
|
| 94 |
+
|
| 95 |
+
### Prerequisites
|
| 96 |
+
|
| 97 |
+
- Python 3.8 or higher
|
| 98 |
+
- pip (Python package manager)
|
| 99 |
+
|
| 100 |
+
### Step 1: Navigate to Project Directory
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
cd demand_prediction
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
### Step 2: Create Virtual Environment (Recommended)
|
| 107 |
+
|
| 108 |
+
**Why use a virtual environment?**
|
| 109 |
+
- Keeps project dependencies isolated from your system Python
|
| 110 |
+
- Prevents conflicts with other projects
|
| 111 |
+
- Makes it easier to manage package versions
|
| 112 |
+
- Best practice for Python projects
|
| 113 |
+
|
| 114 |
+
**Quick Setup (Recommended):**
|
| 115 |
+
|
| 116 |
+
**Windows:**
|
| 117 |
+
```bash
|
| 118 |
+
setup_env.bat
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
**Linux/Mac:**
|
| 122 |
+
```bash
|
| 123 |
+
chmod +x setup_env.sh
|
| 124 |
+
./setup_env.sh
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
**Manual Setup:**
|
| 128 |
+
|
| 129 |
+
**Windows:**
|
| 130 |
+
```bash
|
| 131 |
+
python -m venv venv
|
| 132 |
+
venv\Scripts\activate
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
**Linux/Mac:**
|
| 136 |
+
```bash
|
| 137 |
+
python3 -m venv venv
|
| 138 |
+
source venv/bin/activate
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
After activation, you should see `(venv)` in your terminal prompt.
|
| 142 |
+
|
| 143 |
+
**To deactivate later:**
|
| 144 |
+
```bash
|
| 145 |
+
deactivate
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### Step 3: Install Dependencies
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
pip install -r requirements.txt
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
**Note**: If you don't want to use XGBoost, you can remove it from `requirements.txt`. The system will work fine without it, just skipping XGBoost model training.
|
| 155 |
+
|
| 156 |
+
**Alternative (without virtual environment):**
|
| 157 |
+
If you prefer not to use a virtual environment, you can install directly:
|
| 158 |
+
```bash
|
| 159 |
+
pip install -r requirements.txt
|
| 160 |
+
```
|
| 161 |
+
However, this is **not recommended** as it may cause conflicts with other Python projects.
|
| 162 |
+
|
| 163 |
+
### Step 4: Generate Dataset
|
| 164 |
+
|
| 165 |
+
If you don't have a dataset, generate a synthetic one:
|
| 166 |
+
|
| 167 |
+
```bash
|
| 168 |
+
python generate_dataset.py
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
This will create `data/sales.csv` with realistic e-commerce sales data.
|
| 172 |
+
|
| 173 |
+
## 📊 Dataset
|
| 174 |
+
|
| 175 |
+
The dataset should contain the following columns:
|
| 176 |
+
|
| 177 |
+
- **product_id**: Unique identifier for each product (integer)
|
| 178 |
+
- **date**: Date of sale (YYYY-MM-DD format)
|
| 179 |
+
- **price**: Product price (float)
|
| 180 |
+
- **discount**: Discount percentage (0-100, float)
|
| 181 |
+
- **category**: Product category (string)
|
| 182 |
+
- **sales_quantity**: Target variable - number of units sold (integer)
|
| 183 |
+
|
| 184 |
+
### Dataset Format Example
|
| 185 |
+
|
| 186 |
+
```csv
|
| 187 |
+
product_id,date,price,discount,category,sales_quantity
|
| 188 |
+
1,2020-01-01,499.99,10,Electronics,45
|
| 189 |
+
2,2020-01-01,29.99,0,Clothing,120
|
| 190 |
+
...
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
## 💻 Usage
|
| 194 |
+
|
| 195 |
+
### Step 1: Train the Model
|
| 196 |
+
|
| 197 |
+
Train the model using the sales dataset:
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
python train_model.py
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
This will:
|
| 204 |
+
1. Load and preprocess the data
|
| 205 |
+
2. Extract features from dates
|
| 206 |
+
3. Encode categorical variables
|
| 207 |
+
4. Train multiple ML models (Linear Regression, Random Forest, XGBoost)
|
| 208 |
+
5. Prepare time-series data (aggregate daily sales)
|
| 209 |
+
6. Train time-series models (ARIMA, Prophet)
|
| 210 |
+
7. Evaluate each model using MAE, RMSE, and R2 Score
|
| 211 |
+
8. Compare ML vs Time-Series models
|
| 212 |
+
9. Select the best model automatically (across all model types)
|
| 213 |
+
10. Save the model and preprocessing objects
|
| 214 |
+
11. Generate visualizations
|
| 215 |
+
|
| 216 |
+
**Output:**
|
| 217 |
+
- Best model saved to `models/best_model.joblib` (ML) or `models/best_timeseries_model.joblib` (TS)
|
| 218 |
+
- Preprocessing objects saved to `models/preprocessing.joblib` (for ML models)
|
| 219 |
+
- Visualizations saved to `plots/` directory
|
| 220 |
+
- All models metadata saved to `models/all_models_metadata.json`
|
| 221 |
+
|
| 222 |
+
### Step 2: Make Predictions
|
| 223 |
+
|
| 224 |
+
**For ML Models (product-specific predictions):**
|
| 225 |
+
|
| 226 |
+
Predict demand for a specific product on a date:
|
| 227 |
+
|
| 228 |
+
```bash
|
| 229 |
+
python predict.py --product_id 1 --date 2024-01-15 --price 100 --discount 10 --category Electronics
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
**Parameters for ML Models:**
|
| 233 |
+
- `--product_id`: Product ID (integer, required)
|
| 234 |
+
- `--date`: Date in YYYY-MM-DD format (required)
|
| 235 |
+
- `--price`: Product price (float, required)
|
| 236 |
+
- `--discount`: Discount percentage 0-100 (float, default: 0)
|
| 237 |
+
- `--category`: Product category (string, required)
|
| 238 |
+
- `--model_type`: Model type - `auto` (default), `ml`, or `timeseries`
|
| 239 |
+
|
| 240 |
+
**For Time-Series Models (overall daily demand):**
|
| 241 |
+
|
| 242 |
+
Predict total daily demand across all products:
|
| 243 |
+
|
| 244 |
+
```bash
|
| 245 |
+
python predict.py --date 2024-01-15 --model_type timeseries
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
**Parameters for Time-Series Models:**
|
| 249 |
+
- `--date`: Date in YYYY-MM-DD format (required)
|
| 250 |
+
- `--model_type`: Set to `timeseries` to use time-series models
|
| 251 |
+
|
| 252 |
+
**Example Predictions:**
|
| 253 |
+
|
| 254 |
+
```bash
|
| 255 |
+
# ML Model - Electronics product with discount
|
| 256 |
+
python predict.py --product_id 1 --date 2024-06-15 --price 500 --discount 20 --category Electronics
|
| 257 |
+
|
| 258 |
+
# ML Model - Clothing product without discount
|
| 259 |
+
python predict.py --product_id 5 --date 2024-12-25 --price 50 --discount 0 --category Clothing
|
| 260 |
+
|
| 261 |
+
# Time-Series Model - Overall daily demand
|
| 262 |
+
python predict.py --date 2024-07-06 --model_type timeseries
|
| 263 |
+
|
| 264 |
+
# Auto-detect best model (default)
|
| 265 |
+
python predict.py --product_id 10 --date 2024-07-06 --price 75 --discount 15 --category Sports
|
| 266 |
+
```
|
| 267 |
+
|
| 268 |
+
### Step 3: Launch Interactive Dashboard (Optional)
|
| 269 |
+
|
| 270 |
+
Launch the Streamlit dashboard for interactive visualization and predictions:
|
| 271 |
+
|
| 272 |
+
```bash
|
| 273 |
+
streamlit run app.py
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
The dashboard will open in your default web browser (usually at `http://localhost:8501`).
|
| 277 |
+
|
| 278 |
+
**Dashboard Features:**
|
| 279 |
+
|
| 280 |
+
1. **📈 Sales Trends Page**
|
| 281 |
+
- Interactive filters (category, product, date range)
|
| 282 |
+
- Daily sales trends visualization
|
| 283 |
+
- Monthly sales trends
|
| 284 |
+
- Category-wise analysis
|
| 285 |
+
- Price vs demand relationship
|
| 286 |
+
- Real-time statistics and metrics
|
| 287 |
+
|
| 288 |
+
2. **🔮 Demand Prediction Page**
|
| 289 |
+
- Interactive prediction interface
|
| 290 |
+
- Select model type (Auto/ML/Time-Series)
|
| 291 |
+
- For ML models:
|
| 292 |
+
- Product selection dropdown
|
| 293 |
+
- Category selection
|
| 294 |
+
- Price and discount sliders
|
| 295 |
+
- Date picker
|
| 296 |
+
- Product statistics display
|
| 297 |
+
- For Time-Series models:
|
| 298 |
+
- Date picker for future predictions
|
| 299 |
+
- Overall daily demand forecast
|
| 300 |
+
- Prediction insights and recommendations
|
| 301 |
+
|
| 302 |
+
3. **📊 Model Comparison Page**
|
| 303 |
+
- Side-by-side model performance comparison
|
| 304 |
+
- MAE, RMSE, and R2 Score metrics
|
| 305 |
+
- Visual charts comparing all models
|
| 306 |
+
- Best model highlighting
|
| 307 |
+
- Model type indicators (ML vs Time-Series)
|
| 308 |
+
|
| 309 |
+
**Dashboard Screenshots:**
|
| 310 |
+
- Interactive widgets for easy data exploration
|
| 311 |
+
- Real-time predictions with visual feedback
|
| 312 |
+
- Comprehensive model comparison charts
|
| 313 |
+
|
| 314 |
+
## 🤖 Model Details
|
| 315 |
+
|
| 316 |
+
### Models Trained
|
| 317 |
+
|
| 318 |
+
1. **Linear Regression**
|
| 319 |
+
- Simple linear model
|
| 320 |
+
- Fast training and prediction
|
| 321 |
+
- Good baseline model
|
| 322 |
+
|
| 323 |
+
2. **Random Forest Regressor**
|
| 324 |
+
- Ensemble of decision trees
|
| 325 |
+
- Handles non-linear relationships
|
| 326 |
+
- Provides feature importance
|
| 327 |
+
- Hyperparameters:
|
| 328 |
+
- n_estimators: 100
|
| 329 |
+
- max_depth: 15
|
| 330 |
+
- min_samples_split: 5
|
| 331 |
+
- min_samples_leaf: 2
|
| 332 |
+
|
| 333 |
+
3. **XGBoost Regressor** (Optional)
|
| 334 |
+
- Gradient boosting algorithm
|
| 335 |
+
- Often provides best performance
|
| 336 |
+
- Handles complex patterns
|
| 337 |
+
- Hyperparameters:
|
| 338 |
+
- n_estimators: 100
|
| 339 |
+
- max_depth: 6
|
| 340 |
+
- learning_rate: 0.1
|
| 341 |
+
|
| 342 |
+
4. **ARIMA** (AutoRegressive Integrated Moving Average)
|
| 343 |
+
- Classic time-series forecasting model
|
| 344 |
+
- Captures trends and seasonality
|
| 345 |
+
- Automatically selects best order (p, d, q)
|
| 346 |
+
- Works on aggregated daily sales data
|
| 347 |
+
- Uses chronological train/validation split
|
| 348 |
+
|
| 349 |
+
5. **Prophet** (Facebook's Time-Series Forecasting)
|
| 350 |
+
- Designed for business time series
|
| 351 |
+
- Handles seasonality (weekly, yearly)
|
| 352 |
+
- Robust to missing data and outliers
|
| 353 |
+
- Works on aggregated daily sales data
|
| 354 |
+
- Uses chronological train/validation split
|
| 355 |
+
|
| 356 |
+
### Model Comparison: ML vs Time-Series
|
| 357 |
+
|
| 358 |
+
**Machine Learning Models:**
|
| 359 |
+
- ✅ Predict per-product demand
|
| 360 |
+
- ✅ Use product features (price, discount, category)
|
| 361 |
+
- ✅ Can handle new products with similar features
|
| 362 |
+
- ❌ May not capture long-term temporal patterns as well
|
| 363 |
+
|
| 364 |
+
**Time-Series Models:**
|
| 365 |
+
- ✅ Capture temporal patterns and trends
|
| 366 |
+
- ✅ Handle seasonality automatically
|
| 367 |
+
- ✅ Good for overall demand forecasting
|
| 368 |
+
- ❌ Predict aggregate demand, not per-product
|
| 369 |
+
- ❌ Don't use product-specific features
|
| 370 |
+
|
| 371 |
+
**The system automatically selects the best model based on R2 score across all model types.**
|
| 372 |
+
|
| 373 |
+
### Feature Engineering
|
| 374 |
+
|
| 375 |
+
**For ML Models:**
|
| 376 |
+
|
| 377 |
+
The system extracts the following features from the input data:
|
| 378 |
+
|
| 379 |
+
**Date Features:**
|
| 380 |
+
- `day`: Day of month (1-31)
|
| 381 |
+
- `month`: Month (1-12)
|
| 382 |
+
- `day_of_week`: Day of week (0=Monday, 6=Sunday)
|
| 383 |
+
- `weekend`: Binary indicator (1 if weekend, 0 otherwise)
|
| 384 |
+
- `year`: Year
|
| 385 |
+
- `quarter`: Quarter of year (1-4)
|
| 386 |
+
|
| 387 |
+
**Original Features:**
|
| 388 |
+
- `product_id`: Encoded as categorical
|
| 389 |
+
- `price`: Numerical (scaled)
|
| 390 |
+
- `discount`: Numerical (scaled)
|
| 391 |
+
- `category`: Encoded as categorical
|
| 392 |
+
|
| 393 |
+
**Total Features**: 10 features after encoding and scaling
|
| 394 |
+
|
| 395 |
+
**For Time-Series Models:**
|
| 396 |
+
|
| 397 |
+
- Data is aggregated by date (total daily sales)
|
| 398 |
+
- Uses chronological split (80% train, 20% validation)
|
| 399 |
+
- Prophet automatically handles:
|
| 400 |
+
- Weekly seasonality
|
| 401 |
+
- Yearly seasonality
|
| 402 |
+
- Trend components
|
| 403 |
+
|
| 404 |
+
## 📈 Evaluation Metrics
|
| 405 |
+
|
| 406 |
+
The system evaluates models using three metrics:
|
| 407 |
+
|
| 408 |
+
1. **MAE (Mean Absolute Error)**
|
| 409 |
+
- Average absolute difference between predicted and actual values
|
| 410 |
+
- Lower is better
|
| 411 |
+
- Units: same as target variable (sales quantity)
|
| 412 |
+
|
| 413 |
+
2. **RMSE (Root Mean Squared Error)**
|
| 414 |
+
- Square root of average squared differences
|
| 415 |
+
- Penalizes large errors more than MAE
|
| 416 |
+
- Lower is better
|
| 417 |
+
- Units: same as target variable (sales quantity)
|
| 418 |
+
|
| 419 |
+
3. **R2 Score (Coefficient of Determination)**
|
| 420 |
+
- Proportion of variance explained by the model
|
| 421 |
+
- Range: -∞ to 1 (1 is perfect prediction)
|
| 422 |
+
- Higher is better
|
| 423 |
+
- Used for model selection
|
| 424 |
+
|
| 425 |
+
**Model Selection**: The model with the highest R2 score is selected as the best model.
|
| 426 |
+
|
| 427 |
+
## 📊 Visualizations
|
| 428 |
+
|
| 429 |
+
The training script generates several visualizations:
|
| 430 |
+
|
| 431 |
+
1. **Demand Trends Over Time** (`plots/demand_trends.png`)
|
| 432 |
+
- Shows total daily sales quantity over the entire time period
|
| 433 |
+
- Helps identify overall trends and patterns
|
| 434 |
+
|
| 435 |
+
2. **Monthly Average Demand** (`plots/monthly_demand.png`)
|
| 436 |
+
- Bar chart showing average sales by month
|
| 437 |
+
- Reveals seasonal patterns (e.g., holiday season spikes)
|
| 438 |
+
|
| 439 |
+
3. **Feature Importance** (`plots/feature_importance.png`)
|
| 440 |
+
- Shows which features are most important for predictions
|
| 441 |
+
- Only available for tree-based models (Random Forest, XGBoost)
|
| 442 |
+
|
| 443 |
+
4. **Model Comparison** (`plots/model_comparison.png`)
|
| 444 |
+
- Side-by-side comparison of all models (ML and Time-Series)
|
| 445 |
+
- Color-coded: ML models (blue) vs Time-Series models (orange/red)
|
| 446 |
+
- Shows MAE, RMSE, and R2 Score for each model
|
| 447 |
+
|
| 448 |
+
5. **Time-Series Predictions** (`plots/timeseries_predictions.png`)
|
| 449 |
+
- Actual vs predicted plots for ARIMA and Prophet models
|
| 450 |
+
- Shows how well time-series models capture temporal patterns
|
| 451 |
+
- Only generated if time-series models are available
|
| 452 |
+
|
| 453 |
+
## 🔮 Example Predictions
|
| 454 |
+
|
| 455 |
+
Here are some example predictions to demonstrate the system:
|
| 456 |
+
|
| 457 |
+
```bash
|
| 458 |
+
# Example 1: Electronics on a weekday
|
| 459 |
+
python predict.py --product_id 1 --date 2024-03-15 --price 500 --discount 10 --category Electronics
|
| 460 |
+
# Expected: Moderate demand (weekday, some discount)
|
| 461 |
+
|
| 462 |
+
# Example 2: Clothing on weekend
|
| 463 |
+
python predict.py --product_id 5 --date 2024-07-06 --price 50 --discount 20 --category Clothing
|
| 464 |
+
# Expected: Higher demand (weekend, good discount)
|
| 465 |
+
|
| 466 |
+
# Example 3: Holiday season prediction
|
| 467 |
+
python predict.py --product_id 10 --date 2024-12-20 --price 100 --discount 25 --category Toys
|
| 468 |
+
# Expected: High demand (holiday season, good discount)
|
| 469 |
+
```
|
| 470 |
+
|
| 471 |
+
## 🔧 Technical Details
|
| 472 |
+
|
| 473 |
+
### Data Preprocessing Pipeline
|
| 474 |
+
|
| 475 |
+
1. **Date Conversion**: Convert date strings to datetime objects
|
| 476 |
+
2. **Feature Extraction**: Extract temporal features from dates
|
| 477 |
+
3. **Missing Value Handling**: Fill missing values with median (if any)
|
| 478 |
+
4. **Categorical Encoding**: Label encode product_id and category
|
| 479 |
+
5. **Feature Scaling**: Standardize numerical features using StandardScaler
|
| 480 |
+
|
| 481 |
+
### Model Training Pipeline
|
| 482 |
+
|
| 483 |
+
1. **Data Splitting**: 80% training, 20% validation
|
| 484 |
+
2. **Model Training**: Train all available models
|
| 485 |
+
3. **Evaluation**: Calculate MAE, RMSE, and R2 for each model
|
| 486 |
+
4. **Selection**: Choose model with highest R2 score
|
| 487 |
+
5. **Persistence**: Save model, encoders, and scaler
|
| 488 |
+
|
| 489 |
+
### Prediction Pipeline
|
| 490 |
+
|
| 491 |
+
1. **Load Model**: Load trained model and preprocessing objects
|
| 492 |
+
2. **Feature Preparation**: Extract features from input parameters
|
| 493 |
+
3. **Encoding**: Encode categorical variables using saved encoders
|
| 494 |
+
4. **Scaling**: Scale features using saved scaler
|
| 495 |
+
5. **Prediction**: Make prediction using loaded model
|
| 496 |
+
6. **Post-processing**: Ensure non-negative predictions
|
| 497 |
+
|
| 498 |
+
### Handling Unseen Data
|
| 499 |
+
|
| 500 |
+
The prediction script handles cases where:
|
| 501 |
+
- Product ID was not seen during training (uses default encoding)
|
| 502 |
+
- Category was not seen during training (uses default encoding)
|
| 503 |
+
|
| 504 |
+
Warnings are displayed in such cases.
|
| 505 |
+
|
| 506 |
+
## 🎓 Learning Points
|
| 507 |
+
|
| 508 |
+
This project demonstrates:
|
| 509 |
+
|
| 510 |
+
1. **Supervised Learning**: Regression problem solving
|
| 511 |
+
2. **Feature Engineering**: Creating meaningful features from raw data
|
| 512 |
+
3. **Model Comparison**: Training and evaluating multiple models
|
| 513 |
+
4. **Model Selection**: Automatic best model selection
|
| 514 |
+
5. **Model Persistence**: Saving and loading trained models
|
| 515 |
+
6. **Production-Ready Code**: Clean, modular, well-documented code
|
| 516 |
+
7. **Time Series Features**: Extracting temporal patterns
|
| 517 |
+
8. **Categorical Encoding**: Handling categorical variables
|
| 518 |
+
9. **Feature Scaling**: Normalizing features for better performance
|
| 519 |
+
10. **Evaluation Metrics**: Understanding different regression metrics
|
| 520 |
+
|
| 521 |
+
## 🐛 Troubleshooting
|
| 522 |
+
|
| 523 |
+
### Issue: "Model not found"
|
| 524 |
+
**Solution**: Run `python train_model.py` first to train and save the model.
|
| 525 |
+
|
| 526 |
+
### Issue: "XGBoost not available"
|
| 527 |
+
**Solution**: Install XGBoost with `pip install xgboost`, or the system will work without it (skipping XGBoost model).
|
| 528 |
+
|
| 529 |
+
### Issue: "Category not seen during training"
|
| 530 |
+
**Solution**: This is handled automatically with a warning. The system uses a default encoding.
|
| 531 |
+
|
| 532 |
+
### Issue: Poor prediction accuracy
|
| 533 |
+
**Solutions**:
|
| 534 |
+
- Ensure you have sufficient training data
|
| 535 |
+
- Check that input features are in the same range as training data
|
| 536 |
+
- Try retraining with different hyperparameters
|
| 537 |
+
- Consider adding more features or more training data
|
| 538 |
+
|
| 539 |
+
## 📝 Notes
|
| 540 |
+
|
| 541 |
+
- The synthetic dataset generator creates realistic patterns including:
|
| 542 |
+
- Weekend effects (higher sales on weekends)
|
| 543 |
+
- Seasonal patterns (holiday season spikes)
|
| 544 |
+
- Price and discount effects
|
| 545 |
+
- Category-specific base prices
|
| 546 |
+
|
| 547 |
+
- For production use, consider:
|
| 548 |
+
- Using real historical data
|
| 549 |
+
- Retraining models periodically
|
| 550 |
+
- Adding more features (promotions, weather, etc.)
|
| 551 |
+
- Implementing model versioning
|
| 552 |
+
- Adding prediction confidence intervals
|
| 553 |
+
|
| 554 |
+
## 📄 License
|
| 555 |
+
|
| 556 |
+
This project is provided as-is for educational purposes.
|
| 557 |
+
|
| 558 |
+
## 👤 Author
|
| 559 |
+
|
| 560 |
+
Created as a complete machine learning project demonstrating demand prediction for e-commerce.
|
| 561 |
+
|
| 562 |
+
---
|
| 563 |
+
|
| 564 |
+
**Happy Predicting! 🚀**
|
app.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demand Prediction System - Streamlit Dashboard
|
| 3 |
+
|
| 4 |
+
Interactive dashboard for visualizing sales trends and making demand predictions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import streamlit as st
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import seaborn as sns
|
| 12 |
+
import joblib
|
| 13 |
+
import json
|
| 14 |
+
from datetime import datetime, timedelta, date as dt_date
|
| 15 |
+
import os
|
| 16 |
+
import warnings
|
| 17 |
+
warnings.filterwarnings('ignore')
|
| 18 |
+
|
| 19 |
+
# Page configuration
|
| 20 |
+
st.set_page_config(
|
| 21 |
+
page_title="Demand Prediction Dashboard",
|
| 22 |
+
page_icon="📊",
|
| 23 |
+
layout="wide",
|
| 24 |
+
initial_sidebar_state="expanded"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Custom CSS for better styling
|
| 28 |
+
st.markdown("""
|
| 29 |
+
<style>
|
| 30 |
+
.main-header {
|
| 31 |
+
font-size: 2.5rem;
|
| 32 |
+
font-weight: bold;
|
| 33 |
+
color: #1f77b4;
|
| 34 |
+
text-align: center;
|
| 35 |
+
margin-bottom: 2rem;
|
| 36 |
+
}
|
| 37 |
+
.metric-card {
|
| 38 |
+
background-color: #f0f2f6;
|
| 39 |
+
padding: 1rem;
|
| 40 |
+
border-radius: 0.5rem;
|
| 41 |
+
margin: 0.5rem 0;
|
| 42 |
+
}
|
| 43 |
+
.stButton>button {
|
| 44 |
+
width: 100%;
|
| 45 |
+
background-color: #1f77b4;
|
| 46 |
+
color: white;
|
| 47 |
+
}
|
| 48 |
+
</style>
|
| 49 |
+
""", unsafe_allow_html=True)
|
| 50 |
+
|
| 51 |
+
# Configuration
|
| 52 |
+
DATA_PATH = 'data/sales.csv'
|
| 53 |
+
MODEL_DIR = 'models'
|
| 54 |
+
MODEL_PATH = f'{MODEL_DIR}/best_model.joblib'
|
| 55 |
+
TS_MODEL_PATH = f'{MODEL_DIR}/best_timeseries_model.joblib'
|
| 56 |
+
PREPROCESSING_PATH = f'{MODEL_DIR}/preprocessing.joblib'
|
| 57 |
+
ALL_MODELS_METADATA_PATH = f'{MODEL_DIR}/all_models_metadata.json'
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@st.cache_data
|
| 61 |
+
def load_data():
|
| 62 |
+
"""Load sales data with caching."""
|
| 63 |
+
if os.path.exists(DATA_PATH):
|
| 64 |
+
df = pd.read_csv(DATA_PATH)
|
| 65 |
+
df['date'] = pd.to_datetime(df['date'])
|
| 66 |
+
return df
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@st.cache_resource
|
| 71 |
+
def load_models():
|
| 72 |
+
"""Load trained models with caching."""
|
| 73 |
+
models = {
|
| 74 |
+
'ml_model': None,
|
| 75 |
+
'ts_model': None,
|
| 76 |
+
'preprocessing': None,
|
| 77 |
+
'model_name': None,
|
| 78 |
+
'is_timeseries': False,
|
| 79 |
+
'metadata': None
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
# Load metadata
|
| 83 |
+
if os.path.exists(ALL_MODELS_METADATA_PATH):
|
| 84 |
+
with open(ALL_MODELS_METADATA_PATH, 'r') as f:
|
| 85 |
+
models['metadata'] = json.load(f)
|
| 86 |
+
models['model_name'] = models['metadata'].get('best_model', 'Unknown')
|
| 87 |
+
models['is_timeseries'] = models['model_name'] in ['ARIMA', 'Prophet']
|
| 88 |
+
|
| 89 |
+
# Load ML model
|
| 90 |
+
if os.path.exists(MODEL_PATH):
|
| 91 |
+
try:
|
| 92 |
+
models['ml_model'] = joblib.load(MODEL_PATH)
|
| 93 |
+
except:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# Load time-series model
|
| 97 |
+
if os.path.exists(TS_MODEL_PATH):
|
| 98 |
+
try:
|
| 99 |
+
models['ts_model'] = joblib.load(TS_MODEL_PATH)
|
| 100 |
+
except:
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
# Load preprocessing
|
| 104 |
+
if os.path.exists(PREPROCESSING_PATH):
|
| 105 |
+
try:
|
| 106 |
+
models['preprocessing'] = joblib.load(PREPROCESSING_PATH)
|
| 107 |
+
except:
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
return models
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def prepare_features_ml(product_id, date, price, discount, category, preprocessing_data):
|
| 114 |
+
"""Prepare features for ML model prediction."""
|
| 115 |
+
if preprocessing_data is None:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
# Convert date to pandas Timestamp (handles date, datetime, and string)
|
| 119 |
+
# Handle datetime.date objects explicitly
|
| 120 |
+
if isinstance(date, dt_date):
|
| 121 |
+
date = pd.Timestamp(date)
|
| 122 |
+
elif not isinstance(date, pd.Timestamp):
|
| 123 |
+
date = pd.to_datetime(date)
|
| 124 |
+
|
| 125 |
+
# Extract date features
|
| 126 |
+
day = date.day
|
| 127 |
+
month = date.month
|
| 128 |
+
day_of_week = date.weekday()
|
| 129 |
+
weekend = 1 if day_of_week >= 5 else 0
|
| 130 |
+
year = date.year
|
| 131 |
+
quarter = date.quarter
|
| 132 |
+
|
| 133 |
+
# Encode categorical variables
|
| 134 |
+
category_encoder = preprocessing_data['encoders']['category']
|
| 135 |
+
product_encoder = preprocessing_data['encoders']['product_id']
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
category_encoded = category_encoder.transform([category])[0]
|
| 139 |
+
except ValueError:
|
| 140 |
+
category_encoded = 0
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
product_id_encoded = product_encoder.transform([product_id])[0]
|
| 144 |
+
except ValueError:
|
| 145 |
+
product_id_encoded = product_encoder.transform([product_encoder.classes_[0]])[0]
|
| 146 |
+
|
| 147 |
+
# Create feature dictionary
|
| 148 |
+
feature_dict = {
|
| 149 |
+
'price': price,
|
| 150 |
+
'discount': discount,
|
| 151 |
+
'day': day,
|
| 152 |
+
'month': month,
|
| 153 |
+
'day_of_week': day_of_week,
|
| 154 |
+
'weekend': weekend,
|
| 155 |
+
'year': year,
|
| 156 |
+
'quarter': quarter,
|
| 157 |
+
'category_encoded': category_encoded,
|
| 158 |
+
'product_id_encoded': product_id_encoded
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
# Create feature array in the same order as training
|
| 162 |
+
feature_names = preprocessing_data['feature_names']
|
| 163 |
+
features = np.array([[feature_dict[name] for name in feature_names]])
|
| 164 |
+
|
| 165 |
+
# Scale features
|
| 166 |
+
scaler = preprocessing_data['scaler']
|
| 167 |
+
features_scaled = scaler.transform(features)
|
| 168 |
+
|
| 169 |
+
return features_scaled
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def predict_ml(product_id, date, price, discount, category, model, preprocessing_data):
|
| 173 |
+
"""Make prediction using ML model."""
|
| 174 |
+
features = prepare_features_ml(product_id, date, price, discount, category, preprocessing_data)
|
| 175 |
+
if features is None:
|
| 176 |
+
return None
|
| 177 |
+
prediction = model.predict(features)[0]
|
| 178 |
+
return max(0, prediction)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def predict_timeseries(date, model, model_name):
|
| 182 |
+
"""Make prediction using time-series model."""
|
| 183 |
+
# Convert date to pandas Timestamp (handles date, datetime, and string)
|
| 184 |
+
if isinstance(date, dt_date):
|
| 185 |
+
date = pd.Timestamp(date)
|
| 186 |
+
elif not isinstance(date, pd.Timestamp):
|
| 187 |
+
date = pd.to_datetime(date)
|
| 188 |
+
|
| 189 |
+
if model_name == 'ARIMA':
|
| 190 |
+
try:
|
| 191 |
+
forecast = model.forecast(steps=1)
|
| 192 |
+
prediction = forecast[0] if hasattr(forecast, '__iter__') else forecast
|
| 193 |
+
return max(0, prediction)
|
| 194 |
+
except:
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
elif model_name == 'Prophet':
|
| 198 |
+
try:
|
| 199 |
+
future = pd.DataFrame({'ds': [date]})
|
| 200 |
+
forecast = model.predict(future)
|
| 201 |
+
prediction = forecast['yhat'].iloc[0]
|
| 202 |
+
return max(0, prediction)
|
| 203 |
+
except:
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def main():
|
| 210 |
+
"""Main dashboard function."""
|
| 211 |
+
|
| 212 |
+
# Header
|
| 213 |
+
st.markdown('<h1 class="main-header">📊 Demand Prediction Dashboard</h1>', unsafe_allow_html=True)
|
| 214 |
+
|
| 215 |
+
# Load data
|
| 216 |
+
df = load_data()
|
| 217 |
+
if df is None:
|
| 218 |
+
st.error("❌ Sales data not found. Please run generate_dataset.py first.")
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
# Load models
|
| 222 |
+
models = load_models()
|
| 223 |
+
|
| 224 |
+
# Sidebar
|
| 225 |
+
with st.sidebar:
|
| 226 |
+
st.header("⚙️ Navigation")
|
| 227 |
+
page = st.radio(
|
| 228 |
+
"Select Page",
|
| 229 |
+
["📈 Sales Trends", "🔮 Demand Prediction", "📊 Model Comparison"],
|
| 230 |
+
index=0
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
st.markdown("---")
|
| 234 |
+
st.header("ℹ️ Information")
|
| 235 |
+
if models['metadata']:
|
| 236 |
+
best_model = models['metadata'].get('best_model', 'Unknown')
|
| 237 |
+
st.info(f"**Best Model:** {best_model}")
|
| 238 |
+
if best_model in models['metadata'].get('all_models', {}):
|
| 239 |
+
metrics = models['metadata']['all_models'][best_model]
|
| 240 |
+
st.metric("R2 Score", f"{metrics.get('r2', 0):.4f}")
|
| 241 |
+
|
| 242 |
+
# Main content based on selected page
|
| 243 |
+
if page == "📈 Sales Trends":
|
| 244 |
+
show_sales_trends(df)
|
| 245 |
+
elif page == "🔮 Demand Prediction":
|
| 246 |
+
show_prediction_interface(df, models)
|
| 247 |
+
elif page == "📊 Model Comparison":
|
| 248 |
+
show_model_comparison(models)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def show_sales_trends(df):
|
| 252 |
+
"""Display sales trends visualizations."""
|
| 253 |
+
st.header("📈 Sales Trends Analysis")
|
| 254 |
+
|
| 255 |
+
# Filters
|
| 256 |
+
col1, col2, col3 = st.columns(3)
|
| 257 |
+
|
| 258 |
+
with col1:
|
| 259 |
+
categories = ['All'] + sorted(df['category'].unique().tolist())
|
| 260 |
+
selected_category = st.selectbox("Select Category", categories)
|
| 261 |
+
|
| 262 |
+
with col2:
|
| 263 |
+
products = ['All'] + sorted(df['product_id'].unique().tolist())
|
| 264 |
+
selected_product = st.selectbox("Select Product", products)
|
| 265 |
+
|
| 266 |
+
with col3:
|
| 267 |
+
date_range = st.date_input(
|
| 268 |
+
"Select Date Range",
|
| 269 |
+
value=(df['date'].min(), df['date'].max()),
|
| 270 |
+
min_value=df['date'].min(),
|
| 271 |
+
max_value=df['date'].max()
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Filter data
|
| 275 |
+
filtered_df = df.copy()
|
| 276 |
+
|
| 277 |
+
if selected_category != 'All':
|
| 278 |
+
filtered_df = filtered_df[filtered_df['category'] == selected_category]
|
| 279 |
+
|
| 280 |
+
if selected_product != 'All':
|
| 281 |
+
filtered_df = filtered_df[filtered_df['product_id'] == int(selected_product)]
|
| 282 |
+
|
| 283 |
+
if isinstance(date_range, tuple) and len(date_range) == 2:
|
| 284 |
+
filtered_df = filtered_df[
|
| 285 |
+
(filtered_df['date'] >= pd.to_datetime(date_range[0])) &
|
| 286 |
+
(filtered_df['date'] <= pd.to_datetime(date_range[1]))
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
if len(filtered_df) == 0:
|
| 290 |
+
st.warning("No data available for selected filters.")
|
| 291 |
+
return
|
| 292 |
+
|
| 293 |
+
# Visualizations
|
| 294 |
+
tab1, tab2, tab3, tab4 = st.tabs(["📅 Daily Trends", "📆 Monthly Trends", "📦 Category Analysis", "💰 Price vs Demand"])
|
| 295 |
+
|
| 296 |
+
with tab1:
|
| 297 |
+
st.subheader("Daily Sales Trends")
|
| 298 |
+
daily_sales = filtered_df.groupby('date')['sales_quantity'].sum().reset_index()
|
| 299 |
+
|
| 300 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 301 |
+
ax.plot(daily_sales['date'], daily_sales['sales_quantity'], linewidth=2, alpha=0.7)
|
| 302 |
+
ax.set_title('Total Daily Sales Quantity', fontsize=16, fontweight='bold')
|
| 303 |
+
ax.set_xlabel('Date', fontsize=12)
|
| 304 |
+
ax.set_ylabel('Sales Quantity', fontsize=12)
|
| 305 |
+
ax.grid(True, alpha=0.3)
|
| 306 |
+
plt.xticks(rotation=45)
|
| 307 |
+
plt.tight_layout()
|
| 308 |
+
st.pyplot(fig)
|
| 309 |
+
|
| 310 |
+
# Statistics
|
| 311 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 312 |
+
with col1:
|
| 313 |
+
st.metric("Total Sales", f"{daily_sales['sales_quantity'].sum():,.0f}")
|
| 314 |
+
with col2:
|
| 315 |
+
st.metric("Average Daily", f"{daily_sales['sales_quantity'].mean():.1f}")
|
| 316 |
+
with col3:
|
| 317 |
+
st.metric("Max Daily", f"{daily_sales['sales_quantity'].max():,.0f}")
|
| 318 |
+
with col4:
|
| 319 |
+
st.metric("Min Daily", f"{daily_sales['sales_quantity'].min():,.0f}")
|
| 320 |
+
|
| 321 |
+
with tab2:
|
| 322 |
+
st.subheader("Monthly Sales Trends")
|
| 323 |
+
filtered_df['month_year'] = filtered_df['date'].dt.to_period('M')
|
| 324 |
+
monthly_sales = filtered_df.groupby('month_year')['sales_quantity'].sum().reset_index()
|
| 325 |
+
monthly_sales['month_year'] = monthly_sales['month_year'].astype(str)
|
| 326 |
+
|
| 327 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 328 |
+
ax.bar(range(len(monthly_sales)), monthly_sales['sales_quantity'], alpha=0.7, color='steelblue')
|
| 329 |
+
ax.set_title('Monthly Sales Quantity', fontsize=16, fontweight='bold')
|
| 330 |
+
ax.set_xlabel('Month', fontsize=12)
|
| 331 |
+
ax.set_ylabel('Sales Quantity', fontsize=12)
|
| 332 |
+
ax.set_xticks(range(len(monthly_sales)))
|
| 333 |
+
ax.set_xticklabels(monthly_sales['month_year'], rotation=45, ha='right')
|
| 334 |
+
ax.grid(True, alpha=0.3, axis='y')
|
| 335 |
+
plt.tight_layout()
|
| 336 |
+
st.pyplot(fig)
|
| 337 |
+
|
| 338 |
+
with tab3:
|
| 339 |
+
st.subheader("Sales by Category")
|
| 340 |
+
category_sales = filtered_df.groupby('category')['sales_quantity'].sum().sort_values(ascending=False)
|
| 341 |
+
|
| 342 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 343 |
+
category_sales.plot(kind='barh', ax=ax, color='coral', alpha=0.7)
|
| 344 |
+
ax.set_title('Total Sales by Category', fontsize=16, fontweight='bold')
|
| 345 |
+
ax.set_xlabel('Total Sales Quantity', fontsize=12)
|
| 346 |
+
ax.set_ylabel('Category', fontsize=12)
|
| 347 |
+
ax.grid(True, alpha=0.3, axis='x')
|
| 348 |
+
plt.tight_layout()
|
| 349 |
+
st.pyplot(fig)
|
| 350 |
+
|
| 351 |
+
# Category statistics table
|
| 352 |
+
category_stats = filtered_df.groupby('category').agg({
|
| 353 |
+
'sales_quantity': ['sum', 'mean', 'std', 'min', 'max']
|
| 354 |
+
}).round(2)
|
| 355 |
+
category_stats.columns = ['Total', 'Average', 'Std Dev', 'Min', 'Max']
|
| 356 |
+
st.dataframe(category_stats, use_container_width=True)
|
| 357 |
+
|
| 358 |
+
with tab4:
|
| 359 |
+
st.subheader("Price vs Demand Relationship")
|
| 360 |
+
|
| 361 |
+
# Scatter plot
|
| 362 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 363 |
+
scatter = ax.scatter(filtered_df['price'], filtered_df['sales_quantity'],
|
| 364 |
+
c=filtered_df['discount'], cmap='viridis', alpha=0.6, s=50)
|
| 365 |
+
ax.set_title('Price vs Sales Quantity (colored by discount)', fontsize=16, fontweight='bold')
|
| 366 |
+
ax.set_xlabel('Price', fontsize=12)
|
| 367 |
+
ax.set_ylabel('Sales Quantity', fontsize=12)
|
| 368 |
+
ax.grid(True, alpha=0.3)
|
| 369 |
+
plt.colorbar(scatter, ax=ax, label='Discount %')
|
| 370 |
+
plt.tight_layout()
|
| 371 |
+
st.pyplot(fig)
|
| 372 |
+
|
| 373 |
+
# Correlation
|
| 374 |
+
correlation = filtered_df['price'].corr(filtered_df['sales_quantity'])
|
| 375 |
+
st.metric("Price-Demand Correlation", f"{correlation:.3f}")
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def show_prediction_interface(df, models):
|
| 379 |
+
"""Display interactive prediction interface."""
|
| 380 |
+
st.header("🔮 Demand Prediction")
|
| 381 |
+
|
| 382 |
+
# Check if models are available
|
| 383 |
+
if models['ml_model'] is None and models['ts_model'] is None:
|
| 384 |
+
st.error("❌ No trained models found. Please run train_model.py first.")
|
| 385 |
+
return
|
| 386 |
+
|
| 387 |
+
# Model selection
|
| 388 |
+
model_type = st.radio(
|
| 389 |
+
"Select Model Type",
|
| 390 |
+
["Auto (Best Model)", "Machine Learning", "Time-Series"],
|
| 391 |
+
horizontal=True
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
st.markdown("---")
|
| 395 |
+
|
| 396 |
+
if model_type == "Time-Series" or (model_type == "Auto (Best Model)" and models['is_timeseries']):
|
| 397 |
+
# Time-series prediction
|
| 398 |
+
st.subheader("Overall Daily Demand Prediction")
|
| 399 |
+
|
| 400 |
+
col1, col2 = st.columns(2)
|
| 401 |
+
with col1:
|
| 402 |
+
prediction_date = st.date_input(
|
| 403 |
+
"Select Date for Prediction",
|
| 404 |
+
value=datetime.now().date() + timedelta(days=30),
|
| 405 |
+
min_value=df['date'].max().date() + timedelta(days=1)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
with col2:
|
| 409 |
+
st.write("") # Spacing
|
| 410 |
+
st.write("") # Spacing
|
| 411 |
+
|
| 412 |
+
if st.button("🔮 Predict Demand", type="primary"):
|
| 413 |
+
if models['ts_model'] is None:
|
| 414 |
+
st.error("Time-series model not available.")
|
| 415 |
+
else:
|
| 416 |
+
with st.spinner("Making prediction..."):
|
| 417 |
+
prediction = predict_timeseries(
|
| 418 |
+
prediction_date,
|
| 419 |
+
models['ts_model'],
|
| 420 |
+
models['model_name']
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
if prediction is not None:
|
| 424 |
+
st.success(f"✅ Prediction Complete!")
|
| 425 |
+
|
| 426 |
+
col1, col2, col3 = st.columns(3)
|
| 427 |
+
with col1:
|
| 428 |
+
st.metric("Predicted Daily Demand", f"{prediction:,.0f} units")
|
| 429 |
+
with col2:
|
| 430 |
+
day_name = pd.to_datetime(prediction_date).strftime('%A')
|
| 431 |
+
st.metric("Day of Week", day_name)
|
| 432 |
+
with col3:
|
| 433 |
+
is_weekend = "Yes" if pd.to_datetime(prediction_date).weekday() >= 5 else "No"
|
| 434 |
+
st.metric("Weekend", is_weekend)
|
| 435 |
+
|
| 436 |
+
st.info("💡 This prediction represents the total daily demand across all products.")
|
| 437 |
+
else:
|
| 438 |
+
st.error("Failed to make prediction.")
|
| 439 |
+
|
| 440 |
+
else:
|
| 441 |
+
# ML model prediction
|
| 442 |
+
st.subheader("Product-Specific Demand Prediction")
|
| 443 |
+
|
| 444 |
+
# Get unique values for dropdowns
|
| 445 |
+
categories = sorted(df['category'].unique().tolist())
|
| 446 |
+
products = sorted(df['product_id'].unique().tolist())
|
| 447 |
+
|
| 448 |
+
col1, col2 = st.columns(2)
|
| 449 |
+
|
| 450 |
+
with col1:
|
| 451 |
+
selected_category = st.selectbox("Select Category", categories)
|
| 452 |
+
selected_product = st.selectbox("Select Product ID", products)
|
| 453 |
+
prediction_date = st.date_input(
|
| 454 |
+
"Select Date for Prediction",
|
| 455 |
+
value=datetime.now().date() + timedelta(days=30),
|
| 456 |
+
min_value=df['date'].max().date() + timedelta(days=1)
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
with col2:
|
| 460 |
+
price = st.number_input(
|
| 461 |
+
"Product Price ($)",
|
| 462 |
+
min_value=0.01,
|
| 463 |
+
value=100.0,
|
| 464 |
+
step=1.0,
|
| 465 |
+
format="%.2f"
|
| 466 |
+
)
|
| 467 |
+
discount = st.slider(
|
| 468 |
+
"Discount (%)",
|
| 469 |
+
min_value=0,
|
| 470 |
+
max_value=100,
|
| 471 |
+
value=0,
|
| 472 |
+
step=5
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# Show product statistics
|
| 476 |
+
product_data = df[df['product_id'] == selected_product]
|
| 477 |
+
if len(product_data) > 0:
|
| 478 |
+
with st.expander("📊 Product Statistics"):
|
| 479 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 480 |
+
with col1:
|
| 481 |
+
st.metric("Avg Price", f"${product_data['price'].mean():.2f}")
|
| 482 |
+
with col2:
|
| 483 |
+
st.metric("Avg Sales", f"{product_data['sales_quantity'].mean():.1f}")
|
| 484 |
+
with col3:
|
| 485 |
+
st.metric("Total Sales", f"{product_data['sales_quantity'].sum():,.0f}")
|
| 486 |
+
with col4:
|
| 487 |
+
st.metric("Category", selected_category)
|
| 488 |
+
|
| 489 |
+
if st.button("🔮 Predict Demand", type="primary"):
|
| 490 |
+
if models['ml_model'] is None or models['preprocessing'] is None:
|
| 491 |
+
st.error("ML model or preprocessing not available.")
|
| 492 |
+
else:
|
| 493 |
+
with st.spinner("Making prediction..."):
|
| 494 |
+
prediction = predict_ml(
|
| 495 |
+
selected_product,
|
| 496 |
+
prediction_date,
|
| 497 |
+
price,
|
| 498 |
+
discount,
|
| 499 |
+
selected_category,
|
| 500 |
+
models['ml_model'],
|
| 501 |
+
models['preprocessing']
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if prediction is not None:
|
| 505 |
+
st.success(f"✅ Prediction Complete!")
|
| 506 |
+
|
| 507 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 508 |
+
with col1:
|
| 509 |
+
st.metric("Predicted Demand", f"{prediction:,.0f} units")
|
| 510 |
+
with col2:
|
| 511 |
+
st.metric("Price", f"${price:.2f}")
|
| 512 |
+
with col3:
|
| 513 |
+
st.metric("Discount", f"{discount}%")
|
| 514 |
+
with col4:
|
| 515 |
+
day_name = pd.to_datetime(prediction_date).strftime('%A')
|
| 516 |
+
st.metric("Day", day_name)
|
| 517 |
+
|
| 518 |
+
# Additional insights
|
| 519 |
+
st.markdown("### 📈 Prediction Insights")
|
| 520 |
+
date_obj = pd.to_datetime(prediction_date)
|
| 521 |
+
is_weekend = date_obj.weekday() >= 5
|
| 522 |
+
month = date_obj.month
|
| 523 |
+
|
| 524 |
+
insights = []
|
| 525 |
+
if is_weekend:
|
| 526 |
+
insights.append("📅 Weekend - typically higher demand")
|
| 527 |
+
if month in [11, 12]:
|
| 528 |
+
insights.append("🎄 Holiday season - peak sales period")
|
| 529 |
+
if discount > 0:
|
| 530 |
+
insights.append(f"💰 {discount}% discount - may increase demand")
|
| 531 |
+
|
| 532 |
+
if insights:
|
| 533 |
+
for insight in insights:
|
| 534 |
+
st.info(insight)
|
| 535 |
+
else:
|
| 536 |
+
st.error("Failed to make prediction.")
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def show_model_comparison(models):
|
| 540 |
+
"""Display model comparison."""
|
| 541 |
+
st.header("📊 Model Comparison")
|
| 542 |
+
|
| 543 |
+
if models['metadata'] is None:
|
| 544 |
+
st.warning("Model metadata not available. Please run train_model.py first.")
|
| 545 |
+
return
|
| 546 |
+
|
| 547 |
+
metadata = models['metadata']
|
| 548 |
+
all_models = metadata.get('all_models', {})
|
| 549 |
+
best_model = metadata.get('best_model', 'Unknown')
|
| 550 |
+
|
| 551 |
+
if not all_models:
|
| 552 |
+
st.warning("No model comparison data available.")
|
| 553 |
+
return
|
| 554 |
+
|
| 555 |
+
# Model metrics table
|
| 556 |
+
st.subheader("Model Performance Metrics")
|
| 557 |
+
|
| 558 |
+
comparison_data = []
|
| 559 |
+
for model_name, metrics in all_models.items():
|
| 560 |
+
comparison_data.append({
|
| 561 |
+
'Model': model_name,
|
| 562 |
+
'Type': 'Time-Series' if model_name in ['ARIMA', 'Prophet'] else 'Machine Learning',
|
| 563 |
+
'MAE': metrics.get('mae', 0),
|
| 564 |
+
'RMSE': metrics.get('rmse', 0),
|
| 565 |
+
'R2 Score': metrics.get('r2', 0)
|
| 566 |
+
})
|
| 567 |
+
|
| 568 |
+
comparison_df = pd.DataFrame(comparison_data)
|
| 569 |
+
|
| 570 |
+
# Highlight best model
|
| 571 |
+
def highlight_best(row):
|
| 572 |
+
if row['Model'] == best_model:
|
| 573 |
+
return ['background-color: #90EE90'] * len(row)
|
| 574 |
+
return [''] * len(row)
|
| 575 |
+
|
| 576 |
+
st.dataframe(
|
| 577 |
+
comparison_df.style.apply(highlight_best, axis=1),
|
| 578 |
+
use_container_width=True
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# Visualizations
|
| 582 |
+
st.subheader("Performance Comparison Charts")
|
| 583 |
+
|
| 584 |
+
col1, col2 = st.columns(2)
|
| 585 |
+
|
| 586 |
+
with col1:
|
| 587 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 588 |
+
model_names = comparison_df['Model'].tolist()
|
| 589 |
+
mae_scores = comparison_df['MAE'].tolist()
|
| 590 |
+
|
| 591 |
+
colors = ['coral' if name in ['ARIMA', 'Prophet'] else 'skyblue' for name in model_names]
|
| 592 |
+
ax.bar(model_names, mae_scores, color=colors, alpha=0.7)
|
| 593 |
+
ax.set_title('MAE Comparison (Lower is Better)', fontsize=14, fontweight='bold')
|
| 594 |
+
ax.set_ylabel('MAE', fontsize=12)
|
| 595 |
+
ax.tick_params(axis='x', rotation=45)
|
| 596 |
+
ax.grid(True, alpha=0.3, axis='y')
|
| 597 |
+
plt.tight_layout()
|
| 598 |
+
st.pyplot(fig)
|
| 599 |
+
|
| 600 |
+
with col2:
|
| 601 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 602 |
+
r2_scores = comparison_df['R2 Score'].tolist()
|
| 603 |
+
|
| 604 |
+
colors = ['coral' if name in ['ARIMA', 'Prophet'] else 'skyblue' for name in model_names]
|
| 605 |
+
ax.bar(model_names, r2_scores, color=colors, alpha=0.7)
|
| 606 |
+
ax.set_title('R2 Score Comparison (Higher is Better)', fontsize=14, fontweight='bold')
|
| 607 |
+
ax.set_ylabel('R2 Score', fontsize=12)
|
| 608 |
+
ax.tick_params(axis='x', rotation=45)
|
| 609 |
+
ax.grid(True, alpha=0.3, axis='y')
|
| 610 |
+
plt.tight_layout()
|
| 611 |
+
st.pyplot(fig)
|
| 612 |
+
|
| 613 |
+
# Best model info
|
| 614 |
+
st.markdown("---")
|
| 615 |
+
st.success(f"🏆 **Best Model: {best_model}**")
|
| 616 |
+
if best_model in all_models:
|
| 617 |
+
best_metrics = all_models[best_model]
|
| 618 |
+
col1, col2, col3 = st.columns(3)
|
| 619 |
+
with col1:
|
| 620 |
+
st.metric("MAE", f"{best_metrics.get('mae', 0):.2f}")
|
| 621 |
+
with col2:
|
| 622 |
+
st.metric("RMSE", f"{best_metrics.get('rmse', 0):.2f}")
|
| 623 |
+
with col3:
|
| 624 |
+
st.metric("R2 Score", f"{best_metrics.get('r2', 0):.4f}")
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
if __name__ == "__main__":
|
| 628 |
+
main()
|
data/sales.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generate_dataset.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate Synthetic E-commerce Sales Dataset
|
| 3 |
+
|
| 4 |
+
This script creates a realistic synthetic dataset for demand prediction.
|
| 5 |
+
The dataset includes temporal patterns, seasonality, and realistic relationships
|
| 6 |
+
between features and sales quantity.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from datetime import datetime, timedelta
|
| 12 |
+
|
| 13 |
+
# Set random seed for reproducibility
|
| 14 |
+
np.random.seed(42)
|
| 15 |
+
|
| 16 |
+
# Configuration
|
| 17 |
+
NUM_PRODUCTS = 50
|
| 18 |
+
START_DATE = datetime(2020, 1, 1)
|
| 19 |
+
END_DATE = datetime(2023, 12, 31)
|
| 20 |
+
CATEGORIES = ['Electronics', 'Clothing', 'Home & Garden', 'Sports', 'Books',
|
| 21 |
+
'Toys', 'Beauty', 'Automotive', 'Food & Beverages', 'Health']
|
| 22 |
+
|
| 23 |
+
# Generate date range
|
| 24 |
+
date_range = pd.date_range(start=START_DATE, end=END_DATE, freq='D')
|
| 25 |
+
num_days = len(date_range)
|
| 26 |
+
|
| 27 |
+
# Initialize lists to store data
|
| 28 |
+
data = []
|
| 29 |
+
|
| 30 |
+
# Generate data for each product
|
| 31 |
+
for product_id in range(1, NUM_PRODUCTS + 1):
|
| 32 |
+
# Assign category randomly
|
| 33 |
+
category = np.random.choice(CATEGORIES)
|
| 34 |
+
|
| 35 |
+
# Base price varies by category
|
| 36 |
+
category_base_prices = {
|
| 37 |
+
'Electronics': 500,
|
| 38 |
+
'Clothing': 50,
|
| 39 |
+
'Home & Garden': 100,
|
| 40 |
+
'Sports': 150,
|
| 41 |
+
'Books': 20,
|
| 42 |
+
'Toys': 30,
|
| 43 |
+
'Beauty': 40,
|
| 44 |
+
'Automotive': 300,
|
| 45 |
+
'Food & Beverages': 25,
|
| 46 |
+
'Health': 60
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
base_price = category_base_prices[category] * (0.8 + np.random.random() * 0.4)
|
| 50 |
+
|
| 51 |
+
# Generate daily records
|
| 52 |
+
for date in date_range:
|
| 53 |
+
# Day of week effect (weekends have higher sales)
|
| 54 |
+
day_of_week = date.weekday()
|
| 55 |
+
weekend_multiplier = 1.3 if day_of_week >= 5 else 1.0
|
| 56 |
+
|
| 57 |
+
# Monthly seasonality (higher sales in Nov-Dec, lower in Jan-Feb)
|
| 58 |
+
month = date.month
|
| 59 |
+
if month in [11, 12]: # Holiday season
|
| 60 |
+
seasonality = 1.5
|
| 61 |
+
elif month in [1, 2]: # Post-holiday slump
|
| 62 |
+
seasonality = 0.7
|
| 63 |
+
elif month in [6, 7, 8]: # Summer
|
| 64 |
+
seasonality = 1.2
|
| 65 |
+
else:
|
| 66 |
+
seasonality = 1.0
|
| 67 |
+
|
| 68 |
+
# Random discount (0-30%)
|
| 69 |
+
discount = np.random.choice([0, 5, 10, 15, 20, 25, 30], p=[0.4, 0.2, 0.15, 0.1, 0.08, 0.05, 0.02])
|
| 70 |
+
|
| 71 |
+
# Price with discount
|
| 72 |
+
price = base_price * (1 - discount / 100)
|
| 73 |
+
|
| 74 |
+
# Base demand varies by product
|
| 75 |
+
base_demand = np.random.randint(10, 100)
|
| 76 |
+
|
| 77 |
+
# Calculate sales quantity with multiple factors
|
| 78 |
+
# Higher discount -> higher sales
|
| 79 |
+
discount_effect = 1 + (discount / 100) * 0.5
|
| 80 |
+
|
| 81 |
+
# Lower price -> higher sales (inverse relationship)
|
| 82 |
+
price_effect = 1 / (1 + (price / 1000) * 0.1)
|
| 83 |
+
|
| 84 |
+
# Add some randomness
|
| 85 |
+
noise = np.random.normal(1, 0.15)
|
| 86 |
+
|
| 87 |
+
# Calculate final sales quantity
|
| 88 |
+
sales_quantity = int(
|
| 89 |
+
base_demand *
|
| 90 |
+
weekend_multiplier *
|
| 91 |
+
seasonality *
|
| 92 |
+
discount_effect *
|
| 93 |
+
price_effect *
|
| 94 |
+
noise
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Ensure non-negative
|
| 98 |
+
sales_quantity = max(0, sales_quantity)
|
| 99 |
+
|
| 100 |
+
data.append({
|
| 101 |
+
'product_id': product_id,
|
| 102 |
+
'date': date.strftime('%Y-%m-%d'),
|
| 103 |
+
'price': round(price, 2),
|
| 104 |
+
'discount': discount,
|
| 105 |
+
'category': category,
|
| 106 |
+
'sales_quantity': sales_quantity
|
| 107 |
+
})
|
| 108 |
+
|
| 109 |
+
# Create DataFrame
|
| 110 |
+
df = pd.DataFrame(data)
|
| 111 |
+
|
| 112 |
+
# Shuffle the data
|
| 113 |
+
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
|
| 114 |
+
|
| 115 |
+
# Save to CSV
|
| 116 |
+
output_path = 'data/sales.csv'
|
| 117 |
+
df.to_csv(output_path, index=False)
|
| 118 |
+
|
| 119 |
+
print(f"Dataset generated successfully!")
|
| 120 |
+
print(f"Total records: {len(df)}")
|
| 121 |
+
print(f"Date range: {df['date'].min()} to {df['date'].max()}")
|
| 122 |
+
print(f"Number of products: {df['product_id'].nunique()}")
|
| 123 |
+
print(f"Categories: {df['category'].nunique()}")
|
| 124 |
+
print(f"\nDataset saved to: {output_path}")
|
| 125 |
+
print(f"\nFirst few rows:")
|
| 126 |
+
print(df.head(10))
|
| 127 |
+
print(f"\nDataset statistics:")
|
| 128 |
+
print(df.describe())
|
models/all_models_metadata.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_model": "XGBoost",
|
| 3 |
+
"best_metrics": {
|
| 4 |
+
"mae": 28.266036987304688,
|
| 5 |
+
"rmse": 34.58608132593768,
|
| 6 |
+
"r2": 0.19758522510528564
|
| 7 |
+
},
|
| 8 |
+
"all_models": {
|
| 9 |
+
"Linear Regression": {
|
| 10 |
+
"mae": 28.94336905682285,
|
| 11 |
+
"rmse": 35.499695024759994,
|
| 12 |
+
"r2": 0.15463271181982996
|
| 13 |
+
},
|
| 14 |
+
"Random Forest": {
|
| 15 |
+
"mae": 28.52530939054232,
|
| 16 |
+
"rmse": 34.98799112718141,
|
| 17 |
+
"r2": 0.17882785374399268
|
| 18 |
+
},
|
| 19 |
+
"XGBoost": {
|
| 20 |
+
"mae": 28.266036987304688,
|
| 21 |
+
"rmse": 34.58608132593768,
|
| 22 |
+
"r2": 0.19758522510528564
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
"saved_at": "2026-02-06 17:29:16"
|
| 26 |
+
}
|
models/best_model.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:815655a55967c29f9302ce2d52cd6707ae584cbcb25532722d4a2415acd246a1
|
| 3 |
+
size 495387
|
models/model_metadata.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "XGBoost",
|
| 3 |
+
"metrics": {
|
| 4 |
+
"mae": 28.266036987304688,
|
| 5 |
+
"rmse": 34.58608132593768,
|
| 6 |
+
"r2": 0.19758522510528564
|
| 7 |
+
},
|
| 8 |
+
"feature_names": [
|
| 9 |
+
"price",
|
| 10 |
+
"discount",
|
| 11 |
+
"day",
|
| 12 |
+
"month",
|
| 13 |
+
"day_of_week",
|
| 14 |
+
"weekend",
|
| 15 |
+
"year",
|
| 16 |
+
"quarter",
|
| 17 |
+
"category_encoded",
|
| 18 |
+
"product_id_encoded"
|
| 19 |
+
],
|
| 20 |
+
"saved_at": "2026-02-06 17:29:16"
|
| 21 |
+
}
|
models/preprocessing.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e093b6491538c4afa8a7dfcf6bc7e82118d4f5fc72fa024e2079f98a5b57cc5
|
| 3 |
+
size 2252
|
plots/demand_trends.png
ADDED
|
Git LFS Details
|
plots/feature_importance.png
ADDED
|
Git LFS Details
|
plots/model_comparison.png
ADDED
|
Git LFS Details
|
plots/monthly_demand.png
ADDED
|
Git LFS Details
|
predict.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demand Prediction System - Prediction Script
|
| 3 |
+
|
| 4 |
+
This script loads a trained model and makes demand predictions for products
|
| 5 |
+
on future dates. Supports both ML models and time-series models (ARIMA, Prophet).
|
| 6 |
+
|
| 7 |
+
Usage (ML Models):
|
| 8 |
+
python predict.py --product_id 1 --date 2024-01-15 --price 100 --discount 10 --category Electronics
|
| 9 |
+
|
| 10 |
+
Usage (Time-Series Models - overall demand):
|
| 11 |
+
python predict.py --date 2024-01-15 --model_type timeseries
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import numpy as np
|
| 16 |
+
import joblib
|
| 17 |
+
import json
|
| 18 |
+
import argparse
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
import os
|
| 21 |
+
import warnings
|
| 22 |
+
warnings.filterwarnings('ignore')
|
| 23 |
+
|
| 24 |
+
# Configuration
|
| 25 |
+
MODEL_DIR = 'models'
|
| 26 |
+
MODEL_PATH = f'{MODEL_DIR}/best_model.joblib'
|
| 27 |
+
TS_MODEL_PATH = f'{MODEL_DIR}/best_timeseries_model.joblib'
|
| 28 |
+
PREPROCESSING_PATH = f'{MODEL_DIR}/preprocessing.joblib'
|
| 29 |
+
METADATA_PATH = f'{MODEL_DIR}/model_metadata.json'
|
| 30 |
+
ALL_MODELS_METADATA_PATH = f'{MODEL_DIR}/all_models_metadata.json'
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_model_and_preprocessing(model_type='auto'):
|
| 34 |
+
"""
|
| 35 |
+
Load the trained model and preprocessing objects.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model_type: 'ml', 'timeseries', or 'auto' (auto-detect best model)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
tuple: (model, preprocessing_data, model_name, is_timeseries)
|
| 42 |
+
"""
|
| 43 |
+
# Load metadata to determine best model
|
| 44 |
+
if os.path.exists(ALL_MODELS_METADATA_PATH):
|
| 45 |
+
with open(ALL_MODELS_METADATA_PATH, 'r') as f:
|
| 46 |
+
all_metadata = json.load(f)
|
| 47 |
+
best_model_name = all_metadata.get('best_model', 'Unknown')
|
| 48 |
+
else:
|
| 49 |
+
best_model_name = None
|
| 50 |
+
|
| 51 |
+
# Determine which model to use
|
| 52 |
+
if model_type == 'auto':
|
| 53 |
+
if best_model_name in ['ARIMA', 'Prophet']:
|
| 54 |
+
model_type = 'timeseries'
|
| 55 |
+
else:
|
| 56 |
+
model_type = 'ml'
|
| 57 |
+
|
| 58 |
+
is_timeseries = (model_type == 'timeseries')
|
| 59 |
+
|
| 60 |
+
if is_timeseries:
|
| 61 |
+
# Load time-series model
|
| 62 |
+
if not os.path.exists(TS_MODEL_PATH):
|
| 63 |
+
raise FileNotFoundError(
|
| 64 |
+
f"Time-series model not found at {TS_MODEL_PATH}. Please run train_model.py first."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
print("Loading time-series model...")
|
| 68 |
+
model = joblib.load(TS_MODEL_PATH)
|
| 69 |
+
preprocessing_data = None
|
| 70 |
+
|
| 71 |
+
if best_model_name:
|
| 72 |
+
print(f"Model: {best_model_name}")
|
| 73 |
+
if best_model_name in all_metadata.get('all_models', {}):
|
| 74 |
+
metrics = all_metadata['all_models'][best_model_name]
|
| 75 |
+
print(f"R2 Score: {metrics.get('r2', 'N/A'):.4f}")
|
| 76 |
+
|
| 77 |
+
return model, preprocessing_data, best_model_name or 'Time-Series', True
|
| 78 |
+
else:
|
| 79 |
+
# Load ML model
|
| 80 |
+
if not os.path.exists(MODEL_PATH):
|
| 81 |
+
raise FileNotFoundError(
|
| 82 |
+
f"ML model not found at {MODEL_PATH}. Please run train_model.py first."
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if not os.path.exists(PREPROCESSING_PATH):
|
| 86 |
+
raise FileNotFoundError(
|
| 87 |
+
f"Preprocessing objects not found at {PREPROCESSING_PATH}. Please run train_model.py first."
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
print("Loading ML model and preprocessing objects...")
|
| 91 |
+
model = joblib.load(MODEL_PATH)
|
| 92 |
+
preprocessing_data = joblib.load(PREPROCESSING_PATH)
|
| 93 |
+
|
| 94 |
+
# Load metadata if available
|
| 95 |
+
if os.path.exists(METADATA_PATH):
|
| 96 |
+
with open(METADATA_PATH, 'r') as f:
|
| 97 |
+
metadata = json.load(f)
|
| 98 |
+
model_name = metadata.get('model_name', 'ML Model')
|
| 99 |
+
print(f"Model: {model_name}")
|
| 100 |
+
print(f"R2 Score: {metadata.get('metrics', {}).get('r2', 'N/A'):.4f}")
|
| 101 |
+
else:
|
| 102 |
+
model_name = best_model_name or 'ML Model'
|
| 103 |
+
|
| 104 |
+
return model, preprocessing_data, model_name, False
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def prepare_features(product_id, date, price, discount, category, preprocessing_data):
|
| 108 |
+
"""
|
| 109 |
+
Prepare features for prediction using the same preprocessing pipeline.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
product_id: Product ID
|
| 113 |
+
date: Date string (YYYY-MM-DD) or datetime object
|
| 114 |
+
price: Product price
|
| 115 |
+
discount: Discount percentage (0-100)
|
| 116 |
+
category: Product category
|
| 117 |
+
preprocessing_data: Dictionary containing encoders and scaler
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
numpy array: Prepared features for prediction
|
| 121 |
+
"""
|
| 122 |
+
# Convert date to datetime if string
|
| 123 |
+
if isinstance(date, str):
|
| 124 |
+
date = pd.to_datetime(date)
|
| 125 |
+
|
| 126 |
+
# Extract date features (same as in training)
|
| 127 |
+
day = date.day
|
| 128 |
+
month = date.month
|
| 129 |
+
day_of_week = date.weekday() # 0=Monday, 6=Sunday
|
| 130 |
+
weekend = 1 if day_of_week >= 5 else 0
|
| 131 |
+
year = date.year
|
| 132 |
+
quarter = date.quarter
|
| 133 |
+
|
| 134 |
+
# Encode categorical variables
|
| 135 |
+
category_encoder = preprocessing_data['encoders']['category']
|
| 136 |
+
product_encoder = preprocessing_data['encoders']['product_id']
|
| 137 |
+
|
| 138 |
+
# Handle unseen categories/products
|
| 139 |
+
try:
|
| 140 |
+
category_encoded = category_encoder.transform([category])[0]
|
| 141 |
+
except ValueError:
|
| 142 |
+
# If category not seen during training, use most common category
|
| 143 |
+
print(f"Warning: Category '{category}' not seen during training. Using default encoding.")
|
| 144 |
+
category_encoded = 0
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
product_id_encoded = product_encoder.transform([product_id])[0]
|
| 148 |
+
except ValueError:
|
| 149 |
+
# If product_id not seen during training, use mean encoding
|
| 150 |
+
print(f"Warning: Product ID '{product_id}' not seen during training. Using default encoding.")
|
| 151 |
+
product_id_encoded = product_encoder.transform([product_encoder.classes_[0]])[0]
|
| 152 |
+
|
| 153 |
+
# Create feature dictionary
|
| 154 |
+
feature_dict = {
|
| 155 |
+
'price': price,
|
| 156 |
+
'discount': discount,
|
| 157 |
+
'day': day,
|
| 158 |
+
'month': month,
|
| 159 |
+
'day_of_week': day_of_week,
|
| 160 |
+
'weekend': weekend,
|
| 161 |
+
'year': year,
|
| 162 |
+
'quarter': quarter,
|
| 163 |
+
'category_encoded': category_encoded,
|
| 164 |
+
'product_id_encoded': product_id_encoded
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
# Create feature array in the same order as training
|
| 168 |
+
feature_names = preprocessing_data['feature_names']
|
| 169 |
+
features = np.array([[feature_dict[name] for name in feature_names]])
|
| 170 |
+
|
| 171 |
+
# Scale features
|
| 172 |
+
scaler = preprocessing_data['scaler']
|
| 173 |
+
features_scaled = scaler.transform(features)
|
| 174 |
+
|
| 175 |
+
return features_scaled
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def predict_demand_ml(product_id, date, price, discount, category, model, preprocessing_data):
|
| 179 |
+
"""
|
| 180 |
+
Predict demand for a product on a given date using ML model.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
product_id: Product ID
|
| 184 |
+
date: Date string (YYYY-MM-DD) or datetime object
|
| 185 |
+
price: Product price
|
| 186 |
+
discount: Discount percentage (0-100)
|
| 187 |
+
category: Product category
|
| 188 |
+
model: Trained ML model
|
| 189 |
+
preprocessing_data: Dictionary containing encoders and scaler
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
float: Predicted sales quantity
|
| 193 |
+
"""
|
| 194 |
+
# Prepare features
|
| 195 |
+
features = prepare_features(product_id, date, price, discount, category, preprocessing_data)
|
| 196 |
+
|
| 197 |
+
# Make prediction
|
| 198 |
+
prediction = model.predict(features)[0]
|
| 199 |
+
|
| 200 |
+
# Ensure non-negative prediction
|
| 201 |
+
prediction = max(0, prediction)
|
| 202 |
+
|
| 203 |
+
return prediction
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def predict_demand_timeseries(date, model, model_name):
|
| 207 |
+
"""
|
| 208 |
+
Predict overall daily demand using time-series model.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
date: Date string (YYYY-MM-DD) or datetime object
|
| 212 |
+
model: Trained time-series model (ARIMA or Prophet)
|
| 213 |
+
model_name: Name of the model ('ARIMA' or 'Prophet')
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
float: Predicted total daily sales quantity
|
| 217 |
+
"""
|
| 218 |
+
# Convert date to datetime if string
|
| 219 |
+
if isinstance(date, str):
|
| 220 |
+
date = pd.to_datetime(date)
|
| 221 |
+
|
| 222 |
+
if model_name == 'ARIMA':
|
| 223 |
+
# For ARIMA, we need to calculate how many steps ahead
|
| 224 |
+
# This is a simplified approach - in practice, you'd need the training end date
|
| 225 |
+
# For now, predict 1 step ahead
|
| 226 |
+
try:
|
| 227 |
+
forecast = model.forecast(steps=1)
|
| 228 |
+
prediction = forecast[0] if hasattr(forecast, '__iter__') else forecast
|
| 229 |
+
prediction = max(0, prediction)
|
| 230 |
+
return prediction
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print(f"Error in ARIMA prediction: {e}")
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
elif model_name == 'Prophet':
|
| 236 |
+
# For Prophet, create a future dataframe
|
| 237 |
+
try:
|
| 238 |
+
future = pd.DataFrame({'ds': [date]})
|
| 239 |
+
forecast = model.predict(future)
|
| 240 |
+
prediction = forecast['yhat'].iloc[0]
|
| 241 |
+
prediction = max(0, prediction)
|
| 242 |
+
return prediction
|
| 243 |
+
except Exception as e:
|
| 244 |
+
print(f"Error in Prophet prediction: {e}")
|
| 245 |
+
return None
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
print(f"Unknown time-series model: {model_name}")
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def predict_batch(predictions_data, model, preprocessing_data):
|
| 253 |
+
"""
|
| 254 |
+
Predict demand for multiple products/dates at once.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
predictions_data: List of dictionaries, each containing:
|
| 258 |
+
- product_id
|
| 259 |
+
- date
|
| 260 |
+
- price
|
| 261 |
+
- discount
|
| 262 |
+
- category
|
| 263 |
+
model: Trained model
|
| 264 |
+
preprocessing_data: Dictionary containing encoders and scaler
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
list: List of predicted sales quantities
|
| 268 |
+
"""
|
| 269 |
+
predictions = []
|
| 270 |
+
|
| 271 |
+
for data in predictions_data:
|
| 272 |
+
pred = predict_demand(
|
| 273 |
+
data['product_id'],
|
| 274 |
+
data['date'],
|
| 275 |
+
data['price'],
|
| 276 |
+
data['discount'],
|
| 277 |
+
data['category'],
|
| 278 |
+
model,
|
| 279 |
+
preprocessing_data
|
| 280 |
+
)
|
| 281 |
+
predictions.append(pred)
|
| 282 |
+
|
| 283 |
+
return predictions
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def main():
|
| 287 |
+
"""
|
| 288 |
+
Main function for command-line interface.
|
| 289 |
+
"""
|
| 290 |
+
parser = argparse.ArgumentParser(
|
| 291 |
+
description='Predict product demand for a given date and product details',
|
| 292 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 293 |
+
epilog="""
|
| 294 |
+
Examples (ML Models):
|
| 295 |
+
python predict.py --product_id 1 --date 2024-01-15 --price 100 --discount 10 --category Electronics
|
| 296 |
+
python predict.py --product_id 5 --date 2024-06-20 --price 50 --discount 0 --category Clothing
|
| 297 |
+
|
| 298 |
+
Examples (Time-Series Models - overall daily demand):
|
| 299 |
+
python predict.py --date 2024-01-15 --model_type timeseries
|
| 300 |
+
"""
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
parser.add_argument('--product_id', type=int, default=None,
|
| 304 |
+
help='Product ID (required for ML models)')
|
| 305 |
+
parser.add_argument('--date', type=str, required=True,
|
| 306 |
+
help='Date in YYYY-MM-DD format')
|
| 307 |
+
parser.add_argument('--price', type=float, default=None,
|
| 308 |
+
help='Product price (required for ML models)')
|
| 309 |
+
parser.add_argument('--discount', type=float, default=0,
|
| 310 |
+
help='Discount percentage (0-100), default: 0 (for ML models)')
|
| 311 |
+
parser.add_argument('--category', type=str, default=None,
|
| 312 |
+
help='Product category (required for ML models)')
|
| 313 |
+
parser.add_argument('--model_type', type=str, default='auto',
|
| 314 |
+
choices=['auto', 'ml', 'timeseries'],
|
| 315 |
+
help='Model type to use: auto (best model), ml, or timeseries')
|
| 316 |
+
|
| 317 |
+
args = parser.parse_args()
|
| 318 |
+
|
| 319 |
+
# Validate date format
|
| 320 |
+
try:
|
| 321 |
+
date_obj = pd.to_datetime(args.date)
|
| 322 |
+
except ValueError:
|
| 323 |
+
print(f"Error: Invalid date format '{args.date}'. Please use YYYY-MM-DD format.")
|
| 324 |
+
return
|
| 325 |
+
|
| 326 |
+
# Load model and preprocessing
|
| 327 |
+
try:
|
| 328 |
+
model, preprocessing_data, model_name, is_timeseries = load_model_and_preprocessing(args.model_type)
|
| 329 |
+
except FileNotFoundError as e:
|
| 330 |
+
print(f"Error: {e}")
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
# Validate arguments based on model type
|
| 334 |
+
if not is_timeseries:
|
| 335 |
+
# ML model requires product details
|
| 336 |
+
if args.product_id is None or args.price is None or args.category is None:
|
| 337 |
+
print("Error: ML models require --product_id, --price, and --category arguments.")
|
| 338 |
+
return
|
| 339 |
+
|
| 340 |
+
# Validate discount range
|
| 341 |
+
if args.discount < 0 or args.discount > 100:
|
| 342 |
+
print(f"Warning: Discount {args.discount}% is outside 0-100 range. Clamping to valid range.")
|
| 343 |
+
args.discount = max(0, min(100, args.discount))
|
| 344 |
+
|
| 345 |
+
# Make prediction
|
| 346 |
+
print("\n" + "="*60)
|
| 347 |
+
print("MAKING PREDICTION")
|
| 348 |
+
print("="*60)
|
| 349 |
+
print(f"Model: {model_name}")
|
| 350 |
+
print(f"Model Type: {'Time-Series' if is_timeseries else 'Machine Learning'}")
|
| 351 |
+
print(f"Date: {args.date}")
|
| 352 |
+
|
| 353 |
+
if not is_timeseries:
|
| 354 |
+
print(f"Product ID: {args.product_id}")
|
| 355 |
+
print(f"Price: ${args.price:.2f}")
|
| 356 |
+
print(f"Discount: {args.discount}%")
|
| 357 |
+
print(f"Category: {args.category}")
|
| 358 |
+
|
| 359 |
+
print("-"*60)
|
| 360 |
+
|
| 361 |
+
if is_timeseries:
|
| 362 |
+
predicted_demand = predict_demand_timeseries(
|
| 363 |
+
args.date,
|
| 364 |
+
model,
|
| 365 |
+
model_name
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if predicted_demand is None:
|
| 369 |
+
print("Error: Failed to make prediction.")
|
| 370 |
+
return
|
| 371 |
+
|
| 372 |
+
print(f"\nPredicted Total Daily Sales Quantity: {predicted_demand:.0f} units")
|
| 373 |
+
print("(This is the predicted total demand across all products for this date)")
|
| 374 |
+
else:
|
| 375 |
+
predicted_demand = predict_demand_ml(
|
| 376 |
+
args.product_id,
|
| 377 |
+
args.date,
|
| 378 |
+
args.price,
|
| 379 |
+
args.discount,
|
| 380 |
+
args.category,
|
| 381 |
+
model,
|
| 382 |
+
preprocessing_data
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
print(f"\nPredicted Sales Quantity: {predicted_demand:.0f} units")
|
| 386 |
+
print("(This is the predicted demand for this specific product)")
|
| 387 |
+
|
| 388 |
+
print("="*60)
|
| 389 |
+
|
| 390 |
+
# Additional information
|
| 391 |
+
date_obj = pd.to_datetime(args.date)
|
| 392 |
+
day_name = date_obj.strftime('%A')
|
| 393 |
+
is_weekend = "Yes" if date_obj.weekday() >= 5 else "No"
|
| 394 |
+
|
| 395 |
+
print(f"\nDate Information:")
|
| 396 |
+
print(f" Day of week: {day_name}")
|
| 397 |
+
print(f" Weekend: {is_weekend}")
|
| 398 |
+
print(f" Month: {date_obj.strftime('%B')}")
|
| 399 |
+
print(f" Quarter: Q{date_obj.quarter}")
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
if __name__ == "__main__":
|
| 403 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas>=1.5.0
|
| 2 |
+
numpy>=1.23.0
|
| 3 |
+
scikit-learn>=1.2.0
|
| 4 |
+
matplotlib>=3.6.0
|
| 5 |
+
seaborn>=0.12.0
|
| 6 |
+
joblib>=1.2.0
|
| 7 |
+
xgboost>=1.7.0
|
| 8 |
+
statsmodels>=0.14.0
|
| 9 |
+
prophet>=1.1.0
|
| 10 |
+
streamlit>=1.28.0
|
setup_env.bat
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
echo Creating virtual environment...
|
| 3 |
+
python -m venv venv
|
| 4 |
+
|
| 5 |
+
echo.
|
| 6 |
+
echo Activating virtual environment...
|
| 7 |
+
call venv\Scripts\activate.bat
|
| 8 |
+
|
| 9 |
+
echo.
|
| 10 |
+
echo Installing dependencies...
|
| 11 |
+
pip install --upgrade pip
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
|
| 14 |
+
echo.
|
| 15 |
+
echo ========================================
|
| 16 |
+
echo Setup complete!
|
| 17 |
+
echo ========================================
|
| 18 |
+
echo.
|
| 19 |
+
echo To activate the virtual environment in the future, run:
|
| 20 |
+
echo venv\Scripts\activate
|
| 21 |
+
echo.
|
| 22 |
+
echo To deactivate, run:
|
| 23 |
+
echo deactivate
|
| 24 |
+
echo.
|
| 25 |
+
|
| 26 |
+
pause
|
setup_env.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
echo "Creating virtual environment..."
|
| 4 |
+
python3 -m venv venv
|
| 5 |
+
|
| 6 |
+
echo ""
|
| 7 |
+
echo "Activating virtual environment..."
|
| 8 |
+
source venv/bin/activate
|
| 9 |
+
|
| 10 |
+
echo ""
|
| 11 |
+
echo "Installing dependencies..."
|
| 12 |
+
pip install --upgrade pip
|
| 13 |
+
pip install -r requirements.txt
|
| 14 |
+
|
| 15 |
+
echo ""
|
| 16 |
+
echo "========================================"
|
| 17 |
+
echo "Setup complete!"
|
| 18 |
+
echo "========================================"
|
| 19 |
+
echo ""
|
| 20 |
+
echo "To activate the virtual environment in the future, run:"
|
| 21 |
+
echo " source venv/bin/activate"
|
| 22 |
+
echo ""
|
| 23 |
+
echo "To deactivate, run:"
|
| 24 |
+
echo " deactivate"
|
| 25 |
+
echo ""
|
train_model.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demand Prediction System - Model Training Script
|
| 3 |
+
|
| 4 |
+
This script trains multiple machine learning and time-series models to predict
|
| 5 |
+
product demand (sales quantity) for an e-commerce platform.
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- Data preprocessing and feature engineering
|
| 9 |
+
- Date feature extraction (day, month, day_of_week, weekend)
|
| 10 |
+
- Categorical encoding
|
| 11 |
+
- Feature scaling
|
| 12 |
+
- Multiple ML models (Linear Regression, Random Forest, XGBoost)
|
| 13 |
+
- Time-series models (ARIMA, Prophet)
|
| 14 |
+
- Model evaluation (MAE, RMSE, R2 Score)
|
| 15 |
+
- Automatic best model selection
|
| 16 |
+
- Model persistence using joblib
|
| 17 |
+
- Visualization of results
|
| 18 |
+
- Comparison between ML and time-series approaches
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import pandas as pd
|
| 22 |
+
import numpy as np
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
import seaborn as sns
|
| 25 |
+
from datetime import datetime
|
| 26 |
+
import joblib
|
| 27 |
+
import os
|
| 28 |
+
import warnings
|
| 29 |
+
warnings.filterwarnings('ignore')
|
| 30 |
+
|
| 31 |
+
# Machine Learning imports
|
| 32 |
+
from sklearn.model_selection import train_test_split
|
| 33 |
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
| 34 |
+
from sklearn.linear_model import LinearRegression
|
| 35 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 36 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
| 37 |
+
|
| 38 |
+
# Try to import XGBoost (optional)
|
| 39 |
+
try:
|
| 40 |
+
import xgboost as xgb
|
| 41 |
+
XGBOOST_AVAILABLE = True
|
| 42 |
+
except ImportError:
|
| 43 |
+
XGBOOST_AVAILABLE = False
|
| 44 |
+
print("XGBoost not available. Install with: pip install xgboost")
|
| 45 |
+
|
| 46 |
+
# Try to import time-series libraries
|
| 47 |
+
try:
|
| 48 |
+
from statsmodels.tsa.arima.model import ARIMA
|
| 49 |
+
from statsmodels.tsa.stattools import adfuller
|
| 50 |
+
ARIMA_AVAILABLE = True
|
| 51 |
+
except ImportError:
|
| 52 |
+
ARIMA_AVAILABLE = False
|
| 53 |
+
print("statsmodels not available. Install with: pip install statsmodels")
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
from prophet import Prophet
|
| 57 |
+
PROPHET_AVAILABLE = True
|
| 58 |
+
except ImportError:
|
| 59 |
+
PROPHET_AVAILABLE = False
|
| 60 |
+
print("Prophet not available. Install with: pip install prophet")
|
| 61 |
+
|
| 62 |
+
# Set random seeds for reproducibility
|
| 63 |
+
np.random.seed(42)
|
| 64 |
+
|
| 65 |
+
# Configuration
|
| 66 |
+
DATA_PATH = 'data/sales.csv'
|
| 67 |
+
MODEL_DIR = 'models'
|
| 68 |
+
PLOTS_DIR = 'plots'
|
| 69 |
+
|
| 70 |
+
# Create directories if they don't exist
|
| 71 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 72 |
+
os.makedirs(PLOTS_DIR, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_data(file_path):
|
| 76 |
+
"""
|
| 77 |
+
Load the sales dataset from CSV file.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
file_path: Path to the CSV file
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
DataFrame: Loaded dataset
|
| 84 |
+
"""
|
| 85 |
+
print(f"Loading data from {file_path}...")
|
| 86 |
+
df = pd.read_csv(file_path)
|
| 87 |
+
print(f"Data loaded successfully! Shape: {df.shape}")
|
| 88 |
+
return df
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def preprocess_data(df):
|
| 92 |
+
"""
|
| 93 |
+
Preprocess the data: convert date, extract features, handle missing values.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
df: Raw DataFrame
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
DataFrame: Preprocessed DataFrame
|
| 100 |
+
"""
|
| 101 |
+
print("\n" + "="*60)
|
| 102 |
+
print("PREPROCESSING DATA")
|
| 103 |
+
print("="*60)
|
| 104 |
+
|
| 105 |
+
# Create a copy to avoid modifying original
|
| 106 |
+
df = df.copy()
|
| 107 |
+
|
| 108 |
+
# Convert date column to datetime
|
| 109 |
+
df['date'] = pd.to_datetime(df['date'])
|
| 110 |
+
|
| 111 |
+
# Extract date features
|
| 112 |
+
print("Extracting date features...")
|
| 113 |
+
df['day'] = df['date'].dt.day
|
| 114 |
+
df['month'] = df['date'].dt.month
|
| 115 |
+
df['day_of_week'] = df['date'].dt.dayofweek # 0=Monday, 6=Sunday
|
| 116 |
+
df['weekend'] = (df['day_of_week'] >= 5).astype(int) # 1 if weekend, 0 otherwise
|
| 117 |
+
df['year'] = df['date'].dt.year
|
| 118 |
+
df['quarter'] = df['date'].dt.quarter
|
| 119 |
+
|
| 120 |
+
# Check for missing values
|
| 121 |
+
print("\nMissing values:")
|
| 122 |
+
missing = df.isnull().sum()
|
| 123 |
+
print(missing[missing > 0])
|
| 124 |
+
|
| 125 |
+
if missing.sum() > 0:
|
| 126 |
+
print("Filling missing values...")
|
| 127 |
+
df = df.fillna(df.median(numeric_only=True))
|
| 128 |
+
|
| 129 |
+
# Display basic statistics
|
| 130 |
+
print("\nDataset Info:")
|
| 131 |
+
print(f"Shape: {df.shape}")
|
| 132 |
+
print(f"\nColumns: {df.columns.tolist()}")
|
| 133 |
+
print(f"\nData types:\n{df.dtypes}")
|
| 134 |
+
print(f"\nBasic statistics:\n{df.describe()}")
|
| 135 |
+
|
| 136 |
+
return df
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def feature_engineering(df):
|
| 140 |
+
"""
|
| 141 |
+
Perform feature engineering: encode categorical variables, scale features.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
df: Preprocessed DataFrame
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
tuple: (X_features, y_target, feature_names, encoders, scaler)
|
| 148 |
+
"""
|
| 149 |
+
print("\n" + "="*60)
|
| 150 |
+
print("FEATURE ENGINEERING")
|
| 151 |
+
print("="*60)
|
| 152 |
+
|
| 153 |
+
# Separate features and target
|
| 154 |
+
# Drop original date column (we have extracted features from it)
|
| 155 |
+
# Keep product_id for now (we'll encode it)
|
| 156 |
+
feature_columns = ['product_id', 'price', 'discount', 'category',
|
| 157 |
+
'day', 'month', 'day_of_week', 'weekend', 'year', 'quarter']
|
| 158 |
+
|
| 159 |
+
X = df[feature_columns].copy()
|
| 160 |
+
y = df['sales_quantity'].copy()
|
| 161 |
+
|
| 162 |
+
# Encode categorical variables
|
| 163 |
+
print("Encoding categorical variables...")
|
| 164 |
+
|
| 165 |
+
# Label encode category
|
| 166 |
+
category_encoder = LabelEncoder()
|
| 167 |
+
X['category_encoded'] = category_encoder.fit_transform(X['category'])
|
| 168 |
+
|
| 169 |
+
# Label encode product_id (treating it as categorical)
|
| 170 |
+
product_encoder = LabelEncoder()
|
| 171 |
+
X['product_id_encoded'] = product_encoder.fit_transform(X['product_id'])
|
| 172 |
+
|
| 173 |
+
# Drop original categorical columns
|
| 174 |
+
X = X.drop(['category', 'product_id'], axis=1)
|
| 175 |
+
|
| 176 |
+
# Get feature names
|
| 177 |
+
feature_names = X.columns.tolist()
|
| 178 |
+
|
| 179 |
+
print(f"Features after encoding: {feature_names}")
|
| 180 |
+
print(f"Number of features: {len(feature_names)}")
|
| 181 |
+
|
| 182 |
+
# Scale numerical features
|
| 183 |
+
print("\nScaling numerical features...")
|
| 184 |
+
scaler = StandardScaler()
|
| 185 |
+
X_scaled = scaler.fit_transform(X)
|
| 186 |
+
X_scaled = pd.DataFrame(X_scaled, columns=feature_names)
|
| 187 |
+
|
| 188 |
+
# Store encoders and scaler for later use
|
| 189 |
+
encoders = {
|
| 190 |
+
'category': category_encoder,
|
| 191 |
+
'product_id': product_encoder,
|
| 192 |
+
'scaler': scaler
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return X_scaled, y, feature_names, encoders, scaler
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def train_models(X_train, y_train, X_val, y_val):
|
| 199 |
+
"""
|
| 200 |
+
Train multiple models and return their performance metrics.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
X_train: Training features
|
| 204 |
+
y_train: Training target
|
| 205 |
+
X_val: Validation features
|
| 206 |
+
y_val: Validation target
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
dict: Dictionary containing models and their metrics
|
| 210 |
+
"""
|
| 211 |
+
print("\n" + "="*60)
|
| 212 |
+
print("TRAINING MODELS")
|
| 213 |
+
print("="*60)
|
| 214 |
+
|
| 215 |
+
models = {}
|
| 216 |
+
results = {}
|
| 217 |
+
|
| 218 |
+
# 1. Linear Regression
|
| 219 |
+
print("\n1. Training Linear Regression...")
|
| 220 |
+
lr_model = LinearRegression()
|
| 221 |
+
lr_model.fit(X_train, y_train)
|
| 222 |
+
lr_pred = lr_model.predict(X_val)
|
| 223 |
+
|
| 224 |
+
lr_mae = mean_absolute_error(y_val, lr_pred)
|
| 225 |
+
lr_rmse = np.sqrt(mean_squared_error(y_val, lr_pred))
|
| 226 |
+
lr_r2 = r2_score(y_val, lr_pred)
|
| 227 |
+
|
| 228 |
+
models['Linear Regression'] = lr_model
|
| 229 |
+
results['Linear Regression'] = {
|
| 230 |
+
'model': lr_model,
|
| 231 |
+
'mae': lr_mae,
|
| 232 |
+
'rmse': lr_rmse,
|
| 233 |
+
'r2': lr_r2,
|
| 234 |
+
'predictions': lr_pred
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
print(f" MAE: {lr_mae:.2f}, RMSE: {lr_rmse:.2f}, R2: {lr_r2:.4f}")
|
| 238 |
+
|
| 239 |
+
# 2. Random Forest Regressor
|
| 240 |
+
print("\n2. Training Random Forest Regressor...")
|
| 241 |
+
rf_model = RandomForestRegressor(
|
| 242 |
+
n_estimators=100,
|
| 243 |
+
max_depth=15,
|
| 244 |
+
min_samples_split=5,
|
| 245 |
+
min_samples_leaf=2,
|
| 246 |
+
random_state=42,
|
| 247 |
+
n_jobs=-1
|
| 248 |
+
)
|
| 249 |
+
rf_model.fit(X_train, y_train)
|
| 250 |
+
rf_pred = rf_model.predict(X_val)
|
| 251 |
+
|
| 252 |
+
rf_mae = mean_absolute_error(y_val, rf_pred)
|
| 253 |
+
rf_rmse = np.sqrt(mean_squared_error(y_val, rf_pred))
|
| 254 |
+
rf_r2 = r2_score(y_val, rf_pred)
|
| 255 |
+
|
| 256 |
+
models['Random Forest'] = rf_model
|
| 257 |
+
results['Random Forest'] = {
|
| 258 |
+
'model': rf_model,
|
| 259 |
+
'mae': rf_mae,
|
| 260 |
+
'rmse': rf_rmse,
|
| 261 |
+
'r2': rf_r2,
|
| 262 |
+
'predictions': rf_pred
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
print(f" MAE: {rf_mae:.2f}, RMSE: {rf_rmse:.2f}, R2: {rf_r2:.4f}")
|
| 266 |
+
|
| 267 |
+
# 3. XGBoost (if available)
|
| 268 |
+
if XGBOOST_AVAILABLE:
|
| 269 |
+
print("\n3. Training XGBoost Regressor...")
|
| 270 |
+
xgb_model = xgb.XGBRegressor(
|
| 271 |
+
n_estimators=100,
|
| 272 |
+
max_depth=6,
|
| 273 |
+
learning_rate=0.1,
|
| 274 |
+
random_state=42,
|
| 275 |
+
n_jobs=-1
|
| 276 |
+
)
|
| 277 |
+
xgb_model.fit(X_train, y_train)
|
| 278 |
+
xgb_pred = xgb_model.predict(X_val)
|
| 279 |
+
|
| 280 |
+
xgb_mae = mean_absolute_error(y_val, xgb_pred)
|
| 281 |
+
xgb_rmse = np.sqrt(mean_squared_error(y_val, xgb_pred))
|
| 282 |
+
xgb_r2 = r2_score(y_val, xgb_pred)
|
| 283 |
+
|
| 284 |
+
models['XGBoost'] = xgb_model
|
| 285 |
+
results['XGBoost'] = {
|
| 286 |
+
'model': xgb_model,
|
| 287 |
+
'mae': xgb_mae,
|
| 288 |
+
'rmse': xgb_rmse,
|
| 289 |
+
'r2': xgb_r2,
|
| 290 |
+
'predictions': xgb_pred
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
print(f" MAE: {xgb_mae:.2f}, RMSE: {xgb_rmse:.2f}, R2: {xgb_r2:.4f}")
|
| 294 |
+
else:
|
| 295 |
+
print("\n3. XGBoost skipped (not available)")
|
| 296 |
+
|
| 297 |
+
return results
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def prepare_time_series_data(df):
|
| 301 |
+
"""
|
| 302 |
+
Prepare time-series data by aggregating daily sales.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
df: DataFrame with date and sales_quantity columns
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
tuple: (ts_data, train_size) - time series data and training size
|
| 309 |
+
"""
|
| 310 |
+
print("\n" + "="*60)
|
| 311 |
+
print("PREPARING TIME-SERIES DATA")
|
| 312 |
+
print("="*60)
|
| 313 |
+
|
| 314 |
+
# Aggregate by date
|
| 315 |
+
df['date'] = pd.to_datetime(df['date'])
|
| 316 |
+
ts_data = df.groupby('date')['sales_quantity'].sum().reset_index()
|
| 317 |
+
ts_data = ts_data.sort_values('date').reset_index(drop=True)
|
| 318 |
+
ts_data.columns = ['ds', 'y'] # Prophet expects 'ds' and 'y'
|
| 319 |
+
|
| 320 |
+
print(f"Time-series data shape: {ts_data.shape}")
|
| 321 |
+
print(f"Date range: {ts_data['ds'].min()} to {ts_data['ds'].max()}")
|
| 322 |
+
print(f"Total days: {len(ts_data)}")
|
| 323 |
+
|
| 324 |
+
# Use 80% for training (chronological split for time-series)
|
| 325 |
+
train_size = int(len(ts_data) * 0.8)
|
| 326 |
+
|
| 327 |
+
return ts_data, train_size
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def train_arima(ts_data, train_size):
|
| 331 |
+
"""
|
| 332 |
+
Train ARIMA model on time-series data.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
ts_data: Time-series DataFrame with 'ds' and 'y' columns
|
| 336 |
+
train_size: Number of samples for training
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
dict: Model results dictionary
|
| 340 |
+
"""
|
| 341 |
+
if not ARIMA_AVAILABLE:
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
print("\n" + "="*60)
|
| 345 |
+
print("TRAINING ARIMA MODEL")
|
| 346 |
+
print("="*60)
|
| 347 |
+
|
| 348 |
+
try:
|
| 349 |
+
# Split data chronologically
|
| 350 |
+
train_data = ts_data['y'].iloc[:train_size].values
|
| 351 |
+
val_data = ts_data['y'].iloc[train_size:].values
|
| 352 |
+
val_dates = ts_data['ds'].iloc[train_size:].values
|
| 353 |
+
|
| 354 |
+
print(f"Training on {len(train_data)} samples")
|
| 355 |
+
print(f"Validating on {len(val_data)} samples")
|
| 356 |
+
|
| 357 |
+
# Try different ARIMA orders (p, d, q)
|
| 358 |
+
# Start with auto_arima-like approach - try common orders
|
| 359 |
+
best_aic = np.inf
|
| 360 |
+
best_order = None
|
| 361 |
+
best_model = None
|
| 362 |
+
|
| 363 |
+
# Common ARIMA orders to try
|
| 364 |
+
orders_to_try = [
|
| 365 |
+
(1, 1, 1), # Standard ARIMA(1,1,1)
|
| 366 |
+
(2, 1, 2), # ARIMA(2,1,2)
|
| 367 |
+
(1, 1, 0), # ARIMA(1,1,0) - AR model
|
| 368 |
+
(0, 1, 1), # ARIMA(0,1,1) - MA model
|
| 369 |
+
(2, 1, 1), # ARIMA(2,1,1)
|
| 370 |
+
(1, 1, 2), # ARIMA(1,1,2)
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
print("Trying different ARIMA orders...")
|
| 374 |
+
for order in orders_to_try:
|
| 375 |
+
try:
|
| 376 |
+
model = ARIMA(train_data, order=order)
|
| 377 |
+
fitted_model = model.fit()
|
| 378 |
+
aic = fitted_model.aic
|
| 379 |
+
|
| 380 |
+
if aic < best_aic:
|
| 381 |
+
best_aic = aic
|
| 382 |
+
best_order = order
|
| 383 |
+
best_model = fitted_model
|
| 384 |
+
print(f" Order {order}: AIC = {aic:.2f} (best so far)")
|
| 385 |
+
else:
|
| 386 |
+
print(f" Order {order}: AIC = {aic:.2f}")
|
| 387 |
+
except Exception as e:
|
| 388 |
+
print(f" Order {order}: Failed - {str(e)[:50]}")
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
if best_model is None:
|
| 392 |
+
print("Failed to fit ARIMA model with any order")
|
| 393 |
+
return None
|
| 394 |
+
|
| 395 |
+
print(f"\nBest ARIMA order: {best_order} (AIC: {best_aic:.2f})")
|
| 396 |
+
|
| 397 |
+
# Make predictions
|
| 398 |
+
forecast_steps = len(val_data)
|
| 399 |
+
forecast = best_model.forecast(steps=forecast_steps)
|
| 400 |
+
|
| 401 |
+
# Ensure predictions are non-negative
|
| 402 |
+
forecast = np.maximum(forecast, 0)
|
| 403 |
+
|
| 404 |
+
# Calculate metrics
|
| 405 |
+
mae = mean_absolute_error(val_data, forecast)
|
| 406 |
+
rmse = np.sqrt(mean_squared_error(val_data, forecast))
|
| 407 |
+
r2 = r2_score(val_data, forecast)
|
| 408 |
+
|
| 409 |
+
print(f" MAE: {mae:.2f}, RMSE: {rmse:.2f}, R2: {r2:.4f}")
|
| 410 |
+
|
| 411 |
+
return {
|
| 412 |
+
'model': best_model,
|
| 413 |
+
'order': best_order,
|
| 414 |
+
'mae': mae,
|
| 415 |
+
'rmse': rmse,
|
| 416 |
+
'r2': r2,
|
| 417 |
+
'predictions': forecast,
|
| 418 |
+
'actual': val_data,
|
| 419 |
+
'dates': val_dates
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print(f"Error training ARIMA: {str(e)}")
|
| 424 |
+
return None
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def train_prophet(ts_data, train_size):
|
| 428 |
+
"""
|
| 429 |
+
Train Prophet model on time-series data.
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
ts_data: Time-series DataFrame with 'ds' and 'y' columns
|
| 433 |
+
train_size: Number of samples for training
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
dict: Model results dictionary
|
| 437 |
+
"""
|
| 438 |
+
if not PROPHET_AVAILABLE:
|
| 439 |
+
return None
|
| 440 |
+
|
| 441 |
+
print("\n" + "="*60)
|
| 442 |
+
print("TRAINING PROPHET MODEL")
|
| 443 |
+
print("="*60)
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
# Split data chronologically
|
| 447 |
+
train_data = ts_data.iloc[:train_size].copy()
|
| 448 |
+
val_data = ts_data.iloc[train_size:].copy()
|
| 449 |
+
|
| 450 |
+
print(f"Training on {len(train_data)} samples")
|
| 451 |
+
print(f"Validating on {len(val_data)} samples")
|
| 452 |
+
|
| 453 |
+
# Initialize and fit Prophet model
|
| 454 |
+
# Enable daily seasonality and weekly/yearly seasonality
|
| 455 |
+
model = Prophet(
|
| 456 |
+
daily_seasonality=False, # Disable daily for daily data
|
| 457 |
+
weekly_seasonality=True,
|
| 458 |
+
yearly_seasonality=True,
|
| 459 |
+
seasonality_mode='multiplicative',
|
| 460 |
+
changepoint_prior_scale=0.05
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
print("Fitting Prophet model...")
|
| 464 |
+
model.fit(train_data)
|
| 465 |
+
|
| 466 |
+
# Create future dataframe for validation period
|
| 467 |
+
future = model.make_future_dataframe(periods=len(val_data), freq='D')
|
| 468 |
+
|
| 469 |
+
# Make predictions
|
| 470 |
+
forecast = model.predict(future)
|
| 471 |
+
|
| 472 |
+
# Get predictions for validation period
|
| 473 |
+
val_forecast = forecast.iloc[train_size:]['yhat'].values
|
| 474 |
+
val_actual = val_data['y'].values
|
| 475 |
+
|
| 476 |
+
# Ensure predictions are non-negative
|
| 477 |
+
val_forecast = np.maximum(val_forecast, 0)
|
| 478 |
+
|
| 479 |
+
# Calculate metrics
|
| 480 |
+
mae = mean_absolute_error(val_actual, val_forecast)
|
| 481 |
+
rmse = np.sqrt(mean_squared_error(val_actual, val_forecast))
|
| 482 |
+
r2 = r2_score(val_actual, val_forecast)
|
| 483 |
+
|
| 484 |
+
print(f" MAE: {mae:.2f}, RMSE: {rmse:.2f}, R2: {r2:.4f}")
|
| 485 |
+
|
| 486 |
+
return {
|
| 487 |
+
'model': model,
|
| 488 |
+
'mae': mae,
|
| 489 |
+
'rmse': rmse,
|
| 490 |
+
'r2': r2,
|
| 491 |
+
'predictions': val_forecast,
|
| 492 |
+
'actual': val_actual,
|
| 493 |
+
'dates': val_data['ds'].values,
|
| 494 |
+
'full_forecast': forecast
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
except Exception as e:
|
| 498 |
+
print(f"Error training Prophet: {str(e)}")
|
| 499 |
+
import traceback
|
| 500 |
+
traceback.print_exc()
|
| 501 |
+
return None
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def select_best_model(results):
|
| 505 |
+
"""
|
| 506 |
+
Select the best model based on R2 score (higher is better).
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
results: Dictionary containing model results
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
tuple: (best_model_name, best_model, best_metrics)
|
| 513 |
+
"""
|
| 514 |
+
print("\n" + "="*60)
|
| 515 |
+
print("MODEL COMPARISON")
|
| 516 |
+
print("="*60)
|
| 517 |
+
|
| 518 |
+
# Create comparison DataFrame
|
| 519 |
+
comparison_data = []
|
| 520 |
+
for model_name, metrics in results.items():
|
| 521 |
+
comparison_data.append({
|
| 522 |
+
'Model': model_name,
|
| 523 |
+
'MAE': metrics['mae'],
|
| 524 |
+
'RMSE': metrics['rmse'],
|
| 525 |
+
'R2 Score': metrics['r2']
|
| 526 |
+
})
|
| 527 |
+
|
| 528 |
+
comparison_df = pd.DataFrame(comparison_data)
|
| 529 |
+
print("\nModel Performance Comparison:")
|
| 530 |
+
print(comparison_df.to_string(index=False))
|
| 531 |
+
|
| 532 |
+
# Select best model based on R2 score
|
| 533 |
+
best_model_name = max(results.keys(), key=lambda x: results[x]['r2'])
|
| 534 |
+
best_model = results[best_model_name]['model']
|
| 535 |
+
best_metrics = {
|
| 536 |
+
'mae': results[best_model_name]['mae'],
|
| 537 |
+
'rmse': results[best_model_name]['rmse'],
|
| 538 |
+
'r2': results[best_model_name]['r2']
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
print(f"\n{'='*60}")
|
| 542 |
+
print(f"BEST MODEL: {best_model_name}")
|
| 543 |
+
print(f"MAE: {best_metrics['mae']:.2f}")
|
| 544 |
+
print(f"RMSE: {best_metrics['rmse']:.2f}")
|
| 545 |
+
print(f"R2 Score: {best_metrics['r2']:.4f}")
|
| 546 |
+
print(f"{'='*60}")
|
| 547 |
+
|
| 548 |
+
return best_model_name, best_model, best_metrics
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def visualize_results(df, results, best_model_name, feature_names):
|
| 552 |
+
"""
|
| 553 |
+
Create visualizations: demand trends, feature importance, model comparison.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
df: Original DataFrame
|
| 557 |
+
results: Model results dictionary
|
| 558 |
+
best_model_name: Name of the best model
|
| 559 |
+
feature_names: List of feature names
|
| 560 |
+
"""
|
| 561 |
+
print("\n" + "="*60)
|
| 562 |
+
print("GENERATING VISUALIZATIONS")
|
| 563 |
+
print("="*60)
|
| 564 |
+
|
| 565 |
+
# Set style
|
| 566 |
+
sns.set_style("whitegrid")
|
| 567 |
+
plt.rcParams['figure.figsize'] = (12, 6)
|
| 568 |
+
|
| 569 |
+
# 1. Demand trends over time
|
| 570 |
+
print("1. Plotting demand trends over time...")
|
| 571 |
+
df['date'] = pd.to_datetime(df['date'])
|
| 572 |
+
daily_demand = df.groupby('date')['sales_quantity'].sum().reset_index()
|
| 573 |
+
|
| 574 |
+
plt.figure(figsize=(14, 6))
|
| 575 |
+
plt.plot(daily_demand['date'], daily_demand['sales_quantity'], linewidth=1, alpha=0.7)
|
| 576 |
+
plt.title('Total Daily Sales Quantity Over Time', fontsize=16, fontweight='bold')
|
| 577 |
+
plt.xlabel('Date', fontsize=12)
|
| 578 |
+
plt.ylabel('Total Sales Quantity', fontsize=12)
|
| 579 |
+
plt.grid(True, alpha=0.3)
|
| 580 |
+
plt.tight_layout()
|
| 581 |
+
plt.savefig(f'{PLOTS_DIR}/demand_trends.png', dpi=300, bbox_inches='tight')
|
| 582 |
+
print(f" Saved: {PLOTS_DIR}/demand_trends.png")
|
| 583 |
+
plt.close()
|
| 584 |
+
|
| 585 |
+
# 2. Monthly average demand
|
| 586 |
+
print("2. Plotting monthly average demand...")
|
| 587 |
+
df['month_name'] = pd.to_datetime(df['date']).dt.strftime('%B')
|
| 588 |
+
monthly_avg = df.groupby('month')['sales_quantity'].mean().reset_index()
|
| 589 |
+
month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
|
| 590 |
+
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
|
| 591 |
+
monthly_avg['month_name'] = monthly_avg['month'].apply(lambda x: month_names[x-1])
|
| 592 |
+
|
| 593 |
+
plt.figure(figsize=(12, 6))
|
| 594 |
+
plt.bar(monthly_avg['month_name'], monthly_avg['sales_quantity'], color='steelblue', alpha=0.7)
|
| 595 |
+
plt.title('Average Sales Quantity by Month', fontsize=16, fontweight='bold')
|
| 596 |
+
plt.xlabel('Month', fontsize=12)
|
| 597 |
+
plt.ylabel('Average Sales Quantity', fontsize=12)
|
| 598 |
+
plt.xticks(rotation=45)
|
| 599 |
+
plt.grid(True, alpha=0.3, axis='y')
|
| 600 |
+
plt.tight_layout()
|
| 601 |
+
plt.savefig(f'{PLOTS_DIR}/monthly_demand.png', dpi=300, bbox_inches='tight')
|
| 602 |
+
print(f" Saved: {PLOTS_DIR}/monthly_demand.png")
|
| 603 |
+
plt.close()
|
| 604 |
+
|
| 605 |
+
# 3. Feature importance (for tree-based models)
|
| 606 |
+
print("3. Plotting feature importance...")
|
| 607 |
+
best_model = results[best_model_name]['model']
|
| 608 |
+
|
| 609 |
+
if hasattr(best_model, 'feature_importances_'):
|
| 610 |
+
importances = best_model.feature_importances_
|
| 611 |
+
feature_importance_df = pd.DataFrame({
|
| 612 |
+
'feature': feature_names,
|
| 613 |
+
'importance': importances
|
| 614 |
+
}).sort_values('importance', ascending=False)
|
| 615 |
+
|
| 616 |
+
plt.figure(figsize=(10, 6))
|
| 617 |
+
plt.barh(feature_importance_df['feature'], feature_importance_df['importance'], color='coral', alpha=0.7)
|
| 618 |
+
plt.title(f'Feature Importance - {best_model_name}', fontsize=16, fontweight='bold')
|
| 619 |
+
plt.xlabel('Importance', fontsize=12)
|
| 620 |
+
plt.ylabel('Feature', fontsize=12)
|
| 621 |
+
plt.gca().invert_yaxis()
|
| 622 |
+
plt.grid(True, alpha=0.3, axis='x')
|
| 623 |
+
plt.tight_layout()
|
| 624 |
+
plt.savefig(f'{PLOTS_DIR}/feature_importance.png', dpi=300, bbox_inches='tight')
|
| 625 |
+
print(f" Saved: {PLOTS_DIR}/feature_importance.png")
|
| 626 |
+
plt.close()
|
| 627 |
+
else:
|
| 628 |
+
print(" Feature importance not available for this model type")
|
| 629 |
+
|
| 630 |
+
# 4. Model comparison
|
| 631 |
+
print("4. Plotting model comparison...")
|
| 632 |
+
model_names = list(results.keys())
|
| 633 |
+
mae_scores = [results[m]['mae'] for m in model_names]
|
| 634 |
+
rmse_scores = [results[m]['rmse'] for m in model_names]
|
| 635 |
+
r2_scores = [results[m]['r2'] for m in model_names]
|
| 636 |
+
|
| 637 |
+
# Separate ML and time-series models for visualization
|
| 638 |
+
ml_models = [m for m in model_names if m not in ['ARIMA', 'Prophet']]
|
| 639 |
+
ts_models = [m for m in model_names if m in ['ARIMA', 'Prophet']]
|
| 640 |
+
|
| 641 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| 642 |
+
|
| 643 |
+
# Color code: ML models in blue tones, TS models in orange/red tones
|
| 644 |
+
colors = []
|
| 645 |
+
for m in model_names:
|
| 646 |
+
if m in ts_models:
|
| 647 |
+
colors.append('coral' if m == 'ARIMA' else 'salmon')
|
| 648 |
+
else:
|
| 649 |
+
colors.append('skyblue')
|
| 650 |
+
|
| 651 |
+
# MAE comparison
|
| 652 |
+
axes[0].bar(model_names, mae_scores, color=colors, alpha=0.7)
|
| 653 |
+
axes[0].set_title('MAE Comparison (Lower is Better)', fontsize=14, fontweight='bold')
|
| 654 |
+
axes[0].set_ylabel('MAE', fontsize=12)
|
| 655 |
+
axes[0].tick_params(axis='x', rotation=45)
|
| 656 |
+
axes[0].grid(True, alpha=0.3, axis='y')
|
| 657 |
+
# Add legend
|
| 658 |
+
from matplotlib.patches import Patch
|
| 659 |
+
legend_elements = [
|
| 660 |
+
Patch(facecolor='skyblue', alpha=0.7, label='ML Models'),
|
| 661 |
+
Patch(facecolor='coral', alpha=0.7, label='Time-Series Models')
|
| 662 |
+
]
|
| 663 |
+
axes[0].legend(handles=legend_elements, loc='upper right')
|
| 664 |
+
|
| 665 |
+
# RMSE comparison
|
| 666 |
+
axes[1].bar(model_names, rmse_scores, color=colors, alpha=0.7)
|
| 667 |
+
axes[1].set_title('RMSE Comparison (Lower is Better)', fontsize=14, fontweight='bold')
|
| 668 |
+
axes[1].set_ylabel('RMSE', fontsize=12)
|
| 669 |
+
axes[1].tick_params(axis='x', rotation=45)
|
| 670 |
+
axes[1].grid(True, alpha=0.3, axis='y')
|
| 671 |
+
|
| 672 |
+
# R2 comparison
|
| 673 |
+
axes[2].bar(model_names, r2_scores, color=colors, alpha=0.7)
|
| 674 |
+
axes[2].set_title('R2 Score Comparison (Higher is Better)', fontsize=14, fontweight='bold')
|
| 675 |
+
axes[2].set_ylabel('R2 Score', fontsize=12)
|
| 676 |
+
axes[2].tick_params(axis='x', rotation=45)
|
| 677 |
+
axes[2].grid(True, alpha=0.3, axis='y')
|
| 678 |
+
|
| 679 |
+
plt.tight_layout()
|
| 680 |
+
plt.savefig(f'{PLOTS_DIR}/model_comparison.png', dpi=300, bbox_inches='tight')
|
| 681 |
+
print(f" Saved: {PLOTS_DIR}/model_comparison.png")
|
| 682 |
+
plt.close()
|
| 683 |
+
|
| 684 |
+
# 5. Time-series predictions plot (if time-series models available)
|
| 685 |
+
if ts_models:
|
| 686 |
+
print("5. Plotting time-series model predictions...")
|
| 687 |
+
fig, axes = plt.subplots(len(ts_models), 1, figsize=(14, 6*len(ts_models)))
|
| 688 |
+
if len(ts_models) == 1:
|
| 689 |
+
axes = [axes]
|
| 690 |
+
|
| 691 |
+
for idx, model_name in enumerate(ts_models):
|
| 692 |
+
if model_name in results and 'dates' in results[model_name]:
|
| 693 |
+
dates = pd.to_datetime(results[model_name]['dates'])
|
| 694 |
+
actual = results[model_name]['actual']
|
| 695 |
+
predictions = results[model_name]['predictions']
|
| 696 |
+
|
| 697 |
+
axes[idx].plot(dates, actual, label='Actual', linewidth=2, alpha=0.7)
|
| 698 |
+
axes[idx].plot(dates, predictions, label='Predicted', linewidth=2, alpha=0.7, linestyle='--')
|
| 699 |
+
axes[idx].set_title(f'{model_name} - Actual vs Predicted', fontsize=14, fontweight='bold')
|
| 700 |
+
axes[idx].set_xlabel('Date', fontsize=12)
|
| 701 |
+
axes[idx].set_ylabel('Sales Quantity', fontsize=12)
|
| 702 |
+
axes[idx].legend()
|
| 703 |
+
axes[idx].grid(True, alpha=0.3)
|
| 704 |
+
|
| 705 |
+
plt.tight_layout()
|
| 706 |
+
plt.savefig(f'{PLOTS_DIR}/timeseries_predictions.png', dpi=300, bbox_inches='tight')
|
| 707 |
+
print(f" Saved: {PLOTS_DIR}/timeseries_predictions.png")
|
| 708 |
+
plt.close()
|
| 709 |
+
|
| 710 |
+
print(" Visualization complete!")
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def save_model(model, encoders, scaler, feature_names, best_model_name, best_metrics):
|
| 714 |
+
"""
|
| 715 |
+
Save the trained model and preprocessing objects.
|
| 716 |
+
|
| 717 |
+
Args:
|
| 718 |
+
model: Trained model
|
| 719 |
+
encoders: Dictionary of encoders
|
| 720 |
+
scaler: Fitted scaler
|
| 721 |
+
feature_names: List of feature names
|
| 722 |
+
best_model_name: Name of the best model
|
| 723 |
+
best_metrics: Dictionary of metrics
|
| 724 |
+
"""
|
| 725 |
+
print("\n" + "="*60)
|
| 726 |
+
print("SAVING MODEL")
|
| 727 |
+
print("="*60)
|
| 728 |
+
|
| 729 |
+
# Save model
|
| 730 |
+
model_path = f'{MODEL_DIR}/best_model.joblib'
|
| 731 |
+
joblib.dump(model, model_path)
|
| 732 |
+
print(f"Model saved to: {model_path}")
|
| 733 |
+
|
| 734 |
+
# Save encoders and scaler
|
| 735 |
+
preprocessing_path = f'{MODEL_DIR}/preprocessing.joblib'
|
| 736 |
+
preprocessing_data = {
|
| 737 |
+
'encoders': encoders,
|
| 738 |
+
'scaler': scaler,
|
| 739 |
+
'feature_names': feature_names
|
| 740 |
+
}
|
| 741 |
+
joblib.dump(preprocessing_data, preprocessing_path)
|
| 742 |
+
print(f"Preprocessing objects saved to: {preprocessing_path}")
|
| 743 |
+
|
| 744 |
+
# Save model metadata
|
| 745 |
+
metadata = {
|
| 746 |
+
'model_name': best_model_name,
|
| 747 |
+
'metrics': best_metrics,
|
| 748 |
+
'feature_names': feature_names,
|
| 749 |
+
'saved_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
import json
|
| 753 |
+
metadata_path = f'{MODEL_DIR}/model_metadata.json'
|
| 754 |
+
with open(metadata_path, 'w') as f:
|
| 755 |
+
json.dump(metadata, f, indent=4)
|
| 756 |
+
print(f"Model metadata saved to: {metadata_path}")
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def main():
|
| 760 |
+
"""
|
| 761 |
+
Main function to orchestrate the training pipeline.
|
| 762 |
+
"""
|
| 763 |
+
print("\n" + "="*60)
|
| 764 |
+
print("DEMAND PREDICTION SYSTEM - MODEL TRAINING")
|
| 765 |
+
print("ML Models vs Time-Series Models Comparison")
|
| 766 |
+
print("="*60)
|
| 767 |
+
|
| 768 |
+
# Step 1: Load data
|
| 769 |
+
df = load_data(DATA_PATH)
|
| 770 |
+
|
| 771 |
+
# Step 2: Preprocess data
|
| 772 |
+
df_processed = preprocess_data(df)
|
| 773 |
+
|
| 774 |
+
# Step 3: Feature engineering for ML models
|
| 775 |
+
X, y, feature_names, encoders, scaler = feature_engineering(df_processed)
|
| 776 |
+
|
| 777 |
+
# Step 4: Split data for ML models (random split)
|
| 778 |
+
print("\n" + "="*60)
|
| 779 |
+
print("SPLITTING DATA FOR ML MODELS")
|
| 780 |
+
print("="*60)
|
| 781 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 782 |
+
X, y, test_size=0.2, random_state=42
|
| 783 |
+
)
|
| 784 |
+
print(f"Training set: {X_train.shape[0]} samples")
|
| 785 |
+
print(f"Validation set: {X_val.shape[0]} samples")
|
| 786 |
+
|
| 787 |
+
# Step 5: Train ML models
|
| 788 |
+
print("\n" + "="*70)
|
| 789 |
+
print("TRAINING MACHINE LEARNING MODELS")
|
| 790 |
+
print("="*70)
|
| 791 |
+
results = train_models(X_train, y_train, X_val, y_val)
|
| 792 |
+
|
| 793 |
+
# Step 6: Prepare time-series data
|
| 794 |
+
ts_data, train_size = prepare_time_series_data(df_processed)
|
| 795 |
+
|
| 796 |
+
# Step 7: Train time-series models
|
| 797 |
+
print("\n" + "="*70)
|
| 798 |
+
print("TRAINING TIME-SERIES MODELS")
|
| 799 |
+
print("="*70)
|
| 800 |
+
|
| 801 |
+
# Train ARIMA
|
| 802 |
+
if ARIMA_AVAILABLE:
|
| 803 |
+
arima_results = train_arima(ts_data, train_size)
|
| 804 |
+
if arima_results:
|
| 805 |
+
results['ARIMA'] = arima_results
|
| 806 |
+
else:
|
| 807 |
+
print("\nARIMA skipped (statsmodels not available)")
|
| 808 |
+
|
| 809 |
+
# Train Prophet
|
| 810 |
+
if PROPHET_AVAILABLE:
|
| 811 |
+
prophet_results = train_prophet(ts_data, train_size)
|
| 812 |
+
if prophet_results:
|
| 813 |
+
results['Prophet'] = prophet_results
|
| 814 |
+
else:
|
| 815 |
+
print("\nProphet skipped (prophet not available)")
|
| 816 |
+
|
| 817 |
+
# Step 8: Select best model (across all model types)
|
| 818 |
+
best_model_name, best_model, best_metrics = select_best_model(results)
|
| 819 |
+
|
| 820 |
+
# Step 9: Visualize results
|
| 821 |
+
visualize_results(df_processed, results, best_model_name, feature_names)
|
| 822 |
+
|
| 823 |
+
# Step 10: Save model (only ML models can be saved with preprocessing)
|
| 824 |
+
# For time-series models, save separately
|
| 825 |
+
if best_model_name not in ['ARIMA', 'Prophet']:
|
| 826 |
+
save_model(best_model, encoders, scaler, feature_names, best_model_name, best_metrics)
|
| 827 |
+
else:
|
| 828 |
+
# Save time-series model separately
|
| 829 |
+
print("\n" + "="*60)
|
| 830 |
+
print("SAVING TIME-SERIES MODEL")
|
| 831 |
+
print("="*60)
|
| 832 |
+
ts_model_path = f'{MODEL_DIR}/best_timeseries_model.joblib'
|
| 833 |
+
joblib.dump(best_model, ts_model_path)
|
| 834 |
+
print(f"Time-series model saved to: {ts_model_path}")
|
| 835 |
+
|
| 836 |
+
# Also save preprocessing for ML models (in case user wants to use them)
|
| 837 |
+
preprocessing_path = f'{MODEL_DIR}/preprocessing.joblib'
|
| 838 |
+
preprocessing_data = {
|
| 839 |
+
'encoders': encoders,
|
| 840 |
+
'scaler': scaler,
|
| 841 |
+
'feature_names': feature_names
|
| 842 |
+
}
|
| 843 |
+
joblib.dump(preprocessing_data, preprocessing_path)
|
| 844 |
+
print(f"ML preprocessing objects saved to: {preprocessing_path}")
|
| 845 |
+
|
| 846 |
+
# Save all results metadata
|
| 847 |
+
import json
|
| 848 |
+
all_models_metadata = {
|
| 849 |
+
'best_model': best_model_name,
|
| 850 |
+
'best_metrics': best_metrics,
|
| 851 |
+
'all_models': {}
|
| 852 |
+
}
|
| 853 |
+
for model_name, model_results in results.items():
|
| 854 |
+
all_models_metadata['all_models'][model_name] = {
|
| 855 |
+
'mae': model_results['mae'],
|
| 856 |
+
'rmse': model_results['rmse'],
|
| 857 |
+
'r2': model_results['r2']
|
| 858 |
+
}
|
| 859 |
+
all_models_metadata['saved_at'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 860 |
+
|
| 861 |
+
metadata_path = f'{MODEL_DIR}/all_models_metadata.json'
|
| 862 |
+
with open(metadata_path, 'w') as f:
|
| 863 |
+
json.dump(all_models_metadata, f, indent=4)
|
| 864 |
+
print(f"All models metadata saved to: {metadata_path}")
|
| 865 |
+
|
| 866 |
+
print("\n" + "="*60)
|
| 867 |
+
print("TRAINING COMPLETE!")
|
| 868 |
+
print("="*60)
|
| 869 |
+
print(f"\nBest model: {best_model_name}")
|
| 870 |
+
print(f"Model type: {'Time-Series' if best_model_name in ['ARIMA', 'Prophet'] else 'Machine Learning'}")
|
| 871 |
+
print(f"Model saved to: {MODEL_DIR}/")
|
| 872 |
+
print(f"Visualizations saved to: {PLOTS_DIR}/")
|
| 873 |
+
print("\nYou can now use predict.py to make predictions!")
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
if __name__ == "__main__":
|
| 877 |
+
main()
|