Aryan Jain commited on
Commit
dca20c8
Β·
1 Parent(s): 9fdba00

show charts

Browse files
src/controllers/_data_extractor.py CHANGED
@@ -18,6 +18,7 @@ class Response(BaseModel):
18
  sql_query: str
19
  output: SQLQueryExtractor
20
 
 
21
  class DataExtractorController:
22
  def __init__(self):
23
  self.router = APIRouter()
@@ -36,7 +37,9 @@ class DataExtractorController:
36
  response, sql_query, output = await service.extract(
37
  user_query=user_query.user_query, message_history=message_history
38
  )
39
- return Response(status="success", data=response, sql_query=sql_query, output=output)
 
 
40
  except HTTPException as e:
41
  logger.error(e)
42
  raise e
 
18
  sql_query: str
19
  output: SQLQueryExtractor
20
 
21
+
22
  class DataExtractorController:
23
  def __init__(self):
24
  self.router = APIRouter()
 
37
  response, sql_query, output = await service.extract(
38
  user_query=user_query.user_query, message_history=message_history
39
  )
40
+ return Response(
41
+ status="success", data=response, sql_query=sql_query, output=output
42
+ )
43
  except HTTPException as e:
44
  logger.error(e)
45
  raise e
src/schemas/_pydantic_agent.py CHANGED
@@ -118,12 +118,13 @@ class OrderConditions(BaseModel):
118
  "Use 'ASC' for: oldest, lowest, alphabetical, earliest."
119
  ),
120
  )
121
-
122
-
123
  class ChartsType(Enum):
124
  """
125
  Enumeration of supported chart types for visualization.
126
  """
 
127
  PIE: str = "PIE"
128
  BAR: str = "BAR"
129
  LINE: str = "LINE"
@@ -221,7 +222,7 @@ class SQLQueryExtractor(BaseModel):
221
  "This must be set True if provided tool return true else set False. "
222
  ),
223
  )
224
-
225
  best_suitable_chart: ChartsType = Field(
226
  ...,
227
  description=(
@@ -233,7 +234,6 @@ class SQLQueryExtractor(BaseModel):
233
  ),
234
  )
235
 
236
-
237
  @model_validator(mode="after")
238
  def validate_sql_query(self):
239
  if not self.is_sql_query_verified_using_provided_tool:
 
118
  "Use 'ASC' for: oldest, lowest, alphabetical, earliest."
119
  ),
120
  )
121
+
122
+
123
  class ChartsType(Enum):
124
  """
125
  Enumeration of supported chart types for visualization.
126
  """
127
+
128
  PIE: str = "PIE"
129
  BAR: str = "BAR"
130
  LINE: str = "LINE"
 
222
  "This must be set True if provided tool return true else set False. "
223
  ),
224
  )
225
+
226
  best_suitable_chart: ChartsType = Field(
227
  ...,
228
  description=(
 
234
  ),
235
  )
236
 
 
237
  @model_validator(mode="after")
238
  def validate_sql_query(self):
239
  if not self.is_sql_query_verified_using_provided_tool:
src/streamlit_app.py CHANGED
@@ -9,8 +9,8 @@ import plotly.graph_objects as go
9
  from datetime import datetime
10
 
11
  # ── Path setup so `src` is importable when running from src/ or project root ──
12
- _here = Path(__file__).resolve().parent # src/
13
- _root = _here.parent # project root
14
  for _p in [str(_here), str(_root)]:
15
  if _p not in sys.path:
16
  sys.path.insert(0, _p)
@@ -27,7 +27,8 @@ st.set_page_config(
27
  )
28
 
29
  # ── Custom CSS ────────────────────────────────────────────────────────────────
30
- st.markdown("""
 
31
  <style>
32
  @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=Syne:wght@400;600;700;800&display=swap');
33
 
@@ -259,7 +260,9 @@ html, body, [class*="css"] {
259
  font-weight: 600; margin-bottom: 8px; font-family: var(--mono);
260
  }
261
  </style>
262
- """, unsafe_allow_html=True)
 
 
263
 
264
  # ── Plotly dark theme template ────────────────────────────────────────────────
