Spaces:
Sleeping
Sleeping
Aryan Jain commited on
Commit Β·
dca20c8
1
Parent(s): 9fdba00
show charts
Browse files- src/controllers/_data_extractor.py +4 -1
- src/schemas/_pydantic_agent.py +4 -4
- src/streamlit_app.py +165 -83
src/controllers/_data_extractor.py
CHANGED
|
@@ -18,6 +18,7 @@ class Response(BaseModel):
|
|
| 18 |
sql_query: str
|
| 19 |
output: SQLQueryExtractor
|
| 20 |
|
|
|
|
| 21 |
class DataExtractorController:
|
| 22 |
def __init__(self):
|
| 23 |
self.router = APIRouter()
|
|
@@ -36,7 +37,9 @@ class DataExtractorController:
|
|
| 36 |
response, sql_query, output = await service.extract(
|
| 37 |
user_query=user_query.user_query, message_history=message_history
|
| 38 |
)
|
| 39 |
-
return Response(
|
|
|
|
|
|
|
| 40 |
except HTTPException as e:
|
| 41 |
logger.error(e)
|
| 42 |
raise e
|
|
|
|
| 18 |
sql_query: str
|
| 19 |
output: SQLQueryExtractor
|
| 20 |
|
| 21 |
+
|
| 22 |
class DataExtractorController:
|
| 23 |
def __init__(self):
|
| 24 |
self.router = APIRouter()
|
|
|
|
| 37 |
response, sql_query, output = await service.extract(
|
| 38 |
user_query=user_query.user_query, message_history=message_history
|
| 39 |
)
|
| 40 |
+
return Response(
|
| 41 |
+
status="success", data=response, sql_query=sql_query, output=output
|
| 42 |
+
)
|
| 43 |
except HTTPException as e:
|
| 44 |
logger.error(e)
|
| 45 |
raise e
|
src/schemas/_pydantic_agent.py
CHANGED
|
@@ -118,12 +118,13 @@ class OrderConditions(BaseModel):
|
|
| 118 |
"Use 'ASC' for: oldest, lowest, alphabetical, earliest."
|
| 119 |
),
|
| 120 |
)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
class ChartsType(Enum):
|
| 124 |
"""
|
| 125 |
Enumeration of supported chart types for visualization.
|
| 126 |
"""
|
|
|
|
| 127 |
PIE: str = "PIE"
|
| 128 |
BAR: str = "BAR"
|
| 129 |
LINE: str = "LINE"
|
|
@@ -221,7 +222,7 @@ class SQLQueryExtractor(BaseModel):
|
|
| 221 |
"This must be set True if provided tool return true else set False. "
|
| 222 |
),
|
| 223 |
)
|
| 224 |
-
|
| 225 |
best_suitable_chart: ChartsType = Field(
|
| 226 |
...,
|
| 227 |
description=(
|
|
@@ -233,7 +234,6 @@ class SQLQueryExtractor(BaseModel):
|
|
| 233 |
),
|
| 234 |
)
|
| 235 |
|
| 236 |
-
|
| 237 |
@model_validator(mode="after")
|
| 238 |
def validate_sql_query(self):
|
| 239 |
if not self.is_sql_query_verified_using_provided_tool:
|
|
|
|
| 118 |
"Use 'ASC' for: oldest, lowest, alphabetical, earliest."
|
| 119 |
),
|
| 120 |
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
class ChartsType(Enum):
|
| 124 |
"""
|
| 125 |
Enumeration of supported chart types for visualization.
|
| 126 |
"""
|
| 127 |
+
|
| 128 |
PIE: str = "PIE"
|
| 129 |
BAR: str = "BAR"
|
| 130 |
LINE: str = "LINE"
|
|
|
|
| 222 |
"This must be set True if provided tool return true else set False. "
|
| 223 |
),
|
| 224 |
)
|
| 225 |
+
|
| 226 |
best_suitable_chart: ChartsType = Field(
|
| 227 |
...,
|
| 228 |
description=(
|
|
|
|
| 234 |
),
|
| 235 |
)
|
| 236 |
|
|
|
|
| 237 |
@model_validator(mode="after")
|
| 238 |
def validate_sql_query(self):
|
| 239 |
if not self.is_sql_query_verified_using_provided_tool:
|
src/streamlit_app.py
CHANGED
|
@@ -9,8 +9,8 @@ import plotly.graph_objects as go
|
|
| 9 |
from datetime import datetime
|
| 10 |
|
| 11 |
# ββ Path setup so `src` is importable when running from src/ or project root ββ
|
| 12 |
-
_here = Path(__file__).resolve().parent
|
| 13 |
-
_root = _here.parent
|
| 14 |
for _p in [str(_here), str(_root)]:
|
| 15 |
if _p not in sys.path:
|
| 16 |
sys.path.insert(0, _p)
|
|
@@ -27,7 +27,8 @@ st.set_page_config(
|
|
| 27 |
)
|
| 28 |
|
| 29 |
# ββ Custom CSS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
-
st.markdown(
|
|
|
|
| 31 |
<style>
|
| 32 |
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=Syne:wght@400;600;700;800&display=swap');
|
| 33 |
|
|
@@ -259,7 +260,9 @@ html, body, [class*="css"] {
|
|
| 259 |
font-weight: 600; margin-bottom: 8px; font-family: var(--mono);
|
| 260 |
}
|
| 261 |
</style>
|
| 262 |
-
""",
|
|
|
|
|
|
|
| 263 |
|
| 264 |
# ββ Plotly dark theme template ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 265 |
PLOTLY_TEMPLATE = dict(
|
|
@@ -269,8 +272,16 @@ PLOTLY_TEMPLATE = dict(
|
|
| 269 |
font=dict(color="#e8eaf0", family="JetBrains Mono, monospace", size=11),
|
| 270 |
xaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
|
| 271 |
yaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
|
| 272 |
-
colorway=[
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
legend=dict(bgcolor="#1c2030", bordercolor="#2a2f42", borderwidth=1),
|
| 275 |
margin=dict(l=40, r=20, t=40, b=40),
|
| 276 |
)
|
|
@@ -281,14 +292,20 @@ PLOTLY_TEMPLATE = dict(
|
|
| 281 |
|
| 282 |
CHART_ALIASES = {
|
| 283 |
# ββ ChartsType enum values (uppercase from Pydantic/Enum .value) ββ
|
| 284 |
-
"pie":
|
| 285 |
-
"bar":
|
| 286 |
"line": "line",
|
| 287 |
# ββ Common string variants (lowercase) ββ
|
| 288 |
-
"bar_chart": "bar",
|
| 289 |
-
"
|
| 290 |
-
"
|
| 291 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
}
|
| 293 |
|
| 294 |
|
|
@@ -322,7 +339,9 @@ def _guess_chart_type(df: pd.DataFrame) -> str:
|
|
| 322 |
return "bar"
|
| 323 |
|
| 324 |
|
| 325 |
-
def render_chart(
|
|
|
|
|
|
|
| 326 |
"""
|
| 327 |
Render a Plotly chart with user-controlled column selectors.
|
| 328 |
The user picks x, y, and optional color columns via st.selectbox.
|
|
@@ -334,19 +353,19 @@ def render_chart(df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix
|
|
| 334 |
return
|
| 335 |
|
| 336 |
chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df)
|
| 337 |
-
cols
|
| 338 |
numeric = df.select_dtypes(include="number").columns.tolist()
|
| 339 |
-
cat
|
| 340 |
|
| 341 |
# ββ Smart defaults: pre-select the most sensible column per role βββββββββββββββββββββ
|
| 342 |
-
default_x = cat[0]
|
| 343 |
default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0])
|
| 344 |
|
| 345 |
# ββ Column selector UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 346 |
st.markdown(
|
| 347 |
'<div style="font-size:0.7rem;color:var(--muted);font-family:var(--mono);'
|
| 348 |
'text-transform:uppercase;letter-spacing:0.08em;margin-bottom:8px;">'
|
| 349 |
-
|
| 350 |
unsafe_allow_html=True,
|
| 351 |
)
|
| 352 |
|
|
@@ -408,36 +427,60 @@ def render_chart(df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix
|
|
| 408 |
try:
|
| 409 |
if chart_type == "bar":
|
| 410 |
if color_col:
|
| 411 |
-
fig = px.bar(
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
else:
|
| 414 |
fig = px.bar(df, x=x_col, y=y_col, template=PLOTLY_TEMPLATE)
|
| 415 |
|
| 416 |
elif chart_type == "line":
|
| 417 |
if color_col:
|
| 418 |
-
fig = px.line(
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
else:
|
| 421 |
-
fig = px.line(
|
| 422 |
-
|
|
|
|
| 423 |
|
| 424 |
elif chart_type == "pie":
|
| 425 |
-
fig = px.pie(
|
| 426 |
-
|
|
|
|
| 427 |
fig.update_traces(textinfo="percent+label")
|
| 428 |
|
| 429 |
else: # unrecognised / None β heuristic fallback bar
|
| 430 |
-
fig = px.bar(
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
except Exception as e:
|
| 434 |
-
st.warning(
|
|
|
|
|
|
|
| 435 |
return
|
| 436 |
|
| 437 |
if fig:
|
| 438 |
-
fig.update_layout(
|
|
|
|
|
|
|
| 439 |
st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
|
| 440 |
|
|
|
|
| 441 |
def render_crosstab(df: pd.DataFrame):
|
| 442 |
"""
|
| 443 |
Auto-build a crosstab-style pivot summary.
|
|
@@ -453,41 +496,51 @@ def render_crosstab(df: pd.DataFrame):
|
|
| 453 |
return
|
| 454 |
|
| 455 |
numeric = df.select_dtypes(include="number").columns.tolist()
|
| 456 |
-
cat
|
| 457 |
|
| 458 |
try:
|
| 459 |
if len(cat) >= 2 and len(numeric) >= 1:
|
| 460 |
pivot = df.pivot_table(
|
| 461 |
-
index=cat[0],
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
)
|
| 464 |
-
st.markdown(f'<div class="table-label">π Crosstab β {cat[0]} Γ {cat[1]} (sum of {numeric[0]})</div>',
|
| 465 |
-
unsafe_allow_html=True)
|
| 466 |
st.dataframe(pivot, use_container_width=True)
|
| 467 |
|
| 468 |
elif len(cat) == 1 and len(numeric) >= 1:
|
| 469 |
-
summary = (
|
| 470 |
-
df.groupby(cat[0])[numeric]
|
| 471 |
-
.agg(["sum", "mean", "count"])
|
| 472 |
-
)
|
| 473 |
summary.columns = [f"{v}_{f}" for v, f in summary.columns]
|
| 474 |
summary = summary.reset_index()
|
| 475 |
-
st.markdown(
|
| 476 |
-
|
|
|
|
|
|
|
| 477 |
st.dataframe(summary, use_container_width=True, hide_index=True)
|
| 478 |
|
| 479 |
elif len(numeric) >= 2:
|
| 480 |
corr = df[numeric].corr().round(3)
|
| 481 |
-
st.markdown(
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
else:
|
| 487 |
desc = df.describe(include="all").T.reset_index()
|
| 488 |
desc.rename(columns={"index": "column"}, inplace=True)
|
| 489 |
-
st.markdown(
|
| 490 |
-
|
|
|
|
|
|
|
| 491 |
st.dataframe(desc, use_container_width=True, hide_index=True)
|
| 492 |
|
| 493 |
except Exception as e:
|
|
@@ -500,6 +553,7 @@ def render_crosstab(df: pd.DataFrame):
|
|
| 500 |
def get_controller():
|
| 501 |
return DataExtractorController()
|
| 502 |
|
|
|
|
| 503 |
controller = get_controller()
|
| 504 |
|
| 505 |
# ββ Session state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -510,6 +564,7 @@ if "total_queries" not in st.session_state:
|
|
| 510 |
if "successful_queries" not in st.session_state:
|
| 511 |
st.session_state.successful_queries = 0
|
| 512 |
|
|
|
|
| 513 |
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 514 |
def build_message_history() -> list[Message]:
|
| 515 |
return [
|
|
@@ -517,23 +572,28 @@ def build_message_history() -> list[Message]:
|
|
| 517 |
for msg in st.session_state.chat_history
|
| 518 |
]
|
| 519 |
|
|
|
|
| 520 |
def call_controller(user_query: str):
|
| 521 |
uq = UserQuery(user_query=user_query)
|
| 522 |
history = build_message_history()
|
| 523 |
response = asyncio.run(controller.extrcat(user_query=uq, message_history=history))
|
| 524 |
return response
|
| 525 |
|
|
|
|
| 526 |
def render_message(msg):
|
| 527 |
is_user = msg["role"] == "user"
|
| 528 |
role_class = "user" if is_user else "bot"
|
| 529 |
avatar = "U" if is_user else "AI"
|
| 530 |
ts = msg.get("ts", "")
|
| 531 |
|
| 532 |
-
st.markdown(
|
|
|
|
| 533 |
<div class="msg-wrap {role_class}">
|
| 534 |
<div class="avatar {role_class}">{avatar}</div>
|
| 535 |
<div style="max-width:72%">
|
| 536 |
-
""",
|
|
|
|
|
|
|
| 537 |
|
| 538 |
with st.container():
|
| 539 |
if "status" in msg:
|
|
@@ -541,13 +601,15 @@ def render_message(msg):
|
|
| 541 |
label = "β success" if msg["status"] == "success" else "β error"
|
| 542 |
st.markdown(
|
| 543 |
f'<span class="badge {badge_class}">{label}</span>',
|
| 544 |
-
unsafe_allow_html=True
|
| 545 |
)
|
| 546 |
|
| 547 |
st.markdown(msg["content"])
|
| 548 |
|
| 549 |
if msg.get("sql"):
|
| 550 |
-
st.markdown(
|
|
|
|
|
|
|
| 551 |
st.code(msg["sql"], language="sql")
|
| 552 |
|
| 553 |
st.markdown(f'<div class="ts">{ts}</div>', unsafe_allow_html=True)
|
|
@@ -555,14 +617,20 @@ def render_message(msg):
|
|
| 555 |
|
| 556 |
# ββ Data table ββ
|
| 557 |
if msg.get("data") and len(msg["data"]) > 0:
|
| 558 |
-
st.markdown(
|
|
|
|
|
|
|
| 559 |
df = pd.DataFrame(msg["data"])
|
| 560 |
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 561 |
-
elif
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
st.markdown(
|
| 563 |
'<div style="color:#6b7280;font-size:0.8rem;margin-top:8px;'
|
| 564 |
'font-family:monospace">β Query returned 0 rows.</div>',
|
| 565 |
-
unsafe_allow_html=True
|
| 566 |
)
|
| 567 |
|
| 568 |
|
|
@@ -601,24 +669,30 @@ with st.sidebar:
|
|
| 601 |
|
| 602 |
|
| 603 |
# ββ Main layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 604 |
-
st.markdown(
|
|
|
|
| 605 |
<div style="margin-bottom: 1.5rem;">
|
| 606 |
<div class="page-title">Firerms Data Extractor Chatbot</div>
|
| 607 |
<div class="page-sub">Natural language β SQL β Results</div>
|
| 608 |
</div>
|
| 609 |
-
""",
|
|
|
|
|
|
|
| 610 |
|
| 611 |
# ββ Chat area βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 612 |
chat_container = st.container()
|
| 613 |
with chat_container:
|
| 614 |
if not st.session_state.chat_history:
|
| 615 |
-
st.markdown(
|
|
|
|
| 616 |
<div class="empty-state">
|
| 617 |
<div class="empty-icon">π</div>
|
| 618 |
<div class="empty-title">Ask anything about your data</div>
|
| 619 |
<div class="empty-hint">Type a natural language question and the AI will generate SQL and return results.</div>
|
| 620 |
</div>
|
| 621 |
-
""",
|
|
|
|
|
|
|
| 622 |
else:
|
| 623 |
for msg in st.session_state.chat_history:
|
| 624 |
render_message(msg)
|
|
@@ -632,24 +706,28 @@ if not prompt and prefill:
|
|
| 632 |
if prompt:
|
| 633 |
ts_now = datetime.now().strftime("%H:%M:%S")
|
| 634 |
|
| 635 |
-
st.session_state.chat_history.append(
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
|
|
|
|
|
|
| 640 |
st.session_state.total_queries += 1
|
| 641 |
|
| 642 |
with st.spinner("Generating SQL and fetching resultsβ¦"):
|
| 643 |
try:
|
| 644 |
result = call_controller(prompt)
|
| 645 |
-
status
|
| 646 |
-
sql
|
| 647 |
-
data
|
| 648 |
|
| 649 |
# ββ Extract best_suitable_chart from result.output (SQLQueryExtractor) ββ
|
| 650 |
# result.output.best_suitable_chart is a ChartsType enum β use .value for the string
|
| 651 |
try:
|
| 652 |
-
best_chart =
|
|
|
|
|
|
|
| 653 |
except Exception:
|
| 654 |
best_chart = None
|
| 655 |
|
|
@@ -660,25 +738,29 @@ if prompt:
|
|
| 660 |
else f"Query returned status: `{status}`."
|
| 661 |
)
|
| 662 |
|
| 663 |
-
st.session_state.chat_history.append(
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
|
|
|
|
|
|
| 672 |
|
| 673 |
if status == "success":
|
| 674 |
st.session_state.successful_queries += 1
|
| 675 |
|
| 676 |
except Exception as e:
|
| 677 |
-
st.session_state.chat_history.append(
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
|
|
|
|
|
|
|
|
| 9 |
from datetime import datetime
|
| 10 |
|
| 11 |
# ββ Path setup so `src` is importable when running from src/ or project root ββ
|
| 12 |
+
_here = Path(__file__).resolve().parent # src/
|
| 13 |
+
_root = _here.parent # project root
|
| 14 |
for _p in [str(_here), str(_root)]:
|
| 15 |
if _p not in sys.path:
|
| 16 |
sys.path.insert(0, _p)
|
|
|
|
| 27 |
)
|
| 28 |
|
| 29 |
# ββ Custom CSS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
st.markdown(
|
| 31 |
+
"""
|
| 32 |
<style>
|
| 33 |
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=Syne:wght@400;600;700;800&display=swap');
|
| 34 |
|
|
|
|
| 260 |
font-weight: 600; margin-bottom: 8px; font-family: var(--mono);
|
| 261 |
}
|
| 262 |
</style>
|
| 263 |
+
""",
|
| 264 |
+
unsafe_allow_html=True,
|
| 265 |
+
)
|
| 266 |
|
| 267 |
# ββ Plotly dark theme template ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 268 |
PLOTLY_TEMPLATE = dict(
|
|
|
|
| 272 |
font=dict(color="#e8eaf0", family="JetBrains Mono, monospace", size=11),
|
| 273 |
xaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
|
| 274 |
yaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
|
| 275 |
+
colorway=[
|
| 276 |
+
"#4f8ef7",
|
| 277 |
+
"#7c5cfc",
|
| 278 |
+
"#22d3a5",
|
| 279 |
+
"#f75f5f",
|
| 280 |
+
"#f7a24f",
|
| 281 |
+
"#e879f9",
|
| 282 |
+
"#38bdf8",
|
| 283 |
+
"#fb923c",
|
| 284 |
+
],
|
| 285 |
legend=dict(bgcolor="#1c2030", bordercolor="#2a2f42", borderwidth=1),
|
| 286 |
margin=dict(l=40, r=20, t=40, b=40),
|
| 287 |
)
|
|
|
|
| 292 |
|
| 293 |
CHART_ALIASES = {
|
| 294 |
# ββ ChartsType enum values (uppercase from Pydantic/Enum .value) ββ
|
| 295 |
+
"pie": "pie",
|
| 296 |
+
"bar": "bar",
|
| 297 |
"line": "line",
|
| 298 |
# ββ Common string variants (lowercase) ββ
|
| 299 |
+
"bar_chart": "bar",
|
| 300 |
+
"vertical_bar": "bar",
|
| 301 |
+
"column": "bar",
|
| 302 |
+
"grouped_bar": "bar",
|
| 303 |
+
"stacked_bar": "bar",
|
| 304 |
+
"line_chart": "line",
|
| 305 |
+
"time_series": "line",
|
| 306 |
+
"trend": "line",
|
| 307 |
+
"donut": "pie",
|
| 308 |
+
"doughnut": "pie",
|
| 309 |
}
|
| 310 |
|
| 311 |
|
|
|
|
| 339 |
return "bar"
|
| 340 |
|
| 341 |
|
| 342 |
+
def render_chart(
|
| 343 |
+
df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix: str = "chart"
|
| 344 |
+
):
|
| 345 |
"""
|
| 346 |
Render a Plotly chart with user-controlled column selectors.
|
| 347 |
The user picks x, y, and optional color columns via st.selectbox.
|
|
|
|
| 353 |
return
|
| 354 |
|
| 355 |
chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df)
|
| 356 |
+
cols = list(df.columns)
|
| 357 |
numeric = df.select_dtypes(include="number").columns.tolist()
|
| 358 |
+
cat = df.select_dtypes(exclude="number").columns.tolist()
|
| 359 |
|
| 360 |
# ββ Smart defaults: pre-select the most sensible column per role βββββββββββββββββββββ
|
| 361 |
+
default_x = cat[0] if cat else cols[0]
|
| 362 |
default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0])
|
| 363 |
|
| 364 |
# ββ Column selector UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 365 |
st.markdown(
|
| 366 |
'<div style="font-size:0.7rem;color:var(--muted);font-family:var(--mono);'
|
| 367 |
'text-transform:uppercase;letter-spacing:0.08em;margin-bottom:8px;">'
|
| 368 |
+
"βοΈ Configure columns</div>",
|
| 369 |
unsafe_allow_html=True,
|
| 370 |
)
|
| 371 |
|
|
|
|
| 427 |
try:
|
| 428 |
if chart_type == "bar":
|
| 429 |
if color_col:
|
| 430 |
+
fig = px.bar(
|
| 431 |
+
df,
|
| 432 |
+
x=x_col,
|
| 433 |
+
y=y_col,
|
| 434 |
+
color=color_col,
|
| 435 |
+
barmode="group",
|
| 436 |
+
template=PLOTLY_TEMPLATE,
|
| 437 |
+
)
|
| 438 |
else:
|
| 439 |
fig = px.bar(df, x=x_col, y=y_col, template=PLOTLY_TEMPLATE)
|
| 440 |
|
| 441 |
elif chart_type == "line":
|
| 442 |
if color_col:
|
| 443 |
+
fig = px.line(
|
| 444 |
+
df,
|
| 445 |
+
x=x_col,
|
| 446 |
+
y=y_col,
|
| 447 |
+
color=color_col,
|
| 448 |
+
markers=True,
|
| 449 |
+
template=PLOTLY_TEMPLATE,
|
| 450 |
+
)
|
| 451 |
else:
|
| 452 |
+
fig = px.line(
|
| 453 |
+
df, x=x_col, y=y_col, markers=True, template=PLOTLY_TEMPLATE
|
| 454 |
+
)
|
| 455 |
|
| 456 |
elif chart_type == "pie":
|
| 457 |
+
fig = px.pie(
|
| 458 |
+
df, names=x_col, values=y_col, hole=0.35, template=PLOTLY_TEMPLATE
|
| 459 |
+
)
|
| 460 |
fig.update_traces(textinfo="percent+label")
|
| 461 |
|
| 462 |
else: # unrecognised / None β heuristic fallback bar
|
| 463 |
+
fig = px.bar(
|
| 464 |
+
df,
|
| 465 |
+
x=x_col,
|
| 466 |
+
y=y_col,
|
| 467 |
+
template=PLOTLY_TEMPLATE,
|
| 468 |
+
title=f"Chart type '{chart_type_raw}' not recognized β showing bar chart",
|
| 469 |
+
)
|
| 470 |
|
| 471 |
except Exception as e:
|
| 472 |
+
st.warning(
|
| 473 |
+
f"β Could not render `{chart_type}` chart: {e}. Check the column types selected above."
|
| 474 |
+
)
|
| 475 |
return
|
| 476 |
|
| 477 |
if fig:
|
| 478 |
+
fig.update_layout(
|
| 479 |
+
paper_bgcolor="#141720", plot_bgcolor="#0d0f14", font_color="#e8eaf0"
|
| 480 |
+
)
|
| 481 |
st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
|
| 482 |
|
| 483 |
+
|
| 484 |
def render_crosstab(df: pd.DataFrame):
|
| 485 |
"""
|
| 486 |
Auto-build a crosstab-style pivot summary.
|
|
|
|
| 496 |
return
|
| 497 |
|
| 498 |
numeric = df.select_dtypes(include="number").columns.tolist()
|
| 499 |
+
cat = df.select_dtypes(exclude="number").columns.tolist()
|
| 500 |
|
| 501 |
try:
|
| 502 |
if len(cat) >= 2 and len(numeric) >= 1:
|
| 503 |
pivot = df.pivot_table(
|
| 504 |
+
index=cat[0],
|
| 505 |
+
columns=cat[1],
|
| 506 |
+
values=numeric[0],
|
| 507 |
+
aggfunc="sum",
|
| 508 |
+
fill_value=0,
|
| 509 |
+
)
|
| 510 |
+
st.markdown(
|
| 511 |
+
f'<div class="table-label">π Crosstab β {cat[0]} Γ {cat[1]} (sum of {numeric[0]})</div>',
|
| 512 |
+
unsafe_allow_html=True,
|
| 513 |
)
|
|
|
|
|
|
|
| 514 |
st.dataframe(pivot, use_container_width=True)
|
| 515 |
|
| 516 |
elif len(cat) == 1 and len(numeric) >= 1:
|
| 517 |
+
summary = df.groupby(cat[0])[numeric].agg(["sum", "mean", "count"])
|
|
|
|
|
|
|
|
|
|
| 518 |
summary.columns = [f"{v}_{f}" for v, f in summary.columns]
|
| 519 |
summary = summary.reset_index()
|
| 520 |
+
st.markdown(
|
| 521 |
+
f'<div class="table-label">π Summary β grouped by {cat[0]}</div>',
|
| 522 |
+
unsafe_allow_html=True,
|
| 523 |
+
)
|
| 524 |
st.dataframe(summary, use_container_width=True, hide_index=True)
|
| 525 |
|
| 526 |
elif len(numeric) >= 2:
|
| 527 |
corr = df[numeric].corr().round(3)
|
| 528 |
+
st.markdown(
|
| 529 |
+
'<div class="table-label">π Correlation Matrix</div>',
|
| 530 |
+
unsafe_allow_html=True,
|
| 531 |
+
)
|
| 532 |
+
st.dataframe(
|
| 533 |
+
corr.style.background_gradient(cmap="Blues", axis=None),
|
| 534 |
+
use_container_width=True,
|
| 535 |
+
)
|
| 536 |
|
| 537 |
else:
|
| 538 |
desc = df.describe(include="all").T.reset_index()
|
| 539 |
desc.rename(columns={"index": "column"}, inplace=True)
|
| 540 |
+
st.markdown(
|
| 541 |
+
'<div class="table-label">π Statistical Summary</div>',
|
| 542 |
+
unsafe_allow_html=True,
|
| 543 |
+
)
|
| 544 |
st.dataframe(desc, use_container_width=True, hide_index=True)
|
| 545 |
|
| 546 |
except Exception as e:
|
|
|
|
| 553 |
def get_controller():
|
| 554 |
return DataExtractorController()
|
| 555 |
|
| 556 |
+
|
| 557 |
controller = get_controller()
|
| 558 |
|
| 559 |
# ββ Session state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 564 |
if "successful_queries" not in st.session_state:
|
| 565 |
st.session_state.successful_queries = 0
|
| 566 |
|
| 567 |
+
|
| 568 |
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 569 |
def build_message_history() -> list[Message]:
|
| 570 |
return [
|
|
|
|
| 572 |
for msg in st.session_state.chat_history
|
| 573 |
]
|
| 574 |
|
| 575 |
+
|
| 576 |
def call_controller(user_query: str):
|
| 577 |
uq = UserQuery(user_query=user_query)
|
| 578 |
history = build_message_history()
|
| 579 |
response = asyncio.run(controller.extrcat(user_query=uq, message_history=history))
|
| 580 |
return response
|
| 581 |
|
| 582 |
+
|
| 583 |
def render_message(msg):
|
| 584 |
is_user = msg["role"] == "user"
|
| 585 |
role_class = "user" if is_user else "bot"
|
| 586 |
avatar = "U" if is_user else "AI"
|
| 587 |
ts = msg.get("ts", "")
|
| 588 |
|
| 589 |
+
st.markdown(
|
| 590 |
+
f"""
|
| 591 |
<div class="msg-wrap {role_class}">
|
| 592 |
<div class="avatar {role_class}">{avatar}</div>
|
| 593 |
<div style="max-width:72%">
|
| 594 |
+
""",
|
| 595 |
+
unsafe_allow_html=True,
|
| 596 |
+
)
|
| 597 |
|
| 598 |
with st.container():
|
| 599 |
if "status" in msg:
|
|
|
|
| 601 |
label = "β success" if msg["status"] == "success" else "β error"
|
| 602 |
st.markdown(
|
| 603 |
f'<span class="badge {badge_class}">{label}</span>',
|
| 604 |
+
unsafe_allow_html=True,
|
| 605 |
)
|
| 606 |
|
| 607 |
st.markdown(msg["content"])
|
| 608 |
|
| 609 |
if msg.get("sql"):
|
| 610 |
+
st.markdown(
|
| 611 |
+
'<div class="sql-label">β‘ Generated SQL</div>', unsafe_allow_html=True
|
| 612 |
+
)
|
| 613 |
st.code(msg["sql"], language="sql")
|
| 614 |
|
| 615 |
st.markdown(f'<div class="ts">{ts}</div>', unsafe_allow_html=True)
|
|
|
|
| 617 |
|
| 618 |
# ββ Data table ββ
|
| 619 |
if msg.get("data") and len(msg["data"]) > 0:
|
| 620 |
+
st.markdown(
|
| 621 |
+
'<div class="table-label">π Query Results</div>', unsafe_allow_html=True
|
| 622 |
+
)
|
| 623 |
df = pd.DataFrame(msg["data"])
|
| 624 |
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 625 |
+
elif (
|
| 626 |
+
msg.get("status") == "success"
|
| 627 |
+
and "data" in msg
|
| 628 |
+
and len(msg.get("data", [])) == 0
|
| 629 |
+
):
|
| 630 |
st.markdown(
|
| 631 |
'<div style="color:#6b7280;font-size:0.8rem;margin-top:8px;'
|
| 632 |
'font-family:monospace">β Query returned 0 rows.</div>',
|
| 633 |
+
unsafe_allow_html=True,
|
| 634 |
)
|
| 635 |
|
| 636 |
|
|
|
|
| 669 |
|
| 670 |
|
| 671 |
# ββ Main layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 672 |
+
st.markdown(
|
| 673 |
+
"""
|
| 674 |
<div style="margin-bottom: 1.5rem;">
|
| 675 |
<div class="page-title">Firerms Data Extractor Chatbot</div>
|
| 676 |
<div class="page-sub">Natural language β SQL β Results</div>
|
| 677 |
</div>
|
| 678 |
+
""",
|
| 679 |
+
unsafe_allow_html=True,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
# ββ Chat area βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 683 |
chat_container = st.container()
|
| 684 |
with chat_container:
|
| 685 |
if not st.session_state.chat_history:
|
| 686 |
+
st.markdown(
|
| 687 |
+
"""
|
| 688 |
<div class="empty-state">
|
| 689 |
<div class="empty-icon">π</div>
|
| 690 |
<div class="empty-title">Ask anything about your data</div>
|
| 691 |
<div class="empty-hint">Type a natural language question and the AI will generate SQL and return results.</div>
|
| 692 |
</div>
|
| 693 |
+
""",
|
| 694 |
+
unsafe_allow_html=True,
|
| 695 |
+
)
|
| 696 |
else:
|
| 697 |
for msg in st.session_state.chat_history:
|
| 698 |
render_message(msg)
|
|
|
|
| 706 |
if prompt:
|
| 707 |
ts_now = datetime.now().strftime("%H:%M:%S")
|
| 708 |
|
| 709 |
+
st.session_state.chat_history.append(
|
| 710 |
+
{
|
| 711 |
+
"role": "user",
|
| 712 |
+
"content": prompt,
|
| 713 |
+
"ts": ts_now,
|
| 714 |
+
}
|
| 715 |
+
)
|
| 716 |
st.session_state.total_queries += 1
|
| 717 |
|
| 718 |
with st.spinner("Generating SQL and fetching resultsβ¦"):
|
| 719 |
try:
|
| 720 |
result = call_controller(prompt)
|
| 721 |
+
status = result.status
|
| 722 |
+
sql = result.sql_query
|
| 723 |
+
data = result.data or []
|
| 724 |
|
| 725 |
# ββ Extract best_suitable_chart from result.output (SQLQueryExtractor) ββ
|
| 726 |
# result.output.best_suitable_chart is a ChartsType enum β use .value for the string
|
| 727 |
try:
|
| 728 |
+
best_chart = (
|
| 729 |
+
result.output.best_suitable_chart.value
|
| 730 |
+
) # e.g. "PIE", "BAR", "LINE"
|
| 731 |
except Exception:
|
| 732 |
best_chart = None
|
| 733 |
|
|
|
|
| 738 |
else f"Query returned status: `{status}`."
|
| 739 |
)
|
| 740 |
|
| 741 |
+
st.session_state.chat_history.append(
|
| 742 |
+
{
|
| 743 |
+
"role": "assistant",
|
| 744 |
+
"content": content,
|
| 745 |
+
"sql": sql,
|
| 746 |
+
"data": data,
|
| 747 |
+
"status": status,
|
| 748 |
+
"best_suitable_chart": best_chart,
|
| 749 |
+
"ts": datetime.now().strftime("%H:%M:%S"),
|
| 750 |
+
}
|
| 751 |
+
)
|
| 752 |
|
| 753 |
if status == "success":
|
| 754 |
st.session_state.successful_queries += 1
|
| 755 |
|
| 756 |
except Exception as e:
|
| 757 |
+
st.session_state.chat_history.append(
|
| 758 |
+
{
|
| 759 |
+
"role": "assistant",
|
| 760 |
+
"content": f"β Error: {str(e)}",
|
| 761 |
+
"status": "error",
|
| 762 |
+
"ts": datetime.now().strftime("%H:%M:%S"),
|
| 763 |
+
}
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
st.rerun()
|