Vanshcc commited on
Commit
8eee2bf
·
verified ·
1 Parent(s): 4890957

features add

Browse files
Files changed (2) hide show
  1. app.py +223 -34
  2. viz_utils.py +47 -2
app.py CHANGED
@@ -17,6 +17,10 @@ load_dotenv(Path(__file__).parent / ".env")
17
 
18
  import streamlit as st
19
  import uuid
 
 
 
 
20
  from datetime import datetime
21
 
22
  # Page config must be first
@@ -74,12 +78,10 @@ SUPPORTED_LANGUAGES = {
74
  "తెలుగు (Telugu)": "te",
75
  "मराठी (Marathi)": "mr",
76
  "বাংলা (Bengali)": "bn",
77
- "ગુજરાતી (Gujarati)": "gu"
78
  }
79
 
80
 
81
-
82
-
83
  def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
84
  """Create a custom database configuration from user input."""
85
  return DatabaseConfig(
@@ -150,6 +152,61 @@ def init_session_state():
150
 
151
  if "response_language" not in st.session_state:
152
  st.session_state.response_language = "English"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
  def render_database_config():
@@ -253,6 +310,33 @@ def render_sidebar():
253
  with st.sidebar:
254
  st.title("⚙️ Settings")
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  # User Profile
257
  st.subheader("👤 User Profile")
258
  user_id = st.text_input(
@@ -306,6 +390,19 @@ def render_sidebar():
306
 
307
  st.divider()
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  # Database Configuration
310
  custom_db_params = render_database_config()
311
 
@@ -599,38 +696,84 @@ def render_schema_explorer():
599
  try:
600
  schema = st.session_state.chatbot.introspector.introspect()
601
 
602
- st.markdown("Uncheck tables to exclude them from the chat context.")
603
 
604
- for table_name, table_info in schema.tables.items():
605
- col1, col2 = st.columns([0.05, 0.95])
606
 
607
- with col1:
608
- is_active = table_name not in st.session_state.ignored_tables
609
- active = st.checkbox(
610
- "Use",
611
- value=is_active,
612
- key=f"use_{table_name}",
613
- label_visibility="collapsed",
614
- help=f"Include {table_name} in chat analysis"
615
- )
616
 
617
- if not active:
618
- st.session_state.ignored_tables.add(table_name)
619
- else:
620
- st.session_state.ignored_tables.discard(table_name)
621
-
622
- with col2:
623
- with st.container():
624
- st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
 
625
 
626
- cols = []
627
- for col in table_info.columns:
628
- pk = "🔑" if col.is_primary_key else ""
629
- txt = "📝" if col.is_text_type else ""
630
- cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
 
632
- st.caption(" | ".join(cols))
633
- st.divider()
 
 
 
 
 
 
 
 
634
  except Exception as e:
635
  st.error(f"Error loading schema: {e}")
636
 
@@ -650,7 +793,26 @@ def render_chat_interface():
650
  # Display messages
651
  for i, msg in enumerate(st.session_state.messages):
652
  with st.chat_message(msg["role"]):
653
- st.markdown(msg["content"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
 
655
  # Show metadata for assistant messages
656
  if msg["role"] == "assistant" and "metadata" in msg:
@@ -727,17 +889,33 @@ def render_chat_interface():
727
  """, unsafe_allow_html=True)
728
 
729
  if meta.get("query_type"):
730
- st.caption(f"Query type: {meta['query_type']}")
 
 
 
 
731
 
732
  # SQL Query expander
733
  if meta.get("sql_query"):
734
  with st.expander("🛠️ SQL Query & Details"):
735
  st.code(meta["sql_query"], language="sql")
736
 
737
- # Visualizations
738
  if meta.get("sql_results"):
739
  # Only render viz if we have results
740
  render_visualization(meta["sql_results"], f"viz_{i}")
 
 
 
 
 
 
 
 
 
 
 
 
741
 
742
  # Chat input
743
  if prompt := st.chat_input("Ask about your data..."):
@@ -759,6 +937,9 @@ def render_chat_interface():
759
  if st.session_state.memory:
760
  st.session_state.memory.add_message("user", prompt)
761
 
 
 
 
762
  response = st.session_state.chatbot.chat(
763
  prompt,
764
  st.session_state.memory,
@@ -766,12 +947,17 @@ def render_chat_interface():
766
  language=st.session_state.response_language
767
  )
768
 
 
 
 
 
769
  # Create metadata dict
770
  metadata = {
771
  "query_type": response.query_type,
772
  "sql_query": response.sql_query,
773
  "sql_results": response.sql_results,
774
- "token_usage": response.token_usage
 
775
  }
776
 
777
  # Save to session state
@@ -781,6 +967,9 @@ def render_chat_interface():
781
  "metadata": metadata
782
  })
783
 
 
 
 
784
  # Save to active memory
785
  if st.session_state.memory:
786
  st.session_state.memory.add_message("assistant", response.answer)
 
17
 
18
  import streamlit as st
19
  import uuid
20
+ import time
21
+ import io
22
+ import csv
23
+ import base64
24
  from datetime import datetime
25
 
26
  # Page config must be first
 
78
  "తెలుగు (Telugu)": "te",
79
  "मराठी (Marathi)": "mr",
80
  "বাংলা (Bengali)": "bn",
81
+ "ગુજરાতી (Gujarati)": "gu"
82
  }
83
 
84
 
 
 
85
  def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
86
  """Create a custom database configuration from user input."""
87
  return DatabaseConfig(
 
152
 
153
  if "response_language" not in st.session_state:
154
  st.session_state.response_language = "English"
155
+
156
+ if "favorites" not in st.session_state:
157
+ st.session_state.favorites = [] # List of message indices that are favorited
158
+
159
+
160
+ def export_results_to_csv(results: list) -> str:
161
+ """Convert SQL results to CSV format and return as downloadable string."""
162
+ if not results:
163
+ return ""
164
+
165
+ output = io.StringIO()
166
+ writer = csv.DictWriter(output, fieldnames=results[0].keys())
167
+ writer.writeheader()
168
+ writer.writerows(results)
169
+ return output.getvalue()
170
+
171
+
172
+ def export_chat_to_text() -> str:
173
+ """Export chat messages to text format."""
174
+ if not st.session_state.messages:
175
+ return "No messages to export."
176
+
177
+ lines = []
178
+ lines.append("=" * 50)
179
+ lines.append(f"OnceDataBot Chat Export")
180
+ lines.append(f"Exported: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
181
+ lines.append(f"User: {st.session_state.user_id}")
182
+ lines.append("=" * 50)
183
+ lines.append("")
184
+
185
+ for i, msg in enumerate(st.session_state.messages):
186
+ role = "🧑 User" if msg["role"] == "user" else "🤖 Assistant"
187
+ is_favorited = "⭐ " if i in st.session_state.favorites else ""
188
+ lines.append(f"{is_favorited}{role}:")
189
+ lines.append(msg["content"])
190
+
191
+ if msg["role"] == "assistant" and "metadata" in msg:
192
+ meta = msg["metadata"]
193
+ if meta.get("sql_query"):
194
+ lines.append(f"\n📝 SQL Query: {meta['sql_query']}")
195
+ if meta.get("query_type"):
196
+ lines.append(f"📌 Query Type: {meta['query_type']}")
197
+ if meta.get("execution_time"):
198
+ lines.append(f"⏱️ Execution Time: {meta['execution_time']:.2f}s")
199
+
200
+ lines.append("-" * 40)
201
+ lines.append("")
202
+
203
+ return "\n".join(lines)
204
+
205
+
206
+ def render_copy_button(text: str, key: str):
207
+ """Render a copy to clipboard button using Streamlit."""
208
+ # Using a workaround with st.code which has built-in copy
209
+ st.code(text, language="sql")
210
 
211
 
212
  def render_database_config():
 
310
  with st.sidebar:
311
  st.title("⚙️ Settings")
312
 
313
+ # Session Dashboard
314
+ if st.session_state.messages:
315
+ st.markdown("### 📊 Session Stats")
316
+
317
+ # Calculate stats
318
+ total_msgs = len(st.session_state.messages)
319
+ assistant_msgs = [m for m in st.session_state.messages if m.get("role") == "assistant"]
320
+ sql_queries = sum(1 for m in assistant_msgs if m.get("metadata", {}).get("sql_query"))
321
+
322
+ total_tokens = 0
323
+ exec_times = []
324
+ for m in assistant_msgs:
325
+ meta = m.get("metadata", {})
326
+ total_tokens += meta.get("token_usage", {}).get("total", 0)
327
+ if meta.get("execution_time"):
328
+ exec_times.append(meta["execution_time"])
329
+
330
+ avg_time = sum(exec_times) / len(exec_times) if exec_times else 0
331
+
332
+ col_s1, col_s2 = st.columns(2)
333
+ col_s1.metric("Queries", sql_queries)
334
+ col_s2.metric("Tokens", f"{total_tokens:,}")
335
+ st.caption(f"⏱️ Avg Time: {avg_time:.2f}s | 💬 Msgs: {total_msgs}")
336
+ st.divider()
337
+
338
+
339
+
340
  # User Profile
341
  st.subheader("👤 User Profile")
342
  user_id = st.text_input(
 
390
 
391
  st.divider()
392
 
393
+ # Export Chat Button
394
+ if st.session_state.messages:
395
+ st.download_button(
396
+ label="📄 Export Chat",
397
+ data=export_chat_to_text(),
398
+ file_name=f"chat_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
399
+ mime="text/plain",
400
+ use_container_width=True,
401
+ help="Download your chat conversation as a text file"
402
+ )
403
+
404
+ st.divider()
405
+
406
  # Database Configuration
407
  custom_db_params = render_database_config()
408
 
 
696
  try:
697
  schema = st.session_state.chatbot.introspector.introspect()
698
 
699
+ tab_list, tab_erd = st.tabs(["📋 Table List", "🕸️ Schema Diagram"])
700
 
701
+ with tab_list:
702
+ st.markdown("Uncheck tables to exclude them from the chat context.")
703
 
704
+ for table_name, table_info in schema.tables.items():
705
+ col1, col2 = st.columns([0.05, 0.95])
 
 
 
 
 
 
 
706
 
707
+ with col1:
708
+ is_active = table_name not in st.session_state.ignored_tables
709
+ active = st.checkbox(
710
+ "Use",
711
+ value=is_active,
712
+ key=f"use_{table_name}",
713
+ label_visibility="collapsed",
714
+ help=f"Include {table_name} in chat analysis"
715
+ )
716
 
717
+ if not active:
718
+ st.session_state.ignored_tables.add(table_name)
719
+ else:
720
+ st.session_state.ignored_tables.discard(table_name)
721
+
722
+ with col2:
723
+ with st.container():
724
+ st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
725
+
726
+ cols = []
727
+ for col in table_info.columns:
728
+ pk = "🔑" if col.is_primary_key else ""
729
+ txt = "📝" if col.is_text_type else ""
730
+ cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
731
+
732
+ st.caption(" | ".join(cols))
733
+ st.divider()
734
+
735
+ with tab_erd:
736
+ if len(schema.tables) > 50:
737
+ st.warning("⚠️ Too many tables to visualize effectively (limit: 50).")
738
+ else:
739
+ try:
740
+ # Build Graphviz DOT string
741
+ dot = ['digraph Database {']
742
+ dot.append(' rankdir=LR;')
743
+ dot.append(' node [shape=box, style="filled,rounded", fillcolor="#f0f2f6", fontname="Arial", fontsize=10];')
744
+ dot.append(' edge [fontname="Arial", fontsize=9, color="#666666"];')
745
+
746
+ # Add nodes (tables)
747
+ for table_name in schema.tables:
748
+ if table_name not in st.session_state.ignored_tables:
749
+ dot.append(f' "{table_name}" [label="{table_name}", fillcolor="#e1effe", color="#1e40af"];')
750
+ else:
751
+ dot.append(f' "{table_name}" [label="{table_name} (ignored)", fillcolor="#f3f4f6", color="#9ca3af", fontcolor="#9ca3af"];')
752
+
753
+ # Add edges (relationships)
754
+ has_edges = False
755
+ for table_name, table_info in schema.tables.items():
756
+ for col_name, ref_str in table_info.foreign_keys.items():
757
+ # ref_str format: "referenced_table.referenced_column"
758
+ if "." in ref_str:
759
+ ref_table = ref_str.split(".")[0]
760
+ # specific_col = ref_str.split(".")[1]
761
+
762
+ # Only draw if both tables exist in our schema list
763
+ if ref_table in schema.tables:
764
+ dot.append(f' "{table_name}" -> "{ref_table}" [label="{col_name}"];')
765
+ has_edges = True
766
 
767
+ dot.append('}')
768
+ graph_code = "\n".join(dot)
769
+ st.graphviz_chart(graph_code, width="stretch")
770
+
771
+ if not has_edges:
772
+ st.info("No foreign key relationships detected in the schema metadata.")
773
+
774
+ except Exception as e:
775
+ st.error(f"Could not render diagram: {e}")
776
+
777
  except Exception as e:
778
  st.error(f"Error loading schema: {e}")
779
 
 
793
  # Display messages
794
  for i, msg in enumerate(st.session_state.messages):
795
  with st.chat_message(msg["role"]):
796
+ # Create columns for message and favorite button
797
+ msg_col, fav_col = st.columns([0.95, 0.05])
798
+
799
+ with msg_col:
800
+ st.markdown(msg["content"])
801
+
802
+ with fav_col:
803
+ # Favorite button for assistant messages
804
+ if msg["role"] == "assistant":
805
+ is_favorited = i in st.session_state.favorites
806
+ if st.button(
807
+ "⭐" if is_favorited else "☆",
808
+ key=f"fav_{i}",
809
+ help="Click to favorite/unfavorite this response"
810
+ ):
811
+ if is_favorited:
812
+ st.session_state.favorites.remove(i)
813
+ else:
814
+ st.session_state.favorites.append(i)
815
+ st.rerun()
816
 
817
  # Show metadata for assistant messages
818
  if msg["role"] == "assistant" and "metadata" in msg:
 
889
  """, unsafe_allow_html=True)
890
 
891
  if meta.get("query_type"):
892
+ # Show query type and execution time on same line
893
+ info_text = f"Query type: {meta['query_type']}"
894
+ if meta.get("execution_time"):
895
+ info_text += f" • ⏱️ {meta['execution_time']:.2f}s"
896
+ st.caption(info_text)
897
 
898
  # SQL Query expander
899
  if meta.get("sql_query"):
900
  with st.expander("🛠️ SQL Query & Details"):
901
  st.code(meta["sql_query"], language="sql")
902
 
903
+ # Visualizations and CSV export
904
  if meta.get("sql_results"):
905
  # Only render viz if we have results
906
  render_visualization(meta["sql_results"], f"viz_{i}")
907
+
908
+ # CSV Export button
909
+ csv_data = export_results_to_csv(meta["sql_results"])
910
+ if csv_data:
911
+ st.download_button(
912
+ label="📊 Export to CSV",
913
+ data=csv_data,
914
+ file_name=f"query_results_{i}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
915
+ mime="text/csv",
916
+ key=f"csv_export_{i}",
917
+ help="Download query results as CSV file"
918
+ )
919
 
920
  # Chat input
921
  if prompt := st.chat_input("Ask about your data..."):
 
937
  if st.session_state.memory:
938
  st.session_state.memory.add_message("user", prompt)
939
 
940
+ # Track execution time
941
+ start_time = time.time()
942
+
943
  response = st.session_state.chatbot.chat(
944
  prompt,
945
  st.session_state.memory,
 
947
  language=st.session_state.response_language
948
  )
949
 
950
+ execution_time = time.time() - start_time
951
+
952
+
953
+
954
  # Create metadata dict
955
  metadata = {
956
  "query_type": response.query_type,
957
  "sql_query": response.sql_query,
958
  "sql_results": response.sql_results,
959
+ "token_usage": response.token_usage,
960
+ "execution_time": execution_time
961
  }
962
 
963
  # Save to session state
 
967
  "metadata": metadata
968
  })
969
 
970
+ # Set flag to auto-read the latest response
971
+ st.session_state.auto_read_latest = True
972
+
973
  # Save to active memory
974
  if st.session_state.memory:
975
  st.session_state.memory.add_message("assistant", response.answer)
viz_utils.py CHANGED
@@ -11,11 +11,55 @@ def render_visualization(results, key_prefix):
11
  df = pd.DataFrame(results)
12
 
13
  with st.expander("📊 Results & Visualization", expanded=False):
14
- tab_data, tab_viz = st.tabs(["📄 Data", "📈 Visualize"])
15
 
16
  with tab_data:
17
- st.dataframe(df)
 
 
 
 
 
 
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  with tab_viz:
20
  numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
21
  categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
@@ -33,6 +77,7 @@ def render_visualization(results, key_prefix):
33
  with col2:
34
  # Default X axis logic
35
  x_options = df.columns.tolist()
 
36
  default_x = categorical_cols[0] if categorical_cols else x_options[0]
37
  # Find index safely
38
  try:
 
11
  df = pd.DataFrame(results)
12
 
13
  with st.expander("📊 Results & Visualization", expanded=False):
14
+ tab_data, tab_summary, tab_viz = st.tabs(["📄 Data", "🧮 Summary", "📈 Visualize"])
15
 
16
  with tab_data:
17
+ col_search, col_info = st.columns([3, 1])
18
+ with col_search:
19
+ search_term = st.text_input(
20
+ "🔍 Filter Data",
21
+ placeholder="Type to search...",
22
+ label_visibility="collapsed",
23
+ key=f"{key_prefix}_search"
24
+ )
25
 
26
+ if search_term:
27
+ # Filter dataframe (case-insensitive) across all columns
28
+ mask = df.astype(str).apply(
29
+ lambda x: x.str.contains(search_term, case=False, na=False)
30
+ ).any(axis=1)
31
+ filtered_df = df[mask]
32
+ with col_info:
33
+ st.caption(f"Showing {len(filtered_df)} / {len(df)} rows")
34
+ st.dataframe(filtered_df, width="stretch")
35
+ else:
36
+ with col_info:
37
+ st.caption(f"Total rows: {len(df)}")
38
+ st.dataframe(df, width="stretch")
39
+
40
+ with tab_summary:
41
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
42
+ if not numeric_cols:
43
+ st.info("No numeric columns found for summary statistics.")
44
+ else:
45
+ st.markdown("### 📊 Quick Statistics")
46
+
47
+ # Create metrics for each numeric column
48
+ for col in numeric_cols:
49
+ st.markdown(f"**{col}**")
50
+ c1, c2, c3, c4 = st.columns(4)
51
+
52
+ series = df[col]
53
+ with c1:
54
+ st.metric("Total", f"{series.sum():,.2f}")
55
+ with c2:
56
+ st.metric("Average", f"{series.mean():,.2f}")
57
+ with c3:
58
+ st.metric("Min", f"{series.min():,.2f}")
59
+ with c4:
60
+ st.metric("Max", f"{series.max():,.2f}")
61
+ st.divider()
62
+
63
  with tab_viz:
64
  numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
65
  categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
 
77
  with col2:
78
  # Default X axis logic
79
  x_options = df.columns.tolist()
80
+ # Try to find a good categorical column for X axis
81
  default_x = categorical_cols[0] if categorical_cols else x_options[0]
82
  # Find index safely
83
  try: