Vanshcc commited on
Commit
6ca635c
·
verified ·
1 Parent(s): bed117c

Add Select Table option

Browse files
Files changed (4) hide show
  1. app.py +640 -616
  2. chatbot.py +20 -11
  3. database/schema_introspector.py +4 -1
  4. rag/rag_engine.py +1 -1
app.py CHANGED
@@ -1,616 +1,640 @@
1
- """
2
- Schema-Agnostic Database Chatbot - Streamlit Application
3
-
4
- A production-grade chatbot that connects to ANY database
5
- (MySQL, PostgreSQL, SQLite) and provides intelligent querying
6
- through RAG and Text-to-SQL.
7
-
8
- Uses Groq for FREE LLM inference!
9
- """
10
-
11
- import os
12
- from pathlib import Path
13
-
14
- # Load .env FIRST before any other imports
15
- from dotenv import load_dotenv
16
- load_dotenv(Path(__file__).parent / ".env")
17
-
18
- import streamlit as st
19
- import uuid
20
- from datetime import datetime
21
-
22
- # Page config must be first
23
- st.set_page_config(
24
- page_title="OnceDataBot",
25
- page_icon="🤖",
26
- layout="wide",
27
- initial_sidebar_state="expanded"
28
- )
29
-
30
- # Imports
31
- from config import config, DatabaseConfig, DatabaseType
32
- from database import get_db, get_schema, get_introspector
33
- from database.connection import DatabaseConnection
34
- from llm import create_llm_client
35
- from chatbot import create_chatbot, DatabaseChatbot
36
- from memory import ChatMemory, EnhancedChatMemory
37
-
38
-
39
- # Groq models (all FREE!)
40
- GROQ_MODELS = [
41
- "llama-3.3-70b-versatile",
42
- "llama-3.1-8b-instant",
43
- "mixtral-8x7b-32768",
44
- "gemma2-9b-it"
45
- ]
46
-
47
- # Database types
48
- DB_TYPES = {
49
- "MySQL": "mysql",
50
- "PostgreSQL": "postgresql"
51
- }
52
-
53
-
54
-
55
-
56
- def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
57
- """Create a custom database configuration from user input."""
58
- return DatabaseConfig(
59
- db_type=DatabaseType(db_type),
60
- host=kwargs.get("host", ""),
61
- port=kwargs.get("port", 3306 if db_type == "mysql" else 5432),
62
- database=kwargs.get("database", ""),
63
- username=kwargs.get("username", ""),
64
- password=kwargs.get("password", ""),
65
- ssl_ca=kwargs.get("ssl_ca", None)
66
- )
67
-
68
-
69
- def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
70
- enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory:
71
- """Create enhanced memory with a custom database connection."""
72
- return EnhancedChatMemory(
73
- session_id=session_id,
74
- user_id=user_id,
75
- max_messages=20,
76
- db_connection=db_connection,
77
- llm_client=llm_client,
78
- enable_summarization=enable_summarization,
79
- summary_threshold=summary_threshold
80
- )
81
-
82
-
83
- def init_session_state():
84
- """Initialize Streamlit session state."""
85
- if "session_id" not in st.session_state:
86
- st.session_state.session_id = str(uuid.uuid4())
87
-
88
- if "messages" not in st.session_state:
89
- st.session_state.messages = []
90
-
91
- if "chatbot" not in st.session_state:
92
- st.session_state.chatbot = None
93
-
94
- if "initialized" not in st.session_state:
95
- st.session_state.initialized = False
96
-
97
- if "user_id" not in st.session_state:
98
- st.session_state.user_id = "default"
99
-
100
- if "enable_summarization" not in st.session_state:
101
- st.session_state.enable_summarization = True
102
-
103
- if "summary_threshold" not in st.session_state:
104
- st.session_state.summary_threshold = 10
105
-
106
- if "memory" not in st.session_state:
107
- st.session_state.memory = None
108
-
109
- if "indexed" not in st.session_state:
110
- st.session_state.indexed = False
111
-
112
- if "db_source" not in st.session_state:
113
- st.session_state.db_source = "environment" # "environment" or "custom"
114
-
115
- if "custom_db_config" not in st.session_state:
116
- st.session_state.custom_db_config = None
117
-
118
- if "custom_db_connection" not in st.session_state:
119
- st.session_state.custom_db_connection = None
120
-
121
-
122
- def render_database_config():
123
- """Render database configuration section in sidebar."""
124
- st.subheader("🗄️ Database Configuration")
125
-
126
- # Database source selection
127
- db_source = st.radio(
128
- "Database Source",
129
- options=["Use Environment Variables", "Custom Database"],
130
- index=0 if st.session_state.db_source == "environment" else 1,
131
- key="db_source_radio",
132
- help="Choose to use .env settings or enter custom credentials"
133
- )
134
-
135
- st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom"
136
-
137
- if st.session_state.db_source == "environment":
138
- # Show current environment config
139
- current_db_type = config.database.db_type.value.upper()
140
- st.info(f"📌 Using {current_db_type} from environment")
141
- st.caption(f"Host: {config.database.host}")
142
- return None
143
-
144
- else:
145
- # Custom database configuration
146
- st.markdown("##### Enter Database Credentials")
147
-
148
- # Database type selector
149
- db_type_label = st.selectbox(
150
- "Database Type",
151
- options=list(DB_TYPES.keys()),
152
- index=0,
153
- key="custom_db_type"
154
- )
155
- db_type = DB_TYPES[db_type_label]
156
-
157
- if True: # MySQL or PostgreSQL (SQLite removed)
158
- # MySQL or PostgreSQL
159
- col1, col2 = st.columns([3, 1])
160
- with col1:
161
- host = st.text_input(
162
- "Host",
163
- value="",
164
- key="db_host_input",
165
- placeholder="your-database-host.com"
166
- )
167
- with col2:
168
- default_port = 3306 if db_type == "mysql" else 5432
169
- port = st.number_input(
170
- "Port",
171
- value=default_port,
172
- min_value=1,
173
- max_value=65535,
174
- key="db_port_input"
175
- )
176
-
177
- database = st.text_input(
178
- "Database Name",
179
- value="",
180
- key="db_name_input",
181
- placeholder="your_database"
182
- )
183
-
184
- username = st.text_input(
185
- "Username",
186
- value="",
187
- key="db_user_input",
188
- placeholder="your_username"
189
- )
190
-
191
- password = st.text_input(
192
- "Password",
193
- value="",
194
- type="password",
195
- key="db_pass_input"
196
- )
197
-
198
- # Optional SSL
199
- with st.expander("🔒 SSL Settings (Optional)"):
200
- ssl_ca = st.text_input(
201
- "SSL CA Certificate Path",
202
- value="",
203
- key="ssl_ca_input",
204
- help="Path to SSL CA certificate file (for cloud databases like Aiven)"
205
- )
206
-
207
- return {
208
- "db_type": db_type,
209
- "host": host,
210
- "port": int(port),
211
- "database": database,
212
- "username": username,
213
- "password": password,
214
- "ssl_ca": ssl_ca if ssl_ca else None
215
- }
216
-
217
-
218
- def render_sidebar():
219
- """Render the configuration sidebar."""
220
- with st.sidebar:
221
- st.title("⚙️ Settings")
222
-
223
- # User Profile
224
- st.subheader("👤 User Profile")
225
- user_id = st.text_input(
226
- "User ID / Name",
227
- value=st.session_state.get("user_id", "default"),
228
- key="user_id_input",
229
- help="Your unique ID for private memory storage"
230
- )
231
- if user_id != st.session_state.get("user_id"):
232
- st.session_state.user_id = user_id
233
- st.session_state.session_id = str(uuid.uuid4())
234
- st.session_state.messages = []
235
-
236
- # Recreate memory for new user
237
- if st.session_state.custom_db_connection:
238
- st.session_state.memory = create_custom_memory(
239
- st.session_state.session_id,
240
- user_id,
241
- st.session_state.custom_db_connection,
242
- st.session_state.get("llm"),
243
- st.session_state.enable_summarization,
244
- st.session_state.summary_threshold
245
- )
246
- elif st.session_state.initialized:
247
- from memory import create_enhanced_memory
248
- st.session_state.memory = create_enhanced_memory(
249
- st.session_state.session_id,
250
- user_id=user_id,
251
- enable_summarization=st.session_state.enable_summarization,
252
- summary_threshold=st.session_state.summary_threshold
253
- )
254
-
255
- if st.session_state.memory:
256
- st.session_state.memory.clear_user_history()
257
- st.rerun()
258
-
259
- st.divider()
260
-
261
- # Database Configuration
262
- custom_db_params = render_database_config()
263
-
264
- st.divider()
265
-
266
- # LLM Configuration
267
- st.subheader("🤖 LLM Configuration")
268
-
269
- # Show status of API key
270
- if os.getenv("GROQ_API_KEY"):
271
- st.success("✓ API Key configured")
272
- else:
273
- st.warning("⚠️ GROQ_API_KEY not set in environment")
274
-
275
- st.divider()
276
-
277
- # Initialize Button
278
- if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"):
279
- with st.spinner("Connecting to database..."):
280
- success = initialize_chatbot(custom_db_params, None, None)
281
- if success:
282
- st.success(" Connected!")
283
- st.rerun()
284
-
285
- # Index Button (after initialization)
286
- if st.session_state.initialized:
287
- if st.button("📚 Index Text Data", use_container_width=True):
288
- with st.spinner("Indexing text data..."):
289
- index_data()
290
- st.success(" Indexed!")
291
- st.rerun()
292
-
293
- st.divider()
294
-
295
- # Status
296
- st.subheader("📊 Status")
297
- if st.session_state.initialized:
298
- # Show database type
299
- if st.session_state.custom_db_connection:
300
- db_type = st.session_state.custom_db_connection.db_type.value.upper()
301
- else:
302
- db_type = get_db().db_type.value.upper()
303
-
304
- st.success(f"Database: {db_type} ✓")
305
-
306
- try:
307
- schema = get_schema()
308
- st.info(f"Tables: {len(schema.tables)}")
309
- except:
310
- st.warning("Schema not loaded")
311
-
312
- if st.session_state.indexed:
313
- from rag import get_rag_engine
314
- engine = get_rag_engine()
315
- st.info(f"Indexed Docs: {engine.document_count}")
316
- else:
317
- st.warning("Not connected")
318
-
319
- # New Chat
320
- if st.button(" New Chat", use_container_width=True, type="secondary"):
321
- if st.session_state.memory:
322
- st.session_state.memory.clear()
323
-
324
- st.session_state.messages = []
325
- st.session_state.session_id = str(uuid.uuid4())
326
-
327
- current_user = st.session_state.get("user_id", "default")
328
-
329
- if st.session_state.custom_db_connection:
330
- st.session_state.memory = create_custom_memory(
331
- st.session_state.session_id,
332
- current_user,
333
- st.session_state.custom_db_connection,
334
- st.session_state.get("llm"),
335
- st.session_state.enable_summarization,
336
- st.session_state.summary_threshold
337
- )
338
- elif st.session_state.initialized:
339
- from memory import create_enhanced_memory
340
- st.session_state.memory = create_enhanced_memory(
341
- st.session_state.session_id,
342
- user_id=current_user,
343
- enable_summarization=st.session_state.enable_summarization,
344
- summary_threshold=st.session_state.summary_threshold
345
- )
346
- if st.session_state.get("llm"):
347
- st.session_state.memory.set_llm_client(st.session_state.llm)
348
-
349
- st.rerun()
350
-
351
- # Disconnect button (when using custom DB)
352
- if st.session_state.initialized and st.session_state.db_source == "custom":
353
- if st.button("🔌 Disconnect", use_container_width=True):
354
- if st.session_state.custom_db_connection:
355
- st.session_state.custom_db_connection.close()
356
- st.session_state.custom_db_connection = None
357
- st.session_state.chatbot = None
358
- st.session_state.initialized = False
359
- st.session_state.indexed = False
360
- st.session_state.memory = None
361
- st.success("Disconnected!")
362
- st.rerun()
363
-
364
-
365
- def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
366
- """Initialize the chatbot with either environment or custom database."""
367
- try:
368
- # Get API key
369
- groq_api_key = api_key or os.getenv("GROQ_API_KEY", "")
370
- groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
371
-
372
- if not groq_api_key:
373
- st.error("GROQ_API_KEY not configured. Please enter your API key.")
374
- return False
375
-
376
- # Create LLM client
377
- llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model)
378
-
379
- # Create database connection
380
- if custom_db_params and st.session_state.db_source == "custom":
381
- # Validate custom params
382
- db_type = custom_db_params.get("db_type", "mysql")
383
-
384
- if True:
385
- if not all([custom_db_params.get("host"),
386
- custom_db_params.get("database"),
387
- custom_db_params.get("username")]):
388
- st.error("Please fill in all required database fields.")
389
- return False
390
-
391
- # Create custom config
392
- db_config = create_custom_db_config(**custom_db_params)
393
-
394
- # Create custom connection
395
- custom_connection = DatabaseConnection(db_config)
396
-
397
- # Test connection
398
- success, msg = custom_connection.test_connection()
399
- if not success:
400
- st.error(f"Connection failed: {msg}")
401
- return False
402
-
403
- st.session_state.custom_db_connection = custom_connection
404
- st.session_state.custom_db_config = db_config
405
-
406
- # Override the global db connection for the chatbot
407
- # We need to create a chatbot with this custom connection
408
- from chatbot import DatabaseChatbot
409
- from database.schema_introspector import SchemaIntrospector
410
- from rag import get_rag_engine
411
- from sql import get_sql_generator, get_sql_validator
412
- from router import get_query_router
413
-
414
- chatbot = DatabaseChatbot.__new__(DatabaseChatbot)
415
- chatbot.db = custom_connection
416
- chatbot.introspector = SchemaIntrospector()
417
- chatbot.introspector.db = custom_connection
418
- chatbot.rag_engine = get_rag_engine()
419
- chatbot.sql_generator = get_sql_generator(db_type)
420
- chatbot.sql_validator = get_sql_validator()
421
- chatbot.router = get_query_router()
422
- chatbot.llm_client = llm
423
- chatbot._schema_initialized = False
424
- chatbot._rag_initialized = False
425
-
426
- # Set LLM client
427
- chatbot.set_llm_client(llm)
428
-
429
- # Initialize (introspect schema)
430
- schema = chatbot.introspector.introspect(force_refresh=True)
431
- chatbot.sql_validator.set_allowed_tables(schema.table_names)
432
- chatbot._schema_initialized = True
433
-
434
- st.session_state.chatbot = chatbot
435
-
436
- else:
437
- # Use environment-based connection (existing flow)
438
- chatbot = create_chatbot(llm)
439
- chatbot.set_llm_client(llm)
440
-
441
- success, msg = chatbot.initialize()
442
- if not success:
443
- st.error(f"Initialization failed: {msg}")
444
- return False
445
-
446
- st.session_state.chatbot = chatbot
447
- st.session_state.custom_db_connection = None
448
-
449
- st.session_state.llm = llm
450
- st.session_state.initialized = True
451
- st.session_state.indexed = False # Reset index status on new connection
452
-
453
- # Clear RAG index to ensure no data from previous DB connection persists
454
- if hasattr(chatbot, 'rag_engine') and hasattr(chatbot.rag_engine, 'clear_index'):
455
- chatbot.rag_engine.clear_index()
456
-
457
- # Create memory with appropriate connection
458
- db_conn = st.session_state.custom_db_connection or get_db()
459
- st.session_state.memory = create_custom_memory(
460
- st.session_state.session_id,
461
- st.session_state.user_id,
462
- db_conn,
463
- llm,
464
- st.session_state.enable_summarization,
465
- st.session_state.summary_threshold
466
- )
467
-
468
- return True
469
-
470
- except Exception as e:
471
- st.error(f"Error: {str(e)}")
472
- import traceback
473
- st.error(traceback.format_exc())
474
- return False
475
-
476
-
477
- def index_data():
478
- """Index text data from the database."""
479
- if st.session_state.chatbot:
480
- progress = st.progress(0)
481
- status = st.empty()
482
-
483
- # Get schema from the correct introspector
484
- schema = st.session_state.chatbot.introspector.introspect()
485
- total_tables = len(schema.tables)
486
- indexed = 0
487
-
488
- def progress_callback(table_name, docs):
489
- nonlocal indexed
490
- indexed += 1
491
- progress.progress(indexed / total_tables)
492
- status.text(f"Indexed {table_name}: {docs} documents")
493
-
494
- total_docs = st.session_state.chatbot.index_text_data(progress_callback)
495
- st.session_state.indexed = True
496
- status.text(f"Total: {total_docs} documents indexed")
497
-
498
-
499
- def render_schema_explorer():
500
- """Render schema explorer in an expander."""
501
- if not st.session_state.initialized:
502
- return
503
-
504
- with st.expander("📋 Database Schema", expanded=False):
505
- try:
506
- schema = st.session_state.chatbot.introspector.introspect()
507
-
508
- for table_name, table_info in schema.tables.items():
509
- with st.container():
510
- st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
511
-
512
- cols = []
513
- for col in table_info.columns:
514
- pk = "🔑" if col.is_primary_key else ""
515
- txt = "📝" if col.is_text_type else ""
516
- cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
517
-
518
- st.caption(" | ".join(cols))
519
- st.divider()
520
- except Exception as e:
521
- st.error(f"Error loading schema: {e}")
522
-
523
-
524
- def render_chat_interface():
525
- """Render the main chat interface."""
526
- st.title("🤖 OnceDataBot")
527
- st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL • Powered by Groq (FREE!)")
528
-
529
- # Schema explorer
530
- render_schema_explorer()
531
-
532
- # Chat container
533
- chat_container = st.container()
534
-
535
- with chat_container:
536
- # Display messages
537
- for msg in st.session_state.messages:
538
- with st.chat_message(msg["role"]):
539
- st.markdown(msg["content"])
540
-
541
- # Show metadata for assistant messages
542
- if msg["role"] == "assistant" and "metadata" in msg:
543
- meta = msg["metadata"]
544
- if meta.get("query_type"):
545
- st.caption(f"Query type: {meta['query_type']}")
546
- if meta.get("sql_query"):
547
- with st.expander("SQL Query"):
548
- st.code(meta["sql_query"], language="sql")
549
-
550
- # Chat input
551
- if prompt := st.chat_input("Ask about your data..."):
552
- if not st.session_state.initialized:
553
- st.error("Please connect to a database first!")
554
- return
555
-
556
- # Add user message
557
- st.session_state.messages.append({"role": "user", "content": prompt})
558
- if st.session_state.memory:
559
- st.session_state.memory.add_message("user", prompt)
560
-
561
- with st.chat_message("user"):
562
- st.markdown(prompt)
563
-
564
- # Get response
565
- with st.chat_message("assistant"):
566
- with st.spinner("Thinking..."):
567
- response = st.session_state.chatbot.chat(
568
- prompt,
569
- st.session_state.memory
570
- )
571
-
572
- st.markdown(response.answer)
573
-
574
- # Show metadata
575
- if response.query_type != "general":
576
- st.caption(f"Query type: {response.query_type}")
577
-
578
- if response.sql_query:
579
- with st.expander("SQL Query"):
580
- st.code(response.sql_query, language="sql")
581
-
582
- if response.sql_results:
583
- with st.expander("Results"):
584
- st.dataframe(response.sql_results)
585
-
586
- # Save to memory
587
- st.session_state.messages.append({
588
- "role": "assistant",
589
- "content": response.answer,
590
- "metadata": {
591
- "query_type": response.query_type,
592
- "sql_query": response.sql_query
593
- }
594
- })
595
- if st.session_state.memory:
596
- st.session_state.memory.add_message("assistant", response.answer)
597
-
598
-
599
- def main():
600
- """Main application entry point."""
601
- init_session_state()
602
-
603
- # Auto-connect to environment database on first load
604
- if "auto_connect_attempted" not in st.session_state:
605
- st.session_state.auto_connect_attempted = True
606
- if st.session_state.db_source == "environment":
607
- success = initialize_chatbot()
608
- if success:
609
- st.toast("✅ Auto-connected to database!")
610
-
611
- render_sidebar()
612
- render_chat_interface()
613
-
614
-
615
- if __name__ == "__main__":
616
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Schema-Agnostic Database Chatbot - Streamlit Application
3
+
4
+ A production-grade chatbot that connects to ANY database
5
+ (MySQL, PostgreSQL, SQLite) and provides intelligent querying
6
+ through RAG and Text-to-SQL.
7
+
8
+ Uses Groq for FREE LLM inference!
9
+ """
10
+
11
+ import os
12
+ from pathlib import Path
13
+
14
+ # Load .env FIRST before any other imports
15
+ from dotenv import load_dotenv
16
+ load_dotenv(Path(__file__).parent / ".env")
17
+
18
+ import streamlit as st
19
+ import uuid
20
+ from datetime import datetime
21
+
22
+ # Page config must be first
23
+ st.set_page_config(
24
+ page_title="OnceDataBot",
25
+ page_icon="🤖",
26
+ layout="wide",
27
+ initial_sidebar_state="expanded"
28
+ )
29
+
30
+ # Imports
31
+ from config import config, DatabaseConfig, DatabaseType
32
+ from database import get_db, get_schema, get_introspector
33
+ from database.connection import DatabaseConnection
34
+ from llm import create_llm_client
35
+ from chatbot import create_chatbot, DatabaseChatbot
36
+ from memory import ChatMemory, EnhancedChatMemory
37
+
38
+
39
+ # Groq models (all FREE!)
40
+ GROQ_MODELS = [
41
+ "llama-3.3-70b-versatile",
42
+ "llama-3.1-8b-instant",
43
+ "mixtral-8x7b-32768",
44
+ "gemma2-9b-it"
45
+ ]
46
+
47
+ # Database types
48
+ DB_TYPES = {
49
+ "MySQL": "mysql",
50
+ "PostgreSQL": "postgresql"
51
+ }
52
+
53
+
54
+
55
+
56
+ def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
57
+ """Create a custom database configuration from user input."""
58
+ return DatabaseConfig(
59
+ db_type=DatabaseType(db_type),
60
+ host=kwargs.get("host", ""),
61
+ port=kwargs.get("port", 3306 if db_type == "mysql" else 5432),
62
+ database=kwargs.get("database", ""),
63
+ username=kwargs.get("username", ""),
64
+ password=kwargs.get("password", ""),
65
+ ssl_ca=kwargs.get("ssl_ca", None)
66
+ )
67
+
68
+
69
+ def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
70
+ enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory:
71
+ """Create enhanced memory with a custom database connection."""
72
+ return EnhancedChatMemory(
73
+ session_id=session_id,
74
+ user_id=user_id,
75
+ max_messages=20,
76
+ db_connection=db_connection,
77
+ llm_client=llm_client,
78
+ enable_summarization=enable_summarization,
79
+ summary_threshold=summary_threshold
80
+ )
81
+
82
+
83
+ def init_session_state():
84
+ """Initialize Streamlit session state."""
85
+ if "session_id" not in st.session_state:
86
+ st.session_state.session_id = str(uuid.uuid4())
87
+
88
+ if "messages" not in st.session_state:
89
+ st.session_state.messages = []
90
+
91
+ if "chatbot" not in st.session_state:
92
+ st.session_state.chatbot = None
93
+
94
+ if "initialized" not in st.session_state:
95
+ st.session_state.initialized = False
96
+
97
+ if "user_id" not in st.session_state:
98
+ st.session_state.user_id = "default"
99
+
100
+ if "enable_summarization" not in st.session_state:
101
+ st.session_state.enable_summarization = True
102
+
103
+ if "summary_threshold" not in st.session_state:
104
+ st.session_state.summary_threshold = 10
105
+
106
+ if "memory" not in st.session_state:
107
+ st.session_state.memory = None
108
+
109
+ if "indexed" not in st.session_state:
110
+ st.session_state.indexed = False
111
+
112
+ if "db_source" not in st.session_state:
113
+ st.session_state.db_source = "environment" # "environment" or "custom"
114
+
115
+ if "custom_db_config" not in st.session_state:
116
+ st.session_state.custom_db_config = None
117
+
118
+ if "custom_db_connection" not in st.session_state:
119
+ st.session_state.custom_db_connection = None
120
+
121
+ if "ignored_tables" not in st.session_state:
122
+ st.session_state.ignored_tables = set()
123
+
124
+
125
+ def render_database_config():
126
+ """Render database configuration section in sidebar."""
127
+ st.subheader("🗄️ Database Configuration")
128
+
129
+ # Database source selection
130
+ db_source = st.radio(
131
+ "Database Source",
132
+ options=["Use Environment Variables", "Custom Database"],
133
+ index=0 if st.session_state.db_source == "environment" else 1,
134
+ key="db_source_radio",
135
+ help="Choose to use .env settings or enter custom credentials"
136
+ )
137
+
138
+ st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom"
139
+
140
+ if st.session_state.db_source == "environment":
141
+ # Show current environment config
142
+ current_db_type = config.database.db_type.value.upper()
143
+ st.info(f"📌 Using {current_db_type} from environment")
144
+ st.caption(f"Host: {config.database.host}")
145
+ return None
146
+
147
+ else:
148
+ # Custom database configuration
149
+ st.markdown("##### Enter Database Credentials")
150
+
151
+ # Database type selector
152
+ db_type_label = st.selectbox(
153
+ "Database Type",
154
+ options=list(DB_TYPES.keys()),
155
+ index=0,
156
+ key="custom_db_type"
157
+ )
158
+ db_type = DB_TYPES[db_type_label]
159
+
160
+ if True: # MySQL or PostgreSQL (SQLite removed)
161
+ # MySQL or PostgreSQL
162
+ col1, col2 = st.columns([3, 1])
163
+ with col1:
164
+ host = st.text_input(
165
+ "Host",
166
+ value="",
167
+ key="db_host_input",
168
+ placeholder="your-database-host.com"
169
+ )
170
+ with col2:
171
+ default_port = 3306 if db_type == "mysql" else 5432
172
+ port = st.number_input(
173
+ "Port",
174
+ value=default_port,
175
+ min_value=1,
176
+ max_value=65535,
177
+ key="db_port_input"
178
+ )
179
+
180
+ database = st.text_input(
181
+ "Database Name",
182
+ value="",
183
+ key="db_name_input",
184
+ placeholder="your_database"
185
+ )
186
+
187
+ username = st.text_input(
188
+ "Username",
189
+ value="",
190
+ key="db_user_input",
191
+ placeholder="your_username"
192
+ )
193
+
194
+ password = st.text_input(
195
+ "Password",
196
+ value="",
197
+ type="password",
198
+ key="db_pass_input"
199
+ )
200
+
201
+ # Optional SSL
202
+ with st.expander("🔒 SSL Settings (Optional)"):
203
+ ssl_ca = st.text_input(
204
+ "SSL CA Certificate Path",
205
+ value="",
206
+ key="ssl_ca_input",
207
+ help="Path to SSL CA certificate file (for cloud databases like Aiven)"
208
+ )
209
+
210
+ return {
211
+ "db_type": db_type,
212
+ "host": host,
213
+ "port": int(port),
214
+ "database": database,
215
+ "username": username,
216
+ "password": password,
217
+ "ssl_ca": ssl_ca if ssl_ca else None
218
+ }
219
+
220
+
221
+ def render_sidebar():
222
+ """Render the configuration sidebar."""
223
+ with st.sidebar:
224
+ st.title("⚙️ Settings")
225
+
226
+ # User Profile
227
+ st.subheader("👤 User Profile")
228
+ user_id = st.text_input(
229
+ "User ID / Name",
230
+ value=st.session_state.get("user_id", "default"),
231
+ key="user_id_input",
232
+ help="Your unique ID for private memory storage"
233
+ )
234
+ if user_id != st.session_state.get("user_id"):
235
+ st.session_state.user_id = user_id
236
+ st.session_state.session_id = str(uuid.uuid4())
237
+ st.session_state.messages = []
238
+
239
+ # Recreate memory for new user
240
+ if st.session_state.custom_db_connection:
241
+ st.session_state.memory = create_custom_memory(
242
+ st.session_state.session_id,
243
+ user_id,
244
+ st.session_state.custom_db_connection,
245
+ st.session_state.get("llm"),
246
+ st.session_state.enable_summarization,
247
+ st.session_state.summary_threshold
248
+ )
249
+ elif st.session_state.initialized:
250
+ from memory import create_enhanced_memory
251
+ st.session_state.memory = create_enhanced_memory(
252
+ st.session_state.session_id,
253
+ user_id=user_id,
254
+ enable_summarization=st.session_state.enable_summarization,
255
+ summary_threshold=st.session_state.summary_threshold
256
+ )
257
+
258
+ if st.session_state.memory:
259
+ st.session_state.memory.clear_user_history()
260
+ st.rerun()
261
+
262
+ st.divider()
263
+
264
+ # Database Configuration
265
+ custom_db_params = render_database_config()
266
+
267
+ st.divider()
268
+
269
+ # LLM Configuration
270
+ st.subheader("🤖 LLM Configuration")
271
+
272
+ # Show status of API key
273
+ if os.getenv("GROQ_API_KEY"):
274
+ st.success("✓ API Key configured")
275
+ else:
276
+ st.warning("⚠️ GROQ_API_KEY not set in environment")
277
+
278
+ st.divider()
279
+
280
+ # Initialize Button
281
+ if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"):
282
+ with st.spinner("Connecting to database..."):
283
+ success = initialize_chatbot(custom_db_params, None, None)
284
+ if success:
285
+ st.success("✅ Connected!")
286
+ st.rerun()
287
+
288
+ # Index Button (after initialization)
289
+ if st.session_state.initialized:
290
+ if st.button("📚 Index Text Data", use_container_width=True):
291
+ with st.spinner("Indexing text data..."):
292
+ index_data()
293
+ st.success("✅ Indexed!")
294
+ st.rerun()
295
+
296
+ st.divider()
297
+
298
+ # Status
299
+ st.subheader("📊 Status")
300
+ if st.session_state.initialized:
301
+ # Show database type
302
+ if st.session_state.custom_db_connection:
303
+ db_type = st.session_state.custom_db_connection.db_type.value.upper()
304
+ else:
305
+ db_type = get_db().db_type.value.upper()
306
+
307
+ st.success(f"Database: {db_type} ✓")
308
+
309
+ try:
310
+ schema = get_schema()
311
+ st.info(f"Tables: {len(schema.tables)}")
312
+ except:
313
+ st.warning("Schema not loaded")
314
+
315
+ if st.session_state.indexed:
316
+ from rag import get_rag_engine
317
+ engine = get_rag_engine()
318
+ st.info(f"Indexed Docs: {engine.document_count}")
319
+ else:
320
+ st.warning("Not connected")
321
+
322
+ # New Chat
323
+ if st.button("➕ New Chat", use_container_width=True, type="secondary"):
324
+ if st.session_state.memory:
325
+ st.session_state.memory.clear()
326
+
327
+ st.session_state.messages = []
328
+ st.session_state.session_id = str(uuid.uuid4())
329
+
330
+ current_user = st.session_state.get("user_id", "default")
331
+
332
+ if st.session_state.custom_db_connection:
333
+ st.session_state.memory = create_custom_memory(
334
+ st.session_state.session_id,
335
+ current_user,
336
+ st.session_state.custom_db_connection,
337
+ st.session_state.get("llm"),
338
+ st.session_state.enable_summarization,
339
+ st.session_state.summary_threshold
340
+ )
341
+ elif st.session_state.initialized:
342
+ from memory import create_enhanced_memory
343
+ st.session_state.memory = create_enhanced_memory(
344
+ st.session_state.session_id,
345
+ user_id=current_user,
346
+ enable_summarization=st.session_state.enable_summarization,
347
+ summary_threshold=st.session_state.summary_threshold
348
+ )
349
+ if st.session_state.get("llm"):
350
+ st.session_state.memory.set_llm_client(st.session_state.llm)
351
+
352
+ st.rerun()
353
+
354
+ # Disconnect button (when using custom DB)
355
+ if st.session_state.initialized and st.session_state.db_source == "custom":
356
+ if st.button("���� Disconnect", use_container_width=True):
357
+ if st.session_state.custom_db_connection:
358
+ st.session_state.custom_db_connection.close()
359
+ st.session_state.custom_db_connection = None
360
+ st.session_state.chatbot = None
361
+ st.session_state.initialized = False
362
+ st.session_state.indexed = False
363
+ st.session_state.memory = None
364
+ st.success("Disconnected!")
365
+ st.rerun()
366
+
367
+
368
+ def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
369
+ """Initialize the chatbot with either environment or custom database."""
370
+ try:
371
+ # Get API key
372
+ groq_api_key = api_key or os.getenv("GROQ_API_KEY", "")
373
+ groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
374
+
375
+ if not groq_api_key:
376
+ st.error("GROQ_API_KEY not configured. Please enter your API key.")
377
+ return False
378
+
379
+ # Create LLM client
380
+ llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model)
381
+
382
+ # Create database connection
383
+ if custom_db_params and st.session_state.db_source == "custom":
384
+ # Validate custom params
385
+ db_type = custom_db_params.get("db_type", "mysql")
386
+
387
+ if True:
388
+ if not all([custom_db_params.get("host"),
389
+ custom_db_params.get("database"),
390
+ custom_db_params.get("username")]):
391
+ st.error("Please fill in all required database fields.")
392
+ return False
393
+
394
+ # Create custom config
395
+ db_config = create_custom_db_config(**custom_db_params)
396
+
397
+ # Create custom connection
398
+ custom_connection = DatabaseConnection(db_config)
399
+
400
+ # Test connection
401
+ success, msg = custom_connection.test_connection()
402
+ if not success:
403
+ st.error(f"Connection failed: {msg}")
404
+ return False
405
+
406
+ st.session_state.custom_db_connection = custom_connection
407
+ st.session_state.custom_db_config = db_config
408
+
409
+ # Override the global db connection for the chatbot
410
+ # We need to create a chatbot with this custom connection
411
+ from chatbot import DatabaseChatbot
412
+ from database.schema_introspector import SchemaIntrospector
413
+ from rag import get_rag_engine
414
+ from sql import get_sql_generator, get_sql_validator
415
+ from router import get_query_router
416
+
417
+ chatbot = DatabaseChatbot.__new__(DatabaseChatbot)
418
+ chatbot.db = custom_connection
419
+ chatbot.introspector = SchemaIntrospector()
420
+ chatbot.introspector.db = custom_connection
421
+ chatbot.rag_engine = get_rag_engine()
422
+ chatbot.sql_generator = get_sql_generator(db_type)
423
+ chatbot.sql_validator = get_sql_validator()
424
+ chatbot.router = get_query_router()
425
+ chatbot.llm_client = llm
426
+ chatbot._schema_initialized = False
427
+ chatbot._rag_initialized = False
428
+
429
+ # Set LLM client
430
+ chatbot.set_llm_client(llm)
431
+
432
+ # Initialize (introspect schema)
433
+ schema = chatbot.introspector.introspect(force_refresh=True)
434
+ chatbot.sql_validator.set_allowed_tables(schema.table_names)
435
+ chatbot._schema_initialized = True
436
+
437
+ st.session_state.chatbot = chatbot
438
+
439
+ else:
440
+ # Use environment-based connection (existing flow)
441
+ chatbot = create_chatbot(llm)
442
+ chatbot.set_llm_client(llm)
443
+
444
+ success, msg = chatbot.initialize()
445
+ if not success:
446
+ st.error(f"Initialization failed: {msg}")
447
+ return False
448
+
449
+ st.session_state.chatbot = chatbot
450
+ st.session_state.custom_db_connection = None
451
+
452
+ st.session_state.llm = llm
453
+ st.session_state.initialized = True
454
+ st.session_state.indexed = False # Reset index status on new connection
455
+
456
+ # Clear RAG index to ensure no data from previous DB connection persists
457
+ if hasattr(chatbot, 'rag_engine') and hasattr(chatbot.rag_engine, 'clear_index'):
458
+ chatbot.rag_engine.clear_index()
459
+
460
+ # Create memory with appropriate connection
461
+ db_conn = st.session_state.custom_db_connection or get_db()
462
+ st.session_state.memory = create_custom_memory(
463
+ st.session_state.session_id,
464
+ st.session_state.user_id,
465
+ db_conn,
466
+ llm,
467
+ st.session_state.enable_summarization,
468
+ st.session_state.summary_threshold
469
+ )
470
+
471
+ return True
472
+
473
+ except Exception as e:
474
+ st.error(f"Error: {str(e)}")
475
+ import traceback
476
+ st.error(traceback.format_exc())
477
+ return False
478
+
479
+
480
+ def index_data():
481
+ """Index text data from the database."""
482
+ if st.session_state.chatbot:
483
+ progress = st.progress(0)
484
+ status = st.empty()
485
+
486
+ # Get schema from the correct introspector
487
+ schema = st.session_state.chatbot.introspector.introspect()
488
+ total_tables = len(schema.tables)
489
+ indexed = 0
490
+
491
+ def progress_callback(table_name, docs):
492
+ nonlocal indexed
493
+ indexed += 1
494
+ progress.progress(indexed / total_tables)
495
+ status.text(f"Indexed {table_name}: {docs} documents")
496
+
497
+ total_docs = st.session_state.chatbot.index_text_data(progress_callback)
498
+ st.session_state.indexed = True
499
+ status.text(f"Total: {total_docs} documents indexed")
500
+
501
+
502
+ def render_schema_explorer():
503
+ """Render schema explorer in an expander."""
504
+ if not st.session_state.initialized:
505
+ return
506
+
507
+ with st.expander("📋 Database Schema", expanded=False):
508
+ try:
509
+ schema = st.session_state.chatbot.introspector.introspect()
510
+
511
+ st.markdown("Uncheck tables to exclude them from the chat context.")
512
+
513
+ for table_name, table_info in schema.tables.items():
514
+ col1, col2 = st.columns([0.05, 0.95])
515
+
516
+ with col1:
517
+ is_active = table_name not in st.session_state.ignored_tables
518
+ active = st.checkbox(
519
+ "Use",
520
+ value=is_active,
521
+ key=f"use_{table_name}",
522
+ label_visibility="collapsed",
523
+ help=f"Include {table_name} in chat analysis"
524
+ )
525
+
526
+ if not active:
527
+ st.session_state.ignored_tables.add(table_name)
528
+ else:
529
+ st.session_state.ignored_tables.discard(table_name)
530
+
531
+ with col2:
532
+ with st.container():
533
+ st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
534
+
535
+ cols = []
536
+ for col in table_info.columns:
537
+ pk = "🔑" if col.is_primary_key else ""
538
+ txt = "📝" if col.is_text_type else ""
539
+ cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
540
+
541
+ st.caption(" | ".join(cols))
542
+ st.divider()
543
+ except Exception as e:
544
+ st.error(f"Error loading schema: {e}")
545
+
546
+
547
+ def render_chat_interface():
548
+ """Render the main chat interface."""
549
+ st.title("🤖 OnceDataBot")
550
+ st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL • Powered by Groq (FREE!)")
551
+
552
+ # Schema explorer
553
+ render_schema_explorer()
554
+
555
+ # Chat container
556
+ chat_container = st.container()
557
+
558
+ with chat_container:
559
+ # Display messages
560
+ for msg in st.session_state.messages:
561
+ with st.chat_message(msg["role"]):
562
+ st.markdown(msg["content"])
563
+
564
+ # Show metadata for assistant messages
565
+ if msg["role"] == "assistant" and "metadata" in msg:
566
+ meta = msg["metadata"]
567
+ if meta.get("query_type"):
568
+ st.caption(f"Query type: {meta['query_type']}")
569
+ if meta.get("sql_query"):
570
+ with st.expander("SQL Query"):
571
+ st.code(meta["sql_query"], language="sql")
572
+
573
+ # Chat input
574
+ if prompt := st.chat_input("Ask about your data..."):
575
+ if not st.session_state.initialized:
576
+ st.error("Please connect to a database first!")
577
+ return
578
+
579
+ # Add user message
580
+ st.session_state.messages.append({"role": "user", "content": prompt})
581
+ if st.session_state.memory:
582
+ st.session_state.memory.add_message("user", prompt)
583
+
584
+ with st.chat_message("user"):
585
+ st.markdown(prompt)
586
+
587
+ # Get response
588
+ with st.chat_message("assistant"):
589
+ with st.spinner("Thinking..."):
590
+ response = st.session_state.chatbot.chat(
591
+ prompt,
592
+ st.session_state.memory,
593
+ ignored_tables=list(st.session_state.ignored_tables)
594
+ )
595
+
596
+ st.markdown(response.answer)
597
+
598
+ # Show metadata
599
+ if response.query_type != "general":
600
+ st.caption(f"Query type: {response.query_type}")
601
+
602
+ if response.sql_query:
603
+ with st.expander("SQL Query"):
604
+ st.code(response.sql_query, language="sql")
605
+
606
+ if response.sql_results:
607
+ with st.expander("Results"):
608
+ st.dataframe(response.sql_results)
609
+
610
+ # Save to memory
611
+ st.session_state.messages.append({
612
+ "role": "assistant",
613
+ "content": response.answer,
614
+ "metadata": {
615
+ "query_type": response.query_type,
616
+ "sql_query": response.sql_query
617
+ }
618
+ })
619
+ if st.session_state.memory:
620
+ st.session_state.memory.add_message("assistant", response.answer)
621
+
622
+
623
+ def main():
624
+ """Main application entry point."""
625
+ init_session_state()
626
+
627
+ # Auto-connect to environment database on first load
628
+ if "auto_connect_attempted" not in st.session_state:
629
+ st.session_state.auto_connect_attempted = True
630
+ if st.session_state.db_source == "environment":
631
+ success = initialize_chatbot()
632
+ if success:
633
+ st.toast("✅ Auto-connected to database!")
634
+
635
+ render_sidebar()
636
+ render_chat_interface()
637
+
638
+
639
+ if __name__ == "__main__":
640
+ main()
chatbot.py CHANGED
@@ -158,7 +158,7 @@ YOUR RESPONSE:"""
158
 
