fmegahed commited on
Commit
183818f
·
verified ·
1 Parent(s): bc89d99

Updated the logo

Browse files
Files changed (1) hide show
  1. app.py +637 -637
app.py CHANGED
@@ -1,637 +1,637 @@
1
- """
2
- Cross-Validation Visualizer
3
- ============================
4
- Visualize time-series cross-validation strategies (expanding window and
5
- rolling/sliding window) with animated fold progression and per-fold
6
- accuracy metrics using a naive forecast.
7
-
8
- Part of ISA 444: Business Forecasting — Spring 2026, Miami University.
9
- Deployed to HuggingFace Spaces as fmegahed/cv-visualizer.
10
- """
11
-
12
- import io
13
- import time
14
- import threading
15
-
16
- import gradio as gr
17
- import numpy as np
18
- import pandas as pd
19
- import matplotlib
20
- matplotlib.use("Agg")
21
- import matplotlib.pyplot as plt
22
- import matplotlib.patches as mpatches
23
- from matplotlib.lines import Line2D
24
-
25
- # ---------------------------------------------------------------------------
26
- # Color palette
27
- # ---------------------------------------------------------------------------
28
- TEAL = "#84d6d3"
29
- RED = "#C3142D"
30
- GRAY = "#CCCCCC"
31
- DARK_GRAY = "#888888"
32
- WHITE = "#FFFFFF"
33
-
34
- # ---------------------------------------------------------------------------
35
- # Dataset generators
36
- # ---------------------------------------------------------------------------
37
-
38
- def _airline_passengers() -> pd.DataFrame:
39
- """Classic Box-Jenkins airline passengers (1949-1960, 144 obs)."""
40
- # Reproduce the well-known series with a multiplicative seasonal pattern.
41
- np.random.seed(42)
42
- n = 144
43
- t = np.arange(n)
44
- trend = 132 + 2.4 * t
45
- seasonal_period = 12
46
- seasonal = 40 * np.sin(2 * np.pi * t / seasonal_period)
47
- # Multiplicative-style growth in amplitude
48
- amplitude_growth = 1 + 0.006 * t
49
- y = trend * amplitude_growth + seasonal * amplitude_growth
50
- # Add a touch of noise
51
- y += np.random.normal(0, 5, n)
52
- dates = pd.date_range("1949-01-01", periods=n, freq="MS")
53
- return pd.DataFrame({"ds": dates, "y": np.round(y, 1)})
54
-
55
-
56
- def _ohio_employment() -> pd.DataFrame:
57
- """Synthetic Ohio monthly employment (2010-2024, 180 obs)."""
58
- np.random.seed(123)
59
- n = 180
60
- t = np.arange(n)
61
- trend = 5200 + 3.5 * t
62
- seasonal = 120 * np.sin(2 * np.pi * t / 12) + 60 * np.cos(2 * np.pi * t / 6)
63
- # Covid dip around index 120-130 (~ early 2020)
64
- dip = np.zeros(n)
65
- dip[120:132] = -np.array([200, 800, 1100, 900, 600, 400, 300, 200, 150, 100, 60, 30])
66
- noise = np.random.normal(0, 40, n)
67
- y = trend + seasonal + dip + noise
68
- dates = pd.date_range("2010-01-01", periods=n, freq="MS")
69
- return pd.DataFrame({"ds": dates, "y": np.round(y, 1)})
70
-
71
-
72
- def _simple_trend() -> pd.DataFrame:
73
- """Simple linear trend + noise (120 obs) for pedagogical clarity."""
74
- np.random.seed(7)
75
- n = 120
76
- t = np.arange(n)
77
- y = 0.5 * t + np.random.normal(0, 2, n)
78
- dates = pd.date_range("2015-01-01", periods=n, freq="MS")
79
- return pd.DataFrame({"ds": dates, "y": np.round(y, 2)})
80
-
81
-
82
- DATASETS = {
83
- "Airline Passengers": _airline_passengers,
84
- "Ohio Employment": _ohio_employment,
85
- "Simple Trend + Noise": _simple_trend,
86
- }
87
-
88
- # ---------------------------------------------------------------------------
89
- # Fold computation
90
- # ---------------------------------------------------------------------------
91
-
92
- def compute_folds(n, initial, horizon, step, strategy, window_size=None):
93
- """Return a list of fold dicts with train/test index ranges."""
94
- folds = []
95
- max_possible = n # safety upper bound
96
- if strategy == "Expanding Window":
97
- start = 0
98
- for k in range(max_possible):
99
- train_end = initial + k * step
100
- test_start = train_end
101
- test_end = test_start + horizon
102
- if test_end > n:
103
- break
104
- folds.append({
105
- "fold": k + 1,
106
- "train_start": start,
107
- "train_end": train_end,
108
- "test_start": test_start,
109
- "test_end": test_end,
110
- })
111
- else: # Rolling / Sliding Window
112
- ws = window_size if window_size is not None else initial
113
- for k in range(max_possible):
114
- train_start = k * step
115
- train_end = train_start + ws
116
- test_start = train_end
117
- test_end = test_start + horizon
118
- if test_end > n:
119
- break
120
- folds.append({
121
- "fold": k + 1,
122
- "train_start": train_start,
123
- "train_end": train_end,
124
- "test_start": test_start,
125
- "test_end": test_end,
126
- })
127
- return folds
128
-
129
- # ---------------------------------------------------------------------------
130
- # Naive forecast & metrics
131
- # ---------------------------------------------------------------------------
132
-
133
- def naive_metrics(y_series, folds):
134
- """Compute MAE, RMSE, MAPE per fold using a naive (last-value) forecast."""
135
- records = []
136
- y = y_series.values if hasattr(y_series, "values") else np.array(y_series)
137
- for f in folds:
138
- train_vals = y[f["train_start"]:f["train_end"]]
139
- test_vals = y[f["test_start"]:f["test_end"]]
140
- forecast = np.full_like(test_vals, train_vals[-1], dtype=float)
141
- errors = test_vals - forecast
142
- abs_errors = np.abs(errors)
143
- mae = np.mean(abs_errors)
144
- rmse = np.sqrt(np.mean(errors ** 2))
145
- # MAPE — guard against zeros
146
- nonzero = np.abs(test_vals) > 1e-8
147
- if nonzero.any():
148
- mape = np.mean(np.abs(errors[nonzero] / test_vals[nonzero])) * 100
149
- else:
150
- mape = np.nan
151
- records.append({
152
- "Fold": f["fold"],
153
- "Train Start": f["train_start"],
154
- "Train End": f["train_end"] - 1,
155
- "Test Start": f["test_start"],
156
- "Test End": f["test_end"] - 1,
157
- "Train Size": f["train_end"] - f["train_start"],
158
- "MAE": round(mae, 2),
159
- "RMSE": round(rmse, 2),
160
- "MAPE (%)": round(mape, 2) if not np.isnan(mape) else "N/A",
161
- })
162
- return pd.DataFrame(records)
163
-
164
- # ---------------------------------------------------------------------------
165
- # Plotting
166
- # ---------------------------------------------------------------------------
167
-
168
- def _make_figure(df, folds, current_fold, show_all, strategy_label):
169
- """Build the matplotlib figure with either one or two panels."""
170
- y = df["y"].values
171
- n = len(y)
172
- x = np.arange(n)
173
-
174
- if show_all:
175
- fig, ax_gantt = plt.subplots(figsize=(12, 5), facecolor=WHITE)
176
- _draw_gantt(ax_gantt, folds, current_fold=None, n=n, highlight=False)
177
- ax_gantt.set_title(
178
- f"All {len(folds)} Folds — {strategy_label}",
179
- fontsize=14, fontweight="bold", pad=10,
180
- )
181
- fig.tight_layout(pad=2.0)
182
- return fig
183
-
184
- # Two-panel layout
185
- fig, (ax_ts, ax_gantt) = plt.subplots(
186
- 2, 1, figsize=(12, 7.5),
187
- gridspec_kw={"height_ratios": [2, 1.2]},
188
- facecolor=WHITE,
189
- )
190
-
191
- fold_idx = max(0, min(current_fold - 1, len(folds) - 1))
192
- f = folds[fold_idx]
193
- K = len(folds)
194
-
195
- # --- Top panel: time series with CV split ---
196
- ax_ts.plot(x, y, color=GRAY, linewidth=1.2, zorder=1, label="Full series")
197
-
198
- # Training segment
199
- train_x = x[f["train_start"]:f["train_end"]]
200
- train_y = y[f["train_start"]:f["train_end"]]
201
- ax_ts.plot(train_x, train_y, color=TEAL, linewidth=2.4, zorder=3, label="Training")
202
-
203
- # Test segment
204
- test_x = x[f["test_start"]:f["test_end"]]
205
- test_y = y[f["test_start"]:f["test_end"]]
206
- ax_ts.plot(test_x, test_y, color=RED, linewidth=2.4, zorder=3, label="Test / Validation")
207
-
208
- # Vertical boundary lines
209
- ax_ts.axvline(f["train_end"] - 0.5, color=DARK_GRAY, linestyle="--", linewidth=1, zorder=2, alpha=0.7)
210
- if f["train_start"] > 0:
211
- ax_ts.axvline(f["train_start"] - 0.5, color=DARK_GRAY, linestyle=":", linewidth=1, zorder=2, alpha=0.5)
212
-
213
- ax_ts.set_title(
214
- f"Fold {f['fold']} of {K} — {strategy_label}",
215
- fontsize=14, fontweight="bold", pad=10,
216
- )
217
- ax_ts.set_xlabel("Time Index", fontsize=11)
218
- ax_ts.set_ylabel("y", fontsize=11)
219
- ax_ts.legend(loc="upper left", fontsize=9, framealpha=0.9)
220
- ax_ts.set_xlim(-1, n + 1)
221
-
222
- # --- Bottom panel: Gantt-style fold map ---
223
- _draw_gantt(ax_gantt, folds, current_fold=f["fold"], n=n, highlight=True)
224
-
225
- fig.tight_layout(pad=2.0)
226
- return fig
227
-
228
-
229
- def _draw_gantt(ax, folds, current_fold, n, highlight):
230
- """Draw the Gantt-style fold map on the given axes."""
231
- K = len(folds)
232
- bar_height = 0.6
233
- highlight_height = 0.85
234
-
235
- for f in folds:
236
- k = f["fold"]
237
- is_current = highlight and (k == current_fold)
238
- h = highlight_height if is_current else bar_height
239
- lw = 1.8 if is_current else 0.5
240
- edge = "black" if is_current else "#666666"
241
-
242
- # Training bar
243
- ax.barh(
244
- k, f["train_end"] - f["train_start"], left=f["train_start"],
245
- height=h, color=TEAL, edgecolor=edge, linewidth=lw, zorder=3 if is_current else 2,
246
- )
247
- # Test bar
248
- ax.barh(
249
- k, f["test_end"] - f["test_start"], left=f["test_start"],
250
- height=h, color=RED, edgecolor=edge, linewidth=lw, zorder=3 if is_current else 2,
251
- )
252
-
253
- ax.set_xlabel("Time Index", fontsize=11)
254
- ax.set_ylabel("Fold", fontsize=11)
255
- ax.set_xlim(-1, n + 1)
256
- ax.set_ylim(0.2, K + 0.8)
257
- ax.set_yticks(range(1, K + 1))
258
- ax.invert_yaxis()
259
-
260
- # Legend
261
- handles = [
262
- mpatches.Patch(facecolor=TEAL, edgecolor="#333", label="Training"),
263
- mpatches.Patch(facecolor=RED, edgecolor="#333", label="Test"),
264
- ]
265
- ax.legend(handles=handles, loc="upper right", fontsize=9, framealpha=0.9)
266
-
267
- # ---------------------------------------------------------------------------
268
- # Summary text
269
- # ---------------------------------------------------------------------------
270
-
271
- def build_summary(folds, strategy, initial, step, metrics_df):
272
- K = len(folds)
273
- if K == 0:
274
- return "**No valid folds.** Adjust the parameters so that at least one fold fits within the data."
275
-
276
- # Average metrics
277
- numeric_cols = ["MAE", "RMSE"]
278
- avgs = {c: metrics_df[c].mean() for c in numeric_cols}
279
- mape_vals = pd.to_numeric(metrics_df["MAPE (%)"], errors="coerce")
280
- avg_mape = mape_vals.mean()
281
-
282
- lines = [
283
- f"### Summary",
284
- f"- **Total folds:** {K}",
285
- f"- **Average MAE:** {avgs['MAE']:.2f}",
286
- f"- **Average RMSE:** {avgs['RMSE']:.2f}",
287
- f"- **Average MAPE:** {avg_mape:.2f}%" if not np.isnan(avg_mape) else "- **Average MAPE:** N/A",
288
- "",
289
- ]
290
- if strategy == "Expanding Window":
291
- last_train = initial + (K - 1) * step
292
- lines.append(
293
- f"*Expanding window*: training set grows from **{initial}** to "
294
- f"**{last_train}** observations across {K} folds."
295
- )
296
- else:
297
- ws = folds[0]["train_end"] - folds[0]["train_start"]
298
- lines.append(
299
- f"*Rolling / sliding window*: fixed training size of **{ws}** "
300
- f"observations slides forward across {K} folds."
301
- )
302
- lines.append("")
303
- lines.append(
304
- "Forecasts use a **naive model** (last training value repeated over "
305
- "the horizon) to keep focus on the CV visualization concept."
306
- )
307
- return "\n".join(lines)
308
-
309
- # ---------------------------------------------------------------------------
310
- # Main update callback
311
- # ---------------------------------------------------------------------------
312
-
313
- def load_dataset(name, file_obj):
314
- """Return a DataFrame given the selector value and optional upload."""
315
- if name == "Upload CSV" and file_obj is not None:
316
- raw = pd.read_csv(file_obj.name if hasattr(file_obj, "name") else file_obj)
317
- if "ds" not in raw.columns or "y" not in raw.columns:
318
- raise gr.Error("Uploaded CSV must contain columns named 'ds' and 'y'.")
319
- return raw[["ds", "y"]].copy()
320
- if name in DATASETS:
321
- return DATASETS[name]()
322
- # Fallback
323
- return DATASETS["Simple Trend + Noise"]()
324
-
325
-
326
- def update_total_folds(dataset_name, file_obj, strategy, initial, horizon, step_size, window_size):
327
- """Return the max number of folds so we can update the fold slider range."""
328
- try:
329
- df = load_dataset(dataset_name, file_obj)
330
- except Exception:
331
- return gr.update(maximum=1, value=1)
332
- n = len(df)
333
- folds = compute_folds(n, initial, horizon, step_size, strategy, window_size)
334
- K = max(len(folds), 1)
335
- return gr.update(maximum=K, value=min(1, K))
336
-
337
-
338
- def run_visualizer(dataset_name, file_obj, strategy, initial, horizon, step_size, window_size, current_fold, show_all):
339
- """Core callback — returns (figure, metrics_df, summary_md)."""
340
- try:
341
- df = load_dataset(dataset_name, file_obj)
342
- except gr.Error:
343
- raise
344
- except Exception as exc:
345
- raise gr.Error(f"Could not load data: {exc}")
346
-
347
- n = len(df)
348
- folds = compute_folds(n, initial, horizon, step_size, strategy, window_size)
349
- K = len(folds)
350
-
351
- if K == 0:
352
- fig, ax = plt.subplots(figsize=(12, 4), facecolor=WHITE)
353
- ax.text(0.5, 0.5, "No valid folds — adjust parameters.",
354
- ha="center", va="center", fontsize=14, transform=ax.transAxes)
355
- ax.axis("off")
356
- empty_df = pd.DataFrame(columns=[
357
- "Fold", "Train Start", "Train End", "Test Start", "Test End",
358
- "Train Size", "MAE", "RMSE", "MAPE (%)"
359
- ])
360
- summary = "**No valid folds.** Reduce `initial` + `horizon` or increase data length."
361
- return fig, empty_df, summary
362
-
363
- strategy_label = strategy
364
- fig = _make_figure(df, folds, current_fold, show_all, strategy_label)
365
- metrics_df = naive_metrics(df["y"], folds)
366
-
367
- # Append average row
368
- avg_row = {
369
- "Fold": "Avg",
370
- "Train Start": "",
371
- "Train End": "",
372
- "Test Start": "",
373
- "Test End": "",
374
- "Train Size": "",
375
- "MAE": round(metrics_df["MAE"].mean(), 2),
376
- "RMSE": round(metrics_df["RMSE"].mean(), 2),
377
- }
378
- mape_vals = pd.to_numeric(metrics_df["MAPE (%)"], errors="coerce")
379
- avg_row["MAPE (%)"] = round(mape_vals.mean(), 2) if not mape_vals.isna().all() else "N/A"
380
- avg_df = pd.concat([metrics_df, pd.DataFrame([avg_row])], ignore_index=True)
381
-
382
- summary = build_summary(folds, strategy, initial, step_size, metrics_df)
383
- plt.close("all")
384
- return fig, avg_df, summary
385
-
386
- # ---------------------------------------------------------------------------
387
- # Gradio UI
388
- # ---------------------------------------------------------------------------
389
-
390
- def build_app():
391
- theme = gr.themes.Soft(
392
- primary_hue=gr.themes.Color(
393
- c50="#eafaf9", c100="#d4f5f3", c200="#aaecea",
394
- c300="#84d6d3", c400="#5ec4c0", c500="#3eaea9",
395
- c600="#2e938e", c700="#237873", c800="#1a5d59",
396
- c900="#12423f", c950="#0a2725",
397
- ),
398
- secondary_hue=gr.themes.Color(
399
- c50="#fef2f3", c100="#fde6e8", c200="#fbd0d5",
400
- c300="#f7a4ae", c400="#f17182", c500="#C3142D",
401
- c600="#b01228", c700="#8B0E1E", c800="#6e0b18",
402
- c900="#5c0d17", c950="#33040a",
403
- ),
404
- font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
405
- )
406
-
407
- with gr.Blocks(
408
- title="Cross-Validation Visualizer v1.0",
409
- theme=theme,
410
- css="""
411
- .gradio-container { max-width: 1280px !important; margin: auto; }
412
- footer { display: none !important; }
413
- .gr-button-primary { background: #C3142D !important; border: none !important; }
414
- .gr-button-primary:hover { background: #8B0E1E !important; }
415
- .gr-button-secondary { border-color: #84d6d3 !important; color: #84d6d3 !important; }
416
- .gr-button-secondary:hover { background: #84d6d3 !important; color: white !important; }
417
- .gr-input:focus { border-color: #84d6d3 !important; box-shadow: 0 0 0 2px rgba(132,214,211,0.2) !important; }
418
- """,
419
- ) as demo:
420
- gr.HTML("""
421
- <div style="display: flex; align-items: center; gap: 16px; padding: 16px 24px;
422
- background: linear-gradient(135deg, #C3142D 0%, #8B0E1E 100%);
423
- border-radius: 12px; margin-bottom: 16px; box-shadow: 0 4px 12px rgba(0,0,0,0.15);">
424
- <img src="https://miamioh.edu/miami-brand/_files/images/system/logo-usage/minimum-size/beveled-m-min-size.png"
425
- alt="Miami University" style="height: 56px; filter: brightness(0) invert(1);">
426
- <div>
427
- <h1 style="margin: 0; color: white; font-size: 24px; font-weight: 700; letter-spacing: -0.5px;">
428
- Cross-Validation Visualizer v1.0
429
- </h1>
430
- <p style="margin: 4px 0 0; color: rgba(255,255,255,0.85); font-size: 14px;">
431
- ISA 444: Business Forecasting &middot; Farmer School of Business &middot; Miami University
432
- </p>
433
- </div>
434
- </div>
435
- """)
436
-
437
- gr.HTML("""
438
- <div style="background: #f8f9fa; border-left: 4px solid #84d6d3; padding: 12px 16px;
439
- border-radius: 0 8px 8px 0; margin-bottom: 16px; font-size: 14px; color: #585E60;">
440
- Visualize time-series cross-validation strategies (expanding window and rolling/sliding window)
441
- with animated fold progression and per-fold accuracy metrics using a naive forecast.
442
- Understand how forecast accuracy is evaluated across folds.
443
- </div>
444
- """)
445
-
446
- with gr.Row():
447
- # ---- Left column: controls ----
448
- with gr.Column(scale=1, min_width=300):
449
- gr.Markdown("### Data")
450
- dataset_dd = gr.Dropdown(
451
- choices=["Airline Passengers", "Ohio Employment",
452
- "Simple Trend + Noise", "Upload CSV"],
453
- value="Simple Trend + Noise",
454
- label="Dataset",
455
- )
456
- csv_upload = gr.File(
457
- label="Upload CSV (columns: ds, y)",
458
- file_types=[".csv"],
459
- visible=False,
460
- )
461
-
462
- gr.Markdown("### CV Strategy")
463
- strategy_radio = gr.Radio(
464
- choices=["Expanding Window", "Rolling/Sliding Window"],
465
- value="Expanding Window",
466
- label="Strategy",
467
- )
468
-
469
- gr.Markdown("### Parameters")
470
- initial_slider = gr.Slider(
471
- minimum=12, maximum=120, value=60, step=1,
472
- label="initial (initial training size)",
473
- )
474
- horizon_slider = gr.Slider(
475
- minimum=1, maximum=24, value=12, step=1,
476
- label="horizon (forecast horizon)",
477
- )
478
- step_slider = gr.Slider(
479
- minimum=1, maximum=12, value=1, step=1,
480
- label="step (step size between folds)",
481
- )
482
- window_slider = gr.Slider(
483
- minimum=12, maximum=120, value=60, step=1,
484
- label="window_size (rolling window only)",
485
- visible=False,
486
- )
487
-
488
- gr.Markdown("### Animation Controls")
489
- fold_slider = gr.Slider(
490
- minimum=1, maximum=1, value=1, step=1,
491
- label="Current Fold",
492
- )
493
- with gr.Row():
494
- play_btn = gr.Button("Play Animation", variant="primary")
495
- stop_btn = gr.Button("Stop", variant="stop")
496
- show_all_cb = gr.Checkbox(label="Show All Folds", value=False)
497
-
498
- # ---- Right column: outputs ----
499
- with gr.Column(scale=2, min_width=500):
500
- plot_output = gr.Plot(label="Visualization")
501
- metrics_output = gr.Dataframe(
502
- label="Per-Fold Metrics (Naive Forecast)",
503
- interactive=False,
504
- wrap=True,
505
- )
506
- summary_output = gr.Markdown(label="Summary")
507
-
508
- # ---- Visibility toggles ----
509
- def toggle_csv_upload(name):
510
- return gr.update(visible=(name == "Upload CSV"))
511
-
512
- dataset_dd.change(toggle_csv_upload, inputs=[dataset_dd], outputs=[csv_upload])
513
-
514
- def toggle_window_slider(strategy):
515
- return gr.update(visible=(strategy == "Rolling/Sliding Window"))
516
-
517
- strategy_radio.change(toggle_window_slider, inputs=[strategy_radio], outputs=[window_slider])
518
-
519
- # ---- Gather all control inputs ----
520
- all_inputs = [
521
- dataset_dd, csv_upload, strategy_radio,
522
- initial_slider, horizon_slider, step_slider,
523
- window_slider, fold_slider, show_all_cb,
524
- ]
525
- all_outputs = [plot_output, metrics_output, summary_output]
526
-
527
- # Helper to also refresh the fold slider range
528
- fold_range_inputs = [
529
- dataset_dd, csv_upload, strategy_radio,
530
- initial_slider, horizon_slider, step_slider, window_slider,
531
- ]
532
-
533
- def refresh_and_run(dataset_name, file_obj, strategy, initial, horizon,
534
- step_size, window_size, current_fold, show_all):
535
- """Update fold slider range, clamp current_fold, then run."""
536
- try:
537
- df = load_dataset(dataset_name, file_obj)
538
- except Exception:
539
- df = DATASETS["Simple Trend + Noise"]()
540
- n = len(df)
541
- folds = compute_folds(n, initial, horizon, step_size, strategy, window_size)
542
- K = max(len(folds), 1)
543
- current_fold = max(1, min(current_fold, K))
544
- fig, metrics, summary = run_visualizer(
545
- dataset_name, file_obj, strategy, initial, horizon,
546
- step_size, window_size, current_fold, show_all,
547
- )
548
- return gr.update(maximum=K, value=current_fold), fig, metrics, summary
549
-
550
- combined_outputs = [fold_slider] + all_outputs
551
-
552
- # Trigger on any parameter change
553
- for ctrl in [dataset_dd, csv_upload, strategy_radio, initial_slider,
554
- horizon_slider, step_slider, window_slider, show_all_cb]:
555
- ctrl.change(
556
- refresh_and_run,
557
- inputs=all_inputs,
558
- outputs=combined_outputs,
559
- )
560
-
561
- # Fold slider change (just re-render, no range update needed)
562
- fold_slider.release(
563
- run_visualizer,
564
- inputs=all_inputs,
565
- outputs=all_outputs,
566
- )
567
-
568
- # ---- Animation via a background thread ----
569
- # We use a gr.State to hold the "playing" flag
570
- animation_state = gr.State({"playing": False})
571
-
572
- def start_animation(state, dataset_name, file_obj, strategy, initial,
573
- horizon, step_size, window_size, current_fold, show_all):
574
- state["playing"] = True
575
- try:
576
- df = load_dataset(dataset_name, file_obj)
577
- except Exception:
578
- df = DATASETS["Simple Trend + Noise"]()
579
- n = len(df)
580
- folds = compute_folds(n, initial, horizon, step_size, strategy, window_size)
581
- K = max(len(folds), 1)
582
-
583
- for k in range(1, K + 1):
584
- if not state.get("playing", False):
585
- break
586
- fig, metrics, summary = run_visualizer(
587
- dataset_name, file_obj, strategy, initial, horizon,
588
- step_size, window_size, k, False,
589
- )
590
- yield state, gr.update(maximum=K, value=k), fig, metrics, summary
591
- time.sleep(1.0)
592
- state["playing"] = False
593
- yield state, gr.update(), fig, metrics, summary
594
-
595
- def stop_animation(state):
596
- state["playing"] = False
597
- return state
598
-
599
- play_btn.click(
600
- start_animation,
601
- inputs=[animation_state] + all_inputs,
602
- outputs=[animation_state, fold_slider] + all_outputs,
603
- )
604
- stop_btn.click(stop_animation, inputs=[animation_state], outputs=[animation_state])
605
-
606
- # ---- Initial render on load ----
607
- demo.load(
608
- refresh_and_run,
609
- inputs=all_inputs,
610
- outputs=combined_outputs,
611
- )
612
-
613
- gr.HTML("""
614
- <div style="margin-top: 24px; padding: 16px; background: #f8f9fa; border-radius: 8px;
615
- text-align: center; font-size: 13px; color: #585E60; border-top: 2px solid #84d6d3;">
616
- <div style="margin-bottom: 4px;">
617
- <strong style="color: #C3142D;">Developed by</strong>
618
- <a href="https://miamioh.edu/fsb/directory/?up=/directory/megahefm"
619
- style="color: #84d6d3; text-decoration: none; font-weight: 600;">
620
- Fadel M. Megahed
621
- </a>
622
- &middot; Gloss Professor of Analytics &middot; Miami University
623
- </div>
624
- <div style="font-size: 12px; color: #888;">
625
- Version 1.0.0 &middot; Spring 2026 &middot;
626
- <a href="https://github.com/fmegahed" style="color: #84d6d3; text-decoration: none;">GitHub</a> &middot;
627
- <a href="https://www.linkedin.com/in/fmegahed/" style="color: #84d6d3; text-decoration: none;">LinkedIn</a>
628
- </div>
629
- </div>
630
- """)
631
-
632
- return demo
633
-
634
-
635
- if __name__ == "__main__":
636
- app = build_app()
637
- app.launch()
 
