Spaces:
Sleeping
Sleeping
File size: 32,447 Bytes
3b931bf 3c6f410 3b931bf 3c6f410 3b931bf 783bea4 3c6f410 f869907 3b931bf 411c555 3b931bf f869907 411c555 02f9977 411c555 3c6f410 02f9977 0a46d2f 02f9977 0a46d2f 02f9977 0a46d2f 02f9977 411c555 0a46d2f 3c6f410 411c555 02f9977 411c555 02f9977 411c555 3b931bf f869907 3b931bf 3c6f410 783bea4 411c555 783bea4 411c555 783bea4 6feb8ca 783bea4 411c555 02f9977 411c555 783bea4 3c6f410 02f9977 783bea4 411c555 02f9977 411c555 783bea4 411c555 783bea4 411c555 02f9977 411c555 783bea4 3c6f410 783bea4 411c555 783bea4 411c555 783bea4 6f3c0f1 783bea4 411c555 783bea4 411c555 783bea4 411c555 783bea4 0a46d2f 783bea4 411c555 783bea4 dbd4a93 783bea4 411c555 02f9977 411c555 02f9977 411c555 783bea4 3c6f410 411c555 0a46d2f 783bea4 411c555 783bea4 0a46d2f 02f9977 3c6f410 02f9977 783bea4 0a46d2f c54e39d 0a46d2f 12473a4 c54e39d 12473a4 83a38f7 e999dc3 83a38f7 dbd4a93 0a46d2f 411c555 0a46d2f 411c555 d8e4d75 0a46d2f c54e39d d8e4d75 3c6f410 411c555 12473a4 d8e4d75 5883a1f d8e4d75 0a46d2f a05d65f f405bfb a05d65f 02f9977 15bdd44 0a46d2f 411c555 0a46d2f 411c555 3c6f410 411c555 0a46d2f d8e4d75 411c555 783bea4 0a46d2f d55d5ac 6feb8ca d55d5ac 0a46d2f 783bea4 411c555 783bea4 02f9977 3c6f410 02f9977 783bea4 411c555 3c6f410 411c555 02f9977 3c6f410 411c555 02f9977 411c555 3c6f410 411c555 783bea4 3c6f410 783bea4 411c555 783bea4 02f9977 3c6f410 02f9977 6f3c0f1 02f9977 411c555 3c6f410 411c555 3c6f410 411c555 783bea4 411c555 3c6f410 411c555 6f3c0f1 48d943e 6f3c0f1 783bea4 3c6f410 783bea4 6e8c418 85f5b43 3c6f410 85f5b43 411c555 fdfc6fd 783bea4 411c555 783bea4 411c555 783bea4 411c555 3c6f410 783bea4 411c555 783bea4 3c6f410 783bea4 3c6f410 411c555 783bea4 3c6f410 783bea4 3c6f410 783bea4 3c6f410 3b931bf 411c555 3b931bf 783bea4 3b931bf 411c555 3b931bf 411c555 3b931bf 411c555 3b931bf 3c6f410 3b931bf 3c6f410 3b931bf 3c6f410 3b931bf 3c6f410 3b931bf 6f30bbd 3c6f410 6f30bbd 411c555 3c6f410 411c555 6f30bbd 3c6f410 6f30bbd 3c6f410 6f3c0f1 3c6f410 6f30bbd 02f9977 0a46d2f 02f9977 6f30bbd 3c6f410 3b931bf 783bea4 3b931bf 3c6f410 3b931bf 3c6f410 3b931bf b2fbf5d 3b931bf b2fbf5d 3b931bf 783bea4 3b931bf 3c6f410 3b931bf 3c6f410 3b931bf 411c555 3c6f410 411c555 3b931bf 411c555 3b931bf 783bea4 3c6f410 783bea4 411c555 3b931bf 411c555 3b931bf 411c555 3c6f410 411c555 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 | import asyncio
import sys
import uuid
from datetime import datetime
from pathlib import Path
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
# ββ Path setup so `src` is importable when running from src/ or project root ββ
_here = Path(__file__).resolve().parent # src/
_root = _here.parent # project root
for _p in [str(_here), str(_root)]:
if _p not in sys.path:
sys.path.insert(0, _p)
from controllers._data_extractor import DataExtractorController, UserQuery
from schemas import Message
INCIDENT_COLOR_MAP = {
"1": ("#ef4444", "Fire"),
"2": ("#eab308", "Rupt/Exp"),
"3": ("#3b82f6", "EMS"),
"4": ("#22c55e", "Hazardous"),
"5": ("#a855f7", "Public Assist"),
"6": ("#06b6d4", "Good Intent"),
"7": ("#9ca3af", "False Alarms"),
"8": ("#84cc16", "Weather"),
"9": ("#c0c0c0", "Special Type"),
}
INCIDENT_NULL_COLOR = ("#111111", "Not Entered")
INCIDENT_BLANK_COLOR = ("#f97316", "Blank")
INICIDENT_CATEGORY_NAMES = [
"Fire",
"Rupt/Exp",
"EMS",
"Hazardous",
"Public Assist",
"Good Intent",
"False Alarms",
"Weather",
"Special Type",
"Not Entered",
"Blank",
]
INCIDENT_NAME_COLOR_MAP = {
"fire": ("#ef4444", "Fire"),
"rupt/exp": ("#eab308", "Rupt/Exp"),
"ems": ("#3b82f6", "EMS"),
"hazardous": ("#22c55e", "Hazardous"),
"public assist": ("#a855f7", "Public Assist"),
"good intent": ("#06b6d4", "Good Intent"),
"false alarms": ("#9ca3af", "False Alarms"),
"weather": ("#84cc16", "Weather"),
"special type": ("#c0c0c0", "Special Type"),
}
INCIDENT_COL_NAMES = {
"incidenttype",
"incident_type",
"incidentclassification",
"incident_classification",
"incident_category",
"incidentcategory",
}
def detect_date_columns(df: pd.DataFrame):
for col in df.columns:
if pd.api.types.is_numeric_dtype(df[col]):
continue
try:
numeric_like_ratio = pd.to_numeric(df[col], errors="coerce").notna().mean()
if numeric_like_ratio > 0.8:
continue
except Exception:
pass
if pd.api.types.is_datetime64_any_dtype(df[col]):
return col
else:
try:
converted = pd.to_datetime(df[col], errors="coerce")
valid_ratio = converted.notna().mean()
if valid_ratio > 0.8:
return col
except Exception:
pass
return None
def get_date_range(df: pd.DataFrame):
date_col = detect_date_columns(df)
if not date_col:
return None
series = pd.to_datetime(df[date_col], errors="coerce").dropna()
if series.empty:
return None
col_min = series.min()
col_max = series.max()
if col_min and col_max:
return f"{col_min.strftime('%b %d, %Y')} β {col_max.strftime('%b %d, %Y')}"
return None
def _detect_incident_col(df: pd.DataFrame) -> str | None:
incident_cols = [col for col in df.columns if col.strip().lower() in INCIDENT_COL_NAMES]
return incident_cols if incident_cols else None
def _incident_label_and_color(value) -> tuple[str, str]:
"""Return (display_label, hex_color) for a raw incidenttype value."""
if value is None or value == "nan" or (isinstance(value, float) and pd.isna(value)):
return INCIDENT_NULL_COLOR[1], INCIDENT_NULL_COLOR[0]
s = str(value).strip()
if s == "":
return INCIDENT_BLANK_COLOR[1], INCIDENT_BLANK_COLOR[0]
prefix = s[0].upper() if s[0].isalpha() else s[0]
key = s[0].upper() if s[0].upper() == "S" else s[0]
if key in INCIDENT_COLOR_MAP:
color, label = INCIDENT_COLOR_MAP[key]
return label, color
elif any(
category_name.lower() in s.lower() for category_name in INICIDENT_CATEGORY_NAMES
):
category_name_found = next(
category_name
for category_name in INICIDENT_CATEGORY_NAMES
if category_name.lower() in s.lower()
)
name_key = category_name_found.lower()
if name_key in INCIDENT_NAME_COLOR_MAP:
color, label = INCIDENT_NAME_COLOR_MAP[name_key]
return label, color
return s, "#6b7280"
def _add_incident_category(df: pd.DataFrame, col: str) -> pd.DataFrame:
df = df.copy()
if col not in df.columns or df.empty:
return df
df = df.reset_index(drop=True)
plain = df[col].astype(str).where(df[col].notna(), other=None)
mapped = plain.map(_incident_label_and_color)
df["_incident_label"] = mapped.apply(lambda x: x[0])
df["_incident_color"] = mapped.apply(lambda x: x[1])
return df
# ββ Page config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.set_page_config(
page_title="Data Extractor AI",
page_icon="π",
layout="wide",
initial_sidebar_state="expanded",
)
# ββ Plotly template (adapts to Streamlit theme) ββββββββββββββββββββββββββββββ
def _get_plotly_template():
"""Return a Plotly template that works with Streamlit's current theme."""
return "plotly_white"
# ββ Chart rendering βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
CHART_ALIASES = {
"pie": "pie",
"bar": "bar",
"line": "line",
"bar_chart": "bar",
"vertical_bar": "bar",
"column": "bar",
"grouped_bar": "bar",
"stacked_bar": "bar",
"line_chart": "line",
"time_series": "line",
"trend": "line",
"donut": "pie",
"doughnut": "pie",
}
def _normalise_chart_type(raw: str | None) -> str | None:
if not raw:
return None
return CHART_ALIASES.get(raw.lower().strip(), raw.lower().strip())
def _guess_chart_type(df: pd.DataFrame) -> str:
cols = list(df.columns)
n_cols = len(cols)
n_rows = len(df)
numeric = df.select_dtypes(include="number").columns.tolist()
if n_cols == 1 and numeric:
return "histogram"
cat = df.select_dtypes(exclude="number").columns.tolist()
if n_cols == 2 and len(cat) == 1 and len(numeric) == 1:
return "bar" if n_rows <= 50 else "line"
if len(numeric) >= 2:
return "line"
return "bar"
def normalize(x):
try:
num = float(x)
if num.is_integer():
return str(int(num))
return str(num)
except:
return x
def sort_key(x):
try:
return (0, float(x))
except:
return (1, x)
def get_categorical_columns(df: pd.DataFrame, column_name: str) -> list[str]:
df[column_name] = df[column_name].astype(str)
df[column_name] = df[column_name].map(normalize)
cats = sorted(df[column_name].dropna().unique(), key=sort_key)
df[column_name] = pd.Categorical(df[column_name], categories=cats, ordered=True)
s = df.sort_values(by=column_name).reset_index(drop=True)[column_name]
return s.tolist()
def render_chart(
df: pd.DataFrame,
incident_cols: list[str] | None = None,
chart_type_raw: str | None = None,
key_prefix: str = "chart",
):
if df.empty:
st.info("No data to chart.")
return
df = df.copy()
for c in df.columns:
if hasattr(df[c], "cat"):
df[c] = df[c].astype(str).replace("nan", None)
chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df)
incident_cols = incident_cols or []
active_incident_col: str | None = None
df_plot = df.copy()
cols = list(df.columns)
numeric = df.select_dtypes(include="number").columns.tolist()
cat = df.select_dtypes(exclude="number").columns.tolist()
default_x = (
incident_cols[0]
if incident_cols and incident_cols[0] in cols
else (cat[0] if cat else cols[0])
)
default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0])
# ββ Column selectors ββ
st.caption("βοΈ Configure columns")
if chart_type == "pie":
c1, c2 = st.columns(2)
with c1:
x_col = st.selectbox(
"Labels",
options=cols,
index=cols.index(default_x) if default_x in cols else 0,
key=f"{key_prefix}_pie_x",
)
with c2:
val_opts = numeric if numeric else cols
for col in cols:
try:
if pd.to_numeric(df[col], errors="coerce").notna().all() and col not in val_opts:
val_opts.append(col)
except Exception:
pass
y_col = st.selectbox(
"Values",
options=val_opts,
index=val_opts.index(default_y) if default_y in val_opts else 0,
key=f"{key_prefix}_pie_y",
)
color_col = None
else:
c1, c2, c3 = st.columns(3)
with c1:
x_col = st.selectbox(
"X axis",
options=cols,
index=cols.index(default_x) if default_x in cols else 0,
key=f"{key_prefix}_x",
)
view_all_labels = st.checkbox(
"View All Labels",
key=f"{key_prefix}_x_all_labels",
)
with c2:
y_opts = numeric if numeric else cols
y_col = st.selectbox(
"Y axis",
options=y_opts,
index=y_opts.index(default_y) if default_y in y_opts else 0,
key=f"{key_prefix}_y",
)
with c3:
color_options = ["None"] + [c for c in cols if c not in (x_col, y_col)]
color_sel = st.selectbox(
"Color / Group",
options=color_options,
index=0,
key=f"{key_prefix}_color",
)
view_horizontal_stacked = st.checkbox(
"Horizontal Stacked",
key=f"{key_prefix}_stacked",
)
color_col = None if color_sel == "None" else color_sel
# ββ Incident color mapping βββββββββββββββββββββββββ
incident_color_map = None
if active_incident_col and "_incident_label" in df_plot.columns:
incident_color_map = dict(
zip(df_plot["_incident_label"], df_plot["_incident_color"])
)
# ββ Build chart ββββββββββββββββββββββββββββββββββββ
fig = None
tmpl = _get_plotly_template()
date_range = get_date_range(df_plot)
title = f"{y_col} by {x_col}"
if date_range:
title += f" ({date_range})"
try:
# βββββ BAR βββββ
if chart_type == "bar":
if view_all_labels:
df_plot[x_col] = df_plot[x_col].astype(str)
active_incident_col = x_col if x_col in incident_cols else None
df_plot = (
_add_incident_category(df, active_incident_col)
if active_incident_col
else df.copy()
)
incident_color_map = (
dict(
zip(
df_plot["_incident_label"].tolist(),
df_plot["_incident_color"].tolist(),
)
)
if active_incident_col and "_incident_label" in df_plot.columns
else None
)
if color_col:
df_plot_copy = df_plot.copy()
color_incident_color_map = None
if color_col in incident_cols:
df_plot_copy = _add_incident_category(df_plot_copy, color_col)
color_incident_color_map = dict(
zip(
df_plot_copy["_incident_label"].tolist(),
df_plot_copy["_incident_color"].tolist(),
)
)
if not color_incident_color_map:
color_incident_color_map = incident_color_map
if active_incident_col is not None:
x_col = "_incident_label"
if color_col in incident_cols:
color_col = "_incident_label"
if view_horizontal_stacked:
df_plot_copy[color_col] = df_plot_copy[color_col].astype(str)
if df_plot_copy.duplicated(subset=[x_col, color_col]).any():
df_plot_copy[y_col] = (
df_plot_copy.groupby([x_col, color_col])[y_col]
.transform("sum")
)
df_plot_copy = df_plot_copy.drop_duplicates(subset=[x_col, color_col])
bar_kwargs = dict(
x=x_col, y=y_col, color=color_col, barmode="group", template=tmpl, text=y_col
)
category_orders = {}
if active_incident_col is not None:
bar_kwargs["x"] = "_incident_label"
color_incident_color_map = incident_color_map
bar_kwargs["color_discrete_map"] = color_incident_color_map
if (
color_incident_color_map
and color_col in incident_cols
) or (
color_incident_color_map
and color_col in ["_incident_label"]
):
bar_kwargs["x"] = (
"_incident_label" if active_incident_col is not None else x_col
)
bar_kwargs["color"] = "_incident_label"
bar_kwargs["color_discrete_map"] = color_incident_color_map
category_orders["_incident_label"] = INICIDENT_CATEGORY_NAMES
x_order = get_categorical_columns(df_plot_copy, bar_kwargs["x"])
category_orders[bar_kwargs["x"]] = x_order
if category_orders:
bar_kwargs["category_orders"] = category_orders
fig = px.bar(df_plot_copy, **bar_kwargs)
group_x = bar_kwargs["x"]
group_y = bar_kwargs["y"]
group_color = bar_kwargs.get("color")
x_values = category_orders.get(group_x, df_plot_copy[group_x].drop_duplicates().tolist())
if group_color:
total_base = (
df_plot_copy[[group_x, group_color, group_y]]
.drop_duplicates()
)
else:
total_base = (
df_plot_copy[[group_x, group_y]]
.drop_duplicates()
)
group_totals = (
total_base.groupby(group_x)[group_y]
.sum()
.reset_index()
)
totals_map = dict(
zip(group_totals[group_x], group_totals[group_y])
)
for x_val in x_values:
if x_val in totals_map:
fig.add_annotation(
x=x_val,
y=totals_map[x_val],
text=f"{totals_map[x_val]}",
showarrow=False,
yshift=15,
font=dict(size=14),
xanchor="center"
)
elif incident_color_map and active_incident_col is not None:
df_plot_copy = df_plot.copy()
if df_plot[x_col].duplicated().any():
df_plot_copy[y_col] = df_plot_copy.groupby("_incident_label")[y_col].transform("sum")
df_plot_copy.drop_duplicates(subset=["_incident_label"], inplace=True)
fig = px.bar(
df_plot_copy,
x="_incident_label",
y=y_col,
color="_incident_label",
template=tmpl,
color_discrete_map=incident_color_map,
text=y_col,
category_orders={"_incident_label": INICIDENT_CATEGORY_NAMES}
)
else:
df_plot_copy = df_plot.copy()
if df_plot[x_col].duplicated().any():
df_plot_copy[y_col] = df_plot_copy.groupby(x_col)[y_col].transform("sum")
df_plot_copy.drop_duplicates(subset=[x_col], inplace=True)
fig = px.bar(df_plot_copy, x=x_col, y=y_col, template=tmpl, text=y_col)
fig.update_traces(textposition="outside")
if view_all_labels:
fig.update_xaxes(
type="category",
tickmode="array",
tickvals=get_categorical_columns(df_plot, x_col)
)
if active_incident_col is not None:
fig.update_xaxes(
categoryorder="array",
categoryarray=INICIDENT_CATEGORY_NAMES
)
# βββββ LINE βββββ
elif chart_type == "line":
active_incident_col = x_col if x_col in incident_cols else None
df_plot = (
_add_incident_category(df, active_incident_col)
if active_incident_col
else df.copy()
)
incident_color_map = (
dict(
zip(
df_plot["_incident_label"].tolist(),
df_plot["_incident_color"].tolist(),
)
)
if active_incident_col and "_incident_label" in df_plot.columns
else None
)
if color_col:
line_kwargs = dict(
x=x_col, y=y_col, color=color_col, markers=True, template=tmpl
)
if incident_color_map and color_col in incident_cols:
line_kwargs["x"] = (
"_incident_label" if active_incident_col is not None else x_col
)
line_kwargs["color"] = "_incident_label"
line_kwargs["color_discrete_map"] = incident_color_map
fig = px.line(df_plot, **line_kwargs)
elif incident_color_map and active_incident_col is not None:
fig = px.line(
df_plot,
x="_incident_label",
y=y_col,
color="_incident_label",
markers=True,
template=tmpl,
color_discrete_map=incident_color_map,
)
else:
fig = px.line(df_plot, x=x_col, y=y_col, markers=True, template=tmpl)
# βββββ PIE βββββ
elif chart_type == "pie":
active_incident_col = x_col if x_col in incident_cols else None
df_plot = (
_add_incident_category(df, active_incident_col)
if active_incident_col
else df.copy()
)
incident_color_map = (
dict(
zip(
df_plot["_incident_label"].tolist(),
df_plot["_incident_color"].tolist(),
)
)
if active_incident_col and "_incident_label" in df_plot.columns
else None
)
if y_col:
df_plot[y_col] = pd.to_numeric(df_plot[y_col], errors="coerce").fillna(0)
if incident_color_map and active_incident_col is not None:
fig = px.pie(
df_plot,
names="_incident_label",
values=y_col,
hole=0.35,
template=tmpl,
color="_incident_label",
color_discrete_map=incident_color_map,
)
else:
fig = px.pie(
df_plot, names=x_col, values=y_col, hole=0.35, template=tmpl
)
fig.update_traces(textinfo="percent+label")
# βββββ FALLBACK βββββ
else:
fig = px.bar(
df_plot,
x=x_col,
y=y_col,
template=tmpl,
title=f"Chart type '{chart_type_raw}' not recognized",
)
fig.update_layout(
title={
"text": title.replace("_", " ").capitalize(),
"x": 0.5,
"xanchor": "center",
"font": {
"size": 24
}
}
)
except Exception as e:
st.warning(f"Could not render `{chart_type}` chart: {e}")
return
if fig:
st.plotly_chart(
fig,
use_container_width=True,
config={"displayModeBar": False},
key=f"{uuid.uuid4()}_plot",
)
def render_crosstab(df: pd.DataFrame):
if df.empty:
st.info("No data to summarise.")
return
numeric = df.select_dtypes(include="number").columns.tolist()
cat = df.select_dtypes(exclude="number").columns.tolist()
try:
if len(cat) >= 2 and len(numeric) >= 1:
pivot = df.pivot_table(
index=cat[0],
columns=cat[1],
values=numeric[0],
aggfunc="sum",
fill_value=0,
)
st.caption(f"Crosstab β {cat[0]} x {cat[1]} (sum of {numeric[0]})")
st.dataframe(pivot, use_container_width=True)
elif len(cat) == 1 and len(numeric) >= 1:
summary = df.groupby(cat[0])[numeric].agg(["sum", "mean", "count"])
summary.columns = [f"{v}_{f}" for v, f in summary.columns]
summary = summary.reset_index()
st.caption(f"Summary β grouped by {cat[0]}")
st.dataframe(summary, use_container_width=True, hide_index=True)
elif len(numeric) >= 2:
corr = df[numeric].corr().round(3)
st.caption("Correlation Matrix")
st.dataframe(
corr.style.background_gradient(cmap="Blues", axis=None),
use_container_width=True,
)
else:
desc = df.describe(include="all").T.reset_index()
desc.rename(columns={"index": "column"}, inplace=True)
st.caption("Statistical Summary")
st.dataframe(desc, use_container_width=True, hide_index=True)
except Exception as e:
st.warning(f"Could not build crosstab: {e}")
st.dataframe(df.describe(include="all").T, use_container_width=True)
# ββ Controller singleton βββββββββββββββββββββββββββββββββββββββββββββββββββββ
@st.cache_resource
def get_controller():
return DataExtractorController()
controller = get_controller()
# ββ Session state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "total_queries" not in st.session_state:
st.session_state.total_queries = 0
if "successful_queries" not in st.session_state:
st.session_state.successful_queries = 0
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_message_history() -> list[Message]:
return [
Message(role=msg["role"], content=msg["content"])
for msg in st.session_state.chat_history
]
def call_controller(user_query: str):
uq = UserQuery(user_query=user_query)
history = build_message_history()
response = asyncio.run(controller.extrcat(user_query=uq, message_history=history))
return response
def render_message(msg):
is_user = msg["role"] == "user"
avatar = "π€" if is_user else "π€"
role = "user" if is_user else "assistant"
with st.chat_message(role, avatar=avatar):
# Status badge
if "status" in msg and not is_user:
if msg["status"] == "success":
st.success("Query executed successfully", icon="β
")
else:
st.error("Query failed", icon="β")
st.markdown(msg["content"])
# SQL block
if msg.get("sql"):
with st.expander("Generated SQL", expanded=False):
st.code(msg["sql"], language="sql")
# Timestamp
if msg.get("ts"):
st.caption(msg["ts"])
# ββ Multi-view data panel ββββββββββββββββββββββββββββββββββββββββββββββ
data = msg.get("data", [])
chart_hint = msg.get("best_suitable_chart")
if data and len(data) > 0 and msg.get("status") == "success":
df = pd.DataFrame(data)
tab_table, tab_crosstab, tab_chart = st.tabs(
["π Table", "π Crosstab / Summary", "π Chart"]
)
with tab_table:
st.dataframe(df, use_container_width=True, hide_index=True)
with tab_crosstab:
render_crosstab(df)
with tab_chart:
if chart_hint:
icon_map = {"BAR": "π", "PIE": "π₯§", "LINE": "π"}
icon = icon_map.get(str(chart_hint).upper(), "π")
st.caption(f"{icon} {chart_hint} (AI suggested)")
else:
st.caption("π Auto-detected chart type")
charts = [
"BAR",
"LINE",
"PIE"
]
default_option = chart_hint if chart_hint in charts else "BAR"
default_index = charts.index(default_option)
charts[default_index] = f"{charts[default_index]} (AI suggested)"
chart_hint = st.selectbox(
"Select chart type",
options=charts,
index=default_index,
key=f"chart_type_{msg.get('ts', 'x')}",
)
chart_hint = chart_hint.replace(" (AI suggested)", "") if chart_hint else None
chart_key = f"chart_{msg.get('ts', 'x').replace(':', '_')}"
incident_cols = []
if msg.get("incident_col"):
incident_cols.append(msg["incident_col"])
detected = _detect_incident_col(df)
if detected and detected not in incident_cols:
incident_cols.extend(detected)
for col in incident_cols:
if col in df.columns:
df[col] = df[col].astype("category")
render_chart(df, incident_cols, chart_hint, key_prefix=chart_key)
elif msg.get("status") == "success" and not data:
st.info("Query returned 0 rows.")
# ββ Sidebar βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with st.sidebar:
st.header("π Session Stats")
col1, col2 = st.columns(2)
with col1:
st.metric("Queries", st.session_state.total_queries)
with col2:
st.metric("Success", st.session_state.successful_queries)
st.divider()
st.header("π‘ Example Prompts")
examples = [
"List all fire incidents in last 10 years",
"Show top 10 incidents by type",
"Count incidents per year",
"Find incidents with alarm time after 6pm",
"List unique incident types",
]
for ex in examples:
if st.button(ex, use_container_width=True, key=f"ex_{ex[:20]}"):
st.session_state["prefill"] = ex
st.divider()
if st.button("π Clear History", use_container_width=True):
st.session_state.chat_history = []
st.session_state.total_queries = 0
st.session_state.successful_queries = 0
st.rerun()
if st.session_state.chat_history:
with st.expander("π Raw Message History"):
st.json(build_message_history())
# ββ Main layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.title("Firerms Data Extractor Chatbot")
st.caption("Natural language β SQL β Results")
# ββ Chat area βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
chat_container = st.container()
with chat_container:
if not st.session_state.chat_history:
st.markdown("---")
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.markdown(
"<div style='text-align:center;padding:40px 0;'>"
"<p style='font-size:3rem;'>π</p>"
"<h3>Ask anything about your data</h3>"
"<p>Type a natural language question and the AI will generate SQL and return results.</p>"
"</div>",
unsafe_allow_html=True,
)
else:
for msg in st.session_state.chat_history:
render_message(msg)
# ββ Chat input ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
prefill = st.session_state.pop("prefill", "")
prompt = st.chat_input("Ask a question about your dataβ¦", key="chat_input")
if not prompt and prefill:
prompt = prefill
if prompt:
ts_now = datetime.now().strftime("%H:%M:%S")
st.session_state.chat_history.append(
{"role": "user", "content": prompt, "ts": ts_now}
)
st.session_state.total_queries += 1
with st.spinner("Generating SQL and fetching resultsβ¦"):
try:
result = call_controller(prompt)
status = result.status
sql = result.sql_query
data = result.data or []
try:
best_chart = result.output.best_suitable_chart.value
except Exception:
best_chart = None
incident_col = None
if result.output.is_incident_category_required:
incident_col = result.output.column_name_mapped_with_incident_category
row_count = len(data)
content = (
f"Query executed successfully. Returned **{row_count}** row(s)."
if status == "success"
else f"Query returned status: `{status}`."
)
st.session_state.chat_history.append(
{
"role": "assistant",
"content": content,
"sql": sql,
"data": data,
"status": status,
"best_suitable_chart": best_chart,
"incident_col": incident_col,
"ts": datetime.now().strftime("%H:%M:%S"),
}
)
if status == "success":
st.session_state.successful_queries += 1
except Exception as e:
st.session_state.chat_history.append(
{
"role": "assistant",
"content": f"Error: {str(e)}",
"status": "error",
"ts": datetime.now().strftime("%H:%M:%S"),
}
)
st.rerun()
|