Vanshcc commited on
Commit
5f2c193
·
verified ·
1 Parent(s): 7c3bad4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1090 -1002
app.py CHANGED
@@ -1,1002 +1,1090 @@
1
- """
2
- Schema-Agnostic Database Chatbot - Streamlit Application
3
-
4
- A production-grade chatbot that connects to ANY database
5
- (MySQL, PostgreSQL, SQLite) and provides intelligent querying
6
- through RAG and Text-to-SQL.
7
-
8
- Uses Groq for FREE LLM inference!
9
- """
10
-
11
- import os
12
- from pathlib import Path
13
-
14
- # Load .env FIRST before any other imports
15
- from dotenv import load_dotenv
16
- load_dotenv(Path(__file__).parent / ".env")
17
-
18
- import streamlit as st
19
- import uuid
20
- import time
21
- import io
22
- import csv
23
- import base64
24
- from datetime import datetime
25
-
26
- # Page config must be first
27
- st.set_page_config(
28
- page_title="OnceDataBot",
29
- page_icon="🤖",
30
- layout="wide",
31
- initial_sidebar_state="expanded"
32
- )
33
-
34
- # Imports
35
- from config import config, DatabaseConfig, DatabaseType
36
- from database import get_db, get_schema, get_introspector
37
- from database.connection import DatabaseConnection
38
- from llm import create_llm_client
39
- from chatbot import create_chatbot, DatabaseChatbot
40
- from memory import ChatMemory, EnhancedChatMemory
41
- from viz_utils import render_visualization
42
-
43
-
44
-
45
-
46
-
47
-
48
- # Groq models (all FREE!)
49
- GROQ_MODELS = [
50
- "llama-3.3-70b-versatile",
51
- "llama-3.1-8b-instant",
52
- "mixtral-8x7b-32768",
53
- "gemma2-9b-it"
54
- ]
55
-
56
- # Database types
57
- DB_TYPES = {
58
- "MySQL": "mysql",
59
- "PostgreSQL": "postgresql"
60
- }
61
-
62
- # Supported languages for multi-language responses
63
- SUPPORTED_LANGUAGES = {
64
- "English": "en",
65
- "हिन्दी (Hindi)": "hi",
66
- "Español (Spanish)": "es",
67
- "Français (French)": "fr",
68
- "Deutsch (German)": "de",
69
- "中文 (Chinese)": "zh",
70
- "日本語 (Japanese)": "ja",
71
- "한국어 (Korean)": "ko",
72
- "Português (Portuguese)": "pt",
73
- "العربية (Arabic)": "ar",
74
- "Русский (Russian)": "ru",
75
- "Italiano (Italian)": "it",
76
- "Nederlands (Dutch)": "nl",
77
- "தமிழ் (Tamil)": "ta",
78
- "తెలుగు (Telugu)": "te",
79
- "मराठी (Marathi)": "mr",
80
- "বাংলা (Bengali)": "bn",
81
- "ગુજરાতી (Gujarati)": "gu"
82
- }
83
-
84
-
85
- def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
86
- """Create a custom database configuration from user input."""
87
- return DatabaseConfig(
88
- db_type=DatabaseType(db_type),
89
- host=kwargs.get("host", ""),
90
- port=kwargs.get("port", 3306 if db_type == "mysql" else 5432),
91
- database=kwargs.get("database", ""),
92
- username=kwargs.get("username", ""),
93
- password=kwargs.get("password", ""),
94
- ssl_ca=kwargs.get("ssl_ca", None)
95
- )
96
-
97
-
98
- def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
99
- enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory:
100
- """Create enhanced memory with a custom database connection."""
101
- return EnhancedChatMemory(
102
- session_id=session_id,
103
- user_id=user_id,
104
- max_messages=20,
105
- db_connection=db_connection,
106
- llm_client=llm_client,
107
- enable_summarization=enable_summarization,
108
- summary_threshold=summary_threshold
109
- )
110
-
111
-
112
- def init_session_state():
113
- """Initialize Streamlit session state."""
114
- if "session_id" not in st.session_state:
115
- st.session_state.session_id = str(uuid.uuid4())
116
-
117
- if "messages" not in st.session_state:
118
- st.session_state.messages = []
119
-
120
- if "chatbot" not in st.session_state:
121
- st.session_state.chatbot = None
122
-
123
- if "initialized" not in st.session_state:
124
- st.session_state.initialized = False
125
-
126
- if "user_id" not in st.session_state:
127
- st.session_state.user_id = "default"
128
-
129
- if "enable_summarization" not in st.session_state:
130
- st.session_state.enable_summarization = True
131
-
132
- if "summary_threshold" not in st.session_state:
133
- st.session_state.summary_threshold = 10
134
-
135
- if "memory" not in st.session_state:
136
- st.session_state.memory = None
137
-
138
- if "indexed" not in st.session_state:
139
- st.session_state.indexed = False
140
-
141
- if "db_source" not in st.session_state:
142
- st.session_state.db_source = "environment" # "environment" or "custom"
143
-
144
- if "custom_db_config" not in st.session_state:
145
- st.session_state.custom_db_config = None
146
-
147
- if "custom_db_connection" not in st.session_state:
148
- st.session_state.custom_db_connection = None
149
-
150
- if "ignored_tables" not in st.session_state:
151
- st.session_state.ignored_tables = set()
152
-
153
- if "response_language" not in st.session_state:
154
- st.session_state.response_language = "English"
155
-
156
- if "favorites" not in st.session_state:
157
- st.session_state.favorites = [] # List of message indices that are favorited
158
-
159
-
160
- def export_results_to_csv(results: list) -> str:
161
- """Convert SQL results to CSV format and return as downloadable string."""
162
- if not results:
163
- return ""
164
-
165
- output = io.StringIO()
166
- writer = csv.DictWriter(output, fieldnames=results[0].keys())
167
- writer.writeheader()
168
- writer.writerows(results)
169
- return output.getvalue()
170
-
171
-
172
- def export_chat_to_text() -> str:
173
- """Export chat messages to text format."""
174
- if not st.session_state.messages:
175
- return "No messages to export."
176
-
177
- lines = []
178
- lines.append("=" * 50)
179
- lines.append(f"OnceDataBot Chat Export")
180
- lines.append(f"Exported: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
181
- lines.append(f"User: {st.session_state.user_id}")
182
- lines.append("=" * 50)
183
- lines.append("")
184
-
185
- for i, msg in enumerate(st.session_state.messages):
186
- role = "🧑 User" if msg["role"] == "user" else "🤖 Assistant"
187
- is_favorited = "⭐ " if i in st.session_state.favorites else ""
188
- lines.append(f"{is_favorited}{role}:")
189
- lines.append(msg["content"])
190
-
191
- if msg["role"] == "assistant" and "metadata" in msg:
192
- meta = msg["metadata"]
193
- if meta.get("sql_query"):
194
- lines.append(f"\n📝 SQL Query: {meta['sql_query']}")
195
- if meta.get("query_type"):
196
- lines.append(f"📌 Query Type: {meta['query_type']}")
197
- if meta.get("execution_time"):
198
- lines.append(f"⏱️ Execution Time: {meta['execution_time']:.2f}s")
199
-
200
- lines.append("-" * 40)
201
- lines.append("")
202
-
203
- return "\n".join(lines)
204
-
205
-
206
- def render_copy_button(text: str, key: str):
207
- """Render a copy to clipboard button using Streamlit."""
208
- # Using a workaround with st.code which has built-in copy
209
- st.code(text, language="sql")
210
-
211
-
212
- def render_database_config():
213
- """Render database configuration section in sidebar."""
214
- st.subheader("🗄️ Database Configuration")
215
-
216
- # Database source selection
217
- db_source = st.radio(
218
- "Database Source",
219
- options=["Use Environment Variables", "Custom Database"],
220
- index=0 if st.session_state.db_source == "environment" else 1,
221
- key="db_source_radio",
222
- help="Choose to use .env settings or enter custom credentials"
223
- )
224
-
225
- st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom"
226
-
227
- if st.session_state.db_source == "environment":
228
- # Show current environment config
229
- current_db_type = config.database.db_type.value.upper()
230
- st.info(f"📌 Using {current_db_type} from environment")
231
- st.caption(f"Host: {config.database.host}")
232
- return None
233
-
234
- else:
235
- # Custom database configuration
236
- st.markdown("##### Enter Database Credentials")
237
-
238
- # Database type selector
239
- db_type_label = st.selectbox(
240
- "Database Type",
241
- options=list(DB_TYPES.keys()),
242
- index=0,
243
- key="custom_db_type"
244
- )
245
- db_type = DB_TYPES[db_type_label]
246
-
247
- if True: # MySQL or PostgreSQL (SQLite removed)
248
- # MySQL or PostgreSQL
249
- col1, col2 = st.columns([3, 1])
250
- with col1:
251
- host = st.text_input(
252
- "Host",
253
- value="",
254
- key="db_host_input",
255
- placeholder="your-database-host.com"
256
- )
257
- with col2:
258
- default_port = 3306 if db_type == "mysql" else 5432
259
- port = st.number_input(
260
- "Port",
261
- value=default_port,
262
- min_value=1,
263
- max_value=65535,
264
- key="db_port_input"
265
- )
266
-
267
- database = st.text_input(
268
- "Database Name",
269
- value="",
270
- key="db_name_input",
271
- placeholder="your_database"
272
- )
273
-
274
- username = st.text_input(
275
- "Username",
276
- value="",
277
- key="db_user_input",
278
- placeholder="your_username"
279
- )
280
-
281
- password = st.text_input(
282
- "Password",
283
- value="",
284
- type="password",
285
- key="db_pass_input"
286
- )
287
-
288
- # Optional SSL
289
- with st.expander("🔒 SSL Settings (Optional)"):
290
- ssl_ca = st.text_input(
291
- "SSL CA Certificate Path",
292
- value="",
293
- key="ssl_ca_input",
294
- help="Path to SSL CA certificate file (for cloud databases like Aiven)"
295
- )
296
-
297
- return {
298
- "db_type": db_type,
299
- "host": host,
300
- "port": int(port),
301
- "database": database,
302
- "username": username,
303
- "password": password,
304
- "ssl_ca": ssl_ca if ssl_ca else None
305
- }
306
-
307
-
308
- def render_sidebar():
309
- """Render the configuration sidebar."""
310
- with st.sidebar:
311
- st.title("⚙️ Settings")
312
-
313
- # Session Dashboard
314
- if st.session_state.messages:
315
- st.markdown("### 📊 Session Stats")
316
-
317
- # Calculate stats
318
- total_msgs = len(st.session_state.messages)
319
- assistant_msgs = [m for m in st.session_state.messages if m.get("role") == "assistant"]
320
- sql_queries = sum(1 for m in assistant_msgs if m.get("metadata", {}).get("sql_query"))
321
-
322
- total_tokens = 0
323
- exec_times = []
324
- for m in assistant_msgs:
325
- meta = m.get("metadata", {})
326
- total_tokens += meta.get("token_usage", {}).get("total", 0)
327
- if meta.get("execution_time"):
328
- exec_times.append(meta["execution_time"])
329
-
330
- avg_time = sum(exec_times) / len(exec_times) if exec_times else 0
331
-
332
- col_s1, col_s2 = st.columns(2)
333
- col_s1.metric("Queries", sql_queries)
334
- col_s2.metric("Tokens", f"{total_tokens:,}")
335
- st.caption(f"⏱️ Avg Time: {avg_time:.2f}s | 💬 Msgs: {total_msgs}")
336
- st.divider()
337
-
338
-
339
-
340
- # User Profile
341
- st.subheader("👤 User Profile")
342
- user_id = st.text_input(
343
- "User ID / Name",
344
- value=st.session_state.get("user_id", "default"),
345
- key="user_id_input",
346
- help="Your unique ID for private memory storage"
347
- )
348
- if user_id != st.session_state.get("user_id"):
349
- st.session_state.user_id = user_id
350
- st.session_state.session_id = str(uuid.uuid4())
351
- st.session_state.messages = []
352
-
353
- # Recreate memory for new user
354
- if st.session_state.custom_db_connection:
355
- st.session_state.memory = create_custom_memory(
356
- st.session_state.session_id,
357
- user_id,
358
- st.session_state.custom_db_connection,
359
- st.session_state.get("llm"),
360
- st.session_state.enable_summarization,
361
- st.session_state.summary_threshold
362
- )
363
- elif st.session_state.initialized:
364
- from memory import create_enhanced_memory
365
- st.session_state.memory = create_enhanced_memory(
366
- st.session_state.session_id,
367
- user_id=user_id,
368
- enable_summarization=st.session_state.enable_summarization,
369
- summary_threshold=st.session_state.summary_threshold
370
- )
371
-
372
- if st.session_state.memory:
373
- st.session_state.memory.clear_user_history()
374
- st.rerun()
375
-
376
- st.divider()
377
-
378
- # Language Selection
379
- st.subheader("🌐 Response Language")
380
- selected_language = st.selectbox(
381
- "Select Language",
382
- options=list(SUPPORTED_LANGUAGES.keys()),
383
- index=list(SUPPORTED_LANGUAGES.keys()).index(st.session_state.response_language),
384
- key="language_selector",
385
- help="Choose the language for chatbot responses"
386
- )
387
- if selected_language != st.session_state.response_language:
388
- st.session_state.response_language = selected_language
389
- st.toast(f"🌐 Language changed to {selected_language}")
390
-
391
- st.divider()
392
-
393
- # Export Chat Button
394
- if st.session_state.messages:
395
- st.download_button(
396
- label="📄 Export Chat",
397
- data=export_chat_to_text(),
398
- file_name=f"chat_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
399
- mime="text/plain",
400
- use_container_width=True,
401
- help="Download your chat conversation as a text file"
402
- )
403
-
404
- st.divider()
405
-
406
- # Database Configuration
407
- custom_db_params = render_database_config()
408
-
409
- st.divider()
410
-
411
- # LLM Configuration
412
- st.subheader("🤖 LLM Configuration")
413
-
414
- # Show status of API key
415
- if os.getenv("GROQ_API_KEY"):
416
- st.success("✓ API Key configured")
417
- else:
418
- st.warning("⚠️ GROQ_API_KEY not set in environment")
419
-
420
- st.divider()
421
-
422
- # Initialize Button
423
- if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"):
424
- with st.spinner("Connecting to database..."):
425
- success = initialize_chatbot(custom_db_params, None, None)
426
- if success:
427
- st.success("✅ Connected!")
428
- st.rerun()
429
-
430
- # Index Button (after initialization)
431
- if st.session_state.initialized:
432
- if st.button("📚 Index Text Data", use_container_width=True):
433
- with st.spinner("Indexing text data..."):
434
- index_data()
435
- st.success("✅ Indexed!")
436
- st.rerun()
437
-
438
- st.divider()
439
-
440
- # Status
441
- st.subheader("📊 Status")
442
- if st.session_state.initialized:
443
- # Show database type
444
- if st.session_state.custom_db_connection:
445
- db_type = st.session_state.custom_db_connection.db_type.value.upper()
446
- else:
447
- db_type = get_db().db_type.value.upper()
448
-
449
- st.success(f"Database: {db_type} ✓")
450
-
451
- try:
452
- schema = get_schema()
453
- st.info(f"Tables: {len(schema.tables)}")
454
- except:
455
- st.warning("Schema not loaded")
456
-
457
- if st.session_state.indexed:
458
- from rag import get_rag_engine
459
- engine = get_rag_engine()
460
- st.info(f"Indexed Docs: {engine.document_count}")
461
- else:
462
- st.warning("Not connected")
463
-
464
- # New Chat
465
- if st.button("➕ New Chat", use_container_width=True, type="secondary"):
466
- if st.session_state.memory:
467
- st.session_state.memory.clear()
468
-
469
- st.session_state.messages = []
470
- st.session_state.session_id = str(uuid.uuid4())
471
-
472
- current_user = st.session_state.get("user_id", "default")
473
-
474
- if st.session_state.custom_db_connection:
475
- st.session_state.memory = create_custom_memory(
476
- st.session_state.session_id,
477
- current_user,
478
- st.session_state.custom_db_connection,
479
- st.session_state.get("llm"),
480
- st.session_state.enable_summarization,
481
- st.session_state.summary_threshold
482
- )
483
- elif st.session_state.initialized:
484
- from memory import create_enhanced_memory
485
- st.session_state.memory = create_enhanced_memory(
486
- st.session_state.session_id,
487
- user_id=current_user,
488
- enable_summarization=st.session_state.enable_summarization,
489
- summary_threshold=st.session_state.summary_threshold
490
- )
491
- if st.session_state.get("llm"):
492
- st.session_state.memory.set_llm_client(st.session_state.llm)
493
-
494
- st.rerun()
495
-
496
- # Disconnect button (when using custom DB)
497
- if st.session_state.initialized and st.session_state.db_source == "custom":
498
- if st.button("🔌 Disconnect", use_container_width=True):
499
- if st.session_state.custom_db_connection:
500
- st.session_state.custom_db_connection.close()
501
- st.session_state.custom_db_connection = None
502
- st.session_state.chatbot = None
503
- st.session_state.initialized = False
504
- st.session_state.indexed = False
505
- st.session_state.memory = None
506
- st.success("Disconnected!")
507
- st.rerun()
508
-
509
- st.divider()
510
-
511
- # Chat History Section
512
- if st.session_state.memory:
513
- st.subheader("🕰️ Chat History")
514
- sessions = st.session_state.memory.get_user_sessions()
515
-
516
- if not sessions:
517
- st.caption("No previous chats found.")
518
- else:
519
- for session in sessions:
520
- # Highlight current session
521
- is_current = session["id"] == st.session_state.session_id
522
- icon = "🟢" if is_current else "💬"
523
-
524
- if st.button(
525
- f"{icon} {session['title']}",
526
- key=f"hist_{session['id']}",
527
- use_container_width=True,
528
- type="secondary" if not is_current else "primary"
529
- ):
530
- if not is_current:
531
- # Load selected session
532
- st.session_state.session_id = session["id"]
533
- st.session_state.memory.session_id = session["id"]
534
- st.session_state.memory.messages = [] # Clear current state local cache
535
-
536
- # Load from DB
537
- msgs = st.session_state.memory.load_session(session["id"])
538
- st.session_state.messages = msgs
539
-
540
- # Re-populate memory object messages list for context
541
- # (We need to convert dicts back to ChatMessage objects implicitly or just rely on reload)
542
- # Actually, we should probably re-init the memory to be safe or manually populate
543
- # Let's manually populate to keep the connection valid
544
- from memory import ChatMessage
545
- st.session_state.memory.messages = [
546
- ChatMessage(
547
- role=m['role'],
548
- content=m['content'],
549
- metadata=m.get('metadata')
550
- ) for m in msgs
551
- ]
552
-
553
- st.rerun()
554
-
555
-
556
- def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
557
- """Initialize the chatbot with either environment or custom database."""
558
- try:
559
- # Get API key
560
- groq_api_key = api_key or os.getenv("GROQ_API_KEY", "")
561
- groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
562
-
563
- if not groq_api_key:
564
- st.error("GROQ_API_KEY not configured. Please enter your API key.")
565
- return False
566
-
567
- # Create LLM client
568
- llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model)
569
-
570
- # Create database connection
571
- if custom_db_params and st.session_state.db_source == "custom":
572
- # Validate custom params
573
- db_type = custom_db_params.get("db_type", "mysql")
574
-
575
- if True:
576
- if not all([custom_db_params.get("host"),
577
- custom_db_params.get("database"),
578
- custom_db_params.get("username")]):
579
- st.error("Please fill in all required database fields.")
580
- return False
581
-
582
- # Create custom config
583
- db_config = create_custom_db_config(**custom_db_params)
584
-
585
- # Create custom connection
586
- custom_connection = DatabaseConnection(db_config)
587
-
588
- # Test connection
589
- success, msg = custom_connection.test_connection()
590
- if not success:
591
- st.error(f"Connection failed: {msg}")
592
- return False
593
-
594
- st.session_state.custom_db_connection = custom_connection
595
- st.session_state.custom_db_config = db_config
596
-
597
- # Override the global db connection for the chatbot
598
- # We need to create a chatbot with this custom connection
599
- from chatbot import DatabaseChatbot
600
- from database.schema_introspector import SchemaIntrospector
601
- from rag import get_rag_engine
602
- from sql import get_sql_generator, get_sql_validator
603
- from router import get_query_router
604
-
605
- chatbot = DatabaseChatbot.__new__(DatabaseChatbot)
606
- chatbot.db = custom_connection
607
- chatbot.introspector = SchemaIntrospector()
608
- chatbot.introspector.db = custom_connection
609
- chatbot.rag_engine = get_rag_engine()
610
- chatbot.sql_generator = get_sql_generator(db_type)
611
- chatbot.sql_validator = get_sql_validator()
612
- chatbot.router = get_query_router()
613
- chatbot.llm_client = llm
614
- chatbot._schema_initialized = False
615
- chatbot._rag_initialized = False
616
-
617
- # Set LLM client
618
- chatbot.set_llm_client(llm)
619
-
620
- # Initialize (introspect schema)
621
- schema = chatbot.introspector.introspect(force_refresh=True)
622
- chatbot.sql_validator.set_allowed_tables(schema.table_names)
623
- chatbot._schema_initialized = True
624
-
625
- st.session_state.chatbot = chatbot
626
-
627
- else:
628
- # Use environment-based connection (existing flow)
629
- chatbot = create_chatbot(llm)
630
- chatbot.set_llm_client(llm)
631
-
632
- success, msg = chatbot.initialize()
633
- if not success:
634
- st.error(f"Initialization failed: {msg}")
635
- return False
636
-
637
- st.session_state.chatbot = chatbot
638
- st.session_state.custom_db_connection = None
639
-
640
- st.session_state.llm = llm
641
- st.session_state.initialized = True
642
- st.session_state.indexed = False # Reset index status on new connection
643
-
644
- # Clear RAG index to ensure no data from previous DB connection persists
645
- if hasattr(chatbot, 'rag_engine') and hasattr(chatbot.rag_engine, 'clear_index'):
646
- chatbot.rag_engine.clear_index()
647
-
648
- # Create memory with appropriate connection
649
- db_conn = st.session_state.custom_db_connection or get_db()
650
- st.session_state.memory = create_custom_memory(
651
- st.session_state.session_id,
652
- st.session_state.user_id,
653
- db_conn,
654
- llm,
655
- st.session_state.enable_summarization,
656
- st.session_state.summary_threshold
657
- )
658
-
659
- return True
660
-
661
- except Exception as e:
662
- st.error(f"Error: {str(e)}")
663
- import traceback
664
- st.error(traceback.format_exc())
665
- return False
666
-
667
-
668
- def index_data():
669
- """Index text data from the database."""
670
- if st.session_state.chatbot:
671
- progress = st.progress(0)
672
- status = st.empty()
673
-
674
- # Get schema from the correct introspector
675
- schema = st.session_state.chatbot.introspector.introspect()
676
- total_tables = len(schema.tables)
677
- indexed = 0
678
-
679
- def progress_callback(table_name, docs):
680
- nonlocal indexed
681
- indexed += 1
682
- progress.progress(indexed / total_tables)
683
- status.text(f"Indexed {table_name}: {docs} documents")
684
-
685
- total_docs = st.session_state.chatbot.index_text_data(progress_callback)
686
- st.session_state.indexed = True
687
- status.text(f"Total: {total_docs} documents indexed")
688
-
689
-
690
- def render_schema_explorer():
691
- """Render schema explorer in an expander."""
692
- if not st.session_state.initialized:
693
- return
694
-
695
- with st.expander("📋 Database Schema", expanded=False):
696
- try:
697
- schema = st.session_state.chatbot.introspector.introspect()
698
-
699
- tab_list, tab_erd = st.tabs(["📋 Table List", "🕸️ Schema Diagram"])
700
-
701
- with tab_list:
702
- st.markdown("Uncheck tables to exclude them from the chat context.")
703
-
704
- for table_name, table_info in schema.tables.items():
705
- col1, col2 = st.columns([0.05, 0.95])
706
-
707
- with col1:
708
- is_active = table_name not in st.session_state.ignored_tables
709
- active = st.checkbox(
710
- "Use",
711
- value=is_active,
712
- key=f"use_{table_name}",
713
- label_visibility="collapsed",
714
- help=f"Include {table_name} in chat analysis"
715
- )
716
-
717
- if not active:
718
- st.session_state.ignored_tables.add(table_name)
719
- else:
720
- st.session_state.ignored_tables.discard(table_name)
721
-
722
- with col2:
723
- with st.container():
724
- st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
725
-
726
- cols = []
727
- for col in table_info.columns:
728
- pk = "🔑" if col.is_primary_key else ""
729
- txt = "📝" if col.is_text_type else ""
730
- cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
731
-
732
- st.caption(" | ".join(cols))
733
- st.divider()
734
-
735
- with tab_erd:
736
- if len(schema.tables) > 50:
737
- st.warning("⚠️ Too many tables to visualize effectively (limit: 50).")
738
- else:
739
- try:
740
- # Build Graphviz DOT string
741
- dot = ['digraph Database {']
742
- dot.append(' rankdir=LR;')
743
- dot.append(' node [shape=box, style="filled,rounded", fillcolor="#f0f2f6", fontname="Arial", fontsize=10];')
744
- dot.append(' edge [fontname="Arial", fontsize=9, color="#666666"];')
745
-
746
- # Add nodes (tables)
747
- for table_name in schema.tables:
748
- if table_name not in st.session_state.ignored_tables:
749
- dot.append(f' "{table_name}" [label="{table_name}", fillcolor="#e1effe", color="#1e40af"];')
750
- else:
751
- dot.append(f' "{table_name}" [label="{table_name} (ignored)", fillcolor="#f3f4f6", color="#9ca3af", fontcolor="#9ca3af"];')
752
-
753
- # Add edges (relationships)
754
- has_edges = False
755
- for table_name, table_info in schema.tables.items():
756
- for col_name, ref_str in table_info.foreign_keys.items():
757
- # ref_str format: "referenced_table.referenced_column"
758
- if "." in ref_str:
759
- ref_table = ref_str.split(".")[0]
760
- # specific_col = ref_str.split(".")[1]
761
-
762
- # Only draw if both tables exist in our schema list
763
- if ref_table in schema.tables:
764
- dot.append(f' "{table_name}" -> "{ref_table}" [label="{col_name}"];')
765
- has_edges = True
766
-
767
- dot.append('}')
768
- graph_code = "\n".join(dot)
769
- st.graphviz_chart(graph_code, width="stretch")
770
-
771
- if not has_edges:
772
- st.info("No foreign key relationships detected in the schema metadata.")
773
-
774
- except Exception as e:
775
- st.error(f"Could not render diagram: {e}")
776
-
777
- except Exception as e:
778
- st.error(f"Error loading schema: {e}")
779
-
780
-
781
- def render_chat_interface():
782
- """Render the main chat interface."""
783
- st.title("🤖 OnceDataBot")
784
- st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL • Powered by Groq (FREE!)")
785
-
786
- # Schema explorer
787
- render_schema_explorer()
788
-
789
- # Chat container
790
- chat_container = st.container()
791
-
792
- with chat_container:
793
- # Display messages
794
- for i, msg in enumerate(st.session_state.messages):
795
- with st.chat_message(msg["role"]):
796
- # Create columns for message and favorite button
797
- msg_col, fav_col = st.columns([0.95, 0.05])
798
-
799
- with msg_col:
800
- st.markdown(msg["content"])
801
-
802
- with fav_col:
803
- # Favorite button for assistant messages
804
- if msg["role"] == "assistant":
805
- is_favorited = i in st.session_state.favorites
806
- if st.button(
807
- "⭐" if is_favorited else "☆",
808
- key=f"fav_{i}",
809
- help="Click to favorite/unfavorite this response"
810
- ):
811
- if is_favorited:
812
- st.session_state.favorites.remove(i)
813
- else:
814
- st.session_state.favorites.append(i)
815
- st.rerun()
816
-
817
- # Show metadata for assistant messages
818
- if msg["role"] == "assistant" and "metadata" in msg:
819
- meta = msg["metadata"]
820
-
821
- # Show token usage in a dropdown expander
822
- if "token_usage" in meta:
823
- usage = meta["token_usage"]
824
- total = usage.get('total', 0)
825
-
826
- with st.expander(f"📊 Token Usage ({total:,} total)", expanded=False):
827
- # Create styled token usage boxes using columns
828
- st.markdown("""
829
- <style>
830
- .token-box {
831
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
832
- border-radius: 12px;
833
- padding: 12px 16px;
834
- color: white;
835
- text-align: center;
836
- box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
837
- margin: 4px 0;
838
- }
839
- .token-box-input {
840
- background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
841
- box-shadow: 0 4px 15px rgba(17, 153, 142, 0.3);
842
- }
843
- .token-box-output {
844
- background: linear-gradient(135deg, #ee0979 0%, #ff6a00 100%);
845
- box-shadow: 0 4px 15px rgba(238, 9, 121, 0.3);
846
- }
847
- .token-box-total {
848
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
849
- box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
850
- }
851
- .token-label {
852
- font-size: 11px;
853
- text-transform: uppercase;
854
- letter-spacing: 1px;
855
- opacity: 0.9;
856
- margin-bottom: 4px;
857
- }
858
- .token-value {
859
- font-size: 20px;
860
- font-weight: 700;
861
- }
862
- </style>
863
- """, unsafe_allow_html=True)
864
-
865
- col1, col2, col3 = st.columns(3)
866
-
867
- with col1:
868
- st.markdown(f"""
869
- <div class="token-box token-box-input">
870
- <div class="token-label">📥 Input Tokens</div>
871
- <div class="token-value">{usage.get('input', 0):,}</div>
872
- </div>
873
- """, unsafe_allow_html=True)
874
-
875
- with col2:
876
- st.markdown(f"""
877
- <div class="token-box token-box-output">
878
- <div class="token-label">📤 Output Tokens</div>
879
- <div class="token-value">{usage.get('output', 0):,}</div>
880
- </div>
881
- """, unsafe_allow_html=True)
882
-
883
- with col3:
884
- st.markdown(f"""
885
- <div class="token-box token-box-total">
886
- <div class="token-label">📊 Total Tokens</div>
887
- <div class="token-value">{usage.get('total', 0):,}</div>
888
- </div>
889
- """, unsafe_allow_html=True)
890
-
891
- if meta.get("query_type"):
892
- # Show query type and execution time on same line
893
- info_text = f"Query type: {meta['query_type']}"
894
- if meta.get("execution_time"):
895
- info_text += f" ⏱️ {meta['execution_time']:.2f}s"
896
- st.caption(info_text)
897
-
898
- # SQL Query expander
899
- if meta.get("sql_query"):
900
- with st.expander("🛠️ SQL Query & Details"):
901
- st.code(meta["sql_query"], language="sql")
902
-
903
- # Visualizations and CSV export
904
- if meta.get("sql_results"):
905
- # Only render viz if we have results
906
- render_visualization(meta["sql_results"], f"viz_{i}")
907
-
908
- # CSV Export button
909
- csv_data = export_results_to_csv(meta["sql_results"])
910
- if csv_data:
911
- st.download_button(
912
- label="📊 Export to CSV",
913
- data=csv_data,
914
- file_name=f"query_results_{i}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
915
- mime="text/csv",
916
- key=f"csv_export_{i}",
917
- help="Download query results as CSV file"
918
- )
919
-
920
- # Chat input
921
- if prompt := st.chat_input("Ask about your data..."):
922
- if not st.session_state.initialized:
923
- st.error("Please connect to a database first!")
924
- return
925
-
926
- # Add user message
927
- st.session_state.messages.append({"role": "user", "content": prompt})
928
-
929
- # Calculate memory context for display? No, just render user msg
930
- with st.chat_message("user"):
931
- st.markdown(prompt)
932
-
933
- # Get response
934
- with st.spinner("Thinking..."):
935
- try:
936
- # Add memory interaction
937
- if st.session_state.memory:
938
- st.session_state.memory.add_message("user", prompt)
939
-
940
- # Track execution time
941
- start_time = time.time()
942
-
943
- response = st.session_state.chatbot.chat(
944
- prompt,
945
- st.session_state.memory,
946
- ignored_tables=list(st.session_state.ignored_tables),
947
- language=st.session_state.response_language
948
- )
949
-
950
- execution_time = time.time() - start_time
951
-
952
-
953
-
954
- # Create metadata dict
955
- metadata = {
956
- "query_type": response.query_type,
957
- "sql_query": response.sql_query,
958
- "sql_results": response.sql_results,
959
- "token_usage": response.token_usage,
960
- "execution_time": execution_time
961
- }
962
-
963
- # Save to session state
964
- st.session_state.messages.append({
965
- "role": "assistant",
966
- "content": response.answer,
967
- "metadata": metadata
968
- })
969
-
970
- # Set flag to auto-read the latest response
971
- st.session_state.auto_read_latest = True
972
-
973
- # Save to active memory
974
- if st.session_state.memory:
975
- st.session_state.memory.add_message("assistant", response.answer)
976
-
977
- st.rerun()
978
-
979
- except Exception as e:
980
- st.error(f"An error occurred: {e}")
981
- import traceback
982
- st.error(traceback.format_exc())
983
-
984
-
985
- def main():
986
- """Main application entry point."""
987
- init_session_state()
988
-
989
- # Auto-connect to environment database on first load
990
- if "auto_connect_attempted" not in st.session_state:
991
- st.session_state.auto_connect_attempted = True
992
- if st.session_state.db_source == "environment":
993
- success = initialize_chatbot()
994
- if success:
995
- st.toast("✅ Auto-connected to database!")
996
-
997
- render_sidebar()
998
- render_chat_interface()
999
-
1000
-
1001
- if __name__ == "__main__":
1002
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Schema-Agnostic Database Chatbot - Streamlit Application
3
+
4
+ A production-grade chatbot that connects to ANY database
5
+ (MySQL, PostgreSQL, SQLite) and provides intelligent querying
6
+ through RAG and Text-to-SQL.
7
+
8
+ Uses Groq for FREE LLM inference!
9
+ """
10
+
11
+ import os
12
+ from pathlib import Path
13
+
14
+ # Load .env FIRST before any other imports
15
+ from dotenv import load_dotenv
16
+ load_dotenv(Path(__file__).parent / ".env")
17
+
18
+ import streamlit as st
19
+ import uuid
20
+ import time
21
+ import io
22
+ import csv
23
+ import base64
24
+ import pandas as pd
25
+ from datetime import datetime
26
+
27
+ # Page config must be first
28
+ st.set_page_config(
29
+ page_title="OnceDataBot",
30
+ page_icon="🤖",
31
+ layout="wide",
32
+ initial_sidebar_state="expanded"
33
+ )
34
+
35
+ # Imports
36
+ from config import config, DatabaseConfig, DatabaseType
37
+ from database import get_db, get_schema, get_introspector
38
+ from database.connection import DatabaseConnection
39
+ from llm import create_llm_client
40
+ from chatbot import create_chatbot, DatabaseChatbot
41
+ from memory import ChatMemory, EnhancedChatMemory
42
+ from viz_utils import render_visualization
43
+
44
+
45
+
46
+
47
+
48
+
49
+ # Groq models (all FREE!)
50
+ GROQ_MODELS = [
51
+ "llama-3.3-70b-versatile",
52
+ "llama-3.1-8b-instant",
53
+ "mixtral-8x7b-32768",
54
+ "gemma2-9b-it"
55
+ ]
56
+
57
+ # Database types
58
+ DB_TYPES = {
59
+ "MySQL": "mysql",
60
+ "PostgreSQL": "postgresql",
61
+ "SQLite": "sqlite"
62
+ }
63
+
64
+ # Supported languages for multi-language responses
65
+ SUPPORTED_LANGUAGES = {
66
+ "English": "en",
67
+ "हिन्दी (Hindi)": "hi",
68
+ "Español (Spanish)": "es",
69
+ "Français (French)": "fr",
70
+ "Deutsch (German)": "de",
71
+ "中文 (Chinese)": "zh",
72
+ "日本語 (Japanese)": "ja",
73
+ "한국어 (Korean)": "ko",
74
+ "Português (Portuguese)": "pt",
75
+ "العربية (Arabic)": "ar",
76
+ "Русский (Russian)": "ru",
77
+ "Italiano (Italian)": "it",
78
+ "Nederlands (Dutch)": "nl",
79
+ "தமிழ் (Tamil)": "ta",
80
+ "తెలుగు (Telugu)": "te",
81
+ "मराठी (Marathi)": "mr",
82
+ "বাংলা (Bengali)": "bn",
83
+ "ગુજરાতી (Gujarati)": "gu"
84
+ }
85
+
86
+
87
+ def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
88
+ """Create a custom database configuration from user input."""
89
+ return DatabaseConfig(
90
+ db_type=DatabaseType(db_type),
91
+ host=kwargs.get("host", ""),
92
+ port=kwargs.get("port", 3306 if db_type == "mysql" else 5432),
93
+ database=kwargs.get("database", ""),
94
+ username=kwargs.get("username", ""),
95
+ password=kwargs.get("password", ""),
96
+ ssl_ca=kwargs.get("ssl_ca", None)
97
+ )
98
+
99
+
100
+ def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
101
+ enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory:
102
+ """Create enhanced memory with a custom database connection."""
103
+ return EnhancedChatMemory(
104
+ session_id=session_id,
105
+ user_id=user_id,
106
+ max_messages=20,
107
+ db_connection=db_connection,
108
+ llm_client=llm_client,
109
+ enable_summarization=enable_summarization,
110
+ summary_threshold=summary_threshold
111
+ )
112
+
113
+
114
+ def init_session_state():
115
+ """Initialize Streamlit session state."""
116
+ if "session_id" not in st.session_state:
117
+ st.session_state.session_id = str(uuid.uuid4())
118
+
119
+ if "messages" not in st.session_state:
120
+ st.session_state.messages = []
121
+
122
+ if "chatbot" not in st.session_state:
123
+ st.session_state.chatbot = None
124
+
125
+ if "initialized" not in st.session_state:
126
+ st.session_state.initialized = False
127
+
128
+ if "user_id" not in st.session_state:
129
+ st.session_state.user_id = "default"
130
+
131
+ if "enable_summarization" not in st.session_state:
132
+ st.session_state.enable_summarization = True
133
+
134
+ if "summary_threshold" not in st.session_state:
135
+ st.session_state.summary_threshold = 10
136
+
137
+ if "memory" not in st.session_state:
138
+ st.session_state.memory = None
139
+
140
+ if "indexed" not in st.session_state:
141
+ st.session_state.indexed = False
142
+
143
+ if "db_source" not in st.session_state:
144
+ st.session_state.db_source = "environment" # "environment" or "custom"
145
+
146
+ if "custom_db_config" not in st.session_state:
147
+ st.session_state.custom_db_config = None
148
+
149
+ if "custom_db_connection" not in st.session_state:
150
+ st.session_state.custom_db_connection = None
151
+
152
+ if "ignored_tables" not in st.session_state:
153
+ st.session_state.ignored_tables = set()
154
+
155
+ if "response_language" not in st.session_state:
156
+ st.session_state.response_language = "English"
157
+
158
+ if "favorites" not in st.session_state:
159
+ st.session_state.favorites = [] # List of message indices that are favorited
160
+
161
+
162
+ def export_results_to_csv(results: list) -> str:
163
+ """Convert SQL results to CSV format and return as downloadable string."""
164
+ if not results:
165
+ return ""
166
+
167
+ output = io.StringIO()
168
+ writer = csv.DictWriter(output, fieldnames=results[0].keys())
169
+ writer.writeheader()
170
+ writer.writerows(results)
171
+ return output.getvalue()
172
+
173
+
174
+ def export_chat_to_text() -> str:
175
+ """Export chat messages to text format."""
176
+ if not st.session_state.messages:
177
+ return "No messages to export."
178
+
179
+ lines = []
180
+ lines.append("=" * 50)
181
+ lines.append(f"OnceDataBot Chat Export")
182
+ lines.append(f"Exported: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
183
+ lines.append(f"User: {st.session_state.user_id}")
184
+ lines.append("=" * 50)
185
+ lines.append("")
186
+
187
+ for i, msg in enumerate(st.session_state.messages):
188
+ role = "🧑 User" if msg["role"] == "user" else "🤖 Assistant"
189
+ is_favorited = "⭐ " if i in st.session_state.favorites else ""
190
+ lines.append(f"{is_favorited}{role}:")
191
+ lines.append(msg["content"])
192
+
193
+ if msg["role"] == "assistant" and "metadata" in msg:
194
+ meta = msg["metadata"]
195
+ if meta.get("sql_query"):
196
+ lines.append(f"\n📝 SQL Query: {meta['sql_query']}")
197
+ if meta.get("query_type"):
198
+ lines.append(f"📌 Query Type: {meta['query_type']}")
199
+ if meta.get("execution_time"):
200
+ lines.append(f"⏱️ Execution Time: {meta['execution_time']:.2f}s")
201
+
202
+ lines.append("-" * 40)
203
+ lines.append("")
204
+
205
+ return "\n".join(lines)
206
+
207
+
208
+ def render_copy_button(text: str, key: str):
209
+ """Render a copy to clipboard button using Streamlit."""
210
+ # Using a workaround with st.code which has built-in copy
211
+ st.code(text, language="sql")
212
+
213
+
214
+ def render_database_config():
215
+ """Render database configuration section in sidebar."""
216
+ st.subheader("🗄️ Database Configuration")
217
+
218
+ # Database source selection
219
+ db_source = st.radio(
220
+ "Database Source",
221
+ options=["Use Environment Variables", "Custom Database"],
222
+ index=0 if st.session_state.db_source == "environment" else 1,
223
+ key="db_source_radio",
224
+ help="Choose to use .env settings or enter custom credentials"
225
+ )
226
+
227
+ st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom"
228
+
229
+ if st.session_state.db_source == "environment":
230
+ # Show current environment config
231
+ current_db_type = config.database.db_type.value.upper()
232
+ st.info(f"📌 Using {current_db_type} from environment")
233
+ st.caption(f"Host: {config.database.host}")
234
+ return None
235
+
236
+ else:
237
+ # Custom database configuration
238
+ st.markdown("##### Enter Database Credentials")
239
+
240
+ # Database type selector
241
+ db_type_label = st.selectbox(
242
+ "Database Type",
243
+ options=list(DB_TYPES.keys()),
244
+ index=0,
245
+ key="custom_db_type"
246
+ )
247
+ db_type = DB_TYPES[db_type_label]
248
+
249
+ if db_type == "sqlite":
250
+ # SQLite only needs a file path
251
+ database = st.text_input(
252
+ "SQLite Database File",
253
+ value="ingested_data.db",
254
+ key="db_sqlite_path",
255
+ help="Path to the .db file (will be created if it doesn't exist)"
256
+ )
257
+ return {
258
+ "db_type": db_type,
259
+ "database": database
260
+ }
261
+
262
+ else: # MySQL or PostgreSQL
263
+ # MySQL or PostgreSQL
264
+ col1, col2 = st.columns([3, 1])
265
+ with col1:
266
+ host = st.text_input(
267
+ "Host",
268
+ value="",
269
+ key="db_host_input",
270
+ placeholder="your-database-host.com"
271
+ )
272
+ with col2:
273
+ default_port = 3306 if db_type == "mysql" else 5432
274
+ port = st.number_input(
275
+ "Port",
276
+ value=default_port,
277
+ min_value=1,
278
+ max_value=65535,
279
+ key="db_port_input"
280
+ )
281
+
282
+ database = st.text_input(
283
+ "Database Name",
284
+ value="",
285
+ key="db_name_input",
286
+ placeholder="your_database"
287
+ )
288
+
289
+ username = st.text_input(
290
+ "Username",
291
+ value="",
292
+ key="db_user_input",
293
+ placeholder="your_username"
294
+ )
295
+
296
+ password = st.text_input(
297
+ "Password",
298
+ value="",
299
+ type="password",
300
+ key="db_pass_input"
301
+ )
302
+
303
+ # Optional SSL
304
+ with st.expander("🔒 SSL Settings (Optional)"):
305
+ ssl_ca = st.text_input(
306
+ "SSL CA Certificate Path",
307
+ value="",
308
+ key="ssl_ca_input",
309
+ help="Path to SSL CA certificate file (for cloud databases like Aiven)"
310
+ )
311
+
312
+ return {
313
+ "db_type": db_type,
314
+ "host": host,
315
+ "port": int(port),
316
+ "database": database,
317
+ "username": username,
318
+ "password": password,
319
+ "ssl_ca": ssl_ca if ssl_ca else None
320
+ }
321
+
322
+
323
+ def render_sidebar():
324
+ """Render the configuration sidebar."""
325
+ with st.sidebar:
326
+ st.title("⚙️ Settings")
327
+
328
+ # Session Dashboard
329
+ if st.session_state.messages:
330
+ st.markdown("### 📊 Session Stats")
331
+
332
+ # Calculate stats
333
+ total_msgs = len(st.session_state.messages)
334
+ assistant_msgs = [m for m in st.session_state.messages if m.get("role") == "assistant"]
335
+ sql_queries = sum(1 for m in assistant_msgs if m.get("metadata", {}).get("sql_query"))
336
+
337
+ total_tokens = 0
338
+ exec_times = []
339
+ for m in assistant_msgs:
340
+ meta = m.get("metadata", {})
341
+ total_tokens += meta.get("token_usage", {}).get("total", 0)
342
+ if meta.get("execution_time"):
343
+ exec_times.append(meta["execution_time"])
344
+
345
+ avg_time = sum(exec_times) / len(exec_times) if exec_times else 0
346
+
347
+ col_s1, col_s2 = st.columns(2)
348
+ col_s1.metric("Queries", sql_queries)
349
+ col_s2.metric("Tokens", f"{total_tokens:,}")
350
+ st.caption(f"⏱️ Avg Time: {avg_time:.2f}s | 💬 Msgs: {total_msgs}")
351
+ st.divider()
352
+
353
+
354
+
355
+ # User Profile
356
+ st.subheader("👤 User Profile")
357
+ user_id = st.text_input(
358
+ "User ID / Name",
359
+ value=st.session_state.get("user_id", "default"),
360
+ key="user_id_input",
361
+ help="Your unique ID for private memory storage"
362
+ )
363
+ if user_id != st.session_state.get("user_id"):
364
+ st.session_state.user_id = user_id
365
+ st.session_state.session_id = str(uuid.uuid4())
366
+ st.session_state.messages = []
367
+
368
+ # Recreate memory for new user
369
+ if st.session_state.custom_db_connection:
370
+ st.session_state.memory = create_custom_memory(
371
+ st.session_state.session_id,
372
+ user_id,
373
+ st.session_state.custom_db_connection,
374
+ st.session_state.get("llm"),
375
+ st.session_state.enable_summarization,
376
+ st.session_state.summary_threshold
377
+ )
378
+ elif st.session_state.initialized:
379
+ from memory import create_enhanced_memory
380
+ st.session_state.memory = create_enhanced_memory(
381
+ st.session_state.session_id,
382
+ user_id=user_id,
383
+ enable_summarization=st.session_state.enable_summarization,
384
+ summary_threshold=st.session_state.summary_threshold
385
+ )
386
+
387
+ if st.session_state.memory:
388
+ st.session_state.memory.clear_user_history()
389
+ st.rerun()
390
+
391
+ st.divider()
392
+
393
+ # Language Selection
394
+ st.subheader("🌐 Response Language")
395
+ selected_language = st.selectbox(
396
+ "Select Language",
397
+ options=list(SUPPORTED_LANGUAGES.keys()),
398
+ index=list(SUPPORTED_LANGUAGES.keys()).index(st.session_state.response_language),
399
+ key="language_selector",
400
+ help="Choose the language for chatbot responses"
401
+ )
402
+ if selected_language != st.session_state.response_language:
403
+ st.session_state.response_language = selected_language
404
+ st.toast(f"🌐 Language changed to {selected_language}")
405
+
406
+ st.divider()
407
+
408
+ if st.session_state.messages:
409
+ st.download_button(
410
+ label="📄 Export Chat",
411
+ data=export_chat_to_text(),
412
+ file_name=f"chat_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
413
+ mime="text/plain",
414
+ use_container_width=True,
415
+ help="Download your chat conversation as a text file"
416
+ )
417
+
418
+ st.divider()
419
+
420
+ # CSV Ingestion Section
421
+ st.subheader("📥 Ingest CSV Data")
422
+ uploaded_files = st.file_uploader(
423
+ "Upload CSV(s) to create database",
424
+ type=["csv"],
425
+ accept_multiple_files=True,
426
+ help="Your CSVs will be converted to tables in a local SQLite database"
427
+ )
428
+
429
+ if uploaded_files:
430
+ if st.button("🚀 Upload & Initialize", use_container_width=True):
431
+ with st.spinner("Processing CSVs..."):
432
+ success_count = 0
433
+ table_names = []
434
+ for uploaded_file in uploaded_files:
435
+ success, name, rows = ingest_csv(uploaded_file)
436
+ if success:
437
+ success_count += 1
438
+ table_names.append(name)
439
+ else:
440
+ st.error(f"Failed to ingest {uploaded_file.name}: {name}")
441
+
442
+ if success_count > 0:
443
+ st.success(f"Successfully ingested {success_count} file(s) as tables: {', '.join(table_names)}")
444
+
445
+ # Now initialize chatbot with this SQLite DB
446
+ sqlite_params = {
447
+ "db_type": "sqlite",
448
+ "database": "ingested_data.db"
449
+ }
450
+ # Temporarily set db_source to custom for initialization
451
+ old_source = st.session_state.db_source
452
+ st.session_state.db_source = "custom"
453
+ init_success = initialize_chatbot(sqlite_params, None, None)
454
+ if not init_success:
455
+ st.session_state.db_source = old_source
456
+ else:
457
+ st.rerun()
458
+
459
+ st.divider()
460
+
461
+ # Database Configuration
462
+ custom_db_params = render_database_config()
463
+
464
+ st.divider()
465
+
466
+ # LLM Configuration
467
+ st.subheader("🤖 LLM Configuration")
468
+
469
+ # Show status of API key
470
+ if os.getenv("GROQ_API_KEY"):
471
+ st.success("✓ API Key configured")
472
+ else:
473
+ st.warning("⚠️ GROQ_API_KEY not set in environment")
474
+
475
+ st.divider()
476
+
477
+ # Initialize Button
478
+ if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"):
479
+ with st.spinner("Connecting to database..."):
480
+ success = initialize_chatbot(custom_db_params, None, None)
481
+ if success:
482
+ st.success("✅ Connected!")
483
+ st.rerun()
484
+
485
+ # Index Button (after initialization)
486
+ if st.session_state.initialized:
487
+ if st.button("📚 Index Text Data", use_container_width=True):
488
+ with st.spinner("Indexing text data..."):
489
+ index_data()
490
+ st.success("✅ Indexed!")
491
+ st.rerun()
492
+
493
+ st.divider()
494
+
495
+ # Status
496
+ st.subheader("📊 Status")
497
+ if st.session_state.initialized:
498
+ # Show database type
499
+ if st.session_state.custom_db_connection:
500
+ db_type = st.session_state.custom_db_connection.db_type.value.upper()
501
+ else:
502
+ db_type = get_db().db_type.value.upper()
503
+
504
+ st.success(f"Database: {db_type} ✓")
505
+
506
+ try:
507
+ schema = get_schema()
508
+ st.info(f"Tables: {len(schema.tables)}")
509
+ except:
510
+ st.warning("Schema not loaded")
511
+
512
+ if st.session_state.indexed:
513
+ from rag import get_rag_engine
514
+ engine = get_rag_engine()
515
+ st.info(f"Indexed Docs: {engine.document_count}")
516
+ else:
517
+ st.warning("Not connected")
518
+
519
+ # New Chat
520
+ if st.button("➕ New Chat", use_container_width=True, type="secondary"):
521
+ if st.session_state.memory:
522
+ st.session_state.memory.clear()
523
+
524
+ st.session_state.messages = []
525
+ st.session_state.session_id = str(uuid.uuid4())
526
+
527
+ current_user = st.session_state.get("user_id", "default")
528
+
529
+ if st.session_state.custom_db_connection:
530
+ st.session_state.memory = create_custom_memory(
531
+ st.session_state.session_id,
532
+ current_user,
533
+ st.session_state.custom_db_connection,
534
+ st.session_state.get("llm"),
535
+ st.session_state.enable_summarization,
536
+ st.session_state.summary_threshold
537
+ )
538
+ elif st.session_state.initialized:
539
+ from memory import create_enhanced_memory
540
+ st.session_state.memory = create_enhanced_memory(
541
+ st.session_state.session_id,
542
+ user_id=current_user,
543
+ enable_summarization=st.session_state.enable_summarization,
544
+ summary_threshold=st.session_state.summary_threshold
545
+ )
546
+ if st.session_state.get("llm"):
547
+ st.session_state.memory.set_llm_client(st.session_state.llm)
548
+
549
+ st.rerun()
550
+
551
+ # Disconnect button (when using custom DB)
552
+ if st.session_state.initialized and st.session_state.db_source == "custom":
553
+ if st.button("🔌 Disconnect", use_container_width=True):
554
+ if st.session_state.custom_db_connection:
555
+ st.session_state.custom_db_connection.close()
556
+ st.session_state.custom_db_connection = None
557
+ st.session_state.chatbot = None
558
+ st.session_state.initialized = False
559
+ st.session_state.indexed = False
560
+ st.session_state.memory = None
561
+ st.success("Disconnected!")
562
+ st.rerun()
563
+
564
+ st.divider()
565
+
566
+ # Chat History Section
567
+ if st.session_state.memory:
568
+ st.subheader("🕰️ Chat History")
569
+ sessions = st.session_state.memory.get_user_sessions()
570
+
571
+ if not sessions:
572
+ st.caption("No previous chats found.")
573
+ else:
574
+ for session in sessions:
575
+ # Highlight current session
576
+ is_current = session["id"] == st.session_state.session_id
577
+ icon = "🟢" if is_current else "💬"
578
+
579
+ if st.button(
580
+ f"{icon} {session['title']}",
581
+ key=f"hist_{session['id']}",
582
+ use_container_width=True,
583
+ type="secondary" if not is_current else "primary"
584
+ ):
585
+ if not is_current:
586
+ # Load selected session
587
+ st.session_state.session_id = session["id"]
588
+ st.session_state.memory.session_id = session["id"]
589
+ st.session_state.memory.messages = [] # Clear current state local cache
590
+
591
+ # Load from DB
592
+ msgs = st.session_state.memory.load_session(session["id"])
593
+ st.session_state.messages = msgs
594
+
595
+ # Re-populate memory object messages list for context
596
+ # (We need to convert dicts back to ChatMessage objects implicitly or just rely on reload)
597
+ # Actually, we should probably re-init the memory to be safe or manually populate
598
+ # Let's manually populate to keep the connection valid
599
+ from memory import ChatMessage
600
+ st.session_state.memory.messages = [
601
+ ChatMessage(
602
+ role=m['role'],
603
+ content=m['content'],
604
+ metadata=m.get('metadata')
605
+ ) for m in msgs
606
+ ]
607
+
608
+ st.rerun()
609
+
610
+
611
+ def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
612
+ """Initialize the chatbot with either environment or custom database."""
613
+ try:
614
+ # Get API key
615
+ groq_api_key = api_key or os.getenv("GROQ_API_KEY", "")
616
+ groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
617
+
618
+ if not groq_api_key:
619
+ st.error("GROQ_API_KEY not configured. Please enter your API key.")
620
+ return False
621
+
622
+ # Create LLM client
623
+ llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model)
624
+
625
+ # Create database connection
626
+ if custom_db_params and st.session_state.db_source == "custom":
627
+ # Validate custom params
628
+ db_type = custom_db_params.get("db_type", "mysql")
629
+
630
+ if db_type != "sqlite":
631
+ if not all([custom_db_params.get("host"),
632
+ custom_db_params.get("database"),
633
+ custom_db_params.get("username")]):
634
+ st.error("Please fill in all required database fields.")
635
+ return False
636
+ else:
637
+ if not custom_db_params.get("database"):
638
+ st.error("Please specify a SQLite database file path.")
639
+ return False
640
+
641
+ # Create custom config
642
+ db_config = create_custom_db_config(**custom_db_params)
643
+
644
+ # Create custom connection
645
+ custom_connection = DatabaseConnection(db_config)
646
+
647
+ # Test connection
648
+ success, msg = custom_connection.test_connection()
649
+ if not success:
650
+ st.error(f"Connection failed: {msg}")
651
+ return False
652
+
653
+ st.session_state.custom_db_connection = custom_connection
654
+ st.session_state.custom_db_config = db_config
655
+
656
+ # Override the global db connection for the chatbot
657
+ # We need to create a chatbot with this custom connection
658
+ from chatbot import DatabaseChatbot
659
+ from database.schema_introspector import SchemaIntrospector
660
+ from rag import get_rag_engine
661
+ from sql import get_sql_generator, get_sql_validator
662
+ from router import get_query_router
663
+
664
+ chatbot = DatabaseChatbot.__new__(DatabaseChatbot)
665
+ chatbot.db = custom_connection
666
+ chatbot.introspector = SchemaIntrospector()
667
+ chatbot.introspector.db = custom_connection
668
+ chatbot.rag_engine = get_rag_engine()
669
+ chatbot.sql_generator = get_sql_generator(db_type)
670
+ chatbot.sql_validator = get_sql_validator()
671
+ chatbot.router = get_query_router()
672
+ chatbot.llm_client = llm
673
+ chatbot._schema_initialized = False
674
+ chatbot._rag_initialized = False
675
+
676
+ # Set LLM client
677
+ chatbot.set_llm_client(llm)
678
+
679
+ # Initialize (introspect schema)
680
+ schema = chatbot.introspector.introspect(force_refresh=True)
681
+ chatbot.sql_validator.set_allowed_tables(schema.table_names)
682
+ chatbot._schema_initialized = True
683
+
684
+ st.session_state.chatbot = chatbot
685
+
686
+ else:
687
+ # Use environment-based connection (existing flow)
688
+ chatbot = create_chatbot(llm)
689
+ chatbot.set_llm_client(llm)
690
+
691
+ success, msg = chatbot.initialize()
692
+ if not success:
693
+ st.error(f"Initialization failed: {msg}")
694
+ return False
695
+
696
+ st.session_state.chatbot = chatbot
697
+ st.session_state.custom_db_connection = None
698
+
699
+ st.session_state.llm = llm
700
+ st.session_state.initialized = True
701
+ st.session_state.indexed = False # Reset index status on new connection
702
+
703
+ # Clear RAG index to ensure no data from previous DB connection persists
704
+ if hasattr(chatbot, 'rag_engine') and hasattr(chatbot.rag_engine, 'clear_index'):
705
+ chatbot.rag_engine.clear_index()
706
+
707
+ # Create memory with appropriate connection
708
+ db_conn = st.session_state.custom_db_connection or get_db()
709
+ st.session_state.memory = create_custom_memory(
710
+ st.session_state.session_id,
711
+ st.session_state.user_id,
712
+ db_conn,
713
+ llm,
714
+ st.session_state.enable_summarization,
715
+ st.session_state.summary_threshold
716
+ )
717
+
718
+ return True
719
+
720
+ except Exception as e:
721
+ st.error(f"Error: {str(e)}")
722
+ import traceback
723
+ st.error(traceback.format_exc())
724
+ return False
725
+
726
+
727
+ def ingest_csv(uploaded_file):
728
+ """Ingest a CSV file into a SQLite database."""
729
+ from sqlalchemy import create_engine
730
+
731
+ try:
732
+ # 1. Read CSV
733
+ # Reset file pointer to beginning in case it was read before
734
+ uploaded_file.seek(0)
735
+ df = pd.read_csv(uploaded_file)
736
+
737
+ # 2. Clean table name from filename
738
+ table_name = Path(uploaded_file.name).stem.replace(" ", "_").replace("-", "_").lower()
739
+ # Ensure it starts with a letter and only contains alphanumeric/underscore
740
+ table_name = "".join([c for c in table_name if c.isalnum() or c == "_"])
741
+ if not table_name[0].isalpha():
742
+ table_name = "t_" + table_name
743
+
744
+ # 3. Create/Connect to SQLite DB
745
+ db_path = "ingested_data.db"
746
+ engine = create_engine(f"sqlite:///{db_path}")
747
+
748
+ # 4. Write to DB
749
+ df.to_sql(table_name, engine, if_exists='replace', index=False)
750
+
751
+ return True, table_name, len(df)
752
+ except Exception as e:
753
+ return False, str(e), 0
754
+
755
+
756
+ def index_data():
757
+ """Index text data from the database."""
758
+ if st.session_state.chatbot:
759
+ progress = st.progress(0)
760
+ status = st.empty()
761
+
762
+ # Get schema from the correct introspector
763
+ schema = st.session_state.chatbot.introspector.introspect()
764
+ total_tables = len(schema.tables)
765
+ indexed = 0
766
+
767
+ def progress_callback(table_name, docs):
768
+ nonlocal indexed
769
+ indexed += 1
770
+ progress.progress(indexed / total_tables)
771
+ status.text(f"Indexed {table_name}: {docs} documents")
772
+
773
+ total_docs = st.session_state.chatbot.index_text_data(progress_callback)
774
+ st.session_state.indexed = True
775
+ status.text(f"Total: {total_docs} documents indexed")
776
+
777
+
778
+ def render_schema_explorer():
779
+ """Render schema explorer in an expander."""
780
+ if not st.session_state.initialized:
781
+ return
782
+
783
+ with st.expander("📋 Database Schema", expanded=False):
784
+ try:
785
+ schema = st.session_state.chatbot.introspector.introspect()
786
+
787
+ tab_list, tab_erd = st.tabs(["📋 Table List", "🕸️ Schema Diagram"])
788
+
789
+ with tab_list:
790
+ st.markdown("Uncheck tables to exclude them from the chat context.")
791
+
792
+ for table_name, table_info in schema.tables.items():
793
+ col1, col2 = st.columns([0.05, 0.95])
794
+
795
+ with col1:
796
+ is_active = table_name not in st.session_state.ignored_tables
797
+ active = st.checkbox(
798
+ "Use",
799
+ value=is_active,
800
+ key=f"use_{table_name}",
801
+ label_visibility="collapsed",
802
+ help=f"Include {table_name} in chat analysis"
803
+ )
804
+
805
+ if not active:
806
+ st.session_state.ignored_tables.add(table_name)
807
+ else:
808
+ st.session_state.ignored_tables.discard(table_name)
809
+
810
+ with col2:
811
+ with st.container():
812
+ st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
813
+
814
+ cols = []
815
+ for col in table_info.columns:
816
+ pk = "🔑" if col.is_primary_key else ""
817
+ txt = "📝" if col.is_text_type else ""
818
+ cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
819
+
820
+ st.caption(" | ".join(cols))
821
+ st.divider()
822
+
823
+ with tab_erd:
824
+ if len(schema.tables) > 50:
825
+ st.warning("⚠️ Too many tables to visualize effectively (limit: 50).")
826
+ else:
827
+ try:
828
+ # Build Graphviz DOT string
829
+ dot = ['digraph Database {']
830
+ dot.append(' rankdir=LR;')
831
+ dot.append(' node [shape=box, style="filled,rounded", fillcolor="#f0f2f6", fontname="Arial", fontsize=10];')
832
+ dot.append(' edge [fontname="Arial", fontsize=9, color="#666666"];')
833
+
834
+ # Add nodes (tables)
835
+ for table_name in schema.tables:
836
+ if table_name not in st.session_state.ignored_tables:
837
+ dot.append(f' "{table_name}" [label="{table_name}", fillcolor="#e1effe", color="#1e40af"];')
838
+ else:
839
+ dot.append(f' "{table_name}" [label="{table_name} (ignored)", fillcolor="#f3f4f6", color="#9ca3af", fontcolor="#9ca3af"];')
840
+
841
+ # Add edges (relationships)
842
+ has_edges = False
843
+ for table_name, table_info in schema.tables.items():
844
+ for col_name, ref_str in table_info.foreign_keys.items():
845
+ # ref_str format: "referenced_table.referenced_column"
846
+ if "." in ref_str:
847
+ ref_table = ref_str.split(".")[0]
848
+ # specific_col = ref_str.split(".")[1]
849
+
850
+ # Only draw if both tables exist in our schema list
851
+ if ref_table in schema.tables:
852
+ dot.append(f' "{table_name}" -> "{ref_table}" [label="{col_name}"];')
853
+ has_edges = True
854
+
855
+ dot.append('}')
856
+ graph_code = "\n".join(dot)
857
+ st.graphviz_chart(graph_code, width="stretch")
858
+
859
+ if not has_edges:
860
+ st.info("No foreign key relationships detected in the schema metadata.")
861
+
862
+ except Exception as e:
863
+ st.error(f"Could not render diagram: {e}")
864
+
865
+ except Exception as e:
866
+ st.error(f"Error loading schema: {e}")
867
+
868
+
869
+ def render_chat_interface():
870
+ """Render the main chat interface."""
871
+ st.title("🤖 OnceDataBot")
872
+ st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL • Powered by Groq (FREE!)")
873
+
874
+ # Schema explorer
875
+ render_schema_explorer()
876
+
877
+ # Chat container
878
+ chat_container = st.container()
879
+
880
+ with chat_container:
881
+ # Display messages
882
+ for i, msg in enumerate(st.session_state.messages):
883
+ with st.chat_message(msg["role"]):
884
+ # Create columns for message and favorite button
885
+ msg_col, fav_col = st.columns([0.95, 0.05])
886
+
887
+ with msg_col:
888
+ st.markdown(msg["content"])
889
+
890
+ with fav_col:
891
+ # Favorite button for assistant messages
892
+ if msg["role"] == "assistant":
893
+ is_favorited = i in st.session_state.favorites
894
+ if st.button(
895
+ "⭐" if is_favorited else "☆",
896
+ key=f"fav_{i}",
897
+ help="Click to favorite/unfavorite this response"
898
+ ):
899
+ if is_favorited:
900
+ st.session_state.favorites.remove(i)
901
+ else:
902
+ st.session_state.favorites.append(i)
903
+ st.rerun()
904
+
905
+ # Show metadata for assistant messages
906
+ if msg["role"] == "assistant" and "metadata" in msg:
907
+ meta = msg["metadata"]
908
+
909
+ # Show token usage in a dropdown expander
910
+ if "token_usage" in meta:
911
+ usage = meta["token_usage"]
912
+ total = usage.get('total', 0)
913
+
914
+ with st.expander(f"📊 Token Usage ({total:,} total)", expanded=False):
915
+ # Create styled token usage boxes using columns
916
+ st.markdown("""
917
+ <style>
918
+ .token-box {
919
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
920
+ border-radius: 12px;
921
+ padding: 12px 16px;
922
+ color: white;
923
+ text-align: center;
924
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
925
+ margin: 4px 0;
926
+ }
927
+ .token-box-input {
928
+ background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
929
+ box-shadow: 0 4px 15px rgba(17, 153, 142, 0.3);
930
+ }
931
+ .token-box-output {
932
+ background: linear-gradient(135deg, #ee0979 0%, #ff6a00 100%);
933
+ box-shadow: 0 4px 15px rgba(238, 9, 121, 0.3);
934
+ }
935
+ .token-box-total {
936
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
937
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
938
+ }
939
+ .token-label {
940
+ font-size: 11px;
941
+ text-transform: uppercase;
942
+ letter-spacing: 1px;
943
+ opacity: 0.9;
944
+ margin-bottom: 4px;
945
+ }
946
+ .token-value {
947
+ font-size: 20px;
948
+ font-weight: 700;
949
+ }
950
+ </style>
951
+ """, unsafe_allow_html=True)
952
+
953
+ col1, col2, col3 = st.columns(3)
954
+
955
+ with col1:
956
+ st.markdown(f"""
957
+ <div class="token-box token-box-input">
958
+ <div class="token-label">📥 Input Tokens</div>
959
+ <div class="token-value">{usage.get('input', 0):,}</div>
960
+ </div>
961
+ """, unsafe_allow_html=True)
962
+
963
+ with col2:
964
+ st.markdown(f"""
965
+ <div class="token-box token-box-output">
966
+ <div class="token-label">📤 Output Tokens</div>
967
+ <div class="token-value">{usage.get('output', 0):,}</div>
968
+ </div>
969
+ """, unsafe_allow_html=True)
970
+
971
+ with col3:
972
+ st.markdown(f"""
973
+ <div class="token-box token-box-total">
974
+ <div class="token-label">📊 Total Tokens</div>
975
+ <div class="token-value">{usage.get('total', 0):,}</div>
976
+ </div>
977
+ """, unsafe_allow_html=True)
978
+
979
+ if meta.get("query_type"):
980
+ # Show query type and execution time on same line
981
+ info_text = f"Query type: {meta['query_type']}"
982
+ if meta.get("execution_time"):
983
+ info_text += f" • ⏱️ {meta['execution_time']:.2f}s"
984
+ st.caption(info_text)
985
+
986
+ # SQL Query expander
987
+ if meta.get("sql_query"):
988
+ with st.expander("🛠️ SQL Query & Details"):
989
+ st.code(meta["sql_query"], language="sql")
990
+
991
+ # Visualizations and CSV export
992
+ if meta.get("sql_results"):
993
+ # Only render viz if we have results
994
+ render_visualization(meta["sql_results"], f"viz_{i}")
995
+
996
+ # CSV Export button
997
+ csv_data = export_results_to_csv(meta["sql_results"])
998
+ if csv_data:
999
+ st.download_button(
1000
+ label="📊 Export to CSV",
1001
+ data=csv_data,
1002
+ file_name=f"query_results_{i}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
1003
+ mime="text/csv",
1004
+ key=f"csv_export_{i}",
1005
+ help="Download query results as CSV file"
1006
+ )
1007
+
1008
+ # Chat input
1009
+ if prompt := st.chat_input("Ask about your data..."):
1010
+ if not st.session_state.initialized:
1011
+ st.error("Please connect to a database first!")
1012
+ return
1013
+
1014
+ # Add user message
1015
+ st.session_state.messages.append({"role": "user", "content": prompt})
1016
+
1017
+ # Calculate memory context for display? No, just render user msg
1018
+ with st.chat_message("user"):
1019
+ st.markdown(prompt)
1020
+
1021
+ # Get response
1022
+ with st.spinner("Thinking..."):
1023
+ try:
1024
+ # Add memory interaction
1025
+ if st.session_state.memory:
1026
+ st.session_state.memory.add_message("user", prompt)
1027
+
1028
+ # Track execution time
1029
+ start_time = time.time()
1030
+
1031
+ response = st.session_state.chatbot.chat(
1032
+ prompt,
1033
+ st.session_state.memory,
1034
+ ignored_tables=list(st.session_state.ignored_tables),
1035
+ language=st.session_state.response_language
1036
+ )
1037
+
1038
+ execution_time = time.time() - start_time
1039
+
1040
+
1041
+
1042
+ # Create metadata dict
1043
+ metadata = {
1044
+ "query_type": response.query_type,
1045
+ "sql_query": response.sql_query,
1046
+ "sql_results": response.sql_results,
1047
+ "token_usage": response.token_usage,
1048
+ "execution_time": execution_time
1049
+ }
1050
+
1051
+ # Save to session state
1052
+ st.session_state.messages.append({
1053
+ "role": "assistant",
1054
+ "content": response.answer,
1055
+ "metadata": metadata
1056
+ })
1057
+
1058
+ # Set flag to auto-read the latest response
1059
+ st.session_state.auto_read_latest = True
1060
+
1061
+ # Save to active memory
1062
+ if st.session_state.memory:
1063
+ st.session_state.memory.add_message("assistant", response.answer)
1064
+
1065
+ st.rerun()
1066
+
1067
+ except Exception as e:
1068
+ st.error(f"An error occurred: {e}")
1069
+ import traceback
1070
+ st.error(traceback.format_exc())
1071
+
1072
+
1073
+ def main():
1074
+ """Main application entry point."""
1075
+ init_session_state()
1076
+
1077
+ # Auto-connect to environment database on first load
1078
+ if "auto_connect_attempted" not in st.session_state:
1079
+ st.session_state.auto_connect_attempted = True
1080
+ if st.session_state.db_source == "environment":
1081
+ success = initialize_chatbot()
1082
+ if success:
1083
+ st.toast("✅ Auto-connected to database!")
1084
+
1085
+ render_sidebar()
1086
+ render_chat_interface()
1087
+
1088
+
1089
+ if __name__ == "__main__":
1090
+ main()