AryanSifars commited on
Commit
6f30bbd
Β·
verified Β·
1 Parent(s): dca20c8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +121 -169
src/streamlit_app.py CHANGED
@@ -9,8 +9,8 @@ import plotly.graph_objects as go
9
  from datetime import datetime
10
 
11
  # ── Path setup so `src` is importable when running from src/ or project root ──
12
- _here = Path(__file__).resolve().parent # src/
13
- _root = _here.parent # project root
14
  for _p in [str(_here), str(_root)]:
15
  if _p not in sys.path:
16
  sys.path.insert(0, _p)
@@ -27,8 +27,7 @@ st.set_page_config(
27
  )
28
 
29
  # ── Custom CSS ────────────────────────────────────────────────────────────────
30
- st.markdown(
31
- """
32
  <style>
33
  @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=Syne:wght@400;600;700;800&display=swap');
34
 
@@ -260,9 +259,7 @@ html, body, [class*="css"] {
260
  font-weight: 600; margin-bottom: 8px; font-family: var(--mono);
261
  }
262
  </style>
263
- """,
264
- unsafe_allow_html=True,
265
- )
266
 
267
  # ── Plotly dark theme template ────────────────────────────────────────────────
268
  PLOTLY_TEMPLATE = dict(
@@ -272,16 +269,8 @@ PLOTLY_TEMPLATE = dict(
272
  font=dict(color="#e8eaf0", family="JetBrains Mono, monospace", size=11),
273
  xaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
274
  yaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
275
- colorway=[
276
- "#4f8ef7",
277
- "#7c5cfc",
278
- "#22d3a5",
279
- "#f75f5f",
280
- "#f7a24f",
281
- "#e879f9",
282
- "#38bdf8",
283
- "#fb923c",
284
- ],
285
  legend=dict(bgcolor="#1c2030", bordercolor="#2a2f42", borderwidth=1),
286
  margin=dict(l=40, r=20, t=40, b=40),
287
  )
@@ -292,20 +281,14 @@ PLOTLY_TEMPLATE = dict(
292
 
293
  CHART_ALIASES = {
294
  # ── ChartsType enum values (uppercase from Pydantic/Enum .value) ──
295
- "pie": "pie",
296
- "bar": "bar",
297
  "line": "line",
298
  # ── Common string variants (lowercase) ──
299
- "bar_chart": "bar",
300
- "vertical_bar": "bar",
301
- "column": "bar",
302
- "grouped_bar": "bar",
303
- "stacked_bar": "bar",
304
- "line_chart": "line",
305
- "time_series": "line",
306
- "trend": "line",
307
- "donut": "pie",
308
- "doughnut": "pie",
309
  }
310
 
311
 
@@ -339,9 +322,7 @@ def _guess_chart_type(df: pd.DataFrame) -> str:
339
  return "bar"
340
 
341
 
342
- def render_chart(
343
- df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix: str = "chart"
344
- ):
345
  """
346
  Render a Plotly chart with user-controlled column selectors.
347
  The user picks x, y, and optional color columns via st.selectbox.
@@ -353,19 +334,19 @@ def render_chart(
353
  return
354
 
355
  chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df)
356
- cols = list(df.columns)
357
  numeric = df.select_dtypes(include="number").columns.tolist()
358
- cat = df.select_dtypes(exclude="number").columns.tolist()
359
 
360
  # ── Smart defaults: pre-select the most sensible column per role ─────────────────────
361
- default_x = cat[0] if cat else cols[0]
362
  default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0])
363
 
364
  # ── Column selector UI ─────────────────────────────────────────────────────────────────
365
  st.markdown(
366
  '<div style="font-size:0.7rem;color:var(--muted);font-family:var(--mono);'
367
  'text-transform:uppercase;letter-spacing:0.08em;margin-bottom:8px;">'
368
- "βš™οΈ Configure columns</div>",
369
  unsafe_allow_html=True,
370
  )
371
 
@@ -427,60 +408,36 @@ def render_chart(
427
  try:
428
  if chart_type == "bar":
429
  if color_col:
430
- fig = px.bar(
431
- df,
432
- x=x_col,
433
- y=y_col,
434
- color=color_col,
435
- barmode="group",
436
- template=PLOTLY_TEMPLATE,
437
- )
438
  else:
439
  fig = px.bar(df, x=x_col, y=y_col, template=PLOTLY_TEMPLATE)
440
 
441
  elif chart_type == "line":
442
  if color_col:
443
- fig = px.line(
444
- df,
445
- x=x_col,
446
- y=y_col,
447
- color=color_col,
448
- markers=True,
449
- template=PLOTLY_TEMPLATE,
450
- )
451
  else:
452
- fig = px.line(
453
- df, x=x_col, y=y_col, markers=True, template=PLOTLY_TEMPLATE
454
- )
455
 
456
  elif chart_type == "pie":
457
- fig = px.pie(
458
- df, names=x_col, values=y_col, hole=0.35, template=PLOTLY_TEMPLATE
459
- )
460
  fig.update_traces(textinfo="percent+label")
461
 
462
  else: # unrecognised / None β†’ heuristic fallback bar
463
- fig = px.bar(
464
- df,
465
- x=x_col,
466
- y=y_col,
467
- template=PLOTLY_TEMPLATE,
468
- title=f"Chart type '{chart_type_raw}' not recognized β€” showing bar chart",
469
- )
470
 
471
  except Exception as e:
472
- st.warning(
473
- f"⚠ Could not render `{chart_type}` chart: {e}. Check the column types selected above."
474
- )
475
  return
476
 
477
  if fig:
478
- fig.update_layout(
479
- paper_bgcolor="#141720", plot_bgcolor="#0d0f14", font_color="#e8eaf0"
480
- )
481
  st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
482
 
483
-
484
  def render_crosstab(df: pd.DataFrame):
485
  """
486
  Auto-build a crosstab-style pivot summary.
@@ -496,51 +453,41 @@ def render_crosstab(df: pd.DataFrame):
496
  return
497
 
498
  numeric = df.select_dtypes(include="number").columns.tolist()
499
- cat = df.select_dtypes(exclude="number").columns.tolist()
500
 
501
  try:
502
  if len(cat) >= 2 and len(numeric) >= 1:
503
  pivot = df.pivot_table(
504
- index=cat[0],
505
- columns=cat[1],
506
- values=numeric[0],
507
- aggfunc="sum",
508
- fill_value=0,
509
- )
510
- st.markdown(
511
- f'<div class="table-label">πŸ“ Crosstab β€” {cat[0]} Γ— {cat[1]} (sum of {numeric[0]})</div>',
512
- unsafe_allow_html=True,
513
  )
 
 
514
  st.dataframe(pivot, use_container_width=True)
515
 
516
  elif len(cat) == 1 and len(numeric) >= 1:
517
- summary = df.groupby(cat[0])[numeric].agg(["sum", "mean", "count"])
 
 
 
518
  summary.columns = [f"{v}_{f}" for v, f in summary.columns]
519
  summary = summary.reset_index()
520
- st.markdown(
521
- f'<div class="table-label">πŸ“ Summary β€” grouped by {cat[0]}</div>',
522
- unsafe_allow_html=True,
523
- )
524
  st.dataframe(summary, use_container_width=True, hide_index=True)
525
 
526
  elif len(numeric) >= 2:
527
  corr = df[numeric].corr().round(3)
528
- st.markdown(
529
- '<div class="table-label">πŸ“ Correlation Matrix</div>',
530
- unsafe_allow_html=True,
531
- )
532
- st.dataframe(
533
- corr.style.background_gradient(cmap="Blues", axis=None),
534
- use_container_width=True,
535
- )
536
 
537
  else:
538
  desc = df.describe(include="all").T.reset_index()
539
  desc.rename(columns={"index": "column"}, inplace=True)
540
- st.markdown(
541
- '<div class="table-label">πŸ“ Statistical Summary</div>',
542
- unsafe_allow_html=True,
543
- )
544
  st.dataframe(desc, use_container_width=True, hide_index=True)
545
 
546
  except Exception as e:
@@ -553,7 +500,6 @@ def render_crosstab(df: pd.DataFrame):
553
  def get_controller():
554
  return DataExtractorController()
555
 
556
-
557
  controller = get_controller()
558
 
559
  # ── Session state ────────────────────────────────────────────────────────────
@@ -564,7 +510,6 @@ if "total_queries" not in st.session_state:
564
  if "successful_queries" not in st.session_state:
565
  st.session_state.successful_queries = 0
566
 
567
-
568
  # ── Helpers ──────────────────────────────────────────────────────────────────
569
  def build_message_history() -> list[Message]:
570
  return [
@@ -572,28 +517,23 @@ def build_message_history() -> list[Message]:
572
  for msg in st.session_state.chat_history
573
  ]
574
 
575
-
576
  def call_controller(user_query: str):
577
  uq = UserQuery(user_query=user_query)
578
  history = build_message_history()
579
  response = asyncio.run(controller.extrcat(user_query=uq, message_history=history))
580
  return response
581
 
582
-
583
  def render_message(msg):
584
  is_user = msg["role"] == "user"
585
  role_class = "user" if is_user else "bot"
586
  avatar = "U" if is_user else "AI"
587
  ts = msg.get("ts", "")
588
 
589
- st.markdown(
590
- f"""
591
  <div class="msg-wrap {role_class}">
592
  <div class="avatar {role_class}">{avatar}</div>
593
  <div style="max-width:72%">
594
- """,
595
- unsafe_allow_html=True,
596
- )
597
 
598
  with st.container():
599
  if "status" in msg:
@@ -601,36 +541,62 @@ def render_message(msg):
601
  label = "βœ“ success" if msg["status"] == "success" else "βœ— error"
602
  st.markdown(
603
  f'<span class="badge {badge_class}">{label}</span>',
604
- unsafe_allow_html=True,
605
  )
606
 
607
  st.markdown(msg["content"])
608
 
609
  if msg.get("sql"):
610
- st.markdown(
611
- '<div class="sql-label">⚑ Generated SQL</div>', unsafe_allow_html=True
612
- )
613
  st.code(msg["sql"], language="sql")
614
 
615
  st.markdown(f'<div class="ts">{ts}</div>', unsafe_allow_html=True)
616
  st.markdown("</div></div>", unsafe_allow_html=True)
617
 
618
- # ── Data table ──
619
- if msg.get("data") and len(msg["data"]) > 0:
620
- st.markdown(
621
- '<div class="table-label">πŸ“Š Query Results</div>', unsafe_allow_html=True
622
- )
623
- df = pd.DataFrame(msg["data"])
624
- st.dataframe(df, use_container_width=True, hide_index=True)
625
- elif (
626
- msg.get("status") == "success"
627
- and "data" in msg
628
- and len(msg.get("data", [])) == 0
629
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  st.markdown(
631
  '<div style="color:#6b7280;font-size:0.8rem;margin-top:8px;'
632
  'font-family:monospace">⚠ Query returned 0 rows.</div>',
633
- unsafe_allow_html=True,
634
  )
635
 
636
 
@@ -669,30 +635,24 @@ with st.sidebar:
669
 
670
 
671
  # ── Main layout ───────────────────────────────────────────────────────────────
672
- st.markdown(
673
- """
674
  <div style="margin-bottom: 1.5rem;">
675
  <div class="page-title">Firerms Data Extractor Chatbot</div>
676
  <div class="page-sub">Natural language β†’ SQL β†’ Results</div>
677
  </div>
678
- """,
679
- unsafe_allow_html=True,
680
- )
681
 
682
  # ── Chat area ─────────────────────────────────────────────────────────────────
683
  chat_container = st.container()
684
  with chat_container:
685
  if not st.session_state.chat_history:
686
- st.markdown(
687
- """
688
  <div class="empty-state">
689
  <div class="empty-icon">πŸ”</div>
690
  <div class="empty-title">Ask anything about your data</div>
691
  <div class="empty-hint">Type a natural language question and the AI will generate SQL and return results.</div>
692
  </div>
693
- """,
694
- unsafe_allow_html=True,
695
- )
696
  else:
697
  for msg in st.session_state.chat_history:
698
  render_message(msg)
@@ -706,28 +666,24 @@ if not prompt and prefill:
706
  if prompt:
707
  ts_now = datetime.now().strftime("%H:%M:%S")
708
 
709
- st.session_state.chat_history.append(
710
- {
711
- "role": "user",
712
- "content": prompt,
713
- "ts": ts_now,
714
- }
715
- )
716
  st.session_state.total_queries += 1
717
 
718
  with st.spinner("Generating SQL and fetching results…"):
719
  try:
720
  result = call_controller(prompt)
721
- status = result.status
722
- sql = result.sql_query
723
- data = result.data or []
724
 
725
  # ── Extract best_suitable_chart from result.output (SQLQueryExtractor) ──
726
  # result.output.best_suitable_chart is a ChartsType enum β†’ use .value for the string
727
  try:
728
- best_chart = (
729
- result.output.best_suitable_chart.value
730
- ) # e.g. "PIE", "BAR", "LINE"
731
  except Exception:
732
  best_chart = None
733
 
@@ -738,29 +694,25 @@ if prompt:
738
  else f"Query returned status: `{status}`."
739
  )
740
 
741
- st.session_state.chat_history.append(
742
- {
743
- "role": "assistant",
744
- "content": content,
745
- "sql": sql,
746
- "data": data,
747
- "status": status,
748
- "best_suitable_chart": best_chart,
749
- "ts": datetime.now().strftime("%H:%M:%S"),
750
- }
751
- )
752
 
753
  if status == "success":
754
  st.session_state.successful_queries += 1
755
 
756
  except Exception as e:
757
- st.session_state.chat_history.append(
758
- {
759
- "role": "assistant",
760
- "content": f"❌ Error: {str(e)}",
761
- "status": "error",
762
- "ts": datetime.now().strftime("%H:%M:%S"),
763
- }
764
- )
765
-
766
- st.rerun()
 
9
  from datetime import datetime
10
 
11
  # ── Path setup so `src` is importable when running from src/ or project root ──
12
+ _here = Path(__file__).resolve().parent # src/
13
+ _root = _here.parent # project root
14
  for _p in [str(_here), str(_root)]:
15
  if _p not in sys.path:
16
  sys.path.insert(0, _p)
 
27
  )
28
 
29
  # ── Custom CSS ────────────────────────────────────────────────────────────────
30
+ st.markdown("""
 
31
  <style>
32
  @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=Syne:wght@400;600;700;800&display=swap');
33
 
 
259
  font-weight: 600; margin-bottom: 8px; font-family: var(--mono);
260
  }
261
  </style>
262
+ """, unsafe_allow_html=True)
 
 
263
 
264
  # ── Plotly dark theme template ────────────────────────────────────────────────
265
  PLOTLY_TEMPLATE = dict(
 
269
  font=dict(color="#e8eaf0", family="JetBrains Mono, monospace", size=11),
270
  xaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
271
  yaxis=dict(gridcolor="#2a2f42", linecolor="#2a2f42", zerolinecolor="#2a2f42"),
272
+ colorway=["#4f8ef7", "#7c5cfc", "#22d3a5", "#f75f5f", "#f7a24f",
273
+ "#e879f9", "#38bdf8", "#fb923c"],
 
 
 
 
 
 
 
 
274
  legend=dict(bgcolor="#1c2030", bordercolor="#2a2f42", borderwidth=1),
275
  margin=dict(l=40, r=20, t=40, b=40),
276
  )
 
281
 
282
  CHART_ALIASES = {
283
  # ── ChartsType enum values (uppercase from Pydantic/Enum .value) ──
284
+ "pie": "pie",
285
+ "bar": "bar",
286
  "line": "line",
287
  # ── Common string variants (lowercase) ──
288
+ "bar_chart": "bar", "vertical_bar": "bar", "column": "bar",
289
+ "grouped_bar": "bar", "stacked_bar": "bar",
290
+ "line_chart": "line", "time_series": "line", "trend": "line",
291
+ "donut": "pie", "doughnut": "pie",
 
 
 
 
 
 
292
  }
293
 
294
 
 
322
  return "bar"
323
 
324
 
325
+ def render_chart(df: pd.DataFrame, chart_type_raw: str | None = None, key_prefix: str = "chart"):
 
 
326
  """
327
  Render a Plotly chart with user-controlled column selectors.
328
  The user picks x, y, and optional color columns via st.selectbox.
 
334
  return
335
 
336
  chart_type = _normalise_chart_type(chart_type_raw) or _guess_chart_type(df)
337
+ cols = list(df.columns)
338
  numeric = df.select_dtypes(include="number").columns.tolist()
339
+ cat = df.select_dtypes(exclude="number").columns.tolist()
340
 
341
  # ── Smart defaults: pre-select the most sensible column per role ─────────────────────
342
+ default_x = cat[0] if cat else cols[0]
343
  default_y = numeric[0] if numeric else (cols[1] if len(cols) > 1 else cols[0])
344
 
345
  # ── Column selector UI ─────────────────────────────────────────────────────────────────
346
  st.markdown(
347
  '<div style="font-size:0.7rem;color:var(--muted);font-family:var(--mono);'
348
  'text-transform:uppercase;letter-spacing:0.08em;margin-bottom:8px;">'
349
+ 'βš™οΈ Configure columns</div>',
350
  unsafe_allow_html=True,
351
  )
352
 
 
408
  try:
409
  if chart_type == "bar":
410
  if color_col:
411
+ fig = px.bar(df, x=x_col, y=y_col, color=color_col,
412
+ barmode="group", template=PLOTLY_TEMPLATE)
 
 
 
 
 
 
413
  else:
414
  fig = px.bar(df, x=x_col, y=y_col, template=PLOTLY_TEMPLATE)
415
 
416
  elif chart_type == "line":
417
  if color_col:
418
+ fig = px.line(df, x=x_col, y=y_col, color=color_col,
419
+ markers=True, template=PLOTLY_TEMPLATE)
 
 
 
 
 
 
420
  else:
421
+ fig = px.line(df, x=x_col, y=y_col, markers=True,
422
+ template=PLOTLY_TEMPLATE)
 
423
 
424
  elif chart_type == "pie":
425
+ fig = px.pie(df, names=x_col, values=y_col,
426
+ hole=0.35, template=PLOTLY_TEMPLATE)
 
427
  fig.update_traces(textinfo="percent+label")
428
 
429
  else: # unrecognised / None β†’ heuristic fallback bar
430
+ fig = px.bar(df, x=x_col, y=y_col, template=PLOTLY_TEMPLATE,
431
+ title=f"Chart type '{chart_type_raw}' not recognized β€” showing bar chart")
 
 
 
 
 
432
 
433
  except Exception as e:
434
+ st.warning(f"⚠ Could not render `{chart_type}` chart: {e}. Check the column types selected above.")
 
 
435
  return
436
 
437
  if fig:
438
+ fig.update_layout(paper_bgcolor="#141720", plot_bgcolor="#0d0f14", font_color="#e8eaf0")
 
 
439
  st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
440
 
 
441
  def render_crosstab(df: pd.DataFrame):
442
  """
443
  Auto-build a crosstab-style pivot summary.
 
453
  return
454
 
455
  numeric = df.select_dtypes(include="number").columns.tolist()
456
+ cat = df.select_dtypes(exclude="number").columns.tolist()
457
 
458
  try:
459
  if len(cat) >= 2 and len(numeric) >= 1:
460
  pivot = df.pivot_table(
461
+ index=cat[0], columns=cat[1],
462
+ values=numeric[0], aggfunc="sum", fill_value=0,
 
 
 
 
 
 
 
463
  )
464
+ st.markdown(f'<div class="table-label">πŸ“ Crosstab β€” {cat[0]} Γ— {cat[1]} (sum of {numeric[0]})</div>',
465
+ unsafe_allow_html=True)
466
  st.dataframe(pivot, use_container_width=True)
467
 
468
  elif len(cat) == 1 and len(numeric) >= 1:
469
+ summary = (
470
+ df.groupby(cat[0])[numeric]
471
+ .agg(["sum", "mean", "count"])
472
+ )
473
  summary.columns = [f"{v}_{f}" for v, f in summary.columns]
474
  summary = summary.reset_index()
475
+ st.markdown(f'<div class="table-label">πŸ“ Summary β€” grouped by {cat[0]}</div>',
476
+ unsafe_allow_html=True)
 
 
477
  st.dataframe(summary, use_container_width=True, hide_index=True)
478
 
479
  elif len(numeric) >= 2:
480
  corr = df[numeric].corr().round(3)
481
+ st.markdown('<div class="table-label">πŸ“ Correlation Matrix</div>',
482
+ unsafe_allow_html=True)
483
+ st.dataframe(corr.style.background_gradient(cmap="Blues", axis=None),
484
+ use_container_width=True)
 
 
 
 
485
 
486
  else:
487
  desc = df.describe(include="all").T.reset_index()
488
  desc.rename(columns={"index": "column"}, inplace=True)
489
+ st.markdown('<div class="table-label">πŸ“ Statistical Summary</div>',
490
+ unsafe_allow_html=True)
 
 
491
  st.dataframe(desc, use_container_width=True, hide_index=True)
492
 
493
  except Exception as e:
 
500
  def get_controller():
501
  return DataExtractorController()
502
 
 
503
  controller = get_controller()
504
 
505
  # ── Session state ────────────────────────────────────────────────────────────
 
510
  if "successful_queries" not in st.session_state:
511
  st.session_state.successful_queries = 0
512
 
 
513
  # ── Helpers ──────────────────────────────────────────────────────────────────
514
  def build_message_history() -> list[Message]:
515
  return [
 
517
  for msg in st.session_state.chat_history
518
  ]
519
 
 
520
  def call_controller(user_query: str):
521
  uq = UserQuery(user_query=user_query)
522
  history = build_message_history()
523
  response = asyncio.run(controller.extrcat(user_query=uq, message_history=history))
524
  return response
525
 
 
526
  def render_message(msg):
527
  is_user = msg["role"] == "user"
528
  role_class = "user" if is_user else "bot"
529
  avatar = "U" if is_user else "AI"
530
  ts = msg.get("ts", "")
531
 
532
+ st.markdown(f"""
 
533
  <div class="msg-wrap {role_class}">
534
  <div class="avatar {role_class}">{avatar}</div>
535
  <div style="max-width:72%">
536
+ """, unsafe_allow_html=True)
 
 
537
 
538
  with st.container():
539
  if "status" in msg:
 
541
  label = "βœ“ success" if msg["status"] == "success" else "βœ— error"
542
  st.markdown(
543
  f'<span class="badge {badge_class}">{label}</span>',
544
+ unsafe_allow_html=True
545
  )
546
 
547
  st.markdown(msg["content"])
548
 
549
  if msg.get("sql"):
550
+ st.markdown('<div class="sql-label">⚑ Generated SQL</div>', unsafe_allow_html=True)
 
 
551
  st.code(msg["sql"], language="sql")
552
 
553
  st.markdown(f'<div class="ts">{ts}</div>', unsafe_allow_html=True)
554
  st.markdown("</div></div>", unsafe_allow_html=True)
555
 
556
+ # ── Multi-view data panel ──────────────────────────────────────────────
557
+ data = msg.get("data", [])
558
+ chart_hint = msg.get("best_suitable_chart") # e.g. "bar", "line", "pie"
559
+
560
+ if data and len(data) > 0 and msg.get("status") == "success":
561
+ df = pd.DataFrame(data)
562
+
563
+ tab_table, tab_crosstab, tab_chart = st.tabs([
564
+ "πŸ“‹ Table",
565
+ "πŸ“ Crosstab / Summary",
566
+ "πŸ“Š Chart",
567
+ ])
568
+
569
+ with tab_table:
570
+ st.markdown('<div class="table-label">Query Results</div>', unsafe_allow_html=True)
571
+ st.dataframe(df, use_container_width=True, hide_index=True)
572
+
573
+ with tab_crosstab:
574
+ render_crosstab(df)
575
+
576
+ with tab_chart:
577
+ if chart_hint:
578
+ icon_map = {"BAR": "πŸ“Š", "PIE": "πŸ₯§", "LINE": "πŸ“ˆ"}
579
+ icon = icon_map.get(str(chart_hint).upper(), "πŸ“Š")
580
+ st.markdown(
581
+ f'<div class="chart-label">{icon} {chart_hint}'
582
+ f' <span style="color:var(--muted);font-weight:400">(AI suggested)</span></div>',
583
+ unsafe_allow_html=True,
584
+ )
585
+ else:
586
+ st.markdown(
587
+ '<div class="chart-label">πŸ“Š Auto-detected chart type</div>',
588
+ unsafe_allow_html=True,
589
+ )
590
+ # key_prefix is unique per message (uses timestamp) so each
591
+ # chart's column selectors have independent Streamlit widget keys.
592
+ chart_key = f"chart_{msg.get('ts', 'x').replace(':', '_')}"
593
+ render_chart(df, chart_hint, key_prefix=chart_key)
594
+
595
+ elif msg.get("status") == "success" and not data:
596
  st.markdown(
597
  '<div style="color:#6b7280;font-size:0.8rem;margin-top:8px;'
598
  'font-family:monospace">⚠ Query returned 0 rows.</div>',
599
+ unsafe_allow_html=True
600
  )
601
 
602
 
 
635
 
636
 
637
  # ── Main layout ───────────────────────────────────────────────────────────────
638
+ st.markdown("""
 
639
  <div style="margin-bottom: 1.5rem;">
640
  <div class="page-title">Firerms Data Extractor Chatbot</div>
641
  <div class="page-sub">Natural language β†’ SQL β†’ Results</div>
642
  </div>
643
+ """, unsafe_allow_html=True)
 
 
644
 
645
  # ── Chat area ─────────────────────────────────────────────────────────────────
646
  chat_container = st.container()
647
  with chat_container:
648
  if not st.session_state.chat_history:
649
+ st.markdown("""
 
650
  <div class="empty-state">
651
  <div class="empty-icon">πŸ”</div>
652
  <div class="empty-title">Ask anything about your data</div>
653
  <div class="empty-hint">Type a natural language question and the AI will generate SQL and return results.</div>
654
  </div>
655
+ """, unsafe_allow_html=True)
 
 
656
  else:
657
  for msg in st.session_state.chat_history:
658
  render_message(msg)
 
666
  if prompt:
667
  ts_now = datetime.now().strftime("%H:%M:%S")
668
 
669
+ st.session_state.chat_history.append({
670
+ "role": "user",
671
+ "content": prompt,
672
+ "ts": ts_now,
673
+ })
 
 
674
  st.session_state.total_queries += 1
675
 
676
  with st.spinner("Generating SQL and fetching results…"):
677
  try:
678
  result = call_controller(prompt)
679
+ status = result.status
680
+ sql = result.sql_query
681
+ data = result.data or []
682
 
683
  # ── Extract best_suitable_chart from result.output (SQLQueryExtractor) ──
684
  # result.output.best_suitable_chart is a ChartsType enum β†’ use .value for the string
685
  try:
686
+ best_chart = result.output.best_suitable_chart.value # e.g. "PIE", "BAR", "LINE"
 
 
687
  except Exception:
688
  best_chart = None
689
 
 
694
  else f"Query returned status: `{status}`."
695
  )
696
 
697
+ st.session_state.chat_history.append({
698
+ "role": "assistant",
699
+ "content": content,
700
+ "sql": sql,
701
+ "data": data,
702
+ "status": status,
703
+ "best_suitable_chart": best_chart,
704
+ "ts": datetime.now().strftime("%H:%M:%S"),
705
+ })
 
 
706
 
707
  if status == "success":
708
  st.session_state.successful_queries += 1
709
 
710
  except Exception as e:
711
+ st.session_state.chat_history.append({
712
+ "role": "assistant",
713
+ "content": f"❌ Error: {str(e)}",
714
+ "status": "error",
715
+ "ts": datetime.now().strftime("%H:%M:%S"),
716
+ })
717
+
718
+ st.rerun()