Vanshcc commited on
Commit
7f8cace
·
verified ·
1 Parent(s): b0ebd7c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +636 -636
app.py CHANGED
@@ -1,636 +1,636 @@
1
- """
2
- Schema-Agnostic Database Chatbot - Streamlit Application
3
-
4
- A production-grade chatbot that connects to ANY database
5
- (MySQL, PostgreSQL, SQLite) and provides intelligent querying
6
- through RAG and Text-to-SQL.
7
-
8
- Uses Groq for FREE LLM inference!
9
- """
10
-
11
- import os
12
- from pathlib import Path
13
-
14
- # Load .env FIRST before any other imports
15
- from dotenv import load_dotenv
16
- load_dotenv(Path(__file__).parent / ".env")
17
-
18
- import streamlit as st
19
- import uuid
20
- from datetime import datetime
21
-
22
- # Page config must be first
23
- st.set_page_config(
24
- page_title="OnceDataBot",
25
- page_icon="🤖",
26
- layout="wide",
27
- initial_sidebar_state="expanded"
28
- )
29
-
30
- # Imports
31
- from config import config, DatabaseConfig, DatabaseType
32
- from database import get_db, get_schema, get_introspector
33
- from database.connection import DatabaseConnection
34
- from llm import create_llm_client
35
- from chatbot import create_chatbot, DatabaseChatbot
36
- from memory import ChatMemory, EnhancedChatMemory
37
-
38
-
39
- # Groq models (all FREE!)
40
- GROQ_MODELS = [
41
- "llama-3.3-70b-versatile",
42
- "llama-3.1-8b-instant",
43
- "mixtral-8x7b-32768",
44
- "gemma2-9b-it"
45
- ]
46
-
47
- # Database types
48
- DB_TYPES = {
49
- "MySQL": "mysql",
50
- "PostgreSQL": "postgresql",
51
- "SQLite": "sqlite"
52
- }
53
-
54
-
55
- def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
56
- """Create a custom database configuration from user input."""
57
- db_config = DatabaseConfig.__new__(DatabaseConfig)
58
-
59
- # Set database type
60
- db_config.db_type = DatabaseType(db_type)
61
-
62
- # Set connection parameters
63
- db_config.host = kwargs.get("host", "")
64
- db_config.port = kwargs.get("port", 3306 if db_type == "mysql" else 5432)
65
- db_config.database = kwargs.get("database", "")
66
- db_config.username = kwargs.get("username", "")
67
- db_config.password = kwargs.get("password", "")
68
- db_config.ssl_ca = kwargs.get("ssl_ca", None)
69
- db_config.sqlite_path = kwargs.get("sqlite_path", "./chatbot.db")
70
-
71
- return db_config
72
-
73
-
74
- def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
75
- enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory:
76
- """Create enhanced memory with a custom database connection."""
77
- return EnhancedChatMemory(
78
- session_id=session_id,
79
- user_id=user_id,
80
- max_messages=20,
81
- db_connection=db_connection,
82
- llm_client=llm_client,
83
- enable_summarization=enable_summarization,
84
- summary_threshold=summary_threshold
85
- )
86
-
87
-
88
- def init_session_state():
89
- """Initialize Streamlit session state."""
90
- if "session_id" not in st.session_state:
91
- st.session_state.session_id = str(uuid.uuid4())
92
-
93
- if "messages" not in st.session_state:
94
- st.session_state.messages = []
95
-
96
- if "chatbot" not in st.session_state:
97
- st.session_state.chatbot = None
98
-
99
- if "initialized" not in st.session_state:
100
- st.session_state.initialized = False
101
-
102
- if "user_id" not in st.session_state:
103
- st.session_state.user_id = "default"
104
-
105
- if "enable_summarization" not in st.session_state:
106
- st.session_state.enable_summarization = True
107
-
108
- if "summary_threshold" not in st.session_state:
109
- st.session_state.summary_threshold = 10
110
-
111
- if "memory" not in st.session_state:
112
- st.session_state.memory = None
113
-
114
- if "indexed" not in st.session_state:
115
- st.session_state.indexed = False
116
-
117
- if "db_source" not in st.session_state:
118
- st.session_state.db_source = "environment" # "environment" or "custom"
119
-
120
- if "custom_db_config" not in st.session_state:
121
- st.session_state.custom_db_config = None
122
-
123
- if "custom_db_connection" not in st.session_state:
124
- st.session_state.custom_db_connection = None
125
-
126
-
127
- def render_database_config():
128
- """Render database configuration section in sidebar."""
129
- st.subheader("🗄️ Database Configuration")
130
-
131
- # Database source selection
132
- db_source = st.radio(
133
- "Database Source",
134
- options=["Use Environment Variables", "Custom Database"],
135
- index=0 if st.session_state.db_source == "environment" else 1,
136
- key="db_source_radio",
137
- help="Choose to use .env settings or enter custom credentials"
138
- )
139
-
140
- st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom"
141
-
142
- if st.session_state.db_source == "environment":
143
- # Show current environment config
144
- current_db_type = config.database.db_type.value.upper()
145
- st.info(f"📌 Using {current_db_type} from environment")
146
- if config.database.is_sqlite:
147
- st.caption(f"Path: {config.database.sqlite_path}")
148
- else:
149
- st.caption(f"Host: {config.database.host}")
150
- return None
151
-
152
- else:
153
- # Custom database configuration
154
- st.markdown("##### Enter Database Credentials")
155
-
156
- # Database type selector
157
- db_type_label = st.selectbox(
158
- "Database Type",
159
- options=list(DB_TYPES.keys()),
160
- index=0,
161
- key="custom_db_type"
162
- )
163
- db_type = DB_TYPES[db_type_label]
164
-
165
- if db_type == "sqlite":
166
- # SQLite only needs file path
167
- sqlite_path = st.text_input(
168
- "Database File Path",
169
- value="./chatbot.db",
170
- key="sqlite_path_input",
171
- help="Path to SQLite database file (will be created if doesn't exist)"
172
- )
173
-
174
- return {
175
- "db_type": db_type,
176
- "sqlite_path": sqlite_path
177
- }
178
-
179
- else:
180
- # MySQL or PostgreSQL
181
- col1, col2 = st.columns([3, 1])
182
- with col1:
183
- host = st.text_input(
184
- "Host",
185
- value="",
186
- key="db_host_input",
187
- placeholder="your-database-host.com"
188
- )
189
- with col2:
190
- default_port = 3306 if db_type == "mysql" else 5432
191
- port = st.number_input(
192
- "Port",
193
- value=default_port,
194
- min_value=1,
195
- max_value=65535,
196
- key="db_port_input"
197
- )
198
-
199
- database = st.text_input(
200
- "Database Name",
201
- value="",
202
- key="db_name_input",
203
- placeholder="your_database"
204
- )
205
-
206
- username = st.text_input(
207
- "Username",
208
- value="",
209
- key="db_user_input",
210
- placeholder="your_username"
211
- )
212
-
213
- password = st.text_input(
214
- "Password",
215
- value="",
216
- type="password",
217
- key="db_pass_input"
218
- )
219
-
220
- # Optional SSL
221
- with st.expander("🔒 SSL Settings (Optional)"):
222
- ssl_ca = st.text_input(
223
- "SSL CA Certificate Path",
224
- value="",
225
- key="ssl_ca_input",
226
- help="Path to SSL CA certificate file (for cloud databases like Aiven)"
227
- )
228
-
229
- return {
230
- "db_type": db_type,
231
- "host": host,
232
- "port": int(port),
233
- "database": database,
234
- "username": username,
235
- "password": password,
236
- "ssl_ca": ssl_ca if ssl_ca else None
237
- }
238
-
239
-
240
- def render_sidebar():
241
- """Render the configuration sidebar."""
242
- with st.sidebar:
243
- st.title("⚙️ Settings")
244
-
245
- # User Profile
246
- st.subheader("👤 User Profile")
247
- user_id = st.text_input(
248
- "User ID / Name",
249
- value=st.session_state.get("user_id", "default"),
250
- key="user_id_input",
251
- help="Your unique ID for private memory storage"
252
- )
253
- if user_id != st.session_state.get("user_id"):
254
- st.session_state.user_id = user_id
255
- st.session_state.session_id = str(uuid.uuid4())
256
- st.session_state.messages = []
257
-
258
- # Recreate memory for new user
259
- if st.session_state.custom_db_connection:
260
- st.session_state.memory = create_custom_memory(
261
- st.session_state.session_id,
262
- user_id,
263
- st.session_state.custom_db_connection,
264
- st.session_state.get("llm"),
265
- st.session_state.enable_summarization,
266
- st.session_state.summary_threshold
267
- )
268
- elif st.session_state.initialized:
269
- from memory import create_enhanced_memory
270
- st.session_state.memory = create_enhanced_memory(
271
- st.session_state.session_id,
272
- user_id=user_id,
273
- enable_summarization=st.session_state.enable_summarization,
274
- summary_threshold=st.session_state.summary_threshold
275
- )
276
-
277
- if st.session_state.memory:
278
- st.session_state.memory.clear_user_history()
279
- st.rerun()
280
-
281
- st.divider()
282
-
283
- # Database Configuration
284
- custom_db_params = render_database_config()
285
-
286
- st.divider()
287
-
288
- # LLM Configuration
289
- st.subheader("🤖 LLM Configuration")
290
-
291
- # Model selection only - API key from environment
292
- groq_model = st.selectbox(
293
- "Model",
294
- options=GROQ_MODELS,
295
- index=0,
296
- key="groq_model_select"
297
- )
298
-
299
- # Show status of API key
300
- if os.getenv("GROQ_API_KEY"):
301
- st.success("✓ API Key configured")
302
- else:
303
- st.warning("⚠️ GROQ_API_KEY not set in environment")
304
-
305
- st.divider()
306
-
307
- # Initialize Button
308
- if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"):
309
- with st.spinner("Connecting to database..."):
310
- success = initialize_chatbot(custom_db_params, None, groq_model)
311
- if success:
312
- st.success("✅ Connected!")
313
- st.rerun()
314
-
315
- # Index Button (after initialization)
316
- if st.session_state.initialized:
317
- if st.button("📚 Index Text Data", use_container_width=True):
318
- with st.spinner("Indexing text data..."):
319
- index_data()
320
- st.success("✅ Indexed!")
321
- st.rerun()
322
-
323
- st.divider()
324
-
325
- # Status
326
- st.subheader("📊 Status")
327
- if st.session_state.initialized:
328
- # Show database type
329
- if st.session_state.custom_db_connection:
330
- db_type = st.session_state.custom_db_connection.db_type.value.upper()
331
- else:
332
- db_type = get_db().db_type.value.upper()
333
-
334
- st.success(f"Database: {db_type} ✓")
335
-
336
- try:
337
- schema = get_schema()
338
- st.info(f"Tables: {len(schema.tables)}")
339
- except:
340
- st.warning("Schema not loaded")
341
-
342
- if st.session_state.indexed:
343
- from rag import get_rag_engine
344
- engine = get_rag_engine()
345
- st.info(f"Indexed Docs: {engine.document_count}")
346
- else:
347
- st.warning("Not connected")
348
-
349
- # New Chat
350
- if st.button("➕ New Chat", use_container_width=True, type="secondary"):
351
- if st.session_state.memory:
352
- st.session_state.memory.clear()
353
-
354
- st.session_state.messages = []
355
- st.session_state.session_id = str(uuid.uuid4())
356
-
357
- current_user = st.session_state.get("user_id", "default")
358
-
359
- if st.session_state.custom_db_connection:
360
- st.session_state.memory = create_custom_memory(
361
- st.session_state.session_id,
362
- current_user,
363
- st.session_state.custom_db_connection,
364
- st.session_state.get("llm"),
365
- st.session_state.enable_summarization,
366
- st.session_state.summary_threshold
367
- )
368
- elif st.session_state.initialized:
369
- from memory import create_enhanced_memory
370
- st.session_state.memory = create_enhanced_memory(
371
- st.session_state.session_id,
372
- user_id=current_user,
373
- enable_summarization=st.session_state.enable_summarization,
374
- summary_threshold=st.session_state.summary_threshold
375
- )
376
- if st.session_state.get("llm"):
377
- st.session_state.memory.set_llm_client(st.session_state.llm)
378
-
379
- st.rerun()
380
-
381
- # Disconnect button (when using custom DB)
382
- if st.session_state.initialized and st.session_state.db_source == "custom":
383
- if st.button("🔌 Disconnect", use_container_width=True):
384
- if st.session_state.custom_db_connection:
385
- st.session_state.custom_db_connection.close()
386
- st.session_state.custom_db_connection = None
387
- st.session_state.chatbot = None
388
- st.session_state.initialized = False
389
- st.session_state.indexed = False
390
- st.session_state.memory = None
391
- st.success("Disconnected!")
392
- st.rerun()
393
-
394
-
395
- def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
396
- """Initialize the chatbot with either environment or custom database."""
397
- try:
398
- # Get API key
399
- groq_api_key = api_key or os.getenv("GROQ_API_KEY", "")
400
- groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
401
-
402
- if not groq_api_key:
403
- st.error("GROQ_API_KEY not configured. Please enter your API key.")
404
- return False
405
-
406
- # Create LLM client
407
- llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model)
408
-
409
- # Create database connection
410
- if custom_db_params and st.session_state.db_source == "custom":
411
- # Validate custom params
412
- db_type = custom_db_params.get("db_type", "mysql")
413
-
414
- if db_type == "sqlite":
415
- if not custom_db_params.get("sqlite_path"):
416
- st.error("Please provide SQLite database path.")
417
- return False
418
- else:
419
- if not all([custom_db_params.get("host"),
420
- custom_db_params.get("database"),
421
- custom_db_params.get("username")]):
422
- st.error("Please fill in all required database fields.")
423
- return False
424
-
425
- # Create custom config
426
- db_config = create_custom_db_config(**custom_db_params)
427
-
428
- # Create custom connection
429
- custom_connection = DatabaseConnection(db_config)
430
-
431
- # Test connection
432
- success, msg = custom_connection.test_connection()
433
- if not success:
434
- st.error(f"Connection failed: {msg}")
435
- return False
436
-
437
- st.session_state.custom_db_connection = custom_connection
438
- st.session_state.custom_db_config = db_config
439
-
440
- # Override the global db connection for the chatbot
441
- # We need to create a chatbot with this custom connection
442
- from chatbot import DatabaseChatbot
443
- from database.schema_introspector import SchemaIntrospector
444
- from rag import get_rag_engine
445
- from sql import get_sql_generator, get_sql_validator
446
- from router import get_query_router
447
-
448
- chatbot = DatabaseChatbot.__new__(DatabaseChatbot)
449
- chatbot.db = custom_connection
450
- chatbot.introspector = SchemaIntrospector()
451
- chatbot.introspector.db = custom_connection
452
- chatbot.rag_engine = get_rag_engine()
453
- chatbot.sql_generator = get_sql_generator(db_type)
454
- chatbot.sql_validator = get_sql_validator()
455
- chatbot.router = get_query_router()
456
- chatbot.llm_client = llm
457
- chatbot._schema_initialized = False
458
- chatbot._rag_initialized = False
459
-
460
- # Set LLM client
461
- chatbot.set_llm_client(llm)
462
-
463
- # Initialize (introspect schema)
464
- schema = chatbot.introspector.introspect(force_refresh=True)
465
- chatbot.sql_validator.set_allowed_tables(schema.table_names)
466
- chatbot._schema_initialized = True
467
-
468
- st.session_state.chatbot = chatbot
469
-
470
- else:
471
- # Use environment-based connection (existing flow)
472
- chatbot = create_chatbot(llm)
473
- chatbot.set_llm_client(llm)
474
-
475
- success, msg = chatbot.initialize()
476
- if not success:
477
- st.error(f"Initialization failed: {msg}")
478
- return False
479
-
480
- st.session_state.chatbot = chatbot
481
- st.session_state.custom_db_connection = None
482
-
483
- st.session_state.llm = llm
484
- st.session_state.initialized = True
485
-
486
- # Create memory with appropriate connection
487
- db_conn = st.session_state.custom_db_connection or get_db()
488
- st.session_state.memory = create_custom_memory(
489
- st.session_state.session_id,
490
- st.session_state.user_id,
491
- db_conn,
492
- llm,
493
- st.session_state.enable_summarization,
494
- st.session_state.summary_threshold
495
- )
496
-
497
- return True
498
-
499
- except Exception as e:
500
- st.error(f"Error: {str(e)}")
501
- import traceback
502
- st.error(traceback.format_exc())
503
- return False
504
-
505
-
506
- def index_data():
507
- """Index text data from the database."""
508
- if st.session_state.chatbot:
509
- progress = st.progress(0)
510
- status = st.empty()
511
-
512
- # Get schema from the correct introspector
513
- schema = st.session_state.chatbot.introspector.introspect()
514
- total_tables = len(schema.tables)
515
- indexed = 0
516
-
517
- def progress_callback(table_name, docs):
518
- nonlocal indexed
519
- indexed += 1
520
- progress.progress(indexed / total_tables)
521
- status.text(f"Indexed {table_name}: {docs} documents")
522
-
523
- total_docs = st.session_state.chatbot.index_text_data(progress_callback)
524
- st.session_state.indexed = True
525
- status.text(f"Total: {total_docs} documents indexed")
526
-
527
-
528
- def render_schema_explorer():
529
- """Render schema explorer in an expander."""
530
- if not st.session_state.initialized:
531
- return
532
-
533
- with st.expander("📋 Database Schema", expanded=False):
534
- try:
535
- schema = st.session_state.chatbot.introspector.introspect()
536
-
537
- for table_name, table_info in schema.tables.items():
538
- with st.container():
539
- st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
540
-
541
- cols = []
542
- for col in table_info.columns:
543
- pk = "🔑" if col.is_primary_key else ""
544
- txt = "📝" if col.is_text_type else ""
545
- cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
546
-
547
- st.caption(" | ".join(cols))
548
- st.divider()
549
- except Exception as e:
550
- st.error(f"Error loading schema: {e}")
551
-
552
-
553
- def render_chat_interface():
554
- """Render the main chat interface."""
555
- st.title("🤖 Database Copilot")
556
- st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL | SQLite • Powered by Groq (FREE!)")
557
-
558
- # Schema explorer
559
- render_schema_explorer()
560
-
561
- # Chat container
562
- chat_container = st.container()
563
-
564
- with chat_container:
565
- # Display messages
566
- for msg in st.session_state.messages:
567
- with st.chat_message(msg["role"]):
568
- st.markdown(msg["content"])
569
-
570
- # Show metadata for assistant messages
571
- if msg["role"] == "assistant" and "metadata" in msg:
572
- meta = msg["metadata"]
573
- if meta.get("query_type"):
574
- st.caption(f"Query type: {meta['query_type']}")
575
- if meta.get("sql_query"):
576
- with st.expander("SQL Query"):
577
- st.code(meta["sql_query"], language="sql")
578
-
579
- # Chat input
580
- if prompt := st.chat_input("Ask about your data..."):
581
- if not st.session_state.initialized:
582
- st.error("Please connect to a database first!")
583
- return
584
-
585
- # Add user message
586
- st.session_state.messages.append({"role": "user", "content": prompt})
587
- if st.session_state.memory:
588
- st.session_state.memory.add_message("user", prompt)
589
-
590
- with st.chat_message("user"):
591
- st.markdown(prompt)
592
-
593
- # Get response
594
- with st.chat_message("assistant"):
595
- with st.spinner("Thinking..."):
596
- response = st.session_state.chatbot.chat(
597
- prompt,
598
- st.session_state.memory
599
- )
600
-
601
- st.markdown(response.answer)
602
-
603
- # Show metadata
604
- if response.query_type != "general":
605
- st.caption(f"Query type: {response.query_type}")
606
-
607
- if response.sql_query:
608
- with st.expander("SQL Query"):
609
- st.code(response.sql_query, language="sql")
610
-
611
- if response.sql_results:
612
- with st.expander("Results"):
613
- st.dataframe(response.sql_results)
614
-
615
- # Save to memory
616
- st.session_state.messages.append({
617
- "role": "assistant",
618
- "content": response.answer,
619
- "metadata": {
620
- "query_type": response.query_type,
621
- "sql_query": response.sql_query
622
- }
623
- })
624
- if st.session_state.memory:
625
- st.session_state.memory.add_message("assistant", response.answer)
626
-
627
-
628
- def main():
629
- """Main application entry point."""
630
- init_session_state()
631
- render_sidebar()
632
- render_chat_interface()
633
-
634
-
635
- if __name__ == "__main__":
636
- main()
 