265
  PLOTLY_TEMPLATE = dict(
@@ -269,8 +272,16 @@ PLOTLY_TEMPLATE = dict(
269
  font=dict(color="#e8eaf0", family="JetBrains Mono, monospace", size=11),
270
  xaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
271
  yaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
272
- colorway=["#4f8ef7", "#7c5cfc", "#22d3a5", "#f75f5f", "#f7a24f",
273
- "#e879f9", "#38bdf8", "#fb923c"],
 
 
 
 
 
 
 
 
274
  legend=dict(bgcolor="#1c2030", bordercolor="#2a2f42", borderwidth=1),
275
  margin=dict(l=40, r=20, t=40, b=40),
276
  )
@@ -281,14 +292,20 @@ PLOTLY_TEMPLATE = dict(
281
 
282
  CHART_ALIASES = {
283
  # ── ChartsType enum values (uppercase from Pydantic/Enum .value) ──
284
- "pie": "pie",
285
- "bar": "bar",
286
  "line": "line",
287
  # ── Common string variants (lowercase) ──
288
- "bar_chart": "bar", "vertical_bar": "bar", "column": "bar",
289
- "grouped_bar": "bar", "stacked_bar": "bar",
290
- "line_chart": "line", "time_series": "line", "trend": "line",
291
- "donut": "pie", "doughnut": "pie",
 
 
 
 
 
 
292
  }
293
 
294
 
@@ -322,7 +339,9 @@ def _guess_chart_type(df: pd.DataFrame) -> str:
322
  return "bar"
323
 
324
 
325
- def render_chart(df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix: str = "chart"):
 
 
326
  """
327
  Render a Plotly chart with user-controlled column selectors.
328
  The user picks x, y, and optional color columns via st.selectbox.
@@ -334,19 +353,19 @@ def render_chart(df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix
334
  return
335
 
336
  chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df)
337
- cols = list(df.columns)
338
  numeric = df.select_dtypes(include="number").columns.tolist()
339
- cat = df.select_dtypes(exclude="number").columns.tolist()
340
 
341
  # ── Smart defaults: pre-select the most sensible column per role ─────────────────────
342
- default_x = cat[0] if cat else cols[0]
343
  default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0])
344
 
345
  # ── Column selector UI ─────────────────────────────────────────────────────────────────
346
  st.markdown(
347
  '<div style="font-size:0.7rem;color:var(--muted);font-family:var(--mono);'
348
  'text-transform:uppercase;letter-spacing:0.08em;margin-bottom:8px;">'
349
- 'βš™οΈ Configure columns</div>',
350
  unsafe_allow_html=True,
351
  )
352
 
@@ -408,36 +427,60 @@ def render_chart(df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix
408
  try:
409
  if chart_type == "bar":
410
  if color_col:
411
- fig = px.bar(df, x=x_col, y=y_col, color=color_col,
412
- barmode="group", template=PLOTLY_TEMPLATE)
 
 
 
 
 
 
413
  else:
414
  fig = px.bar(df, x=x_col, y=y_col, template=PLOTLY_TEMPLATE)
415
 
416
  elif chart_type == "line":
417
  if color_col:
418
- fig = px.line(df, x=x_col, y=y_col, color=color_col,
419
- markers=True, template=PLOTLY_TEMPLATE)
 
 
 
 
 
 
420
  else:
421
- fig = px.line(df, x=x_col, y=y_col, markers=True,
422
- template=PLOTLY_TEMPLATE)
 
423
 
424
  elif chart_type == "pie":
425
- fig = px.pie(df, names=x_col, values=y_col,
426
- hole=0.35, template=PLOTLY_TEMPLATE)
 
427
  fig.update_traces(textinfo="percent+label")
428
 
429
  else: # unrecognised / None β†’ heuristic fallback bar
430
- fig = px.bar(df, x=x_col, y=y_col, template=PLOTLY_TEMPLATE,
431
- title=f"Chart type '{chart_type_raw}' not recognized β€” showing bar chart")
 
 
 
 
 
432
 
433
  except Exception as e:
434
- st.warning(f"⚠ Could not render `{chart_type}` chart: {e}. Check the column types selected above.")
 
 
435
  return
436
 
437
  if fig:
438
- fig.update_layout(paper_bgcolor="#141720", plot_bgcolor="#0d0f14", font_color="#e8eaf0")
 
 
439
  st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
440
 
 
441
  def render_crosstab(df: pd.DataFrame):
442
  """
443
  Auto-build a crosstab-style pivot summary.
@@ -453,41 +496,51 @@ def render_crosstab(df: pd.DataFrame):
453
  return
454
 
455
  numeric = df.select_dtypes(include="number").columns.tolist()
456
- cat = df.select_dtypes(exclude="number").columns.tolist()
457
 
458
  try:
459
  if len(cat) >= 2 and len(numeric) >= 1:
460
  pivot = df.pivot_table(
461
- index=cat[0], columns=cat[1],
462
- values=numeric[0], aggfunc="sum", fill_value=0,
 
 
 
 
 
 
 
463
  )
464
- st.markdown(f'<div class="table-label">πŸ“ Crosstab β€” {cat[0]} Γ— {cat[1]} (sum of {numeric[0]})</div>',
465
- unsafe_allow_html=True)
466
  st.dataframe(pivot, use_container_width=True)
467
 
468
  elif len(cat) == 1 and len(numeric) >= 1:
469
- summary = (
470
- df.groupby(cat[0])[numeric]
471
- .agg(["sum", "mean", "count"])
472
- )
473
  summary.columns = [f"{v}_{f}" for v, f in summary.columns]
474
  summary = summary.reset_index()
475
- st.markdown(f'<div class="table-label">πŸ“ Summary β€” grouped by {cat[0]}</div>',
476
- unsafe_allow_html=True)
 
 
477
  st.dataframe(summary, use_container_width=True, hide_index=True)
478
 
479
  elif len(numeric) >= 2:
480
  corr = df[numeric].corr().round(3)
481
- st.markdown('<div class="table-label">πŸ“ Correlation Matrix</div>',
482
- unsafe_allow_html=True)
483
- st.dataframe(corr.style.background_gradient(cmap="Blues", axis=None),
484
- use_container_width=True)
 
 
 
 
485
 
486
  else:
487
  desc = df.describe(include="all").T.reset_index()
488
  desc.rename(columns={"index": "column"}, inplace=True)
489
- st.markdown('<div class="table-label">πŸ“ Statistical Summary</div>',
490
- unsafe_allow_html=True)
 
 
491
  st.dataframe(desc, use_container_width=True, hide_index=True)
492
 
493
  except Exception as e:
@@ -500,6 +553,7 @@ def render_crosstab(df: pd.DataFrame):
500
  def get_controller():
501
  return DataExtractorController()
502
 
 
503
  controller = get_controller()
504
 
505
  # ── Session state ────────────────────────────────────────────────────────────
@@ -510,6 +564,7 @@ if "total_queries" not in st.session_state:
510
  if "successful_queries" not in st.session_state:
511
  st.session_state.successful_queries = 0
512
 
 
513
  # ── Helpers ──────────────────────────────────────────────────────────────────
514
  def build_message_history() -> list[Message]:
515
  return [
@@ -517,23 +572,28 @@ def build_message_history() -> list[Message]:
517
  for msg in st.session_state.chat_history
518
  ]
519
 
 
520
  def call_controller(user_query: str):
521
  uq = UserQuery(user_query=user_query)
522
  history = build_message_history()
523
  response = asyncio.run(controller.extrcat(user_query=uq, message_history=history))
524
  return response
525
 
 
526
  def render_message(msg):
527
  is_user = msg["role"] == "user"
528
  role_class = "user" if is_user else "bot"
529
  avatar = "U" if is_user else "AI"
530
  ts = msg.get("ts", "")
531
 
532
- st.markdown(f"""
 
533
  <div class="msg-wrap {role_class}">
534
  <div class="avatar {role_class}">{avatar}</div>
535
  <div style="max-width:72%">
536
- """, unsafe_allow_html=True)
 
 
537
 
538
  with st.container():
539
  if "status" in msg:
@@ -541,13 +601,15 @@ def render_message(msg):
541
  label = "βœ“ success" if msg["status"] == "success" else "βœ— error"
542
  st.markdown(
543
  f'<span class="badge {badge_class}">{label}</span>',
544
- unsafe_allow_html=True
545
  )
546
 
547
  st.markdown(msg["content"])
548
 
549
  if msg.get("sql"):
550
- st.markdown('<div class="sql-label">⚑ Generated SQL</div>', unsafe_allow_html=True)
 
 
551
  st.code(msg["sql"], language="sql")
552
 
553
  st.markdown(f'<div class="ts">{ts}</div>', unsafe_allow_html=True)
@@ -555,14 +617,20 @@ def render_message(msg):
555
 
556
  # ── Data table ──
557
  if msg.get("data") and len(msg["data"]) > 0:
558
- st.markdown('<div class="table-label">πŸ“Š Query Results</div>', unsafe_allow_html=True)
 
 
559
  df = pd.DataFrame(msg["data"])
560
  st.dataframe(df, use_container_width=True, hide_index=True)
561
- elif msg.get("status") == "success" and "data" in msg and len(msg.get("data", [])) == 0:
 
 
 
 
562
  st.markdown(
563
  '<div style="color:#6b7280;font-size:0.8rem;margin-top:8px;'
564
  'font-family:monospace">⚠ Query returned 0 rows.</div>',
565
- unsafe_allow_html=True
566
  )
567
 
568
 
@@ -601,24 +669,30 @@ with st.sidebar:
601
 
602
 
603
  # ── Main layout ───────────────────────────────────────────────────────────────
604
- st.markdown("""
 
605
  <div style="margin-bottom: 1.5rem;">
606
  <div class="page-title">Firerms Data Extractor Chatbot</div>
607
  <div class="page-sub">Natural language β†’ SQL β†’ Results</div>
608
  </div>
609
- """, unsafe_allow_html=True)
 
 
610
 
611
  # ── Chat area ─────────────────────────────────────────────────────────────────
612
  chat_container = st.container()
613
  with chat_container:
614
  if not st.session_state.chat_history:
615
- st.markdown("""
 
616
  <div class="empty-state">
617
  <div class="empty-icon">πŸ”</div>
618
  <div class="empty-title">Ask anything about your data</div>
619
  <div class="empty-hint">Type a natural language question and the AI will generate SQL and return results.</div>
620
  </div>
621
- """, unsafe_allow_html=True)
 
 
622
  else:
623
  for msg in st.session_state.chat_history:
624
  render_message(msg)
@@ -632,24 +706,28 @@ if not prompt and prefill:
632
  if prompt:
633
  ts_now = datetime.now().strftime("%H:%M:%S")
634
 
635
- st.session_state.chat_history.append({
636
- "role": "user",
637
- "content": prompt,
638
- "ts": ts_now,
639
- })
 
 
640
  st.session_state.total_queries += 1
641
 
642
  with st.spinner("Generating SQL and fetching results…"):
643
  try:
644
  result = call_controller(prompt)
645
- status = result.status
646
- sql = result.sql_query
647
- data = result.data or []
648
 
649
  # ── Extract best_suitable_chart from result.output (SQLQueryExtractor) ──
650
  # result.output.best_suitable_chart is a ChartsType enum β†’ use .value for the string
651
  try:
652
- best_chart = result.output.best_suitable_chart.value # e.g. "PIE", "BAR", "LINE"
 
 
653
  except Exception:
654
  best_chart = None
655
 
@@ -660,25 +738,29 @@ if prompt:
660
  else f"Query returned status: `{status}`."
661
  )
662
 
663
- st.session_state.chat_history.append({
664
- "role": "assistant",
665
- "content": content,
666
- "sql": sql,
667
- "data": data,
668
- "status": status,
669
- "best_suitable_chart": best_chart,
670
- "ts": datetime.now().strftime("%H:%M:%S"),
671
- })
 
 
672
 
673
  if status == "success":
674
  st.session_state.successful_queries += 1
675
 
676
  except Exception as e:
677
- st.session_state.chat_history.append({
678
- "role": "assistant",
679
- "content": f"❌ Error: {str(e)}",
680
- "status": "error",
681
- "ts": datetime.now().strftime("%H:%M:%S"),
682
- })
683
-
684
- st.rerun()
 
 
 
9
  from datetime import datetime
10
 
11
  # ── Path setup so `src` is importable when running from src/ or project root ──
12
+ _here = Path(__file__).resolve().parent # src/
13
+ _root = _here.parent # project root
14
  for _p in [str(_here), str(_root)]:
15
  if _p not in sys.path:
16
  sys.path.insert(0, _p)
 
27
  )
