fmegahed's picture
Update app.py
b1894d3 verified
"""
Decomposition Explorer
Interactive tool for exploring time-series decomposition methods.
Part of ISA 444: Business Forecasting at Miami University (Spring 2026).
Deployed to HuggingFace Spaces as fmegahed/decomposition-explorer.
"""
import io
import warnings
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from statsmodels.tsa.seasonal import STL, seasonal_decompose
# ---------------------------------------------------------------------------
# Color palette
# ---------------------------------------------------------------------------
CLR_PRIMARY = "#84d6d3" # teal
CLR_ACCENT = "#C3142D" # Miami red
CLR_TREND = "#C3142D"
CLR_SEASON = "#84d6d3"
CLR_RESID = "#666666"
# ---------------------------------------------------------------------------
# Built-in datasets
# ---------------------------------------------------------------------------
def _airline_passengers() -> pd.DataFrame:
"""Classic Box-Jenkins airline passengers (1949-1960, monthly)."""
try:
from statsmodels.datasets import co2 # noqa: F401
import statsmodels.api as sm
data = sm.datasets.get_rdataset("AirPassengers", "datasets").data
dates = pd.date_range(start="1949-01-01", periods=len(data), freq="MS")
# Prefer common value column names from Rdatasets
candidate_cols = [c for c in ["value", "passengers", "x"] if c in data.columns]
if candidate_cols:
y = pd.to_numeric(data[candidate_cols[0]], errors="coerce").to_numpy()
else:
# Fallback: take the last numeric column (and avoid obvious time columns)
numeric_cols = data.select_dtypes(include=["number"]).columns.tolist()
drop_cols = [c for c in ["time", "date", "year", "month"] if c in numeric_cols]
numeric_cols = [c for c in numeric_cols if c not in drop_cols]
if not numeric_cols:
raise ValueError(f"Could not identify value column in AirPassengers data: {list(data.columns)}")
y = pd.to_numeric(data[numeric_cols[-1]], errors="coerce").to_numpy()
return pd.DataFrame({"ds": dates, "y": y})
except Exception:
# Fallback: generate the well-known series manually
np.random.seed(0)
dates = pd.date_range("1949-01-01", "1960-12-01", freq="MS")
n = len(dates)
t = np.arange(n)
trend = 110 + 2.5 * t
seasonal_pattern = np.array(
[-24, -20, 2, -1, -5, 30, 47, 46, 14, -10, -25, -26]
)
season = np.tile(seasonal_pattern, n // 12 + 1)[:n]
noise = np.random.normal(0, 6, n)
y = trend + season * (1 + 0.02 * t) + noise
return pd.DataFrame({"ds": dates, "y": np.round(y, 1)})
def _us_retail_employment() -> pd.DataFrame:
"""Realistic synthetic monthly US retail employment (2000-2024)."""
np.random.seed(42)
dates = pd.date_range("2000-01-01", "2024-12-01", freq="MS")
n = len(dates)
t = np.arange(n)
# Trend: upward with dips around 2008-09 and 2020
trend = 15_000 + 12 * t
# 2008-2009 recession dip
recession_08 = -1400 * np.exp(-0.5 * ((t - 108) / 8) ** 2)
# 2020 COVID dip
covid_20 = -2800 * np.exp(-0.5 * ((t - 243) / 3) ** 2)
trend = trend + recession_08 + covid_20
# Seasonal pattern (retail peaks in Nov-Dec)
seasonal_pattern = np.array(
[-200, -350, -100, 50, 100, 150, 100, 80, -50, -100, 250, 500]
)
season = np.tile(seasonal_pattern, n // 12 + 1)[:n]
noise = np.random.normal(0, 60, n)
y = trend + season + noise
return pd.DataFrame({"ds": dates, "y": np.round(y, 1)})
def _ohio_nonfarm() -> pd.DataFrame:
"""Realistic synthetic monthly Ohio nonfarm employment (2010-2024)."""
np.random.seed(7)
dates = pd.date_range("2010-01-01", "2024-12-01", freq="MS")
n = len(dates)
t = np.arange(n)
trend = 5_100 + 4.5 * t
# COVID dip
covid = -650 * np.exp(-0.5 * ((t - 123) / 3) ** 2)
trend = trend + covid
seasonal_pattern = np.array(
[-80, -50, 30, 50, 70, 60, 20, 10, 30, 20, -30, -60]
)
season = np.tile(seasonal_pattern, n // 12 + 1)[:n]
noise = np.random.normal(0, 25, n)
y = trend + season + noise
return pd.DataFrame({"ds": dates, "y": np.round(y, 1)})
BUILTIN_DATASETS = {
"Airline Passengers": _airline_passengers,
"US Retail Employment": _us_retail_employment,
"Ohio Nonfarm Employment": _ohio_nonfarm,
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_dataset(name: str, csv_file) -> pd.DataFrame:
"""Return a DataFrame with columns ds (datetime) and y (float)."""
if csv_file is not None:
try:
raw = pd.read_csv(csv_file.name if hasattr(csv_file, "name") else csv_file)
if "ds" not in raw.columns or "y" not in raw.columns:
raise ValueError("CSV must contain columns 'ds' and 'y'.")
raw["ds"] = pd.to_datetime(raw["ds"])
raw["y"] = pd.to_numeric(raw["y"], errors="coerce")
raw = raw.dropna(subset=["y"]).sort_values("ds").reset_index(drop=True)
return raw
except Exception as exc:
raise gr.Error(f"Could not read uploaded CSV: {exc}")
if name in BUILTIN_DATASETS:
return BUILTIN_DATASETS[name]()
raise gr.Error(f"Unknown dataset: {name}")
def _ensure_odd(val: int) -> int:
"""Force a value to be odd (required by statsmodels windows)."""
val = int(val)
return val if val % 2 == 1 else val + 1
def _strength(residual: np.ndarray, component_plus_residual: np.ndarray) -> float:
"""Compute strength of a component: max(0, 1 - Var(R)/Var(C+R))."""
var_r = np.nanvar(residual)
var_cr = np.nanvar(component_plus_residual)
if var_cr == 0:
return 0.0
return float(max(0.0, 1.0 - var_r / var_cr))
# ---------------------------------------------------------------------------
# Core decomposition + plotting
# ---------------------------------------------------------------------------
def decompose_and_plot(
dataset_name: str,
csv_file,
method: str,
period: int,
stl_seasonal: int,
stl_trend: int,
stl_robust: bool,
):
"""Run decomposition and return (matplotlib Figure, summary string)."""
# --- Load data --------------------------------------------------------
df = _load_dataset(dataset_name, csv_file)
if len(df) < 2 * period:
raise gr.Error(
f"Not enough observations ({len(df)}) for the chosen period ({period}). "
f"Need at least {2 * period} observations."
)
y_series = pd.Series(df["y"].values, index=df["ds"])
# --- Decompose --------------------------------------------------------
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if method == "STL":
stl_seasonal = _ensure_odd(stl_seasonal)
stl_trend_val = _ensure_odd(stl_trend) if stl_trend > 0 else None
stl_obj = STL(
y_series,
period=int(period),
seasonal=stl_seasonal,
trend=stl_trend_val,
robust=bool(stl_robust),
)
result = stl_obj.fit()
else:
model_type = "additive" if "Additive" in method else "multiplicative"
result = seasonal_decompose(
y_series, model=model_type, period=int(period)
)
observed = result.observed
trend = result.trend
seasonal = result.seasonal
resid = result.resid
# --- Strength measures ------------------------------------------------
r = resid.values
t = trend.values
s = seasonal.values
mask = ~(np.isnan(r) | np.isnan(t) | np.isnan(s))
r_clean = r[mask]
t_clean = t[mask]
s_clean = s[mask]
f_trend = _strength(r_clean, t_clean + r_clean)
f_season = _strength(r_clean, s_clean + r_clean)
# --- Plot -------------------------------------------------------------
fig, axes = plt.subplots(4, 1, figsize=(10, 8), sharex=True)
fig.patch.set_facecolor("white")
for ax in axes:
ax.set_facecolor("white")
ax.grid(True, linewidth=0.3, alpha=0.5)
dates = observed.index
# 1. Observed
axes[0].plot(dates, observed, color=CLR_PRIMARY, linewidth=1.2)
axes[0].set_ylabel("Observed", fontsize=10, fontweight="bold")
# 2. Trend
axes[1].plot(dates, trend, color=CLR_TREND, linewidth=1.4)
axes[1].set_ylabel("Trend", fontsize=10, fontweight="bold")
# 3. Seasonal
axes[2].plot(dates, seasonal, color=CLR_SEASON, linewidth=1.0)
axes[2].set_ylabel("Seasonal", fontsize=10, fontweight="bold")
# 4. Residual
axes[3].plot(dates, resid, color=CLR_RESID, linewidth=0.8, alpha=0.8)
axes[3].set_ylabel("Remainder", fontsize=10, fontweight="bold")
axes[3].set_xlabel("Date", fontsize=10)
method_label = method if method == "STL" else method.replace("Classical ", "Classical – ")
fig.suptitle(
f"Decomposition · {method_label} · period = {period}",
fontsize=13,
fontweight="bold",
y=0.98,
)
fig.tight_layout(rect=[0, 0, 1, 0.96])
# --- Summary text -----------------------------------------------------
summary = (
f"Strength of Trend (F_T): {f_trend:.4f}\n"
f"Strength of Seasonality (F_S): {f_season:.4f}\n\n"
f"Formulas:\n"
f" F_T = max(0, 1 − Var(R) / Var(T + R))\n"
f" F_S = max(0, 1 − Var(R) / Var(S + R))"
)
return fig, summary
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
_THEME = gr.themes.Soft(
primary_hue=gr.themes.Color(
c50="#fef2f3", c100="#fde6e8", c200="#fbd0d5",
c300="#f7a4ae", c400="#f17182", c500="#C3142D",
c600="#b01228", c700="#8B0E1E", c800="#6e0b18",
c900="#5c0d17", c950="#33040a",
),
secondary_hue=gr.themes.Color(
c50="#fef2f3", c100="#fde6e8", c200="#fbd0d5",
c300="#f7a4ae", c400="#f17182", c500="#C3142D",
c600="#b01228", c700="#8B0E1E", c800="#6e0b18",
c900="#5c0d17", c950="#33040a",
),
neutral_hue=gr.themes.Color(
c50="#EDECE2", c100="#E5E4D9", c200="#DDDCD0",
c300="#C8C7BC", c400="#A3A299", c500="#858479",
c600="#6B6A61", c700="#53524B", c800="#3B3A35",
c900="#252420", c950="#151410",
),
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
)
_CSS = """
.gradio-container { max-width: 1280px !important; margin: auto; }
footer { display: none !important; }
.gr-button-primary { background: #C3142D !important; border: none !important; }
.gr-button-primary:hover { background: #8B0E1E !important; }
.gr-button-secondary { border-color: #C3142D !important; color: #C3142D !important; }
.gr-button-secondary:hover { background: #8B0E1E !important; color: white !important; }
.gr-input:focus { border-color: #C3142D !important; box-shadow: 0 0 0 2px rgba(195,20,45,0.2) !important; }
"""
def build_app() -> gr.Blocks:
with gr.Blocks(title="Decomposition Explorer v1.0") as app:
gr.HTML("""
<div style="display: flex; align-items: center; gap: 16px; padding: 16px 24px;
background: linear-gradient(135deg, #C3142D 0%, #8B0E1E 100%);
border-radius: 12px; margin-bottom: 16px; box-shadow: 0 4px 12px rgba(0,0,0,0.15);">
<img src="https://raw.githubusercontent.com/fmegahed/isa401/main/figures/beveled-m-min-size.png"
alt="Miami University" style="height: 56px;">
<div>
<h1 style="margin: 0; color: white; font-size: 24px; font-weight: 700; letter-spacing: -0.5px;">
Decomposition Explorer v1.0
</h1>
<p style="margin: 4px 0 0; color: rgba(255,255,255,0.85); font-size: 14px;">
ISA 444: Business Forecasting &middot; Farmer School of Business &middot; Miami University
</p>
</div>
</div>
""")
gr.HTML("""
<div style="background: #EDECE2; border-left: 4px solid #C3142D; padding: 12px 16px;
border-radius: 0 8px 8px 0; margin-bottom: 16px; font-size: 14px; color: #585E60;">
Interactive tool for exploring time-series decomposition methods (Classical and STL).
Choose a built-in dataset or upload your own CSV, adjust decomposition parameters, and
examine trend, seasonal, and remainder components along with strength measures.
</div>
""")
with gr.Row():
# --- Left column: controls ------------------------------------
with gr.Column(scale=1, min_width=280):
dataset_dd = gr.Dropdown(
label="Dataset",
choices=list(BUILTIN_DATASETS.keys()),
value="Airline Passengers",
)
csv_upload = gr.File(
label="Or upload CSV (columns: ds, y)",
file_types=[".csv"],
type="filepath",
)
method_radio = gr.Radio(
label="Decomposition Method",
choices=[
"Classical (Additive)",
"Classical (Multiplicative)",
"STL",
],
value="STL",
)
period_slider = gr.Slider(
label="Period / Season Length",
minimum=2,
maximum=52,
step=1,
value=12,
)
# STL-specific controls
stl_group = gr.Group(visible=True)
with stl_group:
gr.Markdown("**STL Parameters**")
stl_seasonal_slider = gr.Slider(
label="seasonal (seasonality window, odd)",
minimum=7,
maximum=51,
step=2,
value=13,
)
stl_trend_slider = gr.Slider(
label="trend (trend window, odd; 0 = auto)",
minimum=0,
maximum=101,
step=2,
value=0,
)
stl_robust_cb = gr.Checkbox(
label="robust (robust to outliers)",
value=False,
)
# --- Right column: output -------------------------------------
with gr.Column(scale=3):
plot_output = gr.Plot(label="Decomposition")
summary_box = gr.Textbox(
label="Strength Measures",
lines=5,
interactive=False,
)
# --- Visibility toggle for STL controls ---------------------------
def toggle_stl(method):
return gr.Group(visible=(method == "STL"))
method_radio.change(
fn=toggle_stl,
inputs=[method_radio],
outputs=[stl_group],
)
# --- Gather all inputs --------------------------------------------
all_inputs = [
dataset_dd,
csv_upload,
method_radio,
period_slider,
stl_seasonal_slider,
stl_trend_slider,
stl_robust_cb,
]
all_outputs = [plot_output, summary_box]
# --- Wire change events -------------------------------------------
for ctrl in all_inputs:
ctrl.change(
fn=decompose_and_plot,
inputs=all_inputs,
outputs=all_outputs,
)
# --- Initial load -------------------------------------------------
app.load(
fn=decompose_and_plot,
inputs=all_inputs,
outputs=all_outputs,
)
gr.HTML("""
<div style="margin-top: 24px; padding: 16px; background: #EDECE2; border-radius: 8px;
text-align: center; font-size: 13px; color: #585E60; border-top: 2px solid #C3142D;">
<div style="margin-bottom: 4px;">
<strong style="color: #C3142D;">Developed by</strong>
<a href="https://miamioh.edu/fsb/directory/?up=/directory/megahefm"
style="color: #C3142D; text-decoration: none; font-weight: 600;">
Fadel M. Megahed
</a>
&middot; Gloss Professor of Analytics &middot; Miami University
</div>
<div style="font-size: 12px; color: #888;">
Version 1.0.0 &middot; Spring 2026 &middot;
<a href="https://github.com/fmegahed" style="color: #C3142D; text-decoration: none;">GitHub</a> &middot;
<a href="https://www.linkedin.com/in/fmegahed/" style="color: #C3142D; text-decoration: none;">LinkedIn</a>
</div>
</div>
""")
return app
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
demo = build_app()
demo.launch(theme=_THEME, css=_CSS, ssr_mode=False, allowed_paths=["beveled-m-min-size.png"])