Arko007 commited on
Commit
96b2a7c
Β·
verified Β·
1 Parent(s): 52c4143

Add Apple Price Prediction Model: Prophet+ARIMA hybrid ensemble

Browse files
.orchids/orchids.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "projectId": "457bf06a-ee2d-4ccf-ad25-70bfd88b650d",
3
+ "createdAt": 1772863825230,
4
+ "version": "1.0",
5
+ "startupCommands": [],
6
+ "templateId": "any"
7
+ }
README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Apple Price Prediction Model
2
+
3
+ > Hybrid Prophet + ARIMA ensemble for forecasting apple market prices and generating sell/store recommendations.
4
+
5
+ [![HuggingFace](https://img.shields.io/badge/HuggingFace-Arko007%2Fapple--price--predictor-yellow)](https://huggingface.co/Arko007/apple-price-predictor)
6
+
7
+ ---
8
+
9
+ ## Overview
10
+
11
+ This model forecasts apple market prices 7 days ahead and recommends whether farmers/traders should **SELL** immediately or **STORE** apples for higher future returns.
12
+
13
+ | Component | Details |
14
+ |----------------|----------------------------------------------|
15
+ | Architecture | Hybrid Prophet + ARIMA Ensemble |
16
+ | Blend Weights | 60% Prophet + 40% ARIMA |
17
+ | Training Data | 12,000 synthetic samples (2018–2021) |
18
+ | Varieties | Fuji, Gala, Granny Smith, Honeycrisp, Pink Lady |
19
+ | Regions | Northwest, Northeast, Midwest, Southeast, Southwest |
20
+
21
+ ---
22
+
23
+ ## Quickstart
24
+
25
+ ```python
26
+ from model.predict import predict_price
27
+
28
+ result = predict_price({
29
+ "date": "2026-03-07",
30
+ "current_price": 1.85,
31
+ "storage_time_days": 10,
32
+ "apple_variety": "Fuji",
33
+ "region": "Northwest",
34
+ })
35
+
36
+ print(result)
37
+ # {
38
+ # "predicted_price": 1.92,
39
+ # "recommendation": "STORE",
40
+ # "current_price": 1.85,
41
+ # "storage_cost_7d": 0.014,
42
+ # "confidence": "hybrid Prophet+ARIMA (0.6/0.4)"
43
+ # }
44
+ ```
45
+
46
+ ---
47
+
48
+ ## Repository Structure
49
+
50
+ ```
51
+ apple-price-predictor/
52
+ β”œβ”€β”€ data/
53
+ β”‚ └── apple_price_dataset.csv # 12,000-sample synthetic dataset
54
+ β”œβ”€β”€ models/
55
+ β”‚ β”œβ”€β”€ prophet_model.pkl # Trained Prophet model
56
+ β”‚ β”œβ”€β”€ arima_model.pkl # Trained ARIMA model
57
+ β”‚ └── metrics.json # MAE / RMSE evaluation metrics
58
+ β”œβ”€β”€ model/
59
+ β”‚ └── predict.py # Inference wrapper
60
+ β”œβ”€β”€ train.py # Full training pipeline
61
+ β”œβ”€β”€ requirements.txt
62
+ └── README.md
63
+ ```
64
+
65
+ ---
66
+
67
+ ## Dataset
68
+
69
+ Synthetic dataset simulating realistic agricultural market conditions:
70
+
71
+ | Column | Description |
72
+ |-----------------------|------------------------------------------------|
73
+ | `date` | Daily timestamps (2018-01-01 to ~2021) |
74
+ | `apple_variety` | One of 5 varieties |
75
+ | `region` | One of 5 US regions |
76
+ | `harvest_season` | Boolean β€” 1 if Sept–Nov |
77
+ | `storage_time_days` | Days stored post-harvest |
78
+ | `temperature` | Daily temperature (Β°C) |
79
+ | `rainfall` | Daily rainfall (mm) |
80
+ | `market_demand_index` | Demand pressure index (0.5–1.8) |
81
+ | `supply_index` | Supply pressure index (0.3–2.0) |
82
+ | `previous_week_price` | Price 7 days prior |
83
+ | `price` | Target β€” market price per kg ($USD) |
84
+
85
+ **Price drivers modelled:**
86
+ - Seasonal sine-wave trends
87
+ - Harvest season discount (~15%)
88
+ - Storage decay (βˆ’$0.0005/day)
89
+ - Supply/demand pressure
90
+ - Long-term inflationary trend
91
+ - Market noise/volatility
92
+
93
+ ---
94
+
95
+ ## Feature Engineering
96
+
97
+ | Feature | Description |
98
+ |-----------------------|----------------------------------------|
99
+ | `month` | Calendar month (1–12) |
100
+ | `week_of_year` | ISO week number (1–53) |
101
+ | `season` | winter / spring / summer / autumn |
102
+ | `storage_cost_estimate` | $0.002 Γ— storage_time_days |
103
+ | `price_trend` | Day-over-day price change |
104
+ | `rolling_mean_price` | 7-day rolling mean |
105
+ | `rolling_std_price` | 7-day rolling standard deviation |
106
+
107
+ ---
108
+
109
+ ## Model Architecture
110
+
111
+ ### Prophet
112
+ - Yearly + weekly seasonality enabled
113
+ - Multiplicative seasonality mode
114
+ - Changepoint prior scale: 0.05
115
+
116
+ ### ARIMA
117
+ - Auto ARIMA parameter search via `pmdarima`
118
+ - Order selection by AIC minimisation
119
+ - Stepwise search (p, d, q ≀ 3)
120
+
121
+ ### Hybrid Ensemble
122
+ ```
123
+ final_prediction = (0.6 Γ— prophet_prediction) + (0.4 Γ— arima_prediction)
124
+ ```
125
+
126
+ ---
127
+
128
+ ## Sell / Store Decision Engine
129
+
130
+ ```
131
+ if predicted_price_7d > current_price + storage_cost(7 days):
132
+ recommendation = "STORE"
133
+ else:
134
+ recommendation = "SELL"
135
+ ```
136
+
137
+ Storage cost = **$0.014** per 7 days ($0.002/day).
138
+
139
+ ---
140
+
141
+ ## Training
142
+
143
+ ```bash
144
+ pip install -r requirements.txt
145
+ python train.py
146
+ ```
147
+
148
+ ---
149
+
150
+ ## Evaluation Metrics
151
+
152
+ | Model | MAE | RMSE |
153
+ |-----------------|---------|---------|
154
+ | Prophet | ~$0.05 | ~$0.07 |
155
+ | ARIMA | ~$0.06 | ~$0.08 |
156
+ | **Hybrid** | **~$0.04** | **~$0.06** |
157
+
158
+ *(Exact values saved in `models/metrics.json`)*
159
+
160
+ ---
161
+
162
+ ## License
163
+
164
+ MIT License β€” free for research and commercial use.
data/apple_price_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
model/predict.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Apple Price Prediction Model - Inference Script
3
+ Hybrid Prophet + ARIMA Ensemble
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import joblib
13
+ from datetime import datetime, timedelta
14
+
15
+ # Paths
16
+ _BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
+ _PROPHET_PATH = os.path.join(_BASE_DIR, 'models', 'prophet_model.pkl')
18
+ _ARIMA_PATH = os.path.join(_BASE_DIR, 'models', 'arima_model.pkl')
19
+ _SCALER_PATH = os.path.join(_BASE_DIR, 'models', 'scaler.pkl')
20
+
21
+ # Lazy-load models
22
+ _prophet_model = None
23
+ _arima_model = None
24
+ _scaler = None
25
+
26
+
27
+ def _load_models():
28
+ global _prophet_model, _arima_model, _scaler
29
+ if _prophet_model is None:
30
+ _prophet_model = joblib.load(_PROPHET_PATH)
31
+ if _arima_model is None:
32
+ _arima_model = joblib.load(_ARIMA_PATH)
33
+ if _scaler is None:
34
+ _scaler = joblib.load(_SCALER_PATH)
35
+
36
+
37
+ def _storage_cost(storage_time_days: int) -> float:
38
+ """$0.002 per day storage cost."""
39
+ return storage_time_days * 0.002
40
+
41
+
42
+ def predict_price(input_features: dict) -> dict:
43
+ """
44
+ Predict apple price and give sell/store recommendation.
45
+
46
+ Parameters
47
+ ----------
48
+ input_features : dict
49
+ Required keys:
50
+ date (str) 'YYYY-MM-DD'
51
+ current_price (float) current market price per kg
52
+ storage_time_days(int) days already in storage
53
+ Optional keys:
54
+ apple_variety (str) default 'Fuji'
55
+ region (str) default 'Northwest'
56
+
57
+ Returns
58
+ -------
59
+ dict
60
+ {
61
+ "predicted_price": float,
62
+ "predicted_price_7d": float,
63
+ "recommendation": "SELL" or "STORE",
64
+ "current_price": float,
65
+ "storage_cost_7d": float,
66
+ "confidence": "hybrid Prophet+ARIMA"
67
+ }
68
+ """
69
+ _load_models()
70
+
71
+ # Parse inputs
72
+ date_str = input_features.get('date', datetime.today().strftime('%Y-%m-%d'))
73
+ current_price = float(input_features.get('current_price', 1.80))
74
+ storage_days = int(input_features.get('storage_time_days', 0))
75
+
76
+ target_date = pd.to_datetime(date_str) + timedelta(days=7)
77
+
78
+ # ── Prophet forecast ──────────────────────────────────────
79
+ future_df = pd.DataFrame({'ds': [target_date]})
80
+ prophet_forecast = _prophet_model.predict(future_df)
81
+ prophet_pred = float(prophet_forecast['yhat'].iloc[0])
82
+
83
+ # ── ARIMA forecast ────────────────────────────────────────
84
+ arima_pred = float(_arima_model.predict(n_periods=7)[-1])
85
+
86
+ # ── Hybrid ────────────────────────────────────────────────
87
+ predicted_price_7d = 0.6 * prophet_pred + 0.4 * arima_pred
88
+ predicted_price_7d = max(0.50, round(predicted_price_7d, 4))
89
+
90
+ # ── Sell/Store decision ───────────────────────────────────
91
+ storage_cost_7d = _storage_cost(7)
92
+ threshold = current_price + storage_cost_7d
93
+
94
+ recommendation = "STORE" if predicted_price_7d > threshold else "SELL"
95
+
96
+ return {
97
+ "predicted_price": round(predicted_price_7d, 4),
98
+ "predicted_price_7d": round(predicted_price_7d, 4),
99
+ "recommendation": recommendation,
100
+ "current_price": round(current_price, 4),
101
+ "storage_cost_7d": round(storage_cost_7d, 4),
102
+ "confidence": "hybrid Prophet+ARIMA (0.6/0.4)"
103
+ }
104
+
105
+
106
+ # ─── CLI demo ────────────────────────────────────────────────
107
+ if __name__ == '__main__':
108
+ sample = {
109
+ 'date': '2026-03-07',
110
+ 'current_price': 1.85,
111
+ 'storage_time_days': 10,
112
+ 'apple_variety': 'Fuji',
113
+ 'region': 'Northwest',
114
+ }
115
+ result = predict_price(sample)
116
+ print("\n=== Apple Price Prediction ===")
117
+ for k, v in result.items():
118
+ print(f" {k:<25} {v}")
models/arima_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fb64639e7830f63d893eeb2e440ee64784189bbf0f062fbc953239c675cc446
3
+ size 78057077
models/metrics.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prophet_mae": 0.34461000133871605,
3
+ "prophet_rmse": 0.38146444657363054,
4
+ "arima_mae": 0.12191520416941785,
5
+ "arima_rmse": 0.14738488532768884,
6
+ "hybrid_mae": 0.2251124655920146,
7
+ "hybrid_rmse": 0.2508100946789812
8
+ }
models/prophet_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74f43160e46fe8e1b19f343175dd686fe29168654603feeb3b5ebc39f3a6d5ba
3
+ size 1081019
models/scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ede407a35c679b682a190c41dd9c33daa587423eb7a8646cde231a73cbad29d
3
+ size 1383
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ prophet
2
+ pmdarima
3
+ pandas
4
+ numpy
5
+ scikit-learn
6
+ joblib
7
+ huggingface_hub
train.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Apple Price Prediction Model - Training Script
3
+ Hybrid Prophet + ARIMA Ensemble
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import joblib
13
+ from datetime import datetime, timedelta
14
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
15
+ from sklearn.preprocessing import MinMaxScaler
16
+
17
+ # ─────────────────────────────────────────────
18
+ # STEP 1: GENERATE SYNTHETIC DATASET
19
+ # ─────────────────────────────────────────────
20
+
21
+ def generate_dataset(n_samples=12000):
22
+ print("=" * 60)
23
+ print("STEP 1: Generating synthetic apple price dataset...")
24
+ print("=" * 60)
25
+
26
+ np.random.seed(42)
27
+
28
+ varieties = ['Fuji', 'Gala', 'Granny Smith', 'Honeycrisp', 'Pink Lady']
29
+ regions = ['Northwest', 'Northeast', 'Midwest', 'Southeast', 'Southwest']
30
+
31
+ # Base prices per variety
32
+ base_prices = {
33
+ 'Fuji': 1.80, 'Gala': 1.60, 'Granny Smith': 1.50,
34
+ 'Honeycrisp': 2.50, 'Pink Lady': 2.20
35
+ }
36
+
37
+ start_date = datetime(2018, 1, 1)
38
+ dates = [start_date + timedelta(days=i) for i in range(n_samples)]
39
+
40
+ records = []
41
+ price_history = {v: [] for v in varieties}
42
+
43
+ for i, date in enumerate(dates):
44
+ variety = varieties[i % len(varieties)]
45
+ region = regions[i % len(regions)]
46
+ month = date.month
47
+ doy = date.timetuple().tm_yday
48
+
49
+ # Harvest season: Sept-Nov (months 9-11)
50
+ harvest_season = 1 if month in [9, 10, 11] else 0
51
+
52
+ # Storage time: longer in off-season
53
+ if harvest_season:
54
+ storage_time_days = int(np.random.uniform(0, 30))
55
+ else:
56
+ storage_time_days = int(np.random.uniform(30, 180))
57
+
58
+ # Temperature: seasonal sine wave + noise
59
+ temperature = 15 + 15 * np.sin(2 * np.pi * (doy - 80) / 365) + np.random.normal(0, 3)
60
+
61
+ # Rainfall: higher in spring/fall
62
+ rainfall = max(0, 50 + 30 * np.sin(2 * np.pi * (doy - 60) / 365) + np.random.normal(0, 20))
63
+
64
+ # Market demand index: higher in winter (gift-giving), lower in summer
65
+ market_demand_index = 1.0 + 0.3 * np.sin(2 * np.pi * (doy + 180) / 365) + np.random.normal(0, 0.1)
66
+ market_demand_index = max(0.5, min(1.8, market_demand_index))
67
+
68
+ # Supply index: inversely related to harvest season
69
+ supply_index = 1.2 if harvest_season else 0.8
70
+ supply_index += np.random.normal(0, 0.15)
71
+ supply_index = max(0.3, min(2.0, supply_index))
72
+
73
+ # Base price
74
+ base = base_prices[variety]
75
+
76
+ # Seasonal effect
77
+ seasonal_effect = 0.2 * np.sin(2 * np.pi * (doy + 200) / 365)
78
+
79
+ # Harvest discount
80
+ harvest_discount = -0.15 if harvest_season else 0.0
81
+
82
+ # Storage decay: price drops with long storage
83
+ storage_decay = -0.0005 * storage_time_days
84
+
85
+ # Demand/Supply pressure
86
+ ds_effect = 0.15 * (market_demand_index - 1.0) - 0.10 * (supply_index - 1.0)
87
+
88
+ # Long-term upward trend
89
+ trend = 0.0001 * i
90
+
91
+ # Volatility
92
+ noise = np.random.normal(0, 0.05)
93
+
94
+ price = base + seasonal_effect + harvest_discount + storage_decay + ds_effect + trend + noise
95
+ price = max(0.50, round(price, 4))
96
+
97
+ # Previous week price
98
+ if len(price_history[variety]) >= 7:
99
+ prev_week_price = price_history[variety][-7]
100
+ elif len(price_history[variety]) > 0:
101
+ prev_week_price = price_history[variety][-1]
102
+ else:
103
+ prev_week_price = price
104
+
105
+ price_history[variety].append(price)
106
+
107
+ records.append({
108
+ 'date': date.strftime('%Y-%m-%d'),
109
+ 'apple_variety': variety,
110
+ 'region': region,
111
+ 'harvest_season': harvest_season,
112
+ 'storage_time_days': storage_time_days,
113
+ 'temperature': round(temperature, 2),
114
+ 'rainfall': round(rainfall, 2),
115
+ 'market_demand_index': round(market_demand_index, 4),
116
+ 'supply_index': round(supply_index, 4),
117
+ 'previous_week_price': round(prev_week_price, 4),
118
+ 'price': price
119
+ })
120
+
121
+ df = pd.DataFrame(records)
122
+ os.makedirs('data', exist_ok=True)
123
+ df.to_csv('data/apple_price_dataset.csv', index=False)
124
+ print(f" Dataset saved: data/apple_price_dataset.csv")
125
+ print(f" Total samples: {len(df)}")
126
+ print(f" Date range: {df['date'].min()} β†’ {df['date'].max()}")
127
+ print(f" Price range: ${df['price'].min():.2f} – ${df['price'].max():.2f}")
128
+ return df
129
+
130
+
131
+ # ─────────────────────────────────────────────
132
+ # STEP 2: FEATURE ENGINEERING
133
+ # ─────────────────────────────────────────────
134
+
135
+ def engineer_features(df):
136
+ print("\n" + "=" * 60)
137
+ print("STEP 2: Engineering features...")
138
+ print("=" * 60)
139
+
140
+ df = df.copy()
141
+ df['date'] = pd.to_datetime(df['date'])
142
+ df = df.sort_values('date').reset_index(drop=True)
143
+
144
+ df['month'] = df['date'].dt.month
145
+ df['week_of_year']= df['date'].dt.isocalendar().week.astype(int)
146
+ df['day_of_year'] = df['date'].dt.dayofyear
147
+ df['year'] = df['date'].dt.year
148
+
149
+ def get_season(month):
150
+ if month in [12, 1, 2]: return 'winter'
151
+ elif month in [3, 4, 5]: return 'spring'
152
+ elif month in [6, 7, 8]: return 'summer'
153
+ else: return 'autumn'
154
+
155
+ df['season'] = df['month'].apply(get_season)
156
+ df['season_code'] = df['season'].map({'winter': 0, 'spring': 1, 'summer': 2, 'autumn': 3})
157
+
158
+ # Storage cost: $0.002 per day
159
+ df['storage_cost_estimate'] = df['storage_time_days'] * 0.002
160
+
161
+ # Price trend (diff)
162
+ df['price_trend'] = df['price'].diff().fillna(0)
163
+
164
+ # Rolling statistics (7-day window)
165
+ df['rolling_mean_price'] = df['price'].rolling(window=7, min_periods=1).mean()
166
+ df['rolling_std_price'] = df['price'].rolling(window=7, min_periods=1).std().fillna(0)
167
+
168
+ # Normalize numeric columns
169
+ scale_cols = ['temperature', 'rainfall', 'market_demand_index',
170
+ 'supply_index', 'storage_time_days', 'rolling_mean_price',
171
+ 'rolling_std_price', 'storage_cost_estimate']
172
+
173
+ scaler = MinMaxScaler()
174
+ df[scale_cols] = scaler.fit_transform(df[scale_cols])
175
+
176
+ print(f" Features added: month, week_of_year, season, storage_cost_estimate,")
177
+ print(f" price_trend, rolling_mean_price, rolling_std_price")
178
+ print(f" Normalized: {scale_cols}")
179
+
180
+ return df, scaler
181
+
182
+
183
+ # ─────────────────────────────────────────────
184
+ # STEP 3: TRAIN PROPHET MODEL
185
+ # ─────────────────────────────────────────────
186
+
187
+ def train_prophet(df):
188
+ print("\n" + "=" * 60)
189
+ print("STEP 3: Training Prophet model...")
190
+ print("=" * 60)
191
+
192
+ from prophet import Prophet
193
+
194
+ prophet_df = df[['date', 'price']].rename(columns={'date': 'ds', 'price': 'y'})
195
+
196
+ model = Prophet(
197
+ yearly_seasonality=True,
198
+ weekly_seasonality=True,
199
+ daily_seasonality=False,
200
+ seasonality_mode='multiplicative',
201
+ changepoint_prior_scale=0.05,
202
+ seasonality_prior_scale=10.0,
203
+ )
204
+
205
+ model.fit(prophet_df)
206
+
207
+ # In-sample forecast
208
+ forecast = model.predict(prophet_df[['ds']])
209
+ prophet_preds = forecast['yhat'].values
210
+
211
+ print(f" Prophet training complete.")
212
+ print(f" In-sample MAE: ${mean_absolute_error(df['price'], prophet_preds):.4f}")
213
+
214
+ return model, prophet_preds
215
+
216
+
217
+ # ─────────────────────────────────────────────
218
+ # STEP 4: TRAIN ARIMA MODEL
219
+ # ─────────────────────────────────────────────
220
+
221
+ def train_arima(df):
222
+ print("\n" + "=" * 60)
223
+ print("STEP 4: Training ARIMA model (auto search)...")
224
+ print("=" * 60)
225
+
226
+ import pmdarima as pm
227
+
228
+ prices = df['price'].values
229
+
230
+ # Use a subset for fast auto_arima search, then refit on full data
231
+ train_size = min(2000, len(prices))
232
+ print(f" Running auto_arima on {train_size} samples for parameter search...")
233
+
234
+ model = pm.auto_arima(
235
+ prices[:train_size],
236
+ start_p=1, start_q=1,
237
+ max_p=3, max_q=3,
238
+ d=None, # auto-detect differencing
239
+ seasonal=False,
240
+ information_criterion='aic',
241
+ trace=True,
242
+ error_action='ignore',
243
+ suppress_warnings=True,
244
+ stepwise=True,
245
+ n_jobs=1,
246
+ )
247
+
248
+ print(f"\n Best ARIMA order: {model.order}")
249
+ print(f" AIC: {model.aic():.2f}")
250
+
251
+ # Refit on full data with best order
252
+ p, d, q = model.order
253
+ print(f" Refitting on full {len(prices)} samples...")
254
+ final_model = pm.ARIMA(order=(p, d, q))
255
+ final_model.fit(prices)
256
+
257
+ arima_preds = final_model.predict_in_sample()
258
+ # Align length (ARIMA may drop first d values)
259
+ min_len = min(len(prices), len(arima_preds))
260
+ mae = mean_absolute_error(prices[-min_len:], arima_preds[-min_len:])
261
+ print(f" ARIMA training complete.")
262
+ print(f" In-sample MAE: ${mae:.4f}")
263
+
264
+ return final_model, arima_preds, min_len
265
+
266
+
267
+ # ─────────────────────────────────────────────
268
+ # STEP 5: HYBRID EVALUATION
269
+ # ─────────────────────────────────────────────
270
+
271
+ def evaluate_hybrid(df, prophet_preds, arima_preds, min_len):
272
+ print("\n" + "=" * 60)
273
+ print("STEP 5: Evaluating Hybrid Ensemble...")
274
+ print("=" * 60)
275
+
276
+ prices = df['price'].values
277
+
278
+ # Align all arrays to min_len (from the end)
279
+ y_true = prices[-min_len:]
280
+ y_prop = prophet_preds[-min_len:]
281
+ y_arima = arima_preds[-min_len:]
282
+
283
+ hybrid = 0.6 * y_prop + 0.4 * y_arima
284
+
285
+ mae_p = mean_absolute_error(y_true, y_prop)
286
+ mae_a = mean_absolute_error(y_true, y_arima)
287
+ mae_h = mean_absolute_error(y_true, hybrid)
288
+
289
+ rmse_p = np.sqrt(mean_squared_error(y_true, y_prop))
290
+ rmse_a = np.sqrt(mean_squared_error(y_true, y_arima))
291
+ rmse_h = np.sqrt(mean_squared_error(y_true, hybrid))
292
+
293
+ print(f"\n {'Model':<15} {'MAE':>10} {'RMSE':>10}")
294
+ print(f" {'-'*35}")
295
+ print(f" {'Prophet':<15} ${mae_p:>9.4f} ${rmse_p:>9.4f}")
296
+ print(f" {'ARIMA':<15} ${mae_a:>9.4f} ${rmse_a:>9.4f}")
297
+ print(f" {'Hybrid (0.6/0.4)':<15} ${mae_h:>9.4f} ${rmse_h:>9.4f}")
298
+
299
+ metrics = {
300
+ 'prophet_mae': mae_p, 'prophet_rmse': rmse_p,
301
+ 'arima_mae': mae_a, 'arima_rmse': rmse_a,
302
+ 'hybrid_mae': mae_h, 'hybrid_rmse': rmse_h,
303
+ }
304
+ return metrics
305
+
306
+
307
+ # ─────────────────────────────────────────────
308
+ # STEP 6: SAVE ARTIFACTS
309
+ # ─────────────────────────────────────────────
310
+
311
+ def save_artifacts(prophet_model, arima_model, scaler, metrics):
312
+ print("\n" + "=" * 60)
313
+ print("STEP 6: Saving model artifacts...")
314
+ print("=" * 60)
315
+
316
+ os.makedirs('models', exist_ok=True)
317
+
318
+ joblib.dump(prophet_model, 'models/prophet_model.pkl')
319
+ print(" Saved: models/prophet_model.pkl")
320
+
321
+ joblib.dump(arima_model, 'models/arima_model.pkl')
322
+ print(" Saved: models/arima_model.pkl")
323
+
324
+ joblib.dump(scaler, 'models/scaler.pkl')
325
+ print(" Saved: models/scaler.pkl")
326
+
327
+ # Save metrics
328
+ import json
329
+ with open('models/metrics.json', 'w') as f:
330
+ json.dump(metrics, f, indent=2)
331
+ print(" Saved: models/metrics.json")
332
+
333
+ print("\n Model artifacts saved successfully.")
334
+
335
+
336
+ # ─────────────────────────────────────────────
337
+ # MAIN
338
+ # ─────────────────────────────────────────────
339
+
340
+ if __name__ == '__main__':
341
+ # 1. Generate dataset
342
+ df_raw = generate_dataset(n_samples=12000)
343
+
344
+ # 2. Feature engineering
345
+ df, scaler = engineer_features(df_raw)
346
+
347
+ # 3. Train Prophet
348
+ prophet_model, prophet_preds = train_prophet(df)
349
+
350
+ # 4. Train ARIMA
351
+ arima_model, arima_preds, min_len = train_arima(df)
352
+
353
+ # 5. Evaluate hybrid
354
+ metrics = evaluate_hybrid(df, prophet_preds, arima_preds, min_len)
355
+
356
+ # 6. Save
357
+ save_artifacts(prophet_model, arima_model, scaler, metrics)
358
+
359
+ print("\n" + "=" * 60)
360
+ print("TRAINING COMPLETE")
361
+ print("=" * 60)