28
 
29
  # ── Custom CSS ────────────────────────────────────────────────────────────────
30
+ st.markdown(
31
+ """
32
  <style>
33
  @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=Syne:wght@400;600;700;800&display=swap');
34
 
 
260
  font-weight: 600; margin-bottom: 8px; font-family: var(--mono);
261
  }
262
  </style>
263
+ """,
264
+ unsafe_allow_html=True,
265
+ )
266
 
267
  # ── Plotly dark theme template ────────────────────────────────────────────────
268
  PLOTLY_TEMPLATE = dict(
 
272
  font=dict(color="#e8eaf0", family="JetBrains Mono, monospace", size=11),
273
  xaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
274
  yaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
275
+ colorway=[
276
+ "#4f8ef7",
277
+ "#7c5cfc",
278
+ "#22d3a5",
279
+ "#f75f5f",
280
+ "#f7a24f",
281
+ "#e879f9",
282
+ "#38bdf8",
283
+ "#fb923c",
284
+ ],
285
  legend=dict(bgcolor="#1c2030", bordercolor="#2a2f42", borderwidth=1),
286
  margin=dict(l=40, r=20, t=40, b=40),
287
  )
 
292
 
293
  CHART_ALIASES = {
294
  # ── ChartsType enum values (uppercase from Pydantic/Enum .value) ──
295
+ "pie": "pie",
296
+ "bar": "bar",
297
  "line": "line",
298
  # ── Common string variants (lowercase) ──
299
+ "bar_chart": "bar",
300
+ "vertical_bar": "bar",
301
+ "column": "bar",
302
+ "grouped_bar": "bar",
303
+ "stacked_bar": "bar",
304
+ "line_chart": "line",
305
+ "time_series": "line",
306
+ "trend": "line",
307
+ "donut": "pie",
308
+ "doughnut": "pie",
309
  }
