Spaces:
Sleeping
Sleeping
| import asyncio | |
| import sys | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| # ββ Path setup so `src` is importable when running from src/ or project root ββ | |
| _here = Path(__file__).resolve().parent # src/ | |
| _root = _here.parent # project root | |
| for _p in [str(_here), str(_root)]: | |
| if _p not in sys.path: | |
| sys.path.insert(0, _p) | |
| from controllers._data_extractor import DataExtractorController, UserQuery | |
| from schemas import Message | |
| INCIDENT_COLOR_MAP = { | |
| "1": ("#ef4444", "Fire"), | |
| "2": ("#eab308", "Rupt/Exp"), | |
| "3": ("#3b82f6", "EMS"), | |
| "4": ("#22c55e", "Hazardous"), | |
| "5": ("#a855f7", "Public Assist"), | |
| "6": ("#06b6d4", "Good Intent"), | |
| "7": ("#9ca3af", "False Alarms"), | |
| "8": ("#84cc16", "Weather"), | |
| "9": ("#c0c0c0", "Special Type"), | |
| } | |
| INCIDENT_NULL_COLOR = ("#111111", "Not Entered") | |
| INCIDENT_BLANK_COLOR = ("#f97316", "Blank") | |
| INICIDENT_CATEGORY_NAMES = [ | |
| "Fire", | |
| "Rupt/Exp", | |
| "EMS", | |
| "Hazardous", | |
| "Public Assist", | |
| "Good Intent", | |
| "False Alarms", | |
| "Weather", | |
| "Special Type", | |
| "Not Entered", | |
| "Blank", | |
| ] | |
| INCIDENT_NAME_COLOR_MAP = { | |
| "fire": ("#ef4444", "Fire"), | |
| "rupt/exp": ("#eab308", "Rupt/Exp"), | |
| "ems": ("#3b82f6", "EMS"), | |
| "hazardous": ("#22c55e", "Hazardous"), | |
| "public assist": ("#a855f7", "Public Assist"), | |
| "good intent": ("#06b6d4", "Good Intent"), | |
| "false alarms": ("#9ca3af", "False Alarms"), | |
| "weather": ("#84cc16", "Weather"), | |
| "special type": ("#c0c0c0", "Special Type"), | |
| } | |
| INCIDENT_COL_NAMES = { | |
| "incidenttype", | |
| "incident_type", | |
| "incidentclassification", | |
| "incident_classification", | |
| "incident_category", | |
| "incidentcategory", | |
| } | |
| def detect_date_columns(df: pd.DataFrame): | |
| for col in df.columns: | |
| if pd.api.types.is_numeric_dtype(df[col]): | |
| continue | |
| try: | |
| numeric_like_ratio = pd.to_numeric(df[col], errors="coerce").notna().mean() | |
| if numeric_like_ratio > 0.8: | |
| continue | |
| except Exception: | |
| pass | |
| if pd.api.types.is_datetime64_any_dtype(df[col]): | |
| return col | |
| else: | |
| try: | |
| converted = pd.to_datetime(df[col], errors="coerce") | |
| valid_ratio = converted.notna().mean() | |
| if valid_ratio > 0.8: | |
| return col | |
| except Exception: | |
| pass | |
| return None | |
| def get_date_range(df: pd.DataFrame): | |
| date_col = detect_date_columns(df) | |
| if not date_col: | |
| return None | |
| series = pd.to_datetime(df[date_col], errors="coerce").dropna() | |
| if series.empty: | |
| return None | |
| col_min = series.min() | |
| col_max = series.max() | |
| if col_min and col_max: | |
| return f"{col_min.strftime('%b %d, %Y')} β {col_max.strftime('%b %d, %Y')}" | |
| return None | |
| def _detect_incident_col(df: pd.DataFrame) -> str | None: | |
| incident_cols = [col for col in df.columns if col.strip().lower() in INCIDENT_COL_NAMES] | |
| return incident_cols if incident_cols else None | |
| def _incident_label_and_color(value) -> tuple[str, str]: | |
| """Return (display_label, hex_color) for a raw incidenttype value.""" | |
| if value is None or value == "nan" or (isinstance(value, float) and pd.isna(value)): | |
| return INCIDENT_NULL_COLOR[1], INCIDENT_NULL_COLOR[0] | |
| s = str(value).strip() | |
| if s == "": | |
| return INCIDENT_BLANK_COLOR[1], INCIDENT_BLANK_COLOR[0] | |
| prefix = s[0].upper() if s[0].isalpha() else s[0] | |
| key = s[0].upper() if s[0].upper() == "S" else s[0] | |
| if key in INCIDENT_COLOR_MAP: | |
| color, label = INCIDENT_COLOR_MAP[key] | |
| return label, color | |
| elif any( | |
| category_name.lower() in s.lower() for category_name in INICIDENT_CATEGORY_NAMES | |
| ): | |
| category_name_found = next( | |
| category_name | |
| for category_name in INICIDENT_CATEGORY_NAMES | |
| if category_name.lower() in s.lower() | |
| ) | |
| name_key = category_name_found.lower() | |
| if name_key in INCIDENT_NAME_COLOR_MAP: | |
| color, label = INCIDENT_NAME_COLOR_MAP[name_key] | |
| return label, color | |
| return s, "#6b7280" | |
| def _add_incident_category(df: pd.DataFrame, col: str) -> pd.DataFrame: | |
| df = df.copy() | |
| if col not in df.columns or df.empty: | |
| return df | |
| df = df.reset_index(drop=True) | |
| plain = df[col].astype(str).where(df[col].notna(), other=None) | |
| mapped = plain.map(_incident_label_and_color) | |
| df["_incident_label"] = mapped.apply(lambda x: x[0]) | |
| df["_incident_color"] = mapped.apply(lambda x: x[1]) | |
| return df | |
| # ββ Page config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="Data Extractor AI", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # ββ Plotly template (adapts to Streamlit theme) ββββββββββββββββββββββββββββββ | |
| def _get_plotly_template(): | |
| """Return a Plotly template that works with Streamlit's current theme.""" | |
| return "plotly_white" | |
| # ββ Chart rendering βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CHART_ALIASES = { | |
| "pie": "pie", | |
| "bar": "bar", | |
| "line": "line", | |
| "bar_chart": "bar", | |
| "vertical_bar": "bar", | |
| "column": "bar", | |
| "grouped_bar": "bar", | |
| "stacked_bar": "bar", | |
| "line_chart": "line", | |
| "time_series": "line", | |
| "trend": "line", | |
| "donut": "pie", | |
| "doughnut": "pie", | |
| } | |
| def _normalise_chart_type(raw: str | None) -> str | None: | |
| if not raw: | |
| return None | |
| return CHART_ALIASES.get(raw.lower().strip(), raw.lower().strip()) | |
| def _guess_chart_type(df: pd.DataFrame) -> str: | |
| cols = list(df.columns) | |
| n_cols = len(cols) | |
| n_rows = len(df) | |
| numeric = df.select_dtypes(include="number").columns.tolist() | |
| if n_cols == 1 and numeric: | |
| return "histogram" | |
| cat = df.select_dtypes(exclude="number").columns.tolist() | |
| if n_cols == 2 and len(cat) == 1 and len(numeric) == 1: | |
| return "bar" if n_rows <= 50 else "line" | |
| if len(numeric) >= 2: | |
| return "line" | |
| return "bar" | |
| def normalize(x): | |
| try: | |
| num = float(x) | |
| if num.is_integer(): | |
| return str(int(num)) | |
| return str(num) | |
| except: | |
| return x | |
| def sort_key(x): | |
| try: | |
| return (0, float(x)) | |
| except: | |
| return (1, x) | |
| def get_categorical_columns(df: pd.DataFrame, column_name: str) -> list[str]: | |
| df[column_name] = df[column_name].astype(str) | |
| df[column_name] = df[column_name].map(normalize) | |
| cats = sorted(df[column_name].dropna().unique(), key=sort_key) | |
| df[column_name] = pd.Categorical(df[column_name], categories=cats, ordered=True) | |
| s = df.sort_values(by=column_name).reset_index(drop=True)[column_name] | |
| return s.tolist() | |
| def render_chart( | |
| df: pd.DataFrame, | |
| incident_cols: list[str] | None = None, | |
| chart_type_raw: str | None = None, | |
| key_prefix: str = "chart", | |
| ): | |
| if df.empty: | |
| st.info("No data to chart.") | |
| return | |
| df = df.copy() | |
| for c in df.columns: | |
| if hasattr(df[c], "cat"): | |
| df[c] = df[c].astype(str).replace("nan", None) | |
| chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df) | |
| incident_cols = incident_cols or [] | |
| active_incident_col: str | None = None | |
| df_plot = df.copy() | |
| cols = list(df.columns) | |
| numeric = df.select_dtypes(include="number").columns.tolist() | |
| cat = df.select_dtypes(exclude="number").columns.tolist() | |
| default_x = ( | |
| incident_cols[0] | |
| if incident_cols and incident_cols[0] in cols | |
| else (cat[0] if cat else cols[0]) | |
| ) | |
| default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0]) | |
| # ββ Column selectors ββ | |
| st.caption("βοΈ Configure columns") | |
| if chart_type == "pie": | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| x_col = st.selectbox( | |
| "Labels", | |
| options=cols, | |
| index=cols.index(default_x) if default_x in cols else 0, | |
| key=f"{key_prefix}_pie_x", | |
| ) | |
| with c2: | |
| val_opts = numeric if numeric else cols | |
| for col in cols: | |
| try: | |
| if pd.to_numeric(df[col], errors="coerce").notna().all() and col not in val_opts: | |
| val_opts.append(col) | |
| except Exception: | |
| pass | |
| y_col = st.selectbox( | |
| "Values", | |
| options=val_opts, | |
| index=val_opts.index(default_y) if default_y in val_opts else 0, | |
| key=f"{key_prefix}_pie_y", | |
| ) | |
| color_col = None | |
| else: | |
| c1, c2, c3 = st.columns(3) | |
| with c1: | |
| x_col = st.selectbox( | |
| "X axis", | |
| options=cols, | |
| index=cols.index(default_x) if default_x in cols else 0, | |
| key=f"{key_prefix}_x", | |
| ) | |
| view_all_labels = st.checkbox( | |
| "View All Labels", | |
| key=f"{key_prefix}_x_all_labels", | |
| ) | |
| with c2: | |
| y_opts = numeric if numeric else cols | |
| y_col = st.selectbox( | |
| "Y axis", | |
| options=y_opts, | |
| index=y_opts.index(default_y) if default_y in y_opts else 0, | |
| key=f"{key_prefix}_y", | |
| ) | |
| with c3: | |
| color_options = ["None"] + [c for c in cols if c not in (x_col, y_col)] | |
| color_sel = st.selectbox( | |
| "Color / Group", | |
| options=color_options, | |
| index=0, | |
| key=f"{key_prefix}_color", | |
| ) | |
| view_horizontal_stacked = st.checkbox( | |
| "Horizontal Stacked", | |
| key=f"{key_prefix}_stacked", | |
| ) | |
| color_col = None if color_sel == "None" else color_sel | |
| # ββ Incident color mapping βββββββββββββββββββββββββ | |
| incident_color_map = None | |
| if active_incident_col and "_incident_label" in df_plot.columns: | |
| incident_color_map = dict( | |
| zip(df_plot["_incident_label"], df_plot["_incident_color"]) | |
| ) | |
| # ββ Build chart ββββββββββββββββββββββββββββββββββββ | |
| fig = None | |
| tmpl = _get_plotly_template() | |
| date_range = get_date_range(df_plot) | |
| title = f"{y_col} by {x_col}" | |
| if date_range: | |
| title += f" ({date_range})" | |
| try: | |
| # βββββ BAR βββββ | |
| if chart_type == "bar": | |
| if view_all_labels: | |
| df_plot[x_col] = df_plot[x_col].astype(str) | |
| active_incident_col = x_col if x_col in incident_cols else None | |
| df_plot = ( | |
| _add_incident_category(df, active_incident_col) | |
| if active_incident_col | |
| else df.copy() | |
| ) | |
| incident_color_map = ( | |
| dict( | |
| zip( | |
| df_plot["_incident_label"].tolist(), | |
| df_plot["_incident_color"].tolist(), | |
| ) | |
| ) | |
| if active_incident_col and "_incident_label" in df_plot.columns | |
| else None | |
| ) | |
| if color_col: | |
| df_plot_copy = df_plot.copy() | |
| color_incident_color_map = None | |
| if color_col in incident_cols: | |
| df_plot_copy = _add_incident_category(df_plot_copy, color_col) | |
| color_incident_color_map = dict( | |
| zip( | |
| df_plot_copy["_incident_label"].tolist(), | |
| df_plot_copy["_incident_color"].tolist(), | |
| ) | |
| ) | |
| if not color_incident_color_map: | |
| color_incident_color_map = incident_color_map | |
| if active_incident_col is not None: | |
| x_col = "_incident_label" | |
| if color_col in incident_cols: | |
| color_col = "_incident_label" | |
| if view_horizontal_stacked: | |
| df_plot_copy[color_col] = df_plot_copy[color_col].astype(str) | |
| if df_plot_copy.duplicated(subset=[x_col, color_col]).any(): | |
| df_plot_copy[y_col] = ( | |
| df_plot_copy.groupby([x_col, color_col])[y_col] | |
| .transform("sum") | |
| ) | |
| df_plot_copy = df_plot_copy.drop_duplicates(subset=[x_col, color_col]) | |
| bar_kwargs = dict( | |
| x=x_col, y=y_col, color=color_col, barmode="group", template=tmpl, text=y_col | |
| ) | |
| category_orders = {} | |
| if active_incident_col is not None: | |
| bar_kwargs["x"] = "_incident_label" | |
| color_incident_color_map = incident_color_map | |
| bar_kwargs["color_discrete_map"] = color_incident_color_map | |
| if ( | |
| color_incident_color_map | |
| and color_col in incident_cols | |
| ) or ( | |
| color_incident_color_map | |
| and color_col in ["_incident_label"] | |
| ): | |
| bar_kwargs["x"] = ( | |
| "_incident_label" if active_incident_col is not None else x_col | |
| ) | |
| bar_kwargs["color"] = "_incident_label" | |
| bar_kwargs["color_discrete_map"] = color_incident_color_map | |
| category_orders["_incident_label"] = INICIDENT_CATEGORY_NAMES | |
| x_order = get_categorical_columns(df_plot_copy, bar_kwargs["x"]) | |
| category_orders[bar_kwargs["x"]] = x_order | |
| if category_orders: | |
| bar_kwargs["category_orders"] = category_orders | |
| fig = px.bar(df_plot_copy, **bar_kwargs) | |
| group_x = bar_kwargs["x"] | |
| group_y = bar_kwargs["y"] | |
| group_color = bar_kwargs.get("color") | |
| x_values = category_orders.get(group_x, df_plot_copy[group_x].drop_duplicates().tolist()) | |
| if group_color: | |
| total_base = ( | |
| df_plot_copy[[group_x, group_color, group_y]] | |
| .drop_duplicates() | |
| ) | |
| else: | |
| total_base = ( | |
| df_plot_copy[[group_x, group_y]] | |
| .drop_duplicates() | |
| ) | |
| group_totals = ( | |
| total_base.groupby(group_x)[group_y] | |
| .sum() | |
| .reset_index() | |
| ) | |
| totals_map = dict( | |
| zip(group_totals[group_x], group_totals[group_y]) | |
| ) | |
| for x_val in x_values: | |
| if x_val in totals_map: | |
| fig.add_annotation( | |
| x=x_val, | |
| y=totals_map[x_val], | |
| text=f"{totals_map[x_val]}", | |
| showarrow=False, | |
| yshift=15, | |
| font=dict(size=14), | |
| xanchor="center" | |
| ) | |
| elif incident_color_map and active_incident_col is not None: | |
| df_plot_copy = df_plot.copy() | |
| if df_plot[x_col].duplicated().any(): | |
| df_plot_copy[y_col] = df_plot_copy.groupby("_incident_label")[y_col].transform("sum") | |
| df_plot_copy.drop_duplicates(subset=["_incident_label"], inplace=True) | |
| fig = px.bar( | |
| df_plot_copy, | |
| x="_incident_label", | |
| y=y_col, | |
| color="_incident_label", | |
| template=tmpl, | |
| color_discrete_map=incident_color_map, | |
| text=y_col, | |
| category_orders={"_incident_label": INICIDENT_CATEGORY_NAMES} | |
| ) | |
| else: | |
| df_plot_copy = df_plot.copy() | |
| if df_plot[x_col].duplicated().any(): | |
| df_plot_copy[y_col] = df_plot_copy.groupby(x_col)[y_col].transform("sum") | |
| df_plot_copy.drop_duplicates(subset=[x_col], inplace=True) | |
| fig = px.bar(df_plot_copy, x=x_col, y=y_col, template=tmpl, text=y_col) | |
| fig.update_traces(textposition="outside") | |
| if view_all_labels: | |
| fig.update_xaxes( | |
| type="category", | |
| tickmode="array", | |
| tickvals=get_categorical_columns(df_plot, x_col) | |
| ) | |
| if active_incident_col is not None: | |
| fig.update_xaxes( | |
| categoryorder="array", | |
| categoryarray=INICIDENT_CATEGORY_NAMES | |
| ) | |
| # βββββ LINE βββββ | |
| elif chart_type == "line": | |
| active_incident_col = x_col if x_col in incident_cols else None | |
| df_plot = ( | |
| _add_incident_category(df, active_incident_col) | |
| if active_incident_col | |
| else df.copy() | |
| ) | |
| incident_color_map = ( | |
| dict( | |
| zip( | |
| df_plot["_incident_label"].tolist(), | |
| df_plot["_incident_color"].tolist(), | |
| ) | |
| ) | |
| if active_incident_col and "_incident_label" in df_plot.columns | |
| else None | |
| ) | |
| if color_col: | |
| line_kwargs = dict( | |
| x=x_col, y=y_col, color=color_col, markers=True, template=tmpl | |
| ) | |
| if incident_color_map and color_col in incident_cols: | |
| line_kwargs["x"] = ( | |
| "_incident_label" if active_incident_col is not None else x_col | |
| ) | |
| line_kwargs["color"] = "_incident_label" | |
| line_kwargs["color_discrete_map"] = incident_color_map | |
| fig = px.line(df_plot, **line_kwargs) | |
| elif incident_color_map and active_incident_col is not None: | |
| fig = px.line( | |
| df_plot, | |
| x="_incident_label", | |
| y=y_col, | |
| color="_incident_label", | |
| markers=True, | |
| template=tmpl, | |
| color_discrete_map=incident_color_map, | |
| ) | |
| else: | |
| fig = px.line(df_plot, x=x_col, y=y_col, markers=True, template=tmpl) | |
| # βββββ PIE βββββ | |
| elif chart_type == "pie": | |
| active_incident_col = x_col if x_col in incident_cols else None | |
| df_plot = ( | |
| _add_incident_category(df, active_incident_col) | |
| if active_incident_col | |
| else df.copy() | |
| ) | |
| incident_color_map = ( | |
| dict( | |
| zip( | |
| df_plot["_incident_label"].tolist(), | |
| df_plot["_incident_color"].tolist(), | |
| ) | |
| ) | |
| if active_incident_col and "_incident_label" in df_plot.columns | |
| else None | |
| ) | |
| if y_col: | |
| df_plot[y_col] = pd.to_numeric(df_plot[y_col], errors="coerce").fillna(0) | |
| if incident_color_map and active_incident_col is not None: | |
| fig = px.pie( | |
| df_plot, | |
| names="_incident_label", | |
| values=y_col, | |
| hole=0.35, | |
| template=tmpl, | |
| color="_incident_label", | |
| color_discrete_map=incident_color_map, | |
| ) | |
| else: | |
| fig = px.pie( | |
| df_plot, names=x_col, values=y_col, hole=0.35, template=tmpl | |
| ) | |
| fig.update_traces(textinfo="percent+label") | |
| # βββββ FALLBACK βββββ | |
| else: | |
| fig = px.bar( | |
| df_plot, | |
| x=x_col, | |
| y=y_col, | |
| template=tmpl, | |
| title=f"Chart type '{chart_type_raw}' not recognized", | |
| ) | |
| fig.update_layout( | |
| title={ | |
| "text": title.replace("_", " ").capitalize(), | |
| "x": 0.5, | |
| "xanchor": "center", | |
| "font": { | |
| "size": 24 | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| st.warning(f"Could not render `{chart_type}` chart: {e}") | |
| return | |
| if fig: | |
| st.plotly_chart( | |
| fig, | |
| use_container_width=True, | |
| config={"displayModeBar": False}, | |
| key=f"{uuid.uuid4()}_plot", | |
| ) | |
| def render_crosstab(df: pd.DataFrame): | |
| if df.empty: | |
| st.info("No data to summarise.") | |
| return | |
| numeric = df.select_dtypes(include="number").columns.tolist() | |
| cat = df.select_dtypes(exclude="number").columns.tolist() | |
| try: | |
| if len(cat) >= 2 and len(numeric) >= 1: | |
| pivot = df.pivot_table( | |
| index=cat[0], | |
| columns=cat[1], | |
| values=numeric[0], | |
| aggfunc="sum", | |
| fill_value=0, | |
| ) | |
| st.caption(f"Crosstab β {cat[0]} x {cat[1]} (sum of {numeric[0]})") | |
| st.dataframe(pivot, use_container_width=True) | |
| elif len(cat) == 1 and len(numeric) >= 1: | |
| summary = df.groupby(cat[0])[numeric].agg(["sum", "mean", "count"]) | |
| summary.columns = [f"{v}_{f}" for v, f in summary.columns] | |
| summary = summary.reset_index() | |
| st.caption(f"Summary β grouped by {cat[0]}") | |
| st.dataframe(summary, use_container_width=True, hide_index=True) | |
| elif len(numeric) >= 2: | |
| corr = df[numeric].corr().round(3) | |
| st.caption("Correlation Matrix") | |
| st.dataframe( | |
| corr.style.background_gradient(cmap="Blues", axis=None), | |
| use_container_width=True, | |
| ) | |
| else: | |
| desc = df.describe(include="all").T.reset_index() | |
| desc.rename(columns={"index": "column"}, inplace=True) | |
| st.caption("Statistical Summary") | |
| st.dataframe(desc, use_container_width=True, hide_index=True) | |
| except Exception as e: | |
| st.warning(f"Could not build crosstab: {e}") | |
| st.dataframe(df.describe(include="all").T, use_container_width=True) | |
| # ββ Controller singleton βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_controller(): | |
| return DataExtractorController() | |
| controller = get_controller() | |
| # ββ Session state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if "total_queries" not in st.session_state: | |
| st.session_state.total_queries = 0 | |
| if "successful_queries" not in st.session_state: | |
| st.session_state.successful_queries = 0 | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_message_history() -> list[Message]: | |
| return [ | |
| Message(role=msg["role"], content=msg["content"]) | |
| for msg in st.session_state.chat_history | |
| ] | |
| def call_controller(user_query: str): | |
| uq = UserQuery(user_query=user_query) | |
| history = build_message_history() | |
| response = asyncio.run(controller.extrcat(user_query=uq, message_history=history)) | |
| return response | |
| def render_message(msg): | |
| is_user = msg["role"] == "user" | |
| avatar = "π€" if is_user else "π€" | |
| role = "user" if is_user else "assistant" | |
| with st.chat_message(role, avatar=avatar): | |
| # Status badge | |
| if "status" in msg and not is_user: | |
| if msg["status"] == "success": | |
| st.success("Query executed successfully", icon="β ") | |
| else: | |
| st.error("Query failed", icon="β") | |
| st.markdown(msg["content"]) | |
| # SQL block | |
| if msg.get("sql"): | |
| with st.expander("Generated SQL", expanded=False): | |
| st.code(msg["sql"], language="sql") | |
| # Timestamp | |
| if msg.get("ts"): | |
| st.caption(msg["ts"]) | |
| # ββ Multi-view data panel ββββββββββββββββββββββββββββββββββββββββββββββ | |
| data = msg.get("data", []) | |
| chart_hint = msg.get("best_suitable_chart") | |
| if data and len(data) > 0 and msg.get("status") == "success": | |
| df = pd.DataFrame(data) | |
| tab_table, tab_crosstab, tab_chart = st.tabs( | |
| ["π Table", "π Crosstab / Summary", "π Chart"] | |
| ) | |
| with tab_table: | |
| st.dataframe(df, use_container_width=True, hide_index=True) | |
| with tab_crosstab: | |
| render_crosstab(df) | |
| with tab_chart: | |
| if chart_hint: | |
| icon_map = {"BAR": "π", "PIE": "π₯§", "LINE": "π"} | |
| icon = icon_map.get(str(chart_hint).upper(), "π") | |
| st.caption(f"{icon} {chart_hint} (AI suggested)") | |
| else: | |
| st.caption("π Auto-detected chart type") | |
| charts = [ | |
| "BAR", | |
| "LINE", | |
| "PIE" | |
| ] | |
| default_option = chart_hint if chart_hint in charts else "BAR" | |
| default_index = charts.index(default_option) | |
| charts[default_index] = f"{charts[default_index]} (AI suggested)" | |
| chart_hint = st.selectbox( | |
| "Select chart type", | |
| options=charts, | |
| index=default_index, | |
| key=f"chart_type_{msg.get('ts', 'x')}", | |
| ) | |
| chart_hint = chart_hint.replace(" (AI suggested)", "") if chart_hint else None | |
| chart_key = f"chart_{msg.get('ts', 'x').replace(':', '_')}" | |
| incident_cols = [] | |
| if msg.get("incident_col"): | |
| incident_cols.append(msg["incident_col"]) | |
| detected = _detect_incident_col(df) | |
| if detected and detected not in incident_cols: | |
| incident_cols.extend(detected) | |
| for col in incident_cols: | |
| if col in df.columns: | |
| df[col] = df[col].astype("category") | |
| render_chart(df, incident_cols, chart_hint, key_prefix=chart_key) | |
| elif msg.get("status") == "success" and not data: | |
| st.info("Query returned 0 rows.") | |
| # ββ Sidebar βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.header("π Session Stats") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric("Queries", st.session_state.total_queries) | |
| with col2: | |
| st.metric("Success", st.session_state.successful_queries) | |
| st.divider() | |
| st.header("π‘ Example Prompts") | |
| examples = [ | |
| "List all fire incidents in last 10 years", | |
| "Show top 10 incidents by type", | |
| "Count incidents per year", | |
| "Find incidents with alarm time after 6pm", | |
| "List unique incident types", | |
| ] | |
| for ex in examples: | |
| if st.button(ex, use_container_width=True, key=f"ex_{ex[:20]}"): | |
| st.session_state["prefill"] = ex | |
| st.divider() | |
| if st.button("π Clear History", use_container_width=True): | |
| st.session_state.chat_history = [] | |
| st.session_state.total_queries = 0 | |
| st.session_state.successful_queries = 0 | |
| st.rerun() | |
| if st.session_state.chat_history: | |
| with st.expander("π Raw Message History"): | |
| st.json(build_message_history()) | |
| # ββ Main layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.title("Firerms Data Extractor Chatbot") | |
| st.caption("Natural language β SQL β Results") | |
| # ββ Chat area βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| chat_container = st.container() | |
| with chat_container: | |
| if not st.session_state.chat_history: | |
| st.markdown("---") | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| st.markdown( | |
| "<div style='text-align:center;padding:40px 0;'>" | |
| "<p style='font-size:3rem;'>π</p>" | |
| "<h3>Ask anything about your data</h3>" | |
| "<p>Type a natural language question and the AI will generate SQL and return results.</p>" | |
| "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| for msg in st.session_state.chat_history: | |
| render_message(msg) | |
| # ββ Chat input ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| prefill = st.session_state.pop("prefill", "") | |
| prompt = st.chat_input("Ask a question about your dataβ¦", key="chat_input") | |
| if not prompt and prefill: | |
| prompt = prefill | |
| if prompt: | |
| ts_now = datetime.now().strftime("%H:%M:%S") | |
| st.session_state.chat_history.append( | |
| {"role": "user", "content": prompt, "ts": ts_now} | |
| ) | |
| st.session_state.total_queries += 1 | |
| with st.spinner("Generating SQL and fetching resultsβ¦"): | |
| try: | |
| result = call_controller(prompt) | |
| status = result.status | |
| sql = result.sql_query | |
| data = result.data or [] | |
| try: | |
| best_chart = result.output.best_suitable_chart.value | |
| except Exception: | |
| best_chart = None | |
| incident_col = None | |
| if result.output.is_incident_category_required: | |
| incident_col = result.output.column_name_mapped_with_incident_category | |
| row_count = len(data) | |
| content = ( | |
| f"Query executed successfully. Returned **{row_count}** row(s)." | |
| if status == "success" | |
| else f"Query returned status: `{status}`." | |
| ) | |
| st.session_state.chat_history.append( | |
| { | |
| "role": "assistant", | |
| "content": content, | |
| "sql": sql, | |
| "data": data, | |
| "status": status, | |
| "best_suitable_chart": best_chart, | |
| "incident_col": incident_col, | |
| "ts": datetime.now().strftime("%H:%M:%S"), | |
| } | |
| ) | |
| if status == "success": | |
| st.session_state.successful_queries += 1 | |
| except Exception as e: | |
| st.session_state.chat_history.append( | |
| { | |
| "role": "assistant", | |
| "content": f"Error: {str(e)}", | |
| "status": "error", | |
| "ts": datetime.now().strftime("%H:%M:%S"), | |
| } | |
| ) | |
| st.rerun() | |