Spaces:
Sleeping
Sleeping
| """ | |
| Cross-Validation Visualizer | |
| ============================ | |
| Visualize time-series cross-validation strategies (expanding window and | |
| rolling/sliding window) with animated fold progression and per-fold | |
| accuracy metrics using a naive forecast. | |
| Part of ISA 444: Business Forecasting — Spring 2026, Miami University. | |
| Deployed to HuggingFace Spaces as fmegahed/cv-visualizer. | |
| """ | |
| import io | |
| import time | |
| import threading | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib.lines import Line2D | |
| # --------------------------------------------------------------------------- | |
| # Color palette | |
| # --------------------------------------------------------------------------- | |
| TEAL = "#84d6d3" | |
| RED = "#C3142D" | |
| GRAY = "#CCCCCC" | |
| DARK_GRAY = "#888888" | |
| WHITE = "#FFFFFF" | |
| # --------------------------------------------------------------------------- | |
| # Dataset generators | |
| # --------------------------------------------------------------------------- | |
| def _airline_passengers() -> pd.DataFrame: | |
| """Classic Box-Jenkins airline passengers (1949-1960, 144 obs).""" | |
| # Reproduce the well-known series with a multiplicative seasonal pattern. | |
| np.random.seed(42) | |
| n = 144 | |
| t = np.arange(n) | |
| trend = 132 + 2.4 * t | |
| seasonal_period = 12 | |
| seasonal = 40 * np.sin(2 * np.pi * t / seasonal_period) | |
| # Multiplicative-style growth in amplitude | |
| amplitude_growth = 1 + 0.006 * t | |
| y = trend * amplitude_growth + seasonal * amplitude_growth | |
| # Add a touch of noise | |
| y += np.random.normal(0, 5, n) | |
| dates = pd.date_range("1949-01-01", periods=n, freq="MS") | |
| return pd.DataFrame({"ds": dates, "y": np.round(y, 1)}) | |
| def _ohio_employment() -> pd.DataFrame: | |
| """Synthetic Ohio monthly employment (2010-2024, 180 obs).""" | |
| np.random.seed(123) | |
| n = 180 | |
| t = np.arange(n) | |
| trend = 5200 + 3.5 * t | |
| seasonal = 120 * np.sin(2 * np.pi * t / 12) + 60 * np.cos(2 * np.pi * t / 6) | |
| # Covid dip around index 120-130 (~ early 2020) | |
| dip = np.zeros(n) | |
| dip[120:132] = -np.array([200, 800, 1100, 900, 600, 400, 300, 200, 150, 100, 60, 30]) | |
| noise = np.random.normal(0, 40, n) | |
| y = trend + seasonal + dip + noise | |
| dates = pd.date_range("2010-01-01", periods=n, freq="MS") | |
| return pd.DataFrame({"ds": dates, "y": np.round(y, 1)}) | |
| def _simple_trend() -> pd.DataFrame: | |
| """Simple linear trend + noise (120 obs) for pedagogical clarity.""" | |
| np.random.seed(7) | |
| n = 120 | |
| t = np.arange(n) | |
| y = 0.5 * t + np.random.normal(0, 2, n) | |
| dates = pd.date_range("2015-01-01", periods=n, freq="MS") | |
| return pd.DataFrame({"ds": dates, "y": np.round(y, 2)}) | |
| DATASETS = { | |
| "Airline Passengers": _airline_passengers, | |
| "Ohio Employment": _ohio_employment, | |
| "Simple Trend + Noise": _simple_trend, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Fold computation | |
| # --------------------------------------------------------------------------- | |
| def compute_folds(n, initial, horizon, step, strategy, window_size=None): | |
| """Return a list of fold dicts with train/test index ranges.""" | |
| folds = [] | |
| max_possible = n # safety upper bound | |
| if strategy == "Expanding Window": | |
| start = 0 | |
| for k in range(max_possible): | |
| train_end = initial + k * step | |
| test_start = train_end | |
| test_end = test_start + horizon | |
| if test_end > n: | |
| break | |
| folds.append({ | |
| "fold": k + 1, | |
| "train_start": start, | |
| "train_end": train_end, | |
| "test_start": test_start, | |
| "test_end": test_end, | |
| }) | |
| else: # Rolling / Sliding Window | |
| ws = window_size if window_size is not None else initial | |
| for k in range(max_possible): | |
| train_start = k * step | |
| train_end = train_start + ws | |
| test_start = train_end | |
| test_end = test_start + horizon | |
| if test_end > n: | |
| break | |
| folds.append({ | |
| "fold": k + 1, | |
| "train_start": train_start, | |
| "train_end": train_end, | |
| "test_start": test_start, | |
| "test_end": test_end, | |
| }) | |
| return folds | |
| # --------------------------------------------------------------------------- | |
| # Naive forecast & metrics | |
| # --------------------------------------------------------------------------- | |
| def naive_metrics(y_series, folds): | |
| """Compute MAE, RMSE, MAPE per fold using a naive (last-value) forecast.""" | |
| records = [] | |
| y = y_series.values if hasattr(y_series, "values") else np.array(y_series) | |
| for f in folds: | |
| train_vals = y[f["train_start"]:f["train_end"]] | |
| test_vals = y[f["test_start"]:f["test_end"]] | |
| forecast = np.full_like(test_vals, train_vals[-1], dtype=float) | |
| errors = test_vals - forecast | |
| abs_errors = np.abs(errors) | |
| mae = np.mean(abs_errors) | |
| rmse = np.sqrt(np.mean(errors ** 2)) | |
| # MAPE — guard against zeros | |
| nonzero = np.abs(test_vals) > 1e-8 | |
| if nonzero.any(): | |
| mape = np.mean(np.abs(errors[nonzero] / test_vals[nonzero])) * 100 | |
| else: | |
| mape = np.nan | |
| records.append({ | |
| "Fold": f["fold"], | |
| "Train Start": f["train_start"], | |
| "Train End": f["train_end"] - 1, | |
| "Test Start": f["test_start"], | |
| "Test End": f["test_end"] - 1, | |
| "Train Size": f["train_end"] - f["train_start"], | |
| "MAE": round(mae, 2), | |
| "RMSE": round(rmse, 2), | |
| "MAPE (%)": round(mape, 2) if not np.isnan(mape) else "N/A", | |
| }) | |
| return pd.DataFrame(records) | |
| # --------------------------------------------------------------------------- | |
| # Plotting | |
| # --------------------------------------------------------------------------- | |
| def _make_figure(df, folds, current_fold, show_all, strategy_label): | |
| """Build the matplotlib figure with either one or two panels.""" | |
| y = df["y"].values | |
| n = len(y) | |
| x = np.arange(n) | |
| if show_all: | |
| fig, ax_gantt = plt.subplots(figsize=(12, 5), facecolor=WHITE) | |
| _draw_gantt(ax_gantt, folds, current_fold=None, n=n, highlight=False) | |
| ax_gantt.set_title( | |
| f"All {len(folds)} Folds — {strategy_label}", | |
| fontsize=14, fontweight="bold", pad=10, | |
| ) | |
| fig.tight_layout(pad=2.0) | |
| return fig | |
| # Two-panel layout | |
| fig, (ax_ts, ax_gantt) = plt.subplots( | |
| 2, 1, figsize=(12, 7.5), | |
| gridspec_kw={"height_ratios": [2, 1.2]}, | |
| facecolor=WHITE, | |
| ) | |
| fold_idx = max(0, min(current_fold - 1, len(folds) - 1)) | |
| f = folds[fold_idx] | |
| K = len(folds) | |
| # --- Top panel: time series with CV split --- | |
| ax_ts.plot(x, y, color=GRAY, linewidth=1.2, zorder=1, label="Full series") | |
| # Training segment | |
| train_x = x[f["train_start"]:f["train_end"]] | |
| train_y = y[f["train_start"]:f["train_end"]] | |
| ax_ts.plot(train_x, train_y, color=TEAL, linewidth=2.4, zorder=3, label="Training") | |
| # Test segment | |
| test_x = x[f["test_start"]:f["test_end"]] | |
| test_y = y[f["test_start"]:f["test_end"]] | |
| ax_ts.plot(test_x, test_y, color=RED, linewidth=2.4, zorder=3, label="Test / Validation") | |
| # Vertical boundary lines | |
| ax_ts.axvline(f["train_end"] - 0.5, color=DARK_GRAY, linestyle="--", linewidth=1, zorder=2, alpha=0.7) | |
| if f["train_start"] > 0: | |
| ax_ts.axvline(f["train_start"] - 0.5, color=DARK_GRAY, linestyle=":", linewidth=1, zorder=2, alpha=0.5) | |
| ax_ts.set_title( | |
| f"Fold {f['fold']} of {K} — {strategy_label}", | |
| fontsize=14, fontweight="bold", pad=10, | |
| ) | |
| ax_ts.set_xlabel("Time Index", fontsize=11) | |
| ax_ts.set_ylabel("y", fontsize=11) | |
| ax_ts.legend(loc="upper left", fontsize=9, framealpha=0.9) | |
| ax_ts.set_xlim(-1, n + 1) | |
| # --- Bottom panel: Gantt-style fold map --- | |
| _draw_gantt(ax_gantt, folds, current_fold=f["fold"], n=n, highlight=True) | |
| fig.tight_layout(pad=2.0) | |
| return fig | |
| def _draw_gantt(ax, folds, current_fold, n, highlight): | |
| """Draw the Gantt-style fold map on the given axes.""" | |
| K = len(folds) | |
| bar_height = 0.6 | |
| highlight_height = 0.85 | |
| for f in folds: | |
| k = f["fold"] | |
| is_current = highlight and (k == current_fold) | |
| h = highlight_height if is_current else bar_height | |
| lw = 1.8 if is_current else 0.5 | |
| edge = "black" if is_current else "#666666" | |
| # Training bar | |
| ax.barh( | |
| k, f["train_end"] - f["train_start"], left=f["train_start"], | |
| height=h, color=TEAL, edgecolor=edge, linewidth=lw, zorder=3 if is_current else 2, | |
| ) | |
| # Test bar | |
| ax.barh( | |
| k, f["test_end"] - f["test_start"], left=f["test_start"], | |
| height=h, color=RED, edgecolor=edge, linewidth=lw, zorder=3 if is_current else 2, | |
| ) | |
| ax.set_xlabel("Time Index", fontsize=11) | |
| ax.set_ylabel("Fold", fontsize=11) | |
| ax.set_xlim(-1, n + 1) | |
| ax.set_ylim(0.2, K + 0.8) | |
| ax.set_yticks(range(1, K + 1)) | |
| ax.invert_yaxis() | |
| # Legend | |
| handles = [ | |
| mpatches.Patch(facecolor=TEAL, edgecolor="#333", label="Training"), | |
| mpatches.Patch(facecolor=RED, edgecolor="#333", label="Test"), | |
| ] | |
| ax.legend(handles=handles, loc="upper right", fontsize=9, framealpha=0.9) | |
| # --------------------------------------------------------------------------- | |
| # Summary text | |
| # --------------------------------------------------------------------------- | |
| def build_summary(folds, strategy, initial, step, metrics_df): | |
| K = len(folds) | |
| if K == 0: | |
| return "**No valid folds.** Adjust the parameters so that at least one fold fits within the data." | |
| # Average metrics | |
| numeric_cols = ["MAE", "RMSE"] | |
| avgs = {c: metrics_df[c].mean() for c in numeric_cols} | |
| mape_vals = pd.to_numeric(metrics_df["MAPE (%)"], errors="coerce") | |
| avg_mape = mape_vals.mean() | |
| lines = [ | |
| f"### Summary", | |
| f"- **Total folds:** {K}", | |
| f"- **Average MAE:** {avgs['MAE']:.2f}", | |
| f"- **Average RMSE:** {avgs['RMSE']:.2f}", | |
| f"- **Average MAPE:** {avg_mape:.2f}%" if not np.isnan(avg_mape) else "- **Average MAPE:** N/A", | |
| "", | |
| ] | |
| if strategy == "Expanding Window": | |
| last_train = initial + (K - 1) * step | |
| lines.append( | |
| f"*Expanding window*: training set grows from **{initial}** to " | |
| f"**{last_train}** observations across {K} folds." | |
| ) | |
| else: | |
| ws = folds[0]["train_end"] - folds[0]["train_start"] | |
| lines.append( | |
| f"*Rolling / sliding window*: fixed training size of **{ws}** " | |
| f"observations slides forward across {K} folds." | |
| ) | |
| lines.append("") | |
| lines.append( | |
| "Forecasts use a **naive model** (last training value repeated over " | |
| "the horizon) to keep focus on the CV visualization concept." | |
| ) | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Main update callback | |
| # --------------------------------------------------------------------------- | |
| def load_dataset(name, file_obj): | |
| """Return a DataFrame given the selector value and optional upload.""" | |
| if name == "Upload CSV" and file_obj is not None: | |
| raw = pd.read_csv(file_obj.name if hasattr(file_obj, "name") else file_obj) | |
| if "ds" not in raw.columns or "y" not in raw.columns: | |
| raise gr.Error("Uploaded CSV must contain columns named 'ds' and 'y'.") | |
| return raw[["ds", "y"]].copy() | |
| if name in DATASETS: | |
| return DATASETS[name]() | |
| # Fallback | |
| return DATASETS["Simple Trend + Noise"]() | |
| def update_total_folds(dataset_name, file_obj, strategy, initial, horizon, step_size, window_size): | |
| """Return the max number of folds so we can update the fold slider range.""" | |
| try: | |
| df = load_dataset(dataset_name, file_obj) | |
| except Exception: | |
| return gr.update(maximum=1, value=1) | |
| n = len(df) | |
| folds = compute_folds(n, initial, horizon, step_size, strategy, window_size) | |
| K = max(len(folds), 1) | |
| return gr.update(maximum=K, value=min(1, K)) | |
| def run_visualizer(dataset_name, file_obj, strategy, initial, horizon, step_size, window_size, current_fold, show_all): | |
| """Core callback — returns (figure, metrics_df, summary_md).""" | |
| try: | |
| df = load_dataset(dataset_name, file_obj) | |
| except gr.Error: | |
| raise | |
| except Exception as exc: | |
| raise gr.Error(f"Could not load data: {exc}") | |
| n = len(df) | |
| folds = compute_folds(n, initial, horizon, step_size, strategy, window_size) | |
| K = len(folds) | |
| if K == 0: | |
| fig, ax = plt.subplots(figsize=(12, 4), facecolor=WHITE) | |
| ax.text(0.5, 0.5, "No valid folds — adjust parameters.", | |
| ha="center", va="center", fontsize=14, transform=ax.transAxes) | |
| ax.axis("off") | |
| empty_df = pd.DataFrame(columns=[ | |
| "Fold", "Train Start", "Train End", "Test Start", "Test End", | |
| "Train Size", "MAE", "RMSE", "MAPE (%)" | |
| ]) | |
| summary = "**No valid folds.** Reduce `initial` + `horizon` or increase data length." | |
| return fig, empty_df, summary | |
| strategy_label = strategy | |
| fig = _make_figure(df, folds, current_fold, show_all, strategy_label) | |
| metrics_df = naive_metrics(df["y"], folds) | |
| # Append average row | |
| avg_row = { | |
| "Fold": "Avg", | |
| "Train Start": "", | |
| "Train End": "", | |
| "Test Start": "", | |
| "Test End": "", | |
| "Train Size": "", | |
| "MAE": round(metrics_df["MAE"].mean(), 2), | |
| "RMSE": round(metrics_df["RMSE"].mean(), 2), | |
| } | |
| mape_vals = pd.to_numeric(metrics_df["MAPE (%)"], errors="coerce") | |
| avg_row["MAPE (%)"] = round(mape_vals.mean(), 2) if not mape_vals.isna().all() else "N/A" | |
| avg_df = pd.concat([metrics_df, pd.DataFrame([avg_row])], ignore_index=True) | |
| summary = build_summary(folds, strategy, initial, step_size, metrics_df) | |
| plt.close("all") | |
| return fig, avg_df, summary | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def build_app(): | |
| theme = gr.themes.Soft( | |
| primary_hue=gr.themes.Color( | |
| c50="#eafaf9", c100="#d4f5f3", c200="#aaecea", | |
| c300="#84d6d3", c400="#5ec4c0", c500="#3eaea9", | |
| c600="#2e938e", c700="#237873", c800="#1a5d59", | |
| c900="#12423f", c950="#0a2725", | |
| ), | |
| secondary_hue=gr.themes.Color( | |
| c50="#fef2f3", c100="#fde6e8", c200="#fbd0d5", | |
| c300="#f7a4ae", c400="#f17182", c500="#C3142D", | |
| c600="#b01228", c700="#8B0E1E", c800="#6e0b18", | |
| c900="#5c0d17", c950="#33040a", | |
| ), | |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ) | |
| with gr.Blocks( | |
| title="Cross-Validation Visualizer v1.0", | |
| theme=theme, | |
| css=""" | |
| .gradio-container { max-width: 1280px !important; margin: auto; } | |
| footer { display: none !important; } | |
| .gr-button-primary { background: #C3142D !important; border: none !important; } | |
| .gr-button-primary:hover { background: #8B0E1E !important; } | |
| .gr-button-secondary { border-color: #84d6d3 !important; color: #84d6d3 !important; } | |
| .gr-button-secondary:hover { background: #84d6d3 !important; color: white !important; } | |
| .gr-input:focus { border-color: #84d6d3 !important; box-shadow: 0 0 0 2px rgba(132,214,211,0.2) !important; } | |
| """, | |
| ) as demo: | |
| gr.HTML(""" | |
| <div style="display: flex; align-items: center; gap: 16px; padding: 16px 24px; | |
| background: linear-gradient(135deg, #C3142D 0%, #8B0E1E 100%); | |
| border-radius: 12px; margin-bottom: 16px; box-shadow: 0 4px 12px rgba(0,0,0,0.15);"> | |
| <img src="https://miamioh.edu/miami-brand/_files/images/system/logo-usage/minimum-size/beveled-m-min-size.png" | |
| alt="Miami University" style="height: 56px;"> | |
| <div> | |
| <h1 style="margin: 0; color: white; font-size: 24px; font-weight: 700; letter-spacing: -0.5px;"> | |
| Cross-Validation Visualizer v1.0 | |
| </h1> | |
| <p style="margin: 4px 0 0; color: rgba(255,255,255,0.85); font-size: 14px;"> | |
| ISA 444: Business Forecasting · Farmer School of Business · Miami University | |
| </p> | |
| </div> | |
| </div> | |
| """) | |
| gr.HTML(""" | |
| <div style="background: #f8f9fa; border-left: 4px solid #84d6d3; padding: 12px 16px; | |
| border-radius: 0 8px 8px 0; margin-bottom: 16px; font-size: 14px; color: #585E60;"> | |
| Visualize time-series cross-validation strategies (expanding window and rolling/sliding window) | |
| with animated fold progression and per-fold accuracy metrics using a naive forecast. | |
| Understand how forecast accuracy is evaluated across folds. | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # ---- Left column: controls ---- | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### Data") | |
| dataset_dd = gr.Dropdown( | |
| choices=["Airline Passengers", "Ohio Employment", | |
| "Simple Trend + Noise", "Upload CSV"], | |
| value="Simple Trend + Noise", | |
| label="Dataset", | |
| ) | |
| csv_upload = gr.File( | |
| label="Upload CSV (columns: ds, y)", | |
| file_types=[".csv"], | |
| visible=False, | |
| ) | |
| gr.Markdown("### CV Strategy") | |
| strategy_radio = gr.Radio( | |
| choices=["Expanding Window", "Rolling/Sliding Window"], | |
| value="Expanding Window", | |
| label="Strategy", | |
| ) | |
| gr.Markdown("### Parameters") | |
| initial_slider = gr.Slider( | |
| minimum=12, maximum=120, value=60, step=1, | |
| label="initial (initial training size)", | |
| ) | |
| horizon_slider = gr.Slider( | |
| minimum=1, maximum=24, value=12, step=1, | |
| label="horizon (forecast horizon)", | |
| ) | |
| step_slider = gr.Slider( | |
| minimum=1, maximum=12, value=3, step=1, | |
| label="step (step size between folds)", | |
| ) | |
| window_slider = gr.Slider( | |
| minimum=12, maximum=120, value=60, step=1, | |
| label="window_size (rolling window only)", | |
| visible=False, | |
| ) | |
| gr.Markdown("### Animation Controls") | |
| fold_slider = gr.Slider( | |
| minimum=1, maximum=1, value=1, step=1, | |
| label="Current Fold", | |
| ) | |
| with gr.Row(): | |
| play_btn = gr.Button("Play Animation", variant="primary") | |
| stop_btn = gr.Button("Stop", variant="stop") | |
| show_all_cb = gr.Checkbox(label="Show All Folds", value=False) | |
| # ---- Right column: outputs ---- | |
| with gr.Column(scale=2, min_width=500): | |
| plot_output = gr.Plot(label="Visualization") | |
| metrics_output = gr.Dataframe( | |
| label="Per-Fold Metrics (Naive Forecast)", | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| summary_output = gr.Markdown(label="Summary") | |
| # ---- Visibility toggles ---- | |
| def toggle_csv_upload(name): | |
| return gr.update(visible=(name == "Upload CSV")) | |
| dataset_dd.change(toggle_csv_upload, inputs=[dataset_dd], outputs=[csv_upload]) | |
| def toggle_window_slider(strategy): | |
| return gr.update(visible=(strategy == "Rolling/Sliding Window")) | |
| strategy_radio.change(toggle_window_slider, inputs=[strategy_radio], outputs=[window_slider]) | |
| # ---- Gather all control inputs ---- | |
| all_inputs = [ | |
| dataset_dd, csv_upload, strategy_radio, | |
| initial_slider, horizon_slider, step_slider, | |
| window_slider, fold_slider, show_all_cb, | |
| ] | |
| all_outputs = [plot_output, metrics_output, summary_output] | |
| # Helper to also refresh the fold slider range | |
| fold_range_inputs = [ | |
| dataset_dd, csv_upload, strategy_radio, | |
| initial_slider, horizon_slider, step_slider, window_slider, | |
| ] | |
| def refresh_and_run(dataset_name, file_obj, strategy, initial, horizon, | |
| step_size, window_size, current_fold, show_all): | |
| """Update fold slider range, clamp current_fold, then run.""" | |
| try: | |
| df = load_dataset(dataset_name, file_obj) | |
| except Exception: | |
| df = DATASETS["Simple Trend + Noise"]() | |
| n = len(df) | |
| folds = compute_folds(n, initial, horizon, step_size, strategy, window_size) | |
| K = max(len(folds), 1) | |
| current_fold = max(1, min(current_fold, K)) | |
| fig, metrics, summary = run_visualizer( | |
| dataset_name, file_obj, strategy, initial, horizon, | |
| step_size, window_size, current_fold, show_all, | |
| ) | |
| return gr.update(maximum=K, value=current_fold), fig, metrics, summary | |
| combined_outputs = [fold_slider] + all_outputs | |
| # Trigger on any parameter change | |
| for ctrl in [dataset_dd, csv_upload, strategy_radio, initial_slider, | |
| horizon_slider, step_slider, window_slider, show_all_cb]: | |
| ctrl.change( | |
| refresh_and_run, | |
| inputs=all_inputs, | |
| outputs=combined_outputs, | |
| ) | |
| # Fold slider change (just re-render, no range update needed) | |
| fold_slider.release( | |
| run_visualizer, | |
| inputs=all_inputs, | |
| outputs=all_outputs, | |
| ) | |
| # ---- Animation via a background thread ---- | |
| # We use a gr.State to hold the "playing" flag | |
| animation_state = gr.State({"playing": False}) | |
| def start_animation(state, dataset_name, file_obj, strategy, initial, | |
| horizon, step_size, window_size, current_fold, show_all): | |
| state["playing"] = True | |
| try: | |
| df = load_dataset(dataset_name, file_obj) | |
| except Exception: | |
| df = DATASETS["Simple Trend + Noise"]() | |
| n = len(df) | |
| folds = compute_folds(n, initial, horizon, step_size, strategy, window_size) | |
| K = max(len(folds), 1) | |
| for k in range(1, K + 1): | |
| if not state.get("playing", False): | |
| break | |
| fig, metrics, summary = run_visualizer( | |
| dataset_name, file_obj, strategy, initial, horizon, | |
| step_size, window_size, k, False, | |
| ) | |
| yield state, gr.update(maximum=K, value=k), fig, metrics, summary | |
| time.sleep(1.0) | |
| state["playing"] = False | |
| yield state, gr.update(), fig, metrics, summary | |
| def stop_animation(state): | |
| state["playing"] = False | |
| return state | |
| play_btn.click( | |
| start_animation, | |
| inputs=[animation_state] + all_inputs, | |
| outputs=[animation_state, fold_slider] + all_outputs, | |
| ) | |
| stop_btn.click(stop_animation, inputs=[animation_state], outputs=[animation_state]) | |
| # ---- Initial render on load ---- | |
| demo.load( | |
| refresh_and_run, | |
| inputs=all_inputs, | |
| outputs=combined_outputs, | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 24px; padding: 16px; background: #f8f9fa; border-radius: 8px; | |
| text-align: center; font-size: 13px; color: #585E60; border-top: 2px solid #84d6d3;"> | |
| <div style="margin-bottom: 4px;"> | |
| <strong style="color: #C3142D;">Developed by</strong> | |
| <a href="https://miamioh.edu/fsb/directory/?up=/directory/megahefm" | |
| style="color: #84d6d3; text-decoration: none; font-weight: 600;"> | |
| Fadel M. Megahed | |
| </a> | |
| · Glos Professor in Business · Miami University | |
| </div> | |
| <div style="font-size: 12px; color: #888;"> | |
| Version 1.0.0 · Spring 2026 · | |
| <a href="https://github.com/fmegahed" style="color: #84d6d3; text-decoration: none;">GitHub</a> · | |
| <a href="https://www.linkedin.com/in/fmegahed/" style="color: #84d6d3; text-decoration: none;">LinkedIn</a> | |
| </div> | |
| </div> | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch(share=False) | |