import asyncio import sys import uuid from datetime import datetime from pathlib import Path import pandas as pd import plotly.express as px import plotly.graph_objects as go import streamlit as st # ── Path setup so `src` is importable when running from src/ or project root ── _here = Path(__file__).resolve().parent # src/ _root = _here.parent # project root for _p in [str(_here), str(_root)]: if _p not in sys.path: sys.path.insert(0, _p) from controllers._data_extractor import DataExtractorController, UserQuery from schemas import Message INCIDENT_COLOR_MAP = { "1": ("#ef4444", "Fire"), "2": ("#eab308", "Rupt/Exp"), "3": ("#3b82f6", "EMS"), "4": ("#22c55e", "Hazardous"), "5": ("#a855f7", "Public Assist"), "6": ("#06b6d4", "Good Intent"), "7": ("#9ca3af", "False Alarms"), "8": ("#84cc16", "Weather"), "9": ("#c0c0c0", "Special Type"), } INCIDENT_NULL_COLOR = ("#111111", "Not Entered") INCIDENT_BLANK_COLOR = ("#f97316", "Blank") INICIDENT_CATEGORY_NAMES = [ "Fire", "Rupt/Exp", "EMS", "Hazardous", "Public Assist", "Good Intent", "False Alarms", "Weather", "Special Type", "Not Entered", "Blank", ] INCIDENT_NAME_COLOR_MAP = { "fire": ("#ef4444", "Fire"), "rupt/exp": ("#eab308", "Rupt/Exp"), "ems": ("#3b82f6", "EMS"), "hazardous": ("#22c55e", "Hazardous"), "public assist": ("#a855f7", "Public Assist"), "good intent": ("#06b6d4", "Good Intent"), "false alarms": ("#9ca3af", "False Alarms"), "weather": ("#84cc16", "Weather"), "special type": ("#c0c0c0", "Special Type"), } INCIDENT_COL_NAMES = { "incidenttype", "incident_type", "incidentclassification", "incident_classification", "incident_category", "incidentcategory", } def detect_date_columns(df: pd.DataFrame): for col in df.columns: if pd.api.types.is_numeric_dtype(df[col]): continue try: numeric_like_ratio = pd.to_numeric(df[col], errors="coerce").notna().mean() if numeric_like_ratio > 0.8: continue except Exception: pass if pd.api.types.is_datetime64_any_dtype(df[col]): return col else: try: converted = pd.to_datetime(df[col], errors="coerce") valid_ratio = converted.notna().mean() if valid_ratio > 0.8: return col except Exception: pass return None def get_date_range(df: pd.DataFrame): date_col = detect_date_columns(df) if not date_col: return None series = pd.to_datetime(df[date_col], errors="coerce").dropna() if series.empty: return None col_min = series.min() col_max = series.max() if col_min and col_max: return f"{col_min.strftime('%b %d, %Y')} → {col_max.strftime('%b %d, %Y')}" return None def _detect_incident_col(df: pd.DataFrame) -> str | None: incident_cols = [col for col in df.columns if col.strip().lower() in INCIDENT_COL_NAMES] return incident_cols if incident_cols else None def _incident_label_and_color(value) -> tuple[str, str]: """Return (display_label, hex_color) for a raw incidenttype value.""" if value is None or value == "nan" or (isinstance(value, float) and pd.isna(value)): return INCIDENT_NULL_COLOR[1], INCIDENT_NULL_COLOR[0] s = str(value).strip() if s == "": return INCIDENT_BLANK_COLOR[1], INCIDENT_BLANK_COLOR[0] prefix = s[0].upper() if s[0].isalpha() else s[0] key = s[0].upper() if s[0].upper() == "S" else s[0] if key in INCIDENT_COLOR_MAP: color, label = INCIDENT_COLOR_MAP[key] return label, color elif any( category_name.lower() in s.lower() for category_name in INICIDENT_CATEGORY_NAMES ): category_name_found = next( category_name for category_name in INICIDENT_CATEGORY_NAMES if category_name.lower() in s.lower() ) name_key = category_name_found.lower() if name_key in INCIDENT_NAME_COLOR_MAP: color, label = INCIDENT_NAME_COLOR_MAP[name_key] return label, color return s, "#6b7280" def _add_incident_category(df: pd.DataFrame, col: str) -> pd.DataFrame: df = df.copy() if col not in df.columns or df.empty: return df df = df.reset_index(drop=True) plain = df[col].astype(str).where(df[col].notna(), other=None) mapped = plain.map(_incident_label_and_color) df["_incident_label"] = mapped.apply(lambda x: x[0]) df["_incident_color"] = mapped.apply(lambda x: x[1]) return df # ── Page config ────────────────────────────────────────────────────────────── st.set_page_config( page_title="Data Extractor AI", page_icon="🔍", layout="wide", initial_sidebar_state="expanded", ) # ── Plotly template (adapts to Streamlit theme) ────────────────────────────── def _get_plotly_template(): """Return a Plotly template that works with Streamlit's current theme.""" return "plotly_white" # ── Chart rendering ─────────────────────────────────────────────────────────── CHART_ALIASES = { "pie": "pie", "bar": "bar", "line": "line", "bar_chart": "bar", "vertical_bar": "bar", "column": "bar", "grouped_bar": "bar", "stacked_bar": "bar", "line_chart": "line", "time_series": "line", "trend": "line", "donut": "pie", "doughnut": "pie", } def _normalise_chart_type(raw: str | None) -> str | None: if not raw: return None return CHART_ALIASES.get(raw.lower().strip(), raw.lower().strip()) def _guess_chart_type(df: pd.DataFrame) -> str: cols = list(df.columns) n_cols = len(cols) n_rows = len(df) numeric = df.select_dtypes(include="number").columns.tolist() if n_cols == 1 and numeric: return "histogram" cat = df.select_dtypes(exclude="number").columns.tolist() if n_cols == 2 and len(cat) == 1 and len(numeric) == 1: return "bar" if n_rows <= 50 else "line" if len(numeric) >= 2: return "line" return "bar" def normalize(x): try: num = float(x) if num.is_integer(): return str(int(num)) return str(num) except: return x def sort_key(x): try: return (0, float(x)) except: return (1, x) def get_categorical_columns(df: pd.DataFrame, column_name: str) -> list[str]: df[column_name] = df[column_name].astype(str) df[column_name] = df[column_name].map(normalize) cats = sorted(df[column_name].dropna().unique(), key=sort_key) df[column_name] = pd.Categorical(df[column_name], categories=cats, ordered=True) s = df.sort_values(by=column_name).reset_index(drop=True)[column_name] return s.tolist() def render_chart( df: pd.DataFrame, incident_cols: list[str] | None = None, chart_type_raw: str | None = None, key_prefix: str = "chart", ): if df.empty: st.info("No data to chart.") return df = df.copy() for c in df.columns: if hasattr(df[c], "cat"): df[c] = df[c].astype(str).replace("nan", None) chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df) incident_cols = incident_cols or [] active_incident_col: str | None = None df_plot = df.copy() cols = list(df.columns) numeric = df.select_dtypes(include="number").columns.tolist() cat = df.select_dtypes(exclude="number").columns.tolist() default_x = ( incident_cols[0] if incident_cols and incident_cols[0] in cols else (cat[0] if cat else cols[0]) ) default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0]) # ── Column selectors ── st.caption("⚙️ Configure columns") if chart_type == "pie": c1, c2 = st.columns(2) with c1: x_col = st.selectbox( "Labels", options=cols, index=cols.index(default_x) if default_x in cols else 0, key=f"{key_prefix}_pie_x", ) with c2: val_opts = numeric if numeric else cols for col in cols: try: if pd.to_numeric(df[col], errors="coerce").notna().all() and col not in val_opts: val_opts.append(col) except Exception: pass y_col = st.selectbox( "Values", options=val_opts, index=val_opts.index(default_y) if default_y in val_opts else 0, key=f"{key_prefix}_pie_y", ) color_col = None else: c1, c2, c3 = st.columns(3) with c1: x_col = st.selectbox( "X axis", options=cols, index=cols.index(default_x) if default_x in cols else 0, key=f"{key_prefix}_x", ) view_all_labels = st.checkbox( "View All Labels", key=f"{key_prefix}_x_all_labels", ) with c2: y_opts = numeric if numeric else cols y_col = st.selectbox( "Y axis", options=y_opts, index=y_opts.index(default_y) if default_y in y_opts else 0, key=f"{key_prefix}_y", ) with c3: color_options = ["None"] + [c for c in cols if c not in (x_col, y_col)] color_sel = st.selectbox( "Color / Group", options=color_options, index=0, key=f"{key_prefix}_color", ) view_horizontal_stacked = st.checkbox( "Horizontal Stacked", key=f"{key_prefix}_stacked", ) color_col = None if color_sel == "None" else color_sel # ── Incident color mapping ───────────────────────── incident_color_map = None if active_incident_col and "_incident_label" in df_plot.columns: incident_color_map = dict( zip(df_plot["_incident_label"], df_plot["_incident_color"]) ) # ── Build chart ──────────────────────────────────── fig = None tmpl = _get_plotly_template() date_range = get_date_range(df_plot) title = f"{y_col} by {x_col}" if date_range: title += f" ({date_range})" try: # ───── BAR ───── if chart_type == "bar": if view_all_labels: df_plot[x_col] = df_plot[x_col].astype(str) active_incident_col = x_col if x_col in incident_cols else None df_plot = ( _add_incident_category(df, active_incident_col) if active_incident_col else df.copy() ) incident_color_map = ( dict( zip( df_plot["_incident_label"].tolist(), df_plot["_incident_color"].tolist(), ) ) if active_incident_col and "_incident_label" in df_plot.columns else None ) if color_col: df_plot_copy = df_plot.copy() color_incident_color_map = None if color_col in incident_cols: df_plot_copy = _add_incident_category(df_plot_copy, color_col) color_incident_color_map = dict( zip( df_plot_copy["_incident_label"].tolist(), df_plot_copy["_incident_color"].tolist(), ) ) if not color_incident_color_map: color_incident_color_map = incident_color_map if active_incident_col is not None: x_col = "_incident_label" if color_col in incident_cols: color_col = "_incident_label" if view_horizontal_stacked: df_plot_copy[color_col] = df_plot_copy[color_col].astype(str) if df_plot_copy.duplicated(subset=[x_col, color_col]).any(): df_plot_copy[y_col] = ( df_plot_copy.groupby([x_col, color_col])[y_col] .transform("sum") ) df_plot_copy = df_plot_copy.drop_duplicates(subset=[x_col, color_col]) bar_kwargs = dict( x=x_col, y=y_col, color=color_col, barmode="group", template=tmpl, text=y_col ) category_orders = {} if active_incident_col is not None: bar_kwargs["x"] = "_incident_label" color_incident_color_map = incident_color_map bar_kwargs["color_discrete_map"] = color_incident_color_map if ( color_incident_color_map and color_col in incident_cols ) or ( color_incident_color_map and color_col in ["_incident_label"] ): bar_kwargs["x"] = ( "_incident_label" if active_incident_col is not None else x_col ) bar_kwargs["color"] = "_incident_label" bar_kwargs["color_discrete_map"] = color_incident_color_map category_orders["_incident_label"] = INICIDENT_CATEGORY_NAMES x_order = get_categorical_columns(df_plot_copy, bar_kwargs["x"]) category_orders[bar_kwargs["x"]] = x_order if category_orders: bar_kwargs["category_orders"] = category_orders fig = px.bar(df_plot_copy, **bar_kwargs) group_x = bar_kwargs["x"] group_y = bar_kwargs["y"] group_color = bar_kwargs.get("color") x_values = category_orders.get(group_x, df_plot_copy[group_x].drop_duplicates().tolist()) if group_color: total_base = ( df_plot_copy[[group_x, group_color, group_y]] .drop_duplicates() ) else: total_base = ( df_plot_copy[[group_x, group_y]] .drop_duplicates() ) group_totals = ( total_base.groupby(group_x)[group_y] .sum() .reset_index() ) totals_map = dict( zip(group_totals[group_x], group_totals[group_y]) ) for x_val in x_values: if x_val in totals_map: fig.add_annotation( x=x_val, y=totals_map[x_val], text=f"{totals_map[x_val]}", showarrow=False, yshift=15, font=dict(size=14), xanchor="center" ) elif incident_color_map and active_incident_col is not None: df_plot_copy = df_plot.copy() if df_plot[x_col].duplicated().any(): df_plot_copy[y_col] = df_plot_copy.groupby("_incident_label")[y_col].transform("sum") df_plot_copy.drop_duplicates(subset=["_incident_label"], inplace=True) fig = px.bar( df_plot_copy, x="_incident_label", y=y_col, color="_incident_label", template=tmpl, color_discrete_map=incident_color_map, text=y_col, category_orders={"_incident_label": INICIDENT_CATEGORY_NAMES} ) else: df_plot_copy = df_plot.copy() if df_plot[x_col].duplicated().any(): df_plot_copy[y_col] = df_plot_copy.groupby(x_col)[y_col].transform("sum") df_plot_copy.drop_duplicates(subset=[x_col], inplace=True) fig = px.bar(df_plot_copy, x=x_col, y=y_col, template=tmpl, text=y_col) fig.update_traces(textposition="outside") if view_all_labels: fig.update_xaxes( type="category", tickmode="array", tickvals=get_categorical_columns(df_plot, x_col) ) if active_incident_col is not None: fig.update_xaxes( categoryorder="array", categoryarray=INICIDENT_CATEGORY_NAMES ) # ───── LINE ───── elif chart_type == "line": active_incident_col = x_col if x_col in incident_cols else None df_plot = ( _add_incident_category(df, active_incident_col) if active_incident_col else df.copy() ) incident_color_map = ( dict( zip( df_plot["_incident_label"].tolist(), df_plot["_incident_color"].tolist(), ) ) if active_incident_col and "_incident_label" in df_plot.columns else None ) if color_col: line_kwargs = dict( x=x_col, y=y_col, color=color_col, markers=True, template=tmpl ) if incident_color_map and color_col in incident_cols: line_kwargs["x"] = ( "_incident_label" if active_incident_col is not None else x_col ) line_kwargs["color"] = "_incident_label" line_kwargs["color_discrete_map"] = incident_color_map fig = px.line(df_plot, **line_kwargs) elif incident_color_map and active_incident_col is not None: fig = px.line( df_plot, x="_incident_label", y=y_col, color="_incident_label", markers=True, template=tmpl, color_discrete_map=incident_color_map, ) else: fig = px.line(df_plot, x=x_col, y=y_col, markers=True, template=tmpl) # ───── PIE ───── elif chart_type == "pie": active_incident_col = x_col if x_col in incident_cols else None df_plot = ( _add_incident_category(df, active_incident_col) if active_incident_col else df.copy() ) incident_color_map = ( dict( zip( df_plot["_incident_label"].tolist(), df_plot["_incident_color"].tolist(), ) ) if active_incident_col and "_incident_label" in df_plot.columns else None ) if y_col: df_plot[y_col] = pd.to_numeric(df_plot[y_col], errors="coerce").fillna(0) if incident_color_map and active_incident_col is not None: fig = px.pie( df_plot, names="_incident_label", values=y_col, hole=0.35, template=tmpl, color="_incident_label", color_discrete_map=incident_color_map, ) else: fig = px.pie( df_plot, names=x_col, values=y_col, hole=0.35, template=tmpl ) fig.update_traces(textinfo="percent+label") # ───── FALLBACK ───── else: fig = px.bar( df_plot, x=x_col, y=y_col, template=tmpl, title=f"Chart type '{chart_type_raw}' not recognized", ) fig.update_layout( title={ "text": title.replace("_", " ").capitalize(), "x": 0.5, "xanchor": "center", "font": { "size": 24 } } ) except Exception as e: st.warning(f"Could not render `{chart_type}` chart: {e}") return if fig: st.plotly_chart( fig, use_container_width=True, config={"displayModeBar": False}, key=f"{uuid.uuid4()}_plot", ) def render_crosstab(df: pd.DataFrame): if df.empty: st.info("No data to summarise.") return numeric = df.select_dtypes(include="number").columns.tolist() cat = df.select_dtypes(exclude="number").columns.tolist() try: if len(cat) >= 2 and len(numeric) >= 1: pivot = df.pivot_table( index=cat[0], columns=cat[1], values=numeric[0], aggfunc="sum", fill_value=0, ) st.caption(f"Crosstab — {cat[0]} x {cat[1]} (sum of {numeric[0]})") st.dataframe(pivot, use_container_width=True) elif len(cat) == 1 and len(numeric) >= 1: summary = df.groupby(cat[0])[numeric].agg(["sum", "mean", "count"]) summary.columns = [f"{v}_{f}" for v, f in summary.columns] summary = summary.reset_index() st.caption(f"Summary — grouped by {cat[0]}") st.dataframe(summary, use_container_width=True, hide_index=True) elif len(numeric) >= 2: corr = df[numeric].corr().round(3) st.caption("Correlation Matrix") st.dataframe( corr.style.background_gradient(cmap="Blues", axis=None), use_container_width=True, ) else: desc = df.describe(include="all").T.reset_index() desc.rename(columns={"index": "column"}, inplace=True) st.caption("Statistical Summary") st.dataframe(desc, use_container_width=True, hide_index=True) except Exception as e: st.warning(f"Could not build crosstab: {e}") st.dataframe(df.describe(include="all").T, use_container_width=True) # ── Controller singleton ───────────────────────────────────────────────────── @st.cache_resource def get_controller(): return DataExtractorController() controller = get_controller() # ── Session state ──────────────────────────────────────────────────────────── if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "total_queries" not in st.session_state: st.session_state.total_queries = 0 if "successful_queries" not in st.session_state: st.session_state.successful_queries = 0 # ── Helpers ────────────────────────────────────────────────────────────────── def build_message_history() -> list[Message]: return [ Message(role=msg["role"], content=msg["content"]) for msg in st.session_state.chat_history ] def call_controller(user_query: str): uq = UserQuery(user_query=user_query) history = build_message_history() response = asyncio.run(controller.extrcat(user_query=uq, message_history=history)) return response def render_message(msg): is_user = msg["role"] == "user" avatar = "👤" if is_user else "🤖" role = "user" if is_user else "assistant" with st.chat_message(role, avatar=avatar): # Status badge if "status" in msg and not is_user: if msg["status"] == "success": st.success("Query executed successfully", icon="✅") else: st.error("Query failed", icon="❌") st.markdown(msg["content"]) # SQL block if msg.get("sql"): with st.expander("Generated SQL", expanded=False): st.code(msg["sql"], language="sql") # Timestamp if msg.get("ts"): st.caption(msg["ts"]) # ── Multi-view data panel ────────────────────────────────────────────── data = msg.get("data", []) chart_hint = msg.get("best_suitable_chart") if data and len(data) > 0 and msg.get("status") == "success": df = pd.DataFrame(data) tab_table, tab_crosstab, tab_chart = st.tabs( ["📋 Table", "📐 Crosstab / Summary", "📊 Chart"] ) with tab_table: st.dataframe(df, use_container_width=True, hide_index=True) with tab_crosstab: render_crosstab(df) with tab_chart: if chart_hint: icon_map = {"BAR": "📊", "PIE": "🥧", "LINE": "📈"} icon = icon_map.get(str(chart_hint).upper(), "📊") st.caption(f"{icon} {chart_hint} (AI suggested)") else: st.caption("📊 Auto-detected chart type") charts = [ "BAR", "LINE", "PIE" ] default_option = chart_hint if chart_hint in charts else "BAR" default_index = charts.index(default_option) charts[default_index] = f"{charts[default_index]} (AI suggested)" chart_hint = st.selectbox( "Select chart type", options=charts, index=default_index, key=f"chart_type_{msg.get('ts', 'x')}", ) chart_hint = chart_hint.replace(" (AI suggested)", "") if chart_hint else None chart_key = f"chart_{msg.get('ts', 'x').replace(':', '_')}" incident_cols = [] if msg.get("incident_col"): incident_cols.append(msg["incident_col"]) detected = _detect_incident_col(df) if detected and detected not in incident_cols: incident_cols.extend(detected) for col in incident_cols: if col in df.columns: df[col] = df[col].astype("category") render_chart(df, incident_cols, chart_hint, key_prefix=chart_key) elif msg.get("status") == "success" and not data: st.info("Query returned 0 rows.") # ── Sidebar ─────────────────────────────────────────────────────────────────── with st.sidebar: st.header("📊 Session Stats") col1, col2 = st.columns(2) with col1: st.metric("Queries", st.session_state.total_queries) with col2: st.metric("Success", st.session_state.successful_queries) st.divider() st.header("💡 Example Prompts") examples = [ "List all fire incidents in last 10 years", "Show top 10 incidents by type", "Count incidents per year", "Find incidents with alarm time after 6pm", "List unique incident types", ] for ex in examples: if st.button(ex, use_container_width=True, key=f"ex_{ex[:20]}"): st.session_state["prefill"] = ex st.divider() if st.button("🗑 Clear History", use_container_width=True): st.session_state.chat_history = [] st.session_state.total_queries = 0 st.session_state.successful_queries = 0 st.rerun() if st.session_state.chat_history: with st.expander("🔍 Raw Message History"): st.json(build_message_history()) # ── Main layout ─────────────────────────────────────────────────────────────── st.title("Firerms Data Extractor Chatbot") st.caption("Natural language → SQL → Results") # ── Chat area ───────────────────────────────────────────────────────────────── chat_container = st.container() with chat_container: if not st.session_state.chat_history: st.markdown("---") col1, col2, col3 = st.columns([1, 2, 1]) with col2: st.markdown( "
🔍
" "Type a natural language question and the AI will generate SQL and return results.
" "