fmegahed commited on
Commit
0191ae7
·
0 Parent(s):

Initial deploy: Time Series Visualizer v0.1.0

Browse files
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ RUN apt-get update && \
4
+ apt-get install -y --no-install-recommends gcc g++ && \
5
+ rm -rf /var/lib/apt/lists/*
6
+
7
+ RUN useradd -m -u 1000 user
8
+ USER user
9
+ ENV PATH="/home/user/.local/bin:$PATH"
10
+
11
+ WORKDIR /app
12
+
13
+ COPY --chown=user:user requirements.txt .
14
+ RUN pip install --no-cache-dir --upgrade pip && \
15
+ pip install --no-cache-dir -r requirements.txt
16
+
17
+ COPY --chown=user:user . .
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["streamlit", "run", "app.py", \
22
+ "--server.port=7860", \
23
+ "--server.address=0.0.0.0", \
24
+ "--browser.gatherUsageStats=false"]
README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Time Series Visualizer
3
+ emoji: 📈
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # Time Series Visualizer + AI Chart Interpreter
12
+
13
+ A Streamlit app for Miami University Business Analytics students to upload CSV
14
+ time-series data, create publication-quality charts, and get AI-powered chart
15
+ interpretation.
16
+
17
+ ## Features
18
+
19
+ - **Upload & Clean** — auto-detect delimiters, date columns, and numeric formats
20
+ - **9+ Chart Types** — line, seasonal, subseries, ACF/PACF, decomposition, rolling, YoY, lag, spaghetti
21
+ - **Multi-Series Support** — panel (small-multiples) and spaghetti plots for comparing series
22
+ - **AI Interpretation** — GPT-5.2 vision analyzes chart images and returns structured insights
23
+ - **QueryChat** — natural-language data filtering powered by DuckDB
24
+
25
+ ## Privacy
26
+
27
+ All data processing happens in-memory. No data is persisted to disk.
28
+ Only chart PNG images (never raw data) are sent to the AI when you click "Interpret."
app.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Time Series Visualizer + AI Chart Interpreter
3
+ =============================================
4
+ Main Streamlit application. Run with:
5
+
6
+ streamlit run app.py --server.port=7860
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import hashlib
12
+ from pathlib import Path
13
+
14
+ from dotenv import load_dotenv
15
+ load_dotenv()
16
+
17
+ import matplotlib
18
+ matplotlib.use("Agg")
19
+
20
+ import pandas as pd
21
+ import streamlit as st
22
+
23
+ from src.ui_theme import (
24
+ apply_miami_theme,
25
+ get_miami_mpl_style,
26
+ get_palette_colors,
27
+ render_palette_preview,
28
+ )
29
+ from src.cleaning import (
30
+ read_csv_upload,
31
+ suggest_date_columns,
32
+ suggest_numeric_columns,
33
+ clean_dataframe,
34
+ detect_frequency,
35
+ add_time_features,
36
+ CleaningReport,
37
+ FrequencyInfo,
38
+ )
39
+ from src.diagnostics import (
40
+ compute_summary_stats,
41
+ compute_acf_pacf,
42
+ compute_decomposition,
43
+ compute_rolling_stats,
44
+ compute_yoy_change,
45
+ compute_multi_series_summary,
46
+ )
47
+ from src.plotting import (
48
+ fig_to_png_bytes,
49
+ plot_line_with_markers,
50
+ plot_line_colored_markers,
51
+ plot_seasonal,
52
+ plot_seasonal_subseries,
53
+ plot_acf_pacf,
54
+ plot_decomposition,
55
+ plot_rolling_overlay,
56
+ plot_yoy_change,
57
+ plot_lag,
58
+ plot_panel,
59
+ plot_spaghetti,
60
+ )
61
+ from src.ai_interpretation import (
62
+ check_api_key_available,
63
+ interpret_chart,
64
+ render_interpretation,
65
+ )
66
+ from src.querychat_helpers import (
67
+ check_querychat_available,
68
+ create_querychat,
69
+ get_filtered_pandas_df,
70
+ )
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Constants
74
+ # ---------------------------------------------------------------------------
75
+ _DATA_DIR = Path(__file__).parent / "data"
76
+ _DEMO_FILES = {
77
+ "Monthly Retail Sales (single)": _DATA_DIR / "demo_single.csv",
78
+ "Quarterly Revenue by Region (wide)": _DATA_DIR / "demo_multi_wide.csv",
79
+ "Daily Stock Prices – 20 Tickers (long)": _DATA_DIR / "demo_multi_long.csv",
80
+ }
81
+
82
+ _CHART_TYPES = [
83
+ "Line with Markers",
84
+ "Line – Colored Markers",
85
+ "Seasonal Plot",
86
+ "Seasonal Sub-series",
87
+ "ACF / PACF",
88
+ "Decomposition",
89
+ "Rolling Mean Overlay",
90
+ "Year-over-Year Change",
91
+ "Lag Plot",
92
+ ]
93
+
94
+ _PALETTE_NAMES = ["Set2", "Dark2", "Set1", "Paired", "Pastel1", "Pastel2", "Accent"]
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Helpers
99
+ # ---------------------------------------------------------------------------
100
+
101
+ def _df_hash(df: pd.DataFrame) -> str:
102
+ """Fast hash of a DataFrame for cache-key / change-detection."""
103
+ return hashlib.md5(
104
+ pd.util.hash_pandas_object(df).values.tobytes()
105
+ ).hexdigest()
106
+
107
+
108
+ def _load_demo(path: Path) -> pd.DataFrame:
109
+ return pd.read_csv(path)
110
+
111
+
112
+ def _render_cleaning_report(report: CleaningReport) -> None:
113
+ """Show a data-quality card."""
114
+ c1, c2, c3 = st.columns(3)
115
+ c1.metric("Rows before", f"{report.rows_before:,}")
116
+ c2.metric("Rows after", f"{report.rows_after:,}")
117
+ c3.metric("Duplicates found", f"{report.duplicates_found:,}")
118
+
119
+ if report.missing_before:
120
+ with st.expander("Missing values"):
121
+ cols = list(report.missing_before.keys())
122
+ mc1, mc2 = st.columns(2)
123
+ with mc1:
124
+ st.write("**Before cleaning**")
125
+ for c in cols:
126
+ st.write(f"- {c}: {report.missing_before[c]}")
127
+ with mc2:
128
+ st.write("**After cleaning**")
129
+ for c in cols:
130
+ st.write(f"- {c}: {report.missing_after.get(c, 0)}")
131
+
132
+ if report.parsing_warnings:
133
+ with st.expander("Parsing warnings"):
134
+ for w in report.parsing_warnings:
135
+ st.warning(w)
136
+
137
+
138
+ def _render_summary_stats(stats) -> None:
139
+ """Render SummaryStats as metric cards + expander."""
140
+ row1 = st.columns(4)
141
+ row1[0].metric("Count", f"{stats.count:,}")
142
+ row1[1].metric("Missing", f"{stats.missing_count} ({stats.missing_pct:.1f}%)")
143
+ row1[2].metric("Mean", f"{stats.mean_val:,.2f}")
144
+ row1[3].metric("Std Dev", f"{stats.std_val:,.2f}")
145
+
146
+ row2 = st.columns(4)
147
+ row2[0].metric("Min", f"{stats.min_val:,.2f}")
148
+ row2[1].metric("25th %ile", f"{stats.p25:,.2f}")
149
+ row2[2].metric("Median", f"{stats.median_val:,.2f}")
150
+ row2[3].metric("75th %ile / Max", f"{stats.p75:,.2f} / {stats.max_val:,.2f}")
151
+
152
+ with st.expander("Trend & Stationarity"):
153
+ tc1, tc2 = st.columns(2)
154
+ tc1.metric(
155
+ "Trend slope (per period)",
156
+ f"{stats.trend_slope:,.4f}" if pd.notna(stats.trend_slope) else "N/A",
157
+ help="Slope from OLS on a numeric index.",
158
+ )
159
+ tc2.metric(
160
+ "Trend p-value",
161
+ f"{stats.trend_pvalue:.4f}" if pd.notna(stats.trend_pvalue) else "N/A",
162
+ )
163
+ ac1, ac2 = st.columns(2)
164
+ ac1.metric(
165
+ "ADF statistic",
166
+ f"{stats.adf_statistic:.4f}" if pd.notna(stats.adf_statistic) else "N/A",
167
+ help="Augmented Dickey-Fuller test statistic.",
168
+ )
169
+ ac2.metric(
170
+ "ADF p-value",
171
+ f"{stats.adf_pvalue:.4f}" if pd.notna(stats.adf_pvalue) else "N/A",
172
+ help="p < 0.05 suggests the series is stationary.",
173
+ )
174
+ st.caption(
175
+ f"Date range: {stats.date_start.date()} to {stats.date_end.date()} "
176
+ f"({stats.date_span_days:,} days)"
177
+ )
178
+
179
+
180
+ # ---------------------------------------------------------------------------
181
+ # Page config
182
+ # ---------------------------------------------------------------------------
183
+ st.set_page_config(
184
+ page_title="Time Series Visualizer",
185
+ page_icon="\U0001f4c8",
186
+ layout="wide",
187
+ )
188
+ apply_miami_theme()
189
+ style_dict = get_miami_mpl_style()
190
+
191
+ # ---------------------------------------------------------------------------
192
+ # Session state initialisation
193
+ # ---------------------------------------------------------------------------
194
+ for key in [
195
+ "raw_df", "cleaned_df", "cleaning_report", "freq_info",
196
+ "date_col", "y_cols", "qc", "qc_hash",
197
+ ]:
198
+ if key not in st.session_state:
199
+ st.session_state[key] = None
200
+
201
+ # ---------------------------------------------------------------------------
202
+ # Sidebar — Data input
203
+ # ---------------------------------------------------------------------------
204
+ with st.sidebar:
205
+ st.markdown(
206
+ """
207
+ <div style="text-align:center; margin-bottom:0.5rem;">
208
+ <span style="font-size:1.6rem; font-weight:800; color:#C41230;">
209
+ Time Series Visualizer
210
+ </span><br>
211
+ <span style="font-size:0.82rem; color:#000;">
212
+ ISA 444 &middot; Miami University
213
+ </span>
214
+ </div>
215
+ """,
216
+ unsafe_allow_html=True,
217
+ )
218
+ st.divider()
219
+ st.header("Data Input")
220
+
221
+ uploaded = st.file_uploader("Upload a CSV file", type=["csv", "tsv", "txt"])
222
+
223
+ demo_choice = st.selectbox(
224
+ "Or load a demo dataset",
225
+ ["(none)"] + list(_DEMO_FILES.keys()),
226
+ )
227
+
228
+ # Load data
229
+ if uploaded is not None:
230
+ df_raw, delim = read_csv_upload(uploaded)
231
+ st.caption(f"Detected delimiter: `{repr(delim)}`")
232
+ st.session_state.raw_df = df_raw
233
+ elif demo_choice != "(none)":
234
+ st.session_state.raw_df = _load_demo(_DEMO_FILES[demo_choice])
235
+ # else: keep whatever was already in session state
236
+
237
+ raw_df: pd.DataFrame | None = st.session_state.raw_df
238
+
239
+ if raw_df is not None:
240
+ st.divider()
241
+ st.subheader("Column Selection")
242
+
243
+ # Auto-suggest
244
+ date_suggestions = suggest_date_columns(raw_df)
245
+ numeric_suggestions = suggest_numeric_columns(raw_df)
246
+
247
+ all_cols = list(raw_df.columns)
248
+ default_date_idx = all_cols.index(date_suggestions[0]) if date_suggestions else 0
249
+
250
+ date_col = st.selectbox("Date column", all_cols, index=default_date_idx)
251
+
252
+ remaining = [c for c in all_cols if c != date_col]
253
+ default_y = [c for c in numeric_suggestions if c != date_col]
254
+ y_cols = st.multiselect(
255
+ "Value column(s)",
256
+ remaining,
257
+ default=default_y[:4] if default_y else [],
258
+ )
259
+
260
+ st.session_state.date_col = date_col
261
+ st.session_state.y_cols = y_cols
262
+
263
+ st.divider()
264
+ st.subheader("Cleaning Options")
265
+ dup_action = st.selectbox(
266
+ "Duplicate dates",
267
+ ["keep_last", "keep_first", "drop_all"],
268
+ )
269
+ missing_action = st.selectbox(
270
+ "Missing values",
271
+ ["interpolate", "ffill", "drop"],
272
+ )
273
+
274
+ # Clean
275
+ if y_cols:
276
+ cleaned_df, report = clean_dataframe(
277
+ raw_df, date_col, y_cols,
278
+ dup_action=dup_action,
279
+ missing_action=missing_action,
280
+ )
281
+ freq_info = detect_frequency(cleaned_df, date_col)
282
+ cleaned_df = add_time_features(cleaned_df, date_col)
283
+
284
+ st.session_state.cleaned_df = cleaned_df
285
+ st.session_state.cleaning_report = report
286
+ st.session_state.freq_info = freq_info
287
+
288
+ st.caption(f"Frequency: **{freq_info.label}** "
289
+ f"({'regular' if freq_info.is_regular else 'irregular'})")
290
+
291
+ # Frequency override
292
+ freq_override = st.text_input(
293
+ "Override frequency label (optional)",
294
+ value="",
295
+ help="e.g. Daily, Weekly, Monthly, Quarterly, Yearly",
296
+ )
297
+ if freq_override.strip():
298
+ st.session_state.freq_info = FrequencyInfo(
299
+ label=freq_override.strip(),
300
+ median_delta=freq_info.median_delta,
301
+ is_regular=freq_info.is_regular,
302
+ )
303
+
304
+ # ------ QueryChat ------
305
+ if check_querychat_available():
306
+ current_hash = _df_hash(cleaned_df) + str(y_cols)
307
+ if st.session_state.qc_hash != current_hash:
308
+ st.session_state.qc = create_querychat(
309
+ cleaned_df,
310
+ name="uploaded data",
311
+ date_col=date_col,
312
+ y_cols=y_cols,
313
+ freq_label=st.session_state.freq_info.label,
314
+ )
315
+ st.session_state.qc_hash = current_hash
316
+ st.divider()
317
+ st.subheader("QueryChat")
318
+ st.session_state.qc.ui()
319
+ else:
320
+ st.divider()
321
+ st.info(
322
+ "Set `OPENAI_API_KEY` to enable QueryChat "
323
+ "(natural-language data filtering)."
324
+ )
325
+
326
+ # Reset button
327
+ st.divider()
328
+ if st.button("Reset all"):
329
+ for k in list(st.session_state.keys()):
330
+ del st.session_state[k]
331
+ st.rerun()
332
+
333
+ st.divider()
334
+ st.markdown(
335
+ """
336
+ <div style="text-align:center; padding:0.5rem 0;">
337
+ <span style="font-size:0.75rem; color:#000;">
338
+ Developed by <strong>Fadel M. Megahed</strong><br>
339
+ for <strong>ISA 444</strong> &middot; Miami University<br>
340
+ Version <strong>0.1.0</strong>
341
+ </span>
342
+ </div>
343
+ """,
344
+ unsafe_allow_html=True,
345
+ )
346
+ st.caption(
347
+ "**Privacy:** All processing is in-memory. "
348
+ "Only chart images (never raw data) are sent to the AI when you click Interpret."
349
+ )
350
+
351
+ # ---------------------------------------------------------------------------
352
+ # Main area — guard
353
+ # ---------------------------------------------------------------------------
354
+ cleaned_df: pd.DataFrame | None = st.session_state.cleaned_df
355
+ date_col: str | None = st.session_state.date_col
356
+ y_cols: list[str] | None = st.session_state.y_cols
357
+ freq_info: FrequencyInfo | None = st.session_state.freq_info
358
+ report: CleaningReport | None = st.session_state.cleaning_report
359
+
360
+ if cleaned_df is None or not y_cols:
361
+ st.title("Time Series Visualizer")
362
+ st.write(
363
+ "Upload a CSV or choose a demo dataset from the sidebar to get started."
364
+ )
365
+ st.stop()
366
+
367
+ # If QueryChat is active, use its filtered df
368
+ if st.session_state.qc is not None:
369
+ working_df = get_filtered_pandas_df(st.session_state.qc)
370
+ if working_df.empty:
371
+ working_df = cleaned_df
372
+ else:
373
+ working_df = cleaned_df
374
+
375
+ # Data quality report
376
+ if report is not None:
377
+ with st.expander("Data Quality Report", expanded=False):
378
+ _render_cleaning_report(report)
379
+
380
+ # ---------------------------------------------------------------------------
381
+ # Tabs
382
+ # ---------------------------------------------------------------------------
383
+ tab_single, tab_few, tab_many = st.tabs([
384
+ "Single Series",
385
+ "Few Series (Panel)",
386
+ "Many Series (Spaghetti)",
387
+ ])
388
+
389
+ # ===================================================================
390
+ # Tab A — Single Series
391
+ # ===================================================================
392
+ with tab_single:
393
+ if len(y_cols) == 1:
394
+ active_y = y_cols[0]
395
+ else:
396
+ active_y = st.selectbox("Select value column", y_cols, key="tab_a_y")
397
+
398
+ # ---- Date range filter ------------------------------------------------
399
+ dr_mode = st.radio(
400
+ "Date range",
401
+ ["All", "Last N years", "Custom"],
402
+ horizontal=True,
403
+ key="dr_mode",
404
+ )
405
+ df_plot = working_df.copy()
406
+ if dr_mode == "Last N years":
407
+ n_years = st.slider("Years", 1, 20, 5, key="dr_n")
408
+ cutoff = df_plot[date_col].max() - pd.DateOffset(years=n_years)
409
+ df_plot = df_plot[df_plot[date_col] >= cutoff]
410
+ elif dr_mode == "Custom":
411
+ d_min = df_plot[date_col].min().date()
412
+ d_max = df_plot[date_col].max().date()
413
+ sel = st.slider("Date range", d_min, d_max, (d_min, d_max), key="dr_custom")
414
+ df_plot = df_plot[
415
+ (df_plot[date_col].dt.date >= sel[0])
416
+ & (df_plot[date_col].dt.date <= sel[1])
417
+ ]
418
+
419
+ if df_plot.empty:
420
+ st.warning("No data in selected range.")
421
+ st.stop()
422
+
423
+ # ---- Chart controls ---------------------------------------------------
424
+ col_chart, col_opts = st.columns([2, 1])
425
+ with col_opts:
426
+ chart_type = st.selectbox("Chart type", _CHART_TYPES, key="chart_type_a")
427
+
428
+ palette_name = st.selectbox("Color palette", _PALETTE_NAMES, key="pal_a")
429
+ n_colors = max(12, len(y_cols))
430
+ palette_colors = get_palette_colors(palette_name, n_colors)
431
+ swatch_fig = render_palette_preview(palette_colors[:8])
432
+ st.pyplot(swatch_fig, width="stretch")
433
+
434
+ # Chart-specific controls
435
+ period_label = "month"
436
+ window_size = 12
437
+ lag_val = 1
438
+ decomp_model = "additive"
439
+
440
+ if chart_type in ("Seasonal Plot", "Seasonal Sub-series"):
441
+ period_label = st.selectbox("Period", ["month", "quarter"], key="period_a")
442
+
443
+ if chart_type == "Rolling Mean Overlay":
444
+ window_size = st.slider("Window", 2, 52, 12, key="window_a")
445
+
446
+ if chart_type == "Lag Plot":
447
+ lag_val = st.slider("Lag", 1, 52, 1, key="lag_a")
448
+
449
+ if chart_type == "Decomposition":
450
+ decomp_model = st.selectbox("Model", ["additive", "multiplicative"], key="decomp_a")
451
+
452
+ # ---- Render chart -----------------------------------------------------
453
+ with col_chart:
454
+ fig = None
455
+ try:
456
+ if chart_type == "Line with Markers":
457
+ fig = plot_line_with_markers(
458
+ df_plot, date_col, active_y,
459
+ title=f"{active_y} over Time",
460
+ style_dict=style_dict, palette_colors=palette_colors,
461
+ )
462
+
463
+ elif chart_type == "Line – Colored Markers":
464
+ if "month" in df_plot.columns:
465
+ color_by = st.selectbox(
466
+ "Color by",
467
+ ["month", "quarter", "year", "day_of_week"],
468
+ key="color_by_a",
469
+ )
470
+ else:
471
+ color_by = st.selectbox("Color by", [c for c in df_plot.columns if c not in (date_col, active_y)][:5], key="color_by_a")
472
+ fig = plot_line_colored_markers(
473
+ df_plot, date_col, active_y,
474
+ color_by=color_by, palette_colors=palette_colors,
475
+ title=f"{active_y} colored by {color_by}",
476
+ style_dict=style_dict,
477
+ )
478
+
479
+ elif chart_type == "Seasonal Plot":
480
+ fig = plot_seasonal(
481
+ df_plot, date_col, active_y,
482
+ period=period_label,
483
+ palette_name_colors=palette_colors,
484
+ title=f"Seasonal Plot – {active_y}",
485
+ style_dict=style_dict,
486
+ )
487
+
488
+ elif chart_type == "Seasonal Sub-series":
489
+ fig = plot_seasonal_subseries(
490
+ df_plot, date_col, active_y,
491
+ period=period_label,
492
+ title=f"Seasonal Sub-series – {active_y}",
493
+ style_dict=style_dict, palette_colors=palette_colors,
494
+ )
495
+
496
+ elif chart_type == "ACF / PACF":
497
+ series = df_plot[active_y].dropna()
498
+ acf_vals, acf_ci, pacf_vals, pacf_ci = compute_acf_pacf(series)
499
+ fig = plot_acf_pacf(
500
+ acf_vals, acf_ci, pacf_vals, pacf_ci,
501
+ title=f"ACF / PACF – {active_y}",
502
+ style_dict=style_dict,
503
+ )
504
+
505
+ elif chart_type == "Decomposition":
506
+ period_int = None
507
+ if freq_info and freq_info.label == "Monthly":
508
+ period_int = 12
509
+ elif freq_info and freq_info.label == "Quarterly":
510
+ period_int = 4
511
+ elif freq_info and freq_info.label == "Weekly":
512
+ period_int = 52
513
+ elif freq_info and freq_info.label == "Daily":
514
+ period_int = 365
515
+
516
+ result = compute_decomposition(
517
+ df_plot, date_col, active_y,
518
+ model=decomp_model, period=period_int,
519
+ )
520
+ fig = plot_decomposition(
521
+ result,
522
+ title=f"Decomposition – {active_y} ({decomp_model})",
523
+ style_dict=style_dict,
524
+ )
525
+
526
+ elif chart_type == "Rolling Mean Overlay":
527
+ fig = plot_rolling_overlay(
528
+ df_plot, date_col, active_y,
529
+ window=window_size,
530
+ title=f"Rolling {window_size}-pt Mean – {active_y}",
531
+ style_dict=style_dict, palette_colors=palette_colors,
532
+ )
533
+
534
+ elif chart_type == "Year-over-Year Change":
535
+ yoy_result = compute_yoy_change(df_plot, date_col, active_y)
536
+ yoy_df = pd.DataFrame({
537
+ "date": yoy_result[date_col],
538
+ "abs_change": yoy_result["yoy_abs_change"],
539
+ "pct_change": yoy_result["yoy_pct_change"],
540
+ }).dropna()
541
+ fig = plot_yoy_change(
542
+ df_plot, date_col, active_y, yoy_df,
543
+ title=f"Year-over-Year Change – {active_y}",
544
+ style_dict=style_dict,
545
+ )
546
+
547
+ elif chart_type == "Lag Plot":
548
+ fig = plot_lag(
549
+ df_plot[active_y],
550
+ lag=lag_val,
551
+ title=f"Lag-{lag_val} Plot – {active_y}",
552
+ style_dict=style_dict,
553
+ )
554
+
555
+ except Exception as exc:
556
+ st.error(f"Chart error: {exc}")
557
+
558
+ if fig is not None:
559
+ st.pyplot(fig, width="stretch")
560
+
561
+ # ---- Summary stats expander -------------------------------------------
562
+ with st.expander("Summary Statistics", expanded=False):
563
+ stats = compute_summary_stats(df_plot, date_col, active_y)
564
+ _render_summary_stats(stats)
565
+
566
+ # ---- AI Interpretation ------------------------------------------------
567
+ with st.expander("AI Chart Interpretation", expanded=False):
568
+ st.caption(
569
+ "The chart image (PNG) and metadata are sent to OpenAI. "
570
+ "No raw data leaves this app."
571
+ )
572
+ if not check_api_key_available():
573
+ st.warning("Set `OPENAI_API_KEY` to enable AI interpretation.")
574
+ elif fig is not None:
575
+ if st.button("Interpret Chart with AI", key="interpret_a"):
576
+ with st.spinner("Analyzing chart..."):
577
+ png = fig_to_png_bytes(fig)
578
+ date_range_str = (
579
+ f"{df_plot[date_col].min().date()} to "
580
+ f"{df_plot[date_col].max().date()}"
581
+ )
582
+ metadata = {
583
+ "chart_type": chart_type,
584
+ "frequency_label": freq_info.label if freq_info else "Unknown",
585
+ "date_range": date_range_str,
586
+ "y_column": active_y,
587
+ }
588
+ interp = interpret_chart(png, metadata)
589
+ render_interpretation(interp)
590
+
591
+ # ===================================================================
592
+ # Tab B — Few Series (Panel)
593
+ # ===================================================================
594
+ with tab_few:
595
+ if len(y_cols) < 2:
596
+ st.info("Select 2+ value columns in the sidebar to use panel plots.")
597
+ else:
598
+ st.subheader("Panel Plot (Small Multiples)")
599
+
600
+ panel_cols = st.multiselect(
601
+ "Columns to plot",
602
+ y_cols,
603
+ default=y_cols[:4],
604
+ key="panel_cols",
605
+ )
606
+
607
+ if panel_cols:
608
+ pc1, pc2 = st.columns(2)
609
+ with pc1:
610
+ panel_chart = st.selectbox(
611
+ "Chart type", ["line", "bar"], key="panel_chart"
612
+ )
613
+ with pc2:
614
+ shared_y = st.checkbox("Shared Y axis", value=True, key="panel_shared")
615
+
616
+ palette_name_b = st.selectbox("Color palette", _PALETTE_NAMES, key="pal_b")
617
+ palette_b = get_palette_colors(palette_name_b, len(panel_cols))
618
+
619
+ try:
620
+ fig_panel = plot_panel(
621
+ working_df, date_col, panel_cols,
622
+ chart_type=panel_chart,
623
+ shared_y=shared_y,
624
+ title="Panel Comparison",
625
+ style_dict=style_dict,
626
+ palette_colors=palette_b,
627
+ )
628
+ st.pyplot(fig_panel, width="stretch")
629
+ except Exception as exc:
630
+ st.error(f"Panel chart error: {exc}")
631
+
632
+ # Per-series summary table
633
+ with st.expander("Per-series Summary", expanded=False):
634
+ summary_df = compute_multi_series_summary(
635
+ working_df, date_col, panel_cols,
636
+ )
637
+ st.dataframe(
638
+ summary_df.style.format({
639
+ "mean": "{:,.2f}",
640
+ "std": "{:,.2f}",
641
+ "min": "{:,.2f}",
642
+ "max": "{:,.2f}",
643
+ "trend_slope": "{:,.4f}",
644
+ "adf_pvalue": "{:.4f}",
645
+ }),
646
+ width="stretch",
647
+ )
648
+
649
+ # ===================================================================
650
+ # Tab C — Many Series (Spaghetti)
651
+ # ===================================================================
652
+ with tab_many:
653
+ if len(y_cols) < 2:
654
+ st.info("Select 2+ value columns in the sidebar to use spaghetti plots.")
655
+ else:
656
+ st.subheader("Spaghetti Plot")
657
+
658
+ spag_cols = st.multiselect(
659
+ "Columns to include",
660
+ y_cols,
661
+ default=y_cols,
662
+ key="spag_cols",
663
+ )
664
+
665
+ if spag_cols:
666
+ sc1, sc2, sc3 = st.columns(3)
667
+ with sc1:
668
+ alpha_val = st.slider("Alpha", 0.05, 1.0, 0.15, 0.05, key="spag_alpha")
669
+ with sc2:
670
+ top_n = st.number_input("Highlight top N", 0, len(spag_cols), 0, key="spag_topn")
671
+ top_n = top_n if top_n > 0 else None
672
+ with sc3:
673
+ highlight = st.selectbox(
674
+ "Highlight series",
675
+ ["(none)"] + spag_cols,
676
+ key="spag_highlight",
677
+ )
678
+ highlight_col = highlight if highlight != "(none)" else None
679
+
680
+ show_median = st.checkbox("Show Median + IQR band", value=False, key="spag_median")
681
+
682
+ palette_name_c = st.selectbox("Color palette", _PALETTE_NAMES, key="pal_c")
683
+ palette_c = get_palette_colors(palette_name_c, len(spag_cols))
684
+
685
+ try:
686
+ fig_spag = plot_spaghetti(
687
+ working_df, date_col, spag_cols,
688
+ alpha=alpha_val,
689
+ highlight_col=highlight_col,
690
+ top_n=top_n,
691
+ show_median_band=show_median,
692
+ title="Spaghetti Plot",
693
+ style_dict=style_dict,
694
+ palette_colors=palette_c,
695
+ )
696
+ st.pyplot(fig_spag, width="stretch")
697
+ except Exception as exc:
698
+ st.error(f"Spaghetti chart error: {exc}")
data/demo_multi_long.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/demo_multi_wide.csv ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ date,North,South,East,West
2
+ 2017-01-01,102373.1,81565.82,120039.01,85866.99
3
+ 2017-04-01,103071.84,86690.95,130160.6,92986.52
4
+ 2017-07-01,105808.38,82351.48,120806.03,93145.11
5
+ 2017-10-01,93194.45,78439.34,125560.51,88941.36
6
+ 2018-01-01,104960.57,81159.93,125077.0,94745.14
7
+ 2018-04-01,115571.37,89696.76,126428.53,110558.19
8
+ 2018-07-01,101828.39,85679.22,121587.32,96512.67
9
+ 2018-10-01,98901.11,78456.95,122047.42,94006.7
10
+ 2019-01-01,106698.95,91997.32,125729.61,99262.01
11
+ 2019-04-01,110689.57,93621.5,134342.0,104154.17
12
+ 2019-07-01,103348.01,84426.09,129419.71,97054.19
13
+ 2019-10-01,104005.69,85769.66,123581.51,96076.91
14
+ 2020-01-01,106413.09,86675.95,127059.62,97281.52
15
+ 2020-04-01,116820.78,97761.25,130855.46,104689.54
16
+ 2020-07-01,108441.73,94675.79,129860.46,99743.91
17
+ 2020-10-01,111649.8,84537.95,129569.2,97245.62
18
+ 2021-01-01,110450.24,95690.13,133442.28,109743.98
19
+ 2021-04-01,117633.82,99838.34,134862.78,102998.2
20
+ 2021-07-01,116840.55,96866.18,134919.54,106458.78
21
+ 2021-10-01,106507.41,95890.38,131355.95,95361.85
22
+ 2022-01-01,116682.38,95263.84,133348.43,104584.2
23
+ 2022-04-01,125721.43,99538.79,142261.18,115066.85
24
+ 2022-07-01,112777.55,94931.46,137774.63,107792.84
25
+ 2022-10-01,113953.9,90952.57,129971.09,100166.77
26
+ 2023-01-01,119979.65,98968.69,140273.36,107054.09
27
+ 2023-04-01,127345.47,106023.46,146682.35,117038.79
28
+ 2023-07-01,117089.15,101630.07,144049.15,108608.9
29
+ 2023-10-01,112638.63,99081.55,139761.41,107249.38
data/demo_single.csv ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ date,sales
2
+ 2014-01-01,44065.23
3
+ 2014-02-01,45923.47
4
+ 2014-03-01,51695.38
5
+ 2014-04-01,57646.06
6
+ 2014-05-01,57259.9
7
+ 2014-06-01,58531.73
8
+ 2014-07-01,61286.63
9
+ 2014-08-01,56934.87
10
+ 2014-09-01,50661.05
11
+ 2014-10-01,48885.12
12
+ 2014-11-01,44144.96
13
+ 2014-12-01,43268.54
14
+ 2015-01-01,45955.72
15
+ 2015-02-01,44773.44
16
+ 2015-03-01,49350.16
17
+ 2015-04-01,55875.42
18
+ 2015-05-01,58102.54
19
+ 2015-06-01,62028.49
20
+ 2015-07-01,58712.16
21
+ 2015-08-01,54975.39
22
+ 2015-09-01,56931.3
23
+ 2015-10-01,49748.45
24
+ 2015-11-01,47606.85
25
+ 2015-12-01,43750.5
26
+ 2016-01-01,46783.03
27
+ 2016-02-01,51221.85
28
+ 2016-03-01,52898.01
29
+ 2016-04-01,60151.4
30
+ 2016-05-01,61326.93
31
+ 2016-06-01,63216.61
32
+ 2016-07-01,61724.79
33
+ 2016-08-01,63904.56
34
+ 2016-09-01,56373.01
35
+ 2016-10-01,50484.58
36
+ 2016-11-01,51516.89
37
+ 2016-12-01,46558.31
38
+ 2017-01-01,65689.52
39
+ 2017-02-01,49480.66
40
+ 2017-03-01,54943.63
41
+ 2017-04-01,62193.72
42
+ 2017-05-01,66405.14
43
+ 2017-06-01,66542.74
44
+ 2017-07-01,65096.91
45
+ 2017-08-01,61997.79
46
+ 2017-09-01,55842.96
47
+ 2017-10-01,53560.31
48
+ 2017-11-01,51350.52
49
+ 2017-12-01,53514.24
50
+ 2018-01-01,53359.03
51
+ 2018-02-01,52273.92
52
+ 2018-03-01,60648.17
53
+ 2018-04-01,63429.84
54
+ 2018-05-01,65974.36
55
+ 2018-06-01,69823.35
56
+ 2018-07-01,69790.2
57
+ 2018-08-01,66862.56
58
+ 2018-09-01,59521.56
59
+ 2018-10-01,56781.58
60
+ 2018-11-01,55334.32
61
+ 2018-12-01,55751.09
62
+ 2019-01-01,54113.45
63
+ 2019-02-01,57828.68
64
+ 2019-03-01,60187.33
65
+ 2019-04-01,64207.59
66
+ 2019-05-01,71353.25
67
+ 2019-06-01,73712.48
68
+ 2019-07-01,69984.18
69
+ 2019-08-01,69407.07
70
+ 2019-09-01,64323.27
71
+ 2019-10-01,58509.76
72
+ 2019-11-01,57794.59
73
+ 2019-12-01,59276.07
74
+ 2020-01-01,72400.14
75
+ 2020-02-01,63729.29
76
+ 2020-03-01,59560.51
77
+ 2020-04-01,70643.81
78
+ 2020-05-01,72302.3
79
+ 2020-06-01,72801.99
80
+ 2020-07-01,72711.72
81
+ 2020-08-01,65824.86
82
+ 2020-09-01,65560.66
83
+ 2020-10-01,62914.23
84
+ 2020-11-01,62427.58
85
+ 2020-12-01,57563.46
86
+ 2021-01-01,58254.81
87
+ 2021-02-01,61996.49
88
+ 2021-03-01,69030.8
89
+ 2021-04-01,72057.5
90
+ 2021-05-01,73468.68
91
+ 2021-06-01,76826.53
92
+ 2021-07-01,75122.36
93
+ 2021-08-01,74137.29
94
+ 2021-09-01,66995.89
95
+ 2021-10-01,63944.68
96
+ 2021-11-01,61087.58
97
+ 2021-12-01,58072.97
98
+ 2022-01-01,62864.04
99
+ 2022-02-01,65922.11
100
+ 2022-03-01,69610.23
101
+ 2022-04-01,73330.83
102
+ 2022-05-01,89097.46
103
+ 2022-06-01,77358.71
104
+ 2022-07-01,76642.77
105
+ 2022-08-01,72995.45
106
+ 2022-09-01,70477.43
107
+ 2022-10-01,67808.1
108
+ 2022-11-01,68044.17
109
+ 2022-12-01,63749.16
110
+ 2023-01-01,65186.9
111
+ 2023-02-01,67651.11
112
+ 2023-03-01,68162.46
113
+ 2023-04-01,76146.97
114
+ 2023-05-01,79448.66
115
+ 2023-06-01,85526.48
116
+ 2023-07-01,79343.48
117
+ 2023-08-01,77603.09
118
+ 2023-09-01,73130.58
119
+ 2023-10-01,67062.64
120
+ 2023-11-01,68957.44
121
+ 2023-12-01,67303.87
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.54.0
2
+ pandas==2.3.3
3
+ numpy==2.4.2
4
+ matplotlib==3.10.8
5
+ statsmodels==0.14.6
6
+ scipy==1.17.0
7
+ openai==2.2.0
8
+ querychat[streamlit]==0.5.1
9
+ duckdb==1.4.4
10
+ palettable==3.3.3
11
+ pydantic==2.12.5
12
+ python-dotenv==1.1.0
scripts/generate_demo_data.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate demo CSV datasets for the time-series visualization app."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ # Reproducibility
10
+ np.random.seed(42)
11
+
12
+ # Resolve paths relative to the project root (parent of scripts/)
13
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
14
+ DATA_DIR = PROJECT_ROOT / "data"
15
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # 1. data/demo_single.csv -- Monthly retail sales (Jan 2014 - Dec 2023)
20
+ # ---------------------------------------------------------------------------
21
+ def generate_single_series() -> pd.DataFrame:
22
+ n = 120 # 10 years * 12 months
23
+ dates = pd.date_range(start="2014-01-01", periods=n, freq="MS")
24
+
25
+ months = np.arange(n)
26
+
27
+ # Upward trend: start ~50 000, grow ~200 per month
28
+ trend = 50_000 + 200 * months
29
+
30
+ # Seasonal component: sin wave peaking in December (month index 11)
31
+ # sin peaks at pi/2; December is month 11 (0-indexed within each year).
32
+ # Shift so that sin(...) = 1 when month-of-year == 11 (December).
33
+ month_of_year = months % 12
34
+ seasonal = 8_000 * np.sin(2 * np.pi * (month_of_year - 2) / 12)
35
+
36
+ # Random noise
37
+ noise = np.random.normal(0, 2_000, size=n)
38
+
39
+ sales = trend + seasonal + noise
40
+
41
+ # Inject 2-3 anomaly spikes
42
+ for idx in [36, 72, 100]:
43
+ sales[idx] += 15_000
44
+
45
+ df = pd.DataFrame({"date": dates, "sales": np.round(sales, 2)})
46
+ return df
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # 2. data/demo_multi_wide.csv -- Quarterly revenue by region (Q1 2017 - Q4 2023)
51
+ # ---------------------------------------------------------------------------
52
+ def generate_multi_wide() -> pd.DataFrame:
53
+ n = 28 # 7 years * 4 quarters
54
+ dates = pd.date_range(start="2017-01-01", periods=n, freq="QS")
55
+
56
+ quarters = np.arange(n)
57
+ quarter_of_year = quarters % 4 # 0=Q1 .. 3=Q4
58
+
59
+ regions = {
60
+ "North": 100_000,
61
+ "South": 80_000,
62
+ "East": 120_000,
63
+ "West": 90_000,
64
+ }
65
+
66
+ data: dict[str, object] = {"date": dates}
67
+
68
+ for name, base in regions.items():
69
+ trend = base + 800 * quarters
70
+ seasonal = 5_000 * np.sin(2 * np.pi * quarter_of_year / 4)
71
+ noise = np.random.normal(0, 3_000, size=n)
72
+ data[name] = np.round(trend + seasonal + noise, 2)
73
+
74
+ return pd.DataFrame(data)
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # 3. data/demo_multi_long.csv -- Daily stock prices for 20 tickers
79
+ # (2022-01-03 to 2023-12-29, business days only)
80
+ # ---------------------------------------------------------------------------
81
+ def generate_multi_long() -> pd.DataFrame:
82
+ trading_days = pd.bdate_range(start="2022-01-03", end="2023-12-29")
83
+
84
+ # 20 simple four-letter tickers: AAAA, BBBB, ..., TTTT
85
+ tickers = [chr(ord("A") + i) * 4 for i in range(20)]
86
+
87
+ daily_drift = 0.0002
88
+ daily_vol = 0.02
89
+
90
+ frames: list[pd.DataFrame] = []
91
+
92
+ for ticker in tickers:
93
+ start_price = np.random.uniform(50, 500)
94
+ n_days = len(trading_days)
95
+
96
+ # Geometric Brownian Motion: S_t = S_0 * exp(cumsum(log returns))
97
+ log_returns = np.random.normal(
98
+ daily_drift - 0.5 * daily_vol**2, daily_vol, size=n_days
99
+ )
100
+ log_returns[0] = 0 # first day: price = start_price
101
+ prices = start_price * np.exp(np.cumsum(log_returns))
102
+
103
+ frames.append(
104
+ pd.DataFrame(
105
+ {
106
+ "date": trading_days,
107
+ "ticker": ticker,
108
+ "price": np.round(prices, 2),
109
+ }
110
+ )
111
+ )
112
+
113
+ return pd.concat(frames, ignore_index=True)
114
+
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # Main
118
+ # ---------------------------------------------------------------------------
119
+ def main() -> None:
120
+ single = generate_single_series()
121
+ single.to_csv(DATA_DIR / "demo_single.csv", index=False)
122
+ print(f"Wrote {len(single)} rows -> {DATA_DIR / 'demo_single.csv'}")
123
+
124
+ wide = generate_multi_wide()
125
+ wide.to_csv(DATA_DIR / "demo_multi_wide.csv", index=False)
126
+ print(f"Wrote {len(wide)} rows -> {DATA_DIR / 'demo_multi_wide.csv'}")
127
+
128
+ long = generate_multi_long()
129
+ long.to_csv(DATA_DIR / "demo_multi_long.csv", index=False)
130
+ print(f"Wrote {len(long)} rows -> {DATA_DIR / 'demo_multi_long.csv'}")
131
+
132
+
133
+ if __name__ == "__main__":
134
+ main()
src/__init__.py ADDED
File without changes
src/ai_interpretation.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ai_interpretation.py
3
+ --------------------
4
+ AI-powered chart interpretation using OpenAI GPT-5.2 vision with
5
+ Pydantic structured output.
6
+
7
+ Provides:
8
+ - Pydantic models for structured chart analysis results
9
+ - Vision-based chart interpretation via OpenAI's GPT-5.2 model
10
+ - Streamlit rendering of interpretation results
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import base64
16
+ import json
17
+ import os
18
+ from typing import Literal
19
+
20
+ import openai
21
+ from pydantic import BaseModel, ConfigDict
22
+ import streamlit as st
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Pydantic models
27
+ # ---------------------------------------------------------------------------
28
+
29
+ class TrendInfo(BaseModel):
30
+ """Describes the overall trend detected in the chart."""
31
+
32
+ model_config = ConfigDict(extra="forbid")
33
+
34
+ direction: Literal["upward", "downward", "flat", "mixed"]
35
+ description: str
36
+
37
+
38
+ class SeasonalityInfo(BaseModel):
39
+ """Describes any seasonality detected in the chart."""
40
+
41
+ model_config = ConfigDict(extra="forbid")
42
+
43
+ detected: bool
44
+ period: str | None
45
+ description: str
46
+
47
+
48
+ class StationarityInfo(BaseModel):
49
+ """Describes whether the series appears stationary."""
50
+
51
+ model_config = ConfigDict(extra="forbid")
52
+
53
+ likely_stationary: bool
54
+ description: str
55
+
56
+
57
+ class AnomalyItem(BaseModel):
58
+ """A single anomaly or outlier observation."""
59
+
60
+ model_config = ConfigDict(extra="forbid")
61
+
62
+ approximate_location: str
63
+ description: str
64
+ severity: Literal["low", "medium", "high"]
65
+
66
+
67
+ class ChartInterpretation(BaseModel):
68
+ """Complete structured interpretation of a time-series chart."""
69
+
70
+ model_config = ConfigDict(extra="forbid")
71
+
72
+ chart_type_detected: str
73
+ trend: TrendInfo
74
+ seasonality: SeasonalityInfo
75
+ stationarity: StationarityInfo
76
+ anomalies: list[AnomalyItem]
77
+ key_observations: list[str]
78
+ summary: str
79
+ recommendations: list[str]
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # API key check
84
+ # ---------------------------------------------------------------------------
85
+
86
+ def check_api_key_available() -> bool:
87
+ """Return ``True`` if the ``OPENAI_API_KEY`` environment variable is set
88
+ and non-empty."""
89
+ key = os.environ.get("OPENAI_API_KEY", "")
90
+ return bool(key.strip())
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Chart interpretation
95
+ # ---------------------------------------------------------------------------
96
+
97
+ _SYSTEM_PROMPT = (
98
+ "You are a careful time-series analyst helping business analytics "
99
+ "students. Analyze the chart image and provide a structured "
100
+ "interpretation. Be precise about what the data shows; flag anything "
101
+ "noteworthy. Use plain language suitable for students."
102
+ )
103
+
104
+
105
+ def interpret_chart(
106
+ png_bytes: bytes,
107
+ metadata: dict,
108
+ ) -> ChartInterpretation:
109
+ """Send a chart image to GPT-5.2 vision and return a structured
110
+ interpretation.
111
+
112
+ Parameters
113
+ ----------
114
+ png_bytes:
115
+ Raw PNG image bytes of the chart to analyse.
116
+ metadata:
117
+ Context about the chart. Expected keys:
118
+
119
+ * ``chart_type`` -- e.g. ``"line"``, ``"bar"``, ``"decomposition"``
120
+ * ``frequency_label`` -- e.g. ``"Monthly"``, ``"Daily"``
121
+ * ``date_range`` -- human-readable date range string
122
+ * ``y_column`` -- name of the value column being plotted
123
+ """
124
+ try:
125
+ client = openai.OpenAI()
126
+
127
+ # Encode the PNG as a base64 data URI
128
+ b64 = base64.b64encode(png_bytes).decode("utf-8")
129
+ image_data_uri = f"data:image/png;base64,{b64}"
130
+
131
+ chart_type = metadata.get("chart_type", "time-series")
132
+ metadata_str = json.dumps(metadata, default=str)
133
+
134
+ response = client.beta.chat.completions.parse(
135
+ model="gpt-5.2-2025-12-11",
136
+ response_format=ChartInterpretation,
137
+ messages=[
138
+ {"role": "system", "content": _SYSTEM_PROMPT},
139
+ {
140
+ "role": "user",
141
+ "content": [
142
+ {
143
+ "type": "image_url",
144
+ "image_url": {"url": image_data_uri},
145
+ },
146
+ {
147
+ "type": "text",
148
+ "text": (
149
+ f"Analyze this {chart_type} chart. "
150
+ f"Metadata: {metadata_str}"
151
+ ),
152
+ },
153
+ ],
154
+ },
155
+ ],
156
+ )
157
+
158
+ # Prefer the parsed structured output
159
+ parsed = response.choices[0].message.parsed
160
+ if parsed is not None:
161
+ return parsed
162
+
163
+ # Fallback: try to manually parse the raw content
164
+ raw_content = response.choices[0].message.content or ""
165
+ data = json.loads(raw_content)
166
+ return ChartInterpretation(**data)
167
+
168
+ except Exception as exc: # noqa: BLE001
169
+ # Return a minimal interpretation that surfaces the error
170
+ return ChartInterpretation(
171
+ chart_type_detected="unknown",
172
+ trend=TrendInfo(direction="mixed", description="Unable to determine."),
173
+ seasonality=SeasonalityInfo(
174
+ detected=False, period=None, description="Unable to determine."
175
+ ),
176
+ stationarity=StationarityInfo(
177
+ likely_stationary=False, description="Unable to determine."
178
+ ),
179
+ anomalies=[],
180
+ key_observations=["AI interpretation failed; see summary for details."],
181
+ summary=f"Error during AI interpretation: {exc}",
182
+ recommendations=["Check that your OPENAI_API_KEY is set and valid."],
183
+ )
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # Streamlit rendering
188
+ # ---------------------------------------------------------------------------
189
+
190
+ _DIRECTION_EMOJI = {
191
+ "upward": "\u2197\ufe0f", # arrow upper-right
192
+ "downward": "\u2198\ufe0f", # arrow lower-right
193
+ "flat": "\u27a1\ufe0f", # arrow right
194
+ "mixed": "\u2194\ufe0f", # left-right arrow
195
+ }
196
+
197
+ _SEVERITY_COLOR = {
198
+ "low": "green",
199
+ "medium": "orange",
200
+ "high": "red",
201
+ }
202
+
203
+
204
+ def render_interpretation(interp: ChartInterpretation) -> None:
205
+ """Render a :class:`ChartInterpretation` as a styled Streamlit card.
206
+
207
+ Uses ``st.markdown``, ``st.expander``, and related widgets to lay out
208
+ the interpretation in an easy-to-read format with sections for trend,
209
+ seasonality, stationarity, anomalies, key observations, summary, and
210
+ recommendations.
211
+ """
212
+
213
+ st.markdown("### AI Chart Interpretation")
214
+ st.markdown(
215
+ f"**Detected chart type:** {interp.chart_type_detected}"
216
+ )
217
+
218
+ # ---- Summary ----------------------------------------------------------
219
+ st.markdown("---")
220
+ st.markdown(f"**Summary:** {interp.summary}")
221
+
222
+ # ---- Key observations -------------------------------------------------
223
+ with st.expander("Key Observations", expanded=True):
224
+ for obs in interp.key_observations:
225
+ st.markdown(f"- {obs}")
226
+
227
+ # ---- Trend ------------------------------------------------------------
228
+ with st.expander("Trend Analysis"):
229
+ arrow = _DIRECTION_EMOJI.get(interp.trend.direction, "")
230
+ st.markdown(
231
+ f"**Direction:** {interp.trend.direction.capitalize()} {arrow}"
232
+ )
233
+ st.markdown(interp.trend.description)
234
+
235
+ # ---- Seasonality ------------------------------------------------------
236
+ with st.expander("Seasonality"):
237
+ status = "Detected" if interp.seasonality.detected else "Not detected"
238
+ st.markdown(f"**Status:** {status}")
239
+ if interp.seasonality.period:
240
+ st.markdown(f"**Period:** {interp.seasonality.period}")
241
+ st.markdown(interp.seasonality.description)
242
+
243
+ # ---- Stationarity -----------------------------------------------------
244
+ with st.expander("Stationarity"):
245
+ label = (
246
+ "Likely stationary"
247
+ if interp.stationarity.likely_stationary
248
+ else "Likely non-stationary"
249
+ )
250
+ st.markdown(f"**Assessment:** {label}")
251
+ st.markdown(interp.stationarity.description)
252
+
253
+ # ---- Anomalies --------------------------------------------------------
254
+ with st.expander("Anomalies"):
255
+ if not interp.anomalies:
256
+ st.markdown("No anomalies detected.")
257
+ else:
258
+ for anomaly in interp.anomalies:
259
+ color = _SEVERITY_COLOR.get(anomaly.severity, "gray")
260
+ st.markdown(
261
+ f"- **[{anomaly.approximate_location}]** "
262
+ f":{color}[{anomaly.severity.upper()}] "
263
+ f"-- {anomaly.description}"
264
+ )
265
+
266
+ # ---- Recommendations --------------------------------------------------
267
+ with st.expander("Recommended Next Steps"):
268
+ for rec in interp.recommendations:
269
+ st.markdown(f"1. {rec}")
src/cleaning.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CSV ingest and auto-clean pipeline for time-series data.
3
+
4
+ Provides delimiter detection, date/numeric column suggestion,
5
+ numeric cleaning (currency, commas, percentages, parenthesised negatives),
6
+ duplicate and missing-value handling, frequency detection, and
7
+ calendar-feature extraction.
8
+ """
9
+
10
+ import csv
11
+ import io
12
+ import re
13
+ from dataclasses import dataclass, field
14
+ from datetime import timedelta
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Dataclasses
22
+ # ---------------------------------------------------------------------------
23
+
24
+ @dataclass
25
+ class CleaningReport:
26
+ """Summary produced by :func:`clean_dataframe`."""
27
+
28
+ rows_before: int = 0
29
+ rows_after: int = 0
30
+ duplicates_found: int = 0
31
+ duplicates_action: str = ""
32
+ missing_before: dict = field(default_factory=dict)
33
+ missing_after: dict = field(default_factory=dict)
34
+ parsing_warnings: list = field(default_factory=list)
35
+
36
+
37
+ @dataclass
38
+ class FrequencyInfo:
39
+ """Result of :func:`detect_frequency`."""
40
+
41
+ label: str = "Unknown"
42
+ median_delta: timedelta = timedelta(0)
43
+ is_regular: bool = False
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Delimiter detection
48
+ # ---------------------------------------------------------------------------
49
+
50
+ def detect_delimiter(file_bytes: bytes) -> str:
51
+ """Return the most likely CSV delimiter for *file_bytes*.
52
+
53
+ Uses :class:`csv.Sniffer` on the first 8 KB of text. Falls back to a
54
+ comma if the sniffer cannot decide.
55
+ """
56
+ try:
57
+ sample = file_bytes[:8192].decode("utf-8", errors="replace")
58
+ dialect = csv.Sniffer().sniff(sample)
59
+ return dialect.delimiter
60
+ except csv.Error:
61
+ return ","
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Reading uploads
66
+ # ---------------------------------------------------------------------------
67
+
68
+ def read_csv_upload(uploaded_file) -> tuple[pd.DataFrame, str]:
69
+ """Read a Streamlit ``UploadedFile`` and return ``(df, delimiter)``.
70
+
71
+ The file position is rewound so the object can be re-read later if
72
+ needed.
73
+ """
74
+ raw = uploaded_file.getvalue()
75
+ delimiter = detect_delimiter(raw)
76
+ text = raw.decode("utf-8", errors="replace")
77
+ df = pd.read_csv(io.StringIO(text), sep=delimiter)
78
+ # Rewind in case the caller wants to read again
79
+ uploaded_file.seek(0)
80
+ return df, delimiter
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Column suggestion helpers
85
+ # ---------------------------------------------------------------------------
86
+
87
+ _DATE_NAME_TOKENS = re.compile(r"(date|time|year|month|day|period)", re.IGNORECASE)
88
+
89
+
90
+ def suggest_date_columns(df: pd.DataFrame) -> list[str]:
91
+ """Return column names that are likely to contain date/time values.
92
+
93
+ Checks are applied in order:
94
+
95
+ 1. Column already has a datetime dtype.
96
+ 2. :func:`pd.to_datetime` succeeds on the first non-null values.
97
+ 3. The column *name* contains a date-related keyword.
98
+ """
99
+ candidates: list[str] = []
100
+
101
+ for col in df.columns:
102
+ # 1. Already datetime
103
+ if pd.api.types.is_datetime64_any_dtype(df[col]):
104
+ if col not in candidates:
105
+ candidates.append(col)
106
+ continue
107
+
108
+ # 2. Parseable as datetime (check up to first 5 non-null values)
109
+ sample = df[col].dropna().head(5)
110
+ if not sample.empty:
111
+ try:
112
+ pd.to_datetime(sample)
113
+ if col not in candidates:
114
+ candidates.append(col)
115
+ continue
116
+ except (ValueError, TypeError, OverflowError):
117
+ pass
118
+
119
+ # 3. Column name heuristic
120
+ if _DATE_NAME_TOKENS.search(str(col)):
121
+ if col not in candidates:
122
+ candidates.append(col)
123
+
124
+ return candidates
125
+
126
+
127
+ def suggest_numeric_columns(df: pd.DataFrame) -> list[str]:
128
+ """Return columns that are numeric or could be cleaned to numeric.
129
+
130
+ A non-numeric column qualifies if, after stripping common formatting
131
+ characters (currency symbols, commas, ``%``, parentheses), at least half
132
+ of its non-null values can be converted to a number.
133
+ """
134
+ candidates: list[str] = []
135
+
136
+ for col in df.columns:
137
+ if pd.api.types.is_numeric_dtype(df[col]):
138
+ candidates.append(col)
139
+ continue
140
+
141
+ # Attempt lightweight cleaning on a sample
142
+ sample = df[col].dropna().head(50).astype(str)
143
+ if sample.empty:
144
+ continue
145
+
146
+ cleaned = (
147
+ sample
148
+ .str.replace(r"[\$\u20ac\u00a3,% ]", "", regex=True)
149
+ .str.replace(r"^\((.+)\)$", r"-\1", regex=True)
150
+ )
151
+ numeric = pd.to_numeric(cleaned, errors="coerce")
152
+ if numeric.notna().sum() >= max(1, len(sample) * 0.5):
153
+ candidates.append(col)
154
+
155
+ return candidates
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # Numeric cleaning
160
+ # ---------------------------------------------------------------------------
161
+
162
+ def clean_numeric_series(series: pd.Series) -> pd.Series:
163
+ """Clean a series into proper numeric values.
164
+
165
+ Handles:
166
+ * Currency symbols: ``$``, ``EUR`` (U+20AC), ``GBP`` (U+00A3)
167
+ * Thousands separators (commas)
168
+ * Percentage signs
169
+ * Parenthesised negatives, e.g. ``(123)`` becomes ``-123``
170
+ """
171
+ s = series.astype(str)
172
+
173
+ # Strip currency symbols, commas, percent signs, and whitespace
174
+ s = s.str.replace(r"[\$\u20ac\u00a3,%\s]", "", regex=True)
175
+
176
+ # Convert accounting-style negatives: (123.45) -> -123.45
177
+ s = s.str.replace(r"^\((.+)\)$", r"-\1", regex=True)
178
+
179
+ return pd.to_numeric(s, errors="coerce")
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Full cleaning pipeline
184
+ # ---------------------------------------------------------------------------
185
+
186
+ def clean_dataframe(
187
+ df: pd.DataFrame,
188
+ date_col: str,
189
+ y_cols: list[str],
190
+ dup_action: str = "keep_last",
191
+ missing_action: str = "interpolate",
192
+ ) -> tuple[pd.DataFrame, CleaningReport]:
193
+ """Run the full cleaning pipeline and return ``(cleaned_df, report)``.
194
+
195
+ Parameters
196
+ ----------
197
+ df:
198
+ Input dataframe (will not be mutated).
199
+ date_col:
200
+ Name of the column to parse as dates.
201
+ y_cols:
202
+ Names of the value columns to clean to numeric.
203
+ dup_action:
204
+ How to handle duplicate dates: ``"keep_first"``, ``"keep_last"``,
205
+ or ``"drop_all"``.
206
+ missing_action:
207
+ How to handle missing values in *y_cols*: ``"interpolate"``,
208
+ ``"ffill"``, or ``"drop"``.
209
+ """
210
+ df = df.copy()
211
+ report = CleaningReport()
212
+ report.rows_before = len(df)
213
+
214
+ # --- Parse date column ------------------------------------------------
215
+ try:
216
+ df[date_col] = pd.to_datetime(df[date_col])
217
+ except Exception as exc: # noqa: BLE001
218
+ report.parsing_warnings.append(
219
+ f"Date parsing issue in column '{date_col}': {exc}"
220
+ )
221
+ # Coerce individually so partial failures become NaT
222
+ df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
223
+
224
+ nat_count = int(df[date_col].isna().sum())
225
+ if nat_count > 0:
226
+ report.parsing_warnings.append(
227
+ f"{nat_count} value(s) in '{date_col}' could not be parsed as dates."
228
+ )
229
+ df = df.dropna(subset=[date_col])
230
+
231
+ # --- Clean numeric columns --------------------------------------------
232
+ for col in y_cols:
233
+ if not pd.api.types.is_numeric_dtype(df[col]):
234
+ df[col] = clean_numeric_series(df[col])
235
+
236
+ # Record missing values *before* imputation
237
+ report.missing_before = {
238
+ col: int(df[col].isna().sum()) for col in y_cols
239
+ }
240
+
241
+ # --- Handle duplicates on date column ---------------------------------
242
+ dup_mask = df.duplicated(subset=[date_col], keep=False)
243
+ report.duplicates_found = int(dup_mask.sum())
244
+ report.duplicates_action = dup_action
245
+
246
+ if report.duplicates_found > 0:
247
+ if dup_action == "keep_first":
248
+ df = df.drop_duplicates(subset=[date_col], keep="first")
249
+ elif dup_action == "keep_last":
250
+ df = df.drop_duplicates(subset=[date_col], keep="last")
251
+ elif dup_action == "drop_all":
252
+ df = df[~dup_mask]
253
+
254
+ # --- Sort by date -----------------------------------------------------
255
+ df = df.sort_values(date_col).reset_index(drop=True)
256
+
257
+ # --- Handle missing values --------------------------------------------
258
+ if missing_action == "interpolate":
259
+ df[y_cols] = df[y_cols].interpolate(method="linear", limit_direction="both")
260
+ elif missing_action == "ffill":
261
+ df[y_cols] = df[y_cols].ffill().bfill()
262
+ elif missing_action == "drop":
263
+ df = df.dropna(subset=y_cols)
264
+
265
+ report.missing_after = {
266
+ col: int(df[col].isna().sum()) for col in y_cols
267
+ }
268
+ report.rows_after = len(df)
269
+
270
+ return df, report
271
+
272
+
273
+ # ---------------------------------------------------------------------------
274
+ # Frequency detection
275
+ # ---------------------------------------------------------------------------
276
+
277
+ def detect_frequency(df: pd.DataFrame, date_col: str) -> FrequencyInfo:
278
+ """Classify the time-series frequency based on median time delta.
279
+
280
+ Returns a :class:`FrequencyInfo` with a human-readable label, the
281
+ computed median delta, and whether the series is *regular* (the
282
+ standard deviation of deltas is less than 20 % of the median).
283
+ """
284
+ dates = df[date_col].dropna().sort_values()
285
+ if len(dates) < 2:
286
+ return FrequencyInfo(label="Unknown", median_delta=timedelta(0), is_regular=False)
287
+
288
+ deltas = dates.diff().dropna()
289
+ median_delta = deltas.median()
290
+
291
+ # Regularity: std < 20% of median
292
+ std_delta = deltas.std()
293
+ is_regular = bool(std_delta <= median_delta * 0.2) if median_delta > timedelta(0) else False
294
+
295
+ # Classify by median days
296
+ days = median_delta.days
297
+
298
+ if days <= 1:
299
+ label = "Daily"
300
+ elif 5 <= days <= 9:
301
+ label = "Weekly"
302
+ elif 25 <= days <= 35:
303
+ label = "Monthly"
304
+ elif 85 <= days <= 100:
305
+ label = "Quarterly"
306
+ elif 350 <= days <= 380:
307
+ label = "Yearly"
308
+ else:
309
+ label = "Irregular"
310
+
311
+ return FrequencyInfo(label=label, median_delta=median_delta, is_regular=is_regular)
312
+
313
+
314
+ # ---------------------------------------------------------------------------
315
+ # Calendar feature extraction
316
+ # ---------------------------------------------------------------------------
317
+
318
+ def add_time_features(df: pd.DataFrame, date_col: str) -> pd.DataFrame:
319
+ """Add calendar columns derived from *date_col*.
320
+
321
+ New columns: ``year``, ``quarter``, ``month``, ``day_of_week``.
322
+ The dataframe is returned (not copied) with new columns appended.
323
+ """
324
+ dt = df[date_col].dt
325
+ df["year"] = dt.year
326
+ df["quarter"] = dt.quarter
327
+ df["month"] = dt.month
328
+ df["day_of_week"] = dt.dayofweek
329
+ return df
src/diagnostics.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Time-series diagnostics utilities.
2
+
3
+ Provides summary statistics, stationarity tests, trend estimation,
4
+ autocorrelation analysis, seasonal decomposition, rolling statistics,
5
+ year-over-year change computation, and multi-series summaries.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from numpy.typing import NDArray
14
+ from scipy import stats
15
+ from statsmodels.tsa.stattools import adfuller, acf, pacf
16
+ from statsmodels.tsa.seasonal import seasonal_decompose, DecomposeResult
17
+
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Data classes
21
+ # ---------------------------------------------------------------------------
22
+
23
+ @dataclass
24
+ class SummaryStats:
25
+ """Container for univariate time-series summary statistics."""
26
+
27
+ count: int
28
+ missing_count: int
29
+ missing_pct: float
30
+ min_val: float
31
+ max_val: float
32
+ mean_val: float
33
+ median_val: float
34
+ std_val: float
35
+ p25: float
36
+ p75: float
37
+ date_start: pd.Timestamp
38
+ date_end: pd.Timestamp
39
+ date_span_days: int
40
+ trend_slope: float
41
+ trend_pvalue: float
42
+ adf_statistic: float
43
+ adf_pvalue: float
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Core helper functions
48
+ # ---------------------------------------------------------------------------
49
+
50
+ def compute_adf_test(series: pd.Series) -> tuple[float, float]:
51
+ """Run the Augmented Dickey-Fuller test for stationarity.
52
+
53
+ Parameters
54
+ ----------
55
+ series : pd.Series
56
+ The time-series values (NaNs are dropped automatically).
57
+
58
+ Returns
59
+ -------
60
+ tuple[float, float]
61
+ ``(adf_statistic, p_value)``. Returns ``(np.nan, np.nan)`` when the
62
+ test cannot be performed (e.g. too few observations or constant data).
63
+ """
64
+ clean = series.dropna()
65
+ if len(clean) < 2:
66
+ return np.nan, np.nan
67
+ try:
68
+ result = adfuller(clean, autolag="AIC")
69
+ return float(result[0]), float(result[1])
70
+ except Exception:
71
+ return np.nan, np.nan
72
+
73
+
74
+ def compute_trend_slope(
75
+ df: pd.DataFrame,
76
+ date_col: str,
77
+ y_col: str,
78
+ ) -> tuple[float, float]:
79
+ """Estimate a linear trend via OLS on a numeric index.
80
+
81
+ Parameters
82
+ ----------
83
+ df : pd.DataFrame
84
+ Must contain *date_col* and *y_col*.
85
+ date_col : str
86
+ Column with datetime-like values.
87
+ y_col : str
88
+ Column with numeric values.
89
+
90
+ Returns
91
+ -------
92
+ tuple[float, float]
93
+ ``(slope, p_value)`` from ``scipy.stats.linregress``.
94
+ Returns ``(np.nan, np.nan)`` when the regression cannot be computed.
95
+ """
96
+ subset = df[[date_col, y_col]].dropna()
97
+ if len(subset) < 2:
98
+ return np.nan, np.nan
99
+ try:
100
+ x = np.arange(len(subset), dtype=float)
101
+ y = subset[y_col].astype(float).values
102
+ result = stats.linregress(x, y)
103
+ return float(result.slope), float(result.pvalue)
104
+ except Exception:
105
+ return np.nan, np.nan
106
+
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Summary statistics
110
+ # ---------------------------------------------------------------------------
111
+
112
+ def compute_summary_stats(
113
+ df: pd.DataFrame,
114
+ date_col: str,
115
+ y_col: str,
116
+ ) -> SummaryStats:
117
+ """Compute a comprehensive set of summary statistics for a time series.
118
+
119
+ Parameters
120
+ ----------
121
+ df : pd.DataFrame
122
+ Source data.
123
+ date_col : str
124
+ Name of the datetime column.
125
+ y_col : str
126
+ Name of the numeric value column.
127
+
128
+ Returns
129
+ -------
130
+ SummaryStats
131
+ Dataclass instance containing descriptive stats, date range info,
132
+ trend slope / p-value, and ADF test results.
133
+ """
134
+ series = df[y_col]
135
+ dates = pd.to_datetime(df[date_col])
136
+
137
+ count = int(series.notna().sum())
138
+ missing_count = int(series.isna().sum())
139
+ total = len(series)
140
+ missing_pct = (missing_count / total * 100.0) if total > 0 else 0.0
141
+
142
+ min_val = float(series.min())
143
+ max_val = float(series.max())
144
+ mean_val = float(series.mean())
145
+ median_val = float(series.median())
146
+ std_val = float(series.std())
147
+ p25 = float(series.quantile(0.25))
148
+ p75 = float(series.quantile(0.75))
149
+
150
+ date_start = dates.min()
151
+ date_end = dates.max()
152
+ date_span_days = int((date_end - date_start).days)
153
+
154
+ trend_slope, trend_pvalue = compute_trend_slope(df, date_col, y_col)
155
+ adf_statistic, adf_pvalue = compute_adf_test(series)
156
+
157
+ return SummaryStats(
158
+ count=count,
159
+ missing_count=missing_count,
160
+ missing_pct=missing_pct,
161
+ min_val=min_val,
162
+ max_val=max_val,
163
+ mean_val=mean_val,
164
+ median_val=median_val,
165
+ std_val=std_val,
166
+ p25=p25,
167
+ p75=p75,
168
+ date_start=date_start,
169
+ date_end=date_end,
170
+ date_span_days=date_span_days,
171
+ trend_slope=trend_slope,
172
+ trend_pvalue=trend_pvalue,
173
+ adf_statistic=adf_statistic,
174
+ adf_pvalue=adf_pvalue,
175
+ )
176
+
177
+
178
+ # ---------------------------------------------------------------------------
179
+ # Autocorrelation / partial autocorrelation
180
+ # ---------------------------------------------------------------------------
181
+
182
+ def compute_acf_pacf(
183
+ series: pd.Series,
184
+ nlags: int = 40,
185
+ ) -> tuple[NDArray, NDArray, NDArray, NDArray]:
186
+ """Compute ACF and PACF with confidence intervals.
187
+
188
+ Parameters
189
+ ----------
190
+ series : pd.Series
191
+ The time-series values (NaNs are dropped automatically).
192
+ nlags : int, optional
193
+ Maximum number of lags (default 40). Automatically reduced when the
194
+ series is shorter than ``nlags + 1``.
195
+
196
+ Returns
197
+ -------
198
+ tuple[ndarray, ndarray, ndarray, ndarray]
199
+ ``(acf_values, acf_confint, pacf_values, pacf_confint)``
200
+
201
+ * ``acf_values`` -- shape ``(nlags + 1,)``
202
+ * ``acf_confint`` -- shape ``(nlags + 1, 2)``
203
+ * ``pacf_values`` -- shape ``(nlags + 1,)``
204
+ * ``pacf_confint`` -- shape ``(nlags + 1, 2)``
205
+ """
206
+ clean = series.dropna().values.astype(float)
207
+
208
+ # Ensure nlags does not exceed what the data can support.
209
+ max_possible = len(clean) - 1
210
+ if max_possible < 1:
211
+ raise ValueError(
212
+ "Series has fewer than 2 non-NaN observations; "
213
+ "cannot compute ACF/PACF."
214
+ )
215
+ nlags = min(nlags, max_possible)
216
+
217
+ acf_values, acf_confint = acf(clean, nlags=nlags, alpha=0.05)
218
+ pacf_values, pacf_confint = pacf(clean, nlags=nlags, alpha=0.05)
219
+
220
+ return acf_values, acf_confint, pacf_values, pacf_confint
221
+
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # Seasonal decomposition
225
+ # ---------------------------------------------------------------------------
226
+
227
+ def _infer_period(df: pd.DataFrame, date_col: str) -> int:
228
+ """Best-effort period inference from the date column's frequency.
229
+
230
+ Returns a sensible integer period or raises ``ValueError`` when the
231
+ frequency cannot be determined.
232
+ """
233
+ dates = pd.to_datetime(df[date_col])
234
+ freq = pd.infer_freq(dates)
235
+ if freq is None:
236
+ raise ValueError(
237
+ "Cannot infer a regular frequency from the date column. "
238
+ "Please supply an explicit 'period' argument or resample the "
239
+ "data to a regular frequency before calling compute_decomposition."
240
+ )
241
+
242
+ # Map common frequency strings to typical seasonal periods.
243
+ freq_upper = freq.upper()
244
+ period_map: dict[str, int] = {
245
+ "D": 365,
246
+ "B": 252, # business days in a year
247
+ "W": 52,
248
+ "SM": 24, # semi-monthly
249
+ "BMS": 12,
250
+ "BM": 12,
251
+ "MS": 12,
252
+ "M": 12, # calendar month end
253
+ "ME": 12, # month-end (pandas >= 2.2)
254
+ "QS": 4,
255
+ "Q": 4,
256
+ "QE": 4,
257
+ "BQ": 4,
258
+ "AS": 1,
259
+ "A": 1,
260
+ "YS": 1,
261
+ "Y": 1,
262
+ "YE": 1,
263
+ "H": 24,
264
+ "T": 60,
265
+ "MIN": 60,
266
+ "S": 60,
267
+ }
268
+
269
+ # Strip leading digits (e.g. "2W" -> "W") to normalise anchored offsets.
270
+ stripped = freq_upper.lstrip("0123456789")
271
+ # Also strip any anchor suffix like "W-SUN" -> "W".
272
+ base = stripped.split("-")[0]
273
+
274
+ if base in period_map:
275
+ return period_map[base]
276
+
277
+ raise ValueError(
278
+ f"Unable to map inferred frequency '{freq}' to a seasonal period. "
279
+ "Please provide an explicit 'period' argument."
280
+ )
281
+
282
+
283
+ def compute_decomposition(
284
+ df: pd.DataFrame,
285
+ date_col: str,
286
+ y_col: str,
287
+ model: str = "additive",
288
+ period: Optional[int] = None,
289
+ ) -> DecomposeResult:
290
+ """Decompose a time series into trend, seasonal, and residual components.
291
+
292
+ Parameters
293
+ ----------
294
+ df : pd.DataFrame
295
+ Source data.
296
+ date_col : str
297
+ Datetime column name.
298
+ y_col : str
299
+ Numeric value column name.
300
+ model : str, optional
301
+ ``"additive"`` (default) or ``"multiplicative"``.
302
+ period : int or None, optional
303
+ Seasonal period. When *None* the period is inferred from the date
304
+ column's frequency.
305
+
306
+ Returns
307
+ -------
308
+ statsmodels.tsa.seasonal.DecomposeResult
309
+
310
+ Raises
311
+ ------
312
+ ValueError
313
+ If a regular frequency cannot be inferred and *period* is not given.
314
+ """
315
+ ts = (
316
+ df[[date_col, y_col]]
317
+ .copy()
318
+ .set_index(date_col)
319
+ .sort_index()
320
+ )
321
+ ts.index = pd.to_datetime(ts.index)
322
+
323
+ # Forward-fill / back-fill small gaps so decomposition doesn't fail on
324
+ # a handful of interior NaNs.
325
+ ts[y_col] = ts[y_col].ffill().bfill()
326
+
327
+ if period is None:
328
+ period = _infer_period(df, date_col)
329
+
330
+ # Attempt to set a frequency on the index so that seasonal_decompose is
331
+ # happy; fall back to the explicit period if this fails.
332
+ if ts.index.freq is None:
333
+ inferred = pd.infer_freq(ts.index)
334
+ if inferred is not None:
335
+ ts = ts.asfreq(inferred)
336
+ ts[y_col] = ts[y_col].ffill().bfill()
337
+
338
+ return seasonal_decompose(ts[y_col], model=model, period=period)
339
+
340
+
341
+ # ---------------------------------------------------------------------------
342
+ # Rolling statistics
343
+ # ---------------------------------------------------------------------------
344
+
345
+ def compute_rolling_stats(
346
+ df: pd.DataFrame,
347
+ y_col: str,
348
+ window: int = 12,
349
+ ) -> pd.DataFrame:
350
+ """Add rolling mean and rolling standard deviation columns to *df*.
351
+
352
+ Parameters
353
+ ----------
354
+ df : pd.DataFrame
355
+ Source data (not mutated).
356
+ y_col : str
357
+ Column over which rolling statistics are calculated.
358
+ window : int, optional
359
+ Rolling window size (default 12).
360
+
361
+ Returns
362
+ -------
363
+ pd.DataFrame
364
+ Copy of *df* with two extra columns: ``rolling_mean`` and
365
+ ``rolling_std``.
366
+ """
367
+ out = df.copy()
368
+ out["rolling_mean"] = out[y_col].rolling(window=window, min_periods=1).mean()
369
+ out["rolling_std"] = out[y_col].rolling(window=window, min_periods=1).std()
370
+ return out
371
+
372
+
373
+ # ---------------------------------------------------------------------------
374
+ # Year-over-year change
375
+ # ---------------------------------------------------------------------------
376
+
377
+ def _offset_for_frequency(df: pd.DataFrame, date_col: str) -> pd.DateOffset:
378
+ """Return a 1-year ``DateOffset`` appropriate to the series frequency."""
379
+ dates = pd.to_datetime(df[date_col])
380
+ freq = pd.infer_freq(dates)
381
+
382
+ if freq is not None:
383
+ freq_upper = freq.upper().lstrip("0123456789").split("-")[0]
384
+ # For sub-monthly frequencies we shift by 365 days / 52 weeks etc.
385
+ if freq_upper in {"D", "B"}:
386
+ return pd.DateOffset(days=365)
387
+ if freq_upper in {"W"}:
388
+ return pd.DateOffset(weeks=52)
389
+ if freq_upper in {"H", "T", "MIN", "S"}:
390
+ return pd.DateOffset(days=365)
391
+
392
+ # Default: shift by 12 months (works for M, Q, and annual data).
393
+ return pd.DateOffset(months=12)
394
+
395
+
396
+ def compute_yoy_change(
397
+ df: pd.DataFrame,
398
+ date_col: str,
399
+ y_col: str,
400
+ ) -> pd.DataFrame:
401
+ """Compute year-over-year absolute and percentage change.
402
+
403
+ The number of periods to shift is determined from the inferred frequency
404
+ of the date column.
405
+
406
+ Parameters
407
+ ----------
408
+ df : pd.DataFrame
409
+ Source data (not mutated).
410
+ date_col : str
411
+ Datetime column name.
412
+ y_col : str
413
+ Numeric value column name.
414
+
415
+ Returns
416
+ -------
417
+ pd.DataFrame
418
+ Copy of *df* sorted by *date_col* with additional columns
419
+ ``yoy_abs_change`` and ``yoy_pct_change``.
420
+ """
421
+ out = df.copy().sort_values(date_col).reset_index(drop=True)
422
+ out[date_col] = pd.to_datetime(out[date_col])
423
+
424
+ # Determine the number of rows that correspond to ~1 year.
425
+ freq = pd.infer_freq(out[date_col])
426
+ if freq is not None:
427
+ freq_upper = freq.upper().lstrip("0123456789").split("-")[0]
428
+ period_map: dict[str, int] = {
429
+ "D": 365,
430
+ "B": 252,
431
+ "W": 52,
432
+ "SM": 24,
433
+ "BMS": 12,
434
+ "BM": 12,
435
+ "MS": 12,
436
+ "M": 12,
437
+ "ME": 12,
438
+ "QS": 4,
439
+ "Q": 4,
440
+ "QE": 4,
441
+ "BQ": 4,
442
+ "AS": 1,
443
+ "A": 1,
444
+ "YS": 1,
445
+ "Y": 1,
446
+ "YE": 1,
447
+ "H": 8760,
448
+ "T": 525600,
449
+ "MIN": 525600,
450
+ "S": 31536000,
451
+ }
452
+ base = freq_upper
453
+ shift_periods = period_map.get(base, 12)
454
+ else:
455
+ # Fallback: assume monthly data.
456
+ shift_periods = 12
457
+
458
+ shifted = out[y_col].shift(shift_periods)
459
+ out["yoy_abs_change"] = out[y_col] - shifted
460
+ out["yoy_pct_change"] = out["yoy_abs_change"] / shifted.abs().replace(0, np.nan) * 100.0
461
+
462
+ return out
463
+
464
+
465
+ # ---------------------------------------------------------------------------
466
+ # Multi-series summary
467
+ # ---------------------------------------------------------------------------
468
+
469
+ def compute_multi_series_summary(
470
+ df: pd.DataFrame,
471
+ date_col: str,
472
+ y_cols: list[str],
473
+ ) -> pd.DataFrame:
474
+ """Produce a summary DataFrame with one row per value column.
475
+
476
+ Parameters
477
+ ----------
478
+ df : pd.DataFrame
479
+ Source data.
480
+ date_col : str
481
+ Datetime column name.
482
+ y_cols : list[str]
483
+ List of numeric column names to summarise.
484
+
485
+ Returns
486
+ -------
487
+ pd.DataFrame
488
+ Columns: ``variable``, ``count``, ``mean``, ``std``, ``min``,
489
+ ``max``, ``trend_slope``, ``adf_pvalue``.
490
+ """
491
+ rows: list[dict] = []
492
+ for col in y_cols:
493
+ series = df[col]
494
+ slope, _ = compute_trend_slope(df, date_col, col)
495
+ _, adf_p = compute_adf_test(series)
496
+ rows.append(
497
+ {
498
+ "variable": col,
499
+ "count": int(series.notna().sum()),
500
+ "mean": float(series.mean()),
501
+ "std": float(series.std()),
502
+ "min": float(series.min()),
503
+ "max": float(series.max()),
504
+ "trend_slope": slope,
505
+ "adf_pvalue": adf_p,
506
+ }
507
+ )
508
+
509
+ return pd.DataFrame(rows)
src/plotting.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ plotting.py
3
+ -----------
4
+ Chart-generation functions for time-series visualisation.
5
+
6
+ Every public function returns a :class:`matplotlib.figure.Figure` object.
7
+ Callers (e.g. Streamlit pages) can pass the figure to ``st.pyplot(fig)``
8
+ or convert it to PNG bytes via :func:`fig_to_png_bytes`.
9
+
10
+ All functions accept an optional *style_dict* (typically from
11
+ :func:`ui_theme.get_miami_mpl_style`) and an optional *palette_colors*
12
+ list so that colours stay consistent across the application.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import io
18
+ import math
19
+ from typing import Dict, List, Optional, Sequence
20
+
21
+ # CRITICAL: set the non-interactive backend before any other mpl import.
22
+ import matplotlib
23
+ matplotlib.use("Agg")
24
+
25
+ import matplotlib.pyplot as plt # noqa: E402
26
+ import matplotlib.dates as mdates # noqa: E402
27
+ import numpy as np # noqa: E402
28
+ import pandas as pd # noqa: E402
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Brand defaults (mirrors ui_theme.py)
32
+ # ---------------------------------------------------------------------------
33
+ MIAMI_RED: str = "#C41230"
34
+ _DEFAULT_FIG_SIZE = (10, 6)
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Utility
39
+ # ---------------------------------------------------------------------------
40
+
41
+ def fig_to_png_bytes(fig: matplotlib.figure.Figure, dpi: int = 150) -> bytes:
42
+ """Render *fig* to an in-memory PNG and return the raw bytes."""
43
+ buf = io.BytesIO()
44
+ fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
45
+ buf.seek(0)
46
+ return buf.read()
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Internal helpers
51
+ # ---------------------------------------------------------------------------
52
+
53
+ class _StyleContext:
54
+ """Context manager that temporarily applies *style_dict* to rcParams.
55
+
56
+ On exit the previous values are restored so that other figures are not
57
+ affected.
58
+ """
59
+
60
+ def __init__(self, style_dict: Optional[Dict[str, object]]):
61
+ self._style = style_dict
62
+ self._saved: Dict[str, object] = {}
63
+
64
+ def __enter__(self) -> "_StyleContext":
65
+ if self._style:
66
+ for key, value in self._style.items():
67
+ self._saved[key] = plt.rcParams.get(key)
68
+ try:
69
+ plt.rcParams[key] = value
70
+ except (KeyError, ValueError):
71
+ pass
72
+ return self
73
+
74
+ def __exit__(self, *exc_info: object) -> None:
75
+ for key, value in self._saved.items():
76
+ try:
77
+ plt.rcParams[key] = value
78
+ except (KeyError, ValueError):
79
+ pass
80
+
81
+
82
+ def _default_color(palette_colors: Optional[List[str]], idx: int = 0) -> str:
83
+ """Pick a colour from *palette_colors* or fall back to MIAMI_RED."""
84
+ if palette_colors and len(palette_colors) > idx:
85
+ return palette_colors[idx % len(palette_colors)]
86
+ return MIAMI_RED
87
+
88
+
89
+ def _finish_figure(fig: matplotlib.figure.Figure) -> matplotlib.figure.Figure:
90
+ """Apply common finishing touches and return the figure."""
91
+ fig.tight_layout()
92
+ return fig
93
+
94
+
95
+ def _auto_date_axis(ax: plt.Axes) -> None:
96
+ """Auto-format and rotate date tick labels."""
97
+ ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(mdates.AutoDateLocator()))
98
+ for label in ax.get_xticklabels():
99
+ label.set_rotation(30)
100
+ label.set_ha("right")
101
+
102
+
103
+ def _grid_dims(n: int) -> tuple[int, int]:
104
+ """Return (nrows, ncols) for a compact grid of *n* panels."""
105
+ ncols = min(n, 3)
106
+ nrows = math.ceil(n / ncols)
107
+ return nrows, ncols
108
+
109
+
110
+ # ===================================================================
111
+ # 1. Line with markers
112
+ # ===================================================================
113
+
114
+ def plot_line_with_markers(
115
+ df: pd.DataFrame,
116
+ date_col: str,
117
+ y_col: str,
118
+ title: Optional[str] = None,
119
+ style_dict: Optional[Dict[str, object]] = None,
120
+ palette_colors: Optional[List[str]] = None,
121
+ ) -> matplotlib.figure.Figure:
122
+ """Simple line plot with small circle markers.
123
+
124
+ Uses the first palette colour or *MIAMI_RED* as the default.
125
+ """
126
+ with _StyleContext(style_dict):
127
+ fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
128
+ color = _default_color(palette_colors, 0)
129
+ ax.plot(
130
+ df[date_col], df[y_col],
131
+ marker="o", markersize=4, linewidth=1.5,
132
+ color=color, label=y_col,
133
+ )
134
+ ax.set_xlabel(date_col)
135
+ ax.set_ylabel(y_col)
136
+ if title:
137
+ ax.set_title(title)
138
+ _auto_date_axis(ax)
139
+ ax.legend(loc="best")
140
+ return _finish_figure(fig)
141
+
142
+
143
+ # ===================================================================
144
+ # 2. Line with coloured markers
145
+ # ===================================================================
146
+
147
+ def plot_line_colored_markers(
148
+ df: pd.DataFrame,
149
+ date_col: str,
150
+ y_col: str,
151
+ color_by: str,
152
+ palette_colors: List[str],
153
+ title: Optional[str] = None,
154
+ style_dict: Optional[Dict[str, object]] = None,
155
+ ) -> matplotlib.figure.Figure:
156
+ """Line plot where marker colour varies by a categorical column.
157
+
158
+ A legend is added mapping each unique value of *color_by* to its
159
+ colour.
160
+ """
161
+ with _StyleContext(style_dict):
162
+ fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
163
+
164
+ # Draw the connecting line in a neutral grey
165
+ ax.plot(
166
+ df[date_col], df[y_col],
167
+ linewidth=1.0, color="#AAAAAA", zorder=1,
168
+ )
169
+
170
+ # Map categories to colours
171
+ categories = df[color_by].unique()
172
+ n_cats = len(categories)
173
+ if len(palette_colors) < n_cats:
174
+ # cycle palette to cover all categories
175
+ import itertools
176
+ palette_colors = list(itertools.islice(
177
+ itertools.cycle(palette_colors), n_cats
178
+ ))
179
+
180
+ color_map = {cat: palette_colors[i] for i, cat in enumerate(categories)}
181
+
182
+ for cat in categories:
183
+ mask = df[color_by] == cat
184
+ ax.scatter(
185
+ df.loc[mask, date_col], df.loc[mask, y_col],
186
+ c=color_map[cat], label=str(cat),
187
+ s=30, zorder=2, edgecolors="white", linewidths=0.3,
188
+ )
189
+
190
+ ax.set_xlabel(date_col)
191
+ ax.set_ylabel(y_col)
192
+ if title:
193
+ ax.set_title(title)
194
+ _auto_date_axis(ax)
195
+ ax.legend(title=color_by, loc="best", fontsize=8, ncol=max(1, n_cats // 8))
196
+ return _finish_figure(fig)
197
+
198
+
199
+ # ===================================================================
200
+ # 3. Seasonal plot
201
+ # ===================================================================
202
+
203
+ def plot_seasonal(
204
+ df: pd.DataFrame,
205
+ date_col: str,
206
+ y_col: str,
207
+ period: str,
208
+ palette_name_colors: List[str],
209
+ title: Optional[str] = None,
210
+ style_dict: Optional[Dict[str, object]] = None,
211
+ ) -> matplotlib.figure.Figure:
212
+ """Seasonal plot: one line per year/cycle, x-axis is within-period position.
213
+
214
+ Parameters
215
+ ----------
216
+ period:
217
+ ``"month"`` (x = month 1-12) or ``"quarter"`` (x = quarter 1-4).
218
+ palette_name_colors:
219
+ List of hex colours; one per cycle/year.
220
+ """
221
+ with _StyleContext(style_dict):
222
+ tmp = df[[date_col, y_col]].copy()
223
+ tmp["_year"] = tmp[date_col].dt.year
224
+
225
+ if period.lower().startswith("q"):
226
+ tmp["_period_pos"] = tmp[date_col].dt.quarter
227
+ x_label = "Quarter"
228
+ else:
229
+ tmp["_period_pos"] = tmp[date_col].dt.month
230
+ x_label = "Month"
231
+
232
+ years = sorted(tmp["_year"].unique())
233
+ n_years = len(years)
234
+ if len(palette_name_colors) < n_years:
235
+ import itertools
236
+ palette_name_colors = list(itertools.islice(
237
+ itertools.cycle(palette_name_colors), n_years
238
+ ))
239
+
240
+ fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
241
+ for i, year in enumerate(years):
242
+ sub = tmp[tmp["_year"] == year].sort_values("_period_pos")
243
+ ax.plot(
244
+ sub["_period_pos"], sub[y_col],
245
+ marker="o", markersize=4, linewidth=1.4,
246
+ color=palette_name_colors[i], label=str(year),
247
+ )
248
+
249
+ ax.set_xlabel(x_label)
250
+ ax.set_ylabel(y_col)
251
+ if title:
252
+ ax.set_title(title)
253
+ ax.legend(title="Year", loc="best", fontsize=8, ncol=max(1, n_years // 6))
254
+ return _finish_figure(fig)
255
+
256
+
257
+ # ===================================================================
258
+ # 4. Seasonal sub-series
259
+ # ===================================================================
260
+
261
+ def plot_seasonal_subseries(
262
+ df: pd.DataFrame,
263
+ date_col: str,
264
+ y_col: str,
265
+ period: str,
266
+ title: Optional[str] = None,
267
+ style_dict: Optional[Dict[str, object]] = None,
268
+ palette_colors: Optional[List[str]] = None,
269
+ ) -> matplotlib.figure.Figure:
270
+ """Subseries plot with vertical panels for each season and horizontal mean lines.
271
+
272
+ Parameters
273
+ ----------
274
+ period:
275
+ ``"month"`` or ``"quarter"``.
276
+ """
277
+ with _StyleContext(style_dict):
278
+ tmp = df[[date_col, y_col]].copy()
279
+
280
+ if period.lower().startswith("q"):
281
+ tmp["_season"] = tmp[date_col].dt.quarter
282
+ labels = {1: "Q1", 2: "Q2", 3: "Q3", 4: "Q4"}
283
+ else:
284
+ tmp["_season"] = tmp[date_col].dt.month
285
+ labels = {
286
+ 1: "Jan", 2: "Feb", 3: "Mar", 4: "Apr",
287
+ 5: "May", 6: "Jun", 7: "Jul", 8: "Aug",
288
+ 9: "Sep", 10: "Oct", 11: "Nov", 12: "Dec",
289
+ }
290
+
291
+ seasons = sorted(tmp["_season"].unique())
292
+ n = len(seasons)
293
+ fig_w = max(10, n * 1.3)
294
+ fig, axes = plt.subplots(1, n, figsize=(fig_w, 5), sharey=True)
295
+ if n == 1:
296
+ axes = [axes]
297
+
298
+ color = _default_color(palette_colors, 0)
299
+
300
+ for idx, season in enumerate(seasons):
301
+ ax = axes[idx]
302
+ sub = tmp[tmp["_season"] == season].sort_values(date_col)
303
+ x_positions = range(len(sub))
304
+ ax.plot(x_positions, sub[y_col].values, marker="o", markersize=3,
305
+ linewidth=1.2, color=color)
306
+
307
+ mean_val = sub[y_col].mean()
308
+ ax.axhline(mean_val, color=MIAMI_RED, linewidth=1.8, linestyle="--", alpha=0.8)
309
+
310
+ ax.set_title(labels.get(season, str(season)), fontsize=10)
311
+ ax.set_xticks([])
312
+ ax.tick_params(axis="y", labelsize=8)
313
+ if idx == 0:
314
+ ax.set_ylabel(y_col)
315
+
316
+ if title:
317
+ fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
318
+ return _finish_figure(fig)
319
+
320
+
321
+ # ===================================================================
322
+ # 5. ACF / PACF
323
+ # ===================================================================
324
+
325
+ def plot_acf_pacf(
326
+ acf_vals: np.ndarray,
327
+ acf_ci: np.ndarray,
328
+ pacf_vals: np.ndarray,
329
+ pacf_ci: np.ndarray,
330
+ title: Optional[str] = None,
331
+ style_dict: Optional[Dict[str, object]] = None,
332
+ ) -> matplotlib.figure.Figure:
333
+ """Side-by-side ACF and PACF bar plots with confidence-interval bands.
334
+
335
+ Parameters
336
+ ----------
337
+ acf_vals, pacf_vals:
338
+ 1-D arrays of autocorrelation values (lag 0, 1, ...).
339
+ acf_ci, pacf_ci:
340
+ Arrays of shape ``(n_lags, 2)`` giving the lower and upper CI bounds.
341
+ """
342
+ with _StyleContext(style_dict):
343
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
344
+
345
+ for ax, vals, ci, sub_title in [
346
+ (ax1, acf_vals, acf_ci, "ACF"),
347
+ (ax2, pacf_vals, pacf_ci, "PACF"),
348
+ ]:
349
+ lags = np.arange(len(vals))
350
+ ax.bar(lags, vals, width=0.3, color=MIAMI_RED, alpha=0.85, zorder=2)
351
+
352
+ # Confidence band
353
+ lower = ci[:, 0]
354
+ upper = ci[:, 1]
355
+ ax.fill_between(lags, lower, upper, color="#C41230", alpha=0.12, zorder=1)
356
+ ax.axhline(0, color="black", linewidth=0.8)
357
+
358
+ ax.set_xlabel("Lag")
359
+ ax.set_ylabel("Correlation")
360
+ ax.set_title(sub_title)
361
+
362
+ if title:
363
+ fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
364
+ return _finish_figure(fig)
365
+
366
+
367
+ # ===================================================================
368
+ # 6. Decomposition
369
+ # ===================================================================
370
+
371
+ def plot_decomposition(
372
+ decomposition_result,
373
+ title: Optional[str] = None,
374
+ style_dict: Optional[Dict[str, object]] = None,
375
+ ) -> matplotlib.figure.Figure:
376
+ """4-panel plot: observed, trend, seasonal, residual.
377
+
378
+ Parameters
379
+ ----------
380
+ decomposition_result:
381
+ An object with ``.observed``, ``.trend``, ``.seasonal``, and
382
+ ``.resid`` attributes (e.g. from ``statsmodels.tsa.seasonal_decompose``).
383
+ """
384
+ with _StyleContext(style_dict):
385
+ components = [
386
+ ("Observed", decomposition_result.observed),
387
+ ("Trend", decomposition_result.trend),
388
+ ("Seasonal", decomposition_result.seasonal),
389
+ ("Residual", decomposition_result.resid),
390
+ ]
391
+ fig, axes = plt.subplots(4, 1, figsize=(10, 10), sharex=True)
392
+
393
+ for ax, (label, series) in zip(axes, components):
394
+ ax.plot(series.index, series.values, linewidth=1.2, color=MIAMI_RED)
395
+ ax.set_ylabel(label, fontsize=10)
396
+ ax.tick_params(axis="both", labelsize=9)
397
+
398
+ # Date formatting on the shared x-axis (bottom panel)
399
+ _auto_date_axis(axes[-1])
400
+
401
+ if title:
402
+ fig.suptitle(title, fontsize=14, fontweight="bold", y=1.01)
403
+ return _finish_figure(fig)
404
+
405
+
406
+ # ===================================================================
407
+ # 7. Rolling overlay
408
+ # ===================================================================
409
+
410
+ def plot_rolling_overlay(
411
+ df: pd.DataFrame,
412
+ date_col: str,
413
+ y_col: str,
414
+ window: int,
415
+ title: Optional[str] = None,
416
+ style_dict: Optional[Dict[str, object]] = None,
417
+ palette_colors: Optional[List[str]] = None,
418
+ ) -> matplotlib.figure.Figure:
419
+ """Original series (light) with rolling-mean overlay (bold) and +/-1 std band."""
420
+ with _StyleContext(style_dict):
421
+ fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
422
+
423
+ raw_color = _default_color(palette_colors, 0)
424
+ mean_color = _default_color(palette_colors, 1) if palette_colors and len(palette_colors) > 1 else "#333333"
425
+
426
+ dates = df[date_col]
427
+ vals = df[y_col]
428
+ rolling_mean = vals.rolling(window=window, center=True).mean()
429
+ rolling_std = vals.rolling(window=window, center=True).std()
430
+
431
+ # Original series (light)
432
+ ax.plot(dates, vals, linewidth=0.8, alpha=0.4, color=raw_color, label="Original")
433
+
434
+ # Rolling mean (bold)
435
+ ax.plot(dates, rolling_mean, linewidth=2.2, color=mean_color,
436
+ label=f"{window}-pt Rolling Mean")
437
+
438
+ # +/- 1 std band
439
+ ax.fill_between(
440
+ dates,
441
+ rolling_mean - rolling_std,
442
+ rolling_mean + rolling_std,
443
+ alpha=0.15, color=mean_color, label="\u00b11 Std Dev",
444
+ )
445
+
446
+ ax.set_xlabel(date_col)
447
+ ax.set_ylabel(y_col)
448
+ if title:
449
+ ax.set_title(title)
450
+ _auto_date_axis(ax)
451
+ ax.legend(loc="best")
452
+ return _finish_figure(fig)
453
+
454
+
455
+ # ===================================================================
456
+ # 8. Year-over-Year change
457
+ # ===================================================================
458
+
459
+ def plot_yoy_change(
460
+ df: pd.DataFrame,
461
+ date_col: str,
462
+ y_col: str,
463
+ yoy_df: pd.DataFrame,
464
+ title: Optional[str] = None,
465
+ style_dict: Optional[Dict[str, object]] = None,
466
+ ) -> matplotlib.figure.Figure:
467
+ """Two-subplot bar chart: absolute YoY change (top) and percentage YoY change (bottom).
468
+
469
+ Parameters
470
+ ----------
471
+ yoy_df:
472
+ DataFrame with columns ``"date"``, ``"abs_change"``, ``"pct_change"``.
473
+ """
474
+ with _StyleContext(style_dict):
475
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
476
+
477
+ dates = yoy_df["date"]
478
+ abs_change = yoy_df["abs_change"]
479
+ pct_change = yoy_df["pct_change"]
480
+
481
+ # Colours: green for positive, red for negative
482
+ abs_colors = ["#2ca02c" if v >= 0 else "#d62728" for v in abs_change]
483
+ pct_colors = ["#2ca02c" if v >= 0 else "#d62728" for v in pct_change]
484
+
485
+ ax1.bar(dates, abs_change, color=abs_colors, width=20, edgecolor="white", linewidth=0.3)
486
+ ax1.axhline(0, color="black", linewidth=0.6)
487
+ ax1.set_ylabel("Absolute Change")
488
+ ax1.set_title("Year-over-Year Absolute Change")
489
+
490
+ ax2.bar(dates, pct_change, color=pct_colors, width=20, edgecolor="white", linewidth=0.3)
491
+ ax2.axhline(0, color="black", linewidth=0.6)
492
+ ax2.set_ylabel("% Change")
493
+ ax2.set_title("Year-over-Year Percentage Change")
494
+
495
+ _auto_date_axis(ax2)
496
+
497
+ if title:
498
+ fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
499
+ return _finish_figure(fig)
500
+
501
+
502
+ # ===================================================================
503
+ # 9. Lag plot
504
+ # ===================================================================
505
+
506
+ def plot_lag(
507
+ series: pd.Series,
508
+ lag: int = 1,
509
+ title: Optional[str] = None,
510
+ style_dict: Optional[Dict[str, object]] = None,
511
+ ) -> matplotlib.figure.Figure:
512
+ """Scatter plot of y(t) vs y(t-lag) with correlation-coefficient annotation."""
513
+ with _StyleContext(style_dict):
514
+ y = series.dropna().values
515
+ y_t = y[lag:]
516
+ y_lag = y[:-lag]
517
+
518
+ corr = np.corrcoef(y_t, y_lag)[0, 1]
519
+
520
+ fig, ax = plt.subplots(figsize=(7, 7))
521
+ ax.scatter(y_lag, y_t, alpha=0.5, s=20, color=MIAMI_RED, edgecolors="white", linewidths=0.3)
522
+
523
+ # Annotation
524
+ ax.annotate(
525
+ f"r = {corr:.3f}",
526
+ xy=(0.05, 0.95), xycoords="axes fraction",
527
+ fontsize=12, fontweight="bold",
528
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="#CCCCCC", alpha=0.9),
529
+ verticalalignment="top",
530
+ )
531
+
532
+ ax.set_xlabel(f"y(t\u2212{lag})")
533
+ ax.set_ylabel("y(t)")
534
+ if title:
535
+ ax.set_title(title)
536
+ else:
537
+ ax.set_title(f"Lag-{lag} Plot")
538
+ return _finish_figure(fig)
539
+
540
+
541
+ # ===================================================================
542
+ # 10. Panel (small multiples)
543
+ # ===================================================================
544
+
545
+ def plot_panel(
546
+ df: pd.DataFrame,
547
+ date_col: str,
548
+ y_cols: List[str],
549
+ chart_type: str = "line",
550
+ shared_y: bool = True,
551
+ title: Optional[str] = None,
552
+ style_dict: Optional[Dict[str, object]] = None,
553
+ palette_colors: Optional[List[str]] = None,
554
+ ) -> matplotlib.figure.Figure:
555
+ """Small multiples: one subplot per *y_col* arranged in a grid.
556
+
557
+ Parameters
558
+ ----------
559
+ chart_type:
560
+ ``"line"`` or ``"bar"``.
561
+ shared_y:
562
+ If ``True`` all panels share the same y-axis limits.
563
+ """
564
+ with _StyleContext(style_dict):
565
+ n = len(y_cols)
566
+ nrows, ncols = _grid_dims(n)
567
+ fig_h = max(4, nrows * 3.5)
568
+ fig_w = max(8, ncols * 4.5)
569
+ fig, axes = plt.subplots(
570
+ nrows, ncols, figsize=(fig_w, fig_h),
571
+ sharey=shared_y, squeeze=False,
572
+ )
573
+ flat_axes = axes.flatten()
574
+
575
+ for i, col in enumerate(y_cols):
576
+ ax = flat_axes[i]
577
+ color = _default_color(palette_colors, i)
578
+
579
+ if chart_type == "bar":
580
+ ax.bar(df[date_col], df[col], color=color, width=2, edgecolor="white", linewidth=0.3)
581
+ else:
582
+ ax.plot(df[date_col], df[col], linewidth=1.3, color=color)
583
+
584
+ ax.set_title(col, fontsize=10)
585
+ _auto_date_axis(ax)
586
+
587
+ # Hide unused subplots
588
+ for j in range(n, len(flat_axes)):
589
+ flat_axes[j].set_visible(False)
590
+
591
+ if title:
592
+ fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
593
+ return _finish_figure(fig)
594
+
595
+
596
+ # ===================================================================
597
+ # 11. Spaghetti plot
598
+ # ===================================================================
599
+
600
+ def plot_spaghetti(
601
+ df: pd.DataFrame,
602
+ date_col: str,
603
+ y_cols: List[str],
604
+ alpha: float = 0.15,
605
+ highlight_col: Optional[str] = None,
606
+ top_n: Optional[int] = None,
607
+ show_median_band: bool = False,
608
+ title: Optional[str] = None,
609
+ style_dict: Optional[Dict[str, object]] = None,
610
+ palette_colors: Optional[List[str]] = None,
611
+ ) -> matplotlib.figure.Figure:
612
+ """All series on one plot at low opacity, with optional highlighting.
613
+
614
+ Parameters
615
+ ----------
616
+ alpha:
617
+ Opacity for the background spaghetti lines.
618
+ highlight_col:
619
+ Column name to draw with full opacity and thicker line.
620
+ top_n:
621
+ If set, highlight the *top_n* series by maximum value.
622
+ show_median_band:
623
+ If ``True``, overlay the median line and shade the IQR.
624
+ """
625
+ with _StyleContext(style_dict):
626
+ fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
627
+
628
+ dates = df[date_col]
629
+
630
+ # Determine which columns to highlight
631
+ highlight_set: set[str] = set()
632
+ if highlight_col and highlight_col in y_cols:
633
+ highlight_set.add(highlight_col)
634
+ if top_n:
635
+ max_vals = {col: df[col].max() for col in y_cols}
636
+ sorted_cols = sorted(max_vals, key=max_vals.get, reverse=True) # type: ignore[arg-type]
637
+ highlight_set.update(sorted_cols[:top_n])
638
+
639
+ # Draw all series
640
+ for i, col in enumerate(y_cols):
641
+ color = _default_color(palette_colors, i)
642
+ if col in highlight_set:
643
+ ax.plot(dates, df[col], linewidth=2.0, alpha=0.9,
644
+ color=color, label=col, zorder=3)
645
+ else:
646
+ ax.plot(dates, df[col], linewidth=0.8, alpha=alpha,
647
+ color=color, zorder=1)
648
+
649
+ # Median + IQR band
650
+ if show_median_band:
651
+ numeric_data = df[y_cols]
652
+ median_line = numeric_data.median(axis=1)
653
+ q1 = numeric_data.quantile(0.25, axis=1)
654
+ q3 = numeric_data.quantile(0.75, axis=1)
655
+
656
+ ax.plot(dates, median_line, linewidth=2.2, color="#333333",
657
+ label="Median", zorder=4)
658
+ ax.fill_between(dates, q1, q3, alpha=0.2, color="#333333",
659
+ label="IQR", zorder=2)
660
+
661
+ ax.set_xlabel(date_col)
662
+ ax.set_ylabel("Value")
663
+ if title:
664
+ ax.set_title(title)
665
+ _auto_date_axis(ax)
666
+
667
+ # Only add legend if there are labelled items
668
+ handles, labels = ax.get_legend_handles_labels()
669
+ if labels:
670
+ ax.legend(loc="best", fontsize=8)
671
+ return _finish_figure(fig)
src/querychat_helpers.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QueryChat initialization and filtered DataFrame helpers.
3
+
4
+ Provides convenience wrappers around the ``querychat`` library for
5
+ natural-language filtering of time-series DataFrames inside a Streamlit
6
+ app. All functions degrade gracefully when the package or an API key
7
+ is unavailable.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ from typing import List, Optional
14
+
15
+ import pandas as pd
16
+ import streamlit as st
17
+
18
+ try:
19
+ from querychat.streamlit import QueryChat as _QueryChat
20
+
21
+ _QUERYCHAT_AVAILABLE = True
22
+ except ImportError: # pragma: no cover
23
+ _QUERYCHAT_AVAILABLE = False
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Availability check
28
+ # ---------------------------------------------------------------------------
29
+
30
+ def check_querychat_available() -> bool:
31
+ """Return ``True`` when both *querychat* is installed and an API key is set.
32
+
33
+ QueryChat requires an ``OPENAI_API_KEY`` environment variable. This
34
+ helper lets callers gate UI elements behind a simple boolean.
35
+ """
36
+ if not _QUERYCHAT_AVAILABLE:
37
+ return False
38
+ return bool(os.environ.get("OPENAI_API_KEY"))
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # QueryChat factory
43
+ # ---------------------------------------------------------------------------
44
+
45
+ def create_querychat(
46
+ df: pd.DataFrame,
47
+ name: str = "dataset",
48
+ date_col: str = "date",
49
+ y_cols: Optional[List[str]] = None,
50
+ freq_label: str = "",
51
+ ):
52
+ """Create and return a QueryChat instance bound to *df*.
53
+
54
+ Parameters
55
+ ----------
56
+ df:
57
+ The pandas DataFrame to expose to the chat interface.
58
+ name:
59
+ A human-readable name for the dataset (used in the description).
60
+ date_col:
61
+ Name of the date/time column.
62
+ y_cols:
63
+ Names of the value (numeric) columns. If ``None``, an empty
64
+ list is used in the description.
65
+ freq_label:
66
+ Optional frequency label (e.g. ``"Monthly"``, ``"Daily"``).
67
+
68
+ Returns
69
+ -------
70
+ QueryChat instance
71
+ The object returned by ``QueryChat()``.
72
+
73
+ Raises
74
+ ------
75
+ RuntimeError
76
+ If querychat is not installed.
77
+ """
78
+ if not _QUERYCHAT_AVAILABLE:
79
+ raise RuntimeError(
80
+ "The 'querychat' package is not installed. "
81
+ "Install it with: pip install 'querychat[streamlit]'"
82
+ )
83
+
84
+ if y_cols is None:
85
+ y_cols = []
86
+
87
+ value_cols_str = ", ".join(y_cols) if y_cols else "none specified"
88
+ freq_part = f" Frequency: {freq_label}." if freq_label else ""
89
+
90
+ data_description = (
91
+ f"This dataset is named '{name}'. "
92
+ f"It contains {len(df):,} rows. "
93
+ f"The date column is '{date_col}'. "
94
+ f"Value columns: {value_cols_str}."
95
+ f"{freq_part}"
96
+ )
97
+
98
+ greeting = (
99
+ f"Hi! I can help you filter and explore the **{name}** dataset. "
100
+ "Try asking me something like:\n"
101
+ '- "Show only 2023 data"\n'
102
+ '- "Filter where sales > 60000"\n'
103
+ '- "Show rows from January to March"'
104
+ )
105
+
106
+ qc = _QueryChat(
107
+ data_source=df,
108
+ table_name=name.replace(" ", "_"),
109
+ client="openai/gpt-5.2-2025-12-11",
110
+ data_description=data_description,
111
+ greeting=greeting,
112
+ )
113
+
114
+ return qc
115
+
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # Filtered DataFrame extraction
119
+ # ---------------------------------------------------------------------------
120
+
121
+ def get_filtered_pandas_df(qc) -> pd.DataFrame:
122
+ """Extract the currently filtered DataFrame from a QueryChat instance.
123
+
124
+ The underlying ``qc.df()`` may return a *narwhals* DataFrame rather
125
+ than a pandas one. This helper transparently converts when needed
126
+ and falls back to the original frame on any error.
127
+
128
+ Parameters
129
+ ----------
130
+ qc:
131
+ A QueryChat instance previously created via :func:`create_querychat`.
132
+
133
+ Returns
134
+ -------
135
+ pd.DataFrame
136
+ The filtered data as a pandas DataFrame.
137
+ """
138
+ try:
139
+ result = qc.df()
140
+
141
+ # narwhals (or polars) DataFrames expose .to_pandas()
142
+ if hasattr(result, "to_pandas"):
143
+ return result.to_pandas()
144
+
145
+ # Already a pandas DataFrame
146
+ if isinstance(result, pd.DataFrame):
147
+ return result
148
+
149
+ # Unknown type -- attempt conversion as a last resort
150
+ return pd.DataFrame(result)
151
+ except Exception: # noqa: BLE001
152
+ # If anything goes wrong, surface the unfiltered data so the app
153
+ # can continue to function.
154
+ try:
155
+ raw = qc.df()
156
+ if isinstance(raw, pd.DataFrame):
157
+ return raw
158
+ except Exception: # noqa: BLE001
159
+ pass
160
+
161
+ return pd.DataFrame()
src/ui_theme.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ui_theme.py
3
+ -----------
4
+ Miami University branded theme and styling utilities for Streamlit apps.
5
+
6
+ Provides:
7
+ - CSS injection for Streamlit components (buttons, sidebar, metrics, cards)
8
+ - Matplotlib rcParams styled with Miami branding
9
+ - ColorBrewer palette loading via palettable with graceful fallback
10
+ - Color-swatch preview figure generation
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import itertools
16
+ from typing import Dict, List, Optional
17
+
18
+ import matplotlib.figure
19
+ import matplotlib.pyplot as plt
20
+ import streamlit as st
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Brand constants — Miami University (Ohio) official palette
24
+ # ---------------------------------------------------------------------------
25
+ MIAMI_RED: str = "#C41230"
26
+ MIAMI_BLACK: str = "#000000"
27
+ MIAMI_WHITE: str = "#FFFFFF"
28
+
29
+ # Secondary palette tokens used only inside the CSS below.
30
+ _WHITE = "#FFFFFF"
31
+ _BLACK = "#000000"
32
+ _LIGHT_GRAY = "#F5F5F5"
33
+ _BORDER_GRAY = "#E0E0E0"
34
+ _DARK_TEXT = "#000000"
35
+ _HOVER_RED = "#9E0E26"
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Streamlit CSS injection
40
+ # ---------------------------------------------------------------------------
41
+ def apply_miami_theme() -> None:
42
+ """Inject Miami-branded CSS into the active Streamlit page.
43
+
44
+ Styles affected:
45
+ * Primary buttons -- Miami Red background with white text
46
+ * Card containers -- subtle border and rounded corners
47
+ * Sidebar header -- Miami Red accent bar
48
+ * Metric cards -- light background with left red accent
49
+ """
50
+ css = f"""
51
+ <style>
52
+ /* ---- Primary buttons ---- */
53
+ .stButton > button[kind="primary"],
54
+ .stButton > button {{
55
+ background-color: {MIAMI_RED};
56
+ color: {_WHITE};
57
+ border: none;
58
+ border-radius: 6px;
59
+ padding: 0.5rem 1.25rem;
60
+ font-weight: 600;
61
+ transition: background-color 0.2s ease;
62
+ }}
63
+ .stButton > button:hover {{
64
+ background-color: {_HOVER_RED};
65
+ color: {_WHITE};
66
+ border: none;
67
+ }}
68
+ .stButton > button:active,
69
+ .stButton > button:focus {{
70
+ background-color: {_HOVER_RED};
71
+ color: {_WHITE};
72
+ box-shadow: none;
73
+ }}
74
+
75
+ /* ---- Card borders ---- */
76
+ div[data-testid="stExpander"],
77
+ div[data-testid="stHorizontalBlock"] > div {{
78
+ border: 1px solid {_BORDER_GRAY};
79
+ border-radius: 8px;
80
+ padding: 0.75rem;
81
+ }}
82
+
83
+ /* ---- Sidebar header accent ---- */
84
+ section[data-testid="stSidebar"] > div:first-child {{
85
+ border-top: 4px solid {MIAMI_RED};
86
+ }}
87
+ section[data-testid="stSidebar"] h1,
88
+ section[data-testid="stSidebar"] h2,
89
+ section[data-testid="stSidebar"] h3 {{
90
+ color: {MIAMI_RED};
91
+ }}
92
+
93
+ /* ---- Metric cards ---- */
94
+ div[data-testid="stMetric"] {{
95
+ background-color: {_LIGHT_GRAY};
96
+ border-left: 4px solid {MIAMI_RED};
97
+ border-radius: 6px;
98
+ padding: 0.75rem 1rem;
99
+ }}
100
+ div[data-testid="stMetric"] label {{
101
+ color: {_BLACK};
102
+ font-size: 0.85rem;
103
+ }}
104
+ div[data-testid="stMetric"] div[data-testid="stMetricValue"] {{
105
+ color: {_BLACK};
106
+ font-weight: 700;
107
+ }}
108
+ </style>
109
+ """
110
+ st.markdown(css, unsafe_allow_html=True)
111
+
112
+
113
+ # ---------------------------------------------------------------------------
114
+ # Matplotlib style dictionary
115
+ # ---------------------------------------------------------------------------
116
+ def get_miami_mpl_style() -> Dict[str, object]:
117
+ """Return a dictionary of matplotlib rcParams for Miami branding.
118
+
119
+ Usage::
120
+
121
+ import matplotlib as mpl
122
+ mpl.rcParams.update(get_miami_mpl_style())
123
+
124
+ Or apply to a single figure::
125
+
126
+ with mpl.rc_context(get_miami_mpl_style()):
127
+ fig, ax = plt.subplots()
128
+ ...
129
+ """
130
+ return {
131
+ # Figure
132
+ "figure.facecolor": _WHITE,
133
+ "figure.edgecolor": _WHITE,
134
+ "figure.figsize": (10, 5),
135
+ "figure.dpi": 100,
136
+ # Axes
137
+ "axes.facecolor": _WHITE,
138
+ "axes.edgecolor": _BLACK,
139
+ "axes.labelcolor": _BLACK,
140
+ "axes.titlecolor": MIAMI_RED,
141
+ "axes.labelsize": 12,
142
+ "axes.titlesize": 14,
143
+ "axes.titleweight": "bold",
144
+ "axes.prop_cycle": plt.cycler(
145
+ color=[MIAMI_RED, _BLACK, "#4E79A7", "#F28E2B", "#76B7B2"]
146
+ ),
147
+ # Grid
148
+ "axes.grid": True,
149
+ "grid.color": _BORDER_GRAY,
150
+ "grid.linestyle": "--",
151
+ "grid.linewidth": 0.6,
152
+ "grid.alpha": 0.7,
153
+ # Ticks
154
+ "xtick.color": _BLACK,
155
+ "ytick.color": _BLACK,
156
+ "xtick.labelsize": 10,
157
+ "ytick.labelsize": 10,
158
+ # Legend
159
+ "legend.fontsize": 10,
160
+ "legend.frameon": True,
161
+ "legend.framealpha": 0.9,
162
+ "legend.edgecolor": _BORDER_GRAY,
163
+ # Font
164
+ "font.size": 11,
165
+ "font.family": "sans-serif",
166
+ # Savefig
167
+ "savefig.dpi": 150,
168
+ "savefig.bbox": "tight",
169
+ }
170
+
171
+
172
+ # ---------------------------------------------------------------------------
173
+ # ColorBrewer palette loading
174
+ # ---------------------------------------------------------------------------
175
+
176
+ # Mapping of short friendly names to palettable module paths.
177
+ _PALETTE_MAP: Dict[str, str] = {
178
+ "Set1": "colorbrewer.qualitative.Set1",
179
+ "Set2": "colorbrewer.qualitative.Set2",
180
+ "Set3": "colorbrewer.qualitative.Set3",
181
+ "Dark2": "colorbrewer.qualitative.Dark2",
182
+ "Paired": "colorbrewer.qualitative.Paired",
183
+ "Pastel1": "colorbrewer.qualitative.Pastel1",
184
+ "Pastel2": "colorbrewer.qualitative.Pastel2",
185
+ "Accent": "colorbrewer.qualitative.Accent",
186
+ "Tab10": "colorbrewer.qualitative.Set1", # fallback alias
187
+ }
188
+
189
+ _FALLBACK_COLORS: List[str] = [
190
+ MIAMI_RED,
191
+ MIAMI_BLACK,
192
+ "#4E79A7",
193
+ "#F28E2B",
194
+ "#76B7B2",
195
+ "#E15759",
196
+ "#59A14F",
197
+ "#EDC948",
198
+ ]
199
+
200
+
201
+ def _resolve_palette(name: str) -> Optional[List[str]]:
202
+ """Dynamically import a palettable ColorBrewer palette by *name*.
203
+
204
+ Palettable organises palettes by maximum number of classes, e.g.
205
+ ``colorbrewer.qualitative.Set2_8``. We find the variant with the
206
+ most colours available so the caller gets the richest palette.
207
+ """
208
+ import importlib
209
+
210
+ module_path = _PALETTE_MAP.get(name)
211
+ if module_path is None:
212
+ # Try a direct guess: colorbrewer.qualitative.<Name>
213
+ module_path = f"colorbrewer.qualitative.{name}"
214
+
215
+ # palettable stores each size variant as <Name>_<N> inside the module.
216
+ try:
217
+ mod = importlib.import_module(f"palettable.{module_path}")
218
+ except (ImportError, ModuleNotFoundError):
219
+ return None
220
+
221
+ # Discover the variant with the most colours.
222
+ best = None
223
+ best_n = 0
224
+ base = name.split(".")[-1] if "." in name else name
225
+ for attr_name in dir(mod):
226
+ if not attr_name.startswith(base + "_"):
227
+ continue
228
+ try:
229
+ suffix = int(attr_name.split("_")[-1])
230
+ except ValueError:
231
+ continue
232
+ if suffix > best_n:
233
+ best_n = suffix
234
+ best = attr_name
235
+
236
+ if best is None:
237
+ return None
238
+
239
+ palette_obj = getattr(mod, best, None)
240
+ if palette_obj is None:
241
+ return None
242
+
243
+ return [
244
+ "#{:02X}{:02X}{:02X}".format(*rgb) for rgb in palette_obj.colors
245
+ ]
246
+
247
+
248
+ def get_palette_colors(name: str = "Set2", n: int = 8) -> List[str]:
249
+ """Load *n* hex colour strings from a ColorBrewer palette.
250
+
251
+ Parameters
252
+ ----------
253
+ name:
254
+ Friendly palette name such as ``"Set2"``, ``"Dark2"``, ``"Paired"``.
255
+ n:
256
+ Number of colours required. If *n* exceeds the palette length the
257
+ colours are cycled.
258
+
259
+ Returns
260
+ -------
261
+ list[str]
262
+ List of *n* hex colour strings (e.g. ``["#66C2A5", ...]``).
263
+
264
+ Notes
265
+ -----
266
+ If the requested palette cannot be found, a sensible fallback list is
267
+ returned so that calling code never receives an empty list.
268
+ """
269
+ n = max(1, n)
270
+ colors = _resolve_palette(name)
271
+ if colors is None:
272
+ colors = _FALLBACK_COLORS
273
+
274
+ # Cycle if the caller needs more colours than the palette provides.
275
+ cycled = list(itertools.islice(itertools.cycle(colors), n))
276
+ return cycled
277
+
278
+
279
+ # ---------------------------------------------------------------------------
280
+ # Palette preview swatch
281
+ # ---------------------------------------------------------------------------
282
+ def render_palette_preview(
283
+ colors: List[str],
284
+ swatch_width: float = 1.0,
285
+ swatch_height: float = 0.4,
286
+ ) -> matplotlib.figure.Figure:
287
+ """Create a small matplotlib figure showing colour swatches.
288
+
289
+ Parameters
290
+ ----------
291
+ colors:
292
+ List of hex colour strings to display.
293
+ swatch_width:
294
+ Width of each individual swatch in inches.
295
+ swatch_height:
296
+ Height of the swatch strip in inches.
297
+
298
+ Returns
299
+ -------
300
+ matplotlib.figure.Figure
301
+ A Figure instance ready to be passed to ``st.pyplot()`` or saved.
302
+ """
303
+ n = len(colors)
304
+ fig_width = max(swatch_width * n, 2.0)
305
+ fig, ax = plt.subplots(
306
+ figsize=(fig_width, swatch_height + 0.3), dpi=100
307
+ )
308
+
309
+ for i, colour in enumerate(colors):
310
+ ax.add_patch(
311
+ plt.Rectangle(
312
+ (i, 0),
313
+ width=1,
314
+ height=1,
315
+ facecolor=colour,
316
+ edgecolor=_WHITE,
317
+ linewidth=1.5,
318
+ )
319
+ )
320
+
321
+ ax.set_xlim(0, n)
322
+ ax.set_ylim(0, 1)
323
+ ax.set_aspect("equal")
324
+ ax.axis("off")
325
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
326
+ plt.close(fig) # prevent display in non-Streamlit contexts
327
+ return fig