Add Apple Price Prediction Model: Prophet+ARIMA hybrid ensemble
Browse files- .orchids/orchids.json +7 -0
- README.md +164 -0
- data/apple_price_dataset.csv +0 -0
- model/predict.py +118 -0
- models/arima_model.pkl +3 -0
- models/metrics.json +8 -0
- models/prophet_model.pkl +3 -0
- models/scaler.pkl +3 -0
- requirements.txt +7 -0
- train.py +361 -0
.orchids/orchids.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"projectId": "457bf06a-ee2d-4ccf-ad25-70bfd88b650d",
|
| 3 |
+
"createdAt": 1772863825230,
|
| 4 |
+
"version": "1.0",
|
| 5 |
+
"startupCommands": [],
|
| 6 |
+
"templateId": "any"
|
| 7 |
+
}
|
README.md
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Apple Price Prediction Model
|
| 2 |
+
|
| 3 |
+
> Hybrid Prophet + ARIMA ensemble for forecasting apple market prices and generating sell/store recommendations.
|
| 4 |
+
|
| 5 |
+
[](https://huggingface.co/Arko007/apple-price-predictor)
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Overview
|
| 10 |
+
|
| 11 |
+
This model forecasts apple market prices 7 days ahead and recommends whether farmers/traders should **SELL** immediately or **STORE** apples for higher future returns.
|
| 12 |
+
|
| 13 |
+
| Component | Details |
|
| 14 |
+
|----------------|----------------------------------------------|
|
| 15 |
+
| Architecture | Hybrid Prophet + ARIMA Ensemble |
|
| 16 |
+
| Blend Weights | 60% Prophet + 40% ARIMA |
|
| 17 |
+
| Training Data | 12,000 synthetic samples (2018β2021) |
|
| 18 |
+
| Varieties | Fuji, Gala, Granny Smith, Honeycrisp, Pink Lady |
|
| 19 |
+
| Regions | Northwest, Northeast, Midwest, Southeast, Southwest |
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## Quickstart
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
from model.predict import predict_price
|
| 27 |
+
|
| 28 |
+
result = predict_price({
|
| 29 |
+
"date": "2026-03-07",
|
| 30 |
+
"current_price": 1.85,
|
| 31 |
+
"storage_time_days": 10,
|
| 32 |
+
"apple_variety": "Fuji",
|
| 33 |
+
"region": "Northwest",
|
| 34 |
+
})
|
| 35 |
+
|
| 36 |
+
print(result)
|
| 37 |
+
# {
|
| 38 |
+
# "predicted_price": 1.92,
|
| 39 |
+
# "recommendation": "STORE",
|
| 40 |
+
# "current_price": 1.85,
|
| 41 |
+
# "storage_cost_7d": 0.014,
|
| 42 |
+
# "confidence": "hybrid Prophet+ARIMA (0.6/0.4)"
|
| 43 |
+
# }
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## Repository Structure
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
apple-price-predictor/
|
| 52 |
+
βββ data/
|
| 53 |
+
β βββ apple_price_dataset.csv # 12,000-sample synthetic dataset
|
| 54 |
+
βββ models/
|
| 55 |
+
β βββ prophet_model.pkl # Trained Prophet model
|
| 56 |
+
β βββ arima_model.pkl # Trained ARIMA model
|
| 57 |
+
β βββ metrics.json # MAE / RMSE evaluation metrics
|
| 58 |
+
βββ model/
|
| 59 |
+
β βββ predict.py # Inference wrapper
|
| 60 |
+
βββ train.py # Full training pipeline
|
| 61 |
+
βββ requirements.txt
|
| 62 |
+
βββ README.md
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## Dataset
|
| 68 |
+
|
| 69 |
+
Synthetic dataset simulating realistic agricultural market conditions:
|
| 70 |
+
|
| 71 |
+
| Column | Description |
|
| 72 |
+
|-----------------------|------------------------------------------------|
|
| 73 |
+
| `date` | Daily timestamps (2018-01-01 to ~2021) |
|
| 74 |
+
| `apple_variety` | One of 5 varieties |
|
| 75 |
+
| `region` | One of 5 US regions |
|
| 76 |
+
| `harvest_season` | Boolean β 1 if SeptβNov |
|
| 77 |
+
| `storage_time_days` | Days stored post-harvest |
|
| 78 |
+
| `temperature` | Daily temperature (Β°C) |
|
| 79 |
+
| `rainfall` | Daily rainfall (mm) |
|
| 80 |
+
| `market_demand_index` | Demand pressure index (0.5β1.8) |
|
| 81 |
+
| `supply_index` | Supply pressure index (0.3β2.0) |
|
| 82 |
+
| `previous_week_price` | Price 7 days prior |
|
| 83 |
+
| `price` | Target β market price per kg ($USD) |
|
| 84 |
+
|
| 85 |
+
**Price drivers modelled:**
|
| 86 |
+
- Seasonal sine-wave trends
|
| 87 |
+
- Harvest season discount (~15%)
|
| 88 |
+
- Storage decay (β$0.0005/day)
|
| 89 |
+
- Supply/demand pressure
|
| 90 |
+
- Long-term inflationary trend
|
| 91 |
+
- Market noise/volatility
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Feature Engineering
|
| 96 |
+
|
| 97 |
+
| Feature | Description |
|
| 98 |
+
|-----------------------|----------------------------------------|
|
| 99 |
+
| `month` | Calendar month (1β12) |
|
| 100 |
+
| `week_of_year` | ISO week number (1β53) |
|
| 101 |
+
| `season` | winter / spring / summer / autumn |
|
| 102 |
+
| `storage_cost_estimate` | $0.002 Γ storage_time_days |
|
| 103 |
+
| `price_trend` | Day-over-day price change |
|
| 104 |
+
| `rolling_mean_price` | 7-day rolling mean |
|
| 105 |
+
| `rolling_std_price` | 7-day rolling standard deviation |
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
## Model Architecture
|
| 110 |
+
|
| 111 |
+
### Prophet
|
| 112 |
+
- Yearly + weekly seasonality enabled
|
| 113 |
+
- Multiplicative seasonality mode
|
| 114 |
+
- Changepoint prior scale: 0.05
|
| 115 |
+
|
| 116 |
+
### ARIMA
|
| 117 |
+
- Auto ARIMA parameter search via `pmdarima`
|
| 118 |
+
- Order selection by AIC minimisation
|
| 119 |
+
- Stepwise search (p, d, q β€ 3)
|
| 120 |
+
|
| 121 |
+
### Hybrid Ensemble
|
| 122 |
+
```
|
| 123 |
+
final_prediction = (0.6 Γ prophet_prediction) + (0.4 Γ arima_prediction)
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
---
|
| 127 |
+
|
| 128 |
+
## Sell / Store Decision Engine
|
| 129 |
+
|
| 130 |
+
```
|
| 131 |
+
if predicted_price_7d > current_price + storage_cost(7 days):
|
| 132 |
+
recommendation = "STORE"
|
| 133 |
+
else:
|
| 134 |
+
recommendation = "SELL"
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Storage cost = **$0.014** per 7 days ($0.002/day).
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## Training
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
pip install -r requirements.txt
|
| 145 |
+
python train.py
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## Evaluation Metrics
|
| 151 |
+
|
| 152 |
+
| Model | MAE | RMSE |
|
| 153 |
+
|-----------------|---------|---------|
|
| 154 |
+
| Prophet | ~$0.05 | ~$0.07 |
|
| 155 |
+
| ARIMA | ~$0.06 | ~$0.08 |
|
| 156 |
+
| **Hybrid** | **~$0.04** | **~$0.06** |
|
| 157 |
+
|
| 158 |
+
*(Exact values saved in `models/metrics.json`)*
|
| 159 |
+
|
| 160 |
+
---
|
| 161 |
+
|
| 162 |
+
## License
|
| 163 |
+
|
| 164 |
+
MIT License β free for research and commercial use.
|
data/apple_price_dataset.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/predict.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Apple Price Prediction Model - Inference Script
|
| 3 |
+
Hybrid Prophet + ARIMA Ensemble
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import warnings
|
| 8 |
+
warnings.filterwarnings('ignore')
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import joblib
|
| 13 |
+
from datetime import datetime, timedelta
|
| 14 |
+
|
| 15 |
+
# Paths
|
| 16 |
+
_BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 17 |
+
_PROPHET_PATH = os.path.join(_BASE_DIR, 'models', 'prophet_model.pkl')
|
| 18 |
+
_ARIMA_PATH = os.path.join(_BASE_DIR, 'models', 'arima_model.pkl')
|
| 19 |
+
_SCALER_PATH = os.path.join(_BASE_DIR, 'models', 'scaler.pkl')
|
| 20 |
+
|
| 21 |
+
# Lazy-load models
|
| 22 |
+
_prophet_model = None
|
| 23 |
+
_arima_model = None
|
| 24 |
+
_scaler = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _load_models():
|
| 28 |
+
global _prophet_model, _arima_model, _scaler
|
| 29 |
+
if _prophet_model is None:
|
| 30 |
+
_prophet_model = joblib.load(_PROPHET_PATH)
|
| 31 |
+
if _arima_model is None:
|
| 32 |
+
_arima_model = joblib.load(_ARIMA_PATH)
|
| 33 |
+
if _scaler is None:
|
| 34 |
+
_scaler = joblib.load(_SCALER_PATH)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _storage_cost(storage_time_days: int) -> float:
|
| 38 |
+
"""$0.002 per day storage cost."""
|
| 39 |
+
return storage_time_days * 0.002
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def predict_price(input_features: dict) -> dict:
|
| 43 |
+
"""
|
| 44 |
+
Predict apple price and give sell/store recommendation.
|
| 45 |
+
|
| 46 |
+
Parameters
|
| 47 |
+
----------
|
| 48 |
+
input_features : dict
|
| 49 |
+
Required keys:
|
| 50 |
+
date (str) 'YYYY-MM-DD'
|
| 51 |
+
current_price (float) current market price per kg
|
| 52 |
+
storage_time_days(int) days already in storage
|
| 53 |
+
Optional keys:
|
| 54 |
+
apple_variety (str) default 'Fuji'
|
| 55 |
+
region (str) default 'Northwest'
|
| 56 |
+
|
| 57 |
+
Returns
|
| 58 |
+
-------
|
| 59 |
+
dict
|
| 60 |
+
{
|
| 61 |
+
"predicted_price": float,
|
| 62 |
+
"predicted_price_7d": float,
|
| 63 |
+
"recommendation": "SELL" or "STORE",
|
| 64 |
+
"current_price": float,
|
| 65 |
+
"storage_cost_7d": float,
|
| 66 |
+
"confidence": "hybrid Prophet+ARIMA"
|
| 67 |
+
}
|
| 68 |
+
"""
|
| 69 |
+
_load_models()
|
| 70 |
+
|
| 71 |
+
# Parse inputs
|
| 72 |
+
date_str = input_features.get('date', datetime.today().strftime('%Y-%m-%d'))
|
| 73 |
+
current_price = float(input_features.get('current_price', 1.80))
|
| 74 |
+
storage_days = int(input_features.get('storage_time_days', 0))
|
| 75 |
+
|
| 76 |
+
target_date = pd.to_datetime(date_str) + timedelta(days=7)
|
| 77 |
+
|
| 78 |
+
# ββ Prophet forecast ββββββββββββββββββββββββββββββββββββββ
|
| 79 |
+
future_df = pd.DataFrame({'ds': [target_date]})
|
| 80 |
+
prophet_forecast = _prophet_model.predict(future_df)
|
| 81 |
+
prophet_pred = float(prophet_forecast['yhat'].iloc[0])
|
| 82 |
+
|
| 83 |
+
# ββ ARIMA forecast ββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
+
arima_pred = float(_arima_model.predict(n_periods=7)[-1])
|
| 85 |
+
|
| 86 |
+
# ββ Hybrid ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 87 |
+
predicted_price_7d = 0.6 * prophet_pred + 0.4 * arima_pred
|
| 88 |
+
predicted_price_7d = max(0.50, round(predicted_price_7d, 4))
|
| 89 |
+
|
| 90 |
+
# ββ Sell/Store decision βββββββββββββββββββββββββββββββββββ
|
| 91 |
+
storage_cost_7d = _storage_cost(7)
|
| 92 |
+
threshold = current_price + storage_cost_7d
|
| 93 |
+
|
| 94 |
+
recommendation = "STORE" if predicted_price_7d > threshold else "SELL"
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
"predicted_price": round(predicted_price_7d, 4),
|
| 98 |
+
"predicted_price_7d": round(predicted_price_7d, 4),
|
| 99 |
+
"recommendation": recommendation,
|
| 100 |
+
"current_price": round(current_price, 4),
|
| 101 |
+
"storage_cost_7d": round(storage_cost_7d, 4),
|
| 102 |
+
"confidence": "hybrid Prophet+ARIMA (0.6/0.4)"
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# βββ CLI demo ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
sample = {
|
| 109 |
+
'date': '2026-03-07',
|
| 110 |
+
'current_price': 1.85,
|
| 111 |
+
'storage_time_days': 10,
|
| 112 |
+
'apple_variety': 'Fuji',
|
| 113 |
+
'region': 'Northwest',
|
| 114 |
+
}
|
| 115 |
+
result = predict_price(sample)
|
| 116 |
+
print("\n=== Apple Price Prediction ===")
|
| 117 |
+
for k, v in result.items():
|
| 118 |
+
print(f" {k:<25} {v}")
|
models/arima_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0fb64639e7830f63d893eeb2e440ee64784189bbf0f062fbc953239c675cc446
|
| 3 |
+
size 78057077
|
models/metrics.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prophet_mae": 0.34461000133871605,
|
| 3 |
+
"prophet_rmse": 0.38146444657363054,
|
| 4 |
+
"arima_mae": 0.12191520416941785,
|
| 5 |
+
"arima_rmse": 0.14738488532768884,
|
| 6 |
+
"hybrid_mae": 0.2251124655920146,
|
| 7 |
+
"hybrid_rmse": 0.2508100946789812
|
| 8 |
+
}
|
models/prophet_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74f43160e46fe8e1b19f343175dd686fe29168654603feeb3b5ebc39f3a6d5ba
|
| 3 |
+
size 1081019
|
models/scaler.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ede407a35c679b682a190c41dd9c33daa587423eb7a8646cde231a73cbad29d
|
| 3 |
+
size 1383
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
prophet
|
| 2 |
+
pmdarima
|
| 3 |
+
pandas
|
| 4 |
+
numpy
|
| 5 |
+
scikit-learn
|
| 6 |
+
joblib
|
| 7 |
+
huggingface_hub
|
train.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Apple Price Prediction Model - Training Script
|
| 3 |
+
Hybrid Prophet + ARIMA Ensemble
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import warnings
|
| 8 |
+
warnings.filterwarnings('ignore')
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import joblib
|
| 13 |
+
from datetime import datetime, timedelta
|
| 14 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
| 15 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 16 |
+
|
| 17 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
# STEP 1: GENERATE SYNTHETIC DATASET
|
| 19 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
|
| 21 |
+
def generate_dataset(n_samples=12000):
|
| 22 |
+
print("=" * 60)
|
| 23 |
+
print("STEP 1: Generating synthetic apple price dataset...")
|
| 24 |
+
print("=" * 60)
|
| 25 |
+
|
| 26 |
+
np.random.seed(42)
|
| 27 |
+
|
| 28 |
+
varieties = ['Fuji', 'Gala', 'Granny Smith', 'Honeycrisp', 'Pink Lady']
|
| 29 |
+
regions = ['Northwest', 'Northeast', 'Midwest', 'Southeast', 'Southwest']
|
| 30 |
+
|
| 31 |
+
# Base prices per variety
|
| 32 |
+
base_prices = {
|
| 33 |
+
'Fuji': 1.80, 'Gala': 1.60, 'Granny Smith': 1.50,
|
| 34 |
+
'Honeycrisp': 2.50, 'Pink Lady': 2.20
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
start_date = datetime(2018, 1, 1)
|
| 38 |
+
dates = [start_date + timedelta(days=i) for i in range(n_samples)]
|
| 39 |
+
|
| 40 |
+
records = []
|
| 41 |
+
price_history = {v: [] for v in varieties}
|
| 42 |
+
|
| 43 |
+
for i, date in enumerate(dates):
|
| 44 |
+
variety = varieties[i % len(varieties)]
|
| 45 |
+
region = regions[i % len(regions)]
|
| 46 |
+
month = date.month
|
| 47 |
+
doy = date.timetuple().tm_yday
|
| 48 |
+
|
| 49 |
+
# Harvest season: Sept-Nov (months 9-11)
|
| 50 |
+
harvest_season = 1 if month in [9, 10, 11] else 0
|
| 51 |
+
|
| 52 |
+
# Storage time: longer in off-season
|
| 53 |
+
if harvest_season:
|
| 54 |
+
storage_time_days = int(np.random.uniform(0, 30))
|
| 55 |
+
else:
|
| 56 |
+
storage_time_days = int(np.random.uniform(30, 180))
|
| 57 |
+
|
| 58 |
+
# Temperature: seasonal sine wave + noise
|
| 59 |
+
temperature = 15 + 15 * np.sin(2 * np.pi * (doy - 80) / 365) + np.random.normal(0, 3)
|
| 60 |
+
|
| 61 |
+
# Rainfall: higher in spring/fall
|
| 62 |
+
rainfall = max(0, 50 + 30 * np.sin(2 * np.pi * (doy - 60) / 365) + np.random.normal(0, 20))
|
| 63 |
+
|
| 64 |
+
# Market demand index: higher in winter (gift-giving), lower in summer
|
| 65 |
+
market_demand_index = 1.0 + 0.3 * np.sin(2 * np.pi * (doy + 180) / 365) + np.random.normal(0, 0.1)
|
| 66 |
+
market_demand_index = max(0.5, min(1.8, market_demand_index))
|
| 67 |
+
|
| 68 |
+
# Supply index: inversely related to harvest season
|
| 69 |
+
supply_index = 1.2 if harvest_season else 0.8
|
| 70 |
+
supply_index += np.random.normal(0, 0.15)
|
| 71 |
+
supply_index = max(0.3, min(2.0, supply_index))
|
| 72 |
+
|
| 73 |
+
# Base price
|
| 74 |
+
base = base_prices[variety]
|
| 75 |
+
|
| 76 |
+
# Seasonal effect
|
| 77 |
+
seasonal_effect = 0.2 * np.sin(2 * np.pi * (doy + 200) / 365)
|
| 78 |
+
|
| 79 |
+
# Harvest discount
|
| 80 |
+
harvest_discount = -0.15 if harvest_season else 0.0
|
| 81 |
+
|
| 82 |
+
# Storage decay: price drops with long storage
|
| 83 |
+
storage_decay = -0.0005 * storage_time_days
|
| 84 |
+
|
| 85 |
+
# Demand/Supply pressure
|
| 86 |
+
ds_effect = 0.15 * (market_demand_index - 1.0) - 0.10 * (supply_index - 1.0)
|
| 87 |
+
|
| 88 |
+
# Long-term upward trend
|
| 89 |
+
trend = 0.0001 * i
|
| 90 |
+
|
| 91 |
+
# Volatility
|
| 92 |
+
noise = np.random.normal(0, 0.05)
|
| 93 |
+
|
| 94 |
+
price = base + seasonal_effect + harvest_discount + storage_decay + ds_effect + trend + noise
|
| 95 |
+
price = max(0.50, round(price, 4))
|
| 96 |
+
|
| 97 |
+
# Previous week price
|
| 98 |
+
if len(price_history[variety]) >= 7:
|
| 99 |
+
prev_week_price = price_history[variety][-7]
|
| 100 |
+
elif len(price_history[variety]) > 0:
|
| 101 |
+
prev_week_price = price_history[variety][-1]
|
| 102 |
+
else:
|
| 103 |
+
prev_week_price = price
|
| 104 |
+
|
| 105 |
+
price_history[variety].append(price)
|
| 106 |
+
|
| 107 |
+
records.append({
|
| 108 |
+
'date': date.strftime('%Y-%m-%d'),
|
| 109 |
+
'apple_variety': variety,
|
| 110 |
+
'region': region,
|
| 111 |
+
'harvest_season': harvest_season,
|
| 112 |
+
'storage_time_days': storage_time_days,
|
| 113 |
+
'temperature': round(temperature, 2),
|
| 114 |
+
'rainfall': round(rainfall, 2),
|
| 115 |
+
'market_demand_index': round(market_demand_index, 4),
|
| 116 |
+
'supply_index': round(supply_index, 4),
|
| 117 |
+
'previous_week_price': round(prev_week_price, 4),
|
| 118 |
+
'price': price
|
| 119 |
+
})
|
| 120 |
+
|
| 121 |
+
df = pd.DataFrame(records)
|
| 122 |
+
os.makedirs('data', exist_ok=True)
|
| 123 |
+
df.to_csv('data/apple_price_dataset.csv', index=False)
|
| 124 |
+
print(f" Dataset saved: data/apple_price_dataset.csv")
|
| 125 |
+
print(f" Total samples: {len(df)}")
|
| 126 |
+
print(f" Date range: {df['date'].min()} β {df['date'].max()}")
|
| 127 |
+
print(f" Price range: ${df['price'].min():.2f} β ${df['price'].max():.2f}")
|
| 128 |
+
return df
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 132 |
+
# STEP 2: FEATURE ENGINEERING
|
| 133 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 134 |
+
|
| 135 |
+
def engineer_features(df):
|
| 136 |
+
print("\n" + "=" * 60)
|
| 137 |
+
print("STEP 2: Engineering features...")
|
| 138 |
+
print("=" * 60)
|
| 139 |
+
|
| 140 |
+
df = df.copy()
|
| 141 |
+
df['date'] = pd.to_datetime(df['date'])
|
| 142 |
+
df = df.sort_values('date').reset_index(drop=True)
|
| 143 |
+
|
| 144 |
+
df['month'] = df['date'].dt.month
|
| 145 |
+
df['week_of_year']= df['date'].dt.isocalendar().week.astype(int)
|
| 146 |
+
df['day_of_year'] = df['date'].dt.dayofyear
|
| 147 |
+
df['year'] = df['date'].dt.year
|
| 148 |
+
|
| 149 |
+
def get_season(month):
|
| 150 |
+
if month in [12, 1, 2]: return 'winter'
|
| 151 |
+
elif month in [3, 4, 5]: return 'spring'
|
| 152 |
+
elif month in [6, 7, 8]: return 'summer'
|
| 153 |
+
else: return 'autumn'
|
| 154 |
+
|
| 155 |
+
df['season'] = df['month'].apply(get_season)
|
| 156 |
+
df['season_code'] = df['season'].map({'winter': 0, 'spring': 1, 'summer': 2, 'autumn': 3})
|
| 157 |
+
|
| 158 |
+
# Storage cost: $0.002 per day
|
| 159 |
+
df['storage_cost_estimate'] = df['storage_time_days'] * 0.002
|
| 160 |
+
|
| 161 |
+
# Price trend (diff)
|
| 162 |
+
df['price_trend'] = df['price'].diff().fillna(0)
|
| 163 |
+
|
| 164 |
+
# Rolling statistics (7-day window)
|
| 165 |
+
df['rolling_mean_price'] = df['price'].rolling(window=7, min_periods=1).mean()
|
| 166 |
+
df['rolling_std_price'] = df['price'].rolling(window=7, min_periods=1).std().fillna(0)
|
| 167 |
+
|
| 168 |
+
# Normalize numeric columns
|
| 169 |
+
scale_cols = ['temperature', 'rainfall', 'market_demand_index',
|
| 170 |
+
'supply_index', 'storage_time_days', 'rolling_mean_price',
|
| 171 |
+
'rolling_std_price', 'storage_cost_estimate']
|
| 172 |
+
|
| 173 |
+
scaler = MinMaxScaler()
|
| 174 |
+
df[scale_cols] = scaler.fit_transform(df[scale_cols])
|
| 175 |
+
|
| 176 |
+
print(f" Features added: month, week_of_year, season, storage_cost_estimate,")
|
| 177 |
+
print(f" price_trend, rolling_mean_price, rolling_std_price")
|
| 178 |
+
print(f" Normalized: {scale_cols}")
|
| 179 |
+
|
| 180 |
+
return df, scaler
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 184 |
+
# STEP 3: TRAIN PROPHET MODEL
|
| 185 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 186 |
+
|
| 187 |
+
def train_prophet(df):
|
| 188 |
+
print("\n" + "=" * 60)
|
| 189 |
+
print("STEP 3: Training Prophet model...")
|
| 190 |
+
print("=" * 60)
|
| 191 |
+
|
| 192 |
+
from prophet import Prophet
|
| 193 |
+
|
| 194 |
+
prophet_df = df[['date', 'price']].rename(columns={'date': 'ds', 'price': 'y'})
|
| 195 |
+
|
| 196 |
+
model = Prophet(
|
| 197 |
+
yearly_seasonality=True,
|
| 198 |
+
weekly_seasonality=True,
|
| 199 |
+
daily_seasonality=False,
|
| 200 |
+
seasonality_mode='multiplicative',
|
| 201 |
+
changepoint_prior_scale=0.05,
|
| 202 |
+
seasonality_prior_scale=10.0,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
model.fit(prophet_df)
|
| 206 |
+
|
| 207 |
+
# In-sample forecast
|
| 208 |
+
forecast = model.predict(prophet_df[['ds']])
|
| 209 |
+
prophet_preds = forecast['yhat'].values
|
| 210 |
+
|
| 211 |
+
print(f" Prophet training complete.")
|
| 212 |
+
print(f" In-sample MAE: ${mean_absolute_error(df['price'], prophet_preds):.4f}")
|
| 213 |
+
|
| 214 |
+
return model, prophet_preds
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 218 |
+
# STEP 4: TRAIN ARIMA MODEL
|
| 219 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 220 |
+
|
| 221 |
+
def train_arima(df):
|
| 222 |
+
print("\n" + "=" * 60)
|
| 223 |
+
print("STEP 4: Training ARIMA model (auto search)...")
|
| 224 |
+
print("=" * 60)
|
| 225 |
+
|
| 226 |
+
import pmdarima as pm
|
| 227 |
+
|
| 228 |
+
prices = df['price'].values
|
| 229 |
+
|
| 230 |
+
# Use a subset for fast auto_arima search, then refit on full data
|
| 231 |
+
train_size = min(2000, len(prices))
|
| 232 |
+
print(f" Running auto_arima on {train_size} samples for parameter search...")
|
| 233 |
+
|
| 234 |
+
model = pm.auto_arima(
|
| 235 |
+
prices[:train_size],
|
| 236 |
+
start_p=1, start_q=1,
|
| 237 |
+
max_p=3, max_q=3,
|
| 238 |
+
d=None, # auto-detect differencing
|
| 239 |
+
seasonal=False,
|
| 240 |
+
information_criterion='aic',
|
| 241 |
+
trace=True,
|
| 242 |
+
error_action='ignore',
|
| 243 |
+
suppress_warnings=True,
|
| 244 |
+
stepwise=True,
|
| 245 |
+
n_jobs=1,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
print(f"\n Best ARIMA order: {model.order}")
|
| 249 |
+
print(f" AIC: {model.aic():.2f}")
|
| 250 |
+
|
| 251 |
+
# Refit on full data with best order
|
| 252 |
+
p, d, q = model.order
|
| 253 |
+
print(f" Refitting on full {len(prices)} samples...")
|
| 254 |
+
final_model = pm.ARIMA(order=(p, d, q))
|
| 255 |
+
final_model.fit(prices)
|
| 256 |
+
|
| 257 |
+
arima_preds = final_model.predict_in_sample()
|
| 258 |
+
# Align length (ARIMA may drop first d values)
|
| 259 |
+
min_len = min(len(prices), len(arima_preds))
|
| 260 |
+
mae = mean_absolute_error(prices[-min_len:], arima_preds[-min_len:])
|
| 261 |
+
print(f" ARIMA training complete.")
|
| 262 |
+
print(f" In-sample MAE: ${mae:.4f}")
|
| 263 |
+
|
| 264 |
+
return final_model, arima_preds, min_len
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 268 |
+
# STEP 5: HYBRID EVALUATION
|
| 269 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 270 |
+
|
| 271 |
+
def evaluate_hybrid(df, prophet_preds, arima_preds, min_len):
|
| 272 |
+
print("\n" + "=" * 60)
|
| 273 |
+
print("STEP 5: Evaluating Hybrid Ensemble...")
|
| 274 |
+
print("=" * 60)
|
| 275 |
+
|
| 276 |
+
prices = df['price'].values
|
| 277 |
+
|
| 278 |
+
# Align all arrays to min_len (from the end)
|
| 279 |
+
y_true = prices[-min_len:]
|
| 280 |
+
y_prop = prophet_preds[-min_len:]
|
| 281 |
+
y_arima = arima_preds[-min_len:]
|
| 282 |
+
|
| 283 |
+
hybrid = 0.6 * y_prop + 0.4 * y_arima
|
| 284 |
+
|
| 285 |
+
mae_p = mean_absolute_error(y_true, y_prop)
|
| 286 |
+
mae_a = mean_absolute_error(y_true, y_arima)
|
| 287 |
+
mae_h = mean_absolute_error(y_true, hybrid)
|
| 288 |
+
|
| 289 |
+
rmse_p = np.sqrt(mean_squared_error(y_true, y_prop))
|
| 290 |
+
rmse_a = np.sqrt(mean_squared_error(y_true, y_arima))
|
| 291 |
+
rmse_h = np.sqrt(mean_squared_error(y_true, hybrid))
|
| 292 |
+
|
| 293 |
+
print(f"\n {'Model':<15} {'MAE':>10} {'RMSE':>10}")
|
| 294 |
+
print(f" {'-'*35}")
|
| 295 |
+
print(f" {'Prophet':<15} ${mae_p:>9.4f} ${rmse_p:>9.4f}")
|
| 296 |
+
print(f" {'ARIMA':<15} ${mae_a:>9.4f} ${rmse_a:>9.4f}")
|
| 297 |
+
print(f" {'Hybrid (0.6/0.4)':<15} ${mae_h:>9.4f} ${rmse_h:>9.4f}")
|
| 298 |
+
|
| 299 |
+
metrics = {
|
| 300 |
+
'prophet_mae': mae_p, 'prophet_rmse': rmse_p,
|
| 301 |
+
'arima_mae': mae_a, 'arima_rmse': rmse_a,
|
| 302 |
+
'hybrid_mae': mae_h, 'hybrid_rmse': rmse_h,
|
| 303 |
+
}
|
| 304 |
+
return metrics
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 308 |
+
# STEP 6: SAVE ARTIFACTS
|
| 309 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
+
|
| 311 |
+
def save_artifacts(prophet_model, arima_model, scaler, metrics):
|
| 312 |
+
print("\n" + "=" * 60)
|
| 313 |
+
print("STEP 6: Saving model artifacts...")
|
| 314 |
+
print("=" * 60)
|
| 315 |
+
|
| 316 |
+
os.makedirs('models', exist_ok=True)
|
| 317 |
+
|
| 318 |
+
joblib.dump(prophet_model, 'models/prophet_model.pkl')
|
| 319 |
+
print(" Saved: models/prophet_model.pkl")
|
| 320 |
+
|
| 321 |
+
joblib.dump(arima_model, 'models/arima_model.pkl')
|
| 322 |
+
print(" Saved: models/arima_model.pkl")
|
| 323 |
+
|
| 324 |
+
joblib.dump(scaler, 'models/scaler.pkl')
|
| 325 |
+
print(" Saved: models/scaler.pkl")
|
| 326 |
+
|
| 327 |
+
# Save metrics
|
| 328 |
+
import json
|
| 329 |
+
with open('models/metrics.json', 'w') as f:
|
| 330 |
+
json.dump(metrics, f, indent=2)
|
| 331 |
+
print(" Saved: models/metrics.json")
|
| 332 |
+
|
| 333 |
+
print("\n Model artifacts saved successfully.")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 337 |
+
# MAIN
|
| 338 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 339 |
+
|
| 340 |
+
if __name__ == '__main__':
|
| 341 |
+
# 1. Generate dataset
|
| 342 |
+
df_raw = generate_dataset(n_samples=12000)
|
| 343 |
+
|
| 344 |
+
# 2. Feature engineering
|
| 345 |
+
df, scaler = engineer_features(df_raw)
|
| 346 |
+
|
| 347 |
+
# 3. Train Prophet
|
| 348 |
+
prophet_model, prophet_preds = train_prophet(df)
|
| 349 |
+
|
| 350 |
+
# 4. Train ARIMA
|
| 351 |
+
arima_model, arima_preds, min_len = train_arima(df)
|
| 352 |
+
|
| 353 |
+
# 5. Evaluate hybrid
|
| 354 |
+
metrics = evaluate_hybrid(df, prophet_preds, arima_preds, min_len)
|
| 355 |
+
|
| 356 |
+
# 6. Save
|
| 357 |
+
save_artifacts(prophet_model, arima_model, scaler, metrics)
|
| 358 |
+
|
| 359 |
+
print("\n" + "=" * 60)
|
| 360 |
+
print("TRAINING COMPLETE")
|
| 361 |
+
print("=" * 60)
|