Vanshcc commited on
Commit
bed117c
Β·
verified Β·
1 Parent(s): a8441ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +616 -613
app.py CHANGED
@@ -1,613 +1,616 @@
1
- """
2
- Schema-Agnostic Database Chatbot - Streamlit Application
3
-
4
- A production-grade chatbot that connects to ANY database
5
- (MySQL, PostgreSQL, SQLite) and provides intelligent querying
6
- through RAG and Text-to-SQL.
7
-
8
- Uses Groq for FREE LLM inference!
9
- """
10
-
11
- import os
12
- from pathlib import Path
13
-
14
- # Load .env FIRST before any other imports
15
- from dotenv import load_dotenv
16
- load_dotenv(Path(__file__).parent / ".env")
17
-
18
- import streamlit as st
19
- import uuid
20
- from datetime import datetime
21
-
22
- # Page config must be first
23
- st.set_page_config(
24
- page_title="OnceDataBot",
25
- page_icon="πŸ€–",
26
- layout="wide",
27
- initial_sidebar_state="expanded"
28
- )
29
-
30
- # Imports
31
- from config import config, DatabaseConfig, DatabaseType
32
- from database import get_db, get_schema, get_introspector
33
- from database.connection import DatabaseConnection
34
- from llm import create_llm_client
35
- from chatbot import create_chatbot, DatabaseChatbot
36
- from memory import ChatMemory, EnhancedChatMemory
37
-
38
-
39
- # Groq models (all FREE!)
40
- GROQ_MODELS = [
41
- "llama-3.3-70b-versatile",
42
- "llama-3.1-8b-instant",
43
- "mixtral-8x7b-32768",
44
- "gemma2-9b-it"
45
- ]
46
-
47
- # Database types
48
- DB_TYPES = {
49
- "MySQL": "mysql",
50
- "PostgreSQL": "postgresql"
51
- }
52
-
53
-
54
- def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
55
- """Create a custom database configuration from user input."""
56
- return DatabaseConfig(
57
- db_type=DatabaseType(db_type),
58
- host=kwargs.get("host", ""),
59
- port=kwargs.get("port", 3306 if db_type == "mysql" else 5432),
60
- database=kwargs.get("database", ""),
61
- username=kwargs.get("username", ""),
62
- password=kwargs.get("password", ""),
63
- ssl_ca=kwargs.get("ssl_ca", None)
64
- )
65
-
66
-
67
- def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
68
- enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory:
69
- """Create enhanced memory with a custom database connection."""
70
- return EnhancedChatMemory(
71
- session_id=session_id,
72
- user_id=user_id,
73
- max_messages=20,
74
- db_connection=db_connection,
75
- llm_client=llm_client,
76
- enable_summarization=enable_summarization,
77
- summary_threshold=summary_threshold
78
- )
79
-
80
-
81
- def init_session_state():
82
- """Initialize Streamlit session state."""
83
- if "session_id" not in st.session_state:
84
- st.session_state.session_id = str(uuid.uuid4())
85
-
86
- if "messages" not in st.session_state:
87
- st.session_state.messages = []
88
-
89
- if "chatbot" not in st.session_state:
90
- st.session_state.chatbot = None
91
-
92
- if "initialized" not in st.session_state:
93
- st.session_state.initialized = False
94
-
95
- if "user_id" not in st.session_state:
96
- st.session_state.user_id = "default"
97
-
98
- if "enable_summarization" not in st.session_state:
99
- st.session_state.enable_summarization = True
100
-
101
- if "summary_threshold" not in st.session_state:
102
- st.session_state.summary_threshold = 10
103
-
104
- if "memory" not in st.session_state:
105
- st.session_state.memory = None
106
-
107
- if "indexed" not in st.session_state:
108
- st.session_state.indexed = False
109
-
110
- if "db_source" not in st.session_state:
111
- st.session_state.db_source = "environment" # "environment" or "custom"
112
-
113
- if "custom_db_config" not in st.session_state:
114
- st.session_state.custom_db_config = None
115
-
116
- if "custom_db_connection" not in st.session_state:
117
- st.session_state.custom_db_connection = None
118
-
119
-
120
- def render_database_config():
121
- """Render database configuration section in sidebar."""
122
- st.subheader("πŸ—„οΈ Database Configuration")
123
-
124
- # Database source selection
125
- db_source = st.radio(
126
- "Database Source",
127
- options=["Use Environment Variables", "Custom Database"],
128
- index=0 if st.session_state.db_source == "environment" else 1,
129
- key="db_source_radio",
130
- help="Choose to use .env settings or enter custom credentials"
131
- )
132
-
133
- st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom"
134
-
135
- if st.session_state.db_source == "environment":
136
- # Show current environment config
137
- current_db_type = config.database.db_type.value.upper()
138
- st.info(f"πŸ“Œ Using {current_db_type} from environment")
139
- st.caption(f"Host: {config.database.host}")
140
- return None
141
-
142
- else:
143
- # Custom database configuration
144
- st.markdown("##### Enter Database Credentials")
145
-
146
- # Database type selector
147
- db_type_label = st.selectbox(
148
- "Database Type",
149
- options=list(DB_TYPES.keys()),
150
- index=0,
151
- key="custom_db_type"
152
- )
153
- db_type = DB_TYPES[db_type_label]
154
-
155
- if True: # MySQL or PostgreSQL (SQLite removed)
156
- # MySQL or PostgreSQL
157
- col1, col2 = st.columns([3, 1])
158
- with col1:
159
- host = st.text_input(
160
- "Host",
161
- value="",
162
- key="db_host_input",
163
- placeholder="your-database-host.com"
164
- )
165
- with col2:
166
- default_port = 3306 if db_type == "mysql" else 5432
167
- port = st.number_input(
168
- "Port",
169
- value=default_port,
170
- min_value=1,
171
- max_value=65535,
172
- key="db_port_input"
173
- )
174
-
175
- database = st.text_input(
176
- "Database Name",
177
- value="",
178
- key="db_name_input",
179
- placeholder="your_database"
180
- )
181
-
182
- username = st.text_input(
183
- "Username",
184
- value="",
185
- key="db_user_input",
186
- placeholder="your_username"
187
- )
188
-
189
- password = st.text_input(
190
- "Password",
191
- value="",
192
- type="password",
193
- key="db_pass_input"
194
- )
195
-
196
- # Optional SSL
197
- with st.expander("πŸ”’ SSL Settings (Optional)"):
198
- ssl_ca = st.text_input(
199
- "SSL CA Certificate Path",
200
- value="",
201
- key="ssl_ca_input",
202
- help="Path to SSL CA certificate file (for cloud databases like Aiven)"
203
- )
204
-
205
- return {
206
- "db_type": db_type,
207
- "host": host,
208
- "port": int(port),
209
- "database": database,
210
- "username": username,
211
- "password": password,
212
- "ssl_ca": ssl_ca if ssl_ca else None
213
- }
214
-
215
-
216
- def render_sidebar():
217
- """Render the configuration sidebar."""
218
- with st.sidebar:
219
- st.title("βš™οΈ Settings")
220
-
221
- # User Profile
222
- st.subheader("πŸ‘€ User Profile")
223
- user_id = st.text_input(
224
- "User ID / Name",
225
- value=st.session_state.get("user_id", "default"),
226
- key="user_id_input",
227
- help="Your unique ID for private memory storage"
228
- )
229
- if user_id != st.session_state.get("user_id"):
230
- st.session_state.user_id = user_id
231
- st.session_state.session_id = str(uuid.uuid4())
232
- st.session_state.messages = []
233
-
234
- # Recreate memory for new user
235
- if st.session_state.custom_db_connection:
236
- st.session_state.memory = create_custom_memory(
237
- st.session_state.session_id,
238
- user_id,
239
- st.session_state.custom_db_connection,
240
- st.session_state.get("llm"),
241
- st.session_state.enable_summarization,
242
- st.session_state.summary_threshold
243
- )
244
- elif st.session_state.initialized:
245
- from memory import create_enhanced_memory
246
- st.session_state.memory = create_enhanced_memory(
247
- st.session_state.session_id,
248
- user_id=user_id,
249
- enable_summarization=st.session_state.enable_summarization,
250
- summary_threshold=st.session_state.summary_threshold
251
- )
252
-
253
- if st.session_state.memory:
254
- st.session_state.memory.clear_user_history()
255
- st.rerun()
256
-
257
- st.divider()
258
-
259
- # Database Configuration
260
- custom_db_params = render_database_config()
261
-
262
- st.divider()
263
-
264
- # LLM Configuration
265
- st.subheader("πŸ€– LLM Configuration")
266
-
267
- # Model selection only - API key from environment
268
- groq_model = st.selectbox(
269
- "Model",
270
- options=GROQ_MODELS,
271
- index=0,
272
- key="groq_model_select"
273
- )
274
-
275
- # Show status of API key
276
- if os.getenv("GROQ_API_KEY"):
277
- st.success("βœ“ API Key configured")
278
- else:
279
- st.warning("⚠️ GROQ_API_KEY not set in environment")
280
-
281
- st.divider()
282
-
283
- # Initialize Button
284
- if st.button("πŸš€ Connect & Initialize", use_container_width=True, type="primary"):
285
- with st.spinner("Connecting to database..."):
286
- success = initialize_chatbot(custom_db_params, None, groq_model)
287
- if success:
288
- st.success("βœ… Connected!")
289
- st.rerun()
290
-
291
- # Index Button (after initialization)
292
- if st.session_state.initialized:
293
- if st.button("πŸ“š Index Text Data", use_container_width=True):
294
- with st.spinner("Indexing text data..."):
295
- index_data()
296
- st.success("βœ… Indexed!")
297
- st.rerun()
298
-
299
- st.divider()
300
-
301
- # Status
302
- st.subheader("πŸ“Š Status")
303
- if st.session_state.initialized:
304
- # Show database type
305
- if st.session_state.custom_db_connection:
306
- db_type = st.session_state.custom_db_connection.db_type.value.upper()
307
- else:
308
- db_type = get_db().db_type.value.upper()
309
-
310
- st.success(f"Database: {db_type} βœ“")
311
-
312
- try:
313
- schema = get_schema()
314
- st.info(f"Tables: {len(schema.tables)}")
315
- except:
316
- st.warning("Schema not loaded")
317
-
318
- if st.session_state.indexed:
319
- from rag import get_rag_engine
320
- engine = get_rag_engine()
321
- st.info(f"Indexed Docs: {engine.document_count}")
322
- else:
323
- st.warning("Not connected")
324
-
325
- # New Chat
326
- if st.button("βž• New Chat", use_container_width=True, type="secondary"):
327
- if st.session_state.memory:
328
- st.session_state.memory.clear()
329
-
330
- st.session_state.messages = []
331
- st.session_state.session_id = str(uuid.uuid4())
332
-
333
- current_user = st.session_state.get("user_id", "default")
334
-
335
- if st.session_state.custom_db_connection:
336
- st.session_state.memory = create_custom_memory(
337
- st.session_state.session_id,
338
- current_user,
339
- st.session_state.custom_db_connection,
340
- st.session_state.get("llm"),
341
- st.session_state.enable_summarization,
342
- st.session_state.summary_threshold
343
- )
344
- elif st.session_state.initialized:
345
- from memory import create_enhanced_memory
346
- st.session_state.memory = create_enhanced_memory(
347
- st.session_state.session_id,
348
- user_id=current_user,
349
- enable_summarization=st.session_state.enable_summarization,
350
- summary_threshold=st.session_state.summary_threshold
351
- )
352
- if st.session_state.get("llm"):
353
- st.session_state.memory.set_llm_client(st.session_state.llm)
354
-
355
- st.rerun()
356
-
357
- # Disconnect button (when using custom DB)
358
- if st.session_state.initialized and st.session_state.db_source == "custom":
359
- if st.button("πŸ”Œ Disconnect", use_container_width=True):
360
- if st.session_state.custom_db_connection:
361
- st.session_state.custom_db_connection.close()
362
- st.session_state.custom_db_connection = None
363
- st.session_state.chatbot = None
364
- st.session_state.initialized = False
365
- st.session_state.indexed = False
366
- st.session_state.memory = None
367
- st.success("Disconnected!")
368
- st.rerun()
369
-
370
-
371
- def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
372
- """Initialize the chatbot with either environment or custom database."""
373
- try:
374
- # Get API key
375
- groq_api_key = api_key or os.getenv("GROQ_API_KEY", "")
376
- groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
377
-
378
- if not groq_api_key:
379
- st.error("GROQ_API_KEY not configured. Please enter your API key.")
380
- return False
381
-
382
- # Create LLM client
383
- llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model)
384
-
385
- # Create database connection
386
- if custom_db_params and st.session_state.db_source == "custom":
387
- # Validate custom params
388
- db_type = custom_db_params.get("db_type", "mysql")
389
-
390
- if True:
391
- if not all([custom_db_params.get("host"),
392
- custom_db_params.get("database"),
393
- custom_db_params.get("username")]):
394
- st.error("Please fill in all required database fields.")
395
- return False
396
-
397
- # Create custom config
398
- db_config = create_custom_db_config(**custom_db_params)
399
-
400
- # Create custom connection
401
- custom_connection = DatabaseConnection(db_config)
402
-
403
- # Test connection
404
- success, msg = custom_connection.test_connection()
405
- if not success:
406
- st.error(f"Connection failed: {msg}")
407
- return False
408
-
409
- st.session_state.custom_db_connection = custom_connection
410
- st.session_state.custom_db_config = db_config
411
-
412
- # Override the global db connection for the chatbot
413
- # We need to create a chatbot with this custom connection
414
- from chatbot import DatabaseChatbot
415
- from database.schema_introspector import SchemaIntrospector
416
- from rag import get_rag_engine
417
- from sql import get_sql_generator, get_sql_validator
418
- from router import get_query_router
419
-
420
- chatbot = DatabaseChatbot.__new__(DatabaseChatbot)
421
- chatbot.db = custom_connection
422
- chatbot.introspector = SchemaIntrospector()
423
- chatbot.introspector.db = custom_connection
424
- chatbot.rag_engine = get_rag_engine()
425
- chatbot.sql_generator = get_sql_generator(db_type)
426
- chatbot.sql_validator = get_sql_validator()
427
- chatbot.router = get_query_router()
428
- chatbot.llm_client = llm
429
- chatbot._schema_initialized = False
430
- chatbot._rag_initialized = False
431
-
432
- # Set LLM client
433
- chatbot.set_llm_client(llm)
434
-
435
- # Initialize (introspect schema)
436
- schema = chatbot.introspector.introspect(force_refresh=True)
437
- chatbot.sql_validator.set_allowed_tables(schema.table_names)
438
- chatbot._schema_initialized = True
439
-
440
- st.session_state.chatbot = chatbot
441
-
442
- else:
443
- # Use environment-based connection (existing flow)
444
- chatbot = create_chatbot(llm)
445
- chatbot.set_llm_client(llm)
446
-
447
- success, msg = chatbot.initialize()
448
- if not success:
449
- st.error(f"Initialization failed: {msg}")
450
- return False
451
-
452
- st.session_state.chatbot = chatbot
453
- st.session_state.custom_db_connection = None
454
-
455
- st.session_state.llm = llm
456
- st.session_state.initialized = True
457
- st.session_state.indexed = False # Reset index status on new connection
458
-
459
- # Clear RAG index to ensure no data from previous DB connection persists
460
- if hasattr(chatbot, 'rag_engine') and hasattr(chatbot.rag_engine, 'clear_index'):
461
- chatbot.rag_engine.clear_index()
462
-
463
- # Create memory with appropriate connection
464
- db_conn = st.session_state.custom_db_connection or get_db()
465
- st.session_state.memory = create_custom_memory(
466
- st.session_state.session_id,
467
- st.session_state.user_id,
468
- db_conn,
469
- llm,
470
- st.session_state.enable_summarization,
471
- st.session_state.summary_threshold
472
- )
473
-
474
- return True
475
-
476
- except Exception as e:
477
- st.error(f"Error: {str(e)}")
478
- import traceback
479
- st.error(traceback.format_exc())
480
- return False
481
-
482
-
483
- def index_data():
484
- """Index text data from the database."""
485
- if st.session_state.chatbot:
486
- progress = st.progress(0)
487
- status = st.empty()
488
-
489
- # Get schema from the correct introspector
490
- schema = st.session_state.chatbot.introspector.introspect()
491
- total_tables = len(schema.tables)
492
- indexed = 0
493
-
494
- def progress_callback(table_name, docs):
495
- nonlocal indexed
496
- indexed += 1
497
- progress.progress(indexed / total_tables)
498
- status.text(f"Indexed {table_name}: {docs} documents")
499
-
500
- total_docs = st.session_state.chatbot.index_text_data(progress_callback)
501
- st.session_state.indexed = True
502
- status.text(f"Total: {total_docs} documents indexed")
503
-
504
-
505
- def render_schema_explorer():
506
- """Render schema explorer in an expander."""
507
- if not st.session_state.initialized:
508
- return
509
-
510
- with st.expander("πŸ“‹ Database Schema", expanded=False):
511
- try:
512
- schema = st.session_state.chatbot.introspector.introspect()
513
-
514
- for table_name, table_info in schema.tables.items():
515
- with st.container():
516
- st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
517
-
518
- cols = []
519
- for col in table_info.columns:
520
- pk = "πŸ”‘" if col.is_primary_key else ""
521
- txt = "πŸ“" if col.is_text_type else ""
522
- cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
523
-
524
- st.caption(" | ".join(cols))
525
- st.divider()
526
- except Exception as e:
527
- st.error(f"Error loading schema: {e}")
528
-
529
-
530
- def render_chat_interface():
531
- """Render the main chat interface."""
532
- st.title("πŸ€– OnceDataBot")
533
- st.caption("Schema-agnostic chatbot β€’ MySQL | PostgreSQL β€’ Powered by Groq (FREE!)")
534
-
535
- # Schema explorer
536
- render_schema_explorer()
537
-
538
- # Chat container
539
- chat_container = st.container()
540
-
541
- with chat_container:
542
- # Display messages
543
- for msg in st.session_state.messages:
544
- with st.chat_message(msg["role"]):
545
- st.markdown(msg["content"])
546
-
547
- # Show metadata for assistant messages
548
- if msg["role"] == "assistant" and "metadata" in msg:
549
- meta = msg["metadata"]
550
- if meta.get("query_type"):
551
- st.caption(f"Query type: {meta['query_type']}")
552
- if meta.get("sql_query"):
553
- with st.expander("SQL Query"):
554
- st.code(meta["sql_query"], language="sql")
555
-
556
- # Chat input
557
- if prompt := st.chat_input("Ask about your data..."):
558
- if not st.session_state.initialized:
559
- st.error("Please connect to a database first!")
560
- return
561
-
562
- # Add user message
563
- st.session_state.messages.append({"role": "user", "content": prompt})
564
- if st.session_state.memory:
565
- st.session_state.memory.add_message("user", prompt)
566
-
567
- with st.chat_message("user"):
568
- st.markdown(prompt)
569
-
570
- # Get response
571
- with st.chat_message("assistant"):
572
- with st.spinner("Thinking..."):
573
- response = st.session_state.chatbot.chat(
574
- prompt,
575
- st.session_state.memory
576
- )
577
-
578
- st.markdown(response.answer)
579
-
580
- # Show metadata
581
- if response.query_type != "general":
582
- st.caption(f"Query type: {response.query_type}")
583
-
584
- if response.sql_query:
585
- with st.expander("SQL Query"):
586
- st.code(response.sql_query, language="sql")
587
-
588
- if response.sql_results:
589
- with st.expander("Results"):
590
- st.dataframe(response.sql_results)
591
-
592
- # Save to memory
593
- st.session_state.messages.append({
594
- "role": "assistant",
595
- "content": response.answer,
596
- "metadata": {
597
- "query_type": response.query_type,
598
- "sql_query": response.sql_query
599
- }
600
- })
601
- if st.session_state.memory:
602
- st.session_state.memory.add_message("assistant", response.answer)
603
-
604
-
605
- def main():
606
- """Main application entry point."""
607
- init_session_state()
608
- render_sidebar()
609
- render_chat_interface()
610
-
611
-
612
- if __name__ == "__main__":
613
- main()
 
 
 
 
1
+ """
2
+ Schema-Agnostic Database Chatbot - Streamlit Application
3
+
4
+ A production-grade chatbot that connects to ANY database
5
+ (MySQL, PostgreSQL, SQLite) and provides intelligent querying
6
+ through RAG and Text-to-SQL.
7
+
8
+ Uses Groq for FREE LLM inference!
9
+ """
10
+
11
+ import os
12
+ from pathlib import Path
13
+
14
+ # Load .env FIRST before any other imports
15
+ from dotenv import load_dotenv
16
+ load_dotenv(Path(__file__).parent / ".env")
17
+
18
+ import streamlit as st
19
+ import uuid
20
+ from datetime import datetime
21
+
22
+ # Page config must be first
23
+ st.set_page_config(
24
+ page_title="OnceDataBot",
25
+ page_icon="πŸ€–",
26
+ layout="wide",
27
+ initial_sidebar_state="expanded"
28
+ )
29
+
30
+ # Imports
31
+ from config import config, DatabaseConfig, DatabaseType
32
+ from database import get_db, get_schema, get_introspector
33
+ from database.connection import DatabaseConnection
34
+ from llm import create_llm_client
35
+ from chatbot import create_chatbot, DatabaseChatbot
36
+ from memory import ChatMemory, EnhancedChatMemory
37
+
38
+
39
+ # Groq models (all FREE!)
40
+ GROQ_MODELS = [
41
+ "llama-3.3-70b-versatile",
42
+ "llama-3.1-8b-instant",
43
+ "mixtral-8x7b-32768",
44
+ "gemma2-9b-it"
45
+ ]
46
+
47
+ # Database types
48
+ DB_TYPES = {
49
+ "MySQL": "mysql",
50
+ "PostgreSQL": "postgresql"
51
+ }
52
+
53
+
54
+
55
+
56
+ def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
57
+ """Create a custom database configuration from user input."""
58
+ return DatabaseConfig(
59
+ db_type=DatabaseType(db_type),
60
+ host=kwargs.get("host", ""),
61
+ port=kwargs.get("port", 3306 if db_type == "mysql" else 5432),
62
+ database=kwargs.get("database", ""),
63
+ username=kwargs.get("username", ""),
64
+ password=kwargs.get("password", ""),
65
+ ssl_ca=kwargs.get("ssl_ca", None)
66
+ )
67
+
68
+
69
+ def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
70
+ enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory:
71
+ """Create enhanced memory with a custom database connection."""
72
+ return EnhancedChatMemory(
73
+ session_id=session_id,
74
+ user_id=user_id,
75
+ max_messages=20,
76
+ db_connection=db_connection,
77
+ llm_client=llm_client,
78
+ enable_summarization=enable_summarization,
79
+ summary_threshold=summary_threshold
80
+ )
81
+
82
+
83
+ def init_session_state():
84
+ """Initialize Streamlit session state."""
85
+ if "session_id" not in st.session_state:
86
+ st.session_state.session_id = str(uuid.uuid4())
87
+
88
+ if "messages" not in st.session_state:
89
+ st.session_state.messages = []
90
+
91
+ if "chatbot" not in st.session_state:
92
+ st.session_state.chatbot = None
93
+
94
+ if "initialized" not in st.session_state:
95
+ st.session_state.initialized = False
96
+
97
+ if "user_id" not in st.session_state:
98
+ st.session_state.user_id = "default"
99
+
100
+ if "enable_summarization" not in st.session_state:
101
+ st.session_state.enable_summarization = True
102
+
103
+ if "summary_threshold" not in st.session_state:
104
+ st.session_state.summary_threshold = 10
105
+
106
+ if "memory" not in st.session_state:
107
+ st.session_state.memory = None
108
+
109
+ if "indexed" not in st.session_state:
110
+ st.session_state.indexed = False
111
+
112
+ if "db_source" not in st.session_state:
113
+ st.session_state.db_source = "environment" # "environment" or "custom"
114
+
115
+ if "custom_db_config" not in st.session_state:
116
+ st.session_state.custom_db_config = None
117
+
118
+ if "custom_db_connection" not in st.session_state:
119
+ st.session_state.custom_db_connection = None
120
+
121
+
122
+ def render_database_config():
123
+ """Render database configuration section in sidebar."""
124
+ st.subheader("πŸ—„οΈ Database Configuration")
125
+
126
+ # Database source selection
127
+ db_source = st.radio(
128
+ "Database Source",
129
+ options=["Use Environment Variables", "Custom Database"],
130
+ index=0 if st.session_state.db_source == "environment" else 1,
131
+ key="db_source_radio",
132
+ help="Choose to use .env settings or enter custom credentials"
133
+ )
134
+
135
+ st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom"
136
+
137
+ if st.session_state.db_source == "environment":
138
+ # Show current environment config
139
+ current_db_type = config.database.db_type.value.upper()
140
+ st.info(f"πŸ“Œ Using {current_db_type} from environment")
141
+ st.caption(f"Host: {config.database.host}")
142
+ return None
143
+
144
+ else:
145
+ # Custom database configuration
146
+ st.markdown("##### Enter Database Credentials")
147
+
148
+ # Database type selector
149
+ db_type_label = st.selectbox(
150
+ "Database Type",
151
+ options=list(DB_TYPES.keys()),
152
+ index=0,
153
+ key="custom_db_type"
154
+ )
155
+ db_type = DB_TYPES[db_type_label]
156
+
157
+ if True: # MySQL or PostgreSQL (SQLite removed)
158
+ # MySQL or PostgreSQL
159
+ col1, col2 = st.columns([3, 1])
160
+ with col1:
161
+ host = st.text_input(
162
+ "Host",
163
+ value="",
164
+ key="db_host_input",
165
+ placeholder="your-database-host.com"
166
+ )
167
+ with col2:
168
+ default_port = 3306 if db_type == "mysql" else 5432
169
+ port = st.number_input(
170
+ "Port",
171
+ value=default_port,
172
+ min_value=1,
173
+ max_value=65535,
174
+ key="db_port_input"
175
+ )
176
+
177
+ database = st.text_input(
178
+ "Database Name",
179
+ value="",
180
+ key="db_name_input",
181
+ placeholder="your_database"
182
+ )
183
+
184
+ username = st.text_input(
185
+ "Username",
186
+ value="",
187
+ key="db_user_input",
188
+ placeholder="your_username"
189
+ )
190
+
191
+ password = st.text_input(
192
+ "Password",
193
+ value="",
194
+ type="password",
195
+ key="db_pass_input"
196
+ )
197
+
198
+ # Optional SSL
199
+ with st.expander("πŸ”’ SSL Settings (Optional)"):
200
+ ssl_ca = st.text_input(
201
+ "SSL CA Certificate Path",
202
+ value="",
203
+ key="ssl_ca_input",
204
+ help="Path to SSL CA certificate file (for cloud databases like Aiven)"
205
+ )
206
+
207
+ return {
208
+ "db_type": db_type,
209
+ "host": host,
210
+ "port": int(port),
211
+ "database": database,
212
+ "username": username,
213
+ "password": password,
214
+ "ssl_ca": ssl_ca if ssl_ca else None
215
+ }
216
+
217
+
218
+ def render_sidebar():
219
+ """Render the configuration sidebar."""
220
+ with st.sidebar:
221
+ st.title("βš™οΈ Settings")
222
+
223
+ # User Profile
224
+ st.subheader("πŸ‘€ User Profile")
225
+ user_id = st.text_input(
226
+ "User ID / Name",
227
+ value=st.session_state.get("user_id", "default"),
228
+ key="user_id_input",
229
+ help="Your unique ID for private memory storage"
230
+ )
231
+ if user_id != st.session_state.get("user_id"):
232
+ st.session_state.user_id = user_id
233
+ st.session_state.session_id = str(uuid.uuid4())
234
+ st.session_state.messages = []
235
+
236
+ # Recreate memory for new user
237
+ if st.session_state.custom_db_connection:
238
+ st.session_state.memory = create_custom_memory(
239
+ st.session_state.session_id,
240
+ user_id,
241
+ st.session_state.custom_db_connection,
242
+ st.session_state.get("llm"),
243
+ st.session_state.enable_summarization,
244
+ st.session_state.summary_threshold
245
+ )
246
+ elif st.session_state.initialized:
247
+ from memory import create_enhanced_memory
248
+ st.session_state.memory = create_enhanced_memory(
249
+ st.session_state.session_id,
250
+ user_id=user_id,
251
+ enable_summarization=st.session_state.enable_summarization,
252
+ summary_threshold=st.session_state.summary_threshold
253
+ )
254
+
255
+ if st.session_state.memory:
256
+ st.session_state.memory.clear_user_history()
257
+ st.rerun()
258
+
259
+ st.divider()
260
+
261
+ # Database Configuration
262
+ custom_db_params = render_database_config()
263
+
264
+ st.divider()
265
+
266
+ # LLM Configuration
267
+ st.subheader("πŸ€– LLM Configuration")
268
+
269
+ # Show status of API key
270
+ if os.getenv("GROQ_API_KEY"):
271
+ st.success("βœ“ API Key configured")
272
+ else:
273
+ st.warning("⚠️ GROQ_API_KEY not set in environment")
274
+
275
+ st.divider()
276
+
277
+ # Initialize Button
278
+ if st.button("πŸš€ Connect & Initialize", use_container_width=True, type="primary"):
279
+ with st.spinner("Connecting to database..."):
280
+ success = initialize_chatbot(custom_db_params, None, None)
281
+ if success:
282
+ st.success("βœ… Connected!")
283
+ st.rerun()
284
+
285
+ # Index Button (after initialization)
286
+ if st.session_state.initialized:
287
+ if st.button("πŸ“š Index Text Data", use_container_width=True):
288
+ with st.spinner("Indexing text data..."):
289
+ index_data()
290
+ st.success("βœ… Indexed!")
291
+ st.rerun()
292
+
293
+ st.divider()
294
+
295
+ # Status
296
+ st.subheader("πŸ“Š Status")
297
+ if st.session_state.initialized:
298
+ # Show database type
299
+ if st.session_state.custom_db_connection:
300
+ db_type = st.session_state.custom_db_connection.db_type.value.upper()
301
+ else:
302
+ db_type = get_db().db_type.value.upper()
303
+
304
+ st.success(f"Database: {db_type} βœ“")
305
+
306
+ try:
307
+ schema = get_schema()
308
+ st.info(f"Tables: {len(schema.tables)}")
309
+ except:
310
+ st.warning("Schema not loaded")
311
+
312
+ if st.session_state.indexed:
313
+ from rag import get_rag_engine
314
+ engine = get_rag_engine()
315
+ st.info(f"Indexed Docs: {engine.document_count}")
316
+ else:
317
+ st.warning("Not connected")
318
+
319
+ # New Chat
320
+ if st.button("βž• New Chat", use_container_width=True, type="secondary"):
321
+ if st.session_state.memory:
322
+ st.session_state.memory.clear()
323
+
324
+ st.session_state.messages = []
325
+ st.session_state.session_id = str(uuid.uuid4())
326
+
327
+ current_user = st.session_state.get("user_id", "default")
328
+
329
+ if st.session_state.custom_db_connection:
330
+ st.session_state.memory = create_custom_memory(
331
+ st.session_state.session_id,
332
+ current_user,
333
+ st.session_state.custom_db_connection,
334
+ st.session_state.get("llm"),
335
+ st.session_state.enable_summarization,
336
+ st.session_state.summary_threshold
337
+ )
338
+ elif st.session_state.initialized:
339
+ from memory import create_enhanced_memory
340
+ st.session_state.memory = create_enhanced_memory(
341
+ st.session_state.session_id,
342
+ user_id=current_user,
343
+ enable_summarization=st.session_state.enable_summarization,
344
+ summary_threshold=st.session_state.summary_threshold
345
+ )
346
+ if st.session_state.get("llm"):
347
+ st.session_state.memory.set_llm_client(st.session_state.llm)
348
+
349
+ st.rerun()
350
+
351
+ # Disconnect button (when using custom DB)
352
+ if st.session_state.initialized and st.session_state.db_source == "custom":
353
+ if st.button("πŸ”Œ Disconnect", use_container_width=True):
354
+ if st.session_state.custom_db_connection:
355
+ st.session_state.custom_db_connection.close()
356
+ st.session_state.custom_db_connection = None
357
+ st.session_state.chatbot = None
358
+ st.session_state.initialized = False
359
+ st.session_state.indexed = False
360
+ st.session_state.memory = None
361
+ st.success("Disconnected!")
362
+ st.rerun()
363
+
364
+
365
+ def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
366
+ """Initialize the chatbot with either environment or custom database."""
367
+ try:
368
+ # Get API key
369
+ groq_api_key = api_key or os.getenv("GROQ_API_KEY", "")
370
+ groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
371
+
372
+ if not groq_api_key:
373
+ st.error("GROQ_API_KEY not configured. Please enter your API key.")
374
+ return False
375
+
376
+ # Create LLM client
377
+ llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model)
378
+
379
+ # Create database connection
380
+ if custom_db_params and st.session_state.db_source == "custom":
381
+ # Validate custom params
382
+ db_type = custom_db_params.get("db_type", "mysql")
383
+
384
+ if True:
385
+ if not all([custom_db_params.get("host"),
386
+ custom_db_params.get("database"),
387
+ custom_db_params.get("username")]):
388
+ st.error("Please fill in all required database fields.")
389
+ return False
390
+
391
+ # Create custom config
392
+ db_config = create_custom_db_config(**custom_db_params)
393
+
394
+ # Create custom connection
395
+ custom_connection = DatabaseConnection(db_config)
396
+
397
+ # Test connection
398
+ success, msg = custom_connection.test_connection()
399
+ if not success:
400
+ st.error(f"Connection failed: {msg}")
401
+ return False
402
+
403
+ st.session_state.custom_db_connection = custom_connection
404
+ st.session_state.custom_db_config = db_config
405
+
406
+ # Override the global db connection for the chatbot
407
+ # We need to create a chatbot with this custom connection
408
+ from chatbot import DatabaseChatbot
409
+ from database.schema_introspector import SchemaIntrospector
410
+ from rag import get_rag_engine
411
+ from sql import get_sql_generator, get_sql_validator
412
+ from router import get_query_router
413
+
414
+ chatbot = DatabaseChatbot.__new__(DatabaseChatbot)
415
+ chatbot.db = custom_connection
416
+ chatbot.introspector = SchemaIntrospector()
417
+ chatbot.introspector.db = custom_connection
418
+ chatbot.rag_engine = get_rag_engine()
419
+ chatbot.sql_generator = get_sql_generator(db_type)
420
+ chatbot.sql_validator = get_sql_validator()
421
+ chatbot.router = get_query_router()
422
+ chatbot.llm_client = llm
423
+ chatbot._schema_initialized = False
424
+ chatbot._rag_initialized = False
425
+
426
+ # Set LLM client
427
+ chatbot.set_llm_client(llm)
428
+
429
+ # Initialize (introspect schema)
430
+ schema = chatbot.introspector.introspect(force_refresh=True)
431
+ chatbot.sql_validator.set_allowed_tables(schema.table_names)
432
+ chatbot._schema_initialized = True
433
+
434
+ st.session_state.chatbot = chatbot
435
+
436
+ else:
437
+ # Use environment-based connection (existing flow)
438
+ chatbot = create_chatbot(llm)
439
+ chatbot.set_llm_client(llm)
440
+
441
+ success, msg = chatbot.initialize()
442
+ if not success:
443
+ st.error(f"Initialization failed: {msg}")
444
+ return False
445
+
446
+ st.session_state.chatbot = chatbot
447
+ st.session_state.custom_db_connection = None
448
+
449
+ st.session_state.llm = llm
450
+ st.session_state.initialized = True
451
+ st.session_state.indexed = False # Reset index status on new connection
452
+
453
+ # Clear RAG index to ensure no data from previous DB connection persists
454
+ if hasattr(chatbot, 'rag_engine') and hasattr(chatbot.rag_engine, 'clear_index'):
455
+ chatbot.rag_engine.clear_index()
456
+
457
+ # Create memory with appropriate connection
458
+ db_conn = st.session_state.custom_db_connection or get_db()
459
+ st.session_state.memory = create_custom_memory(
460
+ st.session_state.session_id,
461
+ st.session_state.user_id,
462
+ db_conn,
463
+ llm,
464
+ st.session_state.enable_summarization,
465
+ st.session_state.summary_threshold
466
+ )
467
+
468
+ return True
469
+
470
+ except Exception as e:
471
+ st.error(f"Error: {str(e)}")
472
+ import traceback
473
+ st.error(traceback.format_exc())
474
+ return False
475
+
476
+
477
+ def index_data():
478
+ """Index text data from the database."""
479
+ if st.session_state.chatbot:
480
+ progress = st.progress(0)
481
+ status = st.empty()
482
+
483
+ # Get schema from the correct introspector
484
+ schema = st.session_state.chatbot.introspector.introspect()
485
+ total_tables = len(schema.tables)
486
+ indexed = 0
487
+
488
+ def progress_callback(table_name, docs):
489
+ nonlocal indexed
490
+ indexed += 1
491
+ progress.progress(indexed / total_tables)
492
+ status.text(f"Indexed {table_name}: {docs} documents")
493
+
494
+ total_docs = st.session_state.chatbot.index_text_data(progress_callback)
495
+ st.session_state.indexed = True
496
+ status.text(f"Total: {total_docs} documents indexed")
497
+
498
+
499
+ def render_schema_explorer():
500
+ """Render schema explorer in an expander."""
501
+ if not st.session_state.initialized:
502
+ return
503
+
504
+ with st.expander("πŸ“‹ Database Schema", expanded=False):
505
+ try:
506
+ schema = st.session_state.chatbot.introspector.introspect()
507
+
508
+ for table_name, table_info in schema.tables.items():
509
+ with st.container():
510
+ st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
511
+
512
+ cols = []
513
+ for col in table_info.columns:
514
+ pk = "πŸ”‘" if col.is_primary_key else ""
515
+ txt = "πŸ“" if col.is_text_type else ""
516
+ cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
517
+
518
+ st.caption(" | ".join(cols))
519
+ st.divider()
520
+ except Exception as e:
521
+ st.error(f"Error loading schema: {e}")
522
+
523
+
524
+ def render_chat_interface():
525
+ """Render the main chat interface."""
526
+ st.title("πŸ€– OnceDataBot")
527
+ st.caption("Schema-agnostic chatbot β€’ MySQL | PostgreSQL β€’ Powered by Groq (FREE!)")
528
+
529
+ # Schema explorer
530
+ render_schema_explorer()
531
+
532
+ # Chat container
533
+ chat_container = st.container()
534
+
535
+ with chat_container:
536
+ # Display messages
537
+ for msg in st.session_state.messages:
538
+ with st.chat_message(msg["role"]):
539
+ st.markdown(msg["content"])
540
+
541
+ # Show metadata for assistant messages
542
+ if msg["role"] == "assistant" and "metadata" in msg:
543
+ meta = msg["metadata"]
544
+ if meta.get("query_type"):
545
+ st.caption(f"Query type: {meta['query_type']}")
546
+ if meta.get("sql_query"):
547
+ with st.expander("SQL Query"):
548
+ st.code(meta["sql_query"], language="sql")
549
+
550
+ # Chat input
551
+ if prompt := st.chat_input("Ask about your data..."):
552
+ if not st.session_state.initialized:
553
+ st.error("Please connect to a database first!")
554
+ return
555
+
556
+ # Add user message
557
+ st.session_state.messages.append({"role": "user", "content": prompt})
558
+ if st.session_state.memory:
559
+ st.session_state.memory.add_message("user", prompt)
560
+
561
+ with st.chat_message("user"):
562
+ st.markdown(prompt)
563
+
564
+ # Get response
565
+ with st.chat_message("assistant"):
566
+ with st.spinner("Thinking..."):
567
+ response = st.session_state.chatbot.chat(
568
+ prompt,
569
+ st.session_state.memory
570
+ )
571
+
572
+ st.markdown(response.answer)
573
+
574
+ # Show metadata
575
+ if response.query_type != "general":
576
+ st.caption(f"Query type: {response.query_type}")
577
+
578
+ if response.sql_query:
579
+ with st.expander("SQL Query"):
580
+ st.code(response.sql_query, language="sql")
581
+
582
+ if response.sql_results:
583
+ with st.expander("Results"):
584
+ st.dataframe(response.sql_results)
585
+
586
+ # Save to memory
587
+ st.session_state.messages.append({
588
+ "role": "assistant",
589
+ "content": response.answer,
590
+ "metadata": {
591
+ "query_type": response.query_type,
592
+ "sql_query": response.sql_query
593
+ }
594
+ })
595
+ if st.session_state.memory:
596
+ st.session_state.memory.add_message("assistant", response.answer)
597
+
598
+
599
+ def main():
600
+ """Main application entry point."""
601
+ init_session_state()
602
+
603
+ # Auto-connect to environment database on first load
604
+ if "auto_connect_attempted" not in st.session_state:
605
+ st.session_state.auto_connect_attempted = True
606
+ if st.session_state.db_source == "environment":
607
+ success = initialize_chatbot()
608
+ if success:
609
+ st.toast("βœ… Auto-connected to database!")
610
+
611
+ render_sidebar()
612
+ render_chat_interface()
613
+
614
+
615
+ if __name__ == "__main__":
616
+ main()