Commit ·
0191ae7
0
Parent(s):
Initial deploy: Time Series Visualizer v0.1.0
Browse files- Dockerfile +24 -0
- README.md +28 -0
- app.py +698 -0
- data/demo_multi_long.csv +0 -0
- data/demo_multi_wide.csv +29 -0
- data/demo_single.csv +121 -0
- requirements.txt +12 -0
- scripts/generate_demo_data.py +134 -0
- src/__init__.py +0 -0
- src/ai_interpretation.py +269 -0
- src/cleaning.py +329 -0
- src/diagnostics.py +509 -0
- src/plotting.py +671 -0
- src/querychat_helpers.py +161 -0
- src/ui_theme.py +327 -0
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
RUN apt-get update && \
|
| 4 |
+
apt-get install -y --no-install-recommends gcc g++ && \
|
| 5 |
+
rm -rf /var/lib/apt/lists/*
|
| 6 |
+
|
| 7 |
+
RUN useradd -m -u 1000 user
|
| 8 |
+
USER user
|
| 9 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 10 |
+
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
COPY --chown=user:user requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 15 |
+
pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
COPY --chown=user:user . .
|
| 18 |
+
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
CMD ["streamlit", "run", "app.py", \
|
| 22 |
+
"--server.port=7860", \
|
| 23 |
+
"--server.address=0.0.0.0", \
|
| 24 |
+
"--browser.gatherUsageStats=false"]
|
README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Time Series Visualizer
|
| 3 |
+
emoji: 📈
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Time Series Visualizer + AI Chart Interpreter
|
| 12 |
+
|
| 13 |
+
A Streamlit app for Miami University Business Analytics students to upload CSV
|
| 14 |
+
time-series data, create publication-quality charts, and get AI-powered chart
|
| 15 |
+
interpretation.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- **Upload & Clean** — auto-detect delimiters, date columns, and numeric formats
|
| 20 |
+
- **9+ Chart Types** — line, seasonal, subseries, ACF/PACF, decomposition, rolling, YoY, lag, spaghetti
|
| 21 |
+
- **Multi-Series Support** — panel (small-multiples) and spaghetti plots for comparing series
|
| 22 |
+
- **AI Interpretation** — GPT-5.2 vision analyzes chart images and returns structured insights
|
| 23 |
+
- **QueryChat** — natural-language data filtering powered by DuckDB
|
| 24 |
+
|
| 25 |
+
## Privacy
|
| 26 |
+
|
| 27 |
+
All data processing happens in-memory. No data is persisted to disk.
|
| 28 |
+
Only chart PNG images (never raw data) are sent to the AI when you click "Interpret."
|
app.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Time Series Visualizer + AI Chart Interpreter
|
| 3 |
+
=============================================
|
| 4 |
+
Main Streamlit application. Run with:
|
| 5 |
+
|
| 6 |
+
streamlit run app.py --server.port=7860
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import hashlib
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
import matplotlib
|
| 18 |
+
matplotlib.use("Agg")
|
| 19 |
+
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import streamlit as st
|
| 22 |
+
|
| 23 |
+
from src.ui_theme import (
|
| 24 |
+
apply_miami_theme,
|
| 25 |
+
get_miami_mpl_style,
|
| 26 |
+
get_palette_colors,
|
| 27 |
+
render_palette_preview,
|
| 28 |
+
)
|
| 29 |
+
from src.cleaning import (
|
| 30 |
+
read_csv_upload,
|
| 31 |
+
suggest_date_columns,
|
| 32 |
+
suggest_numeric_columns,
|
| 33 |
+
clean_dataframe,
|
| 34 |
+
detect_frequency,
|
| 35 |
+
add_time_features,
|
| 36 |
+
CleaningReport,
|
| 37 |
+
FrequencyInfo,
|
| 38 |
+
)
|
| 39 |
+
from src.diagnostics import (
|
| 40 |
+
compute_summary_stats,
|
| 41 |
+
compute_acf_pacf,
|
| 42 |
+
compute_decomposition,
|
| 43 |
+
compute_rolling_stats,
|
| 44 |
+
compute_yoy_change,
|
| 45 |
+
compute_multi_series_summary,
|
| 46 |
+
)
|
| 47 |
+
from src.plotting import (
|
| 48 |
+
fig_to_png_bytes,
|
| 49 |
+
plot_line_with_markers,
|
| 50 |
+
plot_line_colored_markers,
|
| 51 |
+
plot_seasonal,
|
| 52 |
+
plot_seasonal_subseries,
|
| 53 |
+
plot_acf_pacf,
|
| 54 |
+
plot_decomposition,
|
| 55 |
+
plot_rolling_overlay,
|
| 56 |
+
plot_yoy_change,
|
| 57 |
+
plot_lag,
|
| 58 |
+
plot_panel,
|
| 59 |
+
plot_spaghetti,
|
| 60 |
+
)
|
| 61 |
+
from src.ai_interpretation import (
|
| 62 |
+
check_api_key_available,
|
| 63 |
+
interpret_chart,
|
| 64 |
+
render_interpretation,
|
| 65 |
+
)
|
| 66 |
+
from src.querychat_helpers import (
|
| 67 |
+
check_querychat_available,
|
| 68 |
+
create_querychat,
|
| 69 |
+
get_filtered_pandas_df,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Constants
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
_DATA_DIR = Path(__file__).parent / "data"
|
| 76 |
+
_DEMO_FILES = {
|
| 77 |
+
"Monthly Retail Sales (single)": _DATA_DIR / "demo_single.csv",
|
| 78 |
+
"Quarterly Revenue by Region (wide)": _DATA_DIR / "demo_multi_wide.csv",
|
| 79 |
+
"Daily Stock Prices – 20 Tickers (long)": _DATA_DIR / "demo_multi_long.csv",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
_CHART_TYPES = [
|
| 83 |
+
"Line with Markers",
|
| 84 |
+
"Line – Colored Markers",
|
| 85 |
+
"Seasonal Plot",
|
| 86 |
+
"Seasonal Sub-series",
|
| 87 |
+
"ACF / PACF",
|
| 88 |
+
"Decomposition",
|
| 89 |
+
"Rolling Mean Overlay",
|
| 90 |
+
"Year-over-Year Change",
|
| 91 |
+
"Lag Plot",
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
_PALETTE_NAMES = ["Set2", "Dark2", "Set1", "Paired", "Pastel1", "Pastel2", "Accent"]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
# Helpers
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
|
| 101 |
+
def _df_hash(df: pd.DataFrame) -> str:
|
| 102 |
+
"""Fast hash of a DataFrame for cache-key / change-detection."""
|
| 103 |
+
return hashlib.md5(
|
| 104 |
+
pd.util.hash_pandas_object(df).values.tobytes()
|
| 105 |
+
).hexdigest()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _load_demo(path: Path) -> pd.DataFrame:
|
| 109 |
+
return pd.read_csv(path)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _render_cleaning_report(report: CleaningReport) -> None:
|
| 113 |
+
"""Show a data-quality card."""
|
| 114 |
+
c1, c2, c3 = st.columns(3)
|
| 115 |
+
c1.metric("Rows before", f"{report.rows_before:,}")
|
| 116 |
+
c2.metric("Rows after", f"{report.rows_after:,}")
|
| 117 |
+
c3.metric("Duplicates found", f"{report.duplicates_found:,}")
|
| 118 |
+
|
| 119 |
+
if report.missing_before:
|
| 120 |
+
with st.expander("Missing values"):
|
| 121 |
+
cols = list(report.missing_before.keys())
|
| 122 |
+
mc1, mc2 = st.columns(2)
|
| 123 |
+
with mc1:
|
| 124 |
+
st.write("**Before cleaning**")
|
| 125 |
+
for c in cols:
|
| 126 |
+
st.write(f"- {c}: {report.missing_before[c]}")
|
| 127 |
+
with mc2:
|
| 128 |
+
st.write("**After cleaning**")
|
| 129 |
+
for c in cols:
|
| 130 |
+
st.write(f"- {c}: {report.missing_after.get(c, 0)}")
|
| 131 |
+
|
| 132 |
+
if report.parsing_warnings:
|
| 133 |
+
with st.expander("Parsing warnings"):
|
| 134 |
+
for w in report.parsing_warnings:
|
| 135 |
+
st.warning(w)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _render_summary_stats(stats) -> None:
|
| 139 |
+
"""Render SummaryStats as metric cards + expander."""
|
| 140 |
+
row1 = st.columns(4)
|
| 141 |
+
row1[0].metric("Count", f"{stats.count:,}")
|
| 142 |
+
row1[1].metric("Missing", f"{stats.missing_count} ({stats.missing_pct:.1f}%)")
|
| 143 |
+
row1[2].metric("Mean", f"{stats.mean_val:,.2f}")
|
| 144 |
+
row1[3].metric("Std Dev", f"{stats.std_val:,.2f}")
|
| 145 |
+
|
| 146 |
+
row2 = st.columns(4)
|
| 147 |
+
row2[0].metric("Min", f"{stats.min_val:,.2f}")
|
| 148 |
+
row2[1].metric("25th %ile", f"{stats.p25:,.2f}")
|
| 149 |
+
row2[2].metric("Median", f"{stats.median_val:,.2f}")
|
| 150 |
+
row2[3].metric("75th %ile / Max", f"{stats.p75:,.2f} / {stats.max_val:,.2f}")
|
| 151 |
+
|
| 152 |
+
with st.expander("Trend & Stationarity"):
|
| 153 |
+
tc1, tc2 = st.columns(2)
|
| 154 |
+
tc1.metric(
|
| 155 |
+
"Trend slope (per period)",
|
| 156 |
+
f"{stats.trend_slope:,.4f}" if pd.notna(stats.trend_slope) else "N/A",
|
| 157 |
+
help="Slope from OLS on a numeric index.",
|
| 158 |
+
)
|
| 159 |
+
tc2.metric(
|
| 160 |
+
"Trend p-value",
|
| 161 |
+
f"{stats.trend_pvalue:.4f}" if pd.notna(stats.trend_pvalue) else "N/A",
|
| 162 |
+
)
|
| 163 |
+
ac1, ac2 = st.columns(2)
|
| 164 |
+
ac1.metric(
|
| 165 |
+
"ADF statistic",
|
| 166 |
+
f"{stats.adf_statistic:.4f}" if pd.notna(stats.adf_statistic) else "N/A",
|
| 167 |
+
help="Augmented Dickey-Fuller test statistic.",
|
| 168 |
+
)
|
| 169 |
+
ac2.metric(
|
| 170 |
+
"ADF p-value",
|
| 171 |
+
f"{stats.adf_pvalue:.4f}" if pd.notna(stats.adf_pvalue) else "N/A",
|
| 172 |
+
help="p < 0.05 suggests the series is stationary.",
|
| 173 |
+
)
|
| 174 |
+
st.caption(
|
| 175 |
+
f"Date range: {stats.date_start.date()} to {stats.date_end.date()} "
|
| 176 |
+
f"({stats.date_span_days:,} days)"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ---------------------------------------------------------------------------
|
| 181 |
+
# Page config
|
| 182 |
+
# ---------------------------------------------------------------------------
|
| 183 |
+
st.set_page_config(
|
| 184 |
+
page_title="Time Series Visualizer",
|
| 185 |
+
page_icon="\U0001f4c8",
|
| 186 |
+
layout="wide",
|
| 187 |
+
)
|
| 188 |
+
apply_miami_theme()
|
| 189 |
+
style_dict = get_miami_mpl_style()
|
| 190 |
+
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
# Session state initialisation
|
| 193 |
+
# ---------------------------------------------------------------------------
|
| 194 |
+
for key in [
|
| 195 |
+
"raw_df", "cleaned_df", "cleaning_report", "freq_info",
|
| 196 |
+
"date_col", "y_cols", "qc", "qc_hash",
|
| 197 |
+
]:
|
| 198 |
+
if key not in st.session_state:
|
| 199 |
+
st.session_state[key] = None
|
| 200 |
+
|
| 201 |
+
# ---------------------------------------------------------------------------
|
| 202 |
+
# Sidebar — Data input
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
with st.sidebar:
|
| 205 |
+
st.markdown(
|
| 206 |
+
"""
|
| 207 |
+
<div style="text-align:center; margin-bottom:0.5rem;">
|
| 208 |
+
<span style="font-size:1.6rem; font-weight:800; color:#C41230;">
|
| 209 |
+
Time Series Visualizer
|
| 210 |
+
</span><br>
|
| 211 |
+
<span style="font-size:0.82rem; color:#000;">
|
| 212 |
+
ISA 444 · Miami University
|
| 213 |
+
</span>
|
| 214 |
+
</div>
|
| 215 |
+
""",
|
| 216 |
+
unsafe_allow_html=True,
|
| 217 |
+
)
|
| 218 |
+
st.divider()
|
| 219 |
+
st.header("Data Input")
|
| 220 |
+
|
| 221 |
+
uploaded = st.file_uploader("Upload a CSV file", type=["csv", "tsv", "txt"])
|
| 222 |
+
|
| 223 |
+
demo_choice = st.selectbox(
|
| 224 |
+
"Or load a demo dataset",
|
| 225 |
+
["(none)"] + list(_DEMO_FILES.keys()),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Load data
|
| 229 |
+
if uploaded is not None:
|
| 230 |
+
df_raw, delim = read_csv_upload(uploaded)
|
| 231 |
+
st.caption(f"Detected delimiter: `{repr(delim)}`")
|
| 232 |
+
st.session_state.raw_df = df_raw
|
| 233 |
+
elif demo_choice != "(none)":
|
| 234 |
+
st.session_state.raw_df = _load_demo(_DEMO_FILES[demo_choice])
|
| 235 |
+
# else: keep whatever was already in session state
|
| 236 |
+
|
| 237 |
+
raw_df: pd.DataFrame | None = st.session_state.raw_df
|
| 238 |
+
|
| 239 |
+
if raw_df is not None:
|
| 240 |
+
st.divider()
|
| 241 |
+
st.subheader("Column Selection")
|
| 242 |
+
|
| 243 |
+
# Auto-suggest
|
| 244 |
+
date_suggestions = suggest_date_columns(raw_df)
|
| 245 |
+
numeric_suggestions = suggest_numeric_columns(raw_df)
|
| 246 |
+
|
| 247 |
+
all_cols = list(raw_df.columns)
|
| 248 |
+
default_date_idx = all_cols.index(date_suggestions[0]) if date_suggestions else 0
|
| 249 |
+
|
| 250 |
+
date_col = st.selectbox("Date column", all_cols, index=default_date_idx)
|
| 251 |
+
|
| 252 |
+
remaining = [c for c in all_cols if c != date_col]
|
| 253 |
+
default_y = [c for c in numeric_suggestions if c != date_col]
|
| 254 |
+
y_cols = st.multiselect(
|
| 255 |
+
"Value column(s)",
|
| 256 |
+
remaining,
|
| 257 |
+
default=default_y[:4] if default_y else [],
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
st.session_state.date_col = date_col
|
| 261 |
+
st.session_state.y_cols = y_cols
|
| 262 |
+
|
| 263 |
+
st.divider()
|
| 264 |
+
st.subheader("Cleaning Options")
|
| 265 |
+
dup_action = st.selectbox(
|
| 266 |
+
"Duplicate dates",
|
| 267 |
+
["keep_last", "keep_first", "drop_all"],
|
| 268 |
+
)
|
| 269 |
+
missing_action = st.selectbox(
|
| 270 |
+
"Missing values",
|
| 271 |
+
["interpolate", "ffill", "drop"],
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Clean
|
| 275 |
+
if y_cols:
|
| 276 |
+
cleaned_df, report = clean_dataframe(
|
| 277 |
+
raw_df, date_col, y_cols,
|
| 278 |
+
dup_action=dup_action,
|
| 279 |
+
missing_action=missing_action,
|
| 280 |
+
)
|
| 281 |
+
freq_info = detect_frequency(cleaned_df, date_col)
|
| 282 |
+
cleaned_df = add_time_features(cleaned_df, date_col)
|
| 283 |
+
|
| 284 |
+
st.session_state.cleaned_df = cleaned_df
|
| 285 |
+
st.session_state.cleaning_report = report
|
| 286 |
+
st.session_state.freq_info = freq_info
|
| 287 |
+
|
| 288 |
+
st.caption(f"Frequency: **{freq_info.label}** "
|
| 289 |
+
f"({'regular' if freq_info.is_regular else 'irregular'})")
|
| 290 |
+
|
| 291 |
+
# Frequency override
|
| 292 |
+
freq_override = st.text_input(
|
| 293 |
+
"Override frequency label (optional)",
|
| 294 |
+
value="",
|
| 295 |
+
help="e.g. Daily, Weekly, Monthly, Quarterly, Yearly",
|
| 296 |
+
)
|
| 297 |
+
if freq_override.strip():
|
| 298 |
+
st.session_state.freq_info = FrequencyInfo(
|
| 299 |
+
label=freq_override.strip(),
|
| 300 |
+
median_delta=freq_info.median_delta,
|
| 301 |
+
is_regular=freq_info.is_regular,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# ------ QueryChat ------
|
| 305 |
+
if check_querychat_available():
|
| 306 |
+
current_hash = _df_hash(cleaned_df) + str(y_cols)
|
| 307 |
+
if st.session_state.qc_hash != current_hash:
|
| 308 |
+
st.session_state.qc = create_querychat(
|
| 309 |
+
cleaned_df,
|
| 310 |
+
name="uploaded data",
|
| 311 |
+
date_col=date_col,
|
| 312 |
+
y_cols=y_cols,
|
| 313 |
+
freq_label=st.session_state.freq_info.label,
|
| 314 |
+
)
|
| 315 |
+
st.session_state.qc_hash = current_hash
|
| 316 |
+
st.divider()
|
| 317 |
+
st.subheader("QueryChat")
|
| 318 |
+
st.session_state.qc.ui()
|
| 319 |
+
else:
|
| 320 |
+
st.divider()
|
| 321 |
+
st.info(
|
| 322 |
+
"Set `OPENAI_API_KEY` to enable QueryChat "
|
| 323 |
+
"(natural-language data filtering)."
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Reset button
|
| 327 |
+
st.divider()
|
| 328 |
+
if st.button("Reset all"):
|
| 329 |
+
for k in list(st.session_state.keys()):
|
| 330 |
+
del st.session_state[k]
|
| 331 |
+
st.rerun()
|
| 332 |
+
|
| 333 |
+
st.divider()
|
| 334 |
+
st.markdown(
|
| 335 |
+
"""
|
| 336 |
+
<div style="text-align:center; padding:0.5rem 0;">
|
| 337 |
+
<span style="font-size:0.75rem; color:#000;">
|
| 338 |
+
Developed by <strong>Fadel M. Megahed</strong><br>
|
| 339 |
+
for <strong>ISA 444</strong> · Miami University<br>
|
| 340 |
+
Version <strong>0.1.0</strong>
|
| 341 |
+
</span>
|
| 342 |
+
</div>
|
| 343 |
+
""",
|
| 344 |
+
unsafe_allow_html=True,
|
| 345 |
+
)
|
| 346 |
+
st.caption(
|
| 347 |
+
"**Privacy:** All processing is in-memory. "
|
| 348 |
+
"Only chart images (never raw data) are sent to the AI when you click Interpret."
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# ---------------------------------------------------------------------------
|
| 352 |
+
# Main area — guard
|
| 353 |
+
# ---------------------------------------------------------------------------
|
| 354 |
+
cleaned_df: pd.DataFrame | None = st.session_state.cleaned_df
|
| 355 |
+
date_col: str | None = st.session_state.date_col
|
| 356 |
+
y_cols: list[str] | None = st.session_state.y_cols
|
| 357 |
+
freq_info: FrequencyInfo | None = st.session_state.freq_info
|
| 358 |
+
report: CleaningReport | None = st.session_state.cleaning_report
|
| 359 |
+
|
| 360 |
+
if cleaned_df is None or not y_cols:
|
| 361 |
+
st.title("Time Series Visualizer")
|
| 362 |
+
st.write(
|
| 363 |
+
"Upload a CSV or choose a demo dataset from the sidebar to get started."
|
| 364 |
+
)
|
| 365 |
+
st.stop()
|
| 366 |
+
|
| 367 |
+
# If QueryChat is active, use its filtered df
|
| 368 |
+
if st.session_state.qc is not None:
|
| 369 |
+
working_df = get_filtered_pandas_df(st.session_state.qc)
|
| 370 |
+
if working_df.empty:
|
| 371 |
+
working_df = cleaned_df
|
| 372 |
+
else:
|
| 373 |
+
working_df = cleaned_df
|
| 374 |
+
|
| 375 |
+
# Data quality report
|
| 376 |
+
if report is not None:
|
| 377 |
+
with st.expander("Data Quality Report", expanded=False):
|
| 378 |
+
_render_cleaning_report(report)
|
| 379 |
+
|
| 380 |
+
# ---------------------------------------------------------------------------
|
| 381 |
+
# Tabs
|
| 382 |
+
# ---------------------------------------------------------------------------
|
| 383 |
+
tab_single, tab_few, tab_many = st.tabs([
|
| 384 |
+
"Single Series",
|
| 385 |
+
"Few Series (Panel)",
|
| 386 |
+
"Many Series (Spaghetti)",
|
| 387 |
+
])
|
| 388 |
+
|
| 389 |
+
# ===================================================================
|
| 390 |
+
# Tab A — Single Series
|
| 391 |
+
# ===================================================================
|
| 392 |
+
with tab_single:
|
| 393 |
+
if len(y_cols) == 1:
|
| 394 |
+
active_y = y_cols[0]
|
| 395 |
+
else:
|
| 396 |
+
active_y = st.selectbox("Select value column", y_cols, key="tab_a_y")
|
| 397 |
+
|
| 398 |
+
# ---- Date range filter ------------------------------------------------
|
| 399 |
+
dr_mode = st.radio(
|
| 400 |
+
"Date range",
|
| 401 |
+
["All", "Last N years", "Custom"],
|
| 402 |
+
horizontal=True,
|
| 403 |
+
key="dr_mode",
|
| 404 |
+
)
|
| 405 |
+
df_plot = working_df.copy()
|
| 406 |
+
if dr_mode == "Last N years":
|
| 407 |
+
n_years = st.slider("Years", 1, 20, 5, key="dr_n")
|
| 408 |
+
cutoff = df_plot[date_col].max() - pd.DateOffset(years=n_years)
|
| 409 |
+
df_plot = df_plot[df_plot[date_col] >= cutoff]
|
| 410 |
+
elif dr_mode == "Custom":
|
| 411 |
+
d_min = df_plot[date_col].min().date()
|
| 412 |
+
d_max = df_plot[date_col].max().date()
|
| 413 |
+
sel = st.slider("Date range", d_min, d_max, (d_min, d_max), key="dr_custom")
|
| 414 |
+
df_plot = df_plot[
|
| 415 |
+
(df_plot[date_col].dt.date >= sel[0])
|
| 416 |
+
& (df_plot[date_col].dt.date <= sel[1])
|
| 417 |
+
]
|
| 418 |
+
|
| 419 |
+
if df_plot.empty:
|
| 420 |
+
st.warning("No data in selected range.")
|
| 421 |
+
st.stop()
|
| 422 |
+
|
| 423 |
+
# ---- Chart controls ---------------------------------------------------
|
| 424 |
+
col_chart, col_opts = st.columns([2, 1])
|
| 425 |
+
with col_opts:
|
| 426 |
+
chart_type = st.selectbox("Chart type", _CHART_TYPES, key="chart_type_a")
|
| 427 |
+
|
| 428 |
+
palette_name = st.selectbox("Color palette", _PALETTE_NAMES, key="pal_a")
|
| 429 |
+
n_colors = max(12, len(y_cols))
|
| 430 |
+
palette_colors = get_palette_colors(palette_name, n_colors)
|
| 431 |
+
swatch_fig = render_palette_preview(palette_colors[:8])
|
| 432 |
+
st.pyplot(swatch_fig, width="stretch")
|
| 433 |
+
|
| 434 |
+
# Chart-specific controls
|
| 435 |
+
period_label = "month"
|
| 436 |
+
window_size = 12
|
| 437 |
+
lag_val = 1
|
| 438 |
+
decomp_model = "additive"
|
| 439 |
+
|
| 440 |
+
if chart_type in ("Seasonal Plot", "Seasonal Sub-series"):
|
| 441 |
+
period_label = st.selectbox("Period", ["month", "quarter"], key="period_a")
|
| 442 |
+
|
| 443 |
+
if chart_type == "Rolling Mean Overlay":
|
| 444 |
+
window_size = st.slider("Window", 2, 52, 12, key="window_a")
|
| 445 |
+
|
| 446 |
+
if chart_type == "Lag Plot":
|
| 447 |
+
lag_val = st.slider("Lag", 1, 52, 1, key="lag_a")
|
| 448 |
+
|
| 449 |
+
if chart_type == "Decomposition":
|
| 450 |
+
decomp_model = st.selectbox("Model", ["additive", "multiplicative"], key="decomp_a")
|
| 451 |
+
|
| 452 |
+
# ---- Render chart -----------------------------------------------------
|
| 453 |
+
with col_chart:
|
| 454 |
+
fig = None
|
| 455 |
+
try:
|
| 456 |
+
if chart_type == "Line with Markers":
|
| 457 |
+
fig = plot_line_with_markers(
|
| 458 |
+
df_plot, date_col, active_y,
|
| 459 |
+
title=f"{active_y} over Time",
|
| 460 |
+
style_dict=style_dict, palette_colors=palette_colors,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
elif chart_type == "Line – Colored Markers":
|
| 464 |
+
if "month" in df_plot.columns:
|
| 465 |
+
color_by = st.selectbox(
|
| 466 |
+
"Color by",
|
| 467 |
+
["month", "quarter", "year", "day_of_week"],
|
| 468 |
+
key="color_by_a",
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
color_by = st.selectbox("Color by", [c for c in df_plot.columns if c not in (date_col, active_y)][:5], key="color_by_a")
|
| 472 |
+
fig = plot_line_colored_markers(
|
| 473 |
+
df_plot, date_col, active_y,
|
| 474 |
+
color_by=color_by, palette_colors=palette_colors,
|
| 475 |
+
title=f"{active_y} colored by {color_by}",
|
| 476 |
+
style_dict=style_dict,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
elif chart_type == "Seasonal Plot":
|
| 480 |
+
fig = plot_seasonal(
|
| 481 |
+
df_plot, date_col, active_y,
|
| 482 |
+
period=period_label,
|
| 483 |
+
palette_name_colors=palette_colors,
|
| 484 |
+
title=f"Seasonal Plot – {active_y}",
|
| 485 |
+
style_dict=style_dict,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
elif chart_type == "Seasonal Sub-series":
|
| 489 |
+
fig = plot_seasonal_subseries(
|
| 490 |
+
df_plot, date_col, active_y,
|
| 491 |
+
period=period_label,
|
| 492 |
+
title=f"Seasonal Sub-series – {active_y}",
|
| 493 |
+
style_dict=style_dict, palette_colors=palette_colors,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
elif chart_type == "ACF / PACF":
|
| 497 |
+
series = df_plot[active_y].dropna()
|
| 498 |
+
acf_vals, acf_ci, pacf_vals, pacf_ci = compute_acf_pacf(series)
|
| 499 |
+
fig = plot_acf_pacf(
|
| 500 |
+
acf_vals, acf_ci, pacf_vals, pacf_ci,
|
| 501 |
+
title=f"ACF / PACF – {active_y}",
|
| 502 |
+
style_dict=style_dict,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
elif chart_type == "Decomposition":
|
| 506 |
+
period_int = None
|
| 507 |
+
if freq_info and freq_info.label == "Monthly":
|
| 508 |
+
period_int = 12
|
| 509 |
+
elif freq_info and freq_info.label == "Quarterly":
|
| 510 |
+
period_int = 4
|
| 511 |
+
elif freq_info and freq_info.label == "Weekly":
|
| 512 |
+
period_int = 52
|
| 513 |
+
elif freq_info and freq_info.label == "Daily":
|
| 514 |
+
period_int = 365
|
| 515 |
+
|
| 516 |
+
result = compute_decomposition(
|
| 517 |
+
df_plot, date_col, active_y,
|
| 518 |
+
model=decomp_model, period=period_int,
|
| 519 |
+
)
|
| 520 |
+
fig = plot_decomposition(
|
| 521 |
+
result,
|
| 522 |
+
title=f"Decomposition – {active_y} ({decomp_model})",
|
| 523 |
+
style_dict=style_dict,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
elif chart_type == "Rolling Mean Overlay":
|
| 527 |
+
fig = plot_rolling_overlay(
|
| 528 |
+
df_plot, date_col, active_y,
|
| 529 |
+
window=window_size,
|
| 530 |
+
title=f"Rolling {window_size}-pt Mean – {active_y}",
|
| 531 |
+
style_dict=style_dict, palette_colors=palette_colors,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
elif chart_type == "Year-over-Year Change":
|
| 535 |
+
yoy_result = compute_yoy_change(df_plot, date_col, active_y)
|
| 536 |
+
yoy_df = pd.DataFrame({
|
| 537 |
+
"date": yoy_result[date_col],
|
| 538 |
+
"abs_change": yoy_result["yoy_abs_change"],
|
| 539 |
+
"pct_change": yoy_result["yoy_pct_change"],
|
| 540 |
+
}).dropna()
|
| 541 |
+
fig = plot_yoy_change(
|
| 542 |
+
df_plot, date_col, active_y, yoy_df,
|
| 543 |
+
title=f"Year-over-Year Change – {active_y}",
|
| 544 |
+
style_dict=style_dict,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
elif chart_type == "Lag Plot":
|
| 548 |
+
fig = plot_lag(
|
| 549 |
+
df_plot[active_y],
|
| 550 |
+
lag=lag_val,
|
| 551 |
+
title=f"Lag-{lag_val} Plot – {active_y}",
|
| 552 |
+
style_dict=style_dict,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
except Exception as exc:
|
| 556 |
+
st.error(f"Chart error: {exc}")
|
| 557 |
+
|
| 558 |
+
if fig is not None:
|
| 559 |
+
st.pyplot(fig, width="stretch")
|
| 560 |
+
|
| 561 |
+
# ---- Summary stats expander -------------------------------------------
|
| 562 |
+
with st.expander("Summary Statistics", expanded=False):
|
| 563 |
+
stats = compute_summary_stats(df_plot, date_col, active_y)
|
| 564 |
+
_render_summary_stats(stats)
|
| 565 |
+
|
| 566 |
+
# ---- AI Interpretation ------------------------------------------------
|
| 567 |
+
with st.expander("AI Chart Interpretation", expanded=False):
|
| 568 |
+
st.caption(
|
| 569 |
+
"The chart image (PNG) and metadata are sent to OpenAI. "
|
| 570 |
+
"No raw data leaves this app."
|
| 571 |
+
)
|
| 572 |
+
if not check_api_key_available():
|
| 573 |
+
st.warning("Set `OPENAI_API_KEY` to enable AI interpretation.")
|
| 574 |
+
elif fig is not None:
|
| 575 |
+
if st.button("Interpret Chart with AI", key="interpret_a"):
|
| 576 |
+
with st.spinner("Analyzing chart..."):
|
| 577 |
+
png = fig_to_png_bytes(fig)
|
| 578 |
+
date_range_str = (
|
| 579 |
+
f"{df_plot[date_col].min().date()} to "
|
| 580 |
+
f"{df_plot[date_col].max().date()}"
|
| 581 |
+
)
|
| 582 |
+
metadata = {
|
| 583 |
+
"chart_type": chart_type,
|
| 584 |
+
"frequency_label": freq_info.label if freq_info else "Unknown",
|
| 585 |
+
"date_range": date_range_str,
|
| 586 |
+
"y_column": active_y,
|
| 587 |
+
}
|
| 588 |
+
interp = interpret_chart(png, metadata)
|
| 589 |
+
render_interpretation(interp)
|
| 590 |
+
|
| 591 |
+
# ===================================================================
|
| 592 |
+
# Tab B — Few Series (Panel)
|
| 593 |
+
# ===================================================================
|
| 594 |
+
with tab_few:
|
| 595 |
+
if len(y_cols) < 2:
|
| 596 |
+
st.info("Select 2+ value columns in the sidebar to use panel plots.")
|
| 597 |
+
else:
|
| 598 |
+
st.subheader("Panel Plot (Small Multiples)")
|
| 599 |
+
|
| 600 |
+
panel_cols = st.multiselect(
|
| 601 |
+
"Columns to plot",
|
| 602 |
+
y_cols,
|
| 603 |
+
default=y_cols[:4],
|
| 604 |
+
key="panel_cols",
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
if panel_cols:
|
| 608 |
+
pc1, pc2 = st.columns(2)
|
| 609 |
+
with pc1:
|
| 610 |
+
panel_chart = st.selectbox(
|
| 611 |
+
"Chart type", ["line", "bar"], key="panel_chart"
|
| 612 |
+
)
|
| 613 |
+
with pc2:
|
| 614 |
+
shared_y = st.checkbox("Shared Y axis", value=True, key="panel_shared")
|
| 615 |
+
|
| 616 |
+
palette_name_b = st.selectbox("Color palette", _PALETTE_NAMES, key="pal_b")
|
| 617 |
+
palette_b = get_palette_colors(palette_name_b, len(panel_cols))
|
| 618 |
+
|
| 619 |
+
try:
|
| 620 |
+
fig_panel = plot_panel(
|
| 621 |
+
working_df, date_col, panel_cols,
|
| 622 |
+
chart_type=panel_chart,
|
| 623 |
+
shared_y=shared_y,
|
| 624 |
+
title="Panel Comparison",
|
| 625 |
+
style_dict=style_dict,
|
| 626 |
+
palette_colors=palette_b,
|
| 627 |
+
)
|
| 628 |
+
st.pyplot(fig_panel, width="stretch")
|
| 629 |
+
except Exception as exc:
|
| 630 |
+
st.error(f"Panel chart error: {exc}")
|
| 631 |
+
|
| 632 |
+
# Per-series summary table
|
| 633 |
+
with st.expander("Per-series Summary", expanded=False):
|
| 634 |
+
summary_df = compute_multi_series_summary(
|
| 635 |
+
working_df, date_col, panel_cols,
|
| 636 |
+
)
|
| 637 |
+
st.dataframe(
|
| 638 |
+
summary_df.style.format({
|
| 639 |
+
"mean": "{:,.2f}",
|
| 640 |
+
"std": "{:,.2f}",
|
| 641 |
+
"min": "{:,.2f}",
|
| 642 |
+
"max": "{:,.2f}",
|
| 643 |
+
"trend_slope": "{:,.4f}",
|
| 644 |
+
"adf_pvalue": "{:.4f}",
|
| 645 |
+
}),
|
| 646 |
+
width="stretch",
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
# ===================================================================
|
| 650 |
+
# Tab C — Many Series (Spaghetti)
|
| 651 |
+
# ===================================================================
|
| 652 |
+
with tab_many:
|
| 653 |
+
if len(y_cols) < 2:
|
| 654 |
+
st.info("Select 2+ value columns in the sidebar to use spaghetti plots.")
|
| 655 |
+
else:
|
| 656 |
+
st.subheader("Spaghetti Plot")
|
| 657 |
+
|
| 658 |
+
spag_cols = st.multiselect(
|
| 659 |
+
"Columns to include",
|
| 660 |
+
y_cols,
|
| 661 |
+
default=y_cols,
|
| 662 |
+
key="spag_cols",
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
if spag_cols:
|
| 666 |
+
sc1, sc2, sc3 = st.columns(3)
|
| 667 |
+
with sc1:
|
| 668 |
+
alpha_val = st.slider("Alpha", 0.05, 1.0, 0.15, 0.05, key="spag_alpha")
|
| 669 |
+
with sc2:
|
| 670 |
+
top_n = st.number_input("Highlight top N", 0, len(spag_cols), 0, key="spag_topn")
|
| 671 |
+
top_n = top_n if top_n > 0 else None
|
| 672 |
+
with sc3:
|
| 673 |
+
highlight = st.selectbox(
|
| 674 |
+
"Highlight series",
|
| 675 |
+
["(none)"] + spag_cols,
|
| 676 |
+
key="spag_highlight",
|
| 677 |
+
)
|
| 678 |
+
highlight_col = highlight if highlight != "(none)" else None
|
| 679 |
+
|
| 680 |
+
show_median = st.checkbox("Show Median + IQR band", value=False, key="spag_median")
|
| 681 |
+
|
| 682 |
+
palette_name_c = st.selectbox("Color palette", _PALETTE_NAMES, key="pal_c")
|
| 683 |
+
palette_c = get_palette_colors(palette_name_c, len(spag_cols))
|
| 684 |
+
|
| 685 |
+
try:
|
| 686 |
+
fig_spag = plot_spaghetti(
|
| 687 |
+
working_df, date_col, spag_cols,
|
| 688 |
+
alpha=alpha_val,
|
| 689 |
+
highlight_col=highlight_col,
|
| 690 |
+
top_n=top_n,
|
| 691 |
+
show_median_band=show_median,
|
| 692 |
+
title="Spaghetti Plot",
|
| 693 |
+
style_dict=style_dict,
|
| 694 |
+
palette_colors=palette_c,
|
| 695 |
+
)
|
| 696 |
+
st.pyplot(fig_spag, width="stretch")
|
| 697 |
+
except Exception as exc:
|
| 698 |
+
st.error(f"Spaghetti chart error: {exc}")
|
data/demo_multi_long.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/demo_multi_wide.csv
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
date,North,South,East,West
|
| 2 |
+
2017-01-01,102373.1,81565.82,120039.01,85866.99
|
| 3 |
+
2017-04-01,103071.84,86690.95,130160.6,92986.52
|
| 4 |
+
2017-07-01,105808.38,82351.48,120806.03,93145.11
|
| 5 |
+
2017-10-01,93194.45,78439.34,125560.51,88941.36
|
| 6 |
+
2018-01-01,104960.57,81159.93,125077.0,94745.14
|
| 7 |
+
2018-04-01,115571.37,89696.76,126428.53,110558.19
|
| 8 |
+
2018-07-01,101828.39,85679.22,121587.32,96512.67
|
| 9 |
+
2018-10-01,98901.11,78456.95,122047.42,94006.7
|
| 10 |
+
2019-01-01,106698.95,91997.32,125729.61,99262.01
|
| 11 |
+
2019-04-01,110689.57,93621.5,134342.0,104154.17
|
| 12 |
+
2019-07-01,103348.01,84426.09,129419.71,97054.19
|
| 13 |
+
2019-10-01,104005.69,85769.66,123581.51,96076.91
|
| 14 |
+
2020-01-01,106413.09,86675.95,127059.62,97281.52
|
| 15 |
+
2020-04-01,116820.78,97761.25,130855.46,104689.54
|
| 16 |
+
2020-07-01,108441.73,94675.79,129860.46,99743.91
|
| 17 |
+
2020-10-01,111649.8,84537.95,129569.2,97245.62
|
| 18 |
+
2021-01-01,110450.24,95690.13,133442.28,109743.98
|
| 19 |
+
2021-04-01,117633.82,99838.34,134862.78,102998.2
|
| 20 |
+
2021-07-01,116840.55,96866.18,134919.54,106458.78
|
| 21 |
+
2021-10-01,106507.41,95890.38,131355.95,95361.85
|
| 22 |
+
2022-01-01,116682.38,95263.84,133348.43,104584.2
|
| 23 |
+
2022-04-01,125721.43,99538.79,142261.18,115066.85
|
| 24 |
+
2022-07-01,112777.55,94931.46,137774.63,107792.84
|
| 25 |
+
2022-10-01,113953.9,90952.57,129971.09,100166.77
|
| 26 |
+
2023-01-01,119979.65,98968.69,140273.36,107054.09
|
| 27 |
+
2023-04-01,127345.47,106023.46,146682.35,117038.79
|
| 28 |
+
2023-07-01,117089.15,101630.07,144049.15,108608.9
|
| 29 |
+
2023-10-01,112638.63,99081.55,139761.41,107249.38
|
data/demo_single.csv
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
date,sales
|
| 2 |
+
2014-01-01,44065.23
|
| 3 |
+
2014-02-01,45923.47
|
| 4 |
+
2014-03-01,51695.38
|
| 5 |
+
2014-04-01,57646.06
|
| 6 |
+
2014-05-01,57259.9
|
| 7 |
+
2014-06-01,58531.73
|
| 8 |
+
2014-07-01,61286.63
|
| 9 |
+
2014-08-01,56934.87
|
| 10 |
+
2014-09-01,50661.05
|
| 11 |
+
2014-10-01,48885.12
|
| 12 |
+
2014-11-01,44144.96
|
| 13 |
+
2014-12-01,43268.54
|
| 14 |
+
2015-01-01,45955.72
|
| 15 |
+
2015-02-01,44773.44
|
| 16 |
+
2015-03-01,49350.16
|
| 17 |
+
2015-04-01,55875.42
|
| 18 |
+
2015-05-01,58102.54
|
| 19 |
+
2015-06-01,62028.49
|
| 20 |
+
2015-07-01,58712.16
|
| 21 |
+
2015-08-01,54975.39
|
| 22 |
+
2015-09-01,56931.3
|
| 23 |
+
2015-10-01,49748.45
|
| 24 |
+
2015-11-01,47606.85
|
| 25 |
+
2015-12-01,43750.5
|
| 26 |
+
2016-01-01,46783.03
|
| 27 |
+
2016-02-01,51221.85
|
| 28 |
+
2016-03-01,52898.01
|
| 29 |
+
2016-04-01,60151.4
|
| 30 |
+
2016-05-01,61326.93
|
| 31 |
+
2016-06-01,63216.61
|
| 32 |
+
2016-07-01,61724.79
|
| 33 |
+
2016-08-01,63904.56
|
| 34 |
+
2016-09-01,56373.01
|
| 35 |
+
2016-10-01,50484.58
|
| 36 |
+
2016-11-01,51516.89
|
| 37 |
+
2016-12-01,46558.31
|
| 38 |
+
2017-01-01,65689.52
|
| 39 |
+
2017-02-01,49480.66
|
| 40 |
+
2017-03-01,54943.63
|
| 41 |
+
2017-04-01,62193.72
|
| 42 |
+
2017-05-01,66405.14
|
| 43 |
+
2017-06-01,66542.74
|
| 44 |
+
2017-07-01,65096.91
|
| 45 |
+
2017-08-01,61997.79
|
| 46 |
+
2017-09-01,55842.96
|
| 47 |
+
2017-10-01,53560.31
|
| 48 |
+
2017-11-01,51350.52
|
| 49 |
+
2017-12-01,53514.24
|
| 50 |
+
2018-01-01,53359.03
|
| 51 |
+
2018-02-01,52273.92
|
| 52 |
+
2018-03-01,60648.17
|
| 53 |
+
2018-04-01,63429.84
|
| 54 |
+
2018-05-01,65974.36
|
| 55 |
+
2018-06-01,69823.35
|
| 56 |
+
2018-07-01,69790.2
|
| 57 |
+
2018-08-01,66862.56
|
| 58 |
+
2018-09-01,59521.56
|
| 59 |
+
2018-10-01,56781.58
|
| 60 |
+
2018-11-01,55334.32
|
| 61 |
+
2018-12-01,55751.09
|
| 62 |
+
2019-01-01,54113.45
|
| 63 |
+
2019-02-01,57828.68
|
| 64 |
+
2019-03-01,60187.33
|
| 65 |
+
2019-04-01,64207.59
|
| 66 |
+
2019-05-01,71353.25
|
| 67 |
+
2019-06-01,73712.48
|
| 68 |
+
2019-07-01,69984.18
|
| 69 |
+
2019-08-01,69407.07
|
| 70 |
+
2019-09-01,64323.27
|
| 71 |
+
2019-10-01,58509.76
|
| 72 |
+
2019-11-01,57794.59
|
| 73 |
+
2019-12-01,59276.07
|
| 74 |
+
2020-01-01,72400.14
|
| 75 |
+
2020-02-01,63729.29
|
| 76 |
+
2020-03-01,59560.51
|
| 77 |
+
2020-04-01,70643.81
|
| 78 |
+
2020-05-01,72302.3
|
| 79 |
+
2020-06-01,72801.99
|
| 80 |
+
2020-07-01,72711.72
|
| 81 |
+
2020-08-01,65824.86
|
| 82 |
+
2020-09-01,65560.66
|
| 83 |
+
2020-10-01,62914.23
|
| 84 |
+
2020-11-01,62427.58
|
| 85 |
+
2020-12-01,57563.46
|
| 86 |
+
2021-01-01,58254.81
|
| 87 |
+
2021-02-01,61996.49
|
| 88 |
+
2021-03-01,69030.8
|
| 89 |
+
2021-04-01,72057.5
|
| 90 |
+
2021-05-01,73468.68
|
| 91 |
+
2021-06-01,76826.53
|
| 92 |
+
2021-07-01,75122.36
|
| 93 |
+
2021-08-01,74137.29
|
| 94 |
+
2021-09-01,66995.89
|
| 95 |
+
2021-10-01,63944.68
|
| 96 |
+
2021-11-01,61087.58
|
| 97 |
+
2021-12-01,58072.97
|
| 98 |
+
2022-01-01,62864.04
|
| 99 |
+
2022-02-01,65922.11
|
| 100 |
+
2022-03-01,69610.23
|
| 101 |
+
2022-04-01,73330.83
|
| 102 |
+
2022-05-01,89097.46
|
| 103 |
+
2022-06-01,77358.71
|
| 104 |
+
2022-07-01,76642.77
|
| 105 |
+
2022-08-01,72995.45
|
| 106 |
+
2022-09-01,70477.43
|
| 107 |
+
2022-10-01,67808.1
|
| 108 |
+
2022-11-01,68044.17
|
| 109 |
+
2022-12-01,63749.16
|
| 110 |
+
2023-01-01,65186.9
|
| 111 |
+
2023-02-01,67651.11
|
| 112 |
+
2023-03-01,68162.46
|
| 113 |
+
2023-04-01,76146.97
|
| 114 |
+
2023-05-01,79448.66
|
| 115 |
+
2023-06-01,85526.48
|
| 116 |
+
2023-07-01,79343.48
|
| 117 |
+
2023-08-01,77603.09
|
| 118 |
+
2023-09-01,73130.58
|
| 119 |
+
2023-10-01,67062.64
|
| 120 |
+
2023-11-01,68957.44
|
| 121 |
+
2023-12-01,67303.87
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit==1.54.0
|
| 2 |
+
pandas==2.3.3
|
| 3 |
+
numpy==2.4.2
|
| 4 |
+
matplotlib==3.10.8
|
| 5 |
+
statsmodels==0.14.6
|
| 6 |
+
scipy==1.17.0
|
| 7 |
+
openai==2.2.0
|
| 8 |
+
querychat[streamlit]==0.5.1
|
| 9 |
+
duckdb==1.4.4
|
| 10 |
+
palettable==3.3.3
|
| 11 |
+
pydantic==2.12.5
|
| 12 |
+
python-dotenv==1.1.0
|
scripts/generate_demo_data.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate demo CSV datasets for the time-series visualization app."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
# Reproducibility
|
| 10 |
+
np.random.seed(42)
|
| 11 |
+
|
| 12 |
+
# Resolve paths relative to the project root (parent of scripts/)
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 14 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 15 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# 1. data/demo_single.csv -- Monthly retail sales (Jan 2014 - Dec 2023)
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
def generate_single_series() -> pd.DataFrame:
|
| 22 |
+
n = 120 # 10 years * 12 months
|
| 23 |
+
dates = pd.date_range(start="2014-01-01", periods=n, freq="MS")
|
| 24 |
+
|
| 25 |
+
months = np.arange(n)
|
| 26 |
+
|
| 27 |
+
# Upward trend: start ~50 000, grow ~200 per month
|
| 28 |
+
trend = 50_000 + 200 * months
|
| 29 |
+
|
| 30 |
+
# Seasonal component: sin wave peaking in December (month index 11)
|
| 31 |
+
# sin peaks at pi/2; December is month 11 (0-indexed within each year).
|
| 32 |
+
# Shift so that sin(...) = 1 when month-of-year == 11 (December).
|
| 33 |
+
month_of_year = months % 12
|
| 34 |
+
seasonal = 8_000 * np.sin(2 * np.pi * (month_of_year - 2) / 12)
|
| 35 |
+
|
| 36 |
+
# Random noise
|
| 37 |
+
noise = np.random.normal(0, 2_000, size=n)
|
| 38 |
+
|
| 39 |
+
sales = trend + seasonal + noise
|
| 40 |
+
|
| 41 |
+
# Inject 2-3 anomaly spikes
|
| 42 |
+
for idx in [36, 72, 100]:
|
| 43 |
+
sales[idx] += 15_000
|
| 44 |
+
|
| 45 |
+
df = pd.DataFrame({"date": dates, "sales": np.round(sales, 2)})
|
| 46 |
+
return df
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# 2. data/demo_multi_wide.csv -- Quarterly revenue by region (Q1 2017 - Q4 2023)
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
def generate_multi_wide() -> pd.DataFrame:
|
| 53 |
+
n = 28 # 7 years * 4 quarters
|
| 54 |
+
dates = pd.date_range(start="2017-01-01", periods=n, freq="QS")
|
| 55 |
+
|
| 56 |
+
quarters = np.arange(n)
|
| 57 |
+
quarter_of_year = quarters % 4 # 0=Q1 .. 3=Q4
|
| 58 |
+
|
| 59 |
+
regions = {
|
| 60 |
+
"North": 100_000,
|
| 61 |
+
"South": 80_000,
|
| 62 |
+
"East": 120_000,
|
| 63 |
+
"West": 90_000,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
data: dict[str, object] = {"date": dates}
|
| 67 |
+
|
| 68 |
+
for name, base in regions.items():
|
| 69 |
+
trend = base + 800 * quarters
|
| 70 |
+
seasonal = 5_000 * np.sin(2 * np.pi * quarter_of_year / 4)
|
| 71 |
+
noise = np.random.normal(0, 3_000, size=n)
|
| 72 |
+
data[name] = np.round(trend + seasonal + noise, 2)
|
| 73 |
+
|
| 74 |
+
return pd.DataFrame(data)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
# 3. data/demo_multi_long.csv -- Daily stock prices for 20 tickers
|
| 79 |
+
# (2022-01-03 to 2023-12-29, business days only)
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
def generate_multi_long() -> pd.DataFrame:
|
| 82 |
+
trading_days = pd.bdate_range(start="2022-01-03", end="2023-12-29")
|
| 83 |
+
|
| 84 |
+
# 20 simple four-letter tickers: AAAA, BBBB, ..., TTTT
|
| 85 |
+
tickers = [chr(ord("A") + i) * 4 for i in range(20)]
|
| 86 |
+
|
| 87 |
+
daily_drift = 0.0002
|
| 88 |
+
daily_vol = 0.02
|
| 89 |
+
|
| 90 |
+
frames: list[pd.DataFrame] = []
|
| 91 |
+
|
| 92 |
+
for ticker in tickers:
|
| 93 |
+
start_price = np.random.uniform(50, 500)
|
| 94 |
+
n_days = len(trading_days)
|
| 95 |
+
|
| 96 |
+
# Geometric Brownian Motion: S_t = S_0 * exp(cumsum(log returns))
|
| 97 |
+
log_returns = np.random.normal(
|
| 98 |
+
daily_drift - 0.5 * daily_vol**2, daily_vol, size=n_days
|
| 99 |
+
)
|
| 100 |
+
log_returns[0] = 0 # first day: price = start_price
|
| 101 |
+
prices = start_price * np.exp(np.cumsum(log_returns))
|
| 102 |
+
|
| 103 |
+
frames.append(
|
| 104 |
+
pd.DataFrame(
|
| 105 |
+
{
|
| 106 |
+
"date": trading_days,
|
| 107 |
+
"ticker": ticker,
|
| 108 |
+
"price": np.round(prices, 2),
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return pd.concat(frames, ignore_index=True)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ---------------------------------------------------------------------------
|
| 117 |
+
# Main
|
| 118 |
+
# ---------------------------------------------------------------------------
|
| 119 |
+
def main() -> None:
|
| 120 |
+
single = generate_single_series()
|
| 121 |
+
single.to_csv(DATA_DIR / "demo_single.csv", index=False)
|
| 122 |
+
print(f"Wrote {len(single)} rows -> {DATA_DIR / 'demo_single.csv'}")
|
| 123 |
+
|
| 124 |
+
wide = generate_multi_wide()
|
| 125 |
+
wide.to_csv(DATA_DIR / "demo_multi_wide.csv", index=False)
|
| 126 |
+
print(f"Wrote {len(wide)} rows -> {DATA_DIR / 'demo_multi_wide.csv'}")
|
| 127 |
+
|
| 128 |
+
long = generate_multi_long()
|
| 129 |
+
long.to_csv(DATA_DIR / "demo_multi_long.csv", index=False)
|
| 130 |
+
print(f"Wrote {len(long)} rows -> {DATA_DIR / 'demo_multi_long.csv'}")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|
src/__init__.py
ADDED
|
File without changes
|
src/ai_interpretation.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ai_interpretation.py
|
| 3 |
+
--------------------
|
| 4 |
+
AI-powered chart interpretation using OpenAI GPT-5.2 vision with
|
| 5 |
+
Pydantic structured output.
|
| 6 |
+
|
| 7 |
+
Provides:
|
| 8 |
+
- Pydantic models for structured chart analysis results
|
| 9 |
+
- Vision-based chart interpretation via OpenAI's GPT-5.2 model
|
| 10 |
+
- Streamlit rendering of interpretation results
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import base64
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from typing import Literal
|
| 19 |
+
|
| 20 |
+
import openai
|
| 21 |
+
from pydantic import BaseModel, ConfigDict
|
| 22 |
+
import streamlit as st
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Pydantic models
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
class TrendInfo(BaseModel):
|
| 30 |
+
"""Describes the overall trend detected in the chart."""
|
| 31 |
+
|
| 32 |
+
model_config = ConfigDict(extra="forbid")
|
| 33 |
+
|
| 34 |
+
direction: Literal["upward", "downward", "flat", "mixed"]
|
| 35 |
+
description: str
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SeasonalityInfo(BaseModel):
|
| 39 |
+
"""Describes any seasonality detected in the chart."""
|
| 40 |
+
|
| 41 |
+
model_config = ConfigDict(extra="forbid")
|
| 42 |
+
|
| 43 |
+
detected: bool
|
| 44 |
+
period: str | None
|
| 45 |
+
description: str
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class StationarityInfo(BaseModel):
|
| 49 |
+
"""Describes whether the series appears stationary."""
|
| 50 |
+
|
| 51 |
+
model_config = ConfigDict(extra="forbid")
|
| 52 |
+
|
| 53 |
+
likely_stationary: bool
|
| 54 |
+
description: str
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class AnomalyItem(BaseModel):
|
| 58 |
+
"""A single anomaly or outlier observation."""
|
| 59 |
+
|
| 60 |
+
model_config = ConfigDict(extra="forbid")
|
| 61 |
+
|
| 62 |
+
approximate_location: str
|
| 63 |
+
description: str
|
| 64 |
+
severity: Literal["low", "medium", "high"]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ChartInterpretation(BaseModel):
|
| 68 |
+
"""Complete structured interpretation of a time-series chart."""
|
| 69 |
+
|
| 70 |
+
model_config = ConfigDict(extra="forbid")
|
| 71 |
+
|
| 72 |
+
chart_type_detected: str
|
| 73 |
+
trend: TrendInfo
|
| 74 |
+
seasonality: SeasonalityInfo
|
| 75 |
+
stationarity: StationarityInfo
|
| 76 |
+
anomalies: list[AnomalyItem]
|
| 77 |
+
key_observations: list[str]
|
| 78 |
+
summary: str
|
| 79 |
+
recommendations: list[str]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
# API key check
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def check_api_key_available() -> bool:
|
| 87 |
+
"""Return ``True`` if the ``OPENAI_API_KEY`` environment variable is set
|
| 88 |
+
and non-empty."""
|
| 89 |
+
key = os.environ.get("OPENAI_API_KEY", "")
|
| 90 |
+
return bool(key.strip())
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Chart interpretation
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
_SYSTEM_PROMPT = (
|
| 98 |
+
"You are a careful time-series analyst helping business analytics "
|
| 99 |
+
"students. Analyze the chart image and provide a structured "
|
| 100 |
+
"interpretation. Be precise about what the data shows; flag anything "
|
| 101 |
+
"noteworthy. Use plain language suitable for students."
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def interpret_chart(
|
| 106 |
+
png_bytes: bytes,
|
| 107 |
+
metadata: dict,
|
| 108 |
+
) -> ChartInterpretation:
|
| 109 |
+
"""Send a chart image to GPT-5.2 vision and return a structured
|
| 110 |
+
interpretation.
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
png_bytes:
|
| 115 |
+
Raw PNG image bytes of the chart to analyse.
|
| 116 |
+
metadata:
|
| 117 |
+
Context about the chart. Expected keys:
|
| 118 |
+
|
| 119 |
+
* ``chart_type`` -- e.g. ``"line"``, ``"bar"``, ``"decomposition"``
|
| 120 |
+
* ``frequency_label`` -- e.g. ``"Monthly"``, ``"Daily"``
|
| 121 |
+
* ``date_range`` -- human-readable date range string
|
| 122 |
+
* ``y_column`` -- name of the value column being plotted
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
client = openai.OpenAI()
|
| 126 |
+
|
| 127 |
+
# Encode the PNG as a base64 data URI
|
| 128 |
+
b64 = base64.b64encode(png_bytes).decode("utf-8")
|
| 129 |
+
image_data_uri = f"data:image/png;base64,{b64}"
|
| 130 |
+
|
| 131 |
+
chart_type = metadata.get("chart_type", "time-series")
|
| 132 |
+
metadata_str = json.dumps(metadata, default=str)
|
| 133 |
+
|
| 134 |
+
response = client.beta.chat.completions.parse(
|
| 135 |
+
model="gpt-5.2-2025-12-11",
|
| 136 |
+
response_format=ChartInterpretation,
|
| 137 |
+
messages=[
|
| 138 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 139 |
+
{
|
| 140 |
+
"role": "user",
|
| 141 |
+
"content": [
|
| 142 |
+
{
|
| 143 |
+
"type": "image_url",
|
| 144 |
+
"image_url": {"url": image_data_uri},
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"type": "text",
|
| 148 |
+
"text": (
|
| 149 |
+
f"Analyze this {chart_type} chart. "
|
| 150 |
+
f"Metadata: {metadata_str}"
|
| 151 |
+
),
|
| 152 |
+
},
|
| 153 |
+
],
|
| 154 |
+
},
|
| 155 |
+
],
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Prefer the parsed structured output
|
| 159 |
+
parsed = response.choices[0].message.parsed
|
| 160 |
+
if parsed is not None:
|
| 161 |
+
return parsed
|
| 162 |
+
|
| 163 |
+
# Fallback: try to manually parse the raw content
|
| 164 |
+
raw_content = response.choices[0].message.content or ""
|
| 165 |
+
data = json.loads(raw_content)
|
| 166 |
+
return ChartInterpretation(**data)
|
| 167 |
+
|
| 168 |
+
except Exception as exc: # noqa: BLE001
|
| 169 |
+
# Return a minimal interpretation that surfaces the error
|
| 170 |
+
return ChartInterpretation(
|
| 171 |
+
chart_type_detected="unknown",
|
| 172 |
+
trend=TrendInfo(direction="mixed", description="Unable to determine."),
|
| 173 |
+
seasonality=SeasonalityInfo(
|
| 174 |
+
detected=False, period=None, description="Unable to determine."
|
| 175 |
+
),
|
| 176 |
+
stationarity=StationarityInfo(
|
| 177 |
+
likely_stationary=False, description="Unable to determine."
|
| 178 |
+
),
|
| 179 |
+
anomalies=[],
|
| 180 |
+
key_observations=["AI interpretation failed; see summary for details."],
|
| 181 |
+
summary=f"Error during AI interpretation: {exc}",
|
| 182 |
+
recommendations=["Check that your OPENAI_API_KEY is set and valid."],
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
# Streamlit rendering
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
|
| 190 |
+
_DIRECTION_EMOJI = {
|
| 191 |
+
"upward": "\u2197\ufe0f", # arrow upper-right
|
| 192 |
+
"downward": "\u2198\ufe0f", # arrow lower-right
|
| 193 |
+
"flat": "\u27a1\ufe0f", # arrow right
|
| 194 |
+
"mixed": "\u2194\ufe0f", # left-right arrow
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
_SEVERITY_COLOR = {
|
| 198 |
+
"low": "green",
|
| 199 |
+
"medium": "orange",
|
| 200 |
+
"high": "red",
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def render_interpretation(interp: ChartInterpretation) -> None:
|
| 205 |
+
"""Render a :class:`ChartInterpretation` as a styled Streamlit card.
|
| 206 |
+
|
| 207 |
+
Uses ``st.markdown``, ``st.expander``, and related widgets to lay out
|
| 208 |
+
the interpretation in an easy-to-read format with sections for trend,
|
| 209 |
+
seasonality, stationarity, anomalies, key observations, summary, and
|
| 210 |
+
recommendations.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
st.markdown("### AI Chart Interpretation")
|
| 214 |
+
st.markdown(
|
| 215 |
+
f"**Detected chart type:** {interp.chart_type_detected}"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# ---- Summary ----------------------------------------------------------
|
| 219 |
+
st.markdown("---")
|
| 220 |
+
st.markdown(f"**Summary:** {interp.summary}")
|
| 221 |
+
|
| 222 |
+
# ---- Key observations -------------------------------------------------
|
| 223 |
+
with st.expander("Key Observations", expanded=True):
|
| 224 |
+
for obs in interp.key_observations:
|
| 225 |
+
st.markdown(f"- {obs}")
|
| 226 |
+
|
| 227 |
+
# ---- Trend ------------------------------------------------------------
|
| 228 |
+
with st.expander("Trend Analysis"):
|
| 229 |
+
arrow = _DIRECTION_EMOJI.get(interp.trend.direction, "")
|
| 230 |
+
st.markdown(
|
| 231 |
+
f"**Direction:** {interp.trend.direction.capitalize()} {arrow}"
|
| 232 |
+
)
|
| 233 |
+
st.markdown(interp.trend.description)
|
| 234 |
+
|
| 235 |
+
# ---- Seasonality ------------------------------------------------------
|
| 236 |
+
with st.expander("Seasonality"):
|
| 237 |
+
status = "Detected" if interp.seasonality.detected else "Not detected"
|
| 238 |
+
st.markdown(f"**Status:** {status}")
|
| 239 |
+
if interp.seasonality.period:
|
| 240 |
+
st.markdown(f"**Period:** {interp.seasonality.period}")
|
| 241 |
+
st.markdown(interp.seasonality.description)
|
| 242 |
+
|
| 243 |
+
# ---- Stationarity -----------------------------------------------------
|
| 244 |
+
with st.expander("Stationarity"):
|
| 245 |
+
label = (
|
| 246 |
+
"Likely stationary"
|
| 247 |
+
if interp.stationarity.likely_stationary
|
| 248 |
+
else "Likely non-stationary"
|
| 249 |
+
)
|
| 250 |
+
st.markdown(f"**Assessment:** {label}")
|
| 251 |
+
st.markdown(interp.stationarity.description)
|
| 252 |
+
|
| 253 |
+
# ---- Anomalies --------------------------------------------------------
|
| 254 |
+
with st.expander("Anomalies"):
|
| 255 |
+
if not interp.anomalies:
|
| 256 |
+
st.markdown("No anomalies detected.")
|
| 257 |
+
else:
|
| 258 |
+
for anomaly in interp.anomalies:
|
| 259 |
+
color = _SEVERITY_COLOR.get(anomaly.severity, "gray")
|
| 260 |
+
st.markdown(
|
| 261 |
+
f"- **[{anomaly.approximate_location}]** "
|
| 262 |
+
f":{color}[{anomaly.severity.upper()}] "
|
| 263 |
+
f"-- {anomaly.description}"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# ---- Recommendations --------------------------------------------------
|
| 267 |
+
with st.expander("Recommended Next Steps"):
|
| 268 |
+
for rec in interp.recommendations:
|
| 269 |
+
st.markdown(f"1. {rec}")
|
src/cleaning.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CSV ingest and auto-clean pipeline for time-series data.
|
| 3 |
+
|
| 4 |
+
Provides delimiter detection, date/numeric column suggestion,
|
| 5 |
+
numeric cleaning (currency, commas, percentages, parenthesised negatives),
|
| 6 |
+
duplicate and missing-value handling, frequency detection, and
|
| 7 |
+
calendar-feature extraction.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import csv
|
| 11 |
+
import io
|
| 12 |
+
import re
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from datetime import timedelta
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Dataclasses
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class CleaningReport:
|
| 26 |
+
"""Summary produced by :func:`clean_dataframe`."""
|
| 27 |
+
|
| 28 |
+
rows_before: int = 0
|
| 29 |
+
rows_after: int = 0
|
| 30 |
+
duplicates_found: int = 0
|
| 31 |
+
duplicates_action: str = ""
|
| 32 |
+
missing_before: dict = field(default_factory=dict)
|
| 33 |
+
missing_after: dict = field(default_factory=dict)
|
| 34 |
+
parsing_warnings: list = field(default_factory=list)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class FrequencyInfo:
|
| 39 |
+
"""Result of :func:`detect_frequency`."""
|
| 40 |
+
|
| 41 |
+
label: str = "Unknown"
|
| 42 |
+
median_delta: timedelta = timedelta(0)
|
| 43 |
+
is_regular: bool = False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Delimiter detection
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def detect_delimiter(file_bytes: bytes) -> str:
|
| 51 |
+
"""Return the most likely CSV delimiter for *file_bytes*.
|
| 52 |
+
|
| 53 |
+
Uses :class:`csv.Sniffer` on the first 8 KB of text. Falls back to a
|
| 54 |
+
comma if the sniffer cannot decide.
|
| 55 |
+
"""
|
| 56 |
+
try:
|
| 57 |
+
sample = file_bytes[:8192].decode("utf-8", errors="replace")
|
| 58 |
+
dialect = csv.Sniffer().sniff(sample)
|
| 59 |
+
return dialect.delimiter
|
| 60 |
+
except csv.Error:
|
| 61 |
+
return ","
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
# Reading uploads
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
def read_csv_upload(uploaded_file) -> tuple[pd.DataFrame, str]:
|
| 69 |
+
"""Read a Streamlit ``UploadedFile`` and return ``(df, delimiter)``.
|
| 70 |
+
|
| 71 |
+
The file position is rewound so the object can be re-read later if
|
| 72 |
+
needed.
|
| 73 |
+
"""
|
| 74 |
+
raw = uploaded_file.getvalue()
|
| 75 |
+
delimiter = detect_delimiter(raw)
|
| 76 |
+
text = raw.decode("utf-8", errors="replace")
|
| 77 |
+
df = pd.read_csv(io.StringIO(text), sep=delimiter)
|
| 78 |
+
# Rewind in case the caller wants to read again
|
| 79 |
+
uploaded_file.seek(0)
|
| 80 |
+
return df, delimiter
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# Column suggestion helpers
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
_DATE_NAME_TOKENS = re.compile(r"(date|time|year|month|day|period)", re.IGNORECASE)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def suggest_date_columns(df: pd.DataFrame) -> list[str]:
|
| 91 |
+
"""Return column names that are likely to contain date/time values.
|
| 92 |
+
|
| 93 |
+
Checks are applied in order:
|
| 94 |
+
|
| 95 |
+
1. Column already has a datetime dtype.
|
| 96 |
+
2. :func:`pd.to_datetime` succeeds on the first non-null values.
|
| 97 |
+
3. The column *name* contains a date-related keyword.
|
| 98 |
+
"""
|
| 99 |
+
candidates: list[str] = []
|
| 100 |
+
|
| 101 |
+
for col in df.columns:
|
| 102 |
+
# 1. Already datetime
|
| 103 |
+
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
| 104 |
+
if col not in candidates:
|
| 105 |
+
candidates.append(col)
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
# 2. Parseable as datetime (check up to first 5 non-null values)
|
| 109 |
+
sample = df[col].dropna().head(5)
|
| 110 |
+
if not sample.empty:
|
| 111 |
+
try:
|
| 112 |
+
pd.to_datetime(sample)
|
| 113 |
+
if col not in candidates:
|
| 114 |
+
candidates.append(col)
|
| 115 |
+
continue
|
| 116 |
+
except (ValueError, TypeError, OverflowError):
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
# 3. Column name heuristic
|
| 120 |
+
if _DATE_NAME_TOKENS.search(str(col)):
|
| 121 |
+
if col not in candidates:
|
| 122 |
+
candidates.append(col)
|
| 123 |
+
|
| 124 |
+
return candidates
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def suggest_numeric_columns(df: pd.DataFrame) -> list[str]:
|
| 128 |
+
"""Return columns that are numeric or could be cleaned to numeric.
|
| 129 |
+
|
| 130 |
+
A non-numeric column qualifies if, after stripping common formatting
|
| 131 |
+
characters (currency symbols, commas, ``%``, parentheses), at least half
|
| 132 |
+
of its non-null values can be converted to a number.
|
| 133 |
+
"""
|
| 134 |
+
candidates: list[str] = []
|
| 135 |
+
|
| 136 |
+
for col in df.columns:
|
| 137 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
| 138 |
+
candidates.append(col)
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
# Attempt lightweight cleaning on a sample
|
| 142 |
+
sample = df[col].dropna().head(50).astype(str)
|
| 143 |
+
if sample.empty:
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
cleaned = (
|
| 147 |
+
sample
|
| 148 |
+
.str.replace(r"[\$\u20ac\u00a3,% ]", "", regex=True)
|
| 149 |
+
.str.replace(r"^\((.+)\)$", r"-\1", regex=True)
|
| 150 |
+
)
|
| 151 |
+
numeric = pd.to_numeric(cleaned, errors="coerce")
|
| 152 |
+
if numeric.notna().sum() >= max(1, len(sample) * 0.5):
|
| 153 |
+
candidates.append(col)
|
| 154 |
+
|
| 155 |
+
return candidates
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ---------------------------------------------------------------------------
|
| 159 |
+
# Numeric cleaning
|
| 160 |
+
# ---------------------------------------------------------------------------
|
| 161 |
+
|
| 162 |
+
def clean_numeric_series(series: pd.Series) -> pd.Series:
|
| 163 |
+
"""Clean a series into proper numeric values.
|
| 164 |
+
|
| 165 |
+
Handles:
|
| 166 |
+
* Currency symbols: ``$``, ``EUR`` (U+20AC), ``GBP`` (U+00A3)
|
| 167 |
+
* Thousands separators (commas)
|
| 168 |
+
* Percentage signs
|
| 169 |
+
* Parenthesised negatives, e.g. ``(123)`` becomes ``-123``
|
| 170 |
+
"""
|
| 171 |
+
s = series.astype(str)
|
| 172 |
+
|
| 173 |
+
# Strip currency symbols, commas, percent signs, and whitespace
|
| 174 |
+
s = s.str.replace(r"[\$\u20ac\u00a3,%\s]", "", regex=True)
|
| 175 |
+
|
| 176 |
+
# Convert accounting-style negatives: (123.45) -> -123.45
|
| 177 |
+
s = s.str.replace(r"^\((.+)\)$", r"-\1", regex=True)
|
| 178 |
+
|
| 179 |
+
return pd.to_numeric(s, errors="coerce")
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ---------------------------------------------------------------------------
|
| 183 |
+
# Full cleaning pipeline
|
| 184 |
+
# ---------------------------------------------------------------------------
|
| 185 |
+
|
| 186 |
+
def clean_dataframe(
|
| 187 |
+
df: pd.DataFrame,
|
| 188 |
+
date_col: str,
|
| 189 |
+
y_cols: list[str],
|
| 190 |
+
dup_action: str = "keep_last",
|
| 191 |
+
missing_action: str = "interpolate",
|
| 192 |
+
) -> tuple[pd.DataFrame, CleaningReport]:
|
| 193 |
+
"""Run the full cleaning pipeline and return ``(cleaned_df, report)``.
|
| 194 |
+
|
| 195 |
+
Parameters
|
| 196 |
+
----------
|
| 197 |
+
df:
|
| 198 |
+
Input dataframe (will not be mutated).
|
| 199 |
+
date_col:
|
| 200 |
+
Name of the column to parse as dates.
|
| 201 |
+
y_cols:
|
| 202 |
+
Names of the value columns to clean to numeric.
|
| 203 |
+
dup_action:
|
| 204 |
+
How to handle duplicate dates: ``"keep_first"``, ``"keep_last"``,
|
| 205 |
+
or ``"drop_all"``.
|
| 206 |
+
missing_action:
|
| 207 |
+
How to handle missing values in *y_cols*: ``"interpolate"``,
|
| 208 |
+
``"ffill"``, or ``"drop"``.
|
| 209 |
+
"""
|
| 210 |
+
df = df.copy()
|
| 211 |
+
report = CleaningReport()
|
| 212 |
+
report.rows_before = len(df)
|
| 213 |
+
|
| 214 |
+
# --- Parse date column ------------------------------------------------
|
| 215 |
+
try:
|
| 216 |
+
df[date_col] = pd.to_datetime(df[date_col])
|
| 217 |
+
except Exception as exc: # noqa: BLE001
|
| 218 |
+
report.parsing_warnings.append(
|
| 219 |
+
f"Date parsing issue in column '{date_col}': {exc}"
|
| 220 |
+
)
|
| 221 |
+
# Coerce individually so partial failures become NaT
|
| 222 |
+
df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
|
| 223 |
+
|
| 224 |
+
nat_count = int(df[date_col].isna().sum())
|
| 225 |
+
if nat_count > 0:
|
| 226 |
+
report.parsing_warnings.append(
|
| 227 |
+
f"{nat_count} value(s) in '{date_col}' could not be parsed as dates."
|
| 228 |
+
)
|
| 229 |
+
df = df.dropna(subset=[date_col])
|
| 230 |
+
|
| 231 |
+
# --- Clean numeric columns --------------------------------------------
|
| 232 |
+
for col in y_cols:
|
| 233 |
+
if not pd.api.types.is_numeric_dtype(df[col]):
|
| 234 |
+
df[col] = clean_numeric_series(df[col])
|
| 235 |
+
|
| 236 |
+
# Record missing values *before* imputation
|
| 237 |
+
report.missing_before = {
|
| 238 |
+
col: int(df[col].isna().sum()) for col in y_cols
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
# --- Handle duplicates on date column ---------------------------------
|
| 242 |
+
dup_mask = df.duplicated(subset=[date_col], keep=False)
|
| 243 |
+
report.duplicates_found = int(dup_mask.sum())
|
| 244 |
+
report.duplicates_action = dup_action
|
| 245 |
+
|
| 246 |
+
if report.duplicates_found > 0:
|
| 247 |
+
if dup_action == "keep_first":
|
| 248 |
+
df = df.drop_duplicates(subset=[date_col], keep="first")
|
| 249 |
+
elif dup_action == "keep_last":
|
| 250 |
+
df = df.drop_duplicates(subset=[date_col], keep="last")
|
| 251 |
+
elif dup_action == "drop_all":
|
| 252 |
+
df = df[~dup_mask]
|
| 253 |
+
|
| 254 |
+
# --- Sort by date -----------------------------------------------------
|
| 255 |
+
df = df.sort_values(date_col).reset_index(drop=True)
|
| 256 |
+
|
| 257 |
+
# --- Handle missing values --------------------------------------------
|
| 258 |
+
if missing_action == "interpolate":
|
| 259 |
+
df[y_cols] = df[y_cols].interpolate(method="linear", limit_direction="both")
|
| 260 |
+
elif missing_action == "ffill":
|
| 261 |
+
df[y_cols] = df[y_cols].ffill().bfill()
|
| 262 |
+
elif missing_action == "drop":
|
| 263 |
+
df = df.dropna(subset=y_cols)
|
| 264 |
+
|
| 265 |
+
report.missing_after = {
|
| 266 |
+
col: int(df[col].isna().sum()) for col in y_cols
|
| 267 |
+
}
|
| 268 |
+
report.rows_after = len(df)
|
| 269 |
+
|
| 270 |
+
return df, report
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# ---------------------------------------------------------------------------
|
| 274 |
+
# Frequency detection
|
| 275 |
+
# ---------------------------------------------------------------------------
|
| 276 |
+
|
| 277 |
+
def detect_frequency(df: pd.DataFrame, date_col: str) -> FrequencyInfo:
|
| 278 |
+
"""Classify the time-series frequency based on median time delta.
|
| 279 |
+
|
| 280 |
+
Returns a :class:`FrequencyInfo` with a human-readable label, the
|
| 281 |
+
computed median delta, and whether the series is *regular* (the
|
| 282 |
+
standard deviation of deltas is less than 20 % of the median).
|
| 283 |
+
"""
|
| 284 |
+
dates = df[date_col].dropna().sort_values()
|
| 285 |
+
if len(dates) < 2:
|
| 286 |
+
return FrequencyInfo(label="Unknown", median_delta=timedelta(0), is_regular=False)
|
| 287 |
+
|
| 288 |
+
deltas = dates.diff().dropna()
|
| 289 |
+
median_delta = deltas.median()
|
| 290 |
+
|
| 291 |
+
# Regularity: std < 20% of median
|
| 292 |
+
std_delta = deltas.std()
|
| 293 |
+
is_regular = bool(std_delta <= median_delta * 0.2) if median_delta > timedelta(0) else False
|
| 294 |
+
|
| 295 |
+
# Classify by median days
|
| 296 |
+
days = median_delta.days
|
| 297 |
+
|
| 298 |
+
if days <= 1:
|
| 299 |
+
label = "Daily"
|
| 300 |
+
elif 5 <= days <= 9:
|
| 301 |
+
label = "Weekly"
|
| 302 |
+
elif 25 <= days <= 35:
|
| 303 |
+
label = "Monthly"
|
| 304 |
+
elif 85 <= days <= 100:
|
| 305 |
+
label = "Quarterly"
|
| 306 |
+
elif 350 <= days <= 380:
|
| 307 |
+
label = "Yearly"
|
| 308 |
+
else:
|
| 309 |
+
label = "Irregular"
|
| 310 |
+
|
| 311 |
+
return FrequencyInfo(label=label, median_delta=median_delta, is_regular=is_regular)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# ---------------------------------------------------------------------------
|
| 315 |
+
# Calendar feature extraction
|
| 316 |
+
# ---------------------------------------------------------------------------
|
| 317 |
+
|
| 318 |
+
def add_time_features(df: pd.DataFrame, date_col: str) -> pd.DataFrame:
|
| 319 |
+
"""Add calendar columns derived from *date_col*.
|
| 320 |
+
|
| 321 |
+
New columns: ``year``, ``quarter``, ``month``, ``day_of_week``.
|
| 322 |
+
The dataframe is returned (not copied) with new columns appended.
|
| 323 |
+
"""
|
| 324 |
+
dt = df[date_col].dt
|
| 325 |
+
df["year"] = dt.year
|
| 326 |
+
df["quarter"] = dt.quarter
|
| 327 |
+
df["month"] = dt.month
|
| 328 |
+
df["day_of_week"] = dt.dayofweek
|
| 329 |
+
return df
|
src/diagnostics.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Time-series diagnostics utilities.
|
| 2 |
+
|
| 3 |
+
Provides summary statistics, stationarity tests, trend estimation,
|
| 4 |
+
autocorrelation analysis, seasonal decomposition, rolling statistics,
|
| 5 |
+
year-over-year change computation, and multi-series summaries.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from numpy.typing import NDArray
|
| 14 |
+
from scipy import stats
|
| 15 |
+
from statsmodels.tsa.stattools import adfuller, acf, pacf
|
| 16 |
+
from statsmodels.tsa.seasonal import seasonal_decompose, DecomposeResult
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Data classes
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class SummaryStats:
|
| 25 |
+
"""Container for univariate time-series summary statistics."""
|
| 26 |
+
|
| 27 |
+
count: int
|
| 28 |
+
missing_count: int
|
| 29 |
+
missing_pct: float
|
| 30 |
+
min_val: float
|
| 31 |
+
max_val: float
|
| 32 |
+
mean_val: float
|
| 33 |
+
median_val: float
|
| 34 |
+
std_val: float
|
| 35 |
+
p25: float
|
| 36 |
+
p75: float
|
| 37 |
+
date_start: pd.Timestamp
|
| 38 |
+
date_end: pd.Timestamp
|
| 39 |
+
date_span_days: int
|
| 40 |
+
trend_slope: float
|
| 41 |
+
trend_pvalue: float
|
| 42 |
+
adf_statistic: float
|
| 43 |
+
adf_pvalue: float
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Core helper functions
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def compute_adf_test(series: pd.Series) -> tuple[float, float]:
|
| 51 |
+
"""Run the Augmented Dickey-Fuller test for stationarity.
|
| 52 |
+
|
| 53 |
+
Parameters
|
| 54 |
+
----------
|
| 55 |
+
series : pd.Series
|
| 56 |
+
The time-series values (NaNs are dropped automatically).
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
tuple[float, float]
|
| 61 |
+
``(adf_statistic, p_value)``. Returns ``(np.nan, np.nan)`` when the
|
| 62 |
+
test cannot be performed (e.g. too few observations or constant data).
|
| 63 |
+
"""
|
| 64 |
+
clean = series.dropna()
|
| 65 |
+
if len(clean) < 2:
|
| 66 |
+
return np.nan, np.nan
|
| 67 |
+
try:
|
| 68 |
+
result = adfuller(clean, autolag="AIC")
|
| 69 |
+
return float(result[0]), float(result[1])
|
| 70 |
+
except Exception:
|
| 71 |
+
return np.nan, np.nan
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def compute_trend_slope(
|
| 75 |
+
df: pd.DataFrame,
|
| 76 |
+
date_col: str,
|
| 77 |
+
y_col: str,
|
| 78 |
+
) -> tuple[float, float]:
|
| 79 |
+
"""Estimate a linear trend via OLS on a numeric index.
|
| 80 |
+
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
df : pd.DataFrame
|
| 84 |
+
Must contain *date_col* and *y_col*.
|
| 85 |
+
date_col : str
|
| 86 |
+
Column with datetime-like values.
|
| 87 |
+
y_col : str
|
| 88 |
+
Column with numeric values.
|
| 89 |
+
|
| 90 |
+
Returns
|
| 91 |
+
-------
|
| 92 |
+
tuple[float, float]
|
| 93 |
+
``(slope, p_value)`` from ``scipy.stats.linregress``.
|
| 94 |
+
Returns ``(np.nan, np.nan)`` when the regression cannot be computed.
|
| 95 |
+
"""
|
| 96 |
+
subset = df[[date_col, y_col]].dropna()
|
| 97 |
+
if len(subset) < 2:
|
| 98 |
+
return np.nan, np.nan
|
| 99 |
+
try:
|
| 100 |
+
x = np.arange(len(subset), dtype=float)
|
| 101 |
+
y = subset[y_col].astype(float).values
|
| 102 |
+
result = stats.linregress(x, y)
|
| 103 |
+
return float(result.slope), float(result.pvalue)
|
| 104 |
+
except Exception:
|
| 105 |
+
return np.nan, np.nan
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# Summary statistics
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
def compute_summary_stats(
|
| 113 |
+
df: pd.DataFrame,
|
| 114 |
+
date_col: str,
|
| 115 |
+
y_col: str,
|
| 116 |
+
) -> SummaryStats:
|
| 117 |
+
"""Compute a comprehensive set of summary statistics for a time series.
|
| 118 |
+
|
| 119 |
+
Parameters
|
| 120 |
+
----------
|
| 121 |
+
df : pd.DataFrame
|
| 122 |
+
Source data.
|
| 123 |
+
date_col : str
|
| 124 |
+
Name of the datetime column.
|
| 125 |
+
y_col : str
|
| 126 |
+
Name of the numeric value column.
|
| 127 |
+
|
| 128 |
+
Returns
|
| 129 |
+
-------
|
| 130 |
+
SummaryStats
|
| 131 |
+
Dataclass instance containing descriptive stats, date range info,
|
| 132 |
+
trend slope / p-value, and ADF test results.
|
| 133 |
+
"""
|
| 134 |
+
series = df[y_col]
|
| 135 |
+
dates = pd.to_datetime(df[date_col])
|
| 136 |
+
|
| 137 |
+
count = int(series.notna().sum())
|
| 138 |
+
missing_count = int(series.isna().sum())
|
| 139 |
+
total = len(series)
|
| 140 |
+
missing_pct = (missing_count / total * 100.0) if total > 0 else 0.0
|
| 141 |
+
|
| 142 |
+
min_val = float(series.min())
|
| 143 |
+
max_val = float(series.max())
|
| 144 |
+
mean_val = float(series.mean())
|
| 145 |
+
median_val = float(series.median())
|
| 146 |
+
std_val = float(series.std())
|
| 147 |
+
p25 = float(series.quantile(0.25))
|
| 148 |
+
p75 = float(series.quantile(0.75))
|
| 149 |
+
|
| 150 |
+
date_start = dates.min()
|
| 151 |
+
date_end = dates.max()
|
| 152 |
+
date_span_days = int((date_end - date_start).days)
|
| 153 |
+
|
| 154 |
+
trend_slope, trend_pvalue = compute_trend_slope(df, date_col, y_col)
|
| 155 |
+
adf_statistic, adf_pvalue = compute_adf_test(series)
|
| 156 |
+
|
| 157 |
+
return SummaryStats(
|
| 158 |
+
count=count,
|
| 159 |
+
missing_count=missing_count,
|
| 160 |
+
missing_pct=missing_pct,
|
| 161 |
+
min_val=min_val,
|
| 162 |
+
max_val=max_val,
|
| 163 |
+
mean_val=mean_val,
|
| 164 |
+
median_val=median_val,
|
| 165 |
+
std_val=std_val,
|
| 166 |
+
p25=p25,
|
| 167 |
+
p75=p75,
|
| 168 |
+
date_start=date_start,
|
| 169 |
+
date_end=date_end,
|
| 170 |
+
date_span_days=date_span_days,
|
| 171 |
+
trend_slope=trend_slope,
|
| 172 |
+
trend_pvalue=trend_pvalue,
|
| 173 |
+
adf_statistic=adf_statistic,
|
| 174 |
+
adf_pvalue=adf_pvalue,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ---------------------------------------------------------------------------
|
| 179 |
+
# Autocorrelation / partial autocorrelation
|
| 180 |
+
# ---------------------------------------------------------------------------
|
| 181 |
+
|
| 182 |
+
def compute_acf_pacf(
|
| 183 |
+
series: pd.Series,
|
| 184 |
+
nlags: int = 40,
|
| 185 |
+
) -> tuple[NDArray, NDArray, NDArray, NDArray]:
|
| 186 |
+
"""Compute ACF and PACF with confidence intervals.
|
| 187 |
+
|
| 188 |
+
Parameters
|
| 189 |
+
----------
|
| 190 |
+
series : pd.Series
|
| 191 |
+
The time-series values (NaNs are dropped automatically).
|
| 192 |
+
nlags : int, optional
|
| 193 |
+
Maximum number of lags (default 40). Automatically reduced when the
|
| 194 |
+
series is shorter than ``nlags + 1``.
|
| 195 |
+
|
| 196 |
+
Returns
|
| 197 |
+
-------
|
| 198 |
+
tuple[ndarray, ndarray, ndarray, ndarray]
|
| 199 |
+
``(acf_values, acf_confint, pacf_values, pacf_confint)``
|
| 200 |
+
|
| 201 |
+
* ``acf_values`` -- shape ``(nlags + 1,)``
|
| 202 |
+
* ``acf_confint`` -- shape ``(nlags + 1, 2)``
|
| 203 |
+
* ``pacf_values`` -- shape ``(nlags + 1,)``
|
| 204 |
+
* ``pacf_confint`` -- shape ``(nlags + 1, 2)``
|
| 205 |
+
"""
|
| 206 |
+
clean = series.dropna().values.astype(float)
|
| 207 |
+
|
| 208 |
+
# Ensure nlags does not exceed what the data can support.
|
| 209 |
+
max_possible = len(clean) - 1
|
| 210 |
+
if max_possible < 1:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
"Series has fewer than 2 non-NaN observations; "
|
| 213 |
+
"cannot compute ACF/PACF."
|
| 214 |
+
)
|
| 215 |
+
nlags = min(nlags, max_possible)
|
| 216 |
+
|
| 217 |
+
acf_values, acf_confint = acf(clean, nlags=nlags, alpha=0.05)
|
| 218 |
+
pacf_values, pacf_confint = pacf(clean, nlags=nlags, alpha=0.05)
|
| 219 |
+
|
| 220 |
+
return acf_values, acf_confint, pacf_values, pacf_confint
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# ---------------------------------------------------------------------------
|
| 224 |
+
# Seasonal decomposition
|
| 225 |
+
# ---------------------------------------------------------------------------
|
| 226 |
+
|
| 227 |
+
def _infer_period(df: pd.DataFrame, date_col: str) -> int:
|
| 228 |
+
"""Best-effort period inference from the date column's frequency.
|
| 229 |
+
|
| 230 |
+
Returns a sensible integer period or raises ``ValueError`` when the
|
| 231 |
+
frequency cannot be determined.
|
| 232 |
+
"""
|
| 233 |
+
dates = pd.to_datetime(df[date_col])
|
| 234 |
+
freq = pd.infer_freq(dates)
|
| 235 |
+
if freq is None:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
"Cannot infer a regular frequency from the date column. "
|
| 238 |
+
"Please supply an explicit 'period' argument or resample the "
|
| 239 |
+
"data to a regular frequency before calling compute_decomposition."
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Map common frequency strings to typical seasonal periods.
|
| 243 |
+
freq_upper = freq.upper()
|
| 244 |
+
period_map: dict[str, int] = {
|
| 245 |
+
"D": 365,
|
| 246 |
+
"B": 252, # business days in a year
|
| 247 |
+
"W": 52,
|
| 248 |
+
"SM": 24, # semi-monthly
|
| 249 |
+
"BMS": 12,
|
| 250 |
+
"BM": 12,
|
| 251 |
+
"MS": 12,
|
| 252 |
+
"M": 12, # calendar month end
|
| 253 |
+
"ME": 12, # month-end (pandas >= 2.2)
|
| 254 |
+
"QS": 4,
|
| 255 |
+
"Q": 4,
|
| 256 |
+
"QE": 4,
|
| 257 |
+
"BQ": 4,
|
| 258 |
+
"AS": 1,
|
| 259 |
+
"A": 1,
|
| 260 |
+
"YS": 1,
|
| 261 |
+
"Y": 1,
|
| 262 |
+
"YE": 1,
|
| 263 |
+
"H": 24,
|
| 264 |
+
"T": 60,
|
| 265 |
+
"MIN": 60,
|
| 266 |
+
"S": 60,
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
# Strip leading digits (e.g. "2W" -> "W") to normalise anchored offsets.
|
| 270 |
+
stripped = freq_upper.lstrip("0123456789")
|
| 271 |
+
# Also strip any anchor suffix like "W-SUN" -> "W".
|
| 272 |
+
base = stripped.split("-")[0]
|
| 273 |
+
|
| 274 |
+
if base in period_map:
|
| 275 |
+
return period_map[base]
|
| 276 |
+
|
| 277 |
+
raise ValueError(
|
| 278 |
+
f"Unable to map inferred frequency '{freq}' to a seasonal period. "
|
| 279 |
+
"Please provide an explicit 'period' argument."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def compute_decomposition(
|
| 284 |
+
df: pd.DataFrame,
|
| 285 |
+
date_col: str,
|
| 286 |
+
y_col: str,
|
| 287 |
+
model: str = "additive",
|
| 288 |
+
period: Optional[int] = None,
|
| 289 |
+
) -> DecomposeResult:
|
| 290 |
+
"""Decompose a time series into trend, seasonal, and residual components.
|
| 291 |
+
|
| 292 |
+
Parameters
|
| 293 |
+
----------
|
| 294 |
+
df : pd.DataFrame
|
| 295 |
+
Source data.
|
| 296 |
+
date_col : str
|
| 297 |
+
Datetime column name.
|
| 298 |
+
y_col : str
|
| 299 |
+
Numeric value column name.
|
| 300 |
+
model : str, optional
|
| 301 |
+
``"additive"`` (default) or ``"multiplicative"``.
|
| 302 |
+
period : int or None, optional
|
| 303 |
+
Seasonal period. When *None* the period is inferred from the date
|
| 304 |
+
column's frequency.
|
| 305 |
+
|
| 306 |
+
Returns
|
| 307 |
+
-------
|
| 308 |
+
statsmodels.tsa.seasonal.DecomposeResult
|
| 309 |
+
|
| 310 |
+
Raises
|
| 311 |
+
------
|
| 312 |
+
ValueError
|
| 313 |
+
If a regular frequency cannot be inferred and *period* is not given.
|
| 314 |
+
"""
|
| 315 |
+
ts = (
|
| 316 |
+
df[[date_col, y_col]]
|
| 317 |
+
.copy()
|
| 318 |
+
.set_index(date_col)
|
| 319 |
+
.sort_index()
|
| 320 |
+
)
|
| 321 |
+
ts.index = pd.to_datetime(ts.index)
|
| 322 |
+
|
| 323 |
+
# Forward-fill / back-fill small gaps so decomposition doesn't fail on
|
| 324 |
+
# a handful of interior NaNs.
|
| 325 |
+
ts[y_col] = ts[y_col].ffill().bfill()
|
| 326 |
+
|
| 327 |
+
if period is None:
|
| 328 |
+
period = _infer_period(df, date_col)
|
| 329 |
+
|
| 330 |
+
# Attempt to set a frequency on the index so that seasonal_decompose is
|
| 331 |
+
# happy; fall back to the explicit period if this fails.
|
| 332 |
+
if ts.index.freq is None:
|
| 333 |
+
inferred = pd.infer_freq(ts.index)
|
| 334 |
+
if inferred is not None:
|
| 335 |
+
ts = ts.asfreq(inferred)
|
| 336 |
+
ts[y_col] = ts[y_col].ffill().bfill()
|
| 337 |
+
|
| 338 |
+
return seasonal_decompose(ts[y_col], model=model, period=period)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# ---------------------------------------------------------------------------
|
| 342 |
+
# Rolling statistics
|
| 343 |
+
# ---------------------------------------------------------------------------
|
| 344 |
+
|
| 345 |
+
def compute_rolling_stats(
|
| 346 |
+
df: pd.DataFrame,
|
| 347 |
+
y_col: str,
|
| 348 |
+
window: int = 12,
|
| 349 |
+
) -> pd.DataFrame:
|
| 350 |
+
"""Add rolling mean and rolling standard deviation columns to *df*.
|
| 351 |
+
|
| 352 |
+
Parameters
|
| 353 |
+
----------
|
| 354 |
+
df : pd.DataFrame
|
| 355 |
+
Source data (not mutated).
|
| 356 |
+
y_col : str
|
| 357 |
+
Column over which rolling statistics are calculated.
|
| 358 |
+
window : int, optional
|
| 359 |
+
Rolling window size (default 12).
|
| 360 |
+
|
| 361 |
+
Returns
|
| 362 |
+
-------
|
| 363 |
+
pd.DataFrame
|
| 364 |
+
Copy of *df* with two extra columns: ``rolling_mean`` and
|
| 365 |
+
``rolling_std``.
|
| 366 |
+
"""
|
| 367 |
+
out = df.copy()
|
| 368 |
+
out["rolling_mean"] = out[y_col].rolling(window=window, min_periods=1).mean()
|
| 369 |
+
out["rolling_std"] = out[y_col].rolling(window=window, min_periods=1).std()
|
| 370 |
+
return out
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# ---------------------------------------------------------------------------
|
| 374 |
+
# Year-over-year change
|
| 375 |
+
# ---------------------------------------------------------------------------
|
| 376 |
+
|
| 377 |
+
def _offset_for_frequency(df: pd.DataFrame, date_col: str) -> pd.DateOffset:
|
| 378 |
+
"""Return a 1-year ``DateOffset`` appropriate to the series frequency."""
|
| 379 |
+
dates = pd.to_datetime(df[date_col])
|
| 380 |
+
freq = pd.infer_freq(dates)
|
| 381 |
+
|
| 382 |
+
if freq is not None:
|
| 383 |
+
freq_upper = freq.upper().lstrip("0123456789").split("-")[0]
|
| 384 |
+
# For sub-monthly frequencies we shift by 365 days / 52 weeks etc.
|
| 385 |
+
if freq_upper in {"D", "B"}:
|
| 386 |
+
return pd.DateOffset(days=365)
|
| 387 |
+
if freq_upper in {"W"}:
|
| 388 |
+
return pd.DateOffset(weeks=52)
|
| 389 |
+
if freq_upper in {"H", "T", "MIN", "S"}:
|
| 390 |
+
return pd.DateOffset(days=365)
|
| 391 |
+
|
| 392 |
+
# Default: shift by 12 months (works for M, Q, and annual data).
|
| 393 |
+
return pd.DateOffset(months=12)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def compute_yoy_change(
|
| 397 |
+
df: pd.DataFrame,
|
| 398 |
+
date_col: str,
|
| 399 |
+
y_col: str,
|
| 400 |
+
) -> pd.DataFrame:
|
| 401 |
+
"""Compute year-over-year absolute and percentage change.
|
| 402 |
+
|
| 403 |
+
The number of periods to shift is determined from the inferred frequency
|
| 404 |
+
of the date column.
|
| 405 |
+
|
| 406 |
+
Parameters
|
| 407 |
+
----------
|
| 408 |
+
df : pd.DataFrame
|
| 409 |
+
Source data (not mutated).
|
| 410 |
+
date_col : str
|
| 411 |
+
Datetime column name.
|
| 412 |
+
y_col : str
|
| 413 |
+
Numeric value column name.
|
| 414 |
+
|
| 415 |
+
Returns
|
| 416 |
+
-------
|
| 417 |
+
pd.DataFrame
|
| 418 |
+
Copy of *df* sorted by *date_col* with additional columns
|
| 419 |
+
``yoy_abs_change`` and ``yoy_pct_change``.
|
| 420 |
+
"""
|
| 421 |
+
out = df.copy().sort_values(date_col).reset_index(drop=True)
|
| 422 |
+
out[date_col] = pd.to_datetime(out[date_col])
|
| 423 |
+
|
| 424 |
+
# Determine the number of rows that correspond to ~1 year.
|
| 425 |
+
freq = pd.infer_freq(out[date_col])
|
| 426 |
+
if freq is not None:
|
| 427 |
+
freq_upper = freq.upper().lstrip("0123456789").split("-")[0]
|
| 428 |
+
period_map: dict[str, int] = {
|
| 429 |
+
"D": 365,
|
| 430 |
+
"B": 252,
|
| 431 |
+
"W": 52,
|
| 432 |
+
"SM": 24,
|
| 433 |
+
"BMS": 12,
|
| 434 |
+
"BM": 12,
|
| 435 |
+
"MS": 12,
|
| 436 |
+
"M": 12,
|
| 437 |
+
"ME": 12,
|
| 438 |
+
"QS": 4,
|
| 439 |
+
"Q": 4,
|
| 440 |
+
"QE": 4,
|
| 441 |
+
"BQ": 4,
|
| 442 |
+
"AS": 1,
|
| 443 |
+
"A": 1,
|
| 444 |
+
"YS": 1,
|
| 445 |
+
"Y": 1,
|
| 446 |
+
"YE": 1,
|
| 447 |
+
"H": 8760,
|
| 448 |
+
"T": 525600,
|
| 449 |
+
"MIN": 525600,
|
| 450 |
+
"S": 31536000,
|
| 451 |
+
}
|
| 452 |
+
base = freq_upper
|
| 453 |
+
shift_periods = period_map.get(base, 12)
|
| 454 |
+
else:
|
| 455 |
+
# Fallback: assume monthly data.
|
| 456 |
+
shift_periods = 12
|
| 457 |
+
|
| 458 |
+
shifted = out[y_col].shift(shift_periods)
|
| 459 |
+
out["yoy_abs_change"] = out[y_col] - shifted
|
| 460 |
+
out["yoy_pct_change"] = out["yoy_abs_change"] / shifted.abs().replace(0, np.nan) * 100.0
|
| 461 |
+
|
| 462 |
+
return out
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# ---------------------------------------------------------------------------
|
| 466 |
+
# Multi-series summary
|
| 467 |
+
# ---------------------------------------------------------------------------
|
| 468 |
+
|
| 469 |
+
def compute_multi_series_summary(
|
| 470 |
+
df: pd.DataFrame,
|
| 471 |
+
date_col: str,
|
| 472 |
+
y_cols: list[str],
|
| 473 |
+
) -> pd.DataFrame:
|
| 474 |
+
"""Produce a summary DataFrame with one row per value column.
|
| 475 |
+
|
| 476 |
+
Parameters
|
| 477 |
+
----------
|
| 478 |
+
df : pd.DataFrame
|
| 479 |
+
Source data.
|
| 480 |
+
date_col : str
|
| 481 |
+
Datetime column name.
|
| 482 |
+
y_cols : list[str]
|
| 483 |
+
List of numeric column names to summarise.
|
| 484 |
+
|
| 485 |
+
Returns
|
| 486 |
+
-------
|
| 487 |
+
pd.DataFrame
|
| 488 |
+
Columns: ``variable``, ``count``, ``mean``, ``std``, ``min``,
|
| 489 |
+
``max``, ``trend_slope``, ``adf_pvalue``.
|
| 490 |
+
"""
|
| 491 |
+
rows: list[dict] = []
|
| 492 |
+
for col in y_cols:
|
| 493 |
+
series = df[col]
|
| 494 |
+
slope, _ = compute_trend_slope(df, date_col, col)
|
| 495 |
+
_, adf_p = compute_adf_test(series)
|
| 496 |
+
rows.append(
|
| 497 |
+
{
|
| 498 |
+
"variable": col,
|
| 499 |
+
"count": int(series.notna().sum()),
|
| 500 |
+
"mean": float(series.mean()),
|
| 501 |
+
"std": float(series.std()),
|
| 502 |
+
"min": float(series.min()),
|
| 503 |
+
"max": float(series.max()),
|
| 504 |
+
"trend_slope": slope,
|
| 505 |
+
"adf_pvalue": adf_p,
|
| 506 |
+
}
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
return pd.DataFrame(rows)
|
src/plotting.py
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
plotting.py
|
| 3 |
+
-----------
|
| 4 |
+
Chart-generation functions for time-series visualisation.
|
| 5 |
+
|
| 6 |
+
Every public function returns a :class:`matplotlib.figure.Figure` object.
|
| 7 |
+
Callers (e.g. Streamlit pages) can pass the figure to ``st.pyplot(fig)``
|
| 8 |
+
or convert it to PNG bytes via :func:`fig_to_png_bytes`.
|
| 9 |
+
|
| 10 |
+
All functions accept an optional *style_dict* (typically from
|
| 11 |
+
:func:`ui_theme.get_miami_mpl_style`) and an optional *palette_colors*
|
| 12 |
+
list so that colours stay consistent across the application.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import io
|
| 18 |
+
import math
|
| 19 |
+
from typing import Dict, List, Optional, Sequence
|
| 20 |
+
|
| 21 |
+
# CRITICAL: set the non-interactive backend before any other mpl import.
|
| 22 |
+
import matplotlib
|
| 23 |
+
matplotlib.use("Agg")
|
| 24 |
+
|
| 25 |
+
import matplotlib.pyplot as plt # noqa: E402
|
| 26 |
+
import matplotlib.dates as mdates # noqa: E402
|
| 27 |
+
import numpy as np # noqa: E402
|
| 28 |
+
import pandas as pd # noqa: E402
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Brand defaults (mirrors ui_theme.py)
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
MIAMI_RED: str = "#C41230"
|
| 34 |
+
_DEFAULT_FIG_SIZE = (10, 6)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Utility
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
def fig_to_png_bytes(fig: matplotlib.figure.Figure, dpi: int = 150) -> bytes:
|
| 42 |
+
"""Render *fig* to an in-memory PNG and return the raw bytes."""
|
| 43 |
+
buf = io.BytesIO()
|
| 44 |
+
fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
|
| 45 |
+
buf.seek(0)
|
| 46 |
+
return buf.read()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# Internal helpers
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
class _StyleContext:
|
| 54 |
+
"""Context manager that temporarily applies *style_dict* to rcParams.
|
| 55 |
+
|
| 56 |
+
On exit the previous values are restored so that other figures are not
|
| 57 |
+
affected.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, style_dict: Optional[Dict[str, object]]):
|
| 61 |
+
self._style = style_dict
|
| 62 |
+
self._saved: Dict[str, object] = {}
|
| 63 |
+
|
| 64 |
+
def __enter__(self) -> "_StyleContext":
|
| 65 |
+
if self._style:
|
| 66 |
+
for key, value in self._style.items():
|
| 67 |
+
self._saved[key] = plt.rcParams.get(key)
|
| 68 |
+
try:
|
| 69 |
+
plt.rcParams[key] = value
|
| 70 |
+
except (KeyError, ValueError):
|
| 71 |
+
pass
|
| 72 |
+
return self
|
| 73 |
+
|
| 74 |
+
def __exit__(self, *exc_info: object) -> None:
|
| 75 |
+
for key, value in self._saved.items():
|
| 76 |
+
try:
|
| 77 |
+
plt.rcParams[key] = value
|
| 78 |
+
except (KeyError, ValueError):
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _default_color(palette_colors: Optional[List[str]], idx: int = 0) -> str:
|
| 83 |
+
"""Pick a colour from *palette_colors* or fall back to MIAMI_RED."""
|
| 84 |
+
if palette_colors and len(palette_colors) > idx:
|
| 85 |
+
return palette_colors[idx % len(palette_colors)]
|
| 86 |
+
return MIAMI_RED
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _finish_figure(fig: matplotlib.figure.Figure) -> matplotlib.figure.Figure:
|
| 90 |
+
"""Apply common finishing touches and return the figure."""
|
| 91 |
+
fig.tight_layout()
|
| 92 |
+
return fig
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _auto_date_axis(ax: plt.Axes) -> None:
|
| 96 |
+
"""Auto-format and rotate date tick labels."""
|
| 97 |
+
ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(mdates.AutoDateLocator()))
|
| 98 |
+
for label in ax.get_xticklabels():
|
| 99 |
+
label.set_rotation(30)
|
| 100 |
+
label.set_ha("right")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _grid_dims(n: int) -> tuple[int, int]:
|
| 104 |
+
"""Return (nrows, ncols) for a compact grid of *n* panels."""
|
| 105 |
+
ncols = min(n, 3)
|
| 106 |
+
nrows = math.ceil(n / ncols)
|
| 107 |
+
return nrows, ncols
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ===================================================================
|
| 111 |
+
# 1. Line with markers
|
| 112 |
+
# ===================================================================
|
| 113 |
+
|
| 114 |
+
def plot_line_with_markers(
|
| 115 |
+
df: pd.DataFrame,
|
| 116 |
+
date_col: str,
|
| 117 |
+
y_col: str,
|
| 118 |
+
title: Optional[str] = None,
|
| 119 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 120 |
+
palette_colors: Optional[List[str]] = None,
|
| 121 |
+
) -> matplotlib.figure.Figure:
|
| 122 |
+
"""Simple line plot with small circle markers.
|
| 123 |
+
|
| 124 |
+
Uses the first palette colour or *MIAMI_RED* as the default.
|
| 125 |
+
"""
|
| 126 |
+
with _StyleContext(style_dict):
|
| 127 |
+
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
|
| 128 |
+
color = _default_color(palette_colors, 0)
|
| 129 |
+
ax.plot(
|
| 130 |
+
df[date_col], df[y_col],
|
| 131 |
+
marker="o", markersize=4, linewidth=1.5,
|
| 132 |
+
color=color, label=y_col,
|
| 133 |
+
)
|
| 134 |
+
ax.set_xlabel(date_col)
|
| 135 |
+
ax.set_ylabel(y_col)
|
| 136 |
+
if title:
|
| 137 |
+
ax.set_title(title)
|
| 138 |
+
_auto_date_axis(ax)
|
| 139 |
+
ax.legend(loc="best")
|
| 140 |
+
return _finish_figure(fig)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ===================================================================
|
| 144 |
+
# 2. Line with coloured markers
|
| 145 |
+
# ===================================================================
|
| 146 |
+
|
| 147 |
+
def plot_line_colored_markers(
|
| 148 |
+
df: pd.DataFrame,
|
| 149 |
+
date_col: str,
|
| 150 |
+
y_col: str,
|
| 151 |
+
color_by: str,
|
| 152 |
+
palette_colors: List[str],
|
| 153 |
+
title: Optional[str] = None,
|
| 154 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 155 |
+
) -> matplotlib.figure.Figure:
|
| 156 |
+
"""Line plot where marker colour varies by a categorical column.
|
| 157 |
+
|
| 158 |
+
A legend is added mapping each unique value of *color_by* to its
|
| 159 |
+
colour.
|
| 160 |
+
"""
|
| 161 |
+
with _StyleContext(style_dict):
|
| 162 |
+
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
|
| 163 |
+
|
| 164 |
+
# Draw the connecting line in a neutral grey
|
| 165 |
+
ax.plot(
|
| 166 |
+
df[date_col], df[y_col],
|
| 167 |
+
linewidth=1.0, color="#AAAAAA", zorder=1,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Map categories to colours
|
| 171 |
+
categories = df[color_by].unique()
|
| 172 |
+
n_cats = len(categories)
|
| 173 |
+
if len(palette_colors) < n_cats:
|
| 174 |
+
# cycle palette to cover all categories
|
| 175 |
+
import itertools
|
| 176 |
+
palette_colors = list(itertools.islice(
|
| 177 |
+
itertools.cycle(palette_colors), n_cats
|
| 178 |
+
))
|
| 179 |
+
|
| 180 |
+
color_map = {cat: palette_colors[i] for i, cat in enumerate(categories)}
|
| 181 |
+
|
| 182 |
+
for cat in categories:
|
| 183 |
+
mask = df[color_by] == cat
|
| 184 |
+
ax.scatter(
|
| 185 |
+
df.loc[mask, date_col], df.loc[mask, y_col],
|
| 186 |
+
c=color_map[cat], label=str(cat),
|
| 187 |
+
s=30, zorder=2, edgecolors="white", linewidths=0.3,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
ax.set_xlabel(date_col)
|
| 191 |
+
ax.set_ylabel(y_col)
|
| 192 |
+
if title:
|
| 193 |
+
ax.set_title(title)
|
| 194 |
+
_auto_date_axis(ax)
|
| 195 |
+
ax.legend(title=color_by, loc="best", fontsize=8, ncol=max(1, n_cats // 8))
|
| 196 |
+
return _finish_figure(fig)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ===================================================================
|
| 200 |
+
# 3. Seasonal plot
|
| 201 |
+
# ===================================================================
|
| 202 |
+
|
| 203 |
+
def plot_seasonal(
|
| 204 |
+
df: pd.DataFrame,
|
| 205 |
+
date_col: str,
|
| 206 |
+
y_col: str,
|
| 207 |
+
period: str,
|
| 208 |
+
palette_name_colors: List[str],
|
| 209 |
+
title: Optional[str] = None,
|
| 210 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 211 |
+
) -> matplotlib.figure.Figure:
|
| 212 |
+
"""Seasonal plot: one line per year/cycle, x-axis is within-period position.
|
| 213 |
+
|
| 214 |
+
Parameters
|
| 215 |
+
----------
|
| 216 |
+
period:
|
| 217 |
+
``"month"`` (x = month 1-12) or ``"quarter"`` (x = quarter 1-4).
|
| 218 |
+
palette_name_colors:
|
| 219 |
+
List of hex colours; one per cycle/year.
|
| 220 |
+
"""
|
| 221 |
+
with _StyleContext(style_dict):
|
| 222 |
+
tmp = df[[date_col, y_col]].copy()
|
| 223 |
+
tmp["_year"] = tmp[date_col].dt.year
|
| 224 |
+
|
| 225 |
+
if period.lower().startswith("q"):
|
| 226 |
+
tmp["_period_pos"] = tmp[date_col].dt.quarter
|
| 227 |
+
x_label = "Quarter"
|
| 228 |
+
else:
|
| 229 |
+
tmp["_period_pos"] = tmp[date_col].dt.month
|
| 230 |
+
x_label = "Month"
|
| 231 |
+
|
| 232 |
+
years = sorted(tmp["_year"].unique())
|
| 233 |
+
n_years = len(years)
|
| 234 |
+
if len(palette_name_colors) < n_years:
|
| 235 |
+
import itertools
|
| 236 |
+
palette_name_colors = list(itertools.islice(
|
| 237 |
+
itertools.cycle(palette_name_colors), n_years
|
| 238 |
+
))
|
| 239 |
+
|
| 240 |
+
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
|
| 241 |
+
for i, year in enumerate(years):
|
| 242 |
+
sub = tmp[tmp["_year"] == year].sort_values("_period_pos")
|
| 243 |
+
ax.plot(
|
| 244 |
+
sub["_period_pos"], sub[y_col],
|
| 245 |
+
marker="o", markersize=4, linewidth=1.4,
|
| 246 |
+
color=palette_name_colors[i], label=str(year),
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
ax.set_xlabel(x_label)
|
| 250 |
+
ax.set_ylabel(y_col)
|
| 251 |
+
if title:
|
| 252 |
+
ax.set_title(title)
|
| 253 |
+
ax.legend(title="Year", loc="best", fontsize=8, ncol=max(1, n_years // 6))
|
| 254 |
+
return _finish_figure(fig)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ===================================================================
|
| 258 |
+
# 4. Seasonal sub-series
|
| 259 |
+
# ===================================================================
|
| 260 |
+
|
| 261 |
+
def plot_seasonal_subseries(
|
| 262 |
+
df: pd.DataFrame,
|
| 263 |
+
date_col: str,
|
| 264 |
+
y_col: str,
|
| 265 |
+
period: str,
|
| 266 |
+
title: Optional[str] = None,
|
| 267 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 268 |
+
palette_colors: Optional[List[str]] = None,
|
| 269 |
+
) -> matplotlib.figure.Figure:
|
| 270 |
+
"""Subseries plot with vertical panels for each season and horizontal mean lines.
|
| 271 |
+
|
| 272 |
+
Parameters
|
| 273 |
+
----------
|
| 274 |
+
period:
|
| 275 |
+
``"month"`` or ``"quarter"``.
|
| 276 |
+
"""
|
| 277 |
+
with _StyleContext(style_dict):
|
| 278 |
+
tmp = df[[date_col, y_col]].copy()
|
| 279 |
+
|
| 280 |
+
if period.lower().startswith("q"):
|
| 281 |
+
tmp["_season"] = tmp[date_col].dt.quarter
|
| 282 |
+
labels = {1: "Q1", 2: "Q2", 3: "Q3", 4: "Q4"}
|
| 283 |
+
else:
|
| 284 |
+
tmp["_season"] = tmp[date_col].dt.month
|
| 285 |
+
labels = {
|
| 286 |
+
1: "Jan", 2: "Feb", 3: "Mar", 4: "Apr",
|
| 287 |
+
5: "May", 6: "Jun", 7: "Jul", 8: "Aug",
|
| 288 |
+
9: "Sep", 10: "Oct", 11: "Nov", 12: "Dec",
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
seasons = sorted(tmp["_season"].unique())
|
| 292 |
+
n = len(seasons)
|
| 293 |
+
fig_w = max(10, n * 1.3)
|
| 294 |
+
fig, axes = plt.subplots(1, n, figsize=(fig_w, 5), sharey=True)
|
| 295 |
+
if n == 1:
|
| 296 |
+
axes = [axes]
|
| 297 |
+
|
| 298 |
+
color = _default_color(palette_colors, 0)
|
| 299 |
+
|
| 300 |
+
for idx, season in enumerate(seasons):
|
| 301 |
+
ax = axes[idx]
|
| 302 |
+
sub = tmp[tmp["_season"] == season].sort_values(date_col)
|
| 303 |
+
x_positions = range(len(sub))
|
| 304 |
+
ax.plot(x_positions, sub[y_col].values, marker="o", markersize=3,
|
| 305 |
+
linewidth=1.2, color=color)
|
| 306 |
+
|
| 307 |
+
mean_val = sub[y_col].mean()
|
| 308 |
+
ax.axhline(mean_val, color=MIAMI_RED, linewidth=1.8, linestyle="--", alpha=0.8)
|
| 309 |
+
|
| 310 |
+
ax.set_title(labels.get(season, str(season)), fontsize=10)
|
| 311 |
+
ax.set_xticks([])
|
| 312 |
+
ax.tick_params(axis="y", labelsize=8)
|
| 313 |
+
if idx == 0:
|
| 314 |
+
ax.set_ylabel(y_col)
|
| 315 |
+
|
| 316 |
+
if title:
|
| 317 |
+
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
|
| 318 |
+
return _finish_figure(fig)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# ===================================================================
|
| 322 |
+
# 5. ACF / PACF
|
| 323 |
+
# ===================================================================
|
| 324 |
+
|
| 325 |
+
def plot_acf_pacf(
|
| 326 |
+
acf_vals: np.ndarray,
|
| 327 |
+
acf_ci: np.ndarray,
|
| 328 |
+
pacf_vals: np.ndarray,
|
| 329 |
+
pacf_ci: np.ndarray,
|
| 330 |
+
title: Optional[str] = None,
|
| 331 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 332 |
+
) -> matplotlib.figure.Figure:
|
| 333 |
+
"""Side-by-side ACF and PACF bar plots with confidence-interval bands.
|
| 334 |
+
|
| 335 |
+
Parameters
|
| 336 |
+
----------
|
| 337 |
+
acf_vals, pacf_vals:
|
| 338 |
+
1-D arrays of autocorrelation values (lag 0, 1, ...).
|
| 339 |
+
acf_ci, pacf_ci:
|
| 340 |
+
Arrays of shape ``(n_lags, 2)`` giving the lower and upper CI bounds.
|
| 341 |
+
"""
|
| 342 |
+
with _StyleContext(style_dict):
|
| 343 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
| 344 |
+
|
| 345 |
+
for ax, vals, ci, sub_title in [
|
| 346 |
+
(ax1, acf_vals, acf_ci, "ACF"),
|
| 347 |
+
(ax2, pacf_vals, pacf_ci, "PACF"),
|
| 348 |
+
]:
|
| 349 |
+
lags = np.arange(len(vals))
|
| 350 |
+
ax.bar(lags, vals, width=0.3, color=MIAMI_RED, alpha=0.85, zorder=2)
|
| 351 |
+
|
| 352 |
+
# Confidence band
|
| 353 |
+
lower = ci[:, 0]
|
| 354 |
+
upper = ci[:, 1]
|
| 355 |
+
ax.fill_between(lags, lower, upper, color="#C41230", alpha=0.12, zorder=1)
|
| 356 |
+
ax.axhline(0, color="black", linewidth=0.8)
|
| 357 |
+
|
| 358 |
+
ax.set_xlabel("Lag")
|
| 359 |
+
ax.set_ylabel("Correlation")
|
| 360 |
+
ax.set_title(sub_title)
|
| 361 |
+
|
| 362 |
+
if title:
|
| 363 |
+
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
|
| 364 |
+
return _finish_figure(fig)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# ===================================================================
|
| 368 |
+
# 6. Decomposition
|
| 369 |
+
# ===================================================================
|
| 370 |
+
|
| 371 |
+
def plot_decomposition(
|
| 372 |
+
decomposition_result,
|
| 373 |
+
title: Optional[str] = None,
|
| 374 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 375 |
+
) -> matplotlib.figure.Figure:
|
| 376 |
+
"""4-panel plot: observed, trend, seasonal, residual.
|
| 377 |
+
|
| 378 |
+
Parameters
|
| 379 |
+
----------
|
| 380 |
+
decomposition_result:
|
| 381 |
+
An object with ``.observed``, ``.trend``, ``.seasonal``, and
|
| 382 |
+
``.resid`` attributes (e.g. from ``statsmodels.tsa.seasonal_decompose``).
|
| 383 |
+
"""
|
| 384 |
+
with _StyleContext(style_dict):
|
| 385 |
+
components = [
|
| 386 |
+
("Observed", decomposition_result.observed),
|
| 387 |
+
("Trend", decomposition_result.trend),
|
| 388 |
+
("Seasonal", decomposition_result.seasonal),
|
| 389 |
+
("Residual", decomposition_result.resid),
|
| 390 |
+
]
|
| 391 |
+
fig, axes = plt.subplots(4, 1, figsize=(10, 10), sharex=True)
|
| 392 |
+
|
| 393 |
+
for ax, (label, series) in zip(axes, components):
|
| 394 |
+
ax.plot(series.index, series.values, linewidth=1.2, color=MIAMI_RED)
|
| 395 |
+
ax.set_ylabel(label, fontsize=10)
|
| 396 |
+
ax.tick_params(axis="both", labelsize=9)
|
| 397 |
+
|
| 398 |
+
# Date formatting on the shared x-axis (bottom panel)
|
| 399 |
+
_auto_date_axis(axes[-1])
|
| 400 |
+
|
| 401 |
+
if title:
|
| 402 |
+
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.01)
|
| 403 |
+
return _finish_figure(fig)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# ===================================================================
|
| 407 |
+
# 7. Rolling overlay
|
| 408 |
+
# ===================================================================
|
| 409 |
+
|
| 410 |
+
def plot_rolling_overlay(
|
| 411 |
+
df: pd.DataFrame,
|
| 412 |
+
date_col: str,
|
| 413 |
+
y_col: str,
|
| 414 |
+
window: int,
|
| 415 |
+
title: Optional[str] = None,
|
| 416 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 417 |
+
palette_colors: Optional[List[str]] = None,
|
| 418 |
+
) -> matplotlib.figure.Figure:
|
| 419 |
+
"""Original series (light) with rolling-mean overlay (bold) and +/-1 std band."""
|
| 420 |
+
with _StyleContext(style_dict):
|
| 421 |
+
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
|
| 422 |
+
|
| 423 |
+
raw_color = _default_color(palette_colors, 0)
|
| 424 |
+
mean_color = _default_color(palette_colors, 1) if palette_colors and len(palette_colors) > 1 else "#333333"
|
| 425 |
+
|
| 426 |
+
dates = df[date_col]
|
| 427 |
+
vals = df[y_col]
|
| 428 |
+
rolling_mean = vals.rolling(window=window, center=True).mean()
|
| 429 |
+
rolling_std = vals.rolling(window=window, center=True).std()
|
| 430 |
+
|
| 431 |
+
# Original series (light)
|
| 432 |
+
ax.plot(dates, vals, linewidth=0.8, alpha=0.4, color=raw_color, label="Original")
|
| 433 |
+
|
| 434 |
+
# Rolling mean (bold)
|
| 435 |
+
ax.plot(dates, rolling_mean, linewidth=2.2, color=mean_color,
|
| 436 |
+
label=f"{window}-pt Rolling Mean")
|
| 437 |
+
|
| 438 |
+
# +/- 1 std band
|
| 439 |
+
ax.fill_between(
|
| 440 |
+
dates,
|
| 441 |
+
rolling_mean - rolling_std,
|
| 442 |
+
rolling_mean + rolling_std,
|
| 443 |
+
alpha=0.15, color=mean_color, label="\u00b11 Std Dev",
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
ax.set_xlabel(date_col)
|
| 447 |
+
ax.set_ylabel(y_col)
|
| 448 |
+
if title:
|
| 449 |
+
ax.set_title(title)
|
| 450 |
+
_auto_date_axis(ax)
|
| 451 |
+
ax.legend(loc="best")
|
| 452 |
+
return _finish_figure(fig)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# ===================================================================
|
| 456 |
+
# 8. Year-over-Year change
|
| 457 |
+
# ===================================================================
|
| 458 |
+
|
| 459 |
+
def plot_yoy_change(
|
| 460 |
+
df: pd.DataFrame,
|
| 461 |
+
date_col: str,
|
| 462 |
+
y_col: str,
|
| 463 |
+
yoy_df: pd.DataFrame,
|
| 464 |
+
title: Optional[str] = None,
|
| 465 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 466 |
+
) -> matplotlib.figure.Figure:
|
| 467 |
+
"""Two-subplot bar chart: absolute YoY change (top) and percentage YoY change (bottom).
|
| 468 |
+
|
| 469 |
+
Parameters
|
| 470 |
+
----------
|
| 471 |
+
yoy_df:
|
| 472 |
+
DataFrame with columns ``"date"``, ``"abs_change"``, ``"pct_change"``.
|
| 473 |
+
"""
|
| 474 |
+
with _StyleContext(style_dict):
|
| 475 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
|
| 476 |
+
|
| 477 |
+
dates = yoy_df["date"]
|
| 478 |
+
abs_change = yoy_df["abs_change"]
|
| 479 |
+
pct_change = yoy_df["pct_change"]
|
| 480 |
+
|
| 481 |
+
# Colours: green for positive, red for negative
|
| 482 |
+
abs_colors = ["#2ca02c" if v >= 0 else "#d62728" for v in abs_change]
|
| 483 |
+
pct_colors = ["#2ca02c" if v >= 0 else "#d62728" for v in pct_change]
|
| 484 |
+
|
| 485 |
+
ax1.bar(dates, abs_change, color=abs_colors, width=20, edgecolor="white", linewidth=0.3)
|
| 486 |
+
ax1.axhline(0, color="black", linewidth=0.6)
|
| 487 |
+
ax1.set_ylabel("Absolute Change")
|
| 488 |
+
ax1.set_title("Year-over-Year Absolute Change")
|
| 489 |
+
|
| 490 |
+
ax2.bar(dates, pct_change, color=pct_colors, width=20, edgecolor="white", linewidth=0.3)
|
| 491 |
+
ax2.axhline(0, color="black", linewidth=0.6)
|
| 492 |
+
ax2.set_ylabel("% Change")
|
| 493 |
+
ax2.set_title("Year-over-Year Percentage Change")
|
| 494 |
+
|
| 495 |
+
_auto_date_axis(ax2)
|
| 496 |
+
|
| 497 |
+
if title:
|
| 498 |
+
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
|
| 499 |
+
return _finish_figure(fig)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# ===================================================================
|
| 503 |
+
# 9. Lag plot
|
| 504 |
+
# ===================================================================
|
| 505 |
+
|
| 506 |
+
def plot_lag(
|
| 507 |
+
series: pd.Series,
|
| 508 |
+
lag: int = 1,
|
| 509 |
+
title: Optional[str] = None,
|
| 510 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 511 |
+
) -> matplotlib.figure.Figure:
|
| 512 |
+
"""Scatter plot of y(t) vs y(t-lag) with correlation-coefficient annotation."""
|
| 513 |
+
with _StyleContext(style_dict):
|
| 514 |
+
y = series.dropna().values
|
| 515 |
+
y_t = y[lag:]
|
| 516 |
+
y_lag = y[:-lag]
|
| 517 |
+
|
| 518 |
+
corr = np.corrcoef(y_t, y_lag)[0, 1]
|
| 519 |
+
|
| 520 |
+
fig, ax = plt.subplots(figsize=(7, 7))
|
| 521 |
+
ax.scatter(y_lag, y_t, alpha=0.5, s=20, color=MIAMI_RED, edgecolors="white", linewidths=0.3)
|
| 522 |
+
|
| 523 |
+
# Annotation
|
| 524 |
+
ax.annotate(
|
| 525 |
+
f"r = {corr:.3f}",
|
| 526 |
+
xy=(0.05, 0.95), xycoords="axes fraction",
|
| 527 |
+
fontsize=12, fontweight="bold",
|
| 528 |
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="#CCCCCC", alpha=0.9),
|
| 529 |
+
verticalalignment="top",
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
ax.set_xlabel(f"y(t\u2212{lag})")
|
| 533 |
+
ax.set_ylabel("y(t)")
|
| 534 |
+
if title:
|
| 535 |
+
ax.set_title(title)
|
| 536 |
+
else:
|
| 537 |
+
ax.set_title(f"Lag-{lag} Plot")
|
| 538 |
+
return _finish_figure(fig)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
# ===================================================================
|
| 542 |
+
# 10. Panel (small multiples)
|
| 543 |
+
# ===================================================================
|
| 544 |
+
|
| 545 |
+
def plot_panel(
|
| 546 |
+
df: pd.DataFrame,
|
| 547 |
+
date_col: str,
|
| 548 |
+
y_cols: List[str],
|
| 549 |
+
chart_type: str = "line",
|
| 550 |
+
shared_y: bool = True,
|
| 551 |
+
title: Optional[str] = None,
|
| 552 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 553 |
+
palette_colors: Optional[List[str]] = None,
|
| 554 |
+
) -> matplotlib.figure.Figure:
|
| 555 |
+
"""Small multiples: one subplot per *y_col* arranged in a grid.
|
| 556 |
+
|
| 557 |
+
Parameters
|
| 558 |
+
----------
|
| 559 |
+
chart_type:
|
| 560 |
+
``"line"`` or ``"bar"``.
|
| 561 |
+
shared_y:
|
| 562 |
+
If ``True`` all panels share the same y-axis limits.
|
| 563 |
+
"""
|
| 564 |
+
with _StyleContext(style_dict):
|
| 565 |
+
n = len(y_cols)
|
| 566 |
+
nrows, ncols = _grid_dims(n)
|
| 567 |
+
fig_h = max(4, nrows * 3.5)
|
| 568 |
+
fig_w = max(8, ncols * 4.5)
|
| 569 |
+
fig, axes = plt.subplots(
|
| 570 |
+
nrows, ncols, figsize=(fig_w, fig_h),
|
| 571 |
+
sharey=shared_y, squeeze=False,
|
| 572 |
+
)
|
| 573 |
+
flat_axes = axes.flatten()
|
| 574 |
+
|
| 575 |
+
for i, col in enumerate(y_cols):
|
| 576 |
+
ax = flat_axes[i]
|
| 577 |
+
color = _default_color(palette_colors, i)
|
| 578 |
+
|
| 579 |
+
if chart_type == "bar":
|
| 580 |
+
ax.bar(df[date_col], df[col], color=color, width=2, edgecolor="white", linewidth=0.3)
|
| 581 |
+
else:
|
| 582 |
+
ax.plot(df[date_col], df[col], linewidth=1.3, color=color)
|
| 583 |
+
|
| 584 |
+
ax.set_title(col, fontsize=10)
|
| 585 |
+
_auto_date_axis(ax)
|
| 586 |
+
|
| 587 |
+
# Hide unused subplots
|
| 588 |
+
for j in range(n, len(flat_axes)):
|
| 589 |
+
flat_axes[j].set_visible(False)
|
| 590 |
+
|
| 591 |
+
if title:
|
| 592 |
+
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
|
| 593 |
+
return _finish_figure(fig)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
# ===================================================================
|
| 597 |
+
# 11. Spaghetti plot
|
| 598 |
+
# ===================================================================
|
| 599 |
+
|
| 600 |
+
def plot_spaghetti(
|
| 601 |
+
df: pd.DataFrame,
|
| 602 |
+
date_col: str,
|
| 603 |
+
y_cols: List[str],
|
| 604 |
+
alpha: float = 0.15,
|
| 605 |
+
highlight_col: Optional[str] = None,
|
| 606 |
+
top_n: Optional[int] = None,
|
| 607 |
+
show_median_band: bool = False,
|
| 608 |
+
title: Optional[str] = None,
|
| 609 |
+
style_dict: Optional[Dict[str, object]] = None,
|
| 610 |
+
palette_colors: Optional[List[str]] = None,
|
| 611 |
+
) -> matplotlib.figure.Figure:
|
| 612 |
+
"""All series on one plot at low opacity, with optional highlighting.
|
| 613 |
+
|
| 614 |
+
Parameters
|
| 615 |
+
----------
|
| 616 |
+
alpha:
|
| 617 |
+
Opacity for the background spaghetti lines.
|
| 618 |
+
highlight_col:
|
| 619 |
+
Column name to draw with full opacity and thicker line.
|
| 620 |
+
top_n:
|
| 621 |
+
If set, highlight the *top_n* series by maximum value.
|
| 622 |
+
show_median_band:
|
| 623 |
+
If ``True``, overlay the median line and shade the IQR.
|
| 624 |
+
"""
|
| 625 |
+
with _StyleContext(style_dict):
|
| 626 |
+
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
|
| 627 |
+
|
| 628 |
+
dates = df[date_col]
|
| 629 |
+
|
| 630 |
+
# Determine which columns to highlight
|
| 631 |
+
highlight_set: set[str] = set()
|
| 632 |
+
if highlight_col and highlight_col in y_cols:
|
| 633 |
+
highlight_set.add(highlight_col)
|
| 634 |
+
if top_n:
|
| 635 |
+
max_vals = {col: df[col].max() for col in y_cols}
|
| 636 |
+
sorted_cols = sorted(max_vals, key=max_vals.get, reverse=True) # type: ignore[arg-type]
|
| 637 |
+
highlight_set.update(sorted_cols[:top_n])
|
| 638 |
+
|
| 639 |
+
# Draw all series
|
| 640 |
+
for i, col in enumerate(y_cols):
|
| 641 |
+
color = _default_color(palette_colors, i)
|
| 642 |
+
if col in highlight_set:
|
| 643 |
+
ax.plot(dates, df[col], linewidth=2.0, alpha=0.9,
|
| 644 |
+
color=color, label=col, zorder=3)
|
| 645 |
+
else:
|
| 646 |
+
ax.plot(dates, df[col], linewidth=0.8, alpha=alpha,
|
| 647 |
+
color=color, zorder=1)
|
| 648 |
+
|
| 649 |
+
# Median + IQR band
|
| 650 |
+
if show_median_band:
|
| 651 |
+
numeric_data = df[y_cols]
|
| 652 |
+
median_line = numeric_data.median(axis=1)
|
| 653 |
+
q1 = numeric_data.quantile(0.25, axis=1)
|
| 654 |
+
q3 = numeric_data.quantile(0.75, axis=1)
|
| 655 |
+
|
| 656 |
+
ax.plot(dates, median_line, linewidth=2.2, color="#333333",
|
| 657 |
+
label="Median", zorder=4)
|
| 658 |
+
ax.fill_between(dates, q1, q3, alpha=0.2, color="#333333",
|
| 659 |
+
label="IQR", zorder=2)
|
| 660 |
+
|
| 661 |
+
ax.set_xlabel(date_col)
|
| 662 |
+
ax.set_ylabel("Value")
|
| 663 |
+
if title:
|
| 664 |
+
ax.set_title(title)
|
| 665 |
+
_auto_date_axis(ax)
|
| 666 |
+
|
| 667 |
+
# Only add legend if there are labelled items
|
| 668 |
+
handles, labels = ax.get_legend_handles_labels()
|
| 669 |
+
if labels:
|
| 670 |
+
ax.legend(loc="best", fontsize=8)
|
| 671 |
+
return _finish_figure(fig)
|
src/querychat_helpers.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QueryChat initialization and filtered DataFrame helpers.
|
| 3 |
+
|
| 4 |
+
Provides convenience wrappers around the ``querychat`` library for
|
| 5 |
+
natural-language filtering of time-series DataFrames inside a Streamlit
|
| 6 |
+
app. All functions degrade gracefully when the package or an API key
|
| 7 |
+
is unavailable.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from typing import List, Optional
|
| 14 |
+
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import streamlit as st
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from querychat.streamlit import QueryChat as _QueryChat
|
| 20 |
+
|
| 21 |
+
_QUERYCHAT_AVAILABLE = True
|
| 22 |
+
except ImportError: # pragma: no cover
|
| 23 |
+
_QUERYCHAT_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Availability check
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def check_querychat_available() -> bool:
|
| 31 |
+
"""Return ``True`` when both *querychat* is installed and an API key is set.
|
| 32 |
+
|
| 33 |
+
QueryChat requires an ``OPENAI_API_KEY`` environment variable. This
|
| 34 |
+
helper lets callers gate UI elements behind a simple boolean.
|
| 35 |
+
"""
|
| 36 |
+
if not _QUERYCHAT_AVAILABLE:
|
| 37 |
+
return False
|
| 38 |
+
return bool(os.environ.get("OPENAI_API_KEY"))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# QueryChat factory
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
def create_querychat(
|
| 46 |
+
df: pd.DataFrame,
|
| 47 |
+
name: str = "dataset",
|
| 48 |
+
date_col: str = "date",
|
| 49 |
+
y_cols: Optional[List[str]] = None,
|
| 50 |
+
freq_label: str = "",
|
| 51 |
+
):
|
| 52 |
+
"""Create and return a QueryChat instance bound to *df*.
|
| 53 |
+
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
df:
|
| 57 |
+
The pandas DataFrame to expose to the chat interface.
|
| 58 |
+
name:
|
| 59 |
+
A human-readable name for the dataset (used in the description).
|
| 60 |
+
date_col:
|
| 61 |
+
Name of the date/time column.
|
| 62 |
+
y_cols:
|
| 63 |
+
Names of the value (numeric) columns. If ``None``, an empty
|
| 64 |
+
list is used in the description.
|
| 65 |
+
freq_label:
|
| 66 |
+
Optional frequency label (e.g. ``"Monthly"``, ``"Daily"``).
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
QueryChat instance
|
| 71 |
+
The object returned by ``QueryChat()``.
|
| 72 |
+
|
| 73 |
+
Raises
|
| 74 |
+
------
|
| 75 |
+
RuntimeError
|
| 76 |
+
If querychat is not installed.
|
| 77 |
+
"""
|
| 78 |
+
if not _QUERYCHAT_AVAILABLE:
|
| 79 |
+
raise RuntimeError(
|
| 80 |
+
"The 'querychat' package is not installed. "
|
| 81 |
+
"Install it with: pip install 'querychat[streamlit]'"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if y_cols is None:
|
| 85 |
+
y_cols = []
|
| 86 |
+
|
| 87 |
+
value_cols_str = ", ".join(y_cols) if y_cols else "none specified"
|
| 88 |
+
freq_part = f" Frequency: {freq_label}." if freq_label else ""
|
| 89 |
+
|
| 90 |
+
data_description = (
|
| 91 |
+
f"This dataset is named '{name}'. "
|
| 92 |
+
f"It contains {len(df):,} rows. "
|
| 93 |
+
f"The date column is '{date_col}'. "
|
| 94 |
+
f"Value columns: {value_cols_str}."
|
| 95 |
+
f"{freq_part}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
greeting = (
|
| 99 |
+
f"Hi! I can help you filter and explore the **{name}** dataset. "
|
| 100 |
+
"Try asking me something like:\n"
|
| 101 |
+
'- "Show only 2023 data"\n'
|
| 102 |
+
'- "Filter where sales > 60000"\n'
|
| 103 |
+
'- "Show rows from January to March"'
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
qc = _QueryChat(
|
| 107 |
+
data_source=df,
|
| 108 |
+
table_name=name.replace(" ", "_"),
|
| 109 |
+
client="openai/gpt-5.2-2025-12-11",
|
| 110 |
+
data_description=data_description,
|
| 111 |
+
greeting=greeting,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return qc
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
# Filtered DataFrame extraction
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
|
| 121 |
+
def get_filtered_pandas_df(qc) -> pd.DataFrame:
|
| 122 |
+
"""Extract the currently filtered DataFrame from a QueryChat instance.
|
| 123 |
+
|
| 124 |
+
The underlying ``qc.df()`` may return a *narwhals* DataFrame rather
|
| 125 |
+
than a pandas one. This helper transparently converts when needed
|
| 126 |
+
and falls back to the original frame on any error.
|
| 127 |
+
|
| 128 |
+
Parameters
|
| 129 |
+
----------
|
| 130 |
+
qc:
|
| 131 |
+
A QueryChat instance previously created via :func:`create_querychat`.
|
| 132 |
+
|
| 133 |
+
Returns
|
| 134 |
+
-------
|
| 135 |
+
pd.DataFrame
|
| 136 |
+
The filtered data as a pandas DataFrame.
|
| 137 |
+
"""
|
| 138 |
+
try:
|
| 139 |
+
result = qc.df()
|
| 140 |
+
|
| 141 |
+
# narwhals (or polars) DataFrames expose .to_pandas()
|
| 142 |
+
if hasattr(result, "to_pandas"):
|
| 143 |
+
return result.to_pandas()
|
| 144 |
+
|
| 145 |
+
# Already a pandas DataFrame
|
| 146 |
+
if isinstance(result, pd.DataFrame):
|
| 147 |
+
return result
|
| 148 |
+
|
| 149 |
+
# Unknown type -- attempt conversion as a last resort
|
| 150 |
+
return pd.DataFrame(result)
|
| 151 |
+
except Exception: # noqa: BLE001
|
| 152 |
+
# If anything goes wrong, surface the unfiltered data so the app
|
| 153 |
+
# can continue to function.
|
| 154 |
+
try:
|
| 155 |
+
raw = qc.df()
|
| 156 |
+
if isinstance(raw, pd.DataFrame):
|
| 157 |
+
return raw
|
| 158 |
+
except Exception: # noqa: BLE001
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
return pd.DataFrame()
|
src/ui_theme.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ui_theme.py
|
| 3 |
+
-----------
|
| 4 |
+
Miami University branded theme and styling utilities for Streamlit apps.
|
| 5 |
+
|
| 6 |
+
Provides:
|
| 7 |
+
- CSS injection for Streamlit components (buttons, sidebar, metrics, cards)
|
| 8 |
+
- Matplotlib rcParams styled with Miami branding
|
| 9 |
+
- ColorBrewer palette loading via palettable with graceful fallback
|
| 10 |
+
- Color-swatch preview figure generation
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import itertools
|
| 16 |
+
from typing import Dict, List, Optional
|
| 17 |
+
|
| 18 |
+
import matplotlib.figure
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import streamlit as st
|
| 21 |
+
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
# Brand constants — Miami University (Ohio) official palette
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
MIAMI_RED: str = "#C41230"
|
| 26 |
+
MIAMI_BLACK: str = "#000000"
|
| 27 |
+
MIAMI_WHITE: str = "#FFFFFF"
|
| 28 |
+
|
| 29 |
+
# Secondary palette tokens used only inside the CSS below.
|
| 30 |
+
_WHITE = "#FFFFFF"
|
| 31 |
+
_BLACK = "#000000"
|
| 32 |
+
_LIGHT_GRAY = "#F5F5F5"
|
| 33 |
+
_BORDER_GRAY = "#E0E0E0"
|
| 34 |
+
_DARK_TEXT = "#000000"
|
| 35 |
+
_HOVER_RED = "#9E0E26"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Streamlit CSS injection
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
def apply_miami_theme() -> None:
|
| 42 |
+
"""Inject Miami-branded CSS into the active Streamlit page.
|
| 43 |
+
|
| 44 |
+
Styles affected:
|
| 45 |
+
* Primary buttons -- Miami Red background with white text
|
| 46 |
+
* Card containers -- subtle border and rounded corners
|
| 47 |
+
* Sidebar header -- Miami Red accent bar
|
| 48 |
+
* Metric cards -- light background with left red accent
|
| 49 |
+
"""
|
| 50 |
+
css = f"""
|
| 51 |
+
<style>
|
| 52 |
+
/* ---- Primary buttons ---- */
|
| 53 |
+
.stButton > button[kind="primary"],
|
| 54 |
+
.stButton > button {{
|
| 55 |
+
background-color: {MIAMI_RED};
|
| 56 |
+
color: {_WHITE};
|
| 57 |
+
border: none;
|
| 58 |
+
border-radius: 6px;
|
| 59 |
+
padding: 0.5rem 1.25rem;
|
| 60 |
+
font-weight: 600;
|
| 61 |
+
transition: background-color 0.2s ease;
|
| 62 |
+
}}
|
| 63 |
+
.stButton > button:hover {{
|
| 64 |
+
background-color: {_HOVER_RED};
|
| 65 |
+
color: {_WHITE};
|
| 66 |
+
border: none;
|
| 67 |
+
}}
|
| 68 |
+
.stButton > button:active,
|
| 69 |
+
.stButton > button:focus {{
|
| 70 |
+
background-color: {_HOVER_RED};
|
| 71 |
+
color: {_WHITE};
|
| 72 |
+
box-shadow: none;
|
| 73 |
+
}}
|
| 74 |
+
|
| 75 |
+
/* ---- Card borders ---- */
|
| 76 |
+
div[data-testid="stExpander"],
|
| 77 |
+
div[data-testid="stHorizontalBlock"] > div {{
|
| 78 |
+
border: 1px solid {_BORDER_GRAY};
|
| 79 |
+
border-radius: 8px;
|
| 80 |
+
padding: 0.75rem;
|
| 81 |
+
}}
|
| 82 |
+
|
| 83 |
+
/* ---- Sidebar header accent ---- */
|
| 84 |
+
section[data-testid="stSidebar"] > div:first-child {{
|
| 85 |
+
border-top: 4px solid {MIAMI_RED};
|
| 86 |
+
}}
|
| 87 |
+
section[data-testid="stSidebar"] h1,
|
| 88 |
+
section[data-testid="stSidebar"] h2,
|
| 89 |
+
section[data-testid="stSidebar"] h3 {{
|
| 90 |
+
color: {MIAMI_RED};
|
| 91 |
+
}}
|
| 92 |
+
|
| 93 |
+
/* ---- Metric cards ---- */
|
| 94 |
+
div[data-testid="stMetric"] {{
|
| 95 |
+
background-color: {_LIGHT_GRAY};
|
| 96 |
+
border-left: 4px solid {MIAMI_RED};
|
| 97 |
+
border-radius: 6px;
|
| 98 |
+
padding: 0.75rem 1rem;
|
| 99 |
+
}}
|
| 100 |
+
div[data-testid="stMetric"] label {{
|
| 101 |
+
color: {_BLACK};
|
| 102 |
+
font-size: 0.85rem;
|
| 103 |
+
}}
|
| 104 |
+
div[data-testid="stMetric"] div[data-testid="stMetricValue"] {{
|
| 105 |
+
color: {_BLACK};
|
| 106 |
+
font-weight: 700;
|
| 107 |
+
}}
|
| 108 |
+
</style>
|
| 109 |
+
"""
|
| 110 |
+
st.markdown(css, unsafe_allow_html=True)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
# Matplotlib style dictionary
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
def get_miami_mpl_style() -> Dict[str, object]:
|
| 117 |
+
"""Return a dictionary of matplotlib rcParams for Miami branding.
|
| 118 |
+
|
| 119 |
+
Usage::
|
| 120 |
+
|
| 121 |
+
import matplotlib as mpl
|
| 122 |
+
mpl.rcParams.update(get_miami_mpl_style())
|
| 123 |
+
|
| 124 |
+
Or apply to a single figure::
|
| 125 |
+
|
| 126 |
+
with mpl.rc_context(get_miami_mpl_style()):
|
| 127 |
+
fig, ax = plt.subplots()
|
| 128 |
+
...
|
| 129 |
+
"""
|
| 130 |
+
return {
|
| 131 |
+
# Figure
|
| 132 |
+
"figure.facecolor": _WHITE,
|
| 133 |
+
"figure.edgecolor": _WHITE,
|
| 134 |
+
"figure.figsize": (10, 5),
|
| 135 |
+
"figure.dpi": 100,
|
| 136 |
+
# Axes
|
| 137 |
+
"axes.facecolor": _WHITE,
|
| 138 |
+
"axes.edgecolor": _BLACK,
|
| 139 |
+
"axes.labelcolor": _BLACK,
|
| 140 |
+
"axes.titlecolor": MIAMI_RED,
|
| 141 |
+
"axes.labelsize": 12,
|
| 142 |
+
"axes.titlesize": 14,
|
| 143 |
+
"axes.titleweight": "bold",
|
| 144 |
+
"axes.prop_cycle": plt.cycler(
|
| 145 |
+
color=[MIAMI_RED, _BLACK, "#4E79A7", "#F28E2B", "#76B7B2"]
|
| 146 |
+
),
|
| 147 |
+
# Grid
|
| 148 |
+
"axes.grid": True,
|
| 149 |
+
"grid.color": _BORDER_GRAY,
|
| 150 |
+
"grid.linestyle": "--",
|
| 151 |
+
"grid.linewidth": 0.6,
|
| 152 |
+
"grid.alpha": 0.7,
|
| 153 |
+
# Ticks
|
| 154 |
+
"xtick.color": _BLACK,
|
| 155 |
+
"ytick.color": _BLACK,
|
| 156 |
+
"xtick.labelsize": 10,
|
| 157 |
+
"ytick.labelsize": 10,
|
| 158 |
+
# Legend
|
| 159 |
+
"legend.fontsize": 10,
|
| 160 |
+
"legend.frameon": True,
|
| 161 |
+
"legend.framealpha": 0.9,
|
| 162 |
+
"legend.edgecolor": _BORDER_GRAY,
|
| 163 |
+
# Font
|
| 164 |
+
"font.size": 11,
|
| 165 |
+
"font.family": "sans-serif",
|
| 166 |
+
# Savefig
|
| 167 |
+
"savefig.dpi": 150,
|
| 168 |
+
"savefig.bbox": "tight",
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
# ColorBrewer palette loading
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
|
| 176 |
+
# Mapping of short friendly names to palettable module paths.
|
| 177 |
+
_PALETTE_MAP: Dict[str, str] = {
|
| 178 |
+
"Set1": "colorbrewer.qualitative.Set1",
|
| 179 |
+
"Set2": "colorbrewer.qualitative.Set2",
|
| 180 |
+
"Set3": "colorbrewer.qualitative.Set3",
|
| 181 |
+
"Dark2": "colorbrewer.qualitative.Dark2",
|
| 182 |
+
"Paired": "colorbrewer.qualitative.Paired",
|
| 183 |
+
"Pastel1": "colorbrewer.qualitative.Pastel1",
|
| 184 |
+
"Pastel2": "colorbrewer.qualitative.Pastel2",
|
| 185 |
+
"Accent": "colorbrewer.qualitative.Accent",
|
| 186 |
+
"Tab10": "colorbrewer.qualitative.Set1", # fallback alias
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
_FALLBACK_COLORS: List[str] = [
|
| 190 |
+
MIAMI_RED,
|
| 191 |
+
MIAMI_BLACK,
|
| 192 |
+
"#4E79A7",
|
| 193 |
+
"#F28E2B",
|
| 194 |
+
"#76B7B2",
|
| 195 |
+
"#E15759",
|
| 196 |
+
"#59A14F",
|
| 197 |
+
"#EDC948",
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _resolve_palette(name: str) -> Optional[List[str]]:
|
| 202 |
+
"""Dynamically import a palettable ColorBrewer palette by *name*.
|
| 203 |
+
|
| 204 |
+
Palettable organises palettes by maximum number of classes, e.g.
|
| 205 |
+
``colorbrewer.qualitative.Set2_8``. We find the variant with the
|
| 206 |
+
most colours available so the caller gets the richest palette.
|
| 207 |
+
"""
|
| 208 |
+
import importlib
|
| 209 |
+
|
| 210 |
+
module_path = _PALETTE_MAP.get(name)
|
| 211 |
+
if module_path is None:
|
| 212 |
+
# Try a direct guess: colorbrewer.qualitative.<Name>
|
| 213 |
+
module_path = f"colorbrewer.qualitative.{name}"
|
| 214 |
+
|
| 215 |
+
# palettable stores each size variant as <Name>_<N> inside the module.
|
| 216 |
+
try:
|
| 217 |
+
mod = importlib.import_module(f"palettable.{module_path}")
|
| 218 |
+
except (ImportError, ModuleNotFoundError):
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
# Discover the variant with the most colours.
|
| 222 |
+
best = None
|
| 223 |
+
best_n = 0
|
| 224 |
+
base = name.split(".")[-1] if "." in name else name
|
| 225 |
+
for attr_name in dir(mod):
|
| 226 |
+
if not attr_name.startswith(base + "_"):
|
| 227 |
+
continue
|
| 228 |
+
try:
|
| 229 |
+
suffix = int(attr_name.split("_")[-1])
|
| 230 |
+
except ValueError:
|
| 231 |
+
continue
|
| 232 |
+
if suffix > best_n:
|
| 233 |
+
best_n = suffix
|
| 234 |
+
best = attr_name
|
| 235 |
+
|
| 236 |
+
if best is None:
|
| 237 |
+
return None
|
| 238 |
+
|
| 239 |
+
palette_obj = getattr(mod, best, None)
|
| 240 |
+
if palette_obj is None:
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
return [
|
| 244 |
+
"#{:02X}{:02X}{:02X}".format(*rgb) for rgb in palette_obj.colors
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_palette_colors(name: str = "Set2", n: int = 8) -> List[str]:
|
| 249 |
+
"""Load *n* hex colour strings from a ColorBrewer palette.
|
| 250 |
+
|
| 251 |
+
Parameters
|
| 252 |
+
----------
|
| 253 |
+
name:
|
| 254 |
+
Friendly palette name such as ``"Set2"``, ``"Dark2"``, ``"Paired"``.
|
| 255 |
+
n:
|
| 256 |
+
Number of colours required. If *n* exceeds the palette length the
|
| 257 |
+
colours are cycled.
|
| 258 |
+
|
| 259 |
+
Returns
|
| 260 |
+
-------
|
| 261 |
+
list[str]
|
| 262 |
+
List of *n* hex colour strings (e.g. ``["#66C2A5", ...]``).
|
| 263 |
+
|
| 264 |
+
Notes
|
| 265 |
+
-----
|
| 266 |
+
If the requested palette cannot be found, a sensible fallback list is
|
| 267 |
+
returned so that calling code never receives an empty list.
|
| 268 |
+
"""
|
| 269 |
+
n = max(1, n)
|
| 270 |
+
colors = _resolve_palette(name)
|
| 271 |
+
if colors is None:
|
| 272 |
+
colors = _FALLBACK_COLORS
|
| 273 |
+
|
| 274 |
+
# Cycle if the caller needs more colours than the palette provides.
|
| 275 |
+
cycled = list(itertools.islice(itertools.cycle(colors), n))
|
| 276 |
+
return cycled
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# ---------------------------------------------------------------------------
|
| 280 |
+
# Palette preview swatch
|
| 281 |
+
# ---------------------------------------------------------------------------
|
| 282 |
+
def render_palette_preview(
|
| 283 |
+
colors: List[str],
|
| 284 |
+
swatch_width: float = 1.0,
|
| 285 |
+
swatch_height: float = 0.4,
|
| 286 |
+
) -> matplotlib.figure.Figure:
|
| 287 |
+
"""Create a small matplotlib figure showing colour swatches.
|
| 288 |
+
|
| 289 |
+
Parameters
|
| 290 |
+
----------
|
| 291 |
+
colors:
|
| 292 |
+
List of hex colour strings to display.
|
| 293 |
+
swatch_width:
|
| 294 |
+
Width of each individual swatch in inches.
|
| 295 |
+
swatch_height:
|
| 296 |
+
Height of the swatch strip in inches.
|
| 297 |
+
|
| 298 |
+
Returns
|
| 299 |
+
-------
|
| 300 |
+
matplotlib.figure.Figure
|
| 301 |
+
A Figure instance ready to be passed to ``st.pyplot()`` or saved.
|
| 302 |
+
"""
|
| 303 |
+
n = len(colors)
|
| 304 |
+
fig_width = max(swatch_width * n, 2.0)
|
| 305 |
+
fig, ax = plt.subplots(
|
| 306 |
+
figsize=(fig_width, swatch_height + 0.3), dpi=100
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
for i, colour in enumerate(colors):
|
| 310 |
+
ax.add_patch(
|
| 311 |
+
plt.Rectangle(
|
| 312 |
+
(i, 0),
|
| 313 |
+
width=1,
|
| 314 |
+
height=1,
|
| 315 |
+
facecolor=colour,
|
| 316 |
+
edgecolor=_WHITE,
|
| 317 |
+
linewidth=1.5,
|
| 318 |
+
)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
ax.set_xlim(0, n)
|
| 322 |
+
ax.set_ylim(0, 1)
|
| 323 |
+
ax.set_aspect("equal")
|
| 324 |
+
ax.axis("off")
|
| 325 |
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
| 326 |
+
plt.close(fig) # prevent display in non-Streamlit contexts
|
| 327 |
+
return fig
|