libbeyfox commited on
Commit
e3534aa
·
verified ·
1 Parent(s): 76c205c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +579 -0
app.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+ import tempfile
5
+ import os
6
+ from datetime import datetime
7
+ import numpy as np
8
+ !pip install mlforecast
9
+
10
+ from statsforecast import StatsForecast
11
+ from statsforecast.models import (
12
+ HistoricAverage,
13
+ Naive,
14
+ SeasonalNaive,
15
+ WindowAverage,
16
+ SeasonalWindowAverage,
17
+ AutoETS,
18
+ AutoARIMA,
19
+ AutoCES,
20
+ AutoTheta,
21
+ DynamicOptimizedTheta,
22
+ MSTL
23
+ )
24
+
25
+ from utilsforecast.evaluation import evaluate
26
+ from utilsforecast.losses import *
27
+
28
+ # Import for MLForecast
29
+ from mlforecast import MLForecast
30
+ from lightgbm import LGBMRegressor
31
+
32
+ #Function to generate and return a plot for validation results
33
+ def create_forecast_plot(forecast_df, original_df, title="Forecasting Results", horizon=None, freq='D'):
34
+ plt.figure(figsize=(12, 7))
35
+ unique_ids = forecast_df['unique_id'].unique()
36
+ forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff', 'y']]
37
+
38
+ colors = plt.cm.tab10.colors
39
+ min_cutoff = None
40
+
41
+ for i, unique_id in enumerate(unique_ids):
42
+ original_data = original_df[original_df['unique_id'] == unique_id]
43
+ plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
44
+
45
+ forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
46
+
47
+ if 'cutoff' in forecast_data.columns:
48
+ cutoffs = pd.to_datetime(forecast_data['cutoff'].unique())
49
+ if len(cutoffs) > 0:
50
+ earliest_cutoff = cutoffs.min()
51
+ if min_cutoff is None or earliest_cutoff < min_cutoff:
52
+ min_cutoff = earliest_cutoff
53
+
54
+ for cutoff in cutoffs:
55
+ plt.axvline(x=cutoff, color='gray', linestyle='--', alpha=0.4)
56
+
57
+ for j, col in enumerate(forecast_cols):
58
+ if col in forecast_data.columns:
59
+ model_name = col.replace('_', ' ').title()
60
+ plt.plot(forecast_data['ds'], forecast_data[col],
61
+ color=colors[j % len(colors)],
62
+ linestyle='--',
63
+ linewidth=1.5,
64
+ label=f'{model_name}')
65
+
66
+ plt.title(title, fontsize=16)
67
+ plt.xlabel('Date', fontsize=12)
68
+ plt.ylabel('Value', fontsize=12)
69
+ plt.grid(True, alpha=0.3)
70
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
71
+ plt.tight_layout(rect=[0, 0.05, 1, 0.95])
72
+
73
+ if min_cutoff is not None and horizon is not None:
74
+ date_offset = calculate_date_offset(freq, horizon)
75
+ start_date = min_cutoff - date_offset
76
+ max_date = forecast_df['ds'].max()
77
+ plt.xlim(start_date, max_date)
78
+
79
+ plt.annotate('Training | Test',
80
+ xy=(min_cutoff, plt.ylim()[0]),
81
+ xytext=(0, -40),
82
+ textcoords='offset points',
83
+ horizontalalignment='center',
84
+ fontsize=10)
85
+
86
+ fig = plt.gcf()
87
+ ax = plt.gca()
88
+ fig.autofmt_xdate()
89
+
90
+ return fig
91
+
92
+
93
+
94
+ # Foundation Models
95
+ try:
96
+ from chronos import ChronosPipeline
97
+ import torch
98
+ CHRONOS_AVAILABLE = True
99
+ except:
100
+ CHRONOS_AVAILABLE = False
101
+
102
+ try:
103
+ from uni2ts.model.moirai import MoiraiForecast
104
+ MOIRAI_AVAILABLE = True
105
+ except:
106
+ MOIRAI_AVAILABLE = False
107
+
108
+ # Function to load and process uploaded CSV
109
+ def load_data(file):
110
+ if file is None:
111
+ return None, "Please upload a CSV file"
112
+ try:
113
+ df = pd.read_csv(file)
114
+ required_cols = ['unique_id', 'ds', 'y']
115
+ missing_cols = [col for col in required_cols if col not in df.columns]
116
+ if missing_cols:
117
+ return None, f"Missing required columns: {', '.join(missing_cols)}"
118
+
119
+ df['ds'] = pd.to_datetime(df['ds'])
120
+ df = df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
121
+
122
+ # Check for NaN values
123
+ if df['y'].isna().any():
124
+ return None, "Data contains missing values in the 'y' column"
125
+
126
+ return df, "Data loaded successfully!"
127
+ except Exception as e:
128
+ return None, f"Error loading data: {str(e)}"
129
+
130
+
131
+ # Helper function to calculate date offset based on frequency and horizon
132
+ def calculate_date_offset(freq, horizon):
133
+ """Calculate a timedelta based on frequency code and horizon"""
134
+ if freq == 'H':
135
+ return pd.Timedelta(hours=horizon)
136
+ elif freq == 'D':
137
+ return pd.Timedelta(days=horizon)
138
+ elif freq == 'B':
139
+ return pd.Timedelta(days=int(horizon * 1.4))
140
+ elif freq == 'WS':
141
+ return pd.Timedelta(weeks=horizon)
142
+ elif freq == 'MS':
143
+ return pd.Timedelta(days=horizon * 30)
144
+ elif freq == 'QS':
145
+ return pd.Timedelta(days=horizon * 90)
146
+ elif freq == 'YS':
147
+ return pd.Timedelta(days=horizon * 365)
148
+ else:
149
+ return pd.Timedelta(days=horizon)
150
+
151
+
152
+ # Main forecasting function
153
+ def run_forecast(
154
+ file, frequency, eval_strategy, horizon, step_size, num_windows,
155
+ use_historical_avg, use_naive, use_seasonal_naive, seasonality,
156
+ use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
157
+ use_autoets, use_autoarima, use_autoces, use_autotheta,
158
+ use_lgbm, use_chronos, use_moirai,
159
+ future_horizon
160
+ ):
161
+ """
162
+ Main function to run forecasting with all selected models.
163
+ Now includes proper handling of models that don't support predictors.
164
+ """
165
+ try:
166
+ # Load data
167
+ df, message = load_data(file)
168
+ if df is None:
169
+ return None, None, None, None, None, [], message
170
+
171
+ # Prepare data - only required columns for models without predictors
172
+ df_basic = df[['unique_id', 'ds', 'y']].copy()
173
+
174
+ # For models that need predictors, prepare full feature set
175
+ # (This would be expanded based on your feature engineering)
176
+
177
+ # Initialize models list
178
+ models = []
179
+ models_need_predictors = []
180
+
181
+ # Basic models (no predictors needed)
182
+ if use_historical_avg:
183
+ models.append(HistoricAverage())
184
+ if use_naive:
185
+ models.append(Naive())
186
+ if use_seasonal_naive:
187
+ models.append(SeasonalNaive(season_length=int(seasonality)))
188
+ if use_window_avg:
189
+ models.append(WindowAverage(window_size=int(window_size)))
190
+ if use_seasonal_window_avg:
191
+ models.append(SeasonalWindowAverage(season_length=int(seasonality), window_size=int(seasonal_window_size)))
192
+ if use_autoets:
193
+ models.append(AutoETS(season_length=int(seasonality)))
194
+ if use_autoces:
195
+ models.append(AutoCES(season_length=int(seasonality)))
196
+ if use_autotheta:
197
+ models.append(AutoTheta(season_length=int(seasonality)))
198
+
199
+ # Models that can use predictors
200
+ if use_autoarima:
201
+ models_need_predictors.append(AutoARIMA(season_length=int(seasonality)))
202
+
203
+ # Run cross-validation or fixed window
204
+ if eval_strategy == "Cross Validation":
205
+ h = horizon
206
+ validation_results = []
207
+
208
+ # Run models without predictors
209
+ if models:
210
+ sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
211
+ cv_df = sf.cross_validation(
212
+ df=df_basic,
213
+ h=int(h),
214
+ step_size=int(step_size),
215
+ n_windows=int(num_windows)
216
+ )
217
+ validation_results.append(cv_df)
218
+
219
+ # Run models with predictors (if needed, add predictor handling here)
220
+ # For now, we'll run them without predictors
221
+ if models_need_predictors:
222
+ sf_pred = StatsForecast(models=models_need_predictors, freq=frequency, n_jobs=-1)
223
+ cv_df_pred = sf_pred.cross_validation(
224
+ df=df_basic, # Use df with predictors when implemented
225
+ h=int(h),
226
+ step_size=int(step_size),
227
+ n_windows=int(num_windows)
228
+ )
229
+ validation_results.append(cv_df_pred)
230
+
231
+ # Combine results
232
+ if validation_results:
233
+ validation_df = pd.concat(validation_results, axis=1)
234
+ validation_df = validation_df.loc[:,~validation_df.columns.duplicated()]
235
+ else:
236
+ return None, None, None, None, None, [], "No models selected"
237
+
238
+ else: # Fixed Window
239
+ # Similar logic for fixed window
240
+ # Split data
241
+ train_df = []
242
+ for uid in df_basic['unique_id'].unique():
243
+ uid_data = df_basic[df_basic['unique_id'] == uid].iloc[:-int(horizon)]
244
+ train_df.append(uid_data)
245
+ train_df = pd.concat(train_df)
246
+
247
+ # Fit and predict
248
+ all_models = models + models_need_predictors
249
+ if all_models:
250
+ sf = StatsForecast(models=all_models, freq=frequency, n_jobs=-1)
251
+ sf.fit(train_df)
252
+ validation_df = sf.predict(h=int(horizon), level=[90, 95])
253
+ else:
254
+ return None, None, None, None, None, [], "No models selected"
255
+
256
+ # Add ML model forecasts if selected
257
+ if use_lgbm:
258
+ mlf = MLForecast(
259
+ models={'LightGBM': LGBMRegressor(verbose=-1)},
260
+ freq=frequency,
261
+ lags=[1, 7, 14],
262
+ num_threads=1
263
+ )
264
+
265
+ if eval_strategy == "Cross Validation":
266
+ ml_cv = mlf.cross_validation(
267
+ df=df_basic,
268
+ h=int(horizon),
269
+ step_size=int(step_size),
270
+ n_windows=int(num_windows)
271
+ )
272
+ validation_df = validation_df.merge(ml_cv, on=['unique_id', 'ds', 'cutoff'], how='outer')
273
+ else:
274
+ mlf.fit(train_df)
275
+ ml_pred = mlf.predict(h=int(horizon))
276
+ validation_df = validation_df.merge(ml_pred, on=['unique_id', 'ds'], how='outer')
277
+
278
+ # Add foundation model forecasts
279
+ if use_chronos and CHRONOS_AVAILABLE:
280
+ try:
281
+ pipeline = ChronosPipeline.from_pretrained(
282
+ "amazon/chronos-t5-tiny",
283
+ device_map="auto",
284
+ torch_dtype=torch.bfloat16,
285
+ )
286
+
287
+ chronos_forecasts = []
288
+ for uid in df_basic['unique_id'].unique():
289
+ uid_data = train_df[train_df['unique_id'] == uid]['y'].values
290
+ context = torch.tensor(uid_data)
291
+ forecast = pipeline.predict(context, prediction_length=int(horizon))
292
+ forecast_median = np.median(forecast[0].numpy(), axis=0)
293
+
294
+ uid_forecast = pd.DataFrame({
295
+ 'unique_id': uid,
296
+ 'ds': pd.date_range(
297
+ start=train_df[train_df['unique_id'] == uid]['ds'].max() + pd.Timedelta(days=1),
298
+ periods=int(horizon),
299
+ freq=frequency
300
+ ),
301
+ 'Chronos': forecast_median
302
+ })
303
+ chronos_forecasts.append(uid_forecast)
304
+
305
+ chronos_df = pd.concat(chronos_forecasts)
306
+ validation_df = validation_df.merge(chronos_df, on=['unique_id', 'ds'], how='outer')
307
+ except Exception as e:
308
+ print(f"Chronos error: {e}")
309
+
310
+ # Evaluate models
311
+ eval_cols = [col for col in validation_df.columns if col not in ['unique_id', 'ds', 'cutoff', 'y']]
312
+
313
+ if 'y' not in validation_df.columns:
314
+ # Merge with actual values
315
+ validation_df = validation_df.merge(
316
+ df_basic[['unique_id', 'ds', 'y']],
317
+ on=['unique_id', 'ds'],
318
+ how='left'
319
+ )
320
+
321
+ # Calculate metrics
322
+ metrics_list = []
323
+ for col in eval_cols:
324
+ if col in validation_df.columns and not validation_df[col].isna().all():
325
+ y_true = validation_df['y'].values
326
+ y_pred = validation_df[col].values
327
+
328
+ mask = ~(np.isnan(y_true) | np.isnan(y_pred))
329
+ if mask.sum() > 0:
330
+ y_true_clean = y_true[mask]
331
+ y_pred_clean = y_pred[mask]
332
+
333
+ metrics_list.append({
334
+ 'Model': col,
335
+ 'MAE': mae(y_true_clean, y_pred_clean),
336
+ 'RMSE': rmse(y_true_clean, y_pred_clean),
337
+ 'MAPE': mape(y_true_clean, y_pred_clean)
338
+ })
339
+
340
+ eval_metrics = pd.DataFrame(metrics_list)
341
+
342
+ # Create validation plot
343
+ validation_plot = create_forecast_plot(
344
+ validation_df.reset_index() if 'index' not in validation_df.columns else validation_df,
345
+ df_basic,
346
+ "Validation Results",
347
+ horizon,
348
+ frequency
349
+ )
350
+
351
+ # Future forecast
352
+ future_models = models + models_need_predictors
353
+ if future_models:
354
+ sf_future = StatsForecast(models=future_models, freq=frequency, n_jobs=-1)
355
+ sf_future.fit(df_basic)
356
+ future_df = sf_future.predict(h=int(future_horizon), level=[90, 95])
357
+ else:
358
+ future_df = pd.DataFrame()
359
+
360
+ # Create future forecast plot
361
+ future_plot = create_forecast_plot(
362
+ future_df.reset_index() if not future_df.empty else pd.DataFrame(),
363
+ df_basic,
364
+ "Future Forecast",
365
+ future_horizon,
366
+ frequency
367
+ )
368
+
369
+ # Export files
370
+ export_files = []
371
+
372
+ # Save to temp files
373
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') as f:
374
+ eval_metrics.to_csv(f, index=False)
375
+ export_files.append(f.name)
376
+
377
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') as f:
378
+ validation_df.to_csv(f, index=False)
379
+ export_files.append(f.name)
380
+
381
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') as f:
382
+ future_df.to_csv(f, index=False)
383
+ export_files.append(f.name)
384
+
385
+ return (
386
+ eval_metrics,
387
+ validation_df,
388
+ validation_plot,
389
+ future_df,
390
+ future_plot,
391
+ export_files,
392
+ "✓ Forecasting completed successfully!"
393
+ )
394
+
395
+ except Exception as e:
396
+ import traceback
397
+ error_msg = f"Error: {str(e)}\n\n{traceback.format_exc()}"
398
+ return None, None, None, None, None, [], error_msg
399
+
400
+
401
+
402
+ # Gradio Interface
403
+ with gr.Blocks(title="Duke Energy Forecasting App") as app:
404
+ gr.Markdown("""
405
+ # 🔮 Duke Energy Time Series Forecasting
406
+
407
+ Upload your time series data and select models to generate forecasts.
408
+ Supports StatsForecast, MLForecast, and Foundation Models (Chronos, Moirai).
409
+ """)
410
+
411
+ with gr.Row():
412
+ with gr.Column(scale=1):
413
+ file_input = gr.File(label="Upload CSV File", file_types=['.csv'])
414
+
415
+ with gr.Accordion("Forecast Configuration", open=True):
416
+ frequency = gr.Dropdown(
417
+ choices=[
418
+ ("Hourly", "H"),
419
+ ("Daily", "D"),
420
+ ("Business Day", "B"),
421
+ ("Weekly", "WS"),
422
+ ("Monthly", "MS"),
423
+ ("Quarterly", "QS"),
424
+ ("Yearly", "YS")
425
+ ],
426
+ label="Data Frequency",
427
+ value="D"
428
+ )
429
+
430
+ eval_strategy = gr.Radio(
431
+ choices=["Fixed Window", "Cross Validation"],
432
+ label="Evaluation Strategy",
433
+ value="Cross Validation"
434
+ )
435
+
436
+ with gr.Group(visible=True) as fixed_window_box:
437
+ gr.Markdown("### Fixed Window Settings")
438
+ horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon")
439
+
440
+ with gr.Group(visible=True) as cv_box:
441
+ gr.Markdown("### Cross Validation Settings")
442
+ with gr.Row():
443
+ step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size")
444
+ num_windows = gr.Slider(1, 20, value=5, step=1, label="Number of Windows")
445
+
446
+ with gr.Group():
447
+ gr.Markdown("### Future Forecast Settings")
448
+ future_horizon = gr.Slider(1, 100, value=10, step=1, label="Future Forecast Horizon")
449
+
450
+ with gr.Accordion("Model Configuration", open=True):
451
+ with gr.Tabs():
452
+ with gr.TabItem("Statistical Models"):
453
+ gr.Markdown("## Basic Models")
454
+ with gr.Row():
455
+ use_historical_avg = gr.Checkbox(label="Historical Average", value=True)
456
+ use_naive = gr.Checkbox(label="Naive", value=True)
457
+
458
+ with gr.Group():
459
+ gr.Markdown("### Seasonality Configuration")
460
+ seasonality = gr.Number(label="Seasonality Period", value=7)
461
+
462
+ gr.Markdown("### Seasonal Models")
463
+ use_seasonal_naive = gr.Checkbox(label="Seasonal Naive", value=True)
464
+
465
+ gr.Markdown("### Window-based Models")
466
+ with gr.Row():
467
+ use_window_avg = gr.Checkbox(label="Window Average", value=False)
468
+ window_size = gr.Number(label="Window Size", value=10)
469
+
470
+ with gr.Row():
471
+ use_seasonal_window_avg = gr.Checkbox(label="Seasonal Window Average", value=False)
472
+ seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
473
+
474
+ gr.Markdown("### Advanced Models")
475
+ with gr.Row():
476
+ use_autoets = gr.Checkbox(label="AutoETS", value=False)
477
+ use_autoarima = gr.Checkbox(label="AutoARIMA", value=False)
478
+ with gr.Row():
479
+ use_autoces = gr.Checkbox(label="AutoCES", value=False)
480
+ use_autotheta = gr.Checkbox(label="AutoTheta", value=False)
481
+
482
+ with gr.TabItem("Machine Learning"):
483
+ gr.Markdown("## Gradient Boosting Models")
484
+ use_lgbm = gr.Checkbox(label="LightGBM", value=True)
485
+
486
+ with gr.TabItem("Foundation Models"):
487
+ gr.Markdown("## State-of-the-Art Foundation Models")
488
+
489
+ with gr.Row():
490
+ use_chronos = gr.Checkbox(
491
+ label="Chronos (Amazon)",
492
+ value=CHRONOS_AVAILABLE,
493
+ interactive=CHRONOS_AVAILABLE
494
+ )
495
+ use_moirai = gr.Checkbox(
496
+ label="Moirai (Salesforce)",
497
+ value=False,
498
+ interactive=MOIRAI_AVAILABLE
499
+ )
500
+
501
+ if not CHRONOS_AVAILABLE:
502
+ gr.Markdown("⚠️ Chronos not available. Install: `pip install chronos-forecasting`")
503
+ if not MOIRAI_AVAILABLE:
504
+ gr.Markdown("⚠️ Moirai not available. Install: `pip install uni2ts`")
505
+
506
+ with gr.Column(scale=3):
507
+ message_output = gr.Textbox(label="Status Message")
508
+
509
+ with gr.Tabs():
510
+ with gr.TabItem("Validation Results"):
511
+ eval_output = gr.Dataframe(label="Evaluation Metrics")
512
+ validation_plot = gr.Plot(label="Validation Plot")
513
+ validation_output = gr.Dataframe(label="Validation Data", visible=False)
514
+
515
+ with gr.Row():
516
+ show_data_btn = gr.Button("Show Validation Data")
517
+ hide_data_btn = gr.Button("Hide Validation Data", visible=False)
518
+
519
+ with gr.TabItem("Future Forecast"):
520
+ forecast_plot = gr.Plot(label="Future Forecast Plot")
521
+ forecast_output = gr.Dataframe(label="Future Forecast Data", visible=False)
522
+
523
+ with gr.Row():
524
+ show_forecast_btn = gr.Button("Show Forecast Data")
525
+ hide_forecast_btn = gr.Button("Hide Forecast Data", visible=False)
526
+
527
+ with gr.TabItem("Export Results"):
528
+ export_files = gr.Files(label="Download Results")
529
+
530
+ with gr.Row():
531
+ submit_btn = gr.Button("Run Validation and Forecast", variant="primary", size="lg")
532
+
533
+ # Event handlers
534
+ def update_eval_boxes(strategy):
535
+ return (
536
+ gr.update(visible=strategy == "Fixed Window"),
537
+ gr.update(visible=strategy == "Cross Validation")
538
+ )
539
+
540
+ eval_strategy.change(
541
+ fn=update_eval_boxes,
542
+ inputs=[eval_strategy],
543
+ outputs=[fixed_window_box, cv_box]
544
+ )
545
+
546
+ def show_data():
547
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
548
+
549
+ def hide_data():
550
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
551
+
552
+ show_data_btn.click(fn=show_data, outputs=[validation_output, hide_data_btn, show_data_btn])
553
+ hide_data_btn.click(fn=hide_data, outputs=[validation_output, hide_data_btn, show_data_btn])
554
+ show_forecast_btn.click(fn=show_data, outputs=[forecast_output, hide_forecast_btn, show_forecast_btn])
555
+ hide_forecast_btn.click(fn=hide_data, outputs=[forecast_output, hide_forecast_btn, show_forecast_btn])
556
+
557
+ submit_btn.click(
558
+ fn=run_forecast,
559
+ inputs=[
560
+ file_input, frequency, eval_strategy, horizon, step_size, num_windows,
561
+ use_historical_avg, use_naive, use_seasonal_naive, seasonality,
562
+ use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
563
+ use_autoets, use_autoarima, use_autoces, use_autotheta,
564
+ use_lgbm, use_chronos, use_moirai,
565
+ future_horizon
566
+ ],
567
+ outputs=[
568
+ eval_output,
569
+ validation_output,
570
+ validation_plot,
571
+ forecast_output,
572
+ forecast_plot,
573
+ export_files,
574
+ message_output
575
+ ]
576
+ )
577
+
578
+ if __name__ == "__main__":
579
+ app.launch(share=False)