159
  return total_docs
160
 
161
- def chat(self, query: str, memory: Optional[ChatMemory] = None) -> ChatResponse:
162
  """Process a user query and return a response."""
163
  if not self._schema_initialized:
164
  return ChatResponse(answer="Chatbot not initialized.", query_type="error",
@@ -171,7 +171,16 @@ YOUR RESPONSE:"""
171
  try:
172
  # Use instance introspector
173
  schema = self.introspector.introspect()
174
- schema_context = schema.to_context_string()
 
 
 
 
 
 
 
 
 
175
 
176
  # Check for memory commands
177
  # Check for memory commands
@@ -237,11 +246,11 @@ YOUR RESPONSE:"""
237
 
238
  # Process based on route
239
  if routing.query_type == QueryType.RAG:
240
- return self._handle_rag(query, history)
241
  elif routing.query_type == QueryType.SQL:
242
- return self._handle_sql(query, schema_context, history)
243
  elif routing.query_type == QueryType.HYBRID:
244
- return self._handle_hybrid(query, schema_context, history)
245
  else:
246
  return self._handle_general(query, history)
247
 
@@ -249,9 +258,9 @@ YOUR RESPONSE:"""
249
  logger.error(f"Chat error: {e}")
250
  return ChatResponse(answer=f"Error: {str(e)}", query_type="error", error=str(e))
251
 
252
- def _handle_rag(self, query: str, history: List[Dict]) -> ChatResponse:
253
  """Handle RAG-based query."""
254
- context = self.rag_engine.get_context(query, top_k=5)
255
 
256
  prompt = self.RESPONSE_PROMPT.format(context=f"RELEVANT DATA:\n{context}", question=query)
257
 
@@ -266,7 +275,7 @@ YOUR RESPONSE:"""
266
  return ChatResponse(answer=answer, query_type="rag",
267
  sources=[{"type": "semantic_search", "context": context[:500]}])
268
 
269
- def _handle_sql(self, query: str, schema_context: str, history: List[Dict]) -> ChatResponse:
270
  """Handle SQL-based query."""
271
  sql, explanation = self.sql_generator.generate(query, schema_context, history)
272
 
@@ -287,7 +296,7 @@ YOUR RESPONSE:"""
287
  # We try RAG as a fallback if SQL found nothing
288
  if not results:
289
  logger.info(f"SQL returned no results for query: '{query}'. Falling back to RAG.")
290
- rag_response = self._handle_rag(query, history)
291
 
292
  # Combine the info: "I couldn't find an exact match in the rows, but here is what I found semantically:"
293
  rag_response.answer = f"I couldn't find a direct match using a database query, but here is what I found in the product descriptions:\n\n{rag_response.answer}"
@@ -310,10 +319,10 @@ YOUR RESPONSE:"""
310
  return ChatResponse(answer=answer, query_type="sql",
311
  sql_query=sanitized_sql, sql_results=results[:10])
312
 
313
- def _handle_hybrid(self, query: str, schema_context: str, history: List[Dict]) -> ChatResponse:
314
  """Handle hybrid RAG + SQL query."""
315
  # Get RAG context
316
- rag_context = self.rag_engine.get_context(query, top_k=3)
317
 
318
  # Try SQL as well
319
  sql_context = ""
 
158
 
159
  return total_docs
160
 
161
+ def chat(self, query: str, memory: Optional[ChatMemory] = None, ignored_tables: Optional[List[str]] = None) -> ChatResponse:
162
  """Process a user query and return a response."""
163
  if not self._schema_initialized:
164
  return ChatResponse(answer="Chatbot not initialized.", query_type="error",
 
171
  try:
172
  # Use instance introspector
173
  schema = self.introspector.introspect()
174
+ schema_context = schema.to_context_string(ignored_tables=ignored_tables)
175
+
176
+ # Calculate allowed tables for RAG and Validator
177
+ allowed_tables = None
178
+ if ignored_tables:
179
+ allowed_tables = [t for t in schema.table_names if t not in ignored_tables]
180
+ # Update validator to only allow these tables
181
+ self.sql_validator.set_allowed_tables(allowed_tables)
182
+ else:
183
+ self.sql_validator.set_allowed_tables(schema.table_names)
184
 
185
  # Check for memory commands
186
  # Check for memory commands
 
246
 
247
  # Process based on route
248
  if routing.query_type == QueryType.RAG:
249
+ return self._handle_rag(query, history, allowed_tables)
250
  elif routing.query_type == QueryType.SQL:
251
+ return self._handle_sql(query, schema_context, history, allowed_tables)
252
  elif routing.query_type == QueryType.HYBRID:
253
+ return self._handle_hybrid(query, schema_context, history, allowed_tables)
254
  else:
255
  return self._handle_general(query, history)
256
 
 
258
  logger.error(f"Chat error: {e}")
259
  return ChatResponse(answer=f"Error: {str(e)}", query_type="error", error=str(e))
260
 
261
+ def _handle_rag(self, query: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
262
  """Handle RAG-based query."""
263
+ context = self.rag_engine.get_context(query, top_k=5, table_filter=allowed_tables)
264
 
265
  prompt = self.RESPONSE_PROMPT.format(context=f"RELEVANT DATA:\n{context}", question=query)
266
 
 
275
  return ChatResponse(answer=answer, query_type="rag",
276
  sources=[{"type": "semantic_search", "context": context[:500]}])
277
 
278
+ def _handle_sql(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
279
  """Handle SQL-based query."""
280
  sql, explanation = self.sql_generator.generate(query, schema_context, history)
281
 
 
296
  # We try RAG as a fallback if SQL found nothing
297
  if not results:
298
  logger.info(f"SQL returned no results for query: '{query}'. Falling back to RAG.")
299
+ rag_response = self._handle_rag(query, history, allowed_tables)
300
 
301
  # Combine the info: "I couldn't find an exact match in the rows, but here is what I found semantically:"
302
  rag_response.answer = f"I couldn't find a direct match using a database query, but here is what I found in the product descriptions:\n\n{rag_response.answer}"
 
319
  return ChatResponse(answer=answer, query_type="sql",
320
  sql_query=sanitized_sql, sql_results=results[:10])
321
 
322
+ def _handle_hybrid(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None) -> ChatResponse:
323
  """Handle hybrid RAG + SQL query."""
324
  # Get RAG context
325
+ rag_context = self.rag_engine.get_context(query, top_k=3, table_filter=allowed_tables)
326
 
327
  # Try SQL as well
328
  sql_context = ""
database/schema_introspector.py CHANGED
@@ -109,7 +109,7 @@ class SchemaInfo:
109
  result.append((table_name, col.name))
110
  return result
111
 
112
- def to_context_string(self) -> str:
113
  """
114
  Generate a natural language description of the schema.
115
  This is used as context for the LLM.
@@ -119,6 +119,9 @@ class SchemaInfo:
119
  lines.append("-" * 40)
120
 
121
  for table_name, table_info in self.tables.items():
 
 
 
122
  lines.append(f"\nTable: {table_name}")
123
  if table_info.comment:
124
  lines.append(f" Description: {table_info.comment}")
 
109
  result.append((table_name, col.name))
110
  return result
111
 
112
+ def to_context_string(self, ignored_tables: Optional[List[str]] = None) -> str:
113
  """
114
  Generate a natural language description of the schema.
115
  This is used as context for the LLM.
 
119
  lines.append("-" * 40)
120
 
121
  for table_name, table_info in self.tables.items():
122
+ if ignored_tables and table_name in ignored_tables:
123
+ continue
124
+
125
  lines.append(f"\nTable: {table_name}")
126
  if table_info.comment:
127
  lines.append(f" Description: {table_info.comment}")
rag/rag_engine.py CHANGED
@@ -68,7 +68,7 @@ class RAGEngine:
68
  """
69
  results = self.vector_store.search(query, top_k=top_k * 2)
70
 
71
- if table_filter:
72
  results = [
73
  (doc, score) for doc, score in results
74
  if doc.table_name in table_filter
 
68
  """
69
  results = self.vector_store.search(query, top_k=top_k * 2)
70
 
71
+ if table_filter is not None:
72
  results = [
73
  (doc, score) for doc, score in results
74
  if doc.table_name in table_filter