310
 
311
 
 
339
  return "bar"
340
 
341
 
342
+ def render_chart(
343
+ df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix: str = "chart"
344
+ ):
345
  """
346
  Render a Plotly chart with user-controlled column selectors.
347
  The user picks x, y, and optional color columns via st.selectbox.
 
353
  return
354
 
355
  chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df)
356
+ cols = list(df.columns)
357
  numeric = df.select_dtypes(include="number").columns.tolist()
358
+ cat = df.select_dtypes(exclude="number").columns.tolist()
359
 
360
  # ── Smart defaults: pre-select the most sensible column per role ─────────────────────
361
+ default_x = cat[0] if cat else cols[0]
362
  default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0])
363
 
364
  # ── Column selector UI ─────────────────────────────────────────────────────────────────
365
  st.markdown(
366
  '<div style="font-size:0.7rem;color:var(--muted);font-family:var(--mono);'
367
  'text-transform:uppercase;letter-spacing:0.08em;margin-bottom:8px;">'
368
+ "βš™οΈ Configure columns</div>",
369
  unsafe_allow_html=True,
370
  )
371
 
 
427
  try:
428
  if chart_type == "bar":
429
  if color_col:
430
+ fig = px.bar(
431
+ df,
432
+ x=x_col,
433
+ y=y_col,
434
+ color=color_col,
435
+ barmode="group",
436
+ template=PLOTLY_TEMPLATE,
437
+ )
438
  else:
439
  fig = px.bar(df, x=x_col, y=y_col, template=PLOTLY_TEMPLATE)
440
 
441
  elif chart_type == "line":
442
  if color_col:
443
+ fig = px.line(
444
+ df,
445
+ x=x_col,
446
+ y=y_col,
447
+ color=color_col,
448
+ markers=True,
449
+ template=PLOTLY_TEMPLATE,
450
+ )
451
  else:
452
+ fig = px.line(
453
+ df, x=x_col, y=y_col, markers=True, template=PLOTLY_TEMPLATE
454
+ )
455
 
456
  elif chart_type == "pie":
457
+ fig = px.pie(
458
+ df, names=x_col, values=y_col, hole=0.35, template=PLOTLY_TEMPLATE
459
+ )
460
  fig.update_traces(textinfo="percent+label")
461
 
462
  else: # unrecognised / None β†’ heuristic fallback bar
463
+ fig = px.bar(
464
+ df,
465
+ x=x_col,
466
+ y=y_col,
467
+ template=PLOTLY_TEMPLATE,
468
+ title=f"Chart type '{chart_type_raw}' not recognized β€” showing bar chart",
469
+ )
470
 
471
  except Exception as e:
472
+ st.warning(
473
+ f"⚠ Could not render `{chart_type}` chart: {e}. Check the column types selected above."
474
+ )
475
  return
476
 
477
  if fig:
478
+ fig.update_layout(
479
+ paper_bgcolor="#141720", plot_bgcolor="#0d0f14", font_color="#e8eaf0"
480
+ )
481
  st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
482
 
483
+
484
  def render_crosstab(df: pd.DataFrame):
485
  """
