vikaswebdev commited on
Commit
7f90ea0
·
verified ·
1 Parent(s): 7637a10

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ plots/demand_trends.png filter=lfs diff=lfs merge=lfs -text
37
+ plots/feature_importance.png filter=lfs diff=lfs merge=lfs -text
38
+ plots/model_comparison.png filter=lfs diff=lfs merge=lfs -text
39
+ plots/monthly_demand.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demand Prediction System for E-commerce
2
+
3
+ A complete machine learning and time-series forecasting system for predicting product demand (sales quantity) in e-commerce. Compares both supervised learning regression models and time-series models (ARIMA, Prophet) to find the best approach.
4
+
5
+ ## 📋 Table of Contents
6
+
7
+ - [Overview](#overview)
8
+ - [Features](#features)
9
+ - [Project Structure](#project-structure)
10
+ - [Installation](#installation)
11
+ - [Dataset](#dataset)
12
+ - [Usage](#usage)
13
+ - [Model Details](#model-details)
14
+ - [Evaluation Metrics](#evaluation-metrics)
15
+ - [Visualizations](#visualizations)
16
+ - [Example Predictions](#example-predictions)
17
+ - [Technical Details](#technical-details)
18
+
19
+ ## 🎯 Overview
20
+
21
+ This project implements a demand prediction system that uses historical sales data to forecast future product demand. The system compares two approaches:
22
+
23
+ 1. **Machine Learning Models**: Treat demand prediction as a regression problem using product features (price, discount, category, date features)
24
+ 2. **Time-Series Models**: Treat demand prediction as a time-series problem using historical patterns (ARIMA, Prophet)
25
+
26
+ The system automatically selects the best performing model across both approaches.
27
+
28
+ **Key Capabilities:**
29
+ - Predicts sales quantity for products on future dates (ML models)
30
+ - Predicts overall daily demand (Time-series models)
31
+ - Handles temporal patterns and seasonality
32
+ - Considers price, discount, category, and date features (ML models)
33
+ - Captures time-series patterns and trends (Time-series models)
34
+ - Automatically selects the best model from multiple candidates
35
+ - Provides comprehensive evaluation metrics
36
+ - Compares ML vs Time-Series approaches
37
+
38
+ ## ✨ Features
39
+
40
+ - **Data Preprocessing**: Automatic handling of missing values, date feature extraction
41
+ - **Feature Engineering**:
42
+ - Date features (day, month, day_of_week, weekend, year, quarter)
43
+ - Categorical encoding (product_id, category)
44
+ - Feature scaling
45
+ - **Multiple Models**:
46
+ - **Machine Learning Models:**
47
+ - Linear Regression
48
+ - Random Forest Regressor
49
+ - XGBoost Regressor (optional)
50
+ - **Time-Series Models:**
51
+ - ARIMA (AutoRegressive Integrated Moving Average)
52
+ - Prophet (Facebook's time-series forecasting tool)
53
+ - **Model Selection**: Automatic best model selection based on R2 score
54
+ - **Evaluation Metrics**: MAE, RMSE, and R2 Score
55
+ - **Visualizations**:
56
+ - Demand trends over time
57
+ - Monthly average demand
58
+ - Feature importance
59
+ - Model comparison
60
+ - **Model Persistence**: Save and load trained models using joblib
61
+ - **Future Predictions**: Predict demand for any product on any future date
62
+
63
+ ## 📁 Project Structure
64
+
65
+ ```
66
+ demand_prediction/
67
+
68
+ ├── data/
69
+ │ └── sales.csv # Sales dataset
70
+
71
+ ├── models/ # Generated during training
72
+ │ ├── best_model.joblib # Best ML model (if ML is best)
73
+ │ ├── best_timeseries_model.joblib # Best time-series model (if TS is best)
74
+ │ ├── preprocessing.joblib # Encoders and scaler (for ML models)
75
+ │ ├── model_metadata.json # Model metadata (legacy)
76
+ │ └── all_models_metadata.json # All models comparison metadata
77
+
78
+ ├── plots/ # Generated during training
79
+ │ ├── demand_trends.png # Time series plot
80
+ │ ├── monthly_demand.png # Monthly averages
81
+ │ ├── feature_importance.png # Feature importance (ML models)
82
+ │ ├── model_comparison.png # Model metrics comparison (all models)
83
+ │ └── timeseries_predictions.png # Time-series model predictions
84
+
85
+ ├── generate_dataset.py # Script to generate synthetic dataset
86
+ ├── train_model.py # Main training script
87
+ ├── predict.py # Prediction script
88
+ ├── app.py # Streamlit dashboard (interactive web app)
89
+ ├── requirements.txt # Python dependencies
90
+ └── README.md # This file
91
+ ```
92
+
93
+ ## 🚀 Installation
94
+
95
+ ### Prerequisites
96
+
97
+ - Python 3.8 or higher
98
+ - pip (Python package manager)
99
+
100
+ ### Step 1: Navigate to Project Directory
101
+
102
+ ```bash
103
+ cd demand_prediction
104
+ ```
105
+
106
+ ### Step 2: Create Virtual Environment (Recommended)
107
+
108
+ **Why use a virtual environment?**
109
+ - Keeps project dependencies isolated from your system Python
110
+ - Prevents conflicts with other projects
111
+ - Makes it easier to manage package versions
112
+ - Best practice for Python projects
113
+
114
+ **Quick Setup (Recommended):**
115
+
116
+ **Windows:**
117
+ ```bash
118
+ setup_env.bat
119
+ ```
120
+
121
+ **Linux/Mac:**
122
+ ```bash
123
+ chmod +x setup_env.sh
124
+ ./setup_env.sh
125
+ ```
126
+
127
+ **Manual Setup:**
128
+
129
+ **Windows:**
130
+ ```bash
131
+ python -m venv venv
132
+ venv\Scripts\activate
133
+ ```
134
+
135
+ **Linux/Mac:**
136
+ ```bash
137
+ python3 -m venv venv
138
+ source venv/bin/activate
139
+ ```
140
+
141
+ After activation, you should see `(venv)` in your terminal prompt.
142
+
143
+ **To deactivate later:**
144
+ ```bash
145
+ deactivate
146
+ ```
147
+
148
+ ### Step 3: Install Dependencies
149
+
150
+ ```bash
151
+ pip install -r requirements.txt
152
+ ```
153
+
154
+ **Note**: If you don't want to use XGBoost, you can remove it from `requirements.txt`. The system will work fine without it, just skipping XGBoost model training.
155
+
156
+ **Alternative (without virtual environment):**
157
+ If you prefer not to use a virtual environment, you can install directly:
158
+ ```bash
159
+ pip install -r requirements.txt
160
+ ```
161
+ However, this is **not recommended** as it may cause conflicts with other Python projects.
162
+
163
+ ### Step 4: Generate Dataset
164
+
165
+ If you don't have a dataset, generate a synthetic one:
166
+
167
+ ```bash
168
+ python generate_dataset.py
169
+ ```
170
+
171
+ This will create `data/sales.csv` with realistic e-commerce sales data.
172
+
173
+ ## 📊 Dataset
174
+
175
+ The dataset should contain the following columns:
176
+
177
+ - **product_id**: Unique identifier for each product (integer)
178
+ - **date**: Date of sale (YYYY-MM-DD format)
179
+ - **price**: Product price (float)
180
+ - **discount**: Discount percentage (0-100, float)
181
+ - **category**: Product category (string)
182
+ - **sales_quantity**: Target variable - number of units sold (integer)
183
+
184
+ ### Dataset Format Example
185
+
186
+ ```csv
187
+ product_id,date,price,discount,category,sales_quantity
188
+ 1,2020-01-01,499.99,10,Electronics,45
189
+ 2,2020-01-01,29.99,0,Clothing,120
190
+ ...
191
+ ```
192
+
193
+ ## 💻 Usage
194
+
195
+ ### Step 1: Train the Model
196
+
197
+ Train the model using the sales dataset:
198
+
199
+ ```bash
200
+ python train_model.py
201
+ ```
202
+
203
+ This will:
204
+ 1. Load and preprocess the data
205
+ 2. Extract features from dates
206
+ 3. Encode categorical variables
207
+ 4. Train multiple ML models (Linear Regression, Random Forest, XGBoost)
208
+ 5. Prepare time-series data (aggregate daily sales)
209
+ 6. Train time-series models (ARIMA, Prophet)
210
+ 7. Evaluate each model using MAE, RMSE, and R2 Score
211
+ 8. Compare ML vs Time-Series models
212
+ 9. Select the best model automatically (across all model types)
213
+ 10. Save the model and preprocessing objects
214
+ 11. Generate visualizations
215
+
216
+ **Output:**
217
+ - Best model saved to `models/best_model.joblib` (ML) or `models/best_timeseries_model.joblib` (TS)
218
+ - Preprocessing objects saved to `models/preprocessing.joblib` (for ML models)
219
+ - Visualizations saved to `plots/` directory
220
+ - All models metadata saved to `models/all_models_metadata.json`
221
+
222
+ ### Step 2: Make Predictions
223
+
224
+ **For ML Models (product-specific predictions):**
225
+
226
+ Predict demand for a specific product on a date:
227
+
228
+ ```bash
229
+ python predict.py --product_id 1 --date 2024-01-15 --price 100 --discount 10 --category Electronics
230
+ ```
231
+
232
+ **Parameters for ML Models:**
233
+ - `--product_id`: Product ID (integer, required)
234
+ - `--date`: Date in YYYY-MM-DD format (required)
235
+ - `--price`: Product price (float, required)
236
+ - `--discount`: Discount percentage 0-100 (float, default: 0)
237
+ - `--category`: Product category (string, required)
238
+ - `--model_type`: Model type - `auto` (default), `ml`, or `timeseries`
239
+
240
+ **For Time-Series Models (overall daily demand):**
241
+
242
+ Predict total daily demand across all products:
243
+
244
+ ```bash
245
+ python predict.py --date 2024-01-15 --model_type timeseries
246
+ ```
247
+
248
+ **Parameters for Time-Series Models:**
249
+ - `--date`: Date in YYYY-MM-DD format (required)
250
+ - `--model_type`: Set to `timeseries` to use time-series models
251
+
252
+ **Example Predictions:**
253
+
254
+ ```bash
255
+ # ML Model - Electronics product with discount
256
+ python predict.py --product_id 1 --date 2024-06-15 --price 500 --discount 20 --category Electronics
257
+
258
+ # ML Model - Clothing product without discount
259
+ python predict.py --product_id 5 --date 2024-12-25 --price 50 --discount 0 --category Clothing
260
+
261
+ # Time-Series Model - Overall daily demand
262
+ python predict.py --date 2024-07-06 --model_type timeseries
263
+
264
+ # Auto-detect best model (default)
265
+ python predict.py --product_id 10 --date 2024-07-06 --price 75 --discount 15 --category Sports
266
+ ```
267
+
268
+ ### Step 3: Launch Interactive Dashboard (Optional)
269
+
270
+ Launch the Streamlit dashboard for interactive visualization and predictions:
271
+
272
+ ```bash
273
+ streamlit run app.py
274
+ ```
275
+
276
+ The dashboard will open in your default web browser (usually at `http://localhost:8501`).
277
+
278
+ **Dashboard Features:**
279
+
280
+ 1. **📈 Sales Trends Page**
281
+ - Interactive filters (category, product, date range)
282
+ - Daily sales trends visualization
283
+ - Monthly sales trends
284
+ - Category-wise analysis
285
+ - Price vs demand relationship
286
+ - Real-time statistics and metrics
287
+
288
+ 2. **🔮 Demand Prediction Page**
289
+ - Interactive prediction interface
290
+ - Select model type (Auto/ML/Time-Series)
291
+ - For ML models:
292
+ - Product selection dropdown
293
+ - Category selection
294
+ - Price and discount sliders
295
+ - Date picker
296
+ - Product statistics display
297
+ - For Time-Series models:
298
+ - Date picker for future predictions
299
+ - Overall daily demand forecast
300
+ - Prediction insights and recommendations
301
+
302
+ 3. **📊 Model Comparison Page**
303
+ - Side-by-side model performance comparison
304
+ - MAE, RMSE, and R2 Score metrics
305
+ - Visual charts comparing all models
306
+ - Best model highlighting
307
+ - Model type indicators (ML vs Time-Series)
308
+
309
+ **Dashboard Screenshots:**
310
+ - Interactive widgets for easy data exploration
311
+ - Real-time predictions with visual feedback
312
+ - Comprehensive model comparison charts
313
+
314
+ ## 🤖 Model Details
315
+
316
+ ### Models Trained
317
+
318
+ 1. **Linear Regression**
319
+ - Simple linear model
320
+ - Fast training and prediction
321
+ - Good baseline model
322
+
323
+ 2. **Random Forest Regressor**
324
+ - Ensemble of decision trees
325
+ - Handles non-linear relationships
326
+ - Provides feature importance
327
+ - Hyperparameters:
328
+ - n_estimators: 100
329
+ - max_depth: 15
330
+ - min_samples_split: 5
331
+ - min_samples_leaf: 2
332
+
333
+ 3. **XGBoost Regressor** (Optional)
334
+ - Gradient boosting algorithm
335
+ - Often provides best performance
336
+ - Handles complex patterns
337
+ - Hyperparameters:
338
+ - n_estimators: 100
339
+ - max_depth: 6
340
+ - learning_rate: 0.1
341
+
342
+ 4. **ARIMA** (AutoRegressive Integrated Moving Average)
343
+ - Classic time-series forecasting model
344
+ - Captures trends and seasonality
345
+ - Automatically selects best order (p, d, q)
346
+ - Works on aggregated daily sales data
347
+ - Uses chronological train/validation split
348
+
349
+ 5. **Prophet** (Facebook's Time-Series Forecasting)
350
+ - Designed for business time series
351
+ - Handles seasonality (weekly, yearly)
352
+ - Robust to missing data and outliers
353
+ - Works on aggregated daily sales data
354
+ - Uses chronological train/validation split
355
+
356
+ ### Model Comparison: ML vs Time-Series
357
+
358
+ **Machine Learning Models:**
359
+ - ✅ Predict per-product demand
360
+ - ✅ Use product features (price, discount, category)
361
+ - ✅ Can handle new products with similar features
362
+ - ❌ May not capture long-term temporal patterns as well
363
+
364
+ **Time-Series Models:**
365
+ - ✅ Capture temporal patterns and trends
366
+ - ✅ Handle seasonality automatically
367
+ - ✅ Good for overall demand forecasting
368
+ - ❌ Predict aggregate demand, not per-product
369
+ - ❌ Don't use product-specific features
370
+
371
+ **The system automatically selects the best model based on R2 score across all model types.**
372
+
373
+ ### Feature Engineering
374
+
375
+ **For ML Models:**
376
+
377
+ The system extracts the following features from the input data:
378
+
379
+ **Date Features:**
380
+ - `day`: Day of month (1-31)
381
+ - `month`: Month (1-12)
382
+ - `day_of_week`: Day of week (0=Monday, 6=Sunday)
383
+ - `weekend`: Binary indicator (1 if weekend, 0 otherwise)
384
+ - `year`: Year
385
+ - `quarter`: Quarter of year (1-4)
386
+
387
+ **Original Features:**
388
+ - `product_id`: Encoded as categorical
389
+ - `price`: Numerical (scaled)
390
+ - `discount`: Numerical (scaled)
391
+ - `category`: Encoded as categorical
392
+
393
+ **Total Features**: 10 features after encoding and scaling
394
+
395
+ **For Time-Series Models:**
396
+
397
+ - Data is aggregated by date (total daily sales)
398
+ - Uses chronological split (80% train, 20% validation)
399
+ - Prophet automatically handles:
400
+ - Weekly seasonality
401
+ - Yearly seasonality
402
+ - Trend components
403
+
404
+ ## 📈 Evaluation Metrics
405
+
406
+ The system evaluates models using three metrics:
407
+
408
+ 1. **MAE (Mean Absolute Error)**
409
+ - Average absolute difference between predicted and actual values
410
+ - Lower is better
411
+ - Units: same as target variable (sales quantity)
412
+
413
+ 2. **RMSE (Root Mean Squared Error)**
414
+ - Square root of average squared differences
415
+ - Penalizes large errors more than MAE
416
+ - Lower is better
417
+ - Units: same as target variable (sales quantity)
418
+
419
+ 3. **R2 Score (Coefficient of Determination)**
420
+ - Proportion of variance explained by the model
421
+ - Range: -∞ to 1 (1 is perfect prediction)
422
+ - Higher is better
423
+ - Used for model selection
424
+
425
+ **Model Selection**: The model with the highest R2 score is selected as the best model.
426
+
427
+ ## 📊 Visualizations
428
+
429
+ The training script generates several visualizations:
430
+
431
+ 1. **Demand Trends Over Time** (`plots/demand_trends.png`)
432
+ - Shows total daily sales quantity over the entire time period
433
+ - Helps identify overall trends and patterns
434
+
435
+ 2. **Monthly Average Demand** (`plots/monthly_demand.png`)
436
+ - Bar chart showing average sales by month
437
+ - Reveals seasonal patterns (e.g., holiday season spikes)
438
+
439
+ 3. **Feature Importance** (`plots/feature_importance.png`)
440
+ - Shows which features are most important for predictions
441
+ - Only available for tree-based models (Random Forest, XGBoost)
442
+
443
+ 4. **Model Comparison** (`plots/model_comparison.png`)
444
+ - Side-by-side comparison of all models (ML and Time-Series)
445
+ - Color-coded: ML models (blue) vs Time-Series models (orange/red)
446
+ - Shows MAE, RMSE, and R2 Score for each model
447
+
448
+ 5. **Time-Series Predictions** (`plots/timeseries_predictions.png`)
449
+ - Actual vs predicted plots for ARIMA and Prophet models
450
+ - Shows how well time-series models capture temporal patterns
451
+ - Only generated if time-series models are available
452
+
453
+ ## 🔮 Example Predictions
454
+
455
+ Here are some example predictions to demonstrate the system:
456
+
457
+ ```bash
458
+ # Example 1: Electronics on a weekday
459
+ python predict.py --product_id 1 --date 2024-03-15 --price 500 --discount 10 --category Electronics
460
+ # Expected: Moderate demand (weekday, some discount)
461
+
462
+ # Example 2: Clothing on weekend
463
+ python predict.py --product_id 5 --date 2024-07-06 --price 50 --discount 20 --category Clothing
464
+ # Expected: Higher demand (weekend, good discount)
465
+
466
+ # Example 3: Holiday season prediction
467
+ python predict.py --product_id 10 --date 2024-12-20 --price 100 --discount 25 --category Toys
468
+ # Expected: High demand (holiday season, good discount)
469
+ ```
470
+
471
+ ## 🔧 Technical Details
472
+
473
+ ### Data Preprocessing Pipeline
474
+
475
+ 1. **Date Conversion**: Convert date strings to datetime objects
476
+ 2. **Feature Extraction**: Extract temporal features from dates
477
+ 3. **Missing Value Handling**: Fill missing values with median (if any)
478
+ 4. **Categorical Encoding**: Label encode product_id and category
479
+ 5. **Feature Scaling**: Standardize numerical features using StandardScaler
480
+
481
+ ### Model Training Pipeline
482
+
483
+ 1. **Data Splitting**: 80% training, 20% validation
484
+ 2. **Model Training**: Train all available models
485
+ 3. **Evaluation**: Calculate MAE, RMSE, and R2 for each model
486
+ 4. **Selection**: Choose model with highest R2 score
487
+ 5. **Persistence**: Save model, encoders, and scaler
488
+
489
+ ### Prediction Pipeline
490
+
491
+ 1. **Load Model**: Load trained model and preprocessing objects
492
+ 2. **Feature Preparation**: Extract features from input parameters
493
+ 3. **Encoding**: Encode categorical variables using saved encoders
494
+ 4. **Scaling**: Scale features using saved scaler
495
+ 5. **Prediction**: Make prediction using loaded model
496
+ 6. **Post-processing**: Ensure non-negative predictions
497
+
498
+ ### Handling Unseen Data
499
+
500
+ The prediction script handles cases where:
501
+ - Product ID was not seen during training (uses default encoding)
502
+ - Category was not seen during training (uses default encoding)
503
+
504
+ Warnings are displayed in such cases.
505
+
506
+ ## 🎓 Learning Points
507
+
508
+ This project demonstrates:
509
+
510
+ 1. **Supervised Learning**: Regression problem solving
511
+ 2. **Feature Engineering**: Creating meaningful features from raw data
512
+ 3. **Model Comparison**: Training and evaluating multiple models
513
+ 4. **Model Selection**: Automatic best model selection
514
+ 5. **Model Persistence**: Saving and loading trained models
515
+ 6. **Production-Ready Code**: Clean, modular, well-documented code
516
+ 7. **Time Series Features**: Extracting temporal patterns
517
+ 8. **Categorical Encoding**: Handling categorical variables
518
+ 9. **Feature Scaling**: Normalizing features for better performance
519
+ 10. **Evaluation Metrics**: Understanding different regression metrics
520
+
521
+ ## 🐛 Troubleshooting
522
+
523
+ ### Issue: "Model not found"
524
+ **Solution**: Run `python train_model.py` first to train and save the model.
525
+
526
+ ### Issue: "XGBoost not available"
527
+ **Solution**: Install XGBoost with `pip install xgboost`, or the system will work without it (skipping XGBoost model).
528
+
529
+ ### Issue: "Category not seen during training"
530
+ **Solution**: This is handled automatically with a warning. The system uses a default encoding.
531
+
532
+ ### Issue: Poor prediction accuracy
533
+ **Solutions**:
534
+ - Ensure you have sufficient training data
535
+ - Check that input features are in the same range as training data
536
+ - Try retraining with different hyperparameters
537
+ - Consider adding more features or more training data
538
+
539
+ ## 📝 Notes
540
+
541
+ - The synthetic dataset generator creates realistic patterns including:
542
+ - Weekend effects (higher sales on weekends)
543
+ - Seasonal patterns (holiday season spikes)
544
+ - Price and discount effects
545
+ - Category-specific base prices
546
+
547
+ - For production use, consider:
548
+ - Using real historical data
549
+ - Retraining models periodically
550
+ - Adding more features (promotions, weather, etc.)
551
+ - Implementing model versioning
552
+ - Adding prediction confidence intervals
553
+
554
+ ## 📄 License
555
+
556
+ This project is provided as-is for educational purposes.
557
+
558
+ ## 👤 Author
559
+
560
+ Created as a complete machine learning project demonstrating demand prediction for e-commerce.
561
+
562
+ ---
563
+
564
+ **Happy Predicting! 🚀**
app.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demand Prediction System - Streamlit Dashboard
3
+
4
+ Interactive dashboard for visualizing sales trends and making demand predictions.
5
+ """
6
+
7
+ import streamlit as st
8
+ import pandas as pd
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ import joblib
13
+ import json
14
+ from datetime import datetime, timedelta, date as dt_date
15
+ import os
16
+ import warnings
17
+ warnings.filterwarnings('ignore')
18
+
19
+ # Page configuration
20
+ st.set_page_config(
21
+ page_title="Demand Prediction Dashboard",
22
+ page_icon="📊",
23
+ layout="wide",
24
+ initial_sidebar_state="expanded"
25
+ )
26
+
27
+ # Custom CSS for better styling
28
+ st.markdown("""
29
+ <style>
30
+ .main-header {
31
+ font-size: 2.5rem;
32
+ font-weight: bold;
33
+ color: #1f77b4;
34
+ text-align: center;
35
+ margin-bottom: 2rem;
36
+ }
37
+ .metric-card {
38
+ background-color: #f0f2f6;
39
+ padding: 1rem;
40
+ border-radius: 0.5rem;
41
+ margin: 0.5rem 0;
42
+ }
43
+ .stButton>button {
44
+ width: 100%;
45
+ background-color: #1f77b4;
46
+ color: white;
47
+ }
48
+ </style>
49
+ """, unsafe_allow_html=True)
50
+
51
+ # Configuration
52
+ DATA_PATH = 'data/sales.csv'
53
+ MODEL_DIR = 'models'
54
+ MODEL_PATH = f'{MODEL_DIR}/best_model.joblib'
55
+ TS_MODEL_PATH = f'{MODEL_DIR}/best_timeseries_model.joblib'
56
+ PREPROCESSING_PATH = f'{MODEL_DIR}/preprocessing.joblib'
57
+ ALL_MODELS_METADATA_PATH = f'{MODEL_DIR}/all_models_metadata.json'
58
+
59
+
60
+ @st.cache_data
61
+ def load_data():
62
+ """Load sales data with caching."""
63
+ if os.path.exists(DATA_PATH):
64
+ df = pd.read_csv(DATA_PATH)
65
+ df['date'] = pd.to_datetime(df['date'])
66
+ return df
67
+ return None
68
+
69
+
70
+ @st.cache_resource
71
+ def load_models():
72
+ """Load trained models with caching."""
73
+ models = {
74
+ 'ml_model': None,
75
+ 'ts_model': None,
76
+ 'preprocessing': None,
77
+ 'model_name': None,
78
+ 'is_timeseries': False,
79
+ 'metadata': None
80
+ }
81
+
82
+ # Load metadata
83
+ if os.path.exists(ALL_MODELS_METADATA_PATH):
84
+ with open(ALL_MODELS_METADATA_PATH, 'r') as f:
85
+ models['metadata'] = json.load(f)
86
+ models['model_name'] = models['metadata'].get('best_model', 'Unknown')
87
+ models['is_timeseries'] = models['model_name'] in ['ARIMA', 'Prophet']
88
+
89
+ # Load ML model
90
+ if os.path.exists(MODEL_PATH):
91
+ try:
92
+ models['ml_model'] = joblib.load(MODEL_PATH)
93
+ except:
94
+ pass
95
+
96
+ # Load time-series model
97
+ if os.path.exists(TS_MODEL_PATH):
98
+ try:
99
+ models['ts_model'] = joblib.load(TS_MODEL_PATH)
100
+ except:
101
+ pass
102
+
103
+ # Load preprocessing
104
+ if os.path.exists(PREPROCESSING_PATH):
105
+ try:
106
+ models['preprocessing'] = joblib.load(PREPROCESSING_PATH)
107
+ except:
108
+ pass
109
+
110
+ return models
111
+
112
+
113
+ def prepare_features_ml(product_id, date, price, discount, category, preprocessing_data):
114
+ """Prepare features for ML model prediction."""
115
+ if preprocessing_data is None:
116
+ return None
117
+
118
+ # Convert date to pandas Timestamp (handles date, datetime, and string)
119
+ # Handle datetime.date objects explicitly
120
+ if isinstance(date, dt_date):
121
+ date = pd.Timestamp(date)
122
+ elif not isinstance(date, pd.Timestamp):
123
+ date = pd.to_datetime(date)
124
+
125
+ # Extract date features
126
+ day = date.day
127
+ month = date.month
128
+ day_of_week = date.weekday()
129
+ weekend = 1 if day_of_week >= 5 else 0
130
+ year = date.year
131
+ quarter = date.quarter
132
+
133
+ # Encode categorical variables
134
+ category_encoder = preprocessing_data['encoders']['category']
135
+ product_encoder = preprocessing_data['encoders']['product_id']
136
+
137
+ try:
138
+ category_encoded = category_encoder.transform([category])[0]
139
+ except ValueError:
140
+ category_encoded = 0
141
+
142
+ try:
143
+ product_id_encoded = product_encoder.transform([product_id])[0]
144
+ except ValueError:
145
+ product_id_encoded = product_encoder.transform([product_encoder.classes_[0]])[0]
146
+
147
+ # Create feature dictionary
148
+ feature_dict = {
149
+ 'price': price,
150
+ 'discount': discount,
151
+ 'day': day,
152
+ 'month': month,
153
+ 'day_of_week': day_of_week,
154
+ 'weekend': weekend,
155
+ 'year': year,
156
+ 'quarter': quarter,
157
+ 'category_encoded': category_encoded,
158
+ 'product_id_encoded': product_id_encoded
159
+ }
160
+
161
+ # Create feature array in the same order as training
162
+ feature_names = preprocessing_data['feature_names']
163
+ features = np.array([[feature_dict[name] for name in feature_names]])
164
+
165
+ # Scale features
166
+ scaler = preprocessing_data['scaler']
167
+ features_scaled = scaler.transform(features)
168
+
169
+ return features_scaled
170
+
171
+
172
+ def predict_ml(product_id, date, price, discount, category, model, preprocessing_data):
173
+ """Make prediction using ML model."""
174
+ features = prepare_features_ml(product_id, date, price, discount, category, preprocessing_data)
175
+ if features is None:
176
+ return None
177
+ prediction = model.predict(features)[0]
178
+ return max(0, prediction)
179
+
180
+
181
+ def predict_timeseries(date, model, model_name):
182
+ """Make prediction using time-series model."""
183
+ # Convert date to pandas Timestamp (handles date, datetime, and string)
184
+ if isinstance(date, dt_date):
185
+ date = pd.Timestamp(date)
186
+ elif not isinstance(date, pd.Timestamp):
187
+ date = pd.to_datetime(date)
188
+
189
+ if model_name == 'ARIMA':
190
+ try:
191
+ forecast = model.forecast(steps=1)
192
+ prediction = forecast[0] if hasattr(forecast, '__iter__') else forecast
193
+ return max(0, prediction)
194
+ except:
195
+ return None
196
+
197
+ elif model_name == 'Prophet':
198
+ try:
199
+ future = pd.DataFrame({'ds': [date]})
200
+ forecast = model.predict(future)
201
+ prediction = forecast['yhat'].iloc[0]
202
+ return max(0, prediction)
203
+ except:
204
+ return None
205
+
206
+ return None
207
+
208
+
209
+ def main():
210
+ """Main dashboard function."""
211
+
212
+ # Header
213
+ st.markdown('<h1 class="main-header">📊 Demand Prediction Dashboard</h1>', unsafe_allow_html=True)
214
+
215
+ # Load data
216
+ df = load_data()
217
+ if df is None:
218
+ st.error("❌ Sales data not found. Please run generate_dataset.py first.")
219
+ return
220
+
221
+ # Load models
222
+ models = load_models()
223
+
224
+ # Sidebar
225
+ with st.sidebar:
226
+ st.header("⚙️ Navigation")
227
+ page = st.radio(
228
+ "Select Page",
229
+ ["📈 Sales Trends", "🔮 Demand Prediction", "📊 Model Comparison"],
230
+ index=0
231
+ )
232
+
233
+ st.markdown("---")
234
+ st.header("ℹ️ Information")
235
+ if models['metadata']:
236
+ best_model = models['metadata'].get('best_model', 'Unknown')
237
+ st.info(f"**Best Model:** {best_model}")
238
+ if best_model in models['metadata'].get('all_models', {}):
239
+ metrics = models['metadata']['all_models'][best_model]
240
+ st.metric("R2 Score", f"{metrics.get('r2', 0):.4f}")
241
+
242
+ # Main content based on selected page
243
+ if page == "📈 Sales Trends":
244
+ show_sales_trends(df)
245
+ elif page == "🔮 Demand Prediction":
246
+ show_prediction_interface(df, models)
247
+ elif page == "📊 Model Comparison":
248
+ show_model_comparison(models)
249
+
250
+
251
+ def show_sales_trends(df):
252
+ """Display sales trends visualizations."""
253
+ st.header("📈 Sales Trends Analysis")
254
+
255
+ # Filters
256
+ col1, col2, col3 = st.columns(3)
257
+
258
+ with col1:
259
+ categories = ['All'] + sorted(df['category'].unique().tolist())
260
+ selected_category = st.selectbox("Select Category", categories)
261
+
262
+ with col2:
263
+ products = ['All'] + sorted(df['product_id'].unique().tolist())
264
+ selected_product = st.selectbox("Select Product", products)
265
+
266
+ with col3:
267
+ date_range = st.date_input(
268
+ "Select Date Range",
269
+ value=(df['date'].min(), df['date'].max()),
270
+ min_value=df['date'].min(),
271
+ max_value=df['date'].max()
272
+ )
273
+
274
+ # Filter data
275
+ filtered_df = df.copy()
276
+
277
+ if selected_category != 'All':
278
+ filtered_df = filtered_df[filtered_df['category'] == selected_category]
279
+
280
+ if selected_product != 'All':
281
+ filtered_df = filtered_df[filtered_df['product_id'] == int(selected_product)]
282
+
283
+ if isinstance(date_range, tuple) and len(date_range) == 2:
284
+ filtered_df = filtered_df[
285
+ (filtered_df['date'] >= pd.to_datetime(date_range[0])) &
286
+ (filtered_df['date'] <= pd.to_datetime(date_range[1]))
287
+ ]
288
+
289
+ if len(filtered_df) == 0:
290
+ st.warning("No data available for selected filters.")
291
+ return
292
+
293
+ # Visualizations
294
+ tab1, tab2, tab3, tab4 = st.tabs(["📅 Daily Trends", "📆 Monthly Trends", "📦 Category Analysis", "💰 Price vs Demand"])
295
+
296
+ with tab1:
297
+ st.subheader("Daily Sales Trends")
298
+ daily_sales = filtered_df.groupby('date')['sales_quantity'].sum().reset_index()
299
+
300
+ fig, ax = plt.subplots(figsize=(14, 6))
301
+ ax.plot(daily_sales['date'], daily_sales['sales_quantity'], linewidth=2, alpha=0.7)
302
+ ax.set_title('Total Daily Sales Quantity', fontsize=16, fontweight='bold')
303
+ ax.set_xlabel('Date', fontsize=12)
304
+ ax.set_ylabel('Sales Quantity', fontsize=12)
305
+ ax.grid(True, alpha=0.3)
306
+ plt.xticks(rotation=45)
307
+ plt.tight_layout()
308
+ st.pyplot(fig)
309
+
310
+ # Statistics
311
+ col1, col2, col3, col4 = st.columns(4)
312
+ with col1:
313
+ st.metric("Total Sales", f"{daily_sales['sales_quantity'].sum():,.0f}")
314
+ with col2:
315
+ st.metric("Average Daily", f"{daily_sales['sales_quantity'].mean():.1f}")
316
+ with col3:
317
+ st.metric("Max Daily", f"{daily_sales['sales_quantity'].max():,.0f}")
318
+ with col4:
319
+ st.metric("Min Daily", f"{daily_sales['sales_quantity'].min():,.0f}")
320
+
321
+ with tab2:
322
+ st.subheader("Monthly Sales Trends")
323
+ filtered_df['month_year'] = filtered_df['date'].dt.to_period('M')
324
+ monthly_sales = filtered_df.groupby('month_year')['sales_quantity'].sum().reset_index()
325
+ monthly_sales['month_year'] = monthly_sales['month_year'].astype(str)
326
+
327
+ fig, ax = plt.subplots(figsize=(14, 6))
328
+ ax.bar(range(len(monthly_sales)), monthly_sales['sales_quantity'], alpha=0.7, color='steelblue')
329
+ ax.set_title('Monthly Sales Quantity', fontsize=16, fontweight='bold')
330
+ ax.set_xlabel('Month', fontsize=12)
331
+ ax.set_ylabel('Sales Quantity', fontsize=12)
332
+ ax.set_xticks(range(len(monthly_sales)))
333
+ ax.set_xticklabels(monthly_sales['month_year'], rotation=45, ha='right')
334
+ ax.grid(True, alpha=0.3, axis='y')
335
+ plt.tight_layout()
336
+ st.pyplot(fig)
337
+
338
+ with tab3:
339
+ st.subheader("Sales by Category")
340
+ category_sales = filtered_df.groupby('category')['sales_quantity'].sum().sort_values(ascending=False)
341
+
342
+ fig, ax = plt.subplots(figsize=(12, 6))
343
+ category_sales.plot(kind='barh', ax=ax, color='coral', alpha=0.7)
344
+ ax.set_title('Total Sales by Category', fontsize=16, fontweight='bold')
345
+ ax.set_xlabel('Total Sales Quantity', fontsize=12)
346
+ ax.set_ylabel('Category', fontsize=12)
347
+ ax.grid(True, alpha=0.3, axis='x')
348
+ plt.tight_layout()
349
+ st.pyplot(fig)
350
+
351
+ # Category statistics table
352
+ category_stats = filtered_df.groupby('category').agg({
353
+ 'sales_quantity': ['sum', 'mean', 'std', 'min', 'max']
354
+ }).round(2)
355
+ category_stats.columns = ['Total', 'Average', 'Std Dev', 'Min', 'Max']
356
+ st.dataframe(category_stats, use_container_width=True)
357
+
358
+ with tab4:
359
+ st.subheader("Price vs Demand Relationship")
360
+
361
+ # Scatter plot
362
+ fig, ax = plt.subplots(figsize=(12, 6))
363
+ scatter = ax.scatter(filtered_df['price'], filtered_df['sales_quantity'],
364
+ c=filtered_df['discount'], cmap='viridis', alpha=0.6, s=50)
365
+ ax.set_title('Price vs Sales Quantity (colored by discount)', fontsize=16, fontweight='bold')
366
+ ax.set_xlabel('Price', fontsize=12)
367
+ ax.set_ylabel('Sales Quantity', fontsize=12)
368
+ ax.grid(True, alpha=0.3)
369
+ plt.colorbar(scatter, ax=ax, label='Discount %')
370
+ plt.tight_layout()
371
+ st.pyplot(fig)
372
+
373
+ # Correlation
374
+ correlation = filtered_df['price'].corr(filtered_df['sales_quantity'])
375
+ st.metric("Price-Demand Correlation", f"{correlation:.3f}")
376
+
377
+
378
+ def show_prediction_interface(df, models):
379
+ """Display interactive prediction interface."""
380
+ st.header("🔮 Demand Prediction")
381
+
382
+ # Check if models are available
383
+ if models['ml_model'] is None and models['ts_model'] is None:
384
+ st.error("❌ No trained models found. Please run train_model.py first.")
385
+ return
386
+
387
+ # Model selection
388
+ model_type = st.radio(
389
+ "Select Model Type",
390
+ ["Auto (Best Model)", "Machine Learning", "Time-Series"],
391
+ horizontal=True
392
+ )
393
+
394
+ st.markdown("---")
395
+
396
+ if model_type == "Time-Series" or (model_type == "Auto (Best Model)" and models['is_timeseries']):
397
+ # Time-series prediction
398
+ st.subheader("Overall Daily Demand Prediction")
399
+
400
+ col1, col2 = st.columns(2)
401
+ with col1:
402
+ prediction_date = st.date_input(
403
+ "Select Date for Prediction",
404
+ value=datetime.now().date() + timedelta(days=30),
405
+ min_value=df['date'].max().date() + timedelta(days=1)
406
+ )
407
+
408
+ with col2:
409
+ st.write("") # Spacing
410
+ st.write("") # Spacing
411
+
412
+ if st.button("🔮 Predict Demand", type="primary"):
413
+ if models['ts_model'] is None:
414
+ st.error("Time-series model not available.")
415
+ else:
416
+ with st.spinner("Making prediction..."):
417
+ prediction = predict_timeseries(
418
+ prediction_date,
419
+ models['ts_model'],
420
+ models['model_name']
421
+ )
422
+
423
+ if prediction is not None:
424
+ st.success(f"✅ Prediction Complete!")
425
+
426
+ col1, col2, col3 = st.columns(3)
427
+ with col1:
428
+ st.metric("Predicted Daily Demand", f"{prediction:,.0f} units")
429
+ with col2:
430
+ day_name = pd.to_datetime(prediction_date).strftime('%A')
431
+ st.metric("Day of Week", day_name)
432
+ with col3:
433
+ is_weekend = "Yes" if pd.to_datetime(prediction_date).weekday() >= 5 else "No"
434
+ st.metric("Weekend", is_weekend)
435
+
436
+ st.info("💡 This prediction represents the total daily demand across all products.")
437
+ else:
438
+ st.error("Failed to make prediction.")
439
+
440
+ else:
441
+ # ML model prediction
442
+ st.subheader("Product-Specific Demand Prediction")
443
+
444
+ # Get unique values for dropdowns
445
+ categories = sorted(df['category'].unique().tolist())
446
+ products = sorted(df['product_id'].unique().tolist())
447
+
448
+ col1, col2 = st.columns(2)
449
+
450
+ with col1:
451
+ selected_category = st.selectbox("Select Category", categories)
452
+ selected_product = st.selectbox("Select Product ID", products)
453
+ prediction_date = st.date_input(
454
+ "Select Date for Prediction",
455
+ value=datetime.now().date() + timedelta(days=30),
456
+ min_value=df['date'].max().date() + timedelta(days=1)
457
+ )
458
+
459
+ with col2:
460
+ price = st.number_input(
461
+ "Product Price ($)",
462
+ min_value=0.01,
463
+ value=100.0,
464
+ step=1.0,
465
+ format="%.2f"
466
+ )
467
+ discount = st.slider(
468
+ "Discount (%)",
469
+ min_value=0,
470
+ max_value=100,
471
+ value=0,
472
+ step=5
473
+ )
474
+
475
+ # Show product statistics
476
+ product_data = df[df['product_id'] == selected_product]
477
+ if len(product_data) > 0:
478
+ with st.expander("📊 Product Statistics"):
479
+ col1, col2, col3, col4 = st.columns(4)
480
+ with col1:
481
+ st.metric("Avg Price", f"${product_data['price'].mean():.2f}")
482
+ with col2:
483
+ st.metric("Avg Sales", f"{product_data['sales_quantity'].mean():.1f}")
484
+ with col3:
485
+ st.metric("Total Sales", f"{product_data['sales_quantity'].sum():,.0f}")
486
+ with col4:
487
+ st.metric("Category", selected_category)
488
+
489
+ if st.button("🔮 Predict Demand", type="primary"):
490
+ if models['ml_model'] is None or models['preprocessing'] is None:
491
+ st.error("ML model or preprocessing not available.")
492
+ else:
493
+ with st.spinner("Making prediction..."):
494
+ prediction = predict_ml(
495
+ selected_product,
496
+ prediction_date,
497
+ price,
498
+ discount,
499
+ selected_category,
500
+ models['ml_model'],
501
+ models['preprocessing']
502
+ )
503
+
504
+ if prediction is not None:
505
+ st.success(f"✅ Prediction Complete!")
506
+
507
+ col1, col2, col3, col4 = st.columns(4)
508
+ with col1:
509
+ st.metric("Predicted Demand", f"{prediction:,.0f} units")
510
+ with col2:
511
+ st.metric("Price", f"${price:.2f}")
512
+ with col3:
513
+ st.metric("Discount", f"{discount}%")
514
+ with col4:
515
+ day_name = pd.to_datetime(prediction_date).strftime('%A')
516
+ st.metric("Day", day_name)
517
+
518
+ # Additional insights
519
+ st.markdown("### 📈 Prediction Insights")
520
+ date_obj = pd.to_datetime(prediction_date)
521
+ is_weekend = date_obj.weekday() >= 5
522
+ month = date_obj.month
523
+
524
+ insights = []
525
+ if is_weekend:
526
+ insights.append("📅 Weekend - typically higher demand")
527
+ if month in [11, 12]:
528
+ insights.append("🎄 Holiday season - peak sales period")
529
+ if discount > 0:
530
+ insights.append(f"💰 {discount}% discount - may increase demand")
531
+
532
+ if insights:
533
+ for insight in insights:
534
+ st.info(insight)
535
+ else:
536
+ st.error("Failed to make prediction.")
537
+
538
+
539
+ def show_model_comparison(models):
540
+ """Display model comparison."""
541
+ st.header("📊 Model Comparison")
542
+
543
+ if models['metadata'] is None:
544
+ st.warning("Model metadata not available. Please run train_model.py first.")
545
+ return
546
+
547
+ metadata = models['metadata']
548
+ all_models = metadata.get('all_models', {})
549
+ best_model = metadata.get('best_model', 'Unknown')
550
+
551
+ if not all_models:
552
+ st.warning("No model comparison data available.")
553
+ return
554
+
555
+ # Model metrics table
556
+ st.subheader("Model Performance Metrics")
557
+
558
+ comparison_data = []
559
+ for model_name, metrics in all_models.items():
560
+ comparison_data.append({
561
+ 'Model': model_name,
562
+ 'Type': 'Time-Series' if model_name in ['ARIMA', 'Prophet'] else 'Machine Learning',
563
+ 'MAE': metrics.get('mae', 0),
564
+ 'RMSE': metrics.get('rmse', 0),
565
+ 'R2 Score': metrics.get('r2', 0)
566
+ })
567
+
568
+ comparison_df = pd.DataFrame(comparison_data)
569
+
570
+ # Highlight best model
571
+ def highlight_best(row):
572
+ if row['Model'] == best_model:
573
+ return ['background-color: #90EE90'] * len(row)
574
+ return [''] * len(row)
575
+
576
+ st.dataframe(
577
+ comparison_df.style.apply(highlight_best, axis=1),
578
+ use_container_width=True
579
+ )
580
+
581
+ # Visualizations
582
+ st.subheader("Performance Comparison Charts")
583
+
584
+ col1, col2 = st.columns(2)
585
+
586
+ with col1:
587
+ fig, ax = plt.subplots(figsize=(10, 6))
588
+ model_names = comparison_df['Model'].tolist()
589
+ mae_scores = comparison_df['MAE'].tolist()
590
+
591
+ colors = ['coral' if name in ['ARIMA', 'Prophet'] else 'skyblue' for name in model_names]
592
+ ax.bar(model_names, mae_scores, color=colors, alpha=0.7)
593
+ ax.set_title('MAE Comparison (Lower is Better)', fontsize=14, fontweight='bold')
594
+ ax.set_ylabel('MAE', fontsize=12)
595
+ ax.tick_params(axis='x', rotation=45)
596
+ ax.grid(True, alpha=0.3, axis='y')
597
+ plt.tight_layout()
598
+ st.pyplot(fig)
599
+
600
+ with col2:
601
+ fig, ax = plt.subplots(figsize=(10, 6))
602
+ r2_scores = comparison_df['R2 Score'].tolist()
603
+
604
+ colors = ['coral' if name in ['ARIMA', 'Prophet'] else 'skyblue' for name in model_names]
605
+ ax.bar(model_names, r2_scores, color=colors, alpha=0.7)
606
+ ax.set_title('R2 Score Comparison (Higher is Better)', fontsize=14, fontweight='bold')
607
+ ax.set_ylabel('R2 Score', fontsize=12)
608
+ ax.tick_params(axis='x', rotation=45)
609
+ ax.grid(True, alpha=0.3, axis='y')
610
+ plt.tight_layout()
611
+ st.pyplot(fig)
612
+
613
+ # Best model info
614
+ st.markdown("---")
615
+ st.success(f"🏆 **Best Model: {best_model}**")
616
+ if best_model in all_models:
617
+ best_metrics = all_models[best_model]
618
+ col1, col2, col3 = st.columns(3)
619
+ with col1:
620
+ st.metric("MAE", f"{best_metrics.get('mae', 0):.2f}")
621
+ with col2:
622
+ st.metric("RMSE", f"{best_metrics.get('rmse', 0):.2f}")
623
+ with col3:
624
+ st.metric("R2 Score", f"{best_metrics.get('r2', 0):.4f}")
625
+
626
+
627
+ if __name__ == "__main__":
628
+ main()
data/sales.csv ADDED
The diff for this file is too large to render. See raw diff
 
generate_dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate Synthetic E-commerce Sales Dataset
3
+
4
+ This script creates a realistic synthetic dataset for demand prediction.
5
+ The dataset includes temporal patterns, seasonality, and realistic relationships
6
+ between features and sales quantity.
7
+ """
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+ from datetime import datetime, timedelta
12
+
13
+ # Set random seed for reproducibility
14
+ np.random.seed(42)
15
+
16
+ # Configuration
17
+ NUM_PRODUCTS = 50
18
+ START_DATE = datetime(2020, 1, 1)
19
+ END_DATE = datetime(2023, 12, 31)
20
+ CATEGORIES = ['Electronics', 'Clothing', 'Home & Garden', 'Sports', 'Books',
21
+ 'Toys', 'Beauty', 'Automotive', 'Food & Beverages', 'Health']
22
+
23
+ # Generate date range
24
+ date_range = pd.date_range(start=START_DATE, end=END_DATE, freq='D')
25
+ num_days = len(date_range)
26
+
27
+ # Initialize lists to store data
28
+ data = []
29
+
30
+ # Generate data for each product
31
+ for product_id in range(1, NUM_PRODUCTS + 1):
32
+ # Assign category randomly
33
+ category = np.random.choice(CATEGORIES)
34
+
35
+ # Base price varies by category
36
+ category_base_prices = {
37
+ 'Electronics': 500,
38
+ 'Clothing': 50,
39
+ 'Home & Garden': 100,
40
+ 'Sports': 150,
41
+ 'Books': 20,
42
+ 'Toys': 30,
43
+ 'Beauty': 40,
44
+ 'Automotive': 300,
45
+ 'Food & Beverages': 25,
46
+ 'Health': 60
47
+ }
48
+
49
+ base_price = category_base_prices[category] * (0.8 + np.random.random() * 0.4)
50
+
51
+ # Generate daily records
52
+ for date in date_range:
53
+ # Day of week effect (weekends have higher sales)
54
+ day_of_week = date.weekday()
55
+ weekend_multiplier = 1.3 if day_of_week >= 5 else 1.0
56
+
57
+ # Monthly seasonality (higher sales in Nov-Dec, lower in Jan-Feb)
58
+ month = date.month
59
+ if month in [11, 12]: # Holiday season
60
+ seasonality = 1.5
61
+ elif month in [1, 2]: # Post-holiday slump
62
+ seasonality = 0.7
63
+ elif month in [6, 7, 8]: # Summer
64
+ seasonality = 1.2
65
+ else:
66
+ seasonality = 1.0
67
+
68
+ # Random discount (0-30%)
69
+ discount = np.random.choice([0, 5, 10, 15, 20, 25, 30], p=[0.4, 0.2, 0.15, 0.1, 0.08, 0.05, 0.02])
70
+
71
+ # Price with discount
72
+ price = base_price * (1 - discount / 100)
73
+
74
+ # Base demand varies by product
75
+ base_demand = np.random.randint(10, 100)
76
+
77
+ # Calculate sales quantity with multiple factors
78
+ # Higher discount -> higher sales
79
+ discount_effect = 1 + (discount / 100) * 0.5
80
+
81
+ # Lower price -> higher sales (inverse relationship)
82
+ price_effect = 1 / (1 + (price / 1000) * 0.1)
83
+
84
+ # Add some randomness
85
+ noise = np.random.normal(1, 0.15)
86
+
87
+ # Calculate final sales quantity
88
+ sales_quantity = int(
89
+ base_demand *
90
+ weekend_multiplier *
91
+ seasonality *
92
+ discount_effect *
93
+ price_effect *
94
+ noise
95
+ )
96
+
97
+ # Ensure non-negative
98
+ sales_quantity = max(0, sales_quantity)
99
+
100
+ data.append({
101
+ 'product_id': product_id,
102
+ 'date': date.strftime('%Y-%m-%d'),
103
+ 'price': round(price, 2),
104
+ 'discount': discount,
105
+ 'category': category,
106
+ 'sales_quantity': sales_quantity
107
+ })
108
+
109
+ # Create DataFrame
110
+ df = pd.DataFrame(data)
111
+
112
+ # Shuffle the data
113
+ df = df.sample(frac=1, random_state=42).reset_index(drop=True)
114
+
115
+ # Save to CSV
116
+ output_path = 'data/sales.csv'
117
+ df.to_csv(output_path, index=False)
118
+
119
+ print(f"Dataset generated successfully!")
120
+ print(f"Total records: {len(df)}")
121
+ print(f"Date range: {df['date'].min()} to {df['date'].max()}")
122
+ print(f"Number of products: {df['product_id'].nunique()}")
123
+ print(f"Categories: {df['category'].nunique()}")
124
+ print(f"\nDataset saved to: {output_path}")
125
+ print(f"\nFirst few rows:")
126
+ print(df.head(10))
127
+ print(f"\nDataset statistics:")
128
+ print(df.describe())
models/all_models_metadata.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_model": "XGBoost",
3
+ "best_metrics": {
4
+ "mae": 28.266036987304688,
5
+ "rmse": 34.58608132593768,
6
+ "r2": 0.19758522510528564
7
+ },
8
+ "all_models": {
9
+ "Linear Regression": {
10
+ "mae": 28.94336905682285,
11
+ "rmse": 35.499695024759994,
12
+ "r2": 0.15463271181982996
13
+ },
14
+ "Random Forest": {
15
+ "mae": 28.52530939054232,
16
+ "rmse": 34.98799112718141,
17
+ "r2": 0.17882785374399268
18
+ },
19
+ "XGBoost": {
20
+ "mae": 28.266036987304688,
21
+ "rmse": 34.58608132593768,
22
+ "r2": 0.19758522510528564
23
+ }
24
+ },
25
+ "saved_at": "2026-02-06 17:29:16"
26
+ }
models/best_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:815655a55967c29f9302ce2d52cd6707ae584cbcb25532722d4a2415acd246a1
3
+ size 495387
models/model_metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "XGBoost",
3
+ "metrics": {
4
+ "mae": 28.266036987304688,
5
+ "rmse": 34.58608132593768,
6
+ "r2": 0.19758522510528564
7
+ },
8
+ "feature_names": [
9
+ "price",
10
+ "discount",
11
+ "day",
12
+ "month",
13
+ "day_of_week",
14
+ "weekend",
15
+ "year",
16
+ "quarter",
17
+ "category_encoded",
18
+ "product_id_encoded"
19
+ ],
20
+ "saved_at": "2026-02-06 17:29:16"
21
+ }
models/preprocessing.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e093b6491538c4afa8a7dfcf6bc7e82118d4f5fc72fa024e2079f98a5b57cc5
3
+ size 2252
plots/demand_trends.png ADDED

Git LFS Details

  • SHA256: 4e025111de45c39ee92e9d92c676e7f815e6adce2d5766c99f9a89462e3d75a8
  • Pointer size: 131 Bytes
  • Size of remote file: 663 kB
plots/feature_importance.png ADDED

Git LFS Details

  • SHA256: d635f5cdaf028e35844fe657f83509ff0696a4dbaed6bf8d216c9bd59d1791e6
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
plots/model_comparison.png ADDED

Git LFS Details

  • SHA256: 35e5f71b23c88502e8cb873477bf74eeca319cf7376ddea3f123660a7e5a4c4e
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
plots/monthly_demand.png ADDED

Git LFS Details

  • SHA256: 23bb6eb77daf4871727835e5006b23a6751f9ac1a9b25faec70915e1fd665d86
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
predict.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demand Prediction System - Prediction Script
3
+
4
+ This script loads a trained model and makes demand predictions for products
5
+ on future dates. Supports both ML models and time-series models (ARIMA, Prophet).
6
+
7
+ Usage (ML Models):
8
+ python predict.py --product_id 1 --date 2024-01-15 --price 100 --discount 10 --category Electronics
9
+
10
+ Usage (Time-Series Models - overall demand):
11
+ python predict.py --date 2024-01-15 --model_type timeseries
12
+ """
13
+
14
+ import pandas as pd
15
+ import numpy as np
16
+ import joblib
17
+ import json
18
+ import argparse
19
+ from datetime import datetime
20
+ import os
21
+ import warnings
22
+ warnings.filterwarnings('ignore')
23
+
24
+ # Configuration
25
+ MODEL_DIR = 'models'
26
+ MODEL_PATH = f'{MODEL_DIR}/best_model.joblib'
27
+ TS_MODEL_PATH = f'{MODEL_DIR}/best_timeseries_model.joblib'
28
+ PREPROCESSING_PATH = f'{MODEL_DIR}/preprocessing.joblib'
29
+ METADATA_PATH = f'{MODEL_DIR}/model_metadata.json'
30
+ ALL_MODELS_METADATA_PATH = f'{MODEL_DIR}/all_models_metadata.json'
31
+
32
+
33
+ def load_model_and_preprocessing(model_type='auto'):
34
+ """
35
+ Load the trained model and preprocessing objects.
36
+
37
+ Args:
38
+ model_type: 'ml', 'timeseries', or 'auto' (auto-detect best model)
39
+
40
+ Returns:
41
+ tuple: (model, preprocessing_data, model_name, is_timeseries)
42
+ """
43
+ # Load metadata to determine best model
44
+ if os.path.exists(ALL_MODELS_METADATA_PATH):
45
+ with open(ALL_MODELS_METADATA_PATH, 'r') as f:
46
+ all_metadata = json.load(f)
47
+ best_model_name = all_metadata.get('best_model', 'Unknown')
48
+ else:
49
+ best_model_name = None
50
+
51
+ # Determine which model to use
52
+ if model_type == 'auto':
53
+ if best_model_name in ['ARIMA', 'Prophet']:
54
+ model_type = 'timeseries'
55
+ else:
56
+ model_type = 'ml'
57
+
58
+ is_timeseries = (model_type == 'timeseries')
59
+
60
+ if is_timeseries:
61
+ # Load time-series model
62
+ if not os.path.exists(TS_MODEL_PATH):
63
+ raise FileNotFoundError(
64
+ f"Time-series model not found at {TS_MODEL_PATH}. Please run train_model.py first."
65
+ )
66
+
67
+ print("Loading time-series model...")
68
+ model = joblib.load(TS_MODEL_PATH)
69
+ preprocessing_data = None
70
+
71
+ if best_model_name:
72
+ print(f"Model: {best_model_name}")
73
+ if best_model_name in all_metadata.get('all_models', {}):
74
+ metrics = all_metadata['all_models'][best_model_name]
75
+ print(f"R2 Score: {metrics.get('r2', 'N/A'):.4f}")
76
+
77
+ return model, preprocessing_data, best_model_name or 'Time-Series', True
78
+ else:
79
+ # Load ML model
80
+ if not os.path.exists(MODEL_PATH):
81
+ raise FileNotFoundError(
82
+ f"ML model not found at {MODEL_PATH}. Please run train_model.py first."
83
+ )
84
+
85
+ if not os.path.exists(PREPROCESSING_PATH):
86
+ raise FileNotFoundError(
87
+ f"Preprocessing objects not found at {PREPROCESSING_PATH}. Please run train_model.py first."
88
+ )
89
+
90
+ print("Loading ML model and preprocessing objects...")
91
+ model = joblib.load(MODEL_PATH)
92
+ preprocessing_data = joblib.load(PREPROCESSING_PATH)
93
+
94
+ # Load metadata if available
95
+ if os.path.exists(METADATA_PATH):
96
+ with open(METADATA_PATH, 'r') as f:
97
+ metadata = json.load(f)
98
+ model_name = metadata.get('model_name', 'ML Model')
99
+ print(f"Model: {model_name}")
100
+ print(f"R2 Score: {metadata.get('metrics', {}).get('r2', 'N/A'):.4f}")
101
+ else:
102
+ model_name = best_model_name or 'ML Model'
103
+
104
+ return model, preprocessing_data, model_name, False
105
+
106
+
107
+ def prepare_features(product_id, date, price, discount, category, preprocessing_data):
108
+ """
109
+ Prepare features for prediction using the same preprocessing pipeline.
110
+
111
+ Args:
112
+ product_id: Product ID
113
+ date: Date string (YYYY-MM-DD) or datetime object
114
+ price: Product price
115
+ discount: Discount percentage (0-100)
116
+ category: Product category
117
+ preprocessing_data: Dictionary containing encoders and scaler
118
+
119
+ Returns:
120
+ numpy array: Prepared features for prediction
121
+ """
122
+ # Convert date to datetime if string
123
+ if isinstance(date, str):
124
+ date = pd.to_datetime(date)
125
+
126
+ # Extract date features (same as in training)
127
+ day = date.day
128
+ month = date.month
129
+ day_of_week = date.weekday() # 0=Monday, 6=Sunday
130
+ weekend = 1 if day_of_week >= 5 else 0
131
+ year = date.year
132
+ quarter = date.quarter
133
+
134
+ # Encode categorical variables
135
+ category_encoder = preprocessing_data['encoders']['category']
136
+ product_encoder = preprocessing_data['encoders']['product_id']
137
+
138
+ # Handle unseen categories/products
139
+ try:
140
+ category_encoded = category_encoder.transform([category])[0]
141
+ except ValueError:
142
+ # If category not seen during training, use most common category
143
+ print(f"Warning: Category '{category}' not seen during training. Using default encoding.")
144
+ category_encoded = 0
145
+
146
+ try:
147
+ product_id_encoded = product_encoder.transform([product_id])[0]
148
+ except ValueError:
149
+ # If product_id not seen during training, use mean encoding
150
+ print(f"Warning: Product ID '{product_id}' not seen during training. Using default encoding.")
151
+ product_id_encoded = product_encoder.transform([product_encoder.classes_[0]])[0]
152
+
153
+ # Create feature dictionary
154
+ feature_dict = {
155
+ 'price': price,
156
+ 'discount': discount,
157
+ 'day': day,
158
+ 'month': month,
159
+ 'day_of_week': day_of_week,
160
+ 'weekend': weekend,
161
+ 'year': year,
162
+ 'quarter': quarter,
163
+ 'category_encoded': category_encoded,
164
+ 'product_id_encoded': product_id_encoded
165
+ }
166
+
167
+ # Create feature array in the same order as training
168
+ feature_names = preprocessing_data['feature_names']
169
+ features = np.array([[feature_dict[name] for name in feature_names]])
170
+
171
+ # Scale features
172
+ scaler = preprocessing_data['scaler']
173
+ features_scaled = scaler.transform(features)
174
+
175
+ return features_scaled
176
+
177
+
178
+ def predict_demand_ml(product_id, date, price, discount, category, model, preprocessing_data):
179
+ """
180
+ Predict demand for a product on a given date using ML model.
181
+
182
+ Args:
183
+ product_id: Product ID
184
+ date: Date string (YYYY-MM-DD) or datetime object
185
+ price: Product price
186
+ discount: Discount percentage (0-100)
187
+ category: Product category
188
+ model: Trained ML model
189
+ preprocessing_data: Dictionary containing encoders and scaler
190
+
191
+ Returns:
192
+ float: Predicted sales quantity
193
+ """
194
+ # Prepare features
195
+ features = prepare_features(product_id, date, price, discount, category, preprocessing_data)
196
+
197
+ # Make prediction
198
+ prediction = model.predict(features)[0]
199
+
200
+ # Ensure non-negative prediction
201
+ prediction = max(0, prediction)
202
+
203
+ return prediction
204
+
205
+
206
+ def predict_demand_timeseries(date, model, model_name):
207
+ """
208
+ Predict overall daily demand using time-series model.
209
+
210
+ Args:
211
+ date: Date string (YYYY-MM-DD) or datetime object
212
+ model: Trained time-series model (ARIMA or Prophet)
213
+ model_name: Name of the model ('ARIMA' or 'Prophet')
214
+
215
+ Returns:
216
+ float: Predicted total daily sales quantity
217
+ """
218
+ # Convert date to datetime if string
219
+ if isinstance(date, str):
220
+ date = pd.to_datetime(date)
221
+
222
+ if model_name == 'ARIMA':
223
+ # For ARIMA, we need to calculate how many steps ahead
224
+ # This is a simplified approach - in practice, you'd need the training end date
225
+ # For now, predict 1 step ahead
226
+ try:
227
+ forecast = model.forecast(steps=1)
228
+ prediction = forecast[0] if hasattr(forecast, '__iter__') else forecast
229
+ prediction = max(0, prediction)
230
+ return prediction
231
+ except Exception as e:
232
+ print(f"Error in ARIMA prediction: {e}")
233
+ return None
234
+
235
+ elif model_name == 'Prophet':
236
+ # For Prophet, create a future dataframe
237
+ try:
238
+ future = pd.DataFrame({'ds': [date]})
239
+ forecast = model.predict(future)
240
+ prediction = forecast['yhat'].iloc[0]
241
+ prediction = max(0, prediction)
242
+ return prediction
243
+ except Exception as e:
244
+ print(f"Error in Prophet prediction: {e}")
245
+ return None
246
+
247
+ else:
248
+ print(f"Unknown time-series model: {model_name}")
249
+ return None
250
+
251
+
252
+ def predict_batch(predictions_data, model, preprocessing_data):
253
+ """
254
+ Predict demand for multiple products/dates at once.
255
+
256
+ Args:
257
+ predictions_data: List of dictionaries, each containing:
258
+ - product_id
259
+ - date
260
+ - price
261
+ - discount
262
+ - category
263
+ model: Trained model
264
+ preprocessing_data: Dictionary containing encoders and scaler
265
+
266
+ Returns:
267
+ list: List of predicted sales quantities
268
+ """
269
+ predictions = []
270
+
271
+ for data in predictions_data:
272
+ pred = predict_demand(
273
+ data['product_id'],
274
+ data['date'],
275
+ data['price'],
276
+ data['discount'],
277
+ data['category'],
278
+ model,
279
+ preprocessing_data
280
+ )
281
+ predictions.append(pred)
282
+
283
+ return predictions
284
+
285
+
286
+ def main():
287
+ """
288
+ Main function for command-line interface.
289
+ """
290
+ parser = argparse.ArgumentParser(
291
+ description='Predict product demand for a given date and product details',
292
+ formatter_class=argparse.RawDescriptionHelpFormatter,
293
+ epilog="""
294
+ Examples (ML Models):
295
+ python predict.py --product_id 1 --date 2024-01-15 --price 100 --discount 10 --category Electronics
296
+ python predict.py --product_id 5 --date 2024-06-20 --price 50 --discount 0 --category Clothing
297
+
298
+ Examples (Time-Series Models - overall daily demand):
299
+ python predict.py --date 2024-01-15 --model_type timeseries
300
+ """
301
+ )
302
+
303
+ parser.add_argument('--product_id', type=int, default=None,
304
+ help='Product ID (required for ML models)')
305
+ parser.add_argument('--date', type=str, required=True,
306
+ help='Date in YYYY-MM-DD format')
307
+ parser.add_argument('--price', type=float, default=None,
308
+ help='Product price (required for ML models)')
309
+ parser.add_argument('--discount', type=float, default=0,
310
+ help='Discount percentage (0-100), default: 0 (for ML models)')
311
+ parser.add_argument('--category', type=str, default=None,
312
+ help='Product category (required for ML models)')
313
+ parser.add_argument('--model_type', type=str, default='auto',
314
+ choices=['auto', 'ml', 'timeseries'],
315
+ help='Model type to use: auto (best model), ml, or timeseries')
316
+
317
+ args = parser.parse_args()
318
+
319
+ # Validate date format
320
+ try:
321
+ date_obj = pd.to_datetime(args.date)
322
+ except ValueError:
323
+ print(f"Error: Invalid date format '{args.date}'. Please use YYYY-MM-DD format.")
324
+ return
325
+
326
+ # Load model and preprocessing
327
+ try:
328
+ model, preprocessing_data, model_name, is_timeseries = load_model_and_preprocessing(args.model_type)
329
+ except FileNotFoundError as e:
330
+ print(f"Error: {e}")
331
+ return
332
+
333
+ # Validate arguments based on model type
334
+ if not is_timeseries:
335
+ # ML model requires product details
336
+ if args.product_id is None or args.price is None or args.category is None:
337
+ print("Error: ML models require --product_id, --price, and --category arguments.")
338
+ return
339
+
340
+ # Validate discount range
341
+ if args.discount < 0 or args.discount > 100:
342
+ print(f"Warning: Discount {args.discount}% is outside 0-100 range. Clamping to valid range.")
343
+ args.discount = max(0, min(100, args.discount))
344
+
345
+ # Make prediction
346
+ print("\n" + "="*60)
347
+ print("MAKING PREDICTION")
348
+ print("="*60)
349
+ print(f"Model: {model_name}")
350
+ print(f"Model Type: {'Time-Series' if is_timeseries else 'Machine Learning'}")
351
+ print(f"Date: {args.date}")
352
+
353
+ if not is_timeseries:
354
+ print(f"Product ID: {args.product_id}")
355
+ print(f"Price: ${args.price:.2f}")
356
+ print(f"Discount: {args.discount}%")
357
+ print(f"Category: {args.category}")
358
+
359
+ print("-"*60)
360
+
361
+ if is_timeseries:
362
+ predicted_demand = predict_demand_timeseries(
363
+ args.date,
364
+ model,
365
+ model_name
366
+ )
367
+
368
+ if predicted_demand is None:
369
+ print("Error: Failed to make prediction.")
370
+ return
371
+
372
+ print(f"\nPredicted Total Daily Sales Quantity: {predicted_demand:.0f} units")
373
+ print("(This is the predicted total demand across all products for this date)")
374
+ else:
375
+ predicted_demand = predict_demand_ml(
376
+ args.product_id,
377
+ args.date,
378
+ args.price,
379
+ args.discount,
380
+ args.category,
381
+ model,
382
+ preprocessing_data
383
+ )
384
+
385
+ print(f"\nPredicted Sales Quantity: {predicted_demand:.0f} units")
386
+ print("(This is the predicted demand for this specific product)")
387
+
388
+ print("="*60)
389
+
390
+ # Additional information
391
+ date_obj = pd.to_datetime(args.date)
392
+ day_name = date_obj.strftime('%A')
393
+ is_weekend = "Yes" if date_obj.weekday() >= 5 else "No"
394
+
395
+ print(f"\nDate Information:")
396
+ print(f" Day of week: {day_name}")
397
+ print(f" Weekend: {is_weekend}")
398
+ print(f" Month: {date_obj.strftime('%B')}")
399
+ print(f" Quarter: Q{date_obj.quarter}")
400
+
401
+
402
+ if __name__ == "__main__":
403
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas>=1.5.0
2
+ numpy>=1.23.0
3
+ scikit-learn>=1.2.0
4
+ matplotlib>=3.6.0
5
+ seaborn>=0.12.0
6
+ joblib>=1.2.0
7
+ xgboost>=1.7.0
8
+ statsmodels>=0.14.0
9
+ prophet>=1.1.0
10
+ streamlit>=1.28.0
setup_env.bat ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo Creating virtual environment...
3
+ python -m venv venv
4
+
5
+ echo.
6
+ echo Activating virtual environment...
7
+ call venv\Scripts\activate.bat
8
+
9
+ echo.
10
+ echo Installing dependencies...
11
+ pip install --upgrade pip
12
+ pip install -r requirements.txt
13
+
14
+ echo.
15
+ echo ========================================
16
+ echo Setup complete!
17
+ echo ========================================
18
+ echo.
19
+ echo To activate the virtual environment in the future, run:
20
+ echo venv\Scripts\activate
21
+ echo.
22
+ echo To deactivate, run:
23
+ echo deactivate
24
+ echo.
25
+
26
+ pause
setup_env.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "Creating virtual environment..."
4
+ python3 -m venv venv
5
+
6
+ echo ""
7
+ echo "Activating virtual environment..."
8
+ source venv/bin/activate
9
+
10
+ echo ""
11
+ echo "Installing dependencies..."
12
+ pip install --upgrade pip
13
+ pip install -r requirements.txt
14
+
15
+ echo ""
16
+ echo "========================================"
17
+ echo "Setup complete!"
18
+ echo "========================================"
19
+ echo ""
20
+ echo "To activate the virtual environment in the future, run:"
21
+ echo " source venv/bin/activate"
22
+ echo ""
23
+ echo "To deactivate, run:"
24
+ echo " deactivate"
25
+ echo ""
train_model.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demand Prediction System - Model Training Script
3
+
4
+ This script trains multiple machine learning and time-series models to predict
5
+ product demand (sales quantity) for an e-commerce platform.
6
+
7
+ Features:
8
+ - Data preprocessing and feature engineering
9
+ - Date feature extraction (day, month, day_of_week, weekend)
10
+ - Categorical encoding
11
+ - Feature scaling
12
+ - Multiple ML models (Linear Regression, Random Forest, XGBoost)
13
+ - Time-series models (ARIMA, Prophet)
14
+ - Model evaluation (MAE, RMSE, R2 Score)
15
+ - Automatic best model selection
16
+ - Model persistence using joblib
17
+ - Visualization of results
18
+ - Comparison between ML and time-series approaches
19
+ """
20
+
21
+ import pandas as pd
22
+ import numpy as np
23
+ import matplotlib.pyplot as plt
24
+ import seaborn as sns
25
+ from datetime import datetime
26
+ import joblib
27
+ import os
28
+ import warnings
29
+ warnings.filterwarnings('ignore')
30
+
31
+ # Machine Learning imports
32
+ from sklearn.model_selection import train_test_split
33
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
34
+ from sklearn.linear_model import LinearRegression
35
+ from sklearn.ensemble import RandomForestRegressor
36
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
37
+
38
+ # Try to import XGBoost (optional)
39
+ try:
40
+ import xgboost as xgb
41
+ XGBOOST_AVAILABLE = True
42
+ except ImportError:
43
+ XGBOOST_AVAILABLE = False
44
+ print("XGBoost not available. Install with: pip install xgboost")
45
+
46
+ # Try to import time-series libraries
47
+ try:
48
+ from statsmodels.tsa.arima.model import ARIMA
49
+ from statsmodels.tsa.stattools import adfuller
50
+ ARIMA_AVAILABLE = True
51
+ except ImportError:
52
+ ARIMA_AVAILABLE = False
53
+ print("statsmodels not available. Install with: pip install statsmodels")
54
+
55
+ try:
56
+ from prophet import Prophet
57
+ PROPHET_AVAILABLE = True
58
+ except ImportError:
59
+ PROPHET_AVAILABLE = False
60
+ print("Prophet not available. Install with: pip install prophet")
61
+
62
+ # Set random seeds for reproducibility
63
+ np.random.seed(42)
64
+
65
+ # Configuration
66
+ DATA_PATH = 'data/sales.csv'
67
+ MODEL_DIR = 'models'
68
+ PLOTS_DIR = 'plots'
69
+
70
+ # Create directories if they don't exist
71
+ os.makedirs(MODEL_DIR, exist_ok=True)
72
+ os.makedirs(PLOTS_DIR, exist_ok=True)
73
+
74
+
75
+ def load_data(file_path):
76
+ """
77
+ Load the sales dataset from CSV file.
78
+
79
+ Args:
80
+ file_path: Path to the CSV file
81
+
82
+ Returns:
83
+ DataFrame: Loaded dataset
84
+ """
85
+ print(f"Loading data from {file_path}...")
86
+ df = pd.read_csv(file_path)
87
+ print(f"Data loaded successfully! Shape: {df.shape}")
88
+ return df
89
+
90
+
91
+ def preprocess_data(df):
92
+ """
93
+ Preprocess the data: convert date, extract features, handle missing values.
94
+
95
+ Args:
96
+ df: Raw DataFrame
97
+
98
+ Returns:
99
+ DataFrame: Preprocessed DataFrame
100
+ """
101
+ print("\n" + "="*60)
102
+ print("PREPROCESSING DATA")
103
+ print("="*60)
104
+
105
+ # Create a copy to avoid modifying original
106
+ df = df.copy()
107
+
108
+ # Convert date column to datetime
109
+ df['date'] = pd.to_datetime(df['date'])
110
+
111
+ # Extract date features
112
+ print("Extracting date features...")
113
+ df['day'] = df['date'].dt.day
114
+ df['month'] = df['date'].dt.month
115
+ df['day_of_week'] = df['date'].dt.dayofweek # 0=Monday, 6=Sunday
116
+ df['weekend'] = (df['day_of_week'] >= 5).astype(int) # 1 if weekend, 0 otherwise
117
+ df['year'] = df['date'].dt.year
118
+ df['quarter'] = df['date'].dt.quarter
119
+
120
+ # Check for missing values
121
+ print("\nMissing values:")
122
+ missing = df.isnull().sum()
123
+ print(missing[missing > 0])
124
+
125
+ if missing.sum() > 0:
126
+ print("Filling missing values...")
127
+ df = df.fillna(df.median(numeric_only=True))
128
+
129
+ # Display basic statistics
130
+ print("\nDataset Info:")
131
+ print(f"Shape: {df.shape}")
132
+ print(f"\nColumns: {df.columns.tolist()}")
133
+ print(f"\nData types:\n{df.dtypes}")
134
+ print(f"\nBasic statistics:\n{df.describe()}")
135
+
136
+ return df
137
+
138
+
139
+ def feature_engineering(df):
140
+ """
141
+ Perform feature engineering: encode categorical variables, scale features.
142
+
143
+ Args:
144
+ df: Preprocessed DataFrame
145
+
146
+ Returns:
147
+ tuple: (X_features, y_target, feature_names, encoders, scaler)
148
+ """
149
+ print("\n" + "="*60)
150
+ print("FEATURE ENGINEERING")
151
+ print("="*60)
152
+
153
+ # Separate features and target
154
+ # Drop original date column (we have extracted features from it)
155
+ # Keep product_id for now (we'll encode it)
156
+ feature_columns = ['product_id', 'price', 'discount', 'category',
157
+ 'day', 'month', 'day_of_week', 'weekend', 'year', 'quarter']
158
+
159
+ X = df[feature_columns].copy()
160
+ y = df['sales_quantity'].copy()
161
+
162
+ # Encode categorical variables
163
+ print("Encoding categorical variables...")
164
+
165
+ # Label encode category
166
+ category_encoder = LabelEncoder()
167
+ X['category_encoded'] = category_encoder.fit_transform(X['category'])
168
+
169
+ # Label encode product_id (treating it as categorical)
170
+ product_encoder = LabelEncoder()
171
+ X['product_id_encoded'] = product_encoder.fit_transform(X['product_id'])
172
+
173
+ # Drop original categorical columns
174
+ X = X.drop(['category', 'product_id'], axis=1)
175
+
176
+ # Get feature names
177
+ feature_names = X.columns.tolist()
178
+
179
+ print(f"Features after encoding: {feature_names}")
180
+ print(f"Number of features: {len(feature_names)}")
181
+
182
+ # Scale numerical features
183
+ print("\nScaling numerical features...")
184
+ scaler = StandardScaler()
185
+ X_scaled = scaler.fit_transform(X)
186
+ X_scaled = pd.DataFrame(X_scaled, columns=feature_names)
187
+
188
+ # Store encoders and scaler for later use
189
+ encoders = {
190
+ 'category': category_encoder,
191
+ 'product_id': product_encoder,
192
+ 'scaler': scaler
193
+ }
194
+
195
+ return X_scaled, y, feature_names, encoders, scaler
196
+
197
+
198
+ def train_models(X_train, y_train, X_val, y_val):
199
+ """
200
+ Train multiple models and return their performance metrics.
201
+
202
+ Args:
203
+ X_train: Training features
204
+ y_train: Training target
205
+ X_val: Validation features
206
+ y_val: Validation target
207
+
208
+ Returns:
209
+ dict: Dictionary containing models and their metrics
210
+ """
211
+ print("\n" + "="*60)
212
+ print("TRAINING MODELS")
213
+ print("="*60)
214
+
215
+ models = {}
216
+ results = {}
217
+
218
+ # 1. Linear Regression
219
+ print("\n1. Training Linear Regression...")
220
+ lr_model = LinearRegression()
221
+ lr_model.fit(X_train, y_train)
222
+ lr_pred = lr_model.predict(X_val)
223
+
224
+ lr_mae = mean_absolute_error(y_val, lr_pred)
225
+ lr_rmse = np.sqrt(mean_squared_error(y_val, lr_pred))
226
+ lr_r2 = r2_score(y_val, lr_pred)
227
+
228
+ models['Linear Regression'] = lr_model
229
+ results['Linear Regression'] = {
230
+ 'model': lr_model,
231
+ 'mae': lr_mae,
232
+ 'rmse': lr_rmse,
233
+ 'r2': lr_r2,
234
+ 'predictions': lr_pred
235
+ }
236
+
237
+ print(f" MAE: {lr_mae:.2f}, RMSE: {lr_rmse:.2f}, R2: {lr_r2:.4f}")
238
+
239
+ # 2. Random Forest Regressor
240
+ print("\n2. Training Random Forest Regressor...")
241
+ rf_model = RandomForestRegressor(
242
+ n_estimators=100,
243
+ max_depth=15,
244
+ min_samples_split=5,
245
+ min_samples_leaf=2,
246
+ random_state=42,
247
+ n_jobs=-1
248
+ )
249
+ rf_model.fit(X_train, y_train)
250
+ rf_pred = rf_model.predict(X_val)
251
+
252
+ rf_mae = mean_absolute_error(y_val, rf_pred)
253
+ rf_rmse = np.sqrt(mean_squared_error(y_val, rf_pred))
254
+ rf_r2 = r2_score(y_val, rf_pred)
255
+
256
+ models['Random Forest'] = rf_model
257
+ results['Random Forest'] = {
258
+ 'model': rf_model,
259
+ 'mae': rf_mae,
260
+ 'rmse': rf_rmse,
261
+ 'r2': rf_r2,
262
+ 'predictions': rf_pred
263
+ }
264
+
265
+ print(f" MAE: {rf_mae:.2f}, RMSE: {rf_rmse:.2f}, R2: {rf_r2:.4f}")
266
+
267
+ # 3. XGBoost (if available)
268
+ if XGBOOST_AVAILABLE:
269
+ print("\n3. Training XGBoost Regressor...")
270
+ xgb_model = xgb.XGBRegressor(
271
+ n_estimators=100,
272
+ max_depth=6,
273
+ learning_rate=0.1,
274
+ random_state=42,
275
+ n_jobs=-1
276
+ )
277
+ xgb_model.fit(X_train, y_train)
278
+ xgb_pred = xgb_model.predict(X_val)
279
+
280
+ xgb_mae = mean_absolute_error(y_val, xgb_pred)
281
+ xgb_rmse = np.sqrt(mean_squared_error(y_val, xgb_pred))
282
+ xgb_r2 = r2_score(y_val, xgb_pred)
283
+
284
+ models['XGBoost'] = xgb_model
285
+ results['XGBoost'] = {
286
+ 'model': xgb_model,
287
+ 'mae': xgb_mae,
288
+ 'rmse': xgb_rmse,
289
+ 'r2': xgb_r2,
290
+ 'predictions': xgb_pred
291
+ }
292
+
293
+ print(f" MAE: {xgb_mae:.2f}, RMSE: {xgb_rmse:.2f}, R2: {xgb_r2:.4f}")
294
+ else:
295
+ print("\n3. XGBoost skipped (not available)")
296
+
297
+ return results
298
+
299
+
300
+ def prepare_time_series_data(df):
301
+ """
302
+ Prepare time-series data by aggregating daily sales.
303
+
304
+ Args:
305
+ df: DataFrame with date and sales_quantity columns
306
+
307
+ Returns:
308
+ tuple: (ts_data, train_size) - time series data and training size
309
+ """
310
+ print("\n" + "="*60)
311
+ print("PREPARING TIME-SERIES DATA")
312
+ print("="*60)
313
+
314
+ # Aggregate by date
315
+ df['date'] = pd.to_datetime(df['date'])
316
+ ts_data = df.groupby('date')['sales_quantity'].sum().reset_index()
317
+ ts_data = ts_data.sort_values('date').reset_index(drop=True)
318
+ ts_data.columns = ['ds', 'y'] # Prophet expects 'ds' and 'y'
319
+
320
+ print(f"Time-series data shape: {ts_data.shape}")
321
+ print(f"Date range: {ts_data['ds'].min()} to {ts_data['ds'].max()}")
322
+ print(f"Total days: {len(ts_data)}")
323
+
324
+ # Use 80% for training (chronological split for time-series)
325
+ train_size = int(len(ts_data) * 0.8)
326
+
327
+ return ts_data, train_size
328
+
329
+
330
+ def train_arima(ts_data, train_size):
331
+ """
332
+ Train ARIMA model on time-series data.
333
+
334
+ Args:
335
+ ts_data: Time-series DataFrame with 'ds' and 'y' columns
336
+ train_size: Number of samples for training
337
+
338
+ Returns:
339
+ dict: Model results dictionary
340
+ """
341
+ if not ARIMA_AVAILABLE:
342
+ return None
343
+
344
+ print("\n" + "="*60)
345
+ print("TRAINING ARIMA MODEL")
346
+ print("="*60)
347
+
348
+ try:
349
+ # Split data chronologically
350
+ train_data = ts_data['y'].iloc[:train_size].values
351
+ val_data = ts_data['y'].iloc[train_size:].values
352
+ val_dates = ts_data['ds'].iloc[train_size:].values
353
+
354
+ print(f"Training on {len(train_data)} samples")
355
+ print(f"Validating on {len(val_data)} samples")
356
+
357
+ # Try different ARIMA orders (p, d, q)
358
+ # Start with auto_arima-like approach - try common orders
359
+ best_aic = np.inf
360
+ best_order = None
361
+ best_model = None
362
+
363
+ # Common ARIMA orders to try
364
+ orders_to_try = [
365
+ (1, 1, 1), # Standard ARIMA(1,1,1)
366
+ (2, 1, 2), # ARIMA(2,1,2)
367
+ (1, 1, 0), # ARIMA(1,1,0) - AR model
368
+ (0, 1, 1), # ARIMA(0,1,1) - MA model
369
+ (2, 1, 1), # ARIMA(2,1,1)
370
+ (1, 1, 2), # ARIMA(1,1,2)
371
+ ]
372
+
373
+ print("Trying different ARIMA orders...")
374
+ for order in orders_to_try:
375
+ try:
376
+ model = ARIMA(train_data, order=order)
377
+ fitted_model = model.fit()
378
+ aic = fitted_model.aic
379
+
380
+ if aic < best_aic:
381
+ best_aic = aic
382
+ best_order = order
383
+ best_model = fitted_model
384
+ print(f" Order {order}: AIC = {aic:.2f} (best so far)")
385
+ else:
386
+ print(f" Order {order}: AIC = {aic:.2f}")
387
+ except Exception as e:
388
+ print(f" Order {order}: Failed - {str(e)[:50]}")
389
+ continue
390
+
391
+ if best_model is None:
392
+ print("Failed to fit ARIMA model with any order")
393
+ return None
394
+
395
+ print(f"\nBest ARIMA order: {best_order} (AIC: {best_aic:.2f})")
396
+
397
+ # Make predictions
398
+ forecast_steps = len(val_data)
399
+ forecast = best_model.forecast(steps=forecast_steps)
400
+
401
+ # Ensure predictions are non-negative
402
+ forecast = np.maximum(forecast, 0)
403
+
404
+ # Calculate metrics
405
+ mae = mean_absolute_error(val_data, forecast)
406
+ rmse = np.sqrt(mean_squared_error(val_data, forecast))
407
+ r2 = r2_score(val_data, forecast)
408
+
409
+ print(f" MAE: {mae:.2f}, RMSE: {rmse:.2f}, R2: {r2:.4f}")
410
+
411
+ return {
412
+ 'model': best_model,
413
+ 'order': best_order,
414
+ 'mae': mae,
415
+ 'rmse': rmse,
416
+ 'r2': r2,
417
+ 'predictions': forecast,
418
+ 'actual': val_data,
419
+ 'dates': val_dates
420
+ }
421
+
422
+ except Exception as e:
423
+ print(f"Error training ARIMA: {str(e)}")
424
+ return None
425
+
426
+
427
+ def train_prophet(ts_data, train_size):
428
+ """
429
+ Train Prophet model on time-series data.
430
+
431
+ Args:
432
+ ts_data: Time-series DataFrame with 'ds' and 'y' columns
433
+ train_size: Number of samples for training
434
+
435
+ Returns:
436
+ dict: Model results dictionary
437
+ """
438
+ if not PROPHET_AVAILABLE:
439
+ return None
440
+
441
+ print("\n" + "="*60)
442
+ print("TRAINING PROPHET MODEL")
443
+ print("="*60)
444
+
445
+ try:
446
+ # Split data chronologically
447
+ train_data = ts_data.iloc[:train_size].copy()
448
+ val_data = ts_data.iloc[train_size:].copy()
449
+
450
+ print(f"Training on {len(train_data)} samples")
451
+ print(f"Validating on {len(val_data)} samples")
452
+
453
+ # Initialize and fit Prophet model
454
+ # Enable daily seasonality and weekly/yearly seasonality
455
+ model = Prophet(
456
+ daily_seasonality=False, # Disable daily for daily data
457
+ weekly_seasonality=True,
458
+ yearly_seasonality=True,
459
+ seasonality_mode='multiplicative',
460
+ changepoint_prior_scale=0.05
461
+ )
462
+
463
+ print("Fitting Prophet model...")
464
+ model.fit(train_data)
465
+
466
+ # Create future dataframe for validation period
467
+ future = model.make_future_dataframe(periods=len(val_data), freq='D')
468
+
469
+ # Make predictions
470
+ forecast = model.predict(future)
471
+
472
+ # Get predictions for validation period
473
+ val_forecast = forecast.iloc[train_size:]['yhat'].values
474
+ val_actual = val_data['y'].values
475
+
476
+ # Ensure predictions are non-negative
477
+ val_forecast = np.maximum(val_forecast, 0)
478
+
479
+ # Calculate metrics
480
+ mae = mean_absolute_error(val_actual, val_forecast)
481
+ rmse = np.sqrt(mean_squared_error(val_actual, val_forecast))
482
+ r2 = r2_score(val_actual, val_forecast)
483
+
484
+ print(f" MAE: {mae:.2f}, RMSE: {rmse:.2f}, R2: {r2:.4f}")
485
+
486
+ return {
487
+ 'model': model,
488
+ 'mae': mae,
489
+ 'rmse': rmse,
490
+ 'r2': r2,
491
+ 'predictions': val_forecast,
492
+ 'actual': val_actual,
493
+ 'dates': val_data['ds'].values,
494
+ 'full_forecast': forecast
495
+ }
496
+
497
+ except Exception as e:
498
+ print(f"Error training Prophet: {str(e)}")
499
+ import traceback
500
+ traceback.print_exc()
501
+ return None
502
+
503
+
504
+ def select_best_model(results):
505
+ """
506
+ Select the best model based on R2 score (higher is better).
507
+
508
+ Args:
509
+ results: Dictionary containing model results
510
+
511
+ Returns:
512
+ tuple: (best_model_name, best_model, best_metrics)
513
+ """
514
+ print("\n" + "="*60)
515
+ print("MODEL COMPARISON")
516
+ print("="*60)
517
+
518
+ # Create comparison DataFrame
519
+ comparison_data = []
520
+ for model_name, metrics in results.items():
521
+ comparison_data.append({
522
+ 'Model': model_name,
523
+ 'MAE': metrics['mae'],
524
+ 'RMSE': metrics['rmse'],
525
+ 'R2 Score': metrics['r2']
526
+ })
527
+
528
+ comparison_df = pd.DataFrame(comparison_data)
529
+ print("\nModel Performance Comparison:")
530
+ print(comparison_df.to_string(index=False))
531
+
532
+ # Select best model based on R2 score
533
+ best_model_name = max(results.keys(), key=lambda x: results[x]['r2'])
534
+ best_model = results[best_model_name]['model']
535
+ best_metrics = {
536
+ 'mae': results[best_model_name]['mae'],
537
+ 'rmse': results[best_model_name]['rmse'],
538
+ 'r2': results[best_model_name]['r2']
539
+ }
540
+
541
+ print(f"\n{'='*60}")
542
+ print(f"BEST MODEL: {best_model_name}")
543
+ print(f"MAE: {best_metrics['mae']:.2f}")
544
+ print(f"RMSE: {best_metrics['rmse']:.2f}")
545
+ print(f"R2 Score: {best_metrics['r2']:.4f}")
546
+ print(f"{'='*60}")
547
+
548
+ return best_model_name, best_model, best_metrics
549
+
550
+
551
+ def visualize_results(df, results, best_model_name, feature_names):
552
+ """
553
+ Create visualizations: demand trends, feature importance, model comparison.
554
+
555
+ Args:
556
+ df: Original DataFrame
557
+ results: Model results dictionary
558
+ best_model_name: Name of the best model
559
+ feature_names: List of feature names
560
+ """
561
+ print("\n" + "="*60)
562
+ print("GENERATING VISUALIZATIONS")
563
+ print("="*60)
564
+
565
+ # Set style
566
+ sns.set_style("whitegrid")
567
+ plt.rcParams['figure.figsize'] = (12, 6)
568
+
569
+ # 1. Demand trends over time
570
+ print("1. Plotting demand trends over time...")
571
+ df['date'] = pd.to_datetime(df['date'])
572
+ daily_demand = df.groupby('date')['sales_quantity'].sum().reset_index()
573
+
574
+ plt.figure(figsize=(14, 6))
575
+ plt.plot(daily_demand['date'], daily_demand['sales_quantity'], linewidth=1, alpha=0.7)
576
+ plt.title('Total Daily Sales Quantity Over Time', fontsize=16, fontweight='bold')
577
+ plt.xlabel('Date', fontsize=12)
578
+ plt.ylabel('Total Sales Quantity', fontsize=12)
579
+ plt.grid(True, alpha=0.3)
580
+ plt.tight_layout()
581
+ plt.savefig(f'{PLOTS_DIR}/demand_trends.png', dpi=300, bbox_inches='tight')
582
+ print(f" Saved: {PLOTS_DIR}/demand_trends.png")
583
+ plt.close()
584
+
585
+ # 2. Monthly average demand
586
+ print("2. Plotting monthly average demand...")
587
+ df['month_name'] = pd.to_datetime(df['date']).dt.strftime('%B')
588
+ monthly_avg = df.groupby('month')['sales_quantity'].mean().reset_index()
589
+ month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
590
+ 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
591
+ monthly_avg['month_name'] = monthly_avg['month'].apply(lambda x: month_names[x-1])
592
+
593
+ plt.figure(figsize=(12, 6))
594
+ plt.bar(monthly_avg['month_name'], monthly_avg['sales_quantity'], color='steelblue', alpha=0.7)
595
+ plt.title('Average Sales Quantity by Month', fontsize=16, fontweight='bold')
596
+ plt.xlabel('Month', fontsize=12)
597
+ plt.ylabel('Average Sales Quantity', fontsize=12)
598
+ plt.xticks(rotation=45)
599
+ plt.grid(True, alpha=0.3, axis='y')
600
+ plt.tight_layout()
601
+ plt.savefig(f'{PLOTS_DIR}/monthly_demand.png', dpi=300, bbox_inches='tight')
602
+ print(f" Saved: {PLOTS_DIR}/monthly_demand.png")
603
+ plt.close()
604
+
605
+ # 3. Feature importance (for tree-based models)
606
+ print("3. Plotting feature importance...")
607
+ best_model = results[best_model_name]['model']
608
+
609
+ if hasattr(best_model, 'feature_importances_'):
610
+ importances = best_model.feature_importances_
611
+ feature_importance_df = pd.DataFrame({
612
+ 'feature': feature_names,
613
+ 'importance': importances
614
+ }).sort_values('importance', ascending=False)
615
+
616
+ plt.figure(figsize=(10, 6))
617
+ plt.barh(feature_importance_df['feature'], feature_importance_df['importance'], color='coral', alpha=0.7)
618
+ plt.title(f'Feature Importance - {best_model_name}', fontsize=16, fontweight='bold')
619
+ plt.xlabel('Importance', fontsize=12)
620
+ plt.ylabel('Feature', fontsize=12)
621
+ plt.gca().invert_yaxis()
622
+ plt.grid(True, alpha=0.3, axis='x')
623
+ plt.tight_layout()
624
+ plt.savefig(f'{PLOTS_DIR}/feature_importance.png', dpi=300, bbox_inches='tight')
625
+ print(f" Saved: {PLOTS_DIR}/feature_importance.png")
626
+ plt.close()
627
+ else:
628
+ print(" Feature importance not available for this model type")
629
+
630
+ # 4. Model comparison
631
+ print("4. Plotting model comparison...")
632
+ model_names = list(results.keys())
633
+ mae_scores = [results[m]['mae'] for m in model_names]
634
+ rmse_scores = [results[m]['rmse'] for m in model_names]
635
+ r2_scores = [results[m]['r2'] for m in model_names]
636
+
637
+ # Separate ML and time-series models for visualization
638
+ ml_models = [m for m in model_names if m not in ['ARIMA', 'Prophet']]
639
+ ts_models = [m for m in model_names if m in ['ARIMA', 'Prophet']]
640
+
641
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
642
+
643
+ # Color code: ML models in blue tones, TS models in orange/red tones
644
+ colors = []
645
+ for m in model_names:
646
+ if m in ts_models:
647
+ colors.append('coral' if m == 'ARIMA' else 'salmon')
648
+ else:
649
+ colors.append('skyblue')
650
+
651
+ # MAE comparison
652
+ axes[0].bar(model_names, mae_scores, color=colors, alpha=0.7)
653
+ axes[0].set_title('MAE Comparison (Lower is Better)', fontsize=14, fontweight='bold')
654
+ axes[0].set_ylabel('MAE', fontsize=12)
655
+ axes[0].tick_params(axis='x', rotation=45)
656
+ axes[0].grid(True, alpha=0.3, axis='y')
657
+ # Add legend
658
+ from matplotlib.patches import Patch
659
+ legend_elements = [
660
+ Patch(facecolor='skyblue', alpha=0.7, label='ML Models'),
661
+ Patch(facecolor='coral', alpha=0.7, label='Time-Series Models')
662
+ ]
663
+ axes[0].legend(handles=legend_elements, loc='upper right')
664
+
665
+ # RMSE comparison
666
+ axes[1].bar(model_names, rmse_scores, color=colors, alpha=0.7)
667
+ axes[1].set_title('RMSE Comparison (Lower is Better)', fontsize=14, fontweight='bold')
668
+ axes[1].set_ylabel('RMSE', fontsize=12)
669
+ axes[1].tick_params(axis='x', rotation=45)
670
+ axes[1].grid(True, alpha=0.3, axis='y')
671
+
672
+ # R2 comparison
673
+ axes[2].bar(model_names, r2_scores, color=colors, alpha=0.7)
674
+ axes[2].set_title('R2 Score Comparison (Higher is Better)', fontsize=14, fontweight='bold')
675
+ axes[2].set_ylabel('R2 Score', fontsize=12)
676
+ axes[2].tick_params(axis='x', rotation=45)
677
+ axes[2].grid(True, alpha=0.3, axis='y')
678
+
679
+ plt.tight_layout()
680
+ plt.savefig(f'{PLOTS_DIR}/model_comparison.png', dpi=300, bbox_inches='tight')
681
+ print(f" Saved: {PLOTS_DIR}/model_comparison.png")
682
+ plt.close()
683
+
684
+ # 5. Time-series predictions plot (if time-series models available)
685
+ if ts_models:
686
+ print("5. Plotting time-series model predictions...")
687
+ fig, axes = plt.subplots(len(ts_models), 1, figsize=(14, 6*len(ts_models)))
688
+ if len(ts_models) == 1:
689
+ axes = [axes]
690
+
691
+ for idx, model_name in enumerate(ts_models):
692
+ if model_name in results and 'dates' in results[model_name]:
693
+ dates = pd.to_datetime(results[model_name]['dates'])
694
+ actual = results[model_name]['actual']
695
+ predictions = results[model_name]['predictions']
696
+
697
+ axes[idx].plot(dates, actual, label='Actual', linewidth=2, alpha=0.7)
698
+ axes[idx].plot(dates, predictions, label='Predicted', linewidth=2, alpha=0.7, linestyle='--')
699
+ axes[idx].set_title(f'{model_name} - Actual vs Predicted', fontsize=14, fontweight='bold')
700
+ axes[idx].set_xlabel('Date', fontsize=12)
701
+ axes[idx].set_ylabel('Sales Quantity', fontsize=12)
702
+ axes[idx].legend()
703
+ axes[idx].grid(True, alpha=0.3)
704
+
705
+ plt.tight_layout()
706
+ plt.savefig(f'{PLOTS_DIR}/timeseries_predictions.png', dpi=300, bbox_inches='tight')
707
+ print(f" Saved: {PLOTS_DIR}/timeseries_predictions.png")
708
+ plt.close()
709
+
710
+ print(" Visualization complete!")
711
+
712
+
713
+ def save_model(model, encoders, scaler, feature_names, best_model_name, best_metrics):
714
+ """
715
+ Save the trained model and preprocessing objects.
716
+
717
+ Args:
718
+ model: Trained model
719
+ encoders: Dictionary of encoders
720
+ scaler: Fitted scaler
721
+ feature_names: List of feature names
722
+ best_model_name: Name of the best model
723
+ best_metrics: Dictionary of metrics
724
+ """
725
+ print("\n" + "="*60)
726
+ print("SAVING MODEL")
727
+ print("="*60)
728
+
729
+ # Save model
730
+ model_path = f'{MODEL_DIR}/best_model.joblib'
731
+ joblib.dump(model, model_path)
732
+ print(f"Model saved to: {model_path}")
733
+
734
+ # Save encoders and scaler
735
+ preprocessing_path = f'{MODEL_DIR}/preprocessing.joblib'
736
+ preprocessing_data = {
737
+ 'encoders': encoders,
738
+ 'scaler': scaler,
739
+ 'feature_names': feature_names
740
+ }
741
+ joblib.dump(preprocessing_data, preprocessing_path)
742
+ print(f"Preprocessing objects saved to: {preprocessing_path}")
743
+
744
+ # Save model metadata
745
+ metadata = {
746
+ 'model_name': best_model_name,
747
+ 'metrics': best_metrics,
748
+ 'feature_names': feature_names,
749
+ 'saved_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
750
+ }
751
+
752
+ import json
753
+ metadata_path = f'{MODEL_DIR}/model_metadata.json'
754
+ with open(metadata_path, 'w') as f:
755
+ json.dump(metadata, f, indent=4)
756
+ print(f"Model metadata saved to: {metadata_path}")
757
+
758
+
759
+ def main():
760
+ """
761
+ Main function to orchestrate the training pipeline.
762
+ """
763
+ print("\n" + "="*60)
764
+ print("DEMAND PREDICTION SYSTEM - MODEL TRAINING")
765
+ print("ML Models vs Time-Series Models Comparison")
766
+ print("="*60)
767
+
768
+ # Step 1: Load data
769
+ df = load_data(DATA_PATH)
770
+
771
+ # Step 2: Preprocess data
772
+ df_processed = preprocess_data(df)
773
+
774
+ # Step 3: Feature engineering for ML models
775
+ X, y, feature_names, encoders, scaler = feature_engineering(df_processed)
776
+
777
+ # Step 4: Split data for ML models (random split)
778
+ print("\n" + "="*60)
779
+ print("SPLITTING DATA FOR ML MODELS")
780
+ print("="*60)
781
+ X_train, X_val, y_train, y_val = train_test_split(
782
+ X, y, test_size=0.2, random_state=42
783
+ )
784
+ print(f"Training set: {X_train.shape[0]} samples")
785
+ print(f"Validation set: {X_val.shape[0]} samples")
786
+
787
+ # Step 5: Train ML models
788
+ print("\n" + "="*70)
789
+ print("TRAINING MACHINE LEARNING MODELS")
790
+ print("="*70)
791
+ results = train_models(X_train, y_train, X_val, y_val)
792
+
793
+ # Step 6: Prepare time-series data
794
+ ts_data, train_size = prepare_time_series_data(df_processed)
795
+
796
+ # Step 7: Train time-series models
797
+ print("\n" + "="*70)
798
+ print("TRAINING TIME-SERIES MODELS")
799
+ print("="*70)
800
+
801
+ # Train ARIMA
802
+ if ARIMA_AVAILABLE:
803
+ arima_results = train_arima(ts_data, train_size)
804
+ if arima_results:
805
+ results['ARIMA'] = arima_results
806
+ else:
807
+ print("\nARIMA skipped (statsmodels not available)")
808
+
809
+ # Train Prophet
810
+ if PROPHET_AVAILABLE:
811
+ prophet_results = train_prophet(ts_data, train_size)
812
+ if prophet_results:
813
+ results['Prophet'] = prophet_results
814
+ else:
815
+ print("\nProphet skipped (prophet not available)")
816
+
817
+ # Step 8: Select best model (across all model types)
818
+ best_model_name, best_model, best_metrics = select_best_model(results)
819
+
820
+ # Step 9: Visualize results
821
+ visualize_results(df_processed, results, best_model_name, feature_names)
822
+
823
+ # Step 10: Save model (only ML models can be saved with preprocessing)
824
+ # For time-series models, save separately
825
+ if best_model_name not in ['ARIMA', 'Prophet']:
826
+ save_model(best_model, encoders, scaler, feature_names, best_model_name, best_metrics)
827
+ else:
828
+ # Save time-series model separately
829
+ print("\n" + "="*60)
830
+ print("SAVING TIME-SERIES MODEL")
831
+ print("="*60)
832
+ ts_model_path = f'{MODEL_DIR}/best_timeseries_model.joblib'
833
+ joblib.dump(best_model, ts_model_path)
834
+ print(f"Time-series model saved to: {ts_model_path}")
835
+
836
+ # Also save preprocessing for ML models (in case user wants to use them)
837
+ preprocessing_path = f'{MODEL_DIR}/preprocessing.joblib'
838
+ preprocessing_data = {
839
+ 'encoders': encoders,
840
+ 'scaler': scaler,
841
+ 'feature_names': feature_names
842
+ }
843
+ joblib.dump(preprocessing_data, preprocessing_path)
844
+ print(f"ML preprocessing objects saved to: {preprocessing_path}")
845
+
846
+ # Save all results metadata
847
+ import json
848
+ all_models_metadata = {
849
+ 'best_model': best_model_name,
850
+ 'best_metrics': best_metrics,
851
+ 'all_models': {}
852
+ }
853
+ for model_name, model_results in results.items():
854
+ all_models_metadata['all_models'][model_name] = {
855
+ 'mae': model_results['mae'],
856
+ 'rmse': model_results['rmse'],
857
+ 'r2': model_results['r2']
858
+ }
859
+ all_models_metadata['saved_at'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
860
+
861
+ metadata_path = f'{MODEL_DIR}/all_models_metadata.json'
862
+ with open(metadata_path, 'w') as f:
863
+ json.dump(all_models_metadata, f, indent=4)
864
+ print(f"All models metadata saved to: {metadata_path}")
865
+
866
+ print("\n" + "="*60)
867
+ print("TRAINING COMPLETE!")
868
+ print("="*60)
869
+ print(f"\nBest model: {best_model_name}")
870
+ print(f"Model type: {'Time-Series' if best_model_name in ['ARIMA', 'Prophet'] else 'Machine Learning'}")
871
+ print(f"Model saved to: {MODEL_DIR}/")
872
+ print(f"Visualizations saved to: {PLOTS_DIR}/")
873
+ print("\nYou can now use predict.py to make predictions!")
874
+
875
+
876
+ if __name__ == "__main__":
877
+ main()