1
+ """
2
+ Cross-Validation Visualizer
3
+ ============================
4
+ Visualize time-series cross-validation strategies (expanding window and
5
+ rolling/sliding window) with animated fold progression and per-fold
6
+ accuracy metrics using a naive forecast.
7
+
8
+ Part of ISA 444: Business Forecasting — Spring 2026, Miami University.
9
+ Deployed to HuggingFace Spaces as fmegahed/cv-visualizer.
10
+ """
11
+
12
+ import io
13
+ import time
14
+ import threading
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import pandas as pd
19
+ import matplotlib
20
+ matplotlib.use("Agg")
21
+ import matplotlib.pyplot as plt
22
+ import matplotlib.patches as mpatches
23
+ from matplotlib.lines import Line2D
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Color palette
27
+ # ---------------------------------------------------------------------------
28
+ TEAL = "#84d6d3"
29
+ RED = "#C3142D"
30
+ GRAY = "#CCCCCC"
31
+ DARK_GRAY = "#888888"
32
+ WHITE = "#FFFFFF"
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Dataset generators
36
+ # ---------------------------------------------------------------------------
37
+
38
+ def _airline_passengers() -> pd.DataFrame:
39
+ """Classic Box-Jenkins airline passengers (1949-1960, 144 obs)."""
40
+ # Reproduce the well-known series with a multiplicative seasonal pattern.
41
+ np.random.seed(42)
42
+ n = 144
43
+ t = np.arange(n)
44
+ trend = 132 + 2.4 * t
45
+ seasonal_period = 12
46
+ seasonal = 40 * np.sin(2 * np.pi * t / seasonal_period)
47
+ # Multiplicative-style growth in amplitude
48
+ amplitude_growth = 1 + 0.006 * t
49
+ y = trend * amplitude_growth + seasonal * amplitude_growth
50
+ # Add a touch of noise
51
+ y += np.random.normal(0, 5, n)
52
+ dates = pd.date_range("1949-01-01", periods=n, freq="MS")
53
+ return pd.DataFrame({"ds": dates, "y": np.round(y, 1)})
54
+
55
+
56
+ def _ohio_employment() -> pd.DataFrame:
57
+ """Synthetic Ohio monthly employment (2010-2024, 180 obs)."""
58
+ np.random.seed(123)
59
+ n = 180
60
+ t = np.arange(n)
61
+ trend = 5200 + 3.5 * t
62
+ seasonal = 120 * np.sin(2 * np.pi * t / 12) + 60 * np.cos(2 * np.pi * t / 6)
63
+ # Covid dip around index 120-130 (~ early 2020)
64
+ dip = np.zeros(n)
65
+ dip[120:132] = -np.array([200, 800, 1100, 900, 600, 400, 300, 200, 150, 100, 60, 30])
66
+ noise = np.random.normal(0, 40, n)
67
+ y = trend + seasonal + dip + noise
68
+ dates = pd.date_range("2010-01-01", periods=n, freq="MS")
69
+ return pd.DataFrame({"ds": dates, "y": np.round(y, 1)})
70
+
71
+
72
+ def _simple_trend() -> pd.DataFrame:
73
+ """Simple linear trend + noise (120 obs) for pedagogical clarity."""
74
+ np.random.seed(7)
75
+ n = 120
76
+ t = np.arange(n)
77
+ y = 0.5 * t + np.random.normal(0, 2, n)
78
+ dates = pd.date_range("2015-01-01", periods=n, freq="MS")
79
+ return pd.DataFrame({"ds": dates, "y": np.round(y, 2)})
80
+
81
+
82
+ DATASETS = {
83
+ "Airline Passengers": _airline_passengers,
84
+ "Ohio Employment": _ohio_employment,
85
+ "Simple Trend + Noise": _simple_trend,
86
+ }
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Fold computation
90
+ # ---------------------------------------------------------------------------
91
+
92
+ def compute_folds(n, initial, horizon, step, strategy, window_size=None):
93
+ """Return a list of fold dicts with train/test index ranges."""
94
+ folds = []
95
+ max_possible = n # safety upper bound
96
+ if strategy == "Expanding Window":
97
+ start = 0
98
+ for k in range(max_possible):
99
+ train_end = initial + k * step
100
+ test_start = train_end
101
+ test_end = test_start + horizon
102
+ if test_end > n:
103
+ break
104
+ folds.append({
105
+ "fold": k + 1,
106
+ "train_start": start,
107
+ "train_end": train_end,
108
+ "test_start": test_start,
109
+ "test_end": test_end,
110
+ })
111
+ else: # Rolling / Sliding Window
112
+ ws = window_size if window_size is not None else initial
113
+ for k in range(max_possible):
114
+ train_start = k * step
115
+ train_end = train_start + ws
116
+ test_start = train_end
117
+ test_end = test_start + horizon
118
+ if test_end > n:
119
+ break
120
+ folds.append({
121
+ "fold": k + 1,
122
+ "train_start": train_start,
123
+ "train_end": train_end,
124
+ "test_start": test_start,
125
+ "test_end": test_end,
126
+ })
127
+ return folds
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # Naive forecast & metrics
131
+ # ---------------------------------------------------------------------------
132
+
133
+ def naive_metrics(y_series, folds):
134
+ """Compute MAE, RMSE, MAPE per fold using a naive (last-value) forecast."""
135
+ records = []
136
+ y = y_series.values if hasattr(y_series, "values") else np.array(y_series)
137
+ for f in folds:
138
+ train_vals = y[f["train_start"]:f["train_end"]]
139
+ test_vals = y[f["test_start"]:f["test_end"]]
140
+ forecast = np.full_like(test_vals, train_vals[-1], dtype=float)
141
+ errors = test_vals - forecast
142
+ abs_errors = np.abs(errors)
143
+ mae = np.mean(abs_errors)
144
+ rmse = np.sqrt(np.mean(errors ** 2))
145
+ # MAPE — guard against zeros
146
+ nonzero = np.abs(test_vals) > 1e-8
147
+ if nonzero.any():
148
+ mape = np.mean(np.abs(errors[nonzero] / test_vals[nonzero])) * 100
149
+ else:
150
+ mape = np.nan
151
+ records.append({
152
+ "Fold": f["fold"],
153
+ "Train Start": f["train_start"],
154
+ "Train End": f["train_end"] - 1,
155
+ "Test Start": f["test_start"],
156
+ "Test End": f["test_end"] - 1,
157
+ "Train Size": f["train_end"] - f["train_start"],
158
+ "MAE": round(mae, 2),
159
+ "RMSE": round(rmse, 2),
160
+ "MAPE (%)": round(mape, 2) if not np.isnan(mape) else "N/A",
161
+ })
162
+ return pd.DataFrame(records)
163
+
164
+ # ---------------------------------------------------------------------------
165
+ # Plotting
166
+ # ---------------------------------------------------------------------------
167
+
168
+ def _make_figure(df, folds, current_fold, show_all, strategy_label):
169
+ """Build the matplotlib figure with either one or two panels."""
170
+ y = df["y"].values
171
+ n = len(y)
172
+ x = np.arange(n)
173
+
174
+ if show_all:
175
+ fig, ax_gantt = plt.subplots(figsize=(12, 5), facecolor=WHITE)
176
+ _draw_gantt(ax_gantt, folds, current_fold=None, n=n, highlight=False)
177
+ ax_gantt.set_title(
178
+ f"All {len(folds)} Folds — {strategy_label}",
179
+ fontsize=14, fontweight="bold", pad=10,
180
+ )
181
+ fig.tight_layout(pad=2.0)
182
+ return fig
183
+
184
+ # Two-panel layout
185
+ fig, (ax_ts, ax_gantt) = plt.subplots(
186
+ 2, 1, figsize=(12, 7.5),
187
+ gridspec_kw={"height_ratios": [2, 1.2]},
188
+ facecolor=WHITE,
189
+ )
190
+
191
+ fold_idx = max(0, min(current_fold - 1, len(folds) - 1))
192
+ f = folds[fold_idx]
193
+ K = len(folds)
194
+
195
+ # --- Top panel: time series with CV split ---
196
+ ax_ts.plot(x, y, color=GRAY, linewidth=1.2, zorder=1, label="Full series")
197
+
198
+ # Training segment
199
+ train_x = x[f["train_start"]:f["train_end"]]
200
+ train_y = y[f["train_start"]:f["train_end"]]
201
+ ax_ts.plot(train_x, train_y, color=TEAL, linewidth=2.4, zorder=3, label="Training")
202
+
203
+ # Test segment
204
+ test_x = x[f["test_start"]:f["test_end"]]
205
+ test_y = y[f["test_start"]:f["test_end"]]
206
+ ax_ts.plot(test_x, test_y, color=RED, linewidth=2.4, zorder=3, label="Test / Validation")
207
+
208
+ # Vertical boundary lines
209
+ ax_ts.axvline(f["train_end"] - 0.5, color=DARK_GRAY, linestyle="--", linewidth=1, zorder=2, alpha=0.7)
210
+ if f["train_start"] > 0:
211
+ ax_ts.axvline(f["train_start"] - 0.5, color=DARK_GRAY, linestyle=":", linewidth=1, zorder=2, alpha=0.5)
212
+
213
+ ax_ts.set_title(
214
+ f"Fold {f['fold']} of {K} — {strategy_label}",
215
+ fontsize=14, fontweight="bold", pad=10,
216
+ )
217
+ ax_ts.set_xlabel("Time Index", fontsize=11)
218
+ ax_ts.set_ylabel("y", fontsize=11)
219
+ ax_ts.legend(loc="upper left", fontsize=9, framealpha=0.9)
220
+ ax_ts.set_xlim(-1, n + 1)
221
+
222
+ # --- Bottom panel: Gantt-style fold map ---
223
+ _draw_gantt(ax_gantt, folds, current_fold=f["fold"], n=n, highlight=True)
224
+
225
+ fig.tight_layout(pad=2.0)
226
+ return fig
227
+
228
+
229
+ def _draw_gantt(ax, folds, current_fold, n, highlight):
230
+ """Draw the Gantt-style fold map on the given axes."""
231
+ K = len(folds)
232
+ bar_height = 0.6
233
+ highlight_height = 0.85
234
+
235
+ for f in folds:
236
+ k = f["fold"]
237
+ is_current = highlight and (k == current_fold)
238
+ h = highlight_height if is_current else bar_height
239
+ lw = 1.8 if is_current else 0.5
240
+ edge = "black" if is_current else "#666666"
241
+
242
+ # Training bar
243
+ ax.barh(
244
+ k, f["train_end"] - f["train_start"], left=f["train_start"],
245
+ height=h, color=TEAL, edgecolor=edge, linewidth=lw, zorder=3 if is_current else 2,
246
+ )
247
+ # Test bar
248
+ ax.barh(
249
+ k, f["test_end"] - f["test_start"], left=f["test_start"],
250
+ height=h, color=RED, edgecolor=edge, linewidth=lw, zorder=3 if is_current else 2,
251
+ )
252
+
253
+ ax.set_xlabel("Time Index", fontsize=11)
254
+ ax.set_ylabel("Fold", fontsize=11)
255
+ ax.set_xlim(-1, n + 1)
256
+ ax.set_ylim(0.2, K + 0.8)
257
+ ax.set_yticks(range(1, K + 1))
258
+ ax.invert_yaxis()
259
+
260
+ # Legend
261
+ handles = [
262
+ mpatches.Patch(facecolor=TEAL, edgecolor="#333", label="Training"),
263
+ mpatches.Patch(facecolor=RED, edgecolor="#333", label="Test"),
264
+ ]
265
+ ax.legend(handles=handles, loc="upper right", fontsize=9, framealpha=0.9)
266
+
267
+ # ---------------------------------------------------------------------------
268
+ # Summary text
269
+ # ---------------------------------------------------------------------------
270
+
271
+ def build_summary(folds, strategy, initial, step, metrics_df):
272
+ K = len(folds)
273
+ if K == 0:
274
+ return "**No valid folds.** Adjust the parameters so that at least one fold fits within the data."
275
+
276
+ # Average metrics
277
+ numeric_cols = ["MAE", "RMSE"]
278
+ avgs = {c: metrics_df[c].mean() for c in numeric_cols}
279
+ mape_vals = pd.to_numeric(metrics_df["MAPE (%)"], errors="coerce")
280
+ avg_mape = mape_vals.mean()
281
+
282
+ lines = [
283
+ f"### Summary",
284
+ f"- **Total folds:** {K}",
285
+ f"- **Average MAE:** {avgs['MAE']:.2f}",
286
+ f"- **Average RMSE:** {avgs['RMSE']:.2f}",
287
+ f"- **Average MAPE:** {avg_mape:.2f}%" if not np.isnan(avg_mape) else "- **Average MAPE:** N/A",
288
+ "",
289
+ ]
290
+ if strategy == "Expanding Window":
291
+ last_train = initial + (K - 1) * step
292
+ lines.append(
293
+ f"*Expanding window*: training set grows from **{initial}** to "
294
+ f"**{last_train}** observations across {K} folds."
295
+ )
296
+ else:
297
+ ws = folds[0]["train_end"] - folds[0]["train_start"]
298
+ lines.append(
299
+ f"*Rolling / sliding window*: fixed training size of **{ws}** "
300
+ f"observations slides forward across {K} folds."
301
+ )
302
+ lines.append("")
303
+ lines.append(
304
+ "Forecasts use a **naive model** (last training value repeated over "
305
+ "the horizon) to keep focus on the CV visualization concept."
306
+ )
307
+ return "\n".join(lines)
308
+
309
+ # ---------------------------------------------------------------------------
310
+ # Main update callback
311
+ # ---------------------------------------------------------------------------
312
+
313
+ def load_dataset(name, file_obj):
314
+ """Return a DataFrame given the selector value and optional upload."""
315
+ if name == "Upload CSV" and file_obj is not None:
316
+ raw = pd.read_csv(file_obj.name if hasattr(file_obj, "name") else file_obj)
317
+ if "ds" not in raw.columns or "y" not in raw.columns:
318
+ raise gr.Error("Uploaded CSV must contain columns named 'ds' and 'y'.")
319
+ return raw[["ds", "y"]].copy()
320
+ if name in DATASETS:
321
+ return DATASETS[name]()
322
+ # Fallback
323
+ return DATASETS["Simple Trend + Noise"]()
324
+
325
+
326
+ def update_total_folds(dataset_name, file_obj, strategy, initial, horizon, step_size, window_size):
327
+ """Return the max number of folds so we can update the fold slider range."""
328
+ try:
329
+ df = load_dataset(dataset_name, file_obj)
330
+ except Exception:
331
+ return gr.update(maximum=1, value=1)
332
+ n = len(df)
333
+ folds = compute_folds(n, initial, horizon, step_size, strategy, window_size)
334
+ K = max(len(folds), 1)
335
+ return gr.update(maximum=K, value=min(1, K))
336
+
337
+
338
+ def run_visualizer(dataset_name, file_obj, strategy, initial, horizon, step_size, window_size, current_fold, show_all):
339
+ """Core callback — returns (figure, metrics_df, summary_md)."""
340
+ try:
341
+ df = load_dataset(dataset_name, file_obj)
342
+ except gr.Error:
343
+ raise
344
+ except Exception as exc:
345
+ raise gr.Error(f"Could not load data: {exc}")
346
+
347
+ n = len(df)
348
+ folds = compute_folds(n, initial, horizon, step_size, strategy, window_size)
349
+ K = len(folds)
350
+
351
+ if K == 0:
352
+ fig, ax = plt.subplots(figsize=(12, 4), facecolor=WHITE)
353
+ ax.text(0.5, 0.5, "No valid folds — adjust parameters.",
354
+ ha="center", va="center", fontsize=14, transform=ax.transAxes)
355
+ ax.axis("off")
356
+ empty_df = pd.DataFrame(columns=[
357
+ "Fold", "Train Start", "Train End", "Test Start", "Test End",
358
+ "Train Size", "MAE", "RMSE", "MAPE (%)"
359
+ ])
360
+ summary = "**No valid folds.** Reduce `initial` + `horizon` or increase data length."
361
+ return fig, empty_df, summary
362
+
363
+ strategy_label = strategy
364
+ fig = _make_figure(df, folds, current_fold, show_all, strategy_label)
365
+ metrics_df = naive_metrics(df["y"], folds)
366
+
367
+ # Append average row
368
+ avg_row = {
369
+ "Fold": "Avg",
370
+ "Train Start": "",
371
+ "Train End": "",
372
+ "Test Start": "",
373
+ "Test End": "",
374
+ "Train Size": "",
375
+ "MAE": round(metrics_df["MAE"].mean(), 2),
376
+ "RMSE": round(metrics_df["RMSE"].mean(), 2),
377
+ }
378
+ mape_vals = pd.to_numeric(metrics_df["MAPE (%)"], errors="coerce")
379
+ avg_row["MAPE (%)"] = round(mape_vals.mean(), 2) if not mape_vals.isna().all() else "N/A"
380
+ avg_df = pd.concat([metrics_df, pd.DataFrame([avg_row])], ignore_index=True)
381
+
382
+ summary = build_summary(folds, strategy, initial, step_size, metrics_df)
383
+ plt.close("all")
384
+ return fig, avg_df, summary
385
+
386
+ # ---------------------------------------------------------------------------
387
+ # Gradio UI
388
+ # ---------------------------------------------------------------------------
389
+
390
+ def build_app():
391
+ theme = gr.themes.Soft(
392
+ primary_hue=gr.themes.Color(
393
+ c50="#eafaf9", c100="#d4f5f3", c200="#aaecea",
394
+ c300="#84d6d3", c400="#5ec4c0", c500="#3eaea9",
395
+ c600="#2e938e", c700="#237873", c800="#1a5d59",
396
+ c900="#12423f", c950="#0a2725",
397
+ ),
398
+ secondary_hue=gr.themes.Color(
399
+ c50="#fef2f3", c100="#fde6e8", c200="#fbd0d5",
400
+ c300="#f7a4ae", c400="#f17182", c500="#C3142D",
401
+ c600="#b01228", c700="#8B0E1E", c800="#6e0b18",
402
+ c900="#5c0d17", c950="#33040a",
403
+ ),
404
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
405
+ )
406
+
407
+ with gr.Blocks(
408
+ title="Cross-Validation Visualizer v1.0",
409
+ theme=theme,
410
+ css="""
411
+ .gradio-container { max-width: 1280px !important; margin: auto; }
412
+ footer { display: none !important; }
413
+ .gr-button-primary { background: #C3142D !important; border: none !important; }
414
+ .gr-button-primary:hover { background: #8B0E1E !important; }
415
+ .gr-button-secondary { border-color: #84d6d3 !important; color: #84d6d3 !important; }
416
+ .gr-button-secondary:hover { background: #84d6d3 !important; color: white !important; }
417
+ .gr-input:focus { border-color: #84d6d3 !important; box-shadow: 0 0 0 2px rgba(132,214,211,0.2) !important; }
418
+ """,
419
+ ) as demo:
420
+ gr.HTML("""
421
+ <div style="display: flex; align-items: center; gap: 16px; padding: 16px 24px;
422
+ background: linear-gradient(135deg, #C3142D 0%, #8B0E1E 100%);
423
+ border-radius: 12px; margin-bottom: 16px; box-shadow: 0 4px 12px rgba(0,0,0,0.15);">
424
+ <img src="https://miamioh.edu/miami-brand/_files/images/system/logo-usage/minimum-size/beveled-m-min-size.png"
425
+ alt="Miami University" style="height: 56px;">
426
+ <div>
427
+ <h1 style="margin: 0; color: white; font-size: 24px; font-weight: 700; letter-spacing: -0.5px;">
428
+ Cross-Validation Visualizer v1.0
429
+ </h1>
430
+ <p style="margin: 4px 0 0; color: rgba(255,255,255,0.85); font-size: 14px;">
431
+ ISA 444: Business Forecasting &middot; Farmer School of Business &middot; Miami University
432
+ </p>
433
+ </div>
434
+ </div>
435
+ """)
436
+
437
+ gr.HTML("""
438
+ <div style="background: #f8f9fa; border-left: 4px solid #84d6d3; padding: 12px 16px;
439
+ border-radius: 0 8px 8px 0; margin-bottom: 16px; font-size: 14px; color: #585E60;">
440
+ Visualize time-series cross-validation strategies (expanding window and rolling/sliding window)
441
+ with animated fold progression and per-fold accuracy metrics using a naive forecast.
442
+ Understand how forecast accuracy is evaluated across folds.
443
+ </div>
444
+ """)
445
+
446
+ with gr.Row():
447
+ # ---- Left column: controls ----
448
+ with gr.Column(scale=1, min_width=300):
449
+ gr.Markdown("### Data")
450
+ dataset_dd = gr.Dropdown(
451
+ choices=["Airline Passengers", "Ohio Employment",
452
+ "Simple Trend + Noise", "Upload CSV"],
453
+ value="Simple Trend + Noise",
454
+ label="Dataset",
455
+ )
456
+ csv_upload = gr.File(
457
+ label="Upload CSV (columns: ds, y)",
458
+ file_types=[".csv"],
459
+ visible=False,
460
+ )
461
+
462
+ gr.Markdown("### CV Strategy")
463
+ strategy_radio = gr.Radio(
464
+ choices=["Expanding Window", "Rolling/Sliding Window"],
465
+ value="Expanding Window",
466
+ label="Strategy",
467
+ )
468
+
469
+ gr.Markdown("### Parameters")
470
+ initial_slider = gr.Slider(
471
+ minimum=12, maximum=120, value=60, step=1,
472
+ label="initial (initial training size)",
473
+ )
474
+ horizon_slider = gr.Slider(
475
+ minimum=1, maximum=24, value=12, step=1,
476
+ label="horizon (forecast horizon)",
477
+ )
478
+ step_slider = gr.Slider(
479
+ minimum=1, maximum=12, value=1, step=1,
480
+ label="step (step size between folds)",
481
+ )
482
+ window_slider = gr.Slider(
483
+ minimum=12, maximum=120, value=60, step=1,
484
+ label="window_size (rolling window only)",
485
+ visible=False,
486
+ )
487
+
488
+ gr.Markdown("### Animation Controls")
489
+ fold_slider = gr.Slider(
490
+ minimum=1, maximum=1, value=1, step=1,
491
+ label="Current Fold",
492
+ )
493
+ with gr.Row():
494
+ play_btn = gr.Button("Play Animation", variant="primary")
495
+ stop_btn = gr.Button("Stop", variant="stop")
496
+ show_all_cb = gr.Checkbox(label="Show All Folds", value=False)
497
+
498
+ # ---- Right column: outputs ----
499
+ with gr.Column(scale=2, min_width=500):
500
+ plot_output = gr.Plot(label="Visualization")
501
+ metrics_output = gr.Dataframe(
502
+ label="Per-Fold Metrics (Naive Forecast)",
503
+ interactive=False,
504
+ wrap=True,
505
+ )
506
+ summary_output = gr.Markdown(label="Summary")
507
+
508
+ # ---- Visibility toggles ----
509
+ def toggle_csv_upload(name):
510
+ return gr.update(visible=(name == "Upload CSV"))
511
+
512
+ dataset_dd.change(toggle_csv_upload, inputs=[dataset_dd], outputs=[csv_upload])
513
+
514
+ def toggle_window_slider(strategy):
515
+ return gr.update(visible=(strategy == "Rolling/Sliding Window"))
516
+
517
+ strategy_radio.change(toggle_window_slider, inputs=[strategy_radio], outputs=[window_slider])
518
+
519
+ # ---- Gather all control inputs ----
520
+ all_inputs = [
521
+ dataset_dd, csv_upload, strategy_radio,
522
+ initial_slider, horizon_slider, step_slider,
523
+ window_slider, fold_slider, show_all_cb,
524
+ ]
525
+ all_outputs = [plot_output, metrics_output, summary_output]
526
+
527
+ # Helper to also refresh the fold slider range
528
+ fold_range_inputs = [
529
+ dataset_dd, csv_upload, strategy_radio,
530
+ initial_slider, horizon_slider, step_slider, window_slider,
531
+ ]
532
+
533
+ def refresh_and_run(dataset_name, file_obj, strategy, initial, horizon,
534
+ step_size, window_size, current_fold, show_all):
535
+ """Update fold slider range, clamp current_fold, then run."""
536
+ try:
537
+ df = load_dataset(dataset_name, file_obj)
538
+ except Exception:
539
+ df = DATASETS["Simple Trend + Noise"]()
540
+ n = len(df)
541
+ folds = compute_folds(n, initial, horizon, step_size, strategy, window_size)
542
+ K = max(len(folds), 1)
543
+ current_fold = max(1, min(current_fold, K))
544
+ fig, metrics, summary = run_visualizer(
545
+ dataset_name, file_obj, strategy, initial, horizon,
546
+ step_size, window_size, current_fold, show_all,
547
+ )
548
+ return gr.update(maximum=K, value=current_fold), fig, metrics, summary
549
+
550
+ combined_outputs = [fold_slider] + all_outputs
551
+
552
+ # Trigger on any parameter change
553
+ for ctrl in [dataset_dd, csv_upload, strategy_radio, initial_slider,
554
+ horizon_slider, step_slider, window_slider, show_all_cb]:
555
+ ctrl.change(
556
+ refresh_and_run,
557
+ inputs=all_inputs,
558
+ outputs=combined_outputs,
559
+ )
560
+
561
+ # Fold slider change (just re-render, no range update needed)
562
+ fold_slider.release(
563
+ run_visualizer,
564
+ inputs=all_inputs,
565
+ outputs=all_outputs,
566
+ )
567
+
568
+ # ---- Animation via a background thread ----
569
+ # We use a gr.State to hold the "playing" flag
570
+ animation_state = gr.State({"playing": False})
571
+
572
+ def start_animation(state, dataset_name, file_obj, strategy, initial,
573
+ horizon, step_size, window_size, current_fold, show_all):
574
+ state["playing"] = True
575
+ try:
576
+ df = load_dataset(dataset_name, file_obj)
577
+ except Exception:
578
+ df = DATASETS["Simple Trend + Noise"]()
579
+ n = len(df)
580
+ folds = compute_folds(n, initial, horizon, step_size, strategy, window_size)
581
+ K = max(len(folds), 1)
582
+
583
+ for k in range(1, K + 1):
584
+ if not state.get("playing", False):
585
+ break
586
+ fig, metrics, summary = run_visualizer(
587
+ dataset_name, file_obj, strategy, initial, horizon,
588
+ step_size, window_size, k, False,
589
+ )
590
+ yield state, gr.update(maximum=K, value=k), fig, metrics, summary
591
+ time.sleep(1.0)
592
+ state["playing"] = False
593
+ yield state, gr.update(), fig, metrics, summary
594
+
595
+ def stop_animation(state):
596
+ state["playing"] = False
597
+ return state
598
+
599
+ play_btn.click(
600
+ start_animation,
601
+ inputs=[animation_state] + all_inputs,
602
+ outputs=[animation_state, fold_slider] + all_outputs,
603
+ )
604
+ stop_btn.click(stop_animation, inputs=[animation_state], outputs=[animation_state])
605
+
606
+ # ---- Initial render on load ----
607
+ demo.load(
608
+ refresh_and_run,
609
+ inputs=all_inputs,
610
+ outputs=combined_outputs,
611
+ )
612
+
613
+ gr.HTML("""
614
+ <div style="margin-top: 24px; padding: 16px; background: #f8f9fa; border-radius: 8px;
615
+ text-align: center; font-size: 13px; color: #585E60; border-top: 2px solid #84d6d3;">
616
+ <div style="margin-bottom: 4px;">
617
+ <strong style="color: #C3142D;">Developed by</strong>
618
+ <a href="https://miamioh.edu/fsb/directory/?up=/directory/megahefm"
619
+ style="color: #84d6d3; text-decoration: none; font-weight: 600;">
620
+ Fadel M. Megahed
621
+ </a>
622
+ &middot; Gloss Professor of Analytics &middot; Miami University
623
+ </div>
624
+ <div style="font-size: 12px; color: #888;">
625
+ Version 1.0.0 &middot; Spring 2026 &middot;
626
+ <a href="https://github.com/fmegahed" style="color: #84d6d3; text-decoration: none;">GitHub</a> &middot;
627
+ <a href="https://www.linkedin.com/in/fmegahed/" style="color: #84d6d3; text-decoration: none;">LinkedIn</a>
628
+ </div>
629
+ </div>
630
+ """)
631
+
632
+ return demo
633
+
634
+
635
+ if __name__ == "__main__":
636
+ app = build_app()
637
+ app.launch()