486
  Auto-build a crosstab-style pivot summary.
 
496
  return
497
 
498
  numeric = df.select_dtypes(include="number").columns.tolist()
499
+ cat = df.select_dtypes(exclude="number").columns.tolist()
500
 
501
  try:
502
  if len(cat) >= 2 and len(numeric) >= 1:
503
  pivot = df.pivot_table(
504
+ index=cat[0],
505
+ columns=cat[1],
506
+ values=numeric[0],
507
+ aggfunc="sum",
508
+ fill_value=0,
509
+ )
510
+ st.markdown(
511
+ f'<div class="table-label">πŸ“ Crosstab β€” {cat[0]} Γ— {cat[1]} (sum of {numeric[0]})</div>',
512
+ unsafe_allow_html=True,
513
  )
 
 
514
  st.dataframe(pivot, use_container_width=True)
515
 
516
  elif len(cat) == 1 and len(numeric) >= 1:
517
+ summary = df.groupby(cat[0])[numeric].agg(["sum", "mean", "count"])
 
 
 
518
  summary.columns = [f"{v}_{f}" for v, f in summary.columns]
519
  summary = summary.reset_index()
520
+ st.markdown(
521
+ f'<div class="table-label">πŸ“ Summary β€” grouped by {cat[0]}</div>',
522
+ unsafe_allow_html=True,
523
+ )
524
  st.dataframe(summary, use_container_width=True, hide_index=True)
