Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import tempfile
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import numpy as np
|
| 8 |
+
!pip install mlforecast
|
| 9 |
+
|
| 10 |
+
from statsforecast import StatsForecast
|
| 11 |
+
from statsforecast.models import (
|
| 12 |
+
HistoricAverage,
|
| 13 |
+
Naive,
|
| 14 |
+
SeasonalNaive,
|
| 15 |
+
WindowAverage,
|
| 16 |
+
SeasonalWindowAverage,
|
| 17 |
+
AutoETS,
|
| 18 |
+
AutoARIMA,
|
| 19 |
+
AutoCES,
|
| 20 |
+
AutoTheta,
|
| 21 |
+
DynamicOptimizedTheta,
|
| 22 |
+
MSTL
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from utilsforecast.evaluation import evaluate
|
| 26 |
+
from utilsforecast.losses import *
|
| 27 |
+
|
| 28 |
+
# Import for MLForecast
|
| 29 |
+
from mlforecast import MLForecast
|
| 30 |
+
from lightgbm import LGBMRegressor
|
| 31 |
+
|
| 32 |
+
#Function to generate and return a plot for validation results
|
| 33 |
+
def create_forecast_plot(forecast_df, original_df, title="Forecasting Results", horizon=None, freq='D'):
|
| 34 |
+
plt.figure(figsize=(12, 7))
|
| 35 |
+
unique_ids = forecast_df['unique_id'].unique()
|
| 36 |
+
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff', 'y']]
|
| 37 |
+
|
| 38 |
+
colors = plt.cm.tab10.colors
|
| 39 |
+
min_cutoff = None
|
| 40 |
+
|
| 41 |
+
for i, unique_id in enumerate(unique_ids):
|
| 42 |
+
original_data = original_df[original_df['unique_id'] == unique_id]
|
| 43 |
+
plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
|
| 44 |
+
|
| 45 |
+
forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
|
| 46 |
+
|
| 47 |
+
if 'cutoff' in forecast_data.columns:
|
| 48 |
+
cutoffs = pd.to_datetime(forecast_data['cutoff'].unique())
|
| 49 |
+
if len(cutoffs) > 0:
|
| 50 |
+
earliest_cutoff = cutoffs.min()
|
| 51 |
+
if min_cutoff is None or earliest_cutoff < min_cutoff:
|
| 52 |
+
min_cutoff = earliest_cutoff
|
| 53 |
+
|
| 54 |
+
for cutoff in cutoffs:
|
| 55 |
+
plt.axvline(x=cutoff, color='gray', linestyle='--', alpha=0.4)
|
| 56 |
+
|
| 57 |
+
for j, col in enumerate(forecast_cols):
|
| 58 |
+
if col in forecast_data.columns:
|
| 59 |
+
model_name = col.replace('_', ' ').title()
|
| 60 |
+
plt.plot(forecast_data['ds'], forecast_data[col],
|
| 61 |
+
color=colors[j % len(colors)],
|
| 62 |
+
linestyle='--',
|
| 63 |
+
linewidth=1.5,
|
| 64 |
+
label=f'{model_name}')
|
| 65 |
+
|
| 66 |
+
plt.title(title, fontsize=16)
|
| 67 |
+
plt.xlabel('Date', fontsize=12)
|
| 68 |
+
plt.ylabel('Value', fontsize=12)
|
| 69 |
+
plt.grid(True, alpha=0.3)
|
| 70 |
+
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
|
| 71 |
+
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
|
| 72 |
+
|
| 73 |
+
if min_cutoff is not None and horizon is not None:
|
| 74 |
+
date_offset = calculate_date_offset(freq, horizon)
|
| 75 |
+
start_date = min_cutoff - date_offset
|
| 76 |
+
max_date = forecast_df['ds'].max()
|
| 77 |
+
plt.xlim(start_date, max_date)
|
| 78 |
+
|
| 79 |
+
plt.annotate('Training | Test',
|
| 80 |
+
xy=(min_cutoff, plt.ylim()[0]),
|
| 81 |
+
xytext=(0, -40),
|
| 82 |
+
textcoords='offset points',
|
| 83 |
+
horizontalalignment='center',
|
| 84 |
+
fontsize=10)
|
| 85 |
+
|
| 86 |
+
fig = plt.gcf()
|
| 87 |
+
ax = plt.gca()
|
| 88 |
+
fig.autofmt_xdate()
|
| 89 |
+
|
| 90 |
+
return fig
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Foundation Models
|
| 95 |
+
try:
|
| 96 |
+
from chronos import ChronosPipeline
|
| 97 |
+
import torch
|
| 98 |
+
CHRONOS_AVAILABLE = True
|
| 99 |
+
except:
|
| 100 |
+
CHRONOS_AVAILABLE = False
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
from uni2ts.model.moirai import MoiraiForecast
|
| 104 |
+
MOIRAI_AVAILABLE = True
|
| 105 |
+
except:
|
| 106 |
+
MOIRAI_AVAILABLE = False
|
| 107 |
+
|
| 108 |
+
# Function to load and process uploaded CSV
|
| 109 |
+
def load_data(file):
|
| 110 |
+
if file is None:
|
| 111 |
+
return None, "Please upload a CSV file"
|
| 112 |
+
try:
|
| 113 |
+
df = pd.read_csv(file)
|
| 114 |
+
required_cols = ['unique_id', 'ds', 'y']
|
| 115 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
| 116 |
+
if missing_cols:
|
| 117 |
+
return None, f"Missing required columns: {', '.join(missing_cols)}"
|
| 118 |
+
|
| 119 |
+
df['ds'] = pd.to_datetime(df['ds'])
|
| 120 |
+
df = df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
|
| 121 |
+
|
| 122 |
+
# Check for NaN values
|
| 123 |
+
if df['y'].isna().any():
|
| 124 |
+
return None, "Data contains missing values in the 'y' column"
|
| 125 |
+
|
| 126 |
+
return df, "Data loaded successfully!"
|
| 127 |
+
except Exception as e:
|
| 128 |
+
return None, f"Error loading data: {str(e)}"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Helper function to calculate date offset based on frequency and horizon
|
| 132 |
+
def calculate_date_offset(freq, horizon):
|
| 133 |
+
"""Calculate a timedelta based on frequency code and horizon"""
|
| 134 |
+
if freq == 'H':
|
| 135 |
+
return pd.Timedelta(hours=horizon)
|
| 136 |
+
elif freq == 'D':
|
| 137 |
+
return pd.Timedelta(days=horizon)
|
| 138 |
+
elif freq == 'B':
|
| 139 |
+
return pd.Timedelta(days=int(horizon * 1.4))
|
| 140 |
+
elif freq == 'WS':
|
| 141 |
+
return pd.Timedelta(weeks=horizon)
|
| 142 |
+
elif freq == 'MS':
|
| 143 |
+
return pd.Timedelta(days=horizon * 30)
|
| 144 |
+
elif freq == 'QS':
|
| 145 |
+
return pd.Timedelta(days=horizon * 90)
|
| 146 |
+
elif freq == 'YS':
|
| 147 |
+
return pd.Timedelta(days=horizon * 365)
|
| 148 |
+
else:
|
| 149 |
+
return pd.Timedelta(days=horizon)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# Main forecasting function
|
| 153 |
+
def run_forecast(
|
| 154 |
+
file, frequency, eval_strategy, horizon, step_size, num_windows,
|
| 155 |
+
use_historical_avg, use_naive, use_seasonal_naive, seasonality,
|
| 156 |
+
use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
|
| 157 |
+
use_autoets, use_autoarima, use_autoces, use_autotheta,
|
| 158 |
+
use_lgbm, use_chronos, use_moirai,
|
| 159 |
+
future_horizon
|
| 160 |
+
):
|
| 161 |
+
"""
|
| 162 |
+
Main function to run forecasting with all selected models.
|
| 163 |
+
Now includes proper handling of models that don't support predictors.
|
| 164 |
+
"""
|
| 165 |
+
try:
|
| 166 |
+
# Load data
|
| 167 |
+
df, message = load_data(file)
|
| 168 |
+
if df is None:
|
| 169 |
+
return None, None, None, None, None, [], message
|
| 170 |
+
|
| 171 |
+
# Prepare data - only required columns for models without predictors
|
| 172 |
+
df_basic = df[['unique_id', 'ds', 'y']].copy()
|
| 173 |
+
|
| 174 |
+
# For models that need predictors, prepare full feature set
|
| 175 |
+
# (This would be expanded based on your feature engineering)
|
| 176 |
+
|
| 177 |
+
# Initialize models list
|
| 178 |
+
models = []
|
| 179 |
+
models_need_predictors = []
|
| 180 |
+
|
| 181 |
+
# Basic models (no predictors needed)
|
| 182 |
+
if use_historical_avg:
|
| 183 |
+
models.append(HistoricAverage())
|
| 184 |
+
if use_naive:
|
| 185 |
+
models.append(Naive())
|
| 186 |
+
if use_seasonal_naive:
|
| 187 |
+
models.append(SeasonalNaive(season_length=int(seasonality)))
|
| 188 |
+
if use_window_avg:
|
| 189 |
+
models.append(WindowAverage(window_size=int(window_size)))
|
| 190 |
+
if use_seasonal_window_avg:
|
| 191 |
+
models.append(SeasonalWindowAverage(season_length=int(seasonality), window_size=int(seasonal_window_size)))
|
| 192 |
+
if use_autoets:
|
| 193 |
+
models.append(AutoETS(season_length=int(seasonality)))
|
| 194 |
+
if use_autoces:
|
| 195 |
+
models.append(AutoCES(season_length=int(seasonality)))
|
| 196 |
+
if use_autotheta:
|
| 197 |
+
models.append(AutoTheta(season_length=int(seasonality)))
|
| 198 |
+
|
| 199 |
+
# Models that can use predictors
|
| 200 |
+
if use_autoarima:
|
| 201 |
+
models_need_predictors.append(AutoARIMA(season_length=int(seasonality)))
|
| 202 |
+
|
| 203 |
+
# Run cross-validation or fixed window
|
| 204 |
+
if eval_strategy == "Cross Validation":
|
| 205 |
+
h = horizon
|
| 206 |
+
validation_results = []
|
| 207 |
+
|
| 208 |
+
# Run models without predictors
|
| 209 |
+
if models:
|
| 210 |
+
sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
|
| 211 |
+
cv_df = sf.cross_validation(
|
| 212 |
+
df=df_basic,
|
| 213 |
+
h=int(h),
|
| 214 |
+
step_size=int(step_size),
|
| 215 |
+
n_windows=int(num_windows)
|
| 216 |
+
)
|
| 217 |
+
validation_results.append(cv_df)
|
| 218 |
+
|
| 219 |
+
# Run models with predictors (if needed, add predictor handling here)
|
| 220 |
+
# For now, we'll run them without predictors
|
| 221 |
+
if models_need_predictors:
|
| 222 |
+
sf_pred = StatsForecast(models=models_need_predictors, freq=frequency, n_jobs=-1)
|
| 223 |
+
cv_df_pred = sf_pred.cross_validation(
|
| 224 |
+
df=df_basic, # Use df with predictors when implemented
|
| 225 |
+
h=int(h),
|
| 226 |
+
step_size=int(step_size),
|
| 227 |
+
n_windows=int(num_windows)
|
| 228 |
+
)
|
| 229 |
+
validation_results.append(cv_df_pred)
|
| 230 |
+
|
| 231 |
+
# Combine results
|
| 232 |
+
if validation_results:
|
| 233 |
+
validation_df = pd.concat(validation_results, axis=1)
|
| 234 |
+
validation_df = validation_df.loc[:,~validation_df.columns.duplicated()]
|
| 235 |
+
else:
|
| 236 |
+
return None, None, None, None, None, [], "No models selected"
|
| 237 |
+
|
| 238 |
+
else: # Fixed Window
|
| 239 |
+
# Similar logic for fixed window
|
| 240 |
+
# Split data
|
| 241 |
+
train_df = []
|
| 242 |
+
for uid in df_basic['unique_id'].unique():
|
| 243 |
+
uid_data = df_basic[df_basic['unique_id'] == uid].iloc[:-int(horizon)]
|
| 244 |
+
train_df.append(uid_data)
|
| 245 |
+
train_df = pd.concat(train_df)
|
| 246 |
+
|
| 247 |
+
# Fit and predict
|
| 248 |
+
all_models = models + models_need_predictors
|
| 249 |
+
if all_models:
|
| 250 |
+
sf = StatsForecast(models=all_models, freq=frequency, n_jobs=-1)
|
| 251 |
+
sf.fit(train_df)
|
| 252 |
+
validation_df = sf.predict(h=int(horizon), level=[90, 95])
|
| 253 |
+
else:
|
| 254 |
+
return None, None, None, None, None, [], "No models selected"
|
| 255 |
+
|
| 256 |
+
# Add ML model forecasts if selected
|
| 257 |
+
if use_lgbm:
|
| 258 |
+
mlf = MLForecast(
|
| 259 |
+
models={'LightGBM': LGBMRegressor(verbose=-1)},
|
| 260 |
+
freq=frequency,
|
| 261 |
+
lags=[1, 7, 14],
|
| 262 |
+
num_threads=1
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if eval_strategy == "Cross Validation":
|
| 266 |
+
ml_cv = mlf.cross_validation(
|
| 267 |
+
df=df_basic,
|
| 268 |
+
h=int(horizon),
|
| 269 |
+
step_size=int(step_size),
|
| 270 |
+
n_windows=int(num_windows)
|
| 271 |
+
)
|
| 272 |
+
validation_df = validation_df.merge(ml_cv, on=['unique_id', 'ds', 'cutoff'], how='outer')
|
| 273 |
+
else:
|
| 274 |
+
mlf.fit(train_df)
|
| 275 |
+
ml_pred = mlf.predict(h=int(horizon))
|
| 276 |
+
validation_df = validation_df.merge(ml_pred, on=['unique_id', 'ds'], how='outer')
|
| 277 |
+
|
| 278 |
+
# Add foundation model forecasts
|
| 279 |
+
if use_chronos and CHRONOS_AVAILABLE:
|
| 280 |
+
try:
|
| 281 |
+
pipeline = ChronosPipeline.from_pretrained(
|
| 282 |
+
"amazon/chronos-t5-tiny",
|
| 283 |
+
device_map="auto",
|
| 284 |
+
torch_dtype=torch.bfloat16,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
chronos_forecasts = []
|
| 288 |
+
for uid in df_basic['unique_id'].unique():
|
| 289 |
+
uid_data = train_df[train_df['unique_id'] == uid]['y'].values
|
| 290 |
+
context = torch.tensor(uid_data)
|
| 291 |
+
forecast = pipeline.predict(context, prediction_length=int(horizon))
|
| 292 |
+
forecast_median = np.median(forecast[0].numpy(), axis=0)
|
| 293 |
+
|
| 294 |
+
uid_forecast = pd.DataFrame({
|
| 295 |
+
'unique_id': uid,
|
| 296 |
+
'ds': pd.date_range(
|
| 297 |
+
start=train_df[train_df['unique_id'] == uid]['ds'].max() + pd.Timedelta(days=1),
|
| 298 |
+
periods=int(horizon),
|
| 299 |
+
freq=frequency
|
| 300 |
+
),
|
| 301 |
+
'Chronos': forecast_median
|
| 302 |
+
})
|
| 303 |
+
chronos_forecasts.append(uid_forecast)
|
| 304 |
+
|
| 305 |
+
chronos_df = pd.concat(chronos_forecasts)
|
| 306 |
+
validation_df = validation_df.merge(chronos_df, on=['unique_id', 'ds'], how='outer')
|
| 307 |
+
except Exception as e:
|
| 308 |
+
print(f"Chronos error: {e}")
|
| 309 |
+
|
| 310 |
+
# Evaluate models
|
| 311 |
+
eval_cols = [col for col in validation_df.columns if col not in ['unique_id', 'ds', 'cutoff', 'y']]
|
| 312 |
+
|
| 313 |
+
if 'y' not in validation_df.columns:
|
| 314 |
+
# Merge with actual values
|
| 315 |
+
validation_df = validation_df.merge(
|
| 316 |
+
df_basic[['unique_id', 'ds', 'y']],
|
| 317 |
+
on=['unique_id', 'ds'],
|
| 318 |
+
how='left'
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Calculate metrics
|
| 322 |
+
metrics_list = []
|
| 323 |
+
for col in eval_cols:
|
| 324 |
+
if col in validation_df.columns and not validation_df[col].isna().all():
|
| 325 |
+
y_true = validation_df['y'].values
|
| 326 |
+
y_pred = validation_df[col].values
|
| 327 |
+
|
| 328 |
+
mask = ~(np.isnan(y_true) | np.isnan(y_pred))
|
| 329 |
+
if mask.sum() > 0:
|
| 330 |
+
y_true_clean = y_true[mask]
|
| 331 |
+
y_pred_clean = y_pred[mask]
|
| 332 |
+
|
| 333 |
+
metrics_list.append({
|
| 334 |
+
'Model': col,
|
| 335 |
+
'MAE': mae(y_true_clean, y_pred_clean),
|
| 336 |
+
'RMSE': rmse(y_true_clean, y_pred_clean),
|
| 337 |
+
'MAPE': mape(y_true_clean, y_pred_clean)
|
| 338 |
+
})
|
| 339 |
+
|
| 340 |
+
eval_metrics = pd.DataFrame(metrics_list)
|
| 341 |
+
|
| 342 |
+
# Create validation plot
|
| 343 |
+
validation_plot = create_forecast_plot(
|
| 344 |
+
validation_df.reset_index() if 'index' not in validation_df.columns else validation_df,
|
| 345 |
+
df_basic,
|
| 346 |
+
"Validation Results",
|
| 347 |
+
horizon,
|
| 348 |
+
frequency
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Future forecast
|
| 352 |
+
future_models = models + models_need_predictors
|
| 353 |
+
if future_models:
|
| 354 |
+
sf_future = StatsForecast(models=future_models, freq=frequency, n_jobs=-1)
|
| 355 |
+
sf_future.fit(df_basic)
|
| 356 |
+
future_df = sf_future.predict(h=int(future_horizon), level=[90, 95])
|
| 357 |
+
else:
|
| 358 |
+
future_df = pd.DataFrame()
|
| 359 |
+
|
| 360 |
+
# Create future forecast plot
|
| 361 |
+
future_plot = create_forecast_plot(
|
| 362 |
+
future_df.reset_index() if not future_df.empty else pd.DataFrame(),
|
| 363 |
+
df_basic,
|
| 364 |
+
"Future Forecast",
|
| 365 |
+
future_horizon,
|
| 366 |
+
frequency
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Export files
|
| 370 |
+
export_files = []
|
| 371 |
+
|
| 372 |
+
# Save to temp files
|
| 373 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') as f:
|
| 374 |
+
eval_metrics.to_csv(f, index=False)
|
| 375 |
+
export_files.append(f.name)
|
| 376 |
+
|
| 377 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') as f:
|
| 378 |
+
validation_df.to_csv(f, index=False)
|
| 379 |
+
export_files.append(f.name)
|
| 380 |
+
|
| 381 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') as f:
|
| 382 |
+
future_df.to_csv(f, index=False)
|
| 383 |
+
export_files.append(f.name)
|
| 384 |
+
|
| 385 |
+
return (
|
| 386 |
+
eval_metrics,
|
| 387 |
+
validation_df,
|
| 388 |
+
validation_plot,
|
| 389 |
+
future_df,
|
| 390 |
+
future_plot,
|
| 391 |
+
export_files,
|
| 392 |
+
"✓ Forecasting completed successfully!"
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
except Exception as e:
|
| 396 |
+
import traceback
|
| 397 |
+
error_msg = f"Error: {str(e)}\n\n{traceback.format_exc()}"
|
| 398 |
+
return None, None, None, None, None, [], error_msg
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
# Gradio Interface
|
| 403 |
+
with gr.Blocks(title="Duke Energy Forecasting App") as app:
|
| 404 |
+
gr.Markdown("""
|
| 405 |
+
# 🔮 Duke Energy Time Series Forecasting
|
| 406 |
+
|
| 407 |
+
Upload your time series data and select models to generate forecasts.
|
| 408 |
+
Supports StatsForecast, MLForecast, and Foundation Models (Chronos, Moirai).
|
| 409 |
+
""")
|
| 410 |
+
|
| 411 |
+
with gr.Row():
|
| 412 |
+
with gr.Column(scale=1):
|
| 413 |
+
file_input = gr.File(label="Upload CSV File", file_types=['.csv'])
|
| 414 |
+
|
| 415 |
+
with gr.Accordion("Forecast Configuration", open=True):
|
| 416 |
+
frequency = gr.Dropdown(
|
| 417 |
+
choices=[
|
| 418 |
+
("Hourly", "H"),
|
| 419 |
+
("Daily", "D"),
|
| 420 |
+
("Business Day", "B"),
|
| 421 |
+
("Weekly", "WS"),
|
| 422 |
+
("Monthly", "MS"),
|
| 423 |
+
("Quarterly", "QS"),
|
| 424 |
+
("Yearly", "YS")
|
| 425 |
+
],
|
| 426 |
+
label="Data Frequency",
|
| 427 |
+
value="D"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
eval_strategy = gr.Radio(
|
| 431 |
+
choices=["Fixed Window", "Cross Validation"],
|
| 432 |
+
label="Evaluation Strategy",
|
| 433 |
+
value="Cross Validation"
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
with gr.Group(visible=True) as fixed_window_box:
|
| 437 |
+
gr.Markdown("### Fixed Window Settings")
|
| 438 |
+
horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon")
|
| 439 |
+
|
| 440 |
+
with gr.Group(visible=True) as cv_box:
|
| 441 |
+
gr.Markdown("### Cross Validation Settings")
|
| 442 |
+
with gr.Row():
|
| 443 |
+
step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size")
|
| 444 |
+
num_windows = gr.Slider(1, 20, value=5, step=1, label="Number of Windows")
|
| 445 |
+
|
| 446 |
+
with gr.Group():
|
| 447 |
+
gr.Markdown("### Future Forecast Settings")
|
| 448 |
+
future_horizon = gr.Slider(1, 100, value=10, step=1, label="Future Forecast Horizon")
|
| 449 |
+
|
| 450 |
+
with gr.Accordion("Model Configuration", open=True):
|
| 451 |
+
with gr.Tabs():
|
| 452 |
+
with gr.TabItem("Statistical Models"):
|
| 453 |
+
gr.Markdown("## Basic Models")
|
| 454 |
+
with gr.Row():
|
| 455 |
+
use_historical_avg = gr.Checkbox(label="Historical Average", value=True)
|
| 456 |
+
use_naive = gr.Checkbox(label="Naive", value=True)
|
| 457 |
+
|
| 458 |
+
with gr.Group():
|
| 459 |
+
gr.Markdown("### Seasonality Configuration")
|
| 460 |
+
seasonality = gr.Number(label="Seasonality Period", value=7)
|
| 461 |
+
|
| 462 |
+
gr.Markdown("### Seasonal Models")
|
| 463 |
+
use_seasonal_naive = gr.Checkbox(label="Seasonal Naive", value=True)
|
| 464 |
+
|
| 465 |
+
gr.Markdown("### Window-based Models")
|
| 466 |
+
with gr.Row():
|
| 467 |
+
use_window_avg = gr.Checkbox(label="Window Average", value=False)
|
| 468 |
+
window_size = gr.Number(label="Window Size", value=10)
|
| 469 |
+
|
| 470 |
+
with gr.Row():
|
| 471 |
+
use_seasonal_window_avg = gr.Checkbox(label="Seasonal Window Average", value=False)
|
| 472 |
+
seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
|
| 473 |
+
|
| 474 |
+
gr.Markdown("### Advanced Models")
|
| 475 |
+
with gr.Row():
|
| 476 |
+
use_autoets = gr.Checkbox(label="AutoETS", value=False)
|
| 477 |
+
use_autoarima = gr.Checkbox(label="AutoARIMA", value=False)
|
| 478 |
+
with gr.Row():
|
| 479 |
+
use_autoces = gr.Checkbox(label="AutoCES", value=False)
|
| 480 |
+
use_autotheta = gr.Checkbox(label="AutoTheta", value=False)
|
| 481 |
+
|
| 482 |
+
with gr.TabItem("Machine Learning"):
|
| 483 |
+
gr.Markdown("## Gradient Boosting Models")
|
| 484 |
+
use_lgbm = gr.Checkbox(label="LightGBM", value=True)
|
| 485 |
+
|
| 486 |
+
with gr.TabItem("Foundation Models"):
|
| 487 |
+
gr.Markdown("## State-of-the-Art Foundation Models")
|
| 488 |
+
|
| 489 |
+
with gr.Row():
|
| 490 |
+
use_chronos = gr.Checkbox(
|
| 491 |
+
label="Chronos (Amazon)",
|
| 492 |
+
value=CHRONOS_AVAILABLE,
|
| 493 |
+
interactive=CHRONOS_AVAILABLE
|
| 494 |
+
)
|
| 495 |
+
use_moirai = gr.Checkbox(
|
| 496 |
+
label="Moirai (Salesforce)",
|
| 497 |
+
value=False,
|
| 498 |
+
interactive=MOIRAI_AVAILABLE
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
if not CHRONOS_AVAILABLE:
|
| 502 |
+
gr.Markdown("⚠️ Chronos not available. Install: `pip install chronos-forecasting`")
|
| 503 |
+
if not MOIRAI_AVAILABLE:
|
| 504 |
+
gr.Markdown("⚠️ Moirai not available. Install: `pip install uni2ts`")
|
| 505 |
+
|
| 506 |
+
with gr.Column(scale=3):
|
| 507 |
+
message_output = gr.Textbox(label="Status Message")
|
| 508 |
+
|
| 509 |
+
with gr.Tabs():
|
| 510 |
+
with gr.TabItem("Validation Results"):
|
| 511 |
+
eval_output = gr.Dataframe(label="Evaluation Metrics")
|
| 512 |
+
validation_plot = gr.Plot(label="Validation Plot")
|
| 513 |
+
validation_output = gr.Dataframe(label="Validation Data", visible=False)
|
| 514 |
+
|
| 515 |
+
with gr.Row():
|
| 516 |
+
show_data_btn = gr.Button("Show Validation Data")
|
| 517 |
+
hide_data_btn = gr.Button("Hide Validation Data", visible=False)
|
| 518 |
+
|
| 519 |
+
with gr.TabItem("Future Forecast"):
|
| 520 |
+
forecast_plot = gr.Plot(label="Future Forecast Plot")
|
| 521 |
+
forecast_output = gr.Dataframe(label="Future Forecast Data", visible=False)
|
| 522 |
+
|
| 523 |
+
with gr.Row():
|
| 524 |
+
show_forecast_btn = gr.Button("Show Forecast Data")
|
| 525 |
+
hide_forecast_btn = gr.Button("Hide Forecast Data", visible=False)
|
| 526 |
+
|
| 527 |
+
with gr.TabItem("Export Results"):
|
| 528 |
+
export_files = gr.Files(label="Download Results")
|
| 529 |
+
|
| 530 |
+
with gr.Row():
|
| 531 |
+
submit_btn = gr.Button("Run Validation and Forecast", variant="primary", size="lg")
|
| 532 |
+
|
| 533 |
+
# Event handlers
|
| 534 |
+
def update_eval_boxes(strategy):
|
| 535 |
+
return (
|
| 536 |
+
gr.update(visible=strategy == "Fixed Window"),
|
| 537 |
+
gr.update(visible=strategy == "Cross Validation")
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
eval_strategy.change(
|
| 541 |
+
fn=update_eval_boxes,
|
| 542 |
+
inputs=[eval_strategy],
|
| 543 |
+
outputs=[fixed_window_box, cv_box]
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
def show_data():
|
| 547 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
|
| 548 |
+
|
| 549 |
+
def hide_data():
|
| 550 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
|
| 551 |
+
|
| 552 |
+
show_data_btn.click(fn=show_data, outputs=[validation_output, hide_data_btn, show_data_btn])
|
| 553 |
+
hide_data_btn.click(fn=hide_data, outputs=[validation_output, hide_data_btn, show_data_btn])
|
| 554 |
+
show_forecast_btn.click(fn=show_data, outputs=[forecast_output, hide_forecast_btn, show_forecast_btn])
|
| 555 |
+
hide_forecast_btn.click(fn=hide_data, outputs=[forecast_output, hide_forecast_btn, show_forecast_btn])
|
| 556 |
+
|
| 557 |
+
submit_btn.click(
|
| 558 |
+
fn=run_forecast,
|
| 559 |
+
inputs=[
|
| 560 |
+
file_input, frequency, eval_strategy, horizon, step_size, num_windows,
|
| 561 |
+
use_historical_avg, use_naive, use_seasonal_naive, seasonality,
|
| 562 |
+
use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
|
| 563 |
+
use_autoets, use_autoarima, use_autoces, use_autotheta,
|
| 564 |
+
use_lgbm, use_chronos, use_moirai,
|
| 565 |
+
future_horizon
|
| 566 |
+
],
|
| 567 |
+
outputs=[
|
| 568 |
+
eval_output,
|
| 569 |
+
validation_output,
|
| 570 |
+
validation_plot,
|
| 571 |
+
forecast_output,
|
| 572 |
+
forecast_plot,
|
| 573 |
+
export_files,
|
| 574 |
+
message_output
|
| 575 |
+
]
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
if __name__ == "__main__":
|
| 579 |
+
app.launch(share=False)
|