1
+ """
2
+ Schema-Agnostic Database Chatbot - Streamlit Application
3
+
4
+ A production-grade chatbot that connects to ANY database
5
+ (MySQL, PostgreSQL, SQLite) and provides intelligent querying
6
+ through RAG and Text-to-SQL.
7
+
8
+ Uses Groq for FREE LLM inference!
9
+ """
10
+
11
+ import os
12
+ from pathlib import Path
13
+
14
+ # Load .env FIRST before any other imports
15
+ from dotenv import load_dotenv
16
+ load_dotenv(Path(__file__).parent / ".env")
17
+
18
+ import streamlit as st
19
+ import uuid
20
+ from datetime import datetime
21
+
22
+ # Page config must be first
23
+ st.set_page_config(
24
+ page_title="OnceDataBot",
25
+ page_icon="🤖",
26
+ layout="wide",
27
+ initial_sidebar_state="expanded"
28
+ )
29
+
30
+ # Imports
31
+ from config import config, DatabaseConfig, DatabaseType
32
+ from database import get_db, get_schema, get_introspector
33
+ from database.connection import DatabaseConnection
34
+ from llm import create_llm_client
35
+ from chatbot import create_chatbot, DatabaseChatbot
36
+ from memory import ChatMemory, EnhancedChatMemory
37
+
38
+
39
+ # Groq models (all FREE!)
40
+ GROQ_MODELS = [
41
+ "llama-3.3-70b-versatile",
42
+ "llama-3.1-8b-instant",
43
+ "mixtral-8x7b-32768",
44
+ "gemma2-9b-it"
45
+ ]
46
+
47
+ # Database types
48
+ DB_TYPES = {
49
+ "MySQL": "mysql",
50
+ "PostgreSQL": "postgresql",
51
+ "SQLite": "sqlite"
52
+ }
53
+
54
+
55
+ def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
56
+ """Create a custom database configuration from user input."""
57
+ db_config = DatabaseConfig.__new__(DatabaseConfig)
58
+
59
+ # Set database type
60
+ db_config.db_type = DatabaseType(db_type)
61
+
62
+ # Set connection parameters
63
+ db_config.host = kwargs.get("host", "")
64
+ db_config.port = kwargs.get("port", 3306 if db_type == "mysql" else 5432)
65
+ db_config.database = kwargs.get("database", "")
66
+ db_config.username = kwargs.get("username", "")
67
+ db_config.password = kwargs.get("password", "")
68
+ db_config.ssl_ca = kwargs.get("ssl_ca", None)
69
+ db_config.sqlite_path = kwargs.get("sqlite_path", "./chatbot.db")
70
+
71
+ return db_config
72
+
73
+
74
+ def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
75
+ enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory:
76
+ """Create enhanced memory with a custom database connection."""
77
+ return EnhancedChatMemory(
78
+ session_id=session_id,
79
+ user_id=user_id,
80
+ max_messages=20,
81
+ db_connection=db_connection,
82
+ llm_client=llm_client,
83
+ enable_summarization=enable_summarization,
84
+ summary_threshold=summary_threshold
85
+ )
86
+
87
+
88
+ def init_session_state():
89
+ """Initialize Streamlit session state."""
90
+ if "session_id" not in st.session_state:
91
+ st.session_state.session_id = str(uuid.uuid4())
92
+
93
+ if "messages" not in st.session_state:
94
+ st.session_state.messages = []
95
+
96
+ if "chatbot" not in st.session_state:
97
+ st.session_state.chatbot = None
98
+
99
+ if "initialized" not in st.session_state:
100
+ st.session_state.initialized = False
101
+
102
+ if "user_id" not in st.session_state:
103
+ st.session_state.user_id = "default"
104
+
105
+ if "enable_summarization" not in st.session_state:
106
+ st.session_state.enable_summarization = True
107
+
108
+ if "summary_threshold" not in st.session_state:
109
+ st.session_state.summary_threshold = 10
110
+
111
+ if "memory" not in st.session_state:
112
+ st.session_state.memory = None
113
+
114
+ if "indexed" not in st.session_state:
115
+ st.session_state.indexed = False
116
+
117
+ if "db_source" not in st.session_state:
118
+ st.session_state.db_source = "environment" # "environment" or "custom"
119
+
120
+ if "custom_db_config" not in st.session_state:
121
+ st.session_state.custom_db_config = None
122
+
123
+ if "custom_db_connection" not in st.session_state:
124
+ st.session_state.custom_db_connection = None
125
+
126
+
127
+ def render_database_config():
128
+ """Render database configuration section in sidebar."""
129
+ st.subheader("🗄️ Database Configuration")
130
+
131
+ # Database source selection
132
+ db_source = st.radio(
133
+ "Database Source",
134
+ options=["Use Environment Variables", "Custom Database"],
135
+ index=0 if st.session_state.db_source == "environment" else 1,
136
+ key="db_source_radio",
137
+ help="Choose to use .env settings or enter custom credentials"
138
+ )
139
+
140
+ st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom"
141
+
142
+ if st.session_state.db_source == "environment":
143
+ # Show current environment config
144
+ current_db_type = config.database.db_type.value.upper()
145
+ st.info(f"📌 Using {current_db_type} from environment")
146
+ if config.database.is_sqlite:
147
+ st.caption(f"Path: {config.database.sqlite_path}")
148
+ else:
149
+ st.caption(f"Host: {config.database.host}")
150
+ return None
151
+
152
+ else:
153
+ # Custom database configuration
154
+ st.markdown("##### Enter Database Credentials")
155
+
156
+ # Database type selector
157
+ db_type_label = st.selectbox(
158
+ "Database Type",
159
+ options=list(DB_TYPES.keys()),
160
+ index=0,
161
+ key="custom_db_type"
162
+ )
163
+ db_type = DB_TYPES[db_type_label]
164
+
165
+ if db_type == "sqlite":
166
+ # SQLite only needs file path
167
+ sqlite_path = st.text_input(
168
+ "Database File Path",
169
+ value="./chatbot.db",
170
+ key="sqlite_path_input",
171
+ help="Path to SQLite database file (will be created if doesn't exist)"
172
+ )
173
+
174
+ return {
175
+ "db_type": db_type,
176
+ "sqlite_path": sqlite_path
177
+ }
178
+
179
+ else:
180
+ # MySQL or PostgreSQL
181
+ col1, col2 = st.columns([3, 1])
182
+ with col1:
183
+ host = st.text_input(
184
+ "Host",
185
+ value="",
186
+ key="db_host_input",
187
+ placeholder="your-database-host.com"
188
+ )
189
+ with col2:
190
+ default_port = 3306 if db_type == "mysql" else 5432
191
+ port = st.number_input(
192
+ "Port",
193
+ value=default_port,
194
+ min_value=1,
195
+ max_value=65535,
196
+ key="db_port_input"
197
+ )
198
+
199
+ database = st.text_input(
200
+ "Database Name",
201
+ value="",
202
+ key="db_name_input",
203
+ placeholder="your_database"
204
+ )
205
+
206
+ username = st.text_input(
207
+ "Username",
208
+ value="",
209
+ key="db_user_input",
210
+ placeholder="your_username"
211
+ )
212
+
213
+ password = st.text_input(
214
+ "Password",
215
+ value="",
216
+ type="password",
217
+ key="db_pass_input"
218
+ )
219
+
220
+ # Optional SSL
221
+ with st.expander("🔒 SSL Settings (Optional)"):
222
+ ssl_ca = st.text_input(
223
+ "SSL CA Certificate Path",
224
+ value="",
225
+ key="ssl_ca_input",
226
+ help="Path to SSL CA certificate file (for cloud databases like Aiven)"
227
+ )
228
+
229
+ return {
230
+ "db_type": db_type,
231
+ "host": host,
232
+ "port": int(port),
233
+ "database": database,
234
+ "username": username,
235
+ "password": password,
236
+ "ssl_ca": ssl_ca if ssl_ca else None
237
+ }
238
+
239
+
240
+ def render_sidebar():
241
+ """Render the configuration sidebar."""
242
+ with st.sidebar:
243
+ st.title("⚙️ Settings")
244
+
245
+ # User Profile
246
+ st.subheader("👤 User Profile")
247
+ user_id = st.text_input(
248
+ "User ID / Name",
249
+ value=st.session_state.get("user_id", "default"),
250
+ key="user_id_input",
251
+ help="Your unique ID for private memory storage"
252
+ )
253
+ if user_id != st.session_state.get("user_id"):
254
+ st.session_state.user_id = user_id
255
+ st.session_state.session_id = str(uuid.uuid4())
256
+ st.session_state.messages = []
257
+
258
+ # Recreate memory for new user
259
+ if st.session_state.custom_db_connection:
260
+ st.session_state.memory = create_custom_memory(
261
+ st.session_state.session_id,
262
+ user_id,
263
+ st.session_state.custom_db_connection,
264
+ st.session_state.get("llm"),
265
+ st.session_state.enable_summarization,
266
+ st.session_state.summary_threshold
267
+ )
268
+ elif st.session_state.initialized:
269
+ from memory import create_enhanced_memory
270
+ st.session_state.memory = create_enhanced_memory(
271
+ st.session_state.session_id,
272
+ user_id=user_id,
273
+ enable_summarization=st.session_state.enable_summarization,
274
+ summary_threshold=st.session_state.summary_threshold
275
+ )
276
+
277
+ if st.session_state.memory:
278
+ st.session_state.memory.clear_user_history()
279
+ st.rerun()
280
+
281
+ st.divider()
282
+
283
+ # Database Configuration
284
+ custom_db_params = render_database_config()
285
+
286
+ st.divider()
287
+
288
+ # LLM Configuration
289
+ st.subheader("🤖 LLM Configuration")
290
+
291
+ # Model selection only - API key from environment
292
+ groq_model = st.selectbox(
293
+ "Model",
294
+ options=GROQ_MODELS,
295
+ index=0,
296
+ key="groq_model_select"
297
+ )
298
+
299
+ # Show status of API key
300
+ if os.getenv("GROQ_API_KEY"):
301
+ st.success("✓ API Key configured")
302
+ else:
303
+ st.warning("⚠️ GROQ_API_KEY not set in environment")
304
+
305
+ st.divider()
306
+
307
+ # Initialize Button
308
+ if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"):
309
+ with st.spinner("Connecting to database..."):
310
+ success = initialize_chatbot(custom_db_params, None, groq_model)
311
+ if success:
312
+ st.success("✅ Connected!")
313
+ st.rerun()
314
+
315
+ # Index Button (after initialization)
316
+ if st.session_state.initialized:
317
+ if st.button("📚 Index Text Data", use_container_width=True):
318
+ with st.spinner("Indexing text data..."):
319
+ index_data()
320
+ st.success("✅ Indexed!")
321
+ st.rerun()
322
+
323
+ st.divider()
324
+
325
+ # Status
326
+ st.subheader("📊 Status")
327
+ if st.session_state.initialized:
328
+ # Show database type
329
+ if st.session_state.custom_db_connection:
330
+ db_type = st.session_state.custom_db_connection.db_type.value.upper()
331
+ else:
332
+ db_type = get_db().db_type.value.upper()
333
+
334
+ st.success(f"Database: {db_type} ✓")
335
+
336
+ try:
337
+ schema = get_schema()
338
+ st.info(f"Tables: {len(schema.tables)}")
339
+ except:
340
+ st.warning("Schema not loaded")
341
+
342
+ if st.session_state.indexed:
343
+ from rag import get_rag_engine
344
+ engine = get_rag_engine()
345
+ st.info(f"Indexed Docs: {engine.document_count}")
346
+ else:
347
+ st.warning("Not connected")
348
+
349
+ # New Chat
350
+ if st.button("➕ New Chat", use_container_width=True, type="secondary"):
351
+ if st.session_state.memory:
352
+ st.session_state.memory.clear()
353
+
354
+ st.session_state.messages = []
355
+ st.session_state.session_id = str(uuid.uuid4())
356
+
357
+ current_user = st.session_state.get("user_id", "default")
358
+
359
+ if st.session_state.custom_db_connection:
360
+ st.session_state.memory = create_custom_memory(
361
+ st.session_state.session_id,
362
+ current_user,
363
+ st.session_state.custom_db_connection,
364
+ st.session_state.get("llm"),
365
+ st.session_state.enable_summarization,
366
+ st.session_state.summary_threshold
367
+ )
368
+ elif st.session_state.initialized:
369
+ from memory import create_enhanced_memory
370
+ st.session_state.memory = create_enhanced_memory(
371
+ st.session_state.session_id,
372
+ user_id=current_user,
373
+ enable_summarization=st.session_state.enable_summarization,
374
+ summary_threshold=st.session_state.summary_threshold
375
+ )
376
+ if st.session_state.get("llm"):
377
+ st.session_state.memory.set_llm_client(st.session_state.llm)
378
+
379
+ st.rerun()
380
+
381
+ # Disconnect button (when using custom DB)
382
+ if st.session_state.initialized and st.session_state.db_source == "custom":
383
+ if st.button("🔌 Disconnect", use_container_width=True):
384
+ if st.session_state.custom_db_connection:
385
+ st.session_state.custom_db_connection.close()
386
+ st.session_state.custom_db_connection = None
387
+ st.session_state.chatbot = None
388
+ st.session_state.initialized = False
389
+ st.session_state.indexed = False
390
+ st.session_state.memory = None
391
+ st.success("Disconnected!")
392
+ st.rerun()
393
+
394
+
395
+ def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
396
+ """Initialize the chatbot with either environment or custom database."""
397
+ try:
398
+ # Get API key
399
+ groq_api_key = api_key or os.getenv("GROQ_API_KEY", "")
400
+ groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
401
+
402
+ if not groq_api_key:
403
+ st.error("GROQ_API_KEY not configured. Please enter your API key.")
404
+ return False
405
+
406
+ # Create LLM client
407
+ llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model)
408
+
409
+ # Create database connection
410
+ if custom_db_params and st.session_state.db_source == "custom":
411
+ # Validate custom params
412
+ db_type = custom_db_params.get("db_type", "mysql")
413
+
414
+ if db_type == "sqlite":
415
+ if not custom_db_params.get("sqlite_path"):
416
+ st.error("Please provide SQLite database path.")
417
+ return False
418
+ else:
419
+ if not all([custom_db_params.get("host"),
420
+ custom_db_params.get("database"),
421
+ custom_db_params.get("username")]):
422
+ st.error("Please fill in all required database fields.")
423
+ return False
424
+
425
+ # Create custom config
426
+ db_config = create_custom_db_config(**custom_db_params)
427
+
428
+ # Create custom connection
429
+ custom_connection = DatabaseConnection(db_config)
430
+
431
+ # Test connection
432
+ success, msg = custom_connection.test_connection()
433
+ if not success:
434
+ st.error(f"Connection failed: {msg}")
435
+ return False
436
+
437
+ st.session_state.custom_db_connection = custom_connection
438
+ st.session_state.custom_db_config = db_config
439
+
440
+ # Override the global db connection for the chatbot
441
+ # We need to create a chatbot with this custom connection
442
+ from chatbot import DatabaseChatbot
443
+ from database.schema_introspector import SchemaIntrospector
444
+ from rag import get_rag_engine
445
+ from sql import get_sql_generator, get_sql_validator
446
+ from router import get_query_router
447
+
448
+ chatbot = DatabaseChatbot.__new__(DatabaseChatbot)
449
+ chatbot.db = custom_connection
450
+ chatbot.introspector = SchemaIntrospector()
451
+ chatbot.introspector.db = custom_connection
452
+ chatbot.rag_engine = get_rag_engine()
453
+ chatbot.sql_generator = get_sql_generator(db_type)
454
+ chatbot.sql_validator = get_sql_validator()
455
+ chatbot.router = get_query_router()
456
+ chatbot.llm_client = llm
457
+ chatbot._schema_initialized = False
458
+ chatbot._rag_initialized = False
459
+
460
+ # Set LLM client
461
+ chatbot.set_llm_client(llm)
462
+
463
+ # Initialize (introspect schema)
464
+ schema = chatbot.introspector.introspect(force_refresh=True)
465
+ chatbot.sql_validator.set_allowed_tables(schema.table_names)
466
+ chatbot._schema_initialized = True
467
+
468
+ st.session_state.chatbot = chatbot
469
+
470
+ else:
471
+ # Use environment-based connection (existing flow)
472
+ chatbot = create_chatbot(llm)
473
+ chatbot.set_llm_client(llm)
474
+
475
+ success, msg = chatbot.initialize()
476
+ if not success:
477
+ st.error(f"Initialization failed: {msg}")
478
+ return False
479
+
480
+ st.session_state.chatbot = chatbot
481
+ st.session_state.custom_db_connection = None
482
+
483
+ st.session_state.llm = llm
484
+ st.session_state.initialized = True
485
+
486
+ # Create memory with appropriate connection
487
+ db_conn = st.session_state.custom_db_connection or get_db()
488
+ st.session_state.memory = create_custom_memory(
489
+ st.session_state.session_id,
490
+ st.session_state.user_id,
491
+ db_conn,
492
+ llm,
493
+ st.session_state.enable_summarization,
494
+ st.session_state.summary_threshold
495
+ )
496
+
497
+ return True
498
+
499
+ except Exception as e:
500
+ st.error(f"Error: {str(e)}")
501
+ import traceback
502
+ st.error(traceback.format_exc())
503
+ return False
504
+
505
+
506
+ def index_data():
507
+ """Index text data from the database."""
508
+ if st.session_state.chatbot:
509
+ progress = st.progress(0)
510
+ status = st.empty()
511
+
512
+ # Get schema from the correct introspector
513
+ schema = st.session_state.chatbot.introspector.introspect()
514
+ total_tables = len(schema.tables)
515
+ indexed = 0
516
+
517
+ def progress_callback(table_name, docs):
518
+ nonlocal indexed
519
+ indexed += 1
520
+ progress.progress(indexed / total_tables)
521
+ status.text(f"Indexed {table_name}: {docs} documents")
522
+
523
+ total_docs = st.session_state.chatbot.index_text_data(progress_callback)
524
+ st.session_state.indexed = True
525
+ status.text(f"Total: {total_docs} documents indexed")
526
+
527
+
528
+ def render_schema_explorer():
529
+ """Render schema explorer in an expander."""
530
+ if not st.session_state.initialized:
531
+ return
532
+
533
+ with st.expander("📋 Database Schema", expanded=False):
534
+ try:
535
+ schema = st.session_state.chatbot.introspector.introspect()
536
+
537
+ for table_name, table_info in schema.tables.items():
538
+ with st.container():
539
+ st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
540
+
541
+ cols = []
542
+ for col in table_info.columns:
543
+ pk = "🔑" if col.is_primary_key else ""
544
+ txt = "📝" if col.is_text_type else ""
545
+ cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
546
+
547
+ st.caption(" | ".join(cols))
548
+ st.divider()
549
+ except Exception as e:
550
+ st.error(f"Error loading schema: {e}")
551
+
552
+
553
+ def render_chat_interface():
554
+ """Render the main chat interface."""
555
+ st.title("🤖 OnceDataBot")
556
+ st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL | SQLite • Powered by Groq (FREE!)")
557
+
558
+ # Schema explorer
559
+ render_schema_explorer()
560
+
561
+ # Chat container
562
+ chat_container = st.container()
563
+
564
+ with chat_container:
565
+ # Display messages
566
+ for msg in st.session_state.messages:
567
+ with st.chat_message(msg["role"]):
568
+ st.markdown(msg["content"])
569
+
570
+ # Show metadata for assistant messages
571
+ if msg["role"] == "assistant" and "metadata" in msg:
572
+ meta = msg["metadata"]
573
+ if meta.get("query_type"):
574
+ st.caption(f"Query type: {meta['query_type']}")
575
+ if meta.get("sql_query"):
576
+ with st.expander("SQL Query"):
577
+ st.code(meta["sql_query"], language="sql")
578
+
579
+ # Chat input
580
+ if prompt := st.chat_input("Ask about your data..."):
581
+ if not st.session_state.initialized:
582
+ st.error("Please connect to a database first!")
583
+ return
584
+
585
+ # Add user message
586
+ st.session_state.messages.append({"role": "user", "content": prompt})
587
+ if st.session_state.memory:
588
+ st.session_state.memory.add_message("user", prompt)
589
+
590
+ with st.chat_message("user"):
591
+ st.markdown(prompt)
592
+
593
+ # Get response
594
+ with st.chat_message("assistant"):
595
+ with st.spinner("Thinking..."):
596
+ response = st.session_state.chatbot.chat(
597
+ prompt,
598
+ st.session_state.memory
599
+ )
600
+
601
+ st.markdown(response.answer)
602
+
603
+ # Show metadata
604
+ if response.query_type != "general":
605
+ st.caption(f"Query type: {response.query_type}")
606
+
607
+ if response.sql_query:
608
+ with st.expander("SQL Query"):
609
+ st.code(response.sql_query, language="sql")
610
+
611
+ if response.sql_results:
612
+ with st.expander("Results"):
613
+ st.dataframe(response.sql_results)
614
+
615
+ # Save to memory
616
+ st.session_state.messages.append({
617
+ "role": "assistant",
618
+ "content": response.answer,
619
+ "metadata": {
620
+ "query_type": response.query_type,
621
+ "sql_query": response.sql_query
622
+ }
623
+ })
624
+ if st.session_state.memory:
625
+ st.session_state.memory.add_message("assistant", response.answer)
626
+
627
+
628
+ def main():
629
+ """Main application entry point."""
630
+ init_session_state()
631
+ render_sidebar()
632
+ render_chat_interface()
633
+
634
+
635
+ if __name__ == "__main__":
636
+ main()