525
 
526
  elif len(numeric) >= 2:
527
  corr = df[numeric].corr().round(3)
528
+ st.markdown(
529
+ '<div class="table-label">πŸ“ Correlation Matrix</div>',
530
+ unsafe_allow_html=True,
531
+ )
532
+ st.dataframe(
533
+ corr.style.background_gradient(cmap="Blues", axis=None),
534
+ use_container_width=True,
535
+ )
536
 
537
  else:
538
  desc = df.describe(include="all").T.reset_index()
539
  desc.rename(columns={"index": "column"}, inplace=True)
540
+ st.markdown(
541
+ '<div class="table-label">πŸ“ Statistical Summary</div>',
542
+ unsafe_allow_html=True,
543
+ )
544
  st.dataframe(desc, use_container_width=True, hide_index=True)
545
 
546
  except Exception as e:
 
553
  def get_controller():
554
  return DataExtractorController()
555
 
556
+
557
  controller = get_controller()
558
 
559
  # ── Session state ────────────────────────────────────────────────────────────
 
564
  if "successful_queries" not in st.session_state:
565
  st.session_state.successful_queries = 0
566
 
567
+
568
  # ── Helpers ──────────────────────────────────────────────────────────────────
569
  def build_message_history() -> list[Message]:
570
  return [
 
572
  for msg in st.session_state.chat_history
573
  ]
574
 
575
+
576
  def call_controller(user_query: str):
577
  uq = UserQuery(user_query=user_query)
578
  history = build_message_history()
579
  response = asyncio.run(controller.extrcat(user_query=uq, message_history=history))
580
  return response
581
 
582
+
583
  def render_message(msg):
584
  is_user = msg["role"] == "user"
585
  role_class = "user" if is_user else "bot"
586
  avatar = "U" if is_user else "AI"
587
  ts = msg.get("ts", "")
588
 
589
+ st.markdown(
590
+ f"""
591
  <div class="msg-wrap {role_class}">
592
  <div class="avatar {role_class}">{avatar}</div>
593
  <div style="max-width:72%">
594
+ """,
595
+ unsafe_allow_html=True,
596
+ )
597
 
598
  with st.container():
599
  if "status" in msg:
 
601
  label = "βœ“ success" if msg["status"] == "success" else "βœ— error"
602
  st.markdown(
603
  f'<span class="badge {badge_class}">{label}</span>',
604
+ unsafe_allow_html=True,
605
  )
606
 
607
  st.markdown(msg["content"])
608
 
609
  if msg.get("sql"):
610
+ st.markdown(
611
+ '<div class="sql-label">⚑ Generated SQL</div>', unsafe_allow_html=True
612
+ )
613
  st.code(msg["sql"], language="sql")
614
 
615
  st.markdown(f'<div class="ts">{ts}</div>', unsafe_allow_html=True)
 
617
 
618
  # ── Data table ──
619
  if msg.get("data") and len(msg["data"]) > 0:
620
+ st.markdown(
621
+ '<div class="table-label">πŸ“Š Query Results</div>', unsafe_allow_html=True
622
+ )
623
  df = pd.DataFrame(msg["data"])
624
  st.dataframe(df, use_container_width=True, hide_index=True)
625
+ elif (
626
+ msg.get("status") == "success"
627
+ and "data" in msg
628
+ and len(msg.get("data", [])) == 0
629
+ ):
630
  st.markdown(
631
  '<div style="color:#6b7280;font-size:0.8rem;margin-top:8px;'
632
  'font-family:monospace">⚠ Query returned 0 rows.</div>',
633
+ unsafe_allow_html=True,
634
  )
635
 
636
 
 
669
 
670
 
671
  # ── Main layout ───────────────────────────────────────────────────────────────
672
+ st.markdown(
673
+ """
674
  <div style="margin-bottom: 1.5rem;">
675
  <div class="page-title">Firerms Data Extractor Chatbot</div>
676
  <div class="page-sub">Natural language β†’ SQL β†’ Results</div>
677
  </div>
678
+ """,
679
+ unsafe_allow_html=True,
680
+ )
681
 
682
  # ── Chat area ─────────────────────────────────────────────────────────────────
683
  chat_container = st.container()
684
  with chat_container:
685
  if not st.session_state.chat_history:
686
+ st.markdown(
687
+ """
688
  <div class="empty-state">
689
  <div class="empty-icon">πŸ”</div>
690
  <div class="empty-title">Ask anything about your data</div>
691
  <div class="empty-hint">Type a natural language question and the AI will generate SQL and return results.</div>
692
  </div>
693
+ """,
694
+ unsafe_allow_html=True,
695
+ )
696
  else:
697
  for msg in st.session_state.chat_history:
698
  render_message(msg)
 
706
  if prompt:
707
  ts_now = datetime.now().strftime("%H:%M:%S")
708
 
709
+ st.session_state.chat_history.append(
710
+ {
711
+ "role": "user",
712
+ "content": prompt,
713
+ "ts": ts_now,
714
+ }
715
+ )
716
  st.session_state.total_queries += 1
717
 
718
  with st.spinner("Generating SQL and fetching results…"):
719
  try:
720
  result = call_controller(prompt)
721
+ status = result.status
722
+ sql = result.sql_query
723
+ data = result.data or []
724
 
725
  # ── Extract best_suitable_chart from result.output (SQLQueryExtractor) ──
726
  # result.output.best_suitable_chart is a ChartsType enum β†’ use .value for the string
727
  try:
728
+ best_chart = (
729
+ result.output.best_suitable_chart.value
730
+ ) # e.g. "PIE", "BAR", "LINE"
731
  except Exception:
732
  best_chart = None
733
 
 
738
  else f"Query returned status: `{status}`."
739
  )
740
 
741
+ st.session_state.chat_history.append(
742
+ {
743
+ "role": "assistant",
744
+ "content": content,
745
+ "sql": sql,
746
+ "data": data,
747
+ "status": status,
748
+ "best_suitable_chart": best_chart,
749
+ "ts": datetime.now().strftime("%H:%M:%S"),
750
+ }
751
+ )
752
 
753
  if status == "success":
754
  st.session_state.successful_queries += 1
755
 
756
  except Exception as e:
757
+ st.session_state.chat_history.append(
758
+ {
759
+ "role": "assistant",
760
+ "content": f"❌ Error: {str(e)}",
761
+ "status": "error",
762
+ "ts": datetime.now().strftime("%H:%M:%S"),
763
+ }
764
+ )
765
+
766
+ st.rerun()