Vanshcc commited on
Commit
f9ad313
·
verified ·
1 Parent(s): b3fc82f

Upload 34 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces - Docker SDK
2
+ # Schema-Agnostic Database Chatbot with RAG
3
+
4
+ FROM python:3.11-slim
5
+
6
+ # Set working directory
7
+ WORKDIR /app
8
+
9
+ # Set environment variables
10
+ ENV PYTHONDONTWRITEBYTECODE=1 \
11
+ PYTHONUNBUFFERED=1 \
12
+ PYTHONPATH=/app \
13
+ HF_HOME=/app/.cache \
14
+ TRANSFORMERS_CACHE=/app/.cache/transformers \
15
+ SENTENCE_TRANSFORMERS_HOME=/app/.cache/sentence_transformers
16
+
17
+ # Install system dependencies
18
+ RUN apt-get update && apt-get install -y --no-install-recommends \
19
+ build-essential \
20
+ curl \
21
+ git \
22
+ libpq-dev \
23
+ && rm -rf /var/lib/apt/lists/* \
24
+ && apt-get clean
25
+
26
+ # Create a non-root user for security
27
+ RUN useradd -m -u 1000 appuser
28
+
29
+ # Create cache directories with proper permissions
30
+ RUN mkdir -p /app/.cache/sentence_transformers /app/.cache/transformers /app/faiss_index \
31
+ && chown -R appuser:appuser /app
32
+
33
+ # Copy requirements first for better caching
34
+ COPY --chown=appuser:appuser requirements.txt .
35
+
36
+ # Install Python dependencies
37
+ RUN pip install --no-cache-dir --upgrade pip && \
38
+ pip install --no-cache-dir -r requirements.txt
39
+
40
+ # Copy application code
41
+ COPY --chown=appuser:appuser . .
42
+
43
+ # Switch to non-root user
44
+ USER appuser
45
+
46
+ # Expose Streamlit port (HF Spaces expects 7860)
47
+ EXPOSE 7860
48
+
49
+ # Health check
50
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
51
+ CMD curl --fail http://localhost:7860/_stcore/health || exit 1
52
+
53
+ # Run Streamlit
54
+ CMD ["streamlit", "run", "app.py", \
55
+ "--server.port=7860", \
56
+ "--server.address=0.0.0.0", \
57
+ "--server.enableCORS=true", \
58
+ "--server.enableXsrfProtection=false", \
59
+ "--browser.gatherUsageStats=false", \
60
+ "--server.fileWatcherType=none"]
README.md CHANGED
@@ -1,11 +1,111 @@
1
- ---
2
- title: DB Chatbot
3
- emoji: 🌖
4
- colorFrom: gray
5
- colorTo: yellow
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Database Copilot
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ app_port: 7860
10
+ ---
11
+
12
+ # 🤖 Database Copilot
13
+
14
+ A production-grade, **schema-agnostic chatbot** that connects to **any** database (MySQL, PostgreSQL, or SQLite) and provides intelligent querying through **RAG** (Retrieval-Augmented Generation) and **Text-to-SQL**.
15
+
16
+ **🆓 Powered by Groq for FREE LLM inference!**
17
+
18
+ ## 🌟 Features
19
+
20
+ - **Multi-Database Support**: Works with **MySQL**, **PostgreSQL**, and **SQLite**
21
+ - **Schema-Agnostic**: Works with ANY database schema - no hardcoding required
22
+ - **Dynamic Introspection**: Automatically discovers tables, columns, and relationships
23
+ - **Hybrid Query Routing**: Intelligently routes queries to RAG or SQL based on intent
24
+ - **Semantic Search (RAG)**: FAISS-based vector search for text content
25
+ - **Text-to-SQL**: LLM-powered SQL generation with dialect-specific syntax
26
+ - **Security First**: Read-only queries, SQL validation, table whitelisting
27
+ - **FREE LLM**: Uses Groq API (free tier) with Llama 3.3, Mixtral, and Gemma models
28
+
29
+ ## 🚀 Getting Started
30
+
31
+ ### 1. Configure Secrets
32
+
33
+ This Space requires the following secrets to be set in your Hugging Face Space settings:
34
+
35
+ **Required:**
36
+ | Secret Name | Description |
37
+ |------------|-------------|
38
+ | `GROQ_API_KEY` | Your Groq API key ([Get FREE key](https://console.groq.com)) |
39
+
40
+ **Database Configuration (choose one):**
41
+
42
+ #### For MySQL:
43
+ | Secret Name | Description |
44
+ |------------|-------------|
45
+ | `DB_TYPE` | Set to `mysql` |
46
+ | `DB_HOST` | MySQL server hostname |
47
+ | `DB_PORT` | MySQL port (default: 3306) |
48
+ | `DB_DATABASE` | Database name |
49
+ | `DB_USERNAME` | Database username |
50
+ | `DB_PASSWORD` | Database password |
51
+
52
+ #### For PostgreSQL:
53
+ | Secret Name | Description |
54
+ |------------|-------------|
55
+ | `DB_TYPE` | Set to `postgresql` |
56
+ | `DB_HOST` | PostgreSQL server hostname |
57
+ | `DB_PORT` | PostgreSQL port (default: 5432) |
58
+ | `DB_DATABASE` | Database name |
59
+ | `DB_USERNAME` | Database username |
60
+ | `DB_PASSWORD` | Database password |
61
+
62
+ #### For SQLite:
63
+ | Secret Name | Description |
64
+ |------------|-------------|
65
+ | `DB_TYPE` | Set to `sqlite` |
66
+ | `SQLITE_PATH` | Path to SQLite database file |
67
+
68
+ **Optional:**
69
+ | Secret Name | Description | Default |
70
+ |------------|-------------|---------|
71
+ | `GROQ_MODEL` | Groq model to use | `llama-3.3-70b-versatile` |
72
+ | `DB_SSL_CA` | Path to SSL CA certificate | None |
73
+
74
+ ### 2. Connect & Use
75
+
76
+ 1. Click **"Connect & Initialize"** in the sidebar
77
+ 2. Click **"Index Text Data"** to enable semantic search
78
+ 3. Start asking questions about your data!
79
+
80
+ ## 💬 Example Queries
81
+
82
+ **Semantic Search (RAG):**
83
+ - "What products are related to electronics?"
84
+ - "Tell me about customer feedback on shipping"
85
+
86
+ **Structured Queries (SQL):**
87
+ - "How many orders were placed last month?"
88
+ - "Show me the top 10 customers by revenue"
89
+
90
+ **Hybrid:**
91
+ - "Find customers who complained about delivery and show their order count"
92
+
93
+ ## 🔒 Security
94
+
95
+ - **Read-Only Transactions**: All queries run in read-only mode
96
+ - **SQL Validation**: Only SELECT statements allowed
97
+ - **Forbidden Keywords**: INSERT, UPDATE, DELETE, DROP, etc. are blocked
98
+ - **Table Whitelisting**: Only discovered tables are queryable
99
+ - **Automatic LIMIT**: All queries have LIMIT clauses enforced
100
+
101
+ ## 🆓 Why Groq?
102
+
103
+ [Groq](https://console.groq.com) provides **FREE API access** with incredibly fast inference:
104
+ - **Llama 3.3 70B** - Best quality, state-of-the-art
105
+ - **Llama 3.1 8B Instant** - Fastest responses
106
+ - **Mixtral 8x7B** - Great for code and SQL
107
+ - **Gemma 2 9B** - Google's efficient model
108
+
109
+ ## 📝 License
110
+
111
+ MIT License
app.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Schema-Agnostic Database Chatbot - Streamlit Application
3
+
4
+ A production-grade chatbot that connects to ANY MySQL database
5
+ and provides intelligent querying through RAG and Text-to-SQL.
6
+
7
+ Uses Groq for FREE LLM inference!
8
+ """
9
+
10
+ import os
11
+ from pathlib import Path
12
+
13
+ # Load .env FIRST before any other imports
14
+ from dotenv import load_dotenv
15
+ load_dotenv(Path(__file__).parent / ".env")
16
+
17
+ import streamlit as st
18
+ import uuid
19
+ from datetime import datetime
20
+
21
+ # Page config must be first
22
+ st.set_page_config(
23
+ page_title="Database Copilot",
24
+ page_icon="🤖",
25
+ layout="wide",
26
+ initial_sidebar_state="expanded"
27
+ )
28
+
29
+ # Imports
30
+ from config import config
31
+ from database import get_db, get_schema, get_introspector
32
+ from llm import create_llm_client
33
+ from chatbot import create_chatbot, DatabaseChatbot
34
+ from memory import create_memory, create_enhanced_memory, EnhancedChatMemory
35
+
36
+
37
+ # Groq models (all FREE!)
38
+ GROQ_MODELS = [
39
+ "llama-3.3-70b-versatile",
40
+ "llama-3.1-8b-instant",
41
+ "mixtral-8x7b-32768",
42
+ "gemma2-9b-it"
43
+ ]
44
+
45
+
46
+ def init_session_state():
47
+ """Initialize Streamlit session state."""
48
+ if "session_id" not in st.session_state:
49
+ st.session_state.session_id = str(uuid.uuid4())
50
+
51
+ if "messages" not in st.session_state:
52
+ st.session_state.messages = []
53
+
54
+ if "chatbot" not in st.session_state:
55
+ st.session_state.chatbot = None
56
+
57
+ if "initialized" not in st.session_state:
58
+ st.session_state.initialized = False
59
+
60
+ if "user_id" not in st.session_state:
61
+ st.session_state.user_id = "default"
62
+
63
+ if "enable_summarization" not in st.session_state:
64
+ st.session_state.enable_summarization = True
65
+
66
+ if "summary_threshold" not in st.session_state:
67
+ st.session_state.summary_threshold = 10
68
+
69
+ if "memory" not in st.session_state:
70
+ st.session_state.memory = create_enhanced_memory(
71
+ st.session_state.session_id,
72
+ user_id=st.session_state.user_id,
73
+ enable_summarization=st.session_state.enable_summarization,
74
+ summary_threshold=st.session_state.summary_threshold
75
+ )
76
+ # Clear temporary memory on fresh load/reload
77
+ st.session_state.memory.clear_user_history()
78
+
79
+ if "indexed" not in st.session_state:
80
+ st.session_state.indexed = False
81
+
82
+
83
+ def render_sidebar():
84
+ """Render the configuration sidebar."""
85
+ with st.sidebar:
86
+ st.title("⚙️ Settings")
87
+
88
+ # User Profile
89
+ st.subheader("👤 User Profile")
90
+ user_id = st.text_input(
91
+ "User ID / Name",
92
+ value=st.session_state.get("user_id", "default"),
93
+ key="user_id_input",
94
+ help="Your unique ID for private memory storage"
95
+ )
96
+ if user_id != st.session_state.get("user_id"):
97
+ # USER ID CHANGE - Same behavior as "New Chat":
98
+ # 1. Clear temporary memory (session history) for clean start
99
+ # 2. Permanent memory remains UNTOUCHED (per-user storage)
100
+ st.session_state.user_id = user_id
101
+ st.session_state.session_id = str(uuid.uuid4()) # New session
102
+ st.session_state.messages = [] # Clear UI chat history
103
+
104
+ # Create memory for new user and clear their temp history (fresh start)
105
+ st.session_state.memory = create_enhanced_memory(
106
+ st.session_state.session_id,
107
+ user_id=user_id,
108
+ enable_summarization=st.session_state.enable_summarization,
109
+ summary_threshold=st.session_state.summary_threshold
110
+ )
111
+ st.session_state.memory.clear_user_history() # Clears _chatbot_memory, NOT _chatbot_permanent_memory_v2
112
+ st.rerun()
113
+
114
+ st.divider()
115
+
116
+ # Initialize Button
117
+ if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"):
118
+ with st.spinner("Connecting to database..."):
119
+ success = initialize_chatbot()
120
+ if success:
121
+ st.success("✅ Connected!")
122
+ st.rerun()
123
+
124
+ # Index Button (after initialization)
125
+ if st.session_state.initialized:
126
+ if st.button("📚 Index Text Data", use_container_width=True):
127
+ with st.spinner("Indexing text data..."):
128
+ index_data()
129
+ st.success("✅ Indexed!")
130
+ st.rerun()
131
+
132
+ st.divider()
133
+
134
+ # Status
135
+ st.subheader("📊 Status")
136
+ if st.session_state.initialized:
137
+ st.success("Database: Connected")
138
+ schema = get_schema()
139
+ st.info(f"Tables: {len(schema.tables)}")
140
+
141
+ if st.session_state.indexed:
142
+ from rag import get_rag_engine
143
+ engine = get_rag_engine()
144
+ st.info(f"Indexed Docs: {engine.document_count}")
145
+ else:
146
+ st.warning("Not connected")
147
+
148
+ # New Chat (Context Switch)
149
+ # New Chat (Context Switch)
150
+ if st.button("➕ New Chat", use_container_width=True, type="secondary"):
151
+ # Clear previous session from DB
152
+ if "memory" in st.session_state and st.session_state.memory:
153
+ st.session_state.memory.clear()
154
+
155
+ st.session_state.messages = []
156
+ st.session_state.session_id = str(uuid.uuid4()) # Generate new session ID
157
+
158
+ # Preserve current user ID and memory settings
159
+ current_user = st.session_state.get("user_id", "default")
160
+ st.session_state.memory = create_enhanced_memory(
161
+ st.session_state.session_id,
162
+ user_id=current_user,
163
+ enable_summarization=st.session_state.enable_summarization,
164
+ summary_threshold=st.session_state.summary_threshold
165
+ )
166
+ # Set LLM client if available
167
+ if "llm" in st.session_state and st.session_state.llm:
168
+ st.session_state.memory.set_llm_client(st.session_state.llm)
169
+ st.rerun()
170
+
171
+
172
+ def initialize_chatbot() -> bool:
173
+ """Initialize the chatbot using environment variables."""
174
+ try:
175
+ # Use Groq as default provider (from environment)
176
+ api_key = os.getenv("GROQ_API_KEY", "")
177
+ model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
178
+
179
+ if not api_key:
180
+ st.error("GROQ_API_KEY not configured. Please set it in your .env file.")
181
+ return False
182
+
183
+ llm = create_llm_client("groq", api_key=api_key, model=model)
184
+
185
+ # Create and initialize chatbot
186
+ chatbot = create_chatbot(llm)
187
+
188
+ # Explicitly set LLM client (also configures router and sql_generator)
189
+ chatbot.set_llm_client(llm)
190
+
191
+ success, msg = chatbot.initialize()
192
+
193
+ if success:
194
+ st.session_state.chatbot = chatbot
195
+ st.session_state.llm = llm # Store LLM separately too
196
+ st.session_state.initialized = True
197
+
198
+ # Set LLM client on memory for summarization
199
+ if hasattr(st.session_state.memory, 'set_llm_client'):
200
+ st.session_state.memory.set_llm_client(llm)
201
+
202
+ return True
203
+ else:
204
+ st.error(f"Initialization failed: {msg}")
205
+ return False
206
+
207
+ except Exception as e:
208
+ st.error(f"Error: {str(e)}")
209
+ return False
210
+
211
+
212
+ def index_data():
213
+ """Index text data from the database."""
214
+ if st.session_state.chatbot:
215
+ progress = st.progress(0)
216
+ status = st.empty()
217
+
218
+ schema = get_schema()
219
+ total_tables = len(schema.tables)
220
+ indexed = 0
221
+
222
+ def progress_callback(table_name, docs):
223
+ nonlocal indexed
224
+ indexed += 1
225
+ progress.progress(indexed / total_tables)
226
+ status.text(f"Indexed {table_name}: {docs} documents")
227
+
228
+ total_docs = st.session_state.chatbot.index_text_data(progress_callback)
229
+ st.session_state.indexed = True
230
+ status.text(f"Total: {total_docs} documents indexed")
231
+
232
+
233
+ def render_schema_explorer():
234
+ """Render schema explorer in an expander."""
235
+ if not st.session_state.initialized:
236
+ return
237
+
238
+ with st.expander("📋 Database Schema", expanded=False):
239
+ schema = get_schema()
240
+
241
+ for table_name, table_info in schema.tables.items():
242
+ with st.container():
243
+ st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
244
+
245
+ cols = []
246
+ for col in table_info.columns:
247
+ pk = "🔑" if col.is_primary_key else ""
248
+ txt = "📝" if col.is_text_type else ""
249
+ cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
250
+
251
+ st.caption(" | ".join(cols))
252
+ st.divider()
253
+
254
+
255
+ def render_chat_interface():
256
+ """Render the main chat interface."""
257
+ st.title("🤖 Database Copilot")
258
+ st.caption("Schema-agnostic chatbot powered by Groq (FREE!)")
259
+
260
+ # Schema explorer
261
+ render_schema_explorer()
262
+
263
+ # Chat container
264
+ chat_container = st.container()
265
+
266
+ with chat_container:
267
+ # Display messages
268
+ for msg in st.session_state.messages:
269
+ with st.chat_message(msg["role"]):
270
+ st.markdown(msg["content"])
271
+
272
+ # Show metadata for assistant messages
273
+ if msg["role"] == "assistant" and "metadata" in msg:
274
+ meta = msg["metadata"]
275
+ if meta.get("query_type"):
276
+ st.caption(f"Query type: {meta['query_type']}")
277
+ if meta.get("sql_query"):
278
+ with st.expander("SQL Query"):
279
+ st.code(meta["sql_query"], language="sql")
280
+
281
+ # Chat input
282
+ if prompt := st.chat_input("Ask about your data..."):
283
+ if not st.session_state.initialized:
284
+ st.error("Please connect to a database first!")
285
+ return
286
+
287
+ # Add user message
288
+ st.session_state.messages.append({"role": "user", "content": prompt})
289
+ st.session_state.memory.add_message("user", prompt)
290
+
291
+ with st.chat_message("user"):
292
+ st.markdown(prompt)
293
+
294
+ # Get response
295
+ with st.chat_message("assistant"):
296
+ with st.spinner("Thinking..."):
297
+ response = st.session_state.chatbot.chat(
298
+ prompt,
299
+ st.session_state.memory
300
+ )
301
+
302
+ st.markdown(response.answer)
303
+
304
+ # Show metadata
305
+ if response.query_type != "general":
306
+ st.caption(f"Query type: {response.query_type}")
307
+
308
+ if response.sql_query:
309
+ with st.expander("SQL Query"):
310
+ st.code(response.sql_query, language="sql")
311
+
312
+ if response.sql_results:
313
+ with st.expander("Results"):
314
+ st.dataframe(response.sql_results)
315
+
316
+ # Save to memory
317
+ st.session_state.messages.append({
318
+ "role": "assistant",
319
+ "content": response.answer,
320
+ "metadata": {
321
+ "query_type": response.query_type,
322
+ "sql_query": response.sql_query
323
+ }
324
+ })
325
+ st.session_state.memory.add_message("assistant", response.answer)
326
+
327
+
328
+ def main():
329
+ """Main application entry point."""
330
+ init_session_state()
331
+ render_sidebar()
332
+ render_chat_interface()
333
+
334
+
335
+ if __name__ == "__main__":
336
+ main()
chatbot.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot Core - Main orchestrator for the schema-agnostic database chatbot.
3
+
4
+ Combines all components:
5
+ - Schema introspection
6
+ - Query routing
7
+ - RAG retrieval
8
+ - SQL generation & execution
9
+ - Response generation
10
+ """
11
+
12
+ import logging
13
+ from typing import Dict, Any, List, Optional, Tuple
14
+ from dataclasses import dataclass
15
+
16
+ from database import get_db, get_schema, get_introspector
17
+ from rag import get_rag_engine
18
+ from sql import get_sql_generator, get_sql_validator
19
+ from llm import create_llm_client, LLMClient
20
+ from router import get_query_router, QueryType
21
+ from memory import ChatMemory, EnhancedChatMemory, create_memory
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class ChatResponse:
28
+ """Response from the chatbot."""
29
+ answer: str
30
+ query_type: str
31
+ sources: List[Dict[str, Any]] = None
32
+ sql_query: Optional[str] = None
33
+ sql_results: Optional[List[Dict]] = None
34
+ error: Optional[str] = None
35
+
36
+ def __post_init__(self):
37
+ if self.sources is None:
38
+ self.sources = []
39
+
40
+
41
+ class DatabaseChatbot:
42
+ """Main chatbot class orchestrating all components."""
43
+
44
+ RESPONSE_PROMPT = """You are a helpful database assistant. Answer the user's question based on the provided context.
45
+
46
+ IMPORTANT: Use the conversation history to understand follow-up questions. If the user refers to "it", "that", "the product", etc., look at the previous messages to understand what they're referring to.
47
+
48
+ {context}
49
+
50
+ USER QUESTION: {question}
51
+
52
+ INSTRUCTIONS:
53
+ - Answer ONLY based on the provided context AND conversation history
54
+ - Do NOT use outside knowledge, general assumptions, or hallucinate facts
55
+ - If the context doesn't contain the answer, explicitly state that the information is not available in the database
56
+ - Resolve pronouns using previous messages
57
+ - Be concise but complete
58
+ - Format data nicely
59
+
60
+ YOUR RESPONSE:"""
61
+
62
+ def __init__(self, llm_client: Optional[LLMClient] = None):
63
+ self.db = get_db()
64
+ self.introspector = get_introspector()
65
+ self.rag_engine = get_rag_engine()
66
+ # Pass database type to SQL generator for dialect-specific SQL
67
+ db_type = self.db.db_type.value
68
+ self.sql_generator = get_sql_generator(db_type)
69
+ self.sql_validator = get_sql_validator()
70
+ self.router = get_query_router()
71
+ self.llm_client = llm_client
72
+
73
+ self._schema_initialized = False
74
+ self._rag_initialized = False
75
+
76
+ def set_llm_client(self, llm_client: LLMClient):
77
+ """Configure the LLM client."""
78
+ self.llm_client = llm_client
79
+ self.sql_generator.set_llm_client(llm_client)
80
+ self.router.set_llm_client(llm_client)
81
+
82
+ def initialize(self) -> Tuple[bool, str]:
83
+ """Initialize the chatbot by introspecting the database."""
84
+ try:
85
+ # Test connection
86
+ success, msg = self.db.test_connection()
87
+ if not success:
88
+ return False, f"Database connection failed: {msg}"
89
+
90
+ # Introspect schema
91
+ schema = self.introspector.introspect(force_refresh=True)
92
+
93
+ # Configure SQL validator with discovered tables
94
+ self.sql_validator.set_allowed_tables(schema.table_names)
95
+
96
+ self._schema_initialized = True
97
+
98
+ return True, f"Initialized with {len(schema.tables)} tables"
99
+
100
+ except Exception as e:
101
+ logger.error(f"Initialization failed: {e}")
102
+ return False, str(e)
103
+
104
+ def index_text_data(self, progress_callback=None) -> int:
105
+ """Index all text data for RAG."""
106
+ if not self._schema_initialized:
107
+ raise RuntimeError("Chatbot not initialized. Call initialize() first.")
108
+
109
+ schema = get_schema()
110
+ total_docs = 0
111
+
112
+ for table_name, table_info in schema.tables.items():
113
+ text_cols = [c.name for c in table_info.text_columns]
114
+ if not text_cols:
115
+ continue
116
+
117
+ pk = table_info.primary_keys[0] if table_info.primary_keys else None
118
+ cols_to_select = text_cols + ([pk] if pk else [])
119
+
120
+ query = f"SELECT {', '.join(cols_to_select)} FROM {table_name} LIMIT 1000"
121
+
122
+ try:
123
+ rows = self.db.execute_query(query)
124
+ docs = self.rag_engine.index_table(table_name, rows, text_cols, pk)
125
+ total_docs += docs
126
+
127
+ if progress_callback:
128
+ progress_callback(table_name, docs)
129
+
130
+ except Exception as e:
131
+ logger.warning(f"Failed to index {table_name}: {e}")
132
+
133
+ self.rag_engine.save()
134
+ self._rag_initialized = True
135
+
136
+ return total_docs
137
+
138
+ def chat(self, query: str, memory: Optional[ChatMemory] = None) -> ChatResponse:
139
+ """Process a user query and return a response."""
140
+ if not self._schema_initialized:
141
+ return ChatResponse(answer="Chatbot not initialized.", query_type="error",
142
+ error="Call initialize() first")
143
+
144
+ if not self.llm_client:
145
+ return ChatResponse(answer="LLM not configured.", query_type="error",
146
+ error="Configure LLM client first")
147
+
148
+ try:
149
+ schema = get_schema()
150
+ schema_context = schema.to_context_string()
151
+
152
+ # Check for memory commands
153
+ # Check for memory commands
154
+ # Check for memory commands using regex for flexibility
155
+ import re
156
+ save_pattern = re.compile(r"(?:please\s+)?(?:save|remember|memorize)\s+(?:this|that)?\s*(?:to\s+(?:main\s+)?memory)?\s*(?:that)?\s*:?\s*(.*)", re.IGNORECASE)
157
+ match = save_pattern.match(query.strip())
158
+
159
+ # Check if it looks like a command (starts with command words)
160
+ is_command = bool(match) and (
161
+ query.lower().startswith(("save", "remember", "memorize")) or
162
+ "saved to" in query.lower() # specific user case "saved to main memory"
163
+ )
164
+
165
+ if is_command and memory:
166
+ content_to_save = match.group(1).strip() if match else ""
167
+
168
+ # If specific content is provided (e.g. "Remember that I like pizza")
169
+ if content_to_save:
170
+ # Save the explicit content
171
+ success = memory.save_permanent_context(content_to_save)
172
+ if success:
173
+ return ChatResponse(answer=f"💾 I've saved to your permanent memory: '{content_to_save}'", query_type="memory")
174
+ else:
175
+ return ChatResponse(answer="❌ Failed to save to permanent memory. Please try again.", query_type="memory")
176
+
177
+ # If no content (e.g. "Save this"), save the previous conversation turn
178
+ elif len(memory.messages) >= 2:
179
+ # [-1] is current command ("save to memory")
180
+ # [-2] is previous assistant response
181
+ # [-3] is previous user query (context for the response)
182
+
183
+ msgs_to_save = []
184
+ # We try to grab the last QA pair: User Prompt + AI Response
185
+ # memory.messages structure: [User, AI, User, AI, User(current)]
186
+
187
+ if len(memory.messages) >= 3:
188
+ msg_user = memory.messages[-3]
189
+ msg_ai = memory.messages[-2]
190
+
191
+ # Verify roles to ensure we are saving a Q&A pair
192
+ if msg_user.role == "user" and msg_ai.role == "assistant":
193
+ msgs_to_save = [msg_user, msg_ai]
194
+
195
+ if msgs_to_save:
196
+ # Format: "User: ... | Assistant: ..."
197
+ context_str = f"User: {msgs_to_save[0].content} | Assistant: {msgs_to_save[1].content}"
198
+ success = memory.save_permanent_context(context_str)
199
+ if success:
200
+ return ChatResponse(answer="💾 I've saved our last exchange to your permanent memory.", query_type="memory")
201
+ else:
202
+ return ChatResponse(answer="❌ Failed to save to permanent memory.", query_type="memory")
203
+ else:
204
+ return ChatResponse(answer="⚠️ I couldn't find a clear previous exchange to save. Try saying 'Remember that [fact]'.", query_type="memory")
205
+ else:
206
+ return ChatResponse(answer="⚠️ Nothing previous to save. Tell me something to remember first!", query_type="memory")
207
+
208
+ # Route the query
209
+ routing = self.router.route(query, schema_context)
210
+
211
+ # Get chat history for context
212
+ history = memory.get_context_messages(5) if memory else []
213
+
214
+ # Process based on route
215
+ if routing.query_type == QueryType.RAG:
216
+ return self._handle_rag(query, history)
217
+ elif routing.query_type == QueryType.SQL:
218
+ return self._handle_sql(query, schema_context, history)
219
+ elif routing.query_type == QueryType.HYBRID:
220
+ return self._handle_hybrid(query, schema_context, history)
221
+ else:
222
+ return self._handle_general(query, history)
223
+
224
+ except Exception as e:
225
+ logger.error(f"Chat error: {e}")
226
+ return ChatResponse(answer=f"Error: {str(e)}", query_type="error", error=str(e))
227
+
228
+ def _handle_rag(self, query: str, history: List[Dict]) -> ChatResponse:
229
+ """Handle RAG-based query."""
230
+ context = self.rag_engine.get_context(query, top_k=5)
231
+
232
+ prompt = self.RESPONSE_PROMPT.format(context=f"RELEVANT DATA:\n{context}", question=query)
233
+
234
+ messages = self._construct_messages(
235
+ "You are a helpful database assistant.",
236
+ history,
237
+ prompt
238
+ )
239
+
240
+ answer = self.llm_client.chat(messages)
241
+
242
+ return ChatResponse(answer=answer, query_type="rag",
243
+ sources=[{"type": "semantic_search", "context": context[:500]}])
244
+
245
+ def _handle_sql(self, query: str, schema_context: str, history: List[Dict]) -> ChatResponse:
246
+ """Handle SQL-based query."""
247
+ sql, explanation = self.sql_generator.generate(query, schema_context, history)
248
+
249
+ # Validate SQL
250
+ is_valid, msg, sanitized_sql = self.sql_validator.validate(sql)
251
+ if not is_valid:
252
+ return ChatResponse(answer=f"Could not generate safe query: {msg}",
253
+ query_type="sql", error=msg)
254
+
255
+ # Execute query
256
+ try:
257
+ results = self.db.execute_query(sanitized_sql)
258
+ except Exception as e:
259
+ return ChatResponse(answer=f"Query execution failed: {e}",
260
+ query_type="sql", sql_query=sanitized_sql, error=str(e))
261
+
262
+ # SMART FALLBACK: If SQL returns nothing, it might be a semantic issue (e.g. wrong column)
263
+ # We try RAG as a fallback if SQL found nothing
264
+ if not results:
265
+ logger.info(f"SQL returned no results for query: '{query}'. Falling back to RAG.")
266
+ rag_response = self._handle_rag(query, history)
267
+
268
+ # Combine the info: "I couldn't find an exact match in the rows, but here is what I found semantically:"
269
+ rag_response.answer = f"I couldn't find a direct match using a database query, but here is what I found in the product descriptions:\n\n{rag_response.answer}"
270
+ rag_response.query_type = "hybrid_fallback"
271
+ rag_response.sql_query = sanitized_sql
272
+ return rag_response
273
+
274
+ # Generate response
275
+ context = f"SQL QUERY:\n{sanitized_sql}\n\nRESULTS:\n{self._format_results(results)}"
276
+ prompt = self.RESPONSE_PROMPT.format(context=context, question=query)
277
+
278
+ messages = self._construct_messages(
279
+ "You are a helpful database assistant.",
280
+ history,
281
+ prompt
282
+ )
283
+
284
+ answer = self.llm_client.chat(messages)
285
+
286
+ return ChatResponse(answer=answer, query_type="sql",
287
+ sql_query=sanitized_sql, sql_results=results[:10])
288
+
289
+ def _handle_hybrid(self, query: str, schema_context: str, history: List[Dict]) -> ChatResponse:
290
+ """Handle hybrid RAG + SQL query."""
291
+ # Get RAG context
292
+ rag_context = self.rag_engine.get_context(query, top_k=3)
293
+
294
+ # Try SQL as well
295
+ sql_context = ""
296
+ sql_query = None
297
+ try:
298
+ sql, _ = self.sql_generator.generate(query, schema_context, history)
299
+ is_valid, _, sanitized_sql = self.sql_validator.validate(sql)
300
+ if is_valid:
301
+ results = self.db.execute_query(sanitized_sql)
302
+ sql_context = f"\nSQL RESULTS:\n{self._format_results(results)}"
303
+ sql_query = sanitized_sql
304
+ except Exception as e:
305
+ logger.debug(f"SQL part of hybrid failed: {e}")
306
+
307
+ context = f"SEMANTIC SEARCH RESULTS:\n{rag_context}{sql_context}"
308
+ prompt = self.RESPONSE_PROMPT.format(context=context, question=query)
309
+
310
+ messages = self._construct_messages(
311
+ "You are a helpful database assistant.",
312
+ history,
313
+ prompt
314
+ )
315
+
316
+ answer = self.llm_client.chat(messages)
317
+
318
+ return ChatResponse(answer=answer, query_type="hybrid", sql_query=sql_query)
319
+
320
+ def _construct_messages(self, system_instruction: str, history: List[Dict], user_content: str) -> List[Dict]:
321
+ """Construct message list, merging system messages from history."""
322
+ # Check if first history item is a system message (from memory)
323
+ additional_context = ""
324
+ filtered_history = []
325
+
326
+ for msg in history:
327
+ if msg.get("role") == "system":
328
+ additional_context += f"\n\n{msg.get('content')}"
329
+ else:
330
+ filtered_history.append(msg)
331
+
332
+ full_system_prompt = f"{system_instruction}{additional_context}"
333
+
334
+ messages = [{"role": "system", "content": full_system_prompt}]
335
+ messages.extend(filtered_history)
336
+ messages.append({"role": "user", "content": user_content})
337
+
338
+ return messages
339
+
340
+ def _handle_general(self, query: str, history: List[Dict]) -> ChatResponse:
341
+ """Handle conversation."""
342
+ # Use a strict prompt for general conversation as well to prevent hallucinations
343
+ strict_system_prompt = (
344
+ "You are a helpful database assistant.\n"
345
+ "INSTRUCTIONS:\n"
346
+ "- Answer ONLY based on the conversation history and any context provided within it.\n"
347
+ "- Do NOT use outside knowledge, general assumptions, or hallucinate facts.\n"
348
+ "- If the answer is not in the history or context, state that you don't have that information.\n"
349
+ "- Be concise."
350
+ )
351
+
352
+ messages = self._construct_messages(
353
+ strict_system_prompt,
354
+ history,
355
+ query
356
+ )
357
+ answer = self.llm_client.chat(messages)
358
+ return ChatResponse(answer=answer, query_type="general")
359
+
360
+ def _format_results(self, results: List[Dict], max_rows: int = 10) -> str:
361
+ """Format SQL results for display."""
362
+ if not results:
363
+ return "No results found."
364
+
365
+ rows = results[:max_rows]
366
+ lines = []
367
+
368
+ # Header
369
+ headers = list(rows[0].keys())
370
+ lines.append(" | ".join(headers))
371
+ lines.append("-" * len(lines[0]))
372
+
373
+ # Rows
374
+ for row in rows:
375
+ values = [str(v)[:50] for v in row.values()]
376
+ lines.append(" | ".join(values))
377
+
378
+ if len(results) > max_rows:
379
+ lines.append(f"... and {len(results) - max_rows} more rows")
380
+
381
+ return "\n".join(lines)
382
+
383
+ def get_schema_summary(self) -> str:
384
+ """Get a summary of the database schema."""
385
+ if not self._schema_initialized:
386
+ return "Schema not loaded."
387
+ return get_schema().to_context_string()
388
+
389
+
390
+ def create_chatbot(llm_client: Optional[LLMClient] = None) -> DatabaseChatbot:
391
+ return DatabaseChatbot(llm_client)
config.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration module for the Schema-Agnostic Database Chatbot.
3
+
4
+ This module handles all configuration including:
5
+ - Database connection settings (MySQL, PostgreSQL, SQLite)
6
+ - LLM provider settings (Groq / OpenAI / Local LLaMA)
7
+ - Embedding model configuration
8
+ - Security settings
9
+ """
10
+
11
+ import os
12
+ from pathlib import Path
13
+ from dataclasses import dataclass, field
14
+ from typing import Optional, List
15
+ from enum import Enum
16
+
17
+ # Load .env file BEFORE any os.getenv calls
18
+ from dotenv import load_dotenv
19
+ env_path = Path(__file__).parent / ".env"
20
+ load_dotenv(env_path)
21
+
22
+
23
+ class DatabaseType(Enum):
24
+ """Supported database types."""
25
+ MYSQL = "mysql"
26
+ POSTGRESQL = "postgresql"
27
+ SQLITE = "sqlite"
28
+
29
+
30
+ class LLMProvider(Enum):
31
+ """Supported LLM providers."""
32
+ GROQ = "groq" # FREE!
33
+ OPENAI = "openai"
34
+ LOCAL_LLAMA = "local_llama"
35
+
36
+
37
+ class EmbeddingProvider(Enum):
38
+ """Supported embedding providers."""
39
+ OPENAI = "openai"
40
+ SENTENCE_TRANSFORMERS = "sentence_transformers"
41
+
42
+
43
+ @dataclass
44
+ class DatabaseConfig:
45
+ """
46
+ Database configuration supporting MySQL, PostgreSQL, and SQLite.
47
+
48
+ All sensitive values are loaded from environment variables.
49
+ """
50
+ # Database type (mysql, postgresql, sqlite)
51
+ db_type: DatabaseType = field(
52
+ default_factory=lambda: DatabaseType(os.getenv("DB_TYPE", "mysql").lower())
53
+ )
54
+
55
+ # Common connection settings (for MySQL/PostgreSQL)
56
+ host: str = field(default_factory=lambda: os.getenv("DB_HOST", os.getenv("MYSQL_HOST", "")))
57
+ port: int = field(default_factory=lambda: int(os.getenv("DB_PORT", os.getenv("MYSQL_PORT", "3306"))))
58
+ database: str = field(default_factory=lambda: os.getenv("DB_DATABASE", os.getenv("MYSQL_DATABASE", "")))
59
+ username: str = field(default_factory=lambda: os.getenv("DB_USERNAME", os.getenv("MYSQL_USERNAME", "")))
60
+ password: str = field(default_factory=lambda: os.getenv("DB_PASSWORD", os.getenv("MYSQL_PASSWORD", "")))
61
+
62
+ # SSL configuration
63
+ ssl_ca: Optional[str] = field(default_factory=lambda: os.getenv("DB_SSL_CA", os.getenv("MYSQL_SSL_CA", None)))
64
+
65
+ # SQLite-specific: path to database file
66
+ sqlite_path: str = field(default_factory=lambda: os.getenv("SQLITE_PATH", "./chatbot.db"))
67
+
68
+ @property
69
+ def connection_string(self) -> str:
70
+ """Generate SQLAlchemy connection string based on database type."""
71
+ if self.db_type == DatabaseType.SQLITE:
72
+ # SQLite uses file path
73
+ return f"sqlite:///{self.sqlite_path}"
74
+
75
+ elif self.db_type == DatabaseType.POSTGRESQL:
76
+ # PostgreSQL connection string
77
+ base_url = f"postgresql+psycopg2://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
78
+ if self.ssl_ca:
79
+ return f"{base_url}?sslmode=verify-full&sslrootcert={self.ssl_ca}"
80
+ return base_url
81
+
82
+ else: # MySQL (default)
83
+ # MySQL connection string
84
+ base_url = f"mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
85
+ if self.ssl_ca:
86
+ return f"{base_url}?ssl_ca={self.ssl_ca}"
87
+ return base_url
88
+
89
+ def is_configured(self) -> bool:
90
+ """Check if all required database settings are configured."""
91
+ if self.db_type == DatabaseType.SQLITE:
92
+ # SQLite only needs a valid path
93
+ return bool(self.sqlite_path)
94
+ else:
95
+ # MySQL/PostgreSQL need host, database, username, password
96
+ return all([self.host, self.database, self.username, self.password])
97
+
98
+ @property
99
+ def is_mysql(self) -> bool:
100
+ """Check if using MySQL."""
101
+ return self.db_type == DatabaseType.MYSQL
102
+
103
+ @property
104
+ def is_postgresql(self) -> bool:
105
+ """Check if using PostgreSQL."""
106
+ return self.db_type == DatabaseType.POSTGRESQL
107
+
108
+ @property
109
+ def is_sqlite(self) -> bool:
110
+ """Check if using SQLite."""
111
+ return self.db_type == DatabaseType.SQLITE
112
+
113
+
114
+ @dataclass
115
+ class LLMConfig:
116
+ """LLM configuration for query routing and response generation."""
117
+ provider: LLMProvider = field(
118
+ default_factory=lambda: LLMProvider(os.getenv("LLM_PROVIDER", "openai"))
119
+ )
120
+ openai_api_key: str = field(default_factory=lambda: os.getenv("OPENAI_API_KEY", ""))
121
+ openai_model: str = field(default_factory=lambda: os.getenv("OPENAI_MODEL", "gpt-4o-mini"))
122
+
123
+ # Local LLaMA settings
124
+ local_model_path: str = field(
125
+ default_factory=lambda: os.getenv("LOCAL_MODEL_PATH", "")
126
+ )
127
+ local_model_name: str = field(
128
+ default_factory=lambda: os.getenv("LOCAL_MODEL_NAME", "llama-2-7b-chat")
129
+ )
130
+
131
+ # Generation parameters
132
+ temperature: float = 0.1 # Low temperature for more deterministic outputs
133
+ max_tokens: int = 1024
134
+
135
+ def is_configured(self) -> bool:
136
+ """Check if LLM is properly configured."""
137
+ if self.provider == LLMProvider.OPENAI:
138
+ return bool(self.openai_api_key)
139
+ return bool(self.local_model_path)
140
+
141
+
142
+ @dataclass
143
+ class EmbeddingConfig:
144
+ """Embedding model configuration for RAG."""
145
+ provider: EmbeddingProvider = field(
146
+ default_factory=lambda: EmbeddingProvider(
147
+ os.getenv("EMBEDDING_PROVIDER", "sentence_transformers")
148
+ )
149
+ )
150
+
151
+ # OpenAI embedding settings
152
+ openai_embedding_model: str = "text-embedding-3-small"
153
+
154
+ # Sentence Transformers settings
155
+ st_model_name: str = field(
156
+ default_factory=lambda: os.getenv(
157
+ "EMBEDDING_MODEL",
158
+ "sentence-transformers/all-MiniLM-L6-v2"
159
+ )
160
+ )
161
+
162
+ # Embedding dimensions (varies by model)
163
+ embedding_dim: int = 384 # Default for all-MiniLM-L6-v2
164
+
165
+
166
+ @dataclass
167
+ class SecurityConfig:
168
+ """Security settings for SQL validation and execution."""
169
+
170
+ # SQL operations whitelist - ONLY SELECT allowed
171
+ allowed_operations: List[str] = field(default_factory=lambda: ["SELECT"])
172
+
173
+ # Dangerous keywords that should never appear in queries
174
+ forbidden_keywords: List[str] = field(default_factory=lambda: [
175
+ "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER",
176
+ "TRUNCATE", "GRANT", "REVOKE", "EXECUTE", "EXEC",
177
+ "INTO OUTFILE", "INTO DUMPFILE", "LOAD_FILE",
178
+ "INFORMATION_SCHEMA.USER_PRIVILEGES"
179
+ ])
180
+
181
+ # Maximum number of rows to return
182
+ max_result_rows: int = 100
183
+
184
+ # Default LIMIT clause if not specified
185
+ default_limit: int = 50
186
+
187
+
188
+ @dataclass
189
+ class RAGConfig:
190
+ """RAG (Retrieval-Augmented Generation) configuration."""
191
+
192
+ # FAISS index settings
193
+ faiss_index_path: str = "./faiss_index"
194
+
195
+ # Number of top results to retrieve
196
+ top_k: int = 5
197
+
198
+ # Minimum similarity score for relevance
199
+ similarity_threshold: float = 0.3
200
+
201
+ # Text columns to consider for RAG (common across database types)
202
+ text_column_types: List[str] = field(default_factory=lambda: [
203
+ # MySQL types
204
+ "TEXT", "MEDIUMTEXT", "LONGTEXT", "TINYTEXT", "VARCHAR", "CHAR",
205
+ # PostgreSQL types
206
+ "CHARACTER VARYING", "CHARACTER",
207
+ # SQLite types (SQLite is flexible but these are common)
208
+ "CLOB", "NVARCHAR", "NCHAR"
209
+ ])
210
+
211
+ # Minimum character length to consider a column for RAG
212
+ min_text_length: int = 50
213
+
214
+ # Chunk size for long text documents
215
+ chunk_size: int = 500
216
+ chunk_overlap: int = 50
217
+
218
+
219
+ @dataclass
220
+ class ChatConfig:
221
+ """Chat and memory configuration."""
222
+
223
+ # Short-term memory (in session)
224
+ max_session_messages: int = 20
225
+
226
+ # Long-term memory table name (will be created if not exists)
227
+ memory_table_name: str = "_chatbot_memory"
228
+
229
+ # Number of recent messages to include in context
230
+ context_messages: int = 5
231
+
232
+
233
+ class AppConfig:
234
+ """
235
+ Main application configuration aggregator.
236
+
237
+ Combines all configuration sections and provides
238
+ validation methods.
239
+ """
240
+
241
+ def __init__(self):
242
+ self.database = DatabaseConfig()
243
+ self.llm = LLMConfig()
244
+ self.embedding = EmbeddingConfig()
245
+ self.security = SecurityConfig()
246
+ self.rag = RAGConfig()
247
+ self.chat = ChatConfig()
248
+
249
+ def validate(self) -> tuple[bool, List[str]]:
250
+ """
251
+ Validate all configuration settings.
252
+
253
+ Returns:
254
+ tuple: (is_valid, list of error messages)
255
+ """
256
+ errors = []
257
+
258
+ if not self.database.is_configured():
259
+ db_type = self.database.db_type.value.upper()
260
+ if self.database.is_sqlite:
261
+ errors.append("SQLite configuration incomplete. Check SQLITE_PATH environment variable.")
262
+ else:
263
+ errors.append(f"{db_type} configuration incomplete. Check DB_* environment variables.")
264
+
265
+ if not self.llm.is_configured():
266
+ errors.append(
267
+ f"LLM configuration incomplete for provider: {self.llm.provider.value}. "
268
+ "Check API keys or model paths."
269
+ )
270
+
271
+ return len(errors) == 0, errors
272
+
273
+ @classmethod
274
+ def from_env(cls) -> "AppConfig":
275
+ """Create configuration from environment variables."""
276
+ return cls()
277
+
278
+
279
+ # Global configuration instance
280
+ config = AppConfig.from_env()
database/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database module for the Schema-Agnostic Chatbot.
3
+
4
+ Provides:
5
+ - Database connection management
6
+ - Dynamic schema introspection
7
+ - Safe query execution
8
+ """
9
+
10
+ from .connection import DatabaseConnection, get_db, db_connection
11
+ from .schema_introspector import (
12
+ SchemaIntrospector,
13
+ SchemaInfo,
14
+ TableInfo,
15
+ ColumnInfo,
16
+ get_introspector,
17
+ get_schema
18
+ )
19
+
20
+ __all__ = [
21
+ "DatabaseConnection",
22
+ "get_db",
23
+ "db_connection",
24
+ "SchemaIntrospector",
25
+ "SchemaInfo",
26
+ "TableInfo",
27
+ "ColumnInfo",
28
+ "get_introspector",
29
+ "get_schema"
30
+ ]
database/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (786 Bytes). View file
 
database/__pycache__/connection.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
database/__pycache__/schema_introspector.cpython-311.pyc ADDED
Binary file (31.2 kB). View file
 
database/connection.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database Connection Module - Multi-Database Support.
3
+
4
+ This module provides:
5
+ - SQLAlchemy engine and session management for MySQL, PostgreSQL, and SQLite
6
+ - Connection pooling (for MySQL/PostgreSQL)
7
+ - SSL/TLS support
8
+ - Connection health checking
9
+ """
10
+
11
+ import logging
12
+ from contextlib import contextmanager
13
+ from typing import Optional, Generator
14
+ from sqlalchemy import create_engine, text, event
15
+ from sqlalchemy.engine import Engine
16
+ from sqlalchemy.orm import sessionmaker, Session
17
+ from sqlalchemy.pool import QueuePool, StaticPool
18
+ from sqlalchemy.exc import OperationalError, SQLAlchemyError
19
+
20
+ import sys
21
+ import os
22
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
+ from config import DatabaseConfig, DatabaseType, config
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class DatabaseConnection:
29
+ """
30
+ Manages database connections with connection pooling.
31
+
32
+ Supports MySQL, PostgreSQL, and SQLite.
33
+ """
34
+
35
+ def __init__(self, db_config: Optional[DatabaseConfig] = None):
36
+ """
37
+ Initialize database connection manager.
38
+
39
+ Args:
40
+ db_config: Database configuration. Uses global config if not provided.
41
+ """
42
+ self.config = db_config or config.database
43
+ self._engine: Optional[Engine] = None
44
+ self._session_factory: Optional[sessionmaker] = None
45
+
46
+ def _create_engine(self) -> Engine:
47
+ """
48
+ Create SQLAlchemy engine with appropriate settings for each database type.
49
+
50
+ Returns:
51
+ Configured SQLAlchemy Engine instance
52
+ """
53
+ connect_args = {}
54
+
55
+ if self.config.db_type == DatabaseType.SQLITE:
56
+ # SQLite-specific settings
57
+ # Use StaticPool for SQLite to handle multi-threading
58
+ connect_args["check_same_thread"] = False
59
+
60
+ engine = create_engine(
61
+ self.config.connection_string,
62
+ poolclass=StaticPool, # SQLite works best with StaticPool
63
+ connect_args=connect_args,
64
+ echo=False
65
+ )
66
+
67
+ # Enable foreign keys for SQLite
68
+ @event.listens_for(engine, "connect")
69
+ def set_sqlite_pragma(dbapi_connection, connection_record):
70
+ cursor = dbapi_connection.cursor()
71
+ cursor.execute("PRAGMA foreign_keys=ON")
72
+ cursor.close()
73
+
74
+ elif self.config.db_type == DatabaseType.POSTGRESQL:
75
+ # PostgreSQL-specific settings
76
+ if self.config.ssl_ca:
77
+ connect_args["sslmode"] = "verify-full"
78
+ connect_args["sslrootcert"] = self.config.ssl_ca
79
+
80
+ engine = create_engine(
81
+ self.config.connection_string,
82
+ poolclass=QueuePool,
83
+ pool_size=5,
84
+ max_overflow=10,
85
+ pool_timeout=30,
86
+ pool_recycle=1800,
87
+ pool_pre_ping=True,
88
+ connect_args=connect_args,
89
+ echo=False
90
+ )
91
+
92
+ else: # MySQL (default)
93
+ # MySQL-specific settings (SSL for Aiven)
94
+ if self.config.ssl_ca:
95
+ connect_args["ssl"] = {
96
+ "ca": self.config.ssl_ca,
97
+ "check_hostname": True,
98
+ "verify_mode": True
99
+ }
100
+
101
+ engine = create_engine(
102
+ self.config.connection_string,
103
+ poolclass=QueuePool,
104
+ pool_size=5,
105
+ max_overflow=10,
106
+ pool_timeout=30,
107
+ pool_recycle=1800,
108
+ pool_pre_ping=True,
109
+ connect_args=connect_args,
110
+ echo=False
111
+ )
112
+
113
+ return engine
114
+
115
+ @property
116
+ def engine(self) -> Engine:
117
+ """Get or create the SQLAlchemy engine."""
118
+ if self._engine is None:
119
+ self._engine = self._create_engine()
120
+ return self._engine
121
+
122
+ @property
123
+ def session_factory(self) -> sessionmaker:
124
+ """Get or create the session factory."""
125
+ if self._session_factory is None:
126
+ self._session_factory = sessionmaker(
127
+ bind=self.engine,
128
+ autocommit=False,
129
+ autoflush=False
130
+ )
131
+ return self._session_factory
132
+
133
+ @property
134
+ def db_type(self) -> DatabaseType:
135
+ """Get the current database type."""
136
+ return self.config.db_type
137
+
138
+ @contextmanager
139
+ def get_session(self) -> Generator[Session, None, None]:
140
+ """
141
+ Context manager for database sessions.
142
+
143
+ Yields:
144
+ SQLAlchemy Session instance
145
+
146
+ Example:
147
+ with db.get_session() as session:
148
+ result = session.execute(text("SELECT * FROM users"))
149
+ """
150
+ session = self.session_factory()
151
+ try:
152
+ yield session
153
+ session.commit()
154
+ except SQLAlchemyError as e:
155
+ session.rollback()
156
+ logger.error(f"Database session error: {e}")
157
+ raise
158
+ finally:
159
+ session.close()
160
+
161
+ def execute_query(self, query: str, params: Optional[dict] = None) -> list:
162
+ """
163
+ Execute a read-only SQL query and return results.
164
+
165
+ Args:
166
+ query: SQL query string (must be SELECT)
167
+ params: Optional query parameters for parameterized queries
168
+
169
+ Returns:
170
+ List of result rows as dictionaries
171
+ """
172
+ with self.get_session() as session:
173
+ result = session.execute(text(query), params or {})
174
+ # Convert rows to dictionaries for easier handling
175
+ columns = result.keys()
176
+ return [dict(zip(columns, row)) for row in result.fetchall()]
177
+
178
+ def execute_write(self, query: str, params: Optional[dict] = None) -> bool:
179
+ """
180
+ Execute a write operation (INSERT, UPDATE, DELETE, CREATE).
181
+
182
+ Args:
183
+ query: SQL query string
184
+ params: Optional query parameters
185
+
186
+ Returns:
187
+ bool: True if successful
188
+ """
189
+ with self.get_session() as session:
190
+ session.execute(text(query), params or {})
191
+ session.commit()
192
+ return True
193
+
194
+ def test_connection(self) -> tuple[bool, str]:
195
+ """
196
+ Test database connectivity.
197
+
198
+ Returns:
199
+ tuple: (success: bool, message: str)
200
+ """
201
+ try:
202
+ with self.get_session() as session:
203
+ result = session.execute(text("SELECT 1 as health_check"))
204
+ row = result.fetchone()
205
+ if row and row[0] == 1:
206
+ db_type = self.config.db_type.value.upper()
207
+ return True, f"{db_type} connection successful"
208
+ return False, "Unexpected result from health check query"
209
+ except OperationalError as e:
210
+ logger.error(f"Database connection failed: {e}")
211
+ return False, f"Connection failed: {str(e)}"
212
+ except Exception as e:
213
+ logger.error(f"Unexpected error during connection test: {e}")
214
+ return False, f"Unexpected error: {str(e)}"
215
+
216
+ def close(self):
217
+ """Close all connections and dispose of the engine."""
218
+ if self._engine:
219
+ self._engine.dispose()
220
+ self._engine = None
221
+ self._session_factory = None
222
+ logger.info("Database connections closed")
223
+
224
+
225
+ # Create a global database connection instance
226
+ db_connection = DatabaseConnection()
227
+
228
+
229
+ def get_db() -> DatabaseConnection:
230
+ """Get the global database connection instance."""
231
+ return db_connection
database/schema_introspector.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dynamic Schema Introspection Module - Multi-Database Support.
3
+
4
+ This module is the CORE of the schema-agnostic design.
5
+ It dynamically discovers:
6
+ - All tables in the database
7
+ - All columns with their data types
8
+ - Primary keys and foreign keys
9
+ - Text-like columns for RAG indexing
10
+ - Relationships between tables
11
+
12
+ Supports MySQL, PostgreSQL, and SQLite.
13
+ NEVER hardcodes any table or column names.
14
+ """
15
+
16
+ import logging
17
+ from dataclasses import dataclass, field
18
+ from typing import List, Dict, Optional, Any
19
+ from sqlalchemy import text, inspect
20
+ from sqlalchemy.engine import Engine
21
+
22
+ from .connection import get_db
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class ColumnInfo:
29
+ """Information about a single database column."""
30
+ name: str
31
+ data_type: str
32
+ is_nullable: bool
33
+ is_primary_key: bool
34
+ max_length: Optional[int] = None
35
+ default_value: Optional[str] = None
36
+ comment: Optional[str] = None
37
+
38
+ @property
39
+ def is_text_type(self) -> bool:
40
+ """Check if this column contains text data suitable for RAG."""
41
+ text_types = [
42
+ # MySQL
43
+ 'text', 'mediumtext', 'longtext', 'tinytext', 'varchar', 'char', 'json',
44
+ # PostgreSQL
45
+ 'character varying', 'character', 'text', 'json', 'jsonb',
46
+ # SQLite (column affinity - TEXT)
47
+ 'clob', 'nvarchar', 'nchar', 'ntext'
48
+ ]
49
+ data_type_lower = self.data_type.lower().split('(')[0].strip()
50
+ return data_type_lower in text_types
51
+
52
+ @property
53
+ def is_numeric(self) -> bool:
54
+ """Check if this column contains numeric data."""
55
+ numeric_types = [
56
+ # Common across databases
57
+ 'int', 'integer', 'bigint', 'smallint', 'tinyint',
58
+ 'decimal', 'numeric', 'float', 'double', 'real',
59
+ # PostgreSQL specific
60
+ 'double precision', 'serial', 'bigserial', 'smallserial',
61
+ # SQLite (NUMERIC affinity)
62
+ 'bool', 'boolean'
63
+ ]
64
+ data_type_lower = self.data_type.lower().split('(')[0].strip()
65
+ return data_type_lower in numeric_types
66
+
67
+
68
+ @dataclass
69
+ class TableInfo:
70
+ """Complete information about a database table."""
71
+ name: str
72
+ columns: List[ColumnInfo] = field(default_factory=list)
73
+ primary_keys: List[str] = field(default_factory=list)
74
+ foreign_keys: Dict[str, str] = field(default_factory=dict) # column -> referenced_table.column
75
+ row_count: Optional[int] = None
76
+ comment: Optional[str] = None
77
+
78
+ @property
79
+ def text_columns(self) -> List[ColumnInfo]:
80
+ """Get columns suitable for text/RAG indexing."""
81
+ return [col for col in self.columns if col.is_text_type]
82
+
83
+ @property
84
+ def column_names(self) -> List[str]:
85
+ """Get list of all column names."""
86
+ return [col.name for col in self.columns]
87
+
88
+ def get_column(self, name: str) -> Optional[ColumnInfo]:
89
+ """Get column info by name."""
90
+ for col in self.columns:
91
+ if col.name.lower() == name.lower():
92
+ return col
93
+ return None
94
+
95
+
96
+ @dataclass
97
+ class SchemaInfo:
98
+ """Complete database schema information."""
99
+ database_name: str
100
+ tables: Dict[str, TableInfo] = field(default_factory=dict)
101
+
102
+ @property
103
+ def table_names(self) -> List[str]:
104
+ """Get list of all table names."""
105
+ return list(self.tables.keys())
106
+
107
+ @property
108
+ def all_text_columns(self) -> List[tuple]:
109
+ """Get all text columns across all tables as (table, column) tuples."""
110
+ result = []
111
+ for table_name, table_info in self.tables.items():
112
+ for col in table_info.text_columns:
113
+ result.append((table_name, col.name))
114
+ return result
115
+
116
+ def to_context_string(self) -> str:
117
+ """
118
+ Generate a natural language description of the schema.
119
+ This is used as context for the LLM.
120
+ """
121
+ lines = [f"Database: {self.database_name}", ""]
122
+ lines.append("Available Tables:")
123
+ lines.append("-" * 40)
124
+
125
+ for table_name, table_info in self.tables.items():
126
+ lines.append(f"\nTable: {table_name}")
127
+ if table_info.comment:
128
+ lines.append(f" Description: {table_info.comment}")
129
+ if table_info.row_count is not None:
130
+ lines.append(f" Approximate rows: {table_info.row_count}")
131
+
132
+ lines.append(" Columns:")
133
+ for col in table_info.columns:
134
+ pk_marker = " [PRIMARY KEY]" if col.is_primary_key else ""
135
+ nullable = " (nullable)" if col.is_nullable else " (required)"
136
+ lines.append(f" - {col.name}: {col.data_type}{pk_marker}{nullable}")
137
+ if col.comment:
138
+ lines.append(f" Comment: {col.comment}")
139
+
140
+ if table_info.foreign_keys:
141
+ lines.append(" Foreign Keys:")
142
+ for col, ref in table_info.foreign_keys.items():
143
+ lines.append(f" - {col} -> {ref}")
144
+
145
+ return "\n".join(lines)
146
+
147
+ def to_sql_ddl(self) -> str:
148
+ """
149
+ Generate SQL-like DDL representation of the schema.
150
+ Useful for SQL generation context.
151
+ """
152
+ ddl_lines = []
153
+
154
+ for table_name, table_info in self.tables.items():
155
+ ddl_lines.append(f"CREATE TABLE {table_name} (")
156
+
157
+ col_defs = []
158
+ for col in table_info.columns:
159
+ col_def = f" {col.name} {col.data_type}"
160
+ if col.is_primary_key:
161
+ col_def += " PRIMARY KEY"
162
+ if not col.is_nullable:
163
+ col_def += " NOT NULL"
164
+ col_defs.append(col_def)
165
+
166
+ ddl_lines.append(",\n".join(col_defs))
167
+ ddl_lines.append(");\n")
168
+
169
+ return "\n".join(ddl_lines)
170
+
171
+
172
+ class SchemaIntrospector:
173
+ """
174
+ Dynamically introspects database schema.
175
+
176
+ This is the key component that enables schema-agnostic operation.
177
+ It queries database system catalogs to discover the complete schema.
178
+ Supports MySQL, PostgreSQL, and SQLite.
179
+ """
180
+
181
+ # System tables to exclude from introspection
182
+ SYSTEM_TABLES = {
183
+ '_chatbot_memory', # Our own chat history table
184
+ '_chatbot_permanent_memory_v2',
185
+ '_chatbot_user_summaries',
186
+ 'schema_migrations',
187
+ 'flyway_schema_history',
188
+ # SQLite internal tables
189
+ 'sqlite_sequence',
190
+ 'sqlite_stat1',
191
+ 'sqlite_stat4'
192
+ }
193
+
194
+ def __init__(self, engine: Optional[Engine] = None):
195
+ """
196
+ Initialize the introspector.
197
+
198
+ Args:
199
+ engine: SQLAlchemy engine. Uses global connection if not provided.
200
+ """
201
+ self.db = get_db()
202
+ self._cached_schema: Optional[SchemaInfo] = None
203
+
204
+ def introspect(self, force_refresh: bool = False) -> SchemaInfo:
205
+ """
206
+ Perform complete schema introspection.
207
+
208
+ Args:
209
+ force_refresh: If True, bypass cache and re-introspect
210
+
211
+ Returns:
212
+ SchemaInfo object with complete schema details
213
+ """
214
+ if self._cached_schema is not None and not force_refresh:
215
+ return self._cached_schema
216
+
217
+ logger.info("Starting schema introspection...")
218
+
219
+ # Get database name
220
+ db_name = self._get_database_name()
221
+
222
+ # Get all user tables
223
+ tables = self._get_tables()
224
+
225
+ schema = SchemaInfo(database_name=db_name)
226
+
227
+ for table_name in tables:
228
+ if table_name in self.SYSTEM_TABLES:
229
+ continue
230
+ # Also skip tables that start with underscore (internal tables)
231
+ if table_name.startswith('_chatbot'):
232
+ continue
233
+
234
+ table_info = self._introspect_table(table_name)
235
+ if table_info:
236
+ schema.tables[table_name] = table_info
237
+
238
+ self._cached_schema = schema
239
+ logger.info(f"Schema introspection complete. Found {len(schema.tables)} tables.")
240
+
241
+ return schema
242
+
243
+ def _get_database_name(self) -> str:
244
+ """Get the current database name."""
245
+ db_type = self.db.db_type
246
+
247
+ try:
248
+ if db_type.value == "sqlite":
249
+ # For SQLite, return the database file name
250
+ return self.db.config.sqlite_path.split('/')[-1]
251
+ elif db_type.value == "postgresql":
252
+ result = self.db.execute_query("SELECT current_database() as db_name")
253
+ return result[0]['db_name'] if result else "unknown"
254
+ else: # MySQL
255
+ result = self.db.execute_query("SELECT DATABASE() as db_name")
256
+ return result[0]['db_name'] if result else "unknown"
257
+ except Exception as e:
258
+ logger.error(f"Error getting database name: {e}")
259
+ return "unknown"
260
+
261
+ def _get_tables(self) -> List[str]:
262
+ """
263
+ Get all user tables from the database.
264
+ Uses database-specific queries for comprehensive discovery.
265
+ """
266
+ db_type = self.db.db_type
267
+
268
+ try:
269
+ if db_type.value == "sqlite":
270
+ query = """
271
+ SELECT name as table_name
272
+ FROM sqlite_master
273
+ WHERE type='table'
274
+ AND name NOT LIKE 'sqlite_%'
275
+ ORDER BY name
276
+ """
277
+ result = self.db.execute_query(query)
278
+ return [row['table_name'] for row in result]
279
+
280
+ elif db_type.value == "postgresql":
281
+ query = """
282
+ SELECT table_name
283
+ FROM information_schema.tables
284
+ WHERE table_schema = 'public'
285
+ AND table_type = 'BASE TABLE'
286
+ ORDER BY table_name
287
+ """
288
+ result = self.db.execute_query(query)
289
+ return [row['table_name'] for row in result]
290
+
291
+ else: # MySQL
292
+ query = """
293
+ SELECT TABLE_NAME
294
+ FROM INFORMATION_SCHEMA.TABLES
295
+ WHERE TABLE_SCHEMA = DATABASE()
296
+ AND TABLE_TYPE = 'BASE TABLE'
297
+ ORDER BY TABLE_NAME
298
+ """
299
+ result = self.db.execute_query(query)
300
+ return [row['TABLE_NAME'] for row in result]
301
+
302
+ except Exception as e:
303
+ logger.error(f"Error getting tables: {e}")
304
+ return []
305
+
306
+ def _introspect_table(self, table_name: str) -> Optional[TableInfo]:
307
+ """
308
+ Get complete information about a specific table.
309
+
310
+ Args:
311
+ table_name: Name of the table to introspect
312
+
313
+ Returns:
314
+ TableInfo object or None if table doesn't exist
315
+ """
316
+ try:
317
+ # Get column information
318
+ columns = self._get_columns(table_name)
319
+
320
+ # Get primary keys
321
+ primary_keys = self._get_primary_keys(table_name)
322
+
323
+ # Get foreign keys
324
+ foreign_keys = self._get_foreign_keys(table_name)
325
+
326
+ # Get approximate row count (fast estimation)
327
+ row_count = self._get_row_count(table_name)
328
+
329
+ # Get table comment (not available in SQLite)
330
+ comment = self._get_table_comment(table_name)
331
+
332
+ # Mark primary key columns
333
+ for col in columns:
334
+ col.is_primary_key = col.name in primary_keys
335
+
336
+ return TableInfo(
337
+ name=table_name,
338
+ columns=columns,
339
+ primary_keys=primary_keys,
340
+ foreign_keys=foreign_keys,
341
+ row_count=row_count,
342
+ comment=comment
343
+ )
344
+
345
+ except Exception as e:
346
+ logger.error(f"Error introspecting table {table_name}: {e}")
347
+ return None
348
+
349
+ def _get_columns(self, table_name: str) -> List[ColumnInfo]:
350
+ """Get all columns for a table."""
351
+ db_type = self.db.db_type
352
+
353
+ try:
354
+ if db_type.value == "sqlite":
355
+ query = f"PRAGMA table_info('{table_name}')"
356
+ result = self.db.execute_query(query)
357
+
358
+ columns = []
359
+ for row in result:
360
+ columns.append(ColumnInfo(
361
+ name=row['name'],
362
+ data_type=row['type'] or 'TEXT', # SQLite columns can have no type
363
+ is_nullable=row['notnull'] == 0,
364
+ is_primary_key=row['pk'] == 1,
365
+ max_length=None,
366
+ default_value=row['dflt_value'],
367
+ comment=None # SQLite doesn't support column comments
368
+ ))
369
+ return columns
370
+
371
+ elif db_type.value == "postgresql":
372
+ query = """
373
+ SELECT
374
+ column_name,
375
+ data_type,
376
+ is_nullable,
377
+ column_default,
378
+ character_maximum_length,
379
+ col_description(
380
+ (SELECT oid FROM pg_class WHERE relname = :table_name),
381
+ ordinal_position
382
+ ) as column_comment
383
+ FROM information_schema.columns
384
+ WHERE table_schema = 'public'
385
+ AND table_name = :table_name
386
+ ORDER BY ordinal_position
387
+ """
388
+ result = self.db.execute_query(query, {"table_name": table_name})
389
+
390
+ columns = []
391
+ for row in result:
392
+ columns.append(ColumnInfo(
393
+ name=row['column_name'],
394
+ data_type=row['data_type'],
395
+ is_nullable=row['is_nullable'] == 'YES',
396
+ is_primary_key=False, # Will be set later
397
+ max_length=row['character_maximum_length'],
398
+ default_value=row['column_default'],
399
+ comment=row.get('column_comment')
400
+ ))
401
+ return columns
402
+
403
+ else: # MySQL
404
+ query = """
405
+ SELECT
406
+ COLUMN_NAME,
407
+ COLUMN_TYPE,
408
+ IS_NULLABLE,
409
+ COLUMN_DEFAULT,
410
+ CHARACTER_MAXIMUM_LENGTH,
411
+ COLUMN_COMMENT
412
+ FROM INFORMATION_SCHEMA.COLUMNS
413
+ WHERE TABLE_SCHEMA = DATABASE()
414
+ AND TABLE_NAME = :table_name
415
+ ORDER BY ORDINAL_POSITION
416
+ """
417
+ result = self.db.execute_query(query, {"table_name": table_name})
418
+
419
+ columns = []
420
+ for row in result:
421
+ columns.append(ColumnInfo(
422
+ name=row['COLUMN_NAME'],
423
+ data_type=row['COLUMN_TYPE'],
424
+ is_nullable=row['IS_NULLABLE'] == 'YES',
425
+ is_primary_key=False, # Will be set later
426
+ max_length=row['CHARACTER_MAXIMUM_LENGTH'],
427
+ default_value=row['COLUMN_DEFAULT'],
428
+ comment=row['COLUMN_COMMENT'] if row['COLUMN_COMMENT'] else None
429
+ ))
430
+ return columns
431
+
432
+ except Exception as e:
433
+ logger.error(f"Error getting columns for {table_name}: {e}")
434
+ return []
435
+
436
+ def _get_primary_keys(self, table_name: str) -> List[str]:
437
+ """Get primary key columns for a table."""
438
+ db_type = self.db.db_type
439
+
440
+ try:
441
+ if db_type.value == "sqlite":
442
+ query = f"PRAGMA table_info('{table_name}')"
443
+ result = self.db.execute_query(query)
444
+ return [row['name'] for row in result if row['pk'] > 0]
445
+
446
+ elif db_type.value == "postgresql":
447
+ query = """
448
+ SELECT a.attname as column_name
449
+ FROM pg_index i
450
+ JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
451
+ WHERE i.indrelid = :table_name::regclass
452
+ AND i.indisprimary
453
+ """
454
+ result = self.db.execute_query(query, {"table_name": table_name})
455
+ return [row['column_name'] for row in result]
456
+
457
+ else: # MySQL
458
+ query = """
459
+ SELECT COLUMN_NAME
460
+ FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
461
+ WHERE TABLE_SCHEMA = DATABASE()
462
+ AND TABLE_NAME = :table_name
463
+ AND CONSTRAINT_NAME = 'PRIMARY'
464
+ ORDER BY ORDINAL_POSITION
465
+ """
466
+ result = self.db.execute_query(query, {"table_name": table_name})
467
+ return [row['COLUMN_NAME'] for row in result]
468
+
469
+ except Exception as e:
470
+ logger.error(f"Error getting primary keys for {table_name}: {e}")
471
+ return []
472
+
473
+ def _get_foreign_keys(self, table_name: str) -> Dict[str, str]:
474
+ """Get foreign key relationships for a table."""
475
+ db_type = self.db.db_type
476
+
477
+ try:
478
+ if db_type.value == "sqlite":
479
+ query = f"PRAGMA foreign_key_list('{table_name}')"
480
+ result = self.db.execute_query(query)
481
+ return {
482
+ row['from']: f"{row['table']}.{row['to']}"
483
+ for row in result
484
+ }
485
+
486
+ elif db_type.value == "postgresql":
487
+ query = """
488
+ SELECT
489
+ kcu.column_name,
490
+ ccu.table_name AS foreign_table_name,
491
+ ccu.column_name AS foreign_column_name
492
+ FROM information_schema.table_constraints AS tc
493
+ JOIN information_schema.key_column_usage AS kcu
494
+ ON tc.constraint_name = kcu.constraint_name
495
+ AND tc.table_schema = kcu.table_schema
496
+ JOIN information_schema.constraint_column_usage AS ccu
497
+ ON ccu.constraint_name = tc.constraint_name
498
+ AND ccu.table_schema = tc.table_schema
499
+ WHERE tc.constraint_type = 'FOREIGN KEY'
500
+ AND tc.table_name = :table_name
501
+ """
502
+ result = self.db.execute_query(query, {"table_name": table_name})
503
+ return {
504
+ row['column_name']: f"{row['foreign_table_name']}.{row['foreign_column_name']}"
505
+ for row in result
506
+ }
507
+
508
+ else: # MySQL
509
+ query = """
510
+ SELECT
511
+ COLUMN_NAME,
512
+ REFERENCED_TABLE_NAME,
513
+ REFERENCED_COLUMN_NAME
514
+ FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
515
+ WHERE TABLE_SCHEMA = DATABASE()
516
+ AND TABLE_NAME = :table_name
517
+ AND REFERENCED_TABLE_NAME IS NOT NULL
518
+ """
519
+ result = self.db.execute_query(query, {"table_name": table_name})
520
+ return {
521
+ row['COLUMN_NAME']: f"{row['REFERENCED_TABLE_NAME']}.{row['REFERENCED_COLUMN_NAME']}"
522
+ for row in result
523
+ }
524
+
525
+ except Exception as e:
526
+ logger.error(f"Error getting foreign keys for {table_name}: {e}")
527
+ return {}
528
+
529
+ def _get_row_count(self, table_name: str) -> Optional[int]:
530
+ """
531
+ Get approximate row count for a table.
532
+ Uses different strategies per database.
533
+ """
534
+ db_type = self.db.db_type
535
+
536
+ try:
537
+ if db_type.value == "sqlite":
538
+ # SQLite doesn't have stats table, use max rowid for estimation
539
+ query = f"SELECT MAX(rowid) as row_count FROM \"{table_name}\""
540
+ result = self.db.execute_query(query)
541
+ return result[0]['row_count'] if result and result[0]['row_count'] else 0
542
+
543
+ elif db_type.value == "postgresql":
544
+ # Use pg_stat_user_tables for fast estimation
545
+ query = """
546
+ SELECT n_live_tup as row_count
547
+ FROM pg_stat_user_tables
548
+ WHERE relname = :table_name
549
+ """
550
+ result = self.db.execute_query(query, {"table_name": table_name})
551
+ return result[0]['row_count'] if result else None
552
+
553
+ else: # MySQL
554
+ query = """
555
+ SELECT TABLE_ROWS
556
+ FROM INFORMATION_SCHEMA.TABLES
557
+ WHERE TABLE_SCHEMA = DATABASE()
558
+ AND TABLE_NAME = :table_name
559
+ """
560
+ result = self.db.execute_query(query, {"table_name": table_name})
561
+ return result[0]['TABLE_ROWS'] if result else None
562
+
563
+ except Exception as e:
564
+ logger.error(f"Error getting row count for {table_name}: {e}")
565
+ return None
566
+
567
+ def _get_table_comment(self, table_name: str) -> Optional[str]:
568
+ """Get table comment/description."""
569
+ db_type = self.db.db_type
570
+
571
+ try:
572
+ if db_type.value == "sqlite":
573
+ # SQLite doesn't support table comments
574
+ return None
575
+
576
+ elif db_type.value == "postgresql":
577
+ query = """
578
+ SELECT obj_description(:table_name::regclass, 'pg_class') as table_comment
579
+ """
580
+ result = self.db.execute_query(query, {"table_name": table_name})
581
+ comment = result[0]['table_comment'] if result else None
582
+ return comment if comment else None
583
+
584
+ else: # MySQL
585
+ query = """
586
+ SELECT TABLE_COMMENT
587
+ FROM INFORMATION_SCHEMA.TABLES
588
+ WHERE TABLE_SCHEMA = DATABASE()
589
+ AND TABLE_NAME = :table_name
590
+ """
591
+ result = self.db.execute_query(query, {"table_name": table_name})
592
+ comment = result[0]['TABLE_COMMENT'] if result else None
593
+ return comment if comment else None
594
+
595
+ except Exception as e:
596
+ logger.error(f"Error getting table comment for {table_name}: {e}")
597
+ return None
598
+
599
+ def get_text_columns_for_rag(self, min_length: int = 50) -> List[Dict[str, Any]]:
600
+ """
601
+ Get all text columns suitable for RAG indexing.
602
+
603
+ Args:
604
+ min_length: Minimum max_length for varchar columns to be considered
605
+
606
+ Returns:
607
+ List of dicts with table name, column name, and metadata
608
+ """
609
+ schema = self.introspect()
610
+ text_columns = []
611
+
612
+ for table_name, table_info in schema.tables.items():
613
+ for col in table_info.columns:
614
+ if col.is_text_type:
615
+ # Skip very short varchar columns
616
+ if col.max_length and col.max_length < min_length:
617
+ continue
618
+
619
+ text_columns.append({
620
+ "table": table_name,
621
+ "column": col.name,
622
+ "data_type": col.data_type,
623
+ "primary_keys": table_info.primary_keys,
624
+ "max_length": col.max_length
625
+ })
626
+
627
+ return text_columns
628
+
629
+ def refresh_cache(self) -> SchemaInfo:
630
+ """Force refresh the cached schema."""
631
+ return self.introspect(force_refresh=True)
632
+
633
+
634
+ # Global introspector instance
635
+ _introspector: Optional[SchemaIntrospector] = None
636
+
637
+
638
+ def get_introspector() -> SchemaIntrospector:
639
+ """Get or create the global schema introspector."""
640
+ global _introspector
641
+ if _introspector is None:
642
+ _introspector = SchemaIntrospector()
643
+ return _introspector
644
+
645
+
646
+ def get_schema() -> SchemaInfo:
647
+ """Convenience function to get the current schema."""
648
+ return get_introspector().introspect()
llm/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM module exports."""
2
+
3
+ from .client import (
4
+ LLMClient,
5
+ GroqClient,
6
+ OpenAIClient,
7
+ LocalLLaMAClient,
8
+ create_llm_client
9
+ )
10
+
11
+ __all__ = [
12
+ "LLMClient",
13
+ "GroqClient",
14
+ "OpenAIClient",
15
+ "LocalLLaMAClient",
16
+ "create_llm_client"
17
+ ]
llm/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (433 Bytes). View file
 
llm/__pycache__/client.cpython-311.pyc ADDED
Binary file (8.38 kB). View file
 
llm/client.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Client - Unified interface for Groq, OpenAI, and local models.
3
+
4
+ Groq is the DEFAULT provider (free tier available).
5
+ """
6
+
7
+ import logging
8
+ from abc import ABC, abstractmethod
9
+ from typing import List, Dict, Optional
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class LLMClient(ABC):
15
+ """Abstract base class for LLM clients."""
16
+
17
+ @abstractmethod
18
+ def chat(self, messages: List[Dict[str, str]]) -> str:
19
+ pass
20
+
21
+ @abstractmethod
22
+ def is_available(self) -> bool:
23
+ pass
24
+
25
+
26
+ class GroqClient(LLMClient):
27
+ """
28
+ Groq API client - FREE and FAST inference.
29
+
30
+ Available models:
31
+ - llama-3.3-70b-versatile (recommended)
32
+ - llama-3.1-8b-instant (faster)
33
+ - mixtral-8x7b-32768
34
+ - gemma2-9b-it
35
+ """
36
+
37
+ AVAILABLE_MODELS = [
38
+ "llama-3.3-70b-versatile",
39
+ "llama-3.1-70b-versatile",
40
+ "llama-3.1-8b-instant",
41
+ "llama3-70b-8192",
42
+ "llama3-8b-8192",
43
+ "mixtral-8x7b-32768",
44
+ "gemma2-9b-it"
45
+ ]
46
+
47
+ def __init__(
48
+ self,
49
+ api_key: str,
50
+ model: str = "llama-3.3-70b-versatile",
51
+ temperature: float = 0.1,
52
+ max_tokens: int = 1024
53
+ ):
54
+ self.api_key = api_key
55
+ self.model = model
56
+ self.temperature = temperature
57
+ self.max_tokens = max_tokens
58
+ self._client = None
59
+
60
+ @property
61
+ def client(self):
62
+ if self._client is None:
63
+ from groq import Groq
64
+ self._client = Groq(api_key=self.api_key)
65
+ return self._client
66
+
67
+ def chat(self, messages: List[Dict[str, str]]) -> str:
68
+ response = self.client.chat.completions.create(
69
+ model=self.model,
70
+ messages=messages,
71
+ temperature=self.temperature,
72
+ max_tokens=self.max_tokens
73
+ )
74
+ return response.choices[0].message.content
75
+
76
+ def is_available(self) -> bool:
77
+ try:
78
+ # Simple test call
79
+ self.client.models.list()
80
+ return True
81
+ except Exception as e:
82
+ logger.warning(f"Groq availability check failed: {e}")
83
+ return False
84
+
85
+
86
+ class OpenAIClient(LLMClient):
87
+ """OpenAI API client (paid)."""
88
+
89
+ def __init__(
90
+ self,
91
+ api_key: str,
92
+ model: str = "gpt-4o-mini",
93
+ temperature: float = 0.1,
94
+ max_tokens: int = 1024
95
+ ):
96
+ self.api_key = api_key
97
+ self.model = model
98
+ self.temperature = temperature
99
+ self.max_tokens = max_tokens
100
+ self._client = None
101
+
102
+ @property
103
+ def client(self):
104
+ if self._client is None:
105
+ from openai import OpenAI
106
+ self._client = OpenAI(api_key=self.api_key)
107
+ return self._client
108
+
109
+ def chat(self, messages: List[Dict[str, str]]) -> str:
110
+ response = self.client.chat.completions.create(
111
+ model=self.model,
112
+ messages=messages,
113
+ temperature=self.temperature,
114
+ max_tokens=self.max_tokens
115
+ )
116
+ return response.choices[0].message.content
117
+
118
+ def is_available(self) -> bool:
119
+ try:
120
+ self.client.models.list()
121
+ return True
122
+ except Exception:
123
+ return False
124
+
125
+
126
+ class LocalLLaMAClient(LLMClient):
127
+ """Local LLaMA/Phi model client via transformers."""
128
+
129
+ def __init__(
130
+ self,
131
+ model_name: str = "microsoft/Phi-3-mini-4k-instruct",
132
+ temperature: float = 0.1,
133
+ max_tokens: int = 1024
134
+ ):
135
+ self.model_name = model_name
136
+ self.temperature = temperature
137
+ self.max_tokens = max_tokens
138
+ self._pipeline = None
139
+
140
+ @property
141
+ def pipeline(self):
142
+ if self._pipeline is None:
143
+ from transformers import pipeline
144
+ logger.info(f"Loading local model: {self.model_name}")
145
+ self._pipeline = pipeline(
146
+ "text-generation",
147
+ model=self.model_name,
148
+ torch_dtype="auto",
149
+ device_map="auto"
150
+ )
151
+ return self._pipeline
152
+
153
+ def chat(self, messages: List[Dict[str, str]]) -> str:
154
+ output = self.pipeline(
155
+ messages,
156
+ max_new_tokens=self.max_tokens,
157
+ temperature=self.temperature,
158
+ do_sample=True
159
+ )
160
+ return output[0]["generated_text"][-1]["content"]
161
+
162
+ def is_available(self) -> bool:
163
+ try:
164
+ _ = self.pipeline
165
+ return True
166
+ except Exception:
167
+ return False
168
+
169
+
170
+ def create_llm_client(provider: str = "groq", **kwargs) -> LLMClient:
171
+ """
172
+ Factory function to create LLM client.
173
+
174
+ Args:
175
+ provider: "groq" (default, free), "openai", or "local"
176
+ **kwargs: Provider-specific arguments
177
+
178
+ Returns:
179
+ Configured LLMClient instance
180
+ """
181
+ if provider == "groq":
182
+ return GroqClient(**kwargs)
183
+ elif provider == "openai":
184
+ return OpenAIClient(**kwargs)
185
+ elif provider == "local":
186
+ return LocalLLaMAClient(**kwargs)
187
+ else:
188
+ raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'openai', or 'local'")
memory.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat Memory - Short-term and long-term memory management.
3
+
4
+ Supports MySQL, PostgreSQL, and SQLite with dialect-specific DDL.
5
+ """
6
+
7
+ import logging
8
+ import json
9
+ from typing import List, Dict, Any, Optional
10
+ from datetime import datetime
11
+ from dataclasses import dataclass
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class ChatMessage:
18
+ role: str # "user" or "assistant"
19
+ content: str
20
+ timestamp: datetime = None
21
+ metadata: Dict[str, Any] = None
22
+
23
+ def __post_init__(self):
24
+ if self.timestamp is None:
25
+ self.timestamp = datetime.now()
26
+ if self.metadata is None:
27
+ self.metadata = {}
28
+
29
+ def to_dict(self) -> Dict[str, str]:
30
+ return {"role": self.role, "content": self.content}
31
+
32
+
33
+ def get_memory_table_ddl(db_type: str) -> str:
34
+ """Get the DDL for chat memory table based on database type."""
35
+ if db_type == "postgresql":
36
+ return """
37
+ CREATE TABLE IF NOT EXISTS _chatbot_memory (
38
+ id SERIAL PRIMARY KEY,
39
+ session_id VARCHAR(255) NOT NULL,
40
+ user_id VARCHAR(255) NOT NULL DEFAULT 'default',
41
+ role VARCHAR(50) NOT NULL,
42
+ content TEXT NOT NULL,
43
+ metadata JSONB,
44
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
45
+ )
46
+ """
47
+ elif db_type == "sqlite":
48
+ return """
49
+ CREATE TABLE IF NOT EXISTS _chatbot_memory (
50
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
51
+ session_id TEXT NOT NULL,
52
+ user_id TEXT NOT NULL DEFAULT 'default',
53
+ role TEXT NOT NULL,
54
+ content TEXT NOT NULL,
55
+ metadata TEXT,
56
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
57
+ )
58
+ """
59
+ else: # MySQL
60
+ return """
61
+ CREATE TABLE IF NOT EXISTS _chatbot_memory (
62
+ id INT AUTO_INCREMENT PRIMARY KEY,
63
+ session_id VARCHAR(255) NOT NULL,
64
+ user_id VARCHAR(255) NOT NULL DEFAULT 'default',
65
+ role VARCHAR(50) NOT NULL,
66
+ content TEXT NOT NULL,
67
+ metadata JSON,
68
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
69
+ INDEX idx_session (session_id),
70
+ INDEX idx_user (user_id),
71
+ INDEX idx_created (created_at)
72
+ )
73
+ """
74
+
75
+
76
+ def get_permanent_memory_ddl(db_type: str) -> str:
77
+ """Get the DDL for permanent memory table based on database type."""
78
+ if db_type == "postgresql":
79
+ return """
80
+ CREATE TABLE IF NOT EXISTS _chatbot_permanent_memory_v2 (
81
+ id SERIAL PRIMARY KEY,
82
+ user_id VARCHAR(255) NOT NULL DEFAULT 'default',
83
+ content TEXT NOT NULL,
84
+ tags VARCHAR(255),
85
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
86
+ )
87
+ """
88
+ elif db_type == "sqlite":
89
+ return """
90
+ CREATE TABLE IF NOT EXISTS _chatbot_permanent_memory_v2 (
91
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
92
+ user_id TEXT NOT NULL DEFAULT 'default',
93
+ content TEXT NOT NULL,
94
+ tags TEXT,
95
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
96
+ )
97
+ """
98
+ else: # MySQL
99
+ return """
100
+ CREATE TABLE IF NOT EXISTS _chatbot_permanent_memory_v2 (
101
+ id INT AUTO_INCREMENT PRIMARY KEY,
102
+ user_id VARCHAR(255) NOT NULL DEFAULT 'default',
103
+ content TEXT NOT NULL,
104
+ tags VARCHAR(255),
105
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
106
+ INDEX idx_user (user_id)
107
+ )
108
+ """
109
+
110
+
111
+ def get_summary_table_ddl(db_type: str) -> str:
112
+ """Get the DDL for summary table based on database type."""
113
+ if db_type == "postgresql":
114
+ return """
115
+ CREATE TABLE IF NOT EXISTS _chatbot_user_summaries (
116
+ id SERIAL PRIMARY KEY,
117
+ user_id VARCHAR(255) NOT NULL UNIQUE,
118
+ summary TEXT NOT NULL,
119
+ message_count INT DEFAULT 0,
120
+ last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
121
+ )
122
+ """
123
+ elif db_type == "sqlite":
124
+ return """
125
+ CREATE TABLE IF NOT EXISTS _chatbot_user_summaries (
126
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
127
+ user_id TEXT NOT NULL UNIQUE,
128
+ summary TEXT NOT NULL,
129
+ message_count INTEGER DEFAULT 0,
130
+ last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
131
+ )
132
+ """
133
+ else: # MySQL
134
+ return """
135
+ CREATE TABLE IF NOT EXISTS _chatbot_user_summaries (
136
+ id INT AUTO_INCREMENT PRIMARY KEY,
137
+ user_id VARCHAR(255) NOT NULL,
138
+ summary TEXT NOT NULL,
139
+ message_count INT DEFAULT 0,
140
+ last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
141
+ UNIQUE KEY idx_user (user_id)
142
+ )
143
+ """
144
+
145
+
146
+ def get_upsert_summary_query(db_type: str) -> str:
147
+ """Get the upsert query for summary based on database type."""
148
+ if db_type == "postgresql":
149
+ return """
150
+ INSERT INTO _chatbot_user_summaries
151
+ (user_id, summary, message_count, last_updated)
152
+ VALUES (:user_id, :summary, :message_count, CURRENT_TIMESTAMP)
153
+ ON CONFLICT (user_id)
154
+ DO UPDATE SET
155
+ summary = EXCLUDED.summary,
156
+ message_count = EXCLUDED.message_count,
157
+ last_updated = CURRENT_TIMESTAMP
158
+ """
159
+ elif db_type == "sqlite":
160
+ return """
161
+ INSERT INTO _chatbot_user_summaries
162
+ (user_id, summary, message_count, last_updated)
163
+ VALUES (:user_id, :summary, :message_count, CURRENT_TIMESTAMP)
164
+ ON CONFLICT(user_id)
165
+ DO UPDATE SET
166
+ summary = excluded.summary,
167
+ message_count = excluded.message_count,
168
+ last_updated = CURRENT_TIMESTAMP
169
+ """
170
+ else: # MySQL
171
+ return """
172
+ INSERT INTO _chatbot_user_summaries
173
+ (user_id, summary, message_count)
174
+ VALUES (:user_id, :summary, :message_count)
175
+ ON DUPLICATE KEY UPDATE
176
+ summary = :summary,
177
+ message_count = :message_count,
178
+ last_updated = CURRENT_TIMESTAMP
179
+ """
180
+
181
+
182
+ class ChatMemory:
183
+ """Manages chat history with short-term and long-term storage."""
184
+
185
+ def __init__(self, session_id: str, user_id: str = "default", max_messages: int = 20, db_connection=None):
186
+ self.session_id = session_id
187
+ self.user_id = user_id
188
+ self.max_messages = max_messages
189
+ self.db = db_connection
190
+ self.messages: List[ChatMessage] = []
191
+ self._db_type = None
192
+
193
+ if self.db:
194
+ self._db_type = self.db.db_type.value
195
+ self._ensure_tables()
196
+
197
+ def _ensure_tables(self):
198
+ """Create memory tables if they don't exist."""
199
+ try:
200
+ memory_ddl = get_memory_table_ddl(self._db_type)
201
+ permanent_ddl = get_permanent_memory_ddl(self._db_type)
202
+
203
+ self.db.execute_write(memory_ddl)
204
+ self.db.execute_write(permanent_ddl)
205
+
206
+ # Create indexes for SQLite and PostgreSQL (MySQL creates them inline)
207
+ if self._db_type in ("sqlite", "postgresql"):
208
+ self._create_indexes()
209
+
210
+ # Migration: Ensure user_id column exists (MySQL only for legacy support)
211
+ if self._db_type == "mysql":
212
+ self._migrate_mysql_user_id()
213
+
214
+ except Exception as e:
215
+ logger.warning(f"Failed to create memory tables: {e}")
216
+
217
+ def _create_indexes(self):
218
+ """Create indexes for SQLite and PostgreSQL."""
219
+ try:
220
+ if self._db_type == "sqlite":
221
+ self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_session ON _chatbot_memory(session_id)")
222
+ self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_user ON _chatbot_memory(user_id)")
223
+ self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_created ON _chatbot_memory(created_at)")
224
+ self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_permanent_user ON _chatbot_permanent_memory_v2(user_id)")
225
+ elif self._db_type == "postgresql":
226
+ self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_session ON _chatbot_memory(session_id)")
227
+ self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_user ON _chatbot_memory(user_id)")
228
+ self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_created ON _chatbot_memory(created_at)")
229
+ self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_permanent_user ON _chatbot_permanent_memory_v2(user_id)")
230
+ except Exception as e:
231
+ logger.debug(f"Index creation (may already exist): {e}")
232
+
233
+ def _migrate_mysql_user_id(self):
234
+ """Migrate MySQL table to include user_id column if missing."""
235
+ try:
236
+ check_query = """
237
+ SELECT COLUMN_NAME
238
+ FROM INFORMATION_SCHEMA.COLUMNS
239
+ WHERE TABLE_SCHEMA = :db_name
240
+ AND TABLE_NAME = '_chatbot_memory'
241
+ AND COLUMN_NAME = 'user_id'
242
+ """
243
+ db_name = self.db.config.database
244
+ result = self.db.execute_query(check_query, {"db_name": db_name})
245
+
246
+ if not result:
247
+ self.db.execute_write("ALTER TABLE _chatbot_memory ADD COLUMN user_id VARCHAR(255) NOT NULL DEFAULT 'default' AFTER session_id")
248
+ self.db.execute_write("CREATE INDEX idx_user ON _chatbot_memory(user_id)")
249
+ logger.info("Migrated _chatbot_memory to include user_id")
250
+ except Exception as e:
251
+ logger.debug(f"Migration check failed: {e}")
252
+
253
+ def add_message(self, role: str, content: str, metadata: Dict = None):
254
+ """Add a message to memory and optionally persist it."""
255
+ msg = ChatMessage(role=role, content=content, metadata=metadata)
256
+ self.messages.append(msg)
257
+
258
+ # Trim if exceeds max (short-term)
259
+ if len(self.messages) > self.max_messages:
260
+ self.messages = self.messages[-self.max_messages:]
261
+
262
+ # Persist to DB (session history)
263
+ if self.db:
264
+ try:
265
+ query = """
266
+ INSERT INTO _chatbot_memory (session_id, user_id, role, content, metadata)
267
+ VALUES (:session_id, :user_id, :role, :content, :metadata)
268
+ """
269
+ self.db.execute_write(query, {
270
+ "session_id": self.session_id,
271
+ "user_id": self.user_id,
272
+ "role": role,
273
+ "content": content,
274
+ "metadata": json.dumps(metadata) if metadata else None
275
+ })
276
+ except Exception as e:
277
+ logger.warning(f"Failed to persist message: {e}")
278
+
279
+ def save_permanent_context(self, content: str, tags: str = "user_saved"):
280
+ """Save specific context explicitly to permanent memory for this user."""
281
+ if not self.db:
282
+ return False, "No database connection"
283
+
284
+ try:
285
+ query = """
286
+ INSERT INTO _chatbot_permanent_memory_v2 (user_id, content, tags)
287
+ VALUES (:user_id, :content, :tags)
288
+ """
289
+ self.db.execute_write(query, {
290
+ "user_id": self.user_id,
291
+ "content": content,
292
+ "tags": tags
293
+ })
294
+ return True, "Context saved to permanent memory"
295
+ except Exception as e:
296
+ logger.error(f"Failed to save permanent context: {e}")
297
+ return False, str(e)
298
+
299
+ def get_permanent_context(self, limit: int = 5) -> List[str]:
300
+ """Retrieve recent permanent context for this user only."""
301
+ if not self.db:
302
+ return []
303
+
304
+ try:
305
+ # Use database-agnostic LIMIT syntax
306
+ query = """
307
+ SELECT content FROM _chatbot_permanent_memory_v2
308
+ WHERE user_id = :user_id
309
+ ORDER BY created_at DESC LIMIT :limit
310
+ """
311
+ rows = self.db.execute_query(query, {
312
+ "user_id": self.user_id,
313
+ "limit": limit
314
+ })
315
+ return [row['content'] for row in rows]
316
+ except Exception as e:
317
+ logger.warning(f"Failed to load permanent context: {e}")
318
+ return []
319
+
320
+ def get_messages(self, limit: Optional[int] = None) -> List[Dict[str, str]]:
321
+ """Get messages for LLM context."""
322
+ msgs = self.messages if limit is None else self.messages[-limit:]
323
+ return [m.to_dict() for m in msgs]
324
+
325
+ def get_context_messages(self, count: int = 5) -> List[Dict[str, str]]:
326
+ """Get recent messages plus permanent context for injection."""
327
+ # Get short-term session messages
328
+ context = self.get_messages(limit=count)
329
+
330
+ # Inject permanent memory if available
331
+ perm_docs = self.get_permanent_context(limit=3)
332
+ if perm_docs:
333
+ perm_context = f"IMPORTANT CONTEXT FOR USER '{self.user_id}':\n" + "\n".join(perm_docs)
334
+ # Add as a system note at the start
335
+ context.insert(0, {"role": "system", "content": perm_context})
336
+
337
+ return context
338
+
339
+ def clear(self):
340
+ """Clear current session memory and remove from DB (temporary history)."""
341
+ self.messages = []
342
+
343
+ if self.db:
344
+ try:
345
+ # Delete temporary messages for this session
346
+ query = "DELETE FROM _chatbot_memory WHERE session_id = :session_id"
347
+ self.db.execute_write(query, {"session_id": self.session_id})
348
+ logger.info(f"Cleared session memory for {self.session_id}")
349
+ except Exception as e:
350
+ logger.warning(f"Failed to clear memory from DB: {e}")
351
+
352
+ def clear_user_history(self):
353
+ """Clear ALL temporary history for this user (across all sessions)."""
354
+ self.messages = []
355
+ if self.db:
356
+ try:
357
+ query = "DELETE FROM _chatbot_memory WHERE user_id = :user_id"
358
+ self.db.execute_write(query, {"user_id": self.user_id})
359
+ logger.info(f"Cleared all temporary history for user: {self.user_id}")
360
+ except Exception as e:
361
+ logger.warning(f"Failed to clear user history from DB: {e}")
362
+
363
+
364
+ class ConversationSummaryMemory:
365
+ """
366
+ Per-user conversation summary memory using LLM for summarization.
367
+
368
+ This class maintains a running summary of the conversation, updating it
369
+ periodically (when message count exceeds threshold). This dramatically
370
+ reduces token usage while preserving context for long conversations.
371
+
372
+ Features:
373
+ - Automatic summarization when threshold is reached
374
+ - Per-user summary storage in database
375
+ - Combines summary + recent messages for optimal context
376
+ - Lazy summarization (only when needed)
377
+ """
378
+
379
+ SUMMARIZATION_PROMPT = """You are a conversation summarizer. Create a concise summary of the conversation below that captures:
380
+ 1. Key topics discussed
381
+ 2. Important facts or preferences mentioned by the user
382
+ 3. Any decisions or conclusions reached
383
+ 4. Context needed for follow-up questions
384
+
385
+ Keep the summary under 300 words but include all important details.
386
+
387
+ CONVERSATION:
388
+ {conversation}
389
+
390
+ SUMMARY:"""
391
+
392
+ INCREMENTAL_SUMMARY_PROMPT = """You are a conversation summarizer. Update the existing summary to incorporate new messages.
393
+
394
+ EXISTING SUMMARY:
395
+ {existing_summary}
396
+
397
+ NEW MESSAGES:
398
+ {new_messages}
399
+
400
+ Create an updated, comprehensive summary that:
401
+ 1. Incorporates new information from the recent messages
402
+ 2. Retains important context from the existing summary
403
+ 3. Removes redundant or outdated information
404
+ 4. Stays under 300 words
405
+
406
+ UPDATED SUMMARY:"""
407
+
408
+ def __init__(
409
+ self,
410
+ user_id: str,
411
+ session_id: str,
412
+ db_connection=None,
413
+ llm_client=None,
414
+ summary_threshold: int = 10, # Summarize every N messages
415
+ recent_messages_count: int = 5 # Keep this many recent messages verbatim
416
+ ):
417
+ self.user_id = user_id
418
+ self.session_id = session_id
419
+ self.db = db_connection
420
+ self.llm = llm_client
421
+ self.summary_threshold = summary_threshold
422
+ self.recent_messages_count = recent_messages_count
423
+ self._db_type = None
424
+
425
+ self._cached_summary: Optional[str] = None
426
+ self._messages_since_summary: int = 0
427
+
428
+ if self.db:
429
+ self._db_type = self.db.db_type.value
430
+ self._ensure_tables()
431
+ self._load_state()
432
+
433
+ def _ensure_tables(self):
434
+ """Create summary table if it doesn't exist."""
435
+ try:
436
+ ddl = get_summary_table_ddl(self._db_type)
437
+ self.db.execute_write(ddl)
438
+ except Exception as e:
439
+ logger.warning(f"Failed to create summary table: {e}")
440
+
441
+ def _load_state(self):
442
+ """Load existing summary state from database (per-user, not per-session)."""
443
+ try:
444
+ query = """
445
+ SELECT summary, message_count FROM _chatbot_user_summaries
446
+ WHERE user_id = :user_id
447
+ """
448
+ rows = self.db.execute_query(query, {
449
+ "user_id": self.user_id
450
+ })
451
+ if rows:
452
+ self._cached_summary = rows[0].get('summary')
453
+ self._messages_since_summary = 0 # Reset since we loaded
454
+ logger.debug(f"Loaded summary for user {self.user_id}")
455
+ except Exception as e:
456
+ logger.warning(f"Failed to load summary state: {e}")
457
+
458
+ def set_llm_client(self, llm_client):
459
+ """Set the LLM client for summarization."""
460
+ self.llm = llm_client
461
+
462
+ def on_message_added(self, message_count: int):
463
+ """
464
+ Called after a message is added to track when to summarize.
465
+
466
+ Args:
467
+ message_count: Current total number of messages in the conversation
468
+ """
469
+ self._messages_since_summary += 1
470
+
471
+ # Check if we should summarize
472
+ if self._messages_since_summary >= self.summary_threshold:
473
+ self._trigger_summarization()
474
+
475
+ def _trigger_summarization(self):
476
+ """Trigger summarization of the conversation."""
477
+ if not self.llm:
478
+ logger.warning("Cannot summarize: No LLM client configured")
479
+ return
480
+
481
+ if not self.db:
482
+ logger.warning("Cannot summarize: No database connection")
483
+ return
484
+
485
+ try:
486
+ # Get messages that need to be summarized
487
+ query = """
488
+ SELECT role, content FROM _chatbot_memory
489
+ WHERE user_id = :user_id AND session_id = :session_id
490
+ ORDER BY created_at ASC
491
+ """
492
+ rows = self.db.execute_query(query, {
493
+ "user_id": self.user_id,
494
+ "session_id": self.session_id
495
+ })
496
+
497
+ if not rows:
498
+ return
499
+
500
+ # Format conversation for summarization
501
+ conversation_text = self._format_messages_for_summary(rows)
502
+
503
+ # Generate summary
504
+ if self._cached_summary:
505
+ # Incremental update
506
+ prompt = self.INCREMENTAL_SUMMARY_PROMPT.format(
507
+ existing_summary=self._cached_summary,
508
+ new_messages=conversation_text
509
+ )
510
+ else:
511
+ # Fresh summary
512
+ prompt = self.SUMMARIZATION_PROMPT.format(conversation=conversation_text)
513
+
514
+ messages = [
515
+ {"role": "system", "content": "You are a helpful assistant that creates concise conversation summaries."},
516
+ {"role": "user", "content": prompt}
517
+ ]
518
+
519
+ summary = self.llm.chat(messages)
520
+
521
+ # Save to database
522
+ self._save_summary(summary, len(rows))
523
+
524
+ self._cached_summary = summary
525
+ self._messages_since_summary = 0
526
+
527
+ logger.info(f"Generated summary for user {self.user_id}")
528
+
529
+ except Exception as e:
530
+ logger.error(f"Summarization failed: {e}")
531
+
532
+ def _format_messages_for_summary(self, messages: List[Dict]) -> str:
533
+ """Format messages as text for summarization."""
534
+ lines = []
535
+ for msg in messages:
536
+ role = msg.get('role', 'unknown').upper()
537
+ content = msg.get('content', '')
538
+ lines.append(f"{role}: {content}")
539
+ return "\n\n".join(lines)
540
+
541
+ def _save_summary(self, summary: str, message_count: int):
542
+ """Save or update summary in database (per-user)."""
543
+ try:
544
+ query = get_upsert_summary_query(self._db_type)
545
+ self.db.execute_write(query, {
546
+ "user_id": self.user_id,
547
+ "summary": summary,
548
+ "message_count": message_count
549
+ })
550
+ except Exception as e:
551
+ logger.error(f"Failed to save summary: {e}")
552
+
553
+ def get_summary(self) -> Optional[str]:
554
+ """Get the current conversation summary."""
555
+ return self._cached_summary
556
+
557
+ def get_context_for_llm(self, recent_messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
558
+ """
559
+ Get optimized context for LLM calls.
560
+
561
+ Combines the summary (if available) with recent messages for optimal
562
+ token usage while maintaining context.
563
+
564
+ Args:
565
+ recent_messages: List of recent messages to include verbatim
566
+
567
+ Returns:
568
+ List of messages with summary prepended as system context
569
+ """
570
+ context_messages = []
571
+
572
+ # Add summary as system context if available
573
+ if self._cached_summary:
574
+ summary_context = f"""CONVERSATION SUMMARY (previous context):
575
+ {self._cached_summary}
576
+
577
+ Use this summary to understand the conversation history and context for follow-up questions."""
578
+ context_messages.append({
579
+ "role": "system",
580
+ "content": summary_context
581
+ })
582
+
583
+ # Add recent messages verbatim
584
+ context_messages.extend(recent_messages[-self.recent_messages_count:])
585
+
586
+ return context_messages
587
+
588
+ def force_summarize(self):
589
+ """Force immediate summarization regardless of threshold."""
590
+ self._trigger_summarization()
591
+
592
+ def clear_summary(self):
593
+ """Clear the summary for this user."""
594
+ self._cached_summary = None
595
+ self._messages_since_summary = 0
596
+
597
+ if self.db:
598
+ try:
599
+ query = "DELETE FROM _chatbot_user_summaries WHERE user_id = :user_id"
600
+ self.db.execute_write(query, {
601
+ "user_id": self.user_id
602
+ })
603
+ logger.info(f"Cleared summary for user: {self.user_id}")
604
+ except Exception as e:
605
+ logger.warning(f"Failed to clear summary: {e}")
606
+
607
+ def clear_all_user_summaries(self):
608
+ """Clear all summaries for this user (alias for clear_summary since it's now per-user)."""
609
+ self.clear_summary()
610
+
611
+
612
+ class EnhancedChatMemory(ChatMemory):
613
+ """
614
+ Enhanced ChatMemory with integrated conversation summarization.
615
+
616
+ Combines the standard ChatMemory functionality with ConversationSummaryMemory
617
+ for automatic summarization and optimized context retrieval.
618
+ """
619
+
620
+ def __init__(
621
+ self,
622
+ session_id: str,
623
+ user_id: str = "default",
624
+ max_messages: int = 20,
625
+ db_connection=None,
626
+ llm_client=None,
627
+ enable_summarization: bool = True,
628
+ summary_threshold: int = 10
629
+ ):
630
+ super().__init__(session_id, user_id, max_messages, db_connection)
631
+
632
+ self.enable_summarization = enable_summarization
633
+ self.summary_memory: Optional[ConversationSummaryMemory] = None
634
+
635
+ if enable_summarization:
636
+ self.summary_memory = ConversationSummaryMemory(
637
+ user_id=user_id,
638
+ session_id=session_id,
639
+ db_connection=db_connection,
640
+ llm_client=llm_client,
641
+ summary_threshold=summary_threshold
642
+ )
643
+
644
+ def set_llm_client(self, llm_client):
645
+ """Set the LLM client for summarization."""
646
+ if self.summary_memory:
647
+ self.summary_memory.set_llm_client(llm_client)
648
+
649
+ def add_message(self, role: str, content: str, metadata: Dict = None):
650
+ """Add a message and trigger summarization check."""
651
+ super().add_message(role, content, metadata)
652
+
653
+ # Notify summary memory of new message
654
+ if self.summary_memory:
655
+ self.summary_memory.on_message_added(len(self.messages))
656
+
657
+ def get_context_messages(self, count: int = 5) -> List[Dict[str, str]]:
658
+ """
659
+ Get context messages with summary integration.
660
+
661
+ If summarization is enabled and a summary exists, it will be
662
+ prepended to provide historical context while keeping recent
663
+ messages verbatim.
664
+ """
665
+ # Get base context from parent
666
+ base_context = super().get_context_messages(count)
667
+
668
+ # If summarization is enabled, use summary-enhanced context
669
+ if self.summary_memory and self.summary_memory.get_summary():
670
+ # Filter out system messages from base context (we'll add summary separately)
671
+ filtered = [m for m in base_context if m.get("role") != "system"]
672
+
673
+ # Get summary-enhanced context
674
+ enhanced = self.summary_memory.get_context_for_llm(filtered)
675
+
676
+ # Re-add permanent memory context if it was present
677
+ for msg in base_context:
678
+ if msg.get("role") == "system" and "IMPORTANT CONTEXT" in msg.get("content", ""):
679
+ enhanced.insert(0, msg)
680
+
681
+ return enhanced
682
+
683
+ return base_context
684
+
685
+ def get_summary(self) -> Optional[str]:
686
+ """Get the current conversation summary."""
687
+ if self.summary_memory:
688
+ return self.summary_memory.get_summary()
689
+ return None
690
+
691
+ def force_summarize(self):
692
+ """Force immediate summarization."""
693
+ if self.summary_memory:
694
+ self.summary_memory.force_summarize()
695
+
696
+ def clear(self):
697
+ """Clear session memory but KEEP the summary (long-term memory)."""
698
+ super().clear()
699
+ # NOTE: Summary is intentionally NOT cleared here
700
+ # Summary acts as long-term memory that persists across chat sessions
701
+
702
+ def clear_with_summary(self):
703
+ """Clear session memory AND the summary (full reset)."""
704
+ super().clear()
705
+ if self.summary_memory:
706
+ self.summary_memory.clear_summary()
707
+
708
+ def clear_user_history(self):
709
+ """Clear all user temp history but KEEP summaries."""
710
+ super().clear_user_history()
711
+ # NOTE: Summaries are intentionally NOT cleared
712
+ # They persist as long-term memory for the user
713
+
714
+ def clear_all_including_summaries(self):
715
+ """Clear ALL user data including summaries (complete wipe)."""
716
+ super().clear_user_history()
717
+ if self.summary_memory:
718
+ self.summary_memory.clear_all_user_summaries()
719
+
720
+
721
+ def create_memory(session_id: str, user_id: str = "default", max_messages: int = 20) -> ChatMemory:
722
+ """Create a standard ChatMemory instance."""
723
+ from database import get_db
724
+ db = get_db()
725
+ return ChatMemory(session_id=session_id, user_id=user_id, max_messages=max_messages, db_connection=db)
726
+
727
+
728
+ def create_enhanced_memory(
729
+ session_id: str,
730
+ user_id: str = "default",
731
+ max_messages: int = 20,
732
+ llm_client=None,
733
+ enable_summarization: bool = True,
734
+ summary_threshold: int = 10
735
+ ) -> EnhancedChatMemory:
736
+ """
737
+ Create an EnhancedChatMemory with summarization support.
738
+
739
+ Args:
740
+ session_id: Unique session identifier
741
+ user_id: User identifier for per-user memory isolation
742
+ max_messages: Maximum messages to keep in short-term memory
743
+ llm_client: LLM client for summarization (can be set later)
744
+ enable_summarization: Whether to enable automatic summarization
745
+ summary_threshold: Summarize after this many messages
746
+
747
+ Returns:
748
+ EnhancedChatMemory instance with summarization capabilities
749
+ """
750
+ from database import get_db
751
+ db = get_db()
752
+ return EnhancedChatMemory(
753
+ session_id=session_id,
754
+ user_id=user_id,
755
+ max_messages=max_messages,
756
+ db_connection=db,
757
+ llm_client=llm_client,
758
+ enable_summarization=enable_summarization,
759
+ summary_threshold=summary_threshold
760
+ )
rag/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAG module exports."""
2
+
3
+ from .embeddings import (
4
+ EmbeddingProvider,
5
+ SentenceTransformerEmbedding,
6
+ OpenAIEmbedding,
7
+ get_embedding_provider,
8
+ create_embedding_provider
9
+ )
10
+ from .document_processor import Document, DocumentProcessor, get_document_processor
11
+ from .vector_store import VectorStore, get_vector_store
12
+ from .rag_engine import RAGEngine, get_rag_engine
13
+
14
+ __all__ = [
15
+ "EmbeddingProvider", "SentenceTransformerEmbedding", "OpenAIEmbedding",
16
+ "get_embedding_provider", "create_embedding_provider",
17
+ "Document", "DocumentProcessor", "get_document_processor",
18
+ "VectorStore", "get_vector_store",
19
+ "RAGEngine", "get_rag_engine"
20
+ ]
rag/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (873 Bytes). View file
 
rag/__pycache__/document_processor.cpython-311.pyc ADDED
Binary file (6.98 kB). View file
 
rag/__pycache__/embeddings.cpython-311.pyc ADDED
Binary file (9.78 kB). View file
 
rag/__pycache__/rag_engine.cpython-311.pyc ADDED
Binary file (5.62 kB). View file
 
rag/__pycache__/vector_store.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
rag/document_processor.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Document Processor for RAG.
3
+
4
+ Converts database rows into semantic documents for embedding.
5
+ """
6
+
7
+ import logging
8
+ import hashlib
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Dict, Any, Optional, Generator
11
+ import re
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class Document:
18
+ """Semantic document from the database."""
19
+ id: str
20
+ content: str
21
+ metadata: Dict[str, Any] = field(default_factory=dict)
22
+ table_name: str = ""
23
+ column_name: str = ""
24
+ primary_key_value: Optional[str] = None
25
+ chunk_index: int = 0
26
+ total_chunks: int = 1
27
+
28
+ def __post_init__(self):
29
+ if not self.id:
30
+ hash_input = f"{self.table_name}:{self.column_name}:{self.primary_key_value}:{self.chunk_index}"
31
+ self.id = hashlib.md5(hash_input.encode()).hexdigest()
32
+
33
+ def to_context_string(self) -> str:
34
+ source = f"[Source: {self.table_name}.{self.column_name}"
35
+ if self.primary_key_value:
36
+ source += f" (id: {self.primary_key_value})"
37
+ source += "]"
38
+ return f"{source}\n{self.content}"
39
+
40
+
41
+ class TextChunker:
42
+ """Splits long text into overlapping chunks."""
43
+
44
+ def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
45
+ self.chunk_size = chunk_size
46
+ self.chunk_overlap = chunk_overlap
47
+ self.sentence_pattern = re.compile(r'(?<=[.!?])\s+(?=[A-Z])')
48
+
49
+ def chunk_text(self, text: str) -> List[str]:
50
+ if not text or len(text) <= self.chunk_size:
51
+ return [text] if text else []
52
+
53
+ sentences = self.sentence_pattern.split(text)
54
+ chunks = []
55
+ current_chunk = []
56
+ current_length = 0
57
+
58
+ for sentence in sentences:
59
+ sentence = sentence.strip()
60
+ if not sentence:
61
+ continue
62
+
63
+ if current_length + len(sentence) + 1 > self.chunk_size:
64
+ if current_chunk:
65
+ chunks.append(' '.join(current_chunk))
66
+ current_chunk = [sentence]
67
+ current_length = len(sentence)
68
+ else:
69
+ current_chunk.append(sentence)
70
+ current_length += len(sentence) + 1
71
+
72
+ if current_chunk:
73
+ chunks.append(' '.join(current_chunk))
74
+
75
+ return chunks if chunks else [text]
76
+
77
+
78
+ class DocumentProcessor:
79
+ """Converts database rows into semantic documents."""
80
+
81
+ def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
82
+ self.chunker = TextChunker(chunk_size, chunk_overlap)
83
+
84
+ def process_row(
85
+ self, row: Dict[str, Any], table_name: str,
86
+ text_columns: List[str], primary_key_column: Optional[str] = None
87
+ ) -> List[Document]:
88
+ documents = []
89
+ pk_value = str(row.get(primary_key_column, "")) if primary_key_column else None
90
+
91
+ for column_name in text_columns:
92
+ text = row.get(column_name)
93
+ if not text or not isinstance(text, str):
94
+ continue
95
+
96
+ text = text.strip()
97
+ if not text:
98
+ continue
99
+
100
+ chunks = self.chunker.chunk_text(text)
101
+ for i, chunk in enumerate(chunks):
102
+ doc = Document(
103
+ id="", content=chunk, table_name=table_name,
104
+ column_name=column_name, primary_key_value=pk_value,
105
+ chunk_index=i, total_chunks=len(chunks),
106
+ metadata={"table": table_name, "column": column_name, "pk": pk_value}
107
+ )
108
+ documents.append(doc)
109
+
110
+ return documents
111
+
112
+ def process_rows(
113
+ self, rows: List[Dict[str, Any]], table_name: str,
114
+ text_columns: List[str], primary_key_column: Optional[str] = None
115
+ ) -> Generator[Document, None, None]:
116
+ for row in rows:
117
+ for doc in self.process_row(row, table_name, text_columns, primary_key_column):
118
+ yield doc
119
+
120
+
121
+ def get_document_processor(chunk_size: int = 500, chunk_overlap: int = 50) -> DocumentProcessor:
122
+ return DocumentProcessor(chunk_size, chunk_overlap)
rag/embeddings.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Embedding Generation Module.
3
+
4
+ Supports:
5
+ - Sentence Transformers (local, free)
6
+ - OpenAI Embeddings (cloud, paid)
7
+
8
+ Configurable via environment variables.
9
+ """
10
+
11
+ import logging
12
+ from abc import ABC, abstractmethod
13
+ from typing import List, Optional
14
+ import numpy as np
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class EmbeddingProvider(ABC):
20
+ """Abstract base class for embedding providers."""
21
+
22
+ @abstractmethod
23
+ def embed_text(self, text: str) -> np.ndarray:
24
+ """Generate embedding for a single text."""
25
+ pass
26
+
27
+ @abstractmethod
28
+ def embed_texts(self, texts: List[str]) -> np.ndarray:
29
+ """Generate embeddings for multiple texts."""
30
+ pass
31
+
32
+ @property
33
+ @abstractmethod
34
+ def dimension(self) -> int:
35
+ """Return the embedding dimension."""
36
+ pass
37
+
38
+
39
+ class SentenceTransformerEmbedding(EmbeddingProvider):
40
+ """
41
+ Sentence Transformers embedding provider.
42
+
43
+ Uses local models, no API key required.
44
+ Default: all-MiniLM-L6-v2 (384 dimensions)
45
+ """
46
+
47
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
48
+ """
49
+ Initialize the Sentence Transformer model.
50
+
51
+ Args:
52
+ model_name: HuggingFace model name
53
+ """
54
+ self.model_name = model_name
55
+ self._model = None
56
+ self._dimension = None
57
+
58
+ @property
59
+ def model(self):
60
+ """Lazy load the model."""
61
+ if self._model is None:
62
+ try:
63
+ from sentence_transformers import SentenceTransformer
64
+ logger.info(f"Loading embedding model: {self.model_name}")
65
+ self._model = SentenceTransformer(self.model_name)
66
+ self._dimension = self._model.get_sentence_embedding_dimension()
67
+ logger.info(f"Model loaded. Embedding dimension: {self._dimension}")
68
+ except ImportError:
69
+ raise ImportError(
70
+ "sentence-transformers is required. Install with: pip install sentence-transformers"
71
+ )
72
+ return self._model
73
+
74
+ @property
75
+ def dimension(self) -> int:
76
+ """Get embedding dimension."""
77
+ if self._dimension is None:
78
+ _ = self.model # Force model load
79
+ return self._dimension
80
+
81
+ def embed_text(self, text: str) -> np.ndarray:
82
+ """Generate embedding for a single text."""
83
+ return self.model.encode(text, convert_to_numpy=True)
84
+
85
+ def embed_texts(self, texts: List[str]) -> np.ndarray:
86
+ """Generate embeddings for multiple texts."""
87
+ return self.model.encode(texts, convert_to_numpy=True, show_progress_bar=len(texts) > 100)
88
+
89
+
90
+ class OpenAIEmbedding(EmbeddingProvider):
91
+ """
92
+ OpenAI embedding provider.
93
+
94
+ Uses OpenAI API, requires API key.
95
+ Default: text-embedding-3-small (1536 dimensions)
96
+ """
97
+
98
+ DIMENSION_MAP = {
99
+ "text-embedding-3-small": 1536,
100
+ "text-embedding-3-large": 3072,
101
+ "text-embedding-ada-002": 1536
102
+ }
103
+
104
+ def __init__(self, api_key: str, model_name: str = "text-embedding-3-small"):
105
+ """
106
+ Initialize OpenAI embedding client.
107
+
108
+ Args:
109
+ api_key: OpenAI API key
110
+ model_name: OpenAI embedding model name
111
+ """
112
+ self.api_key = api_key
113
+ self.model_name = model_name
114
+ self._client = None
115
+ self._dimension = self.DIMENSION_MAP.get(model_name, 1536)
116
+
117
+ @property
118
+ def client(self):
119
+ """Lazy load the OpenAI client."""
120
+ if self._client is None:
121
+ try:
122
+ from openai import OpenAI
123
+ self._client = OpenAI(api_key=self.api_key)
124
+ except ImportError:
125
+ raise ImportError(
126
+ "openai is required. Install with: pip install openai"
127
+ )
128
+ return self._client
129
+
130
+ @property
131
+ def dimension(self) -> int:
132
+ """Get embedding dimension."""
133
+ return self._dimension
134
+
135
+ def embed_text(self, text: str) -> np.ndarray:
136
+ """Generate embedding for a single text."""
137
+ response = self.client.embeddings.create(
138
+ input=text,
139
+ model=self.model_name
140
+ )
141
+ return np.array(response.data[0].embedding, dtype=np.float32)
142
+
143
+ def embed_texts(self, texts: List[str]) -> np.ndarray:
144
+ """Generate embeddings for multiple texts (batch)."""
145
+ # OpenAI API supports batching up to 2048 inputs
146
+ batch_size = 100
147
+ all_embeddings = []
148
+
149
+ for i in range(0, len(texts), batch_size):
150
+ batch = texts[i:i + batch_size]
151
+ response = self.client.embeddings.create(
152
+ input=batch,
153
+ model=self.model_name
154
+ )
155
+ embeddings = [np.array(d.embedding, dtype=np.float32) for d in response.data]
156
+ all_embeddings.extend(embeddings)
157
+
158
+ return np.array(all_embeddings)
159
+
160
+
161
+ def create_embedding_provider(
162
+ provider_type: str = "sentence_transformers",
163
+ model_name: Optional[str] = None,
164
+ api_key: Optional[str] = None
165
+ ) -> EmbeddingProvider:
166
+ """
167
+ Factory function to create the appropriate embedding provider.
168
+
169
+ Args:
170
+ provider_type: "sentence_transformers" or "openai"
171
+ model_name: Model name (optional, uses defaults)
172
+ api_key: API key for OpenAI (required if using OpenAI)
173
+
174
+ Returns:
175
+ Configured EmbeddingProvider instance
176
+ """
177
+ if provider_type == "openai":
178
+ if not api_key:
179
+ raise ValueError("OpenAI API key is required for OpenAI embeddings")
180
+ return OpenAIEmbedding(
181
+ api_key=api_key,
182
+ model_name=model_name or "text-embedding-3-small"
183
+ )
184
+ else:
185
+ return SentenceTransformerEmbedding(
186
+ model_name=model_name or "sentence-transformers/all-MiniLM-L6-v2"
187
+ )
188
+
189
+
190
+ # Global embedding provider instance
191
+ _embedding_provider: Optional[EmbeddingProvider] = None
192
+
193
+
194
+ def get_embedding_provider() -> EmbeddingProvider:
195
+ """Get or create the global embedding provider."""
196
+ global _embedding_provider
197
+ if _embedding_provider is None:
198
+ # Default to sentence transformers (free, local)
199
+ _embedding_provider = SentenceTransformerEmbedding()
200
+ return _embedding_provider
201
+
202
+
203
+ def set_embedding_provider(provider: EmbeddingProvider):
204
+ """Set the global embedding provider."""
205
+ global _embedding_provider
206
+ _embedding_provider = provider
rag/rag_engine.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG Engine - Orchestrates the retrieval-augmented generation pipeline.
3
+
4
+ Handles:
5
+ - Automatic indexing of text columns from the database
6
+ - Semantic retrieval using FAISS
7
+ - Context building for the LLM
8
+ """
9
+
10
+ import logging
11
+ from typing import List, Dict, Any, Optional, Tuple
12
+
13
+ from .document_processor import Document, get_document_processor
14
+ from .vector_store import VectorStore, get_vector_store
15
+ from .embeddings import get_embedding_provider
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class RAGEngine:
21
+ """Main RAG engine for semantic retrieval from database text."""
22
+
23
+ def __init__(self, vector_store: Optional[VectorStore] = None):
24
+ self.vector_store = vector_store or get_vector_store()
25
+ self.doc_processor = get_document_processor()
26
+ self.indexed_tables: Dict[str, bool] = {}
27
+
28
+ def index_table(
29
+ self,
30
+ table_name: str,
31
+ rows: List[Dict[str, Any]],
32
+ text_columns: List[str],
33
+ primary_key_column: Optional[str] = None
34
+ ) -> int:
35
+ """
36
+ Index text data from a table.
37
+
38
+ Returns:
39
+ Number of documents indexed
40
+ """
41
+ documents = list(self.doc_processor.process_rows(
42
+ rows, table_name, text_columns, primary_key_column
43
+ ))
44
+
45
+ if documents:
46
+ self.vector_store.add_documents(documents)
47
+ self.indexed_tables[table_name] = True
48
+ logger.info(f"Indexed {len(documents)} documents from {table_name}")
49
+
50
+ return len(documents)
51
+
52
+ def search(
53
+ self,
54
+ query: str,
55
+ top_k: int = 5,
56
+ table_filter: Optional[List[str]] = None
57
+ ) -> List[Tuple[Document, float]]:
58
+ """
59
+ Search for relevant documents.
60
+
61
+ Args:
62
+ query: Search query
63
+ top_k: Number of results
64
+ table_filter: Optional list of tables to search in
65
+
66
+ Returns:
67
+ List of (document, score) tuples
68
+ """
69
+ results = self.vector_store.search(query, top_k=top_k * 2)
70
+
71
+ if table_filter:
72
+ results = [
73
+ (doc, score) for doc, score in results
74
+ if doc.table_name in table_filter
75
+ ]
76
+
77
+ return results[:top_k]
78
+
79
+ def get_context(
80
+ self,
81
+ query: str,
82
+ top_k: int = 5,
83
+ table_filter: Optional[List[str]] = None
84
+ ) -> str:
85
+ """
86
+ Get formatted context for LLM from search results.
87
+ """
88
+ results = self.search(query, top_k, table_filter)
89
+
90
+ if not results:
91
+ return "No relevant information found in the database."
92
+
93
+ context_parts = []
94
+ for doc, score in results:
95
+ context_parts.append(doc.to_context_string())
96
+
97
+ return "\n\n---\n\n".join(context_parts)
98
+
99
+ def clear_index(self):
100
+ """Clear the entire index."""
101
+ self.vector_store.clear()
102
+ self.indexed_tables = {}
103
+
104
+ def save(self):
105
+ """Save the index to disk."""
106
+ self.vector_store.save()
107
+
108
+ @property
109
+ def document_count(self) -> int:
110
+ return len(self.vector_store)
111
+
112
+
113
+ _rag_engine: Optional[RAGEngine] = None
114
+
115
+
116
+ def get_rag_engine() -> RAGEngine:
117
+ global _rag_engine
118
+ if _rag_engine is None:
119
+ _rag_engine = RAGEngine()
120
+ return _rag_engine
rag/vector_store.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FAISS Vector Store for RAG.
3
+
4
+ Manages the FAISS index for semantic search over database text content.
5
+ """
6
+
7
+ import logging
8
+ import pickle
9
+ import os
10
+ from typing import List, Dict, Any, Optional, Tuple
11
+ import numpy as np
12
+
13
+ try:
14
+ import faiss
15
+ except ImportError:
16
+ faiss = None
17
+
18
+ from .document_processor import Document
19
+ from .embeddings import get_embedding_provider, EmbeddingProvider
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class VectorStore:
25
+ """FAISS-based vector store for semantic search."""
26
+
27
+ def __init__(
28
+ self,
29
+ embedding_provider: Optional[EmbeddingProvider] = None,
30
+ index_path: str = "./faiss_index"
31
+ ):
32
+ if faiss is None:
33
+ raise ImportError("faiss-cpu is required. Install with: pip install faiss-cpu")
34
+
35
+ self.embedding_provider = embedding_provider or get_embedding_provider()
36
+ self.index_path = index_path
37
+ self.dimension = self.embedding_provider.dimension
38
+
39
+ self.index: Optional[faiss.IndexFlatIP] = None
40
+ self.documents: List[Document] = []
41
+ self.id_to_idx: Dict[str, int] = {}
42
+
43
+ self._initialize_index()
44
+
45
+ def _initialize_index(self):
46
+ """Initialize or load the FAISS index."""
47
+ index_file = os.path.join(self.index_path, "index.faiss")
48
+ docs_file = os.path.join(self.index_path, "documents.pkl")
49
+
50
+ if os.path.exists(index_file) and os.path.exists(docs_file):
51
+ try:
52
+ # Check file size - if 0 something is wrong
53
+ if os.path.getsize(index_file) > 0:
54
+ self.index = faiss.read_index(index_file)
55
+ with open(docs_file, 'rb') as f:
56
+ self.documents = pickle.load(f)
57
+ self.id_to_idx = {doc.id: i for i, doc in enumerate(self.documents)}
58
+
59
+ # Verify index dimension matches expected
60
+ if self.index.d != self.dimension:
61
+ logger.warning(f"Index dimension mismatch: {self.index.d} != {self.dimension}. Resetting.")
62
+ raise ValueError("Dimension mismatch")
63
+
64
+ logger.info(f"Loaded index with {len(self.documents)} documents")
65
+ return
66
+ except (Exception, RuntimeError) as e:
67
+ logger.warning(f"Failed to load index (might be corrupted or memory error): {e}")
68
+ # If loading fails, we should probably backup the broken files or just overwrite
69
+ if os.path.exists(index_file):
70
+ try:
71
+ os.rename(index_file, index_file + ".bak")
72
+ os.rename(docs_file, docs_file + ".bak")
73
+ except:
74
+ pass
75
+
76
+ # Create new index (Inner Product for cosine similarity with normalized vectors)
77
+ self.index = faiss.IndexFlatIP(self.dimension)
78
+ self.documents = []
79
+ self.id_to_idx = {}
80
+ logger.info(f"Created new FAISS index with dimension {self.dimension}")
81
+
82
+ def add_documents(self, documents: List[Document], batch_size: int = 100):
83
+ """Add documents to the vector store."""
84
+ if not documents:
85
+ return
86
+
87
+ new_docs = [doc for doc in documents if doc.id not in self.id_to_idx]
88
+ if not new_docs:
89
+ logger.info("No new documents to add")
90
+ return
91
+
92
+ logger.info(f"Adding {len(new_docs)} documents to index")
93
+
94
+ for i in range(0, len(new_docs), batch_size):
95
+ batch = new_docs[i:i + batch_size]
96
+ texts = [doc.content for doc in batch]
97
+
98
+ embeddings = self.embedding_provider.embed_texts(texts)
99
+
100
+ # Normalize for cosine similarity
101
+ faiss.normalize_L2(embeddings)
102
+
103
+ start_idx = len(self.documents)
104
+ self.index.add(embeddings)
105
+
106
+ for j, doc in enumerate(batch):
107
+ self.documents.append(doc)
108
+ self.id_to_idx[doc.id] = start_idx + j
109
+
110
+ logger.info(f"Index now contains {len(self.documents)} documents")
111
+
112
+ def search(
113
+ self, query: str, top_k: int = 5, threshold: float = 0.0
114
+ ) -> List[Tuple[Document, float]]:
115
+ """Search for similar documents."""
116
+ if not self.documents:
117
+ return []
118
+
119
+ query_embedding = self.embedding_provider.embed_text(query)
120
+ query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
121
+ faiss.normalize_L2(query_embedding)
122
+
123
+ k = min(top_k, len(self.documents))
124
+ scores, indices = self.index.search(query_embedding, k)
125
+
126
+ results = []
127
+ for score, idx in zip(scores[0], indices[0]):
128
+ if idx >= 0 and score >= threshold:
129
+ results.append((self.documents[idx], float(score)))
130
+
131
+ return results
132
+
133
+ def save(self):
134
+ """Save the index to disk."""
135
+ os.makedirs(self.index_path, exist_ok=True)
136
+
137
+ index_file = os.path.join(self.index_path, "index.faiss")
138
+ docs_file = os.path.join(self.index_path, "documents.pkl")
139
+
140
+ faiss.write_index(self.index, index_file)
141
+ with open(docs_file, 'wb') as f:
142
+ pickle.dump(self.documents, f)
143
+
144
+ logger.info(f"Saved index with {len(self.documents)} documents")
145
+
146
+ def clear(self):
147
+ """Clear the index."""
148
+ self.index = faiss.IndexFlatIP(self.dimension)
149
+ self.documents = []
150
+ self.id_to_idx = {}
151
+
152
+ # Delete files
153
+ index_file = os.path.join(self.index_path, "index.faiss")
154
+ docs_file = os.path.join(self.index_path, "documents.pkl")
155
+
156
+ for f in [index_file, docs_file]:
157
+ if os.path.exists(f):
158
+ os.remove(f)
159
+
160
+ logger.info("Index cleared")
161
+
162
+ def __len__(self) -> int:
163
+ return len(self.documents)
164
+
165
+
166
+ _vector_store: Optional[VectorStore] = None
167
+
168
+
169
+ def get_vector_store() -> VectorStore:
170
+ global _vector_store
171
+ if _vector_store is None:
172
+ _vector_store = VectorStore()
173
+ return _vector_store
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Schema-Agnostic Database Chatbot
2
+ # Multi-Database Support: MySQL, PostgreSQL, SQLite
3
+
4
+ # Core dependencies
5
+ streamlit>=1.30.0
6
+ sqlalchemy>=2.0.0
7
+
8
+ # Database drivers
9
+ pymysql>=1.1.0 # MySQL driver
10
+ psycopg2-binary>=2.9.9 # PostgreSQL driver
11
+ # SQLite is built into Python - no driver needed
12
+
13
+ # RAG dependencies
14
+ faiss-cpu>=1.7.4
15
+ sentence-transformers>=2.2.2
16
+
17
+ # LLM dependencies
18
+ groq>=0.4.0 # FREE API!
19
+ openai>=1.0.0 # Optional, for OpenAI provider
20
+
21
+ # For local models (optional)
22
+ # transformers>=4.36.0
23
+ # torch>=2.0.0
24
+
25
+ # SQL parsing and validation
26
+ sqlparse>=0.4.4
27
+
28
+ # Utilities
29
+ python-dotenv>=1.0.0
30
+ numpy>=1.24.0
31
+ pandas>=2.0.0
router.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query Router - Decides between RAG, SQL, or hybrid approach.
3
+
4
+ Analyzes user intent and routes to the appropriate handler.
5
+ """
6
+
7
+ import logging
8
+ from enum import Enum
9
+ from typing import Dict, Any, Optional, Tuple, List
10
+ from dataclasses import dataclass
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class QueryType(Enum):
16
+ RAG = "rag" # Semantic search in text
17
+ SQL = "sql" # Structured query
18
+ HYBRID = "hybrid" # Both RAG and SQL
19
+ GENERAL = "general" # General conversation
20
+
21
+
22
+ @dataclass
23
+ class RoutingDecision:
24
+ query_type: QueryType
25
+ confidence: float
26
+ reasoning: str
27
+ suggested_tables: List[str] = None
28
+
29
+ def __post_init__(self):
30
+ if self.suggested_tables is None:
31
+ self.suggested_tables = []
32
+
33
+
34
+ class QueryRouter:
35
+ """Routes queries to appropriate handlers based on intent analysis."""
36
+
37
+ ROUTING_PROMPT = """Analyze this user query and determine the best approach to answer it.
38
+
39
+ DATABASE SCHEMA:
40
+ {schema}
41
+
42
+ USER QUERY: {query}
43
+
44
+ Determine if this query needs:
45
+ 1. RAG - Semantic search through text content (searching for meanings, concepts, descriptions)
46
+ 2. SQL - Structured database query (counting, filtering, aggregating, specific lookups)
47
+ 3. HYBRID - Both semantic search and structured query
48
+ 4. GENERAL - General conversation not requiring database access
49
+
50
+ Respond in this exact format:
51
+ TYPE: [RAG|SQL|HYBRID|GENERAL]
52
+ CONFIDENCE: [0.0-1.0]
53
+ TABLES: [comma-separated list of relevant tables, or NONE]
54
+ REASONING: [brief explanation]"""
55
+
56
+ def __init__(self, llm_client=None):
57
+ self.llm_client = llm_client
58
+
59
+ def set_llm_client(self, llm_client):
60
+ self.llm_client = llm_client
61
+
62
+ def route(self, query: str, schema_context: str) -> RoutingDecision:
63
+ """Analyze query and determine routing."""
64
+ if not self.llm_client:
65
+ # Fallback to simple heuristics
66
+ return self._heuristic_route(query)
67
+
68
+ prompt = self.ROUTING_PROMPT.format(schema=schema_context, query=query)
69
+
70
+ try:
71
+ response = self.llm_client.chat([
72
+ {"role": "system", "content": "You are a query routing assistant."},
73
+ {"role": "user", "content": prompt}
74
+ ])
75
+ return self._parse_routing_response(response)
76
+ except Exception as e:
77
+ logger.warning(f"LLM routing failed: {e}, using heuristics")
78
+ return self._heuristic_route(query)
79
+
80
+ def _parse_routing_response(self, response: str) -> RoutingDecision:
81
+ """Parse LLM routing response."""
82
+ lines = response.strip().split('\n')
83
+
84
+ query_type = QueryType.GENERAL
85
+ confidence = 0.5
86
+ tables = []
87
+ reasoning = ""
88
+
89
+ for line in lines:
90
+ line = line.strip()
91
+ if line.startswith("TYPE:"):
92
+ type_str = line.replace("TYPE:", "").strip().upper()
93
+ query_type = QueryType[type_str] if type_str in QueryType.__members__ else QueryType.GENERAL
94
+ elif line.startswith("CONFIDENCE:"):
95
+ try:
96
+ confidence = float(line.replace("CONFIDENCE:", "").strip())
97
+ except ValueError:
98
+ confidence = 0.5
99
+ elif line.startswith("TABLES:"):
100
+ tables_str = line.replace("TABLES:", "").strip()
101
+ if tables_str.upper() != "NONE":
102
+ tables = [t.strip() for t in tables_str.split(",")]
103
+ elif line.startswith("REASONING:"):
104
+ reasoning = line.replace("REASONING:", "").strip()
105
+
106
+ return RoutingDecision(query_type, confidence, reasoning, tables)
107
+
108
+ def _heuristic_route(self, query: str) -> RoutingDecision:
109
+ """Simple heuristic-based routing when LLM is unavailable."""
110
+ query_lower = query.lower()
111
+
112
+ # SQL keywords - for structured data retrieval
113
+ sql_keywords = [
114
+ 'how many', 'count', 'total', 'average', 'sum', 'max', 'min',
115
+ 'list all', 'show all', 'find all', 'get all', 'between',
116
+ 'greater than', 'less than', 'equal to', 'top', 'bottom',
117
+ # Data listing patterns
118
+ 'what products', 'what customers', 'what orders', 'what items',
119
+ 'show me', 'list', 'display', 'give me', 'get me',
120
+ 'all products', 'all customers', 'all orders',
121
+ 'products do you have', 'customers do you have',
122
+ 'from new york', 'from chicago', 'from los angeles',
123
+ # Specific lookups
124
+ 'price of', 'cost of', 'stock of', 'quantity',
125
+ 'where', 'which', 'who'
126
+ ]
127
+
128
+ # RAG keywords - for semantic/conceptual questions
129
+ rag_keywords = [
130
+ 'what is the policy', 'explain', 'describe', 'tell me about',
131
+ 'meaning of', 'definition', 'why', 'how does', 'what does',
132
+ 'similar to', 'return policy', 'shipping policy', 'warranty',
133
+ 'support', 'help with', 'information about', 'details about'
134
+ ]
135
+
136
+ sql_score = sum(1 for kw in sql_keywords if kw in query_lower)
137
+ rag_score = sum(1 for kw in rag_keywords if kw in query_lower)
138
+
139
+ # Boost SQL score for common listing patterns
140
+ if any(word in query_lower for word in ['products', 'customers', 'orders', 'items']):
141
+ if any(word in query_lower for word in ['what', 'show', 'list', 'all', 'have']):
142
+ sql_score += 2
143
+
144
+ if sql_score > rag_score:
145
+ return RoutingDecision(QueryType.SQL, 0.8, "SQL query for data retrieval")
146
+ elif rag_score > sql_score:
147
+ return RoutingDecision(QueryType.RAG, 0.8, "Semantic search for concepts")
148
+ elif sql_score > 0 and rag_score > 0:
149
+ return RoutingDecision(QueryType.HYBRID, 0.6, "Mixed query type")
150
+ else:
151
+ # Default to SQL for simple questions about data
152
+ if any(word in query_lower for word in ['products', 'customers', 'orders']):
153
+ return RoutingDecision(QueryType.SQL, 0.6, "Default to SQL for data tables")
154
+ return RoutingDecision(QueryType.RAG, 0.5, "Default to semantic search")
155
+
156
+
157
+ _router: Optional[QueryRouter] = None
158
+
159
+
160
+ def get_query_router() -> QueryRouter:
161
+ global _router
162
+ if _router is None:
163
+ _router = QueryRouter()
164
+ return _router
sql/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """SQL module exports."""
2
+
3
+ from .validator import SQLValidator, SQLValidationError, get_sql_validator
4
+ from .generator import SQLGenerator, get_sql_generator
5
+
6
+ __all__ = [
7
+ "SQLValidator", "SQLValidationError", "get_sql_validator",
8
+ "SQLGenerator", "get_sql_generator"
9
+ ]
sql/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (480 Bytes). View file
 
sql/__pycache__/generator.cpython-311.pyc ADDED
Binary file (6.94 kB). View file
 
sql/__pycache__/validator.cpython-311.pyc ADDED
Binary file (7.14 kB). View file
 
sql/generator.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text-to-SQL Generator - Multi-Database Support.
3
+
4
+ Uses LLM to generate SQL queries from natural language,
5
+ with dynamic schema context. Supports MySQL, PostgreSQL, and SQLite.
6
+ """
7
+
8
+ import logging
9
+ from typing import Optional, Dict, Any, List, Tuple
10
+ import re
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_sql_dialect(db_type: str) -> str:
16
+ """Get the SQL dialect name for the given database type."""
17
+ dialects = {
18
+ "mysql": "MySQL",
19
+ "postgresql": "PostgreSQL",
20
+ "sqlite": "SQLite"
21
+ }
22
+ return dialects.get(db_type, "SQL")
23
+
24
+
25
+ def get_dialect_specific_hints(db_type: str) -> str:
26
+ """Get database-specific hints for SQL generation."""
27
+ if db_type == "postgresql":
28
+ return """
29
+ PostgreSQL-SPECIFIC NOTES:
30
+ - Use ILIKE for case-insensitive pattern matching (instead of LIKE)
31
+ - String concatenation uses || operator
32
+ - Use LIMIT at the end of queries
33
+ - Boolean values are TRUE/FALSE (not 1/0)
34
+ - Use double quotes for identifiers with special chars, single quotes for strings
35
+ """
36
+ elif db_type == "sqlite":
37
+ return """
38
+ SQLite-SPECIFIC NOTES:
39
+ - LIKE is case-insensitive for ASCII characters by default
40
+ - Use || for string concatenation
41
+ - No ILIKE - use LIKE (case-insensitive) or GLOB (case-sensitive)
42
+ - Use LIMIT at the end of queries
43
+ - Boolean values are 1/0
44
+ - Uses strftime() for date functions instead of DATE_FORMAT
45
+ """
46
+ else: # MySQL
47
+ return """
48
+ MySQL-SPECIFIC NOTES:
49
+ - LIKE is case-insensitive for non-binary strings
50
+ - Use CONCAT() for string concatenation
51
+ - Use LIMIT at the end of queries
52
+ - Boolean values are 1/0
53
+ - Use backticks for identifiers with special chars, single quotes for strings
54
+ """
55
+
56
+
57
+ class SQLGenerator:
58
+ """Generates SQL queries from natural language using LLM."""
59
+
60
+ SYSTEM_PROMPT_TEMPLATE = """You are a SQL expert. Generate {dialect} SELECT queries based on user questions.
61
+
62
+ RULES:
63
+ 1. ONLY generate SELECT statements.
64
+ 2. NEVER use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, or TRUNCATE.
65
+ 3. Always include a LIMIT clause (max 50 rows unless specified).
66
+ 4. Use table and column names EXACTLY as shown in the schema.
67
+ 5. AMBIGUITY: If the user asks for a category, type, or specific value, and you are unsure which column it belongs to:
68
+ - Check multiple likely columns (e.g., `category`, `sub_category`, `type`, `description`).
69
+ - Use pattern matching for flexibility.
70
+ - Use `OR` to combine multiple column checks.
71
+ 6. DATA AWARENESS: In footwear databases, specific types like 'Formal', 'Casual', or 'Sports' often appear in `sub_category` OR `category`. Check both if available.
72
+ 7. Return ONLY the SQL query, no explanations.
73
+
74
+ {dialect_hints}
75
+
76
+ DATABASE SCHEMA:
77
+ {schema}
78
+
79
+ Generate a single {dialect} SELECT query to answer the user's question."""
80
+
81
+ def __init__(self, llm_client=None, db_type: str = "mysql"):
82
+ self.llm_client = llm_client
83
+ self.db_type = db_type
84
+
85
+ def set_llm_client(self, llm_client):
86
+ self.llm_client = llm_client
87
+
88
+ def set_db_type(self, db_type: str):
89
+ """Set the database type for SQL generation."""
90
+ self.db_type = db_type
91
+
92
+ def generate(
93
+ self,
94
+ question: str,
95
+ schema_context: str,
96
+ chat_history: Optional[List[Dict[str, str]]] = None
97
+ ) -> Tuple[str, str]:
98
+ """
99
+ Generate SQL from natural language.
100
+
101
+ Returns:
102
+ Tuple of (sql_query, explanation)
103
+ """
104
+ if not self.llm_client:
105
+ raise ValueError("LLM client not configured")
106
+
107
+ dialect = get_sql_dialect(self.db_type)
108
+ dialect_hints = get_dialect_specific_hints(self.db_type)
109
+
110
+ system_prompt = self.SYSTEM_PROMPT_TEMPLATE.format(
111
+ dialect=dialect,
112
+ dialect_hints=dialect_hints,
113
+ schema=schema_context
114
+ )
115
+
116
+ messages = [{"role": "system", "content": system_prompt}]
117
+
118
+ if chat_history:
119
+ for msg in chat_history[-3:]: # Last 3 exchanges for context
120
+ messages.append(msg)
121
+
122
+ messages.append({"role": "user", "content": question})
123
+
124
+ response = self.llm_client.chat(messages)
125
+
126
+ # Extract SQL from response
127
+ sql = self._extract_sql(response)
128
+
129
+ return sql, response
130
+
131
+ def _extract_sql(self, response: str) -> str:
132
+ """Extract SQL query from LLM response."""
133
+ # Look for SQL in code blocks
134
+ code_block = re.search(r'```(?:sql)?\\s*(.*?)```', response, re.DOTALL | re.IGNORECASE)
135
+ if code_block:
136
+ return code_block.group(1).strip()
137
+
138
+ # Look for SELECT statement
139
+ select_match = re.search(
140
+ r'(SELECT\\s+.+?(?:;|$))',
141
+ response,
142
+ re.DOTALL | re.IGNORECASE
143
+ )
144
+ if select_match:
145
+ return select_match.group(1).strip().rstrip(';')
146
+
147
+ return response.strip()
148
+
149
+
150
+ _generator: Optional[SQLGenerator] = None
151
+
152
+
153
+ def get_sql_generator(db_type: str = "mysql") -> SQLGenerator:
154
+ global _generator
155
+ if _generator is None:
156
+ _generator = SQLGenerator(db_type=db_type)
157
+ else:
158
+ _generator.set_db_type(db_type)
159
+ return _generator
sql/validator.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL Validator - Security layer for SQL queries.
3
+
4
+ Ensures ONLY safe SELECT queries are executed.
5
+ Validates against whitelist and blocks dangerous operations.
6
+ """
7
+
8
+ import logging
9
+ import re
10
+ from typing import List, Tuple, Optional, Set
11
+ import sqlparse
12
+ from sqlparse.sql import Statement, Token, Identifier, IdentifierList
13
+ from sqlparse.tokens import Keyword, DML
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class SQLValidationError(Exception):
19
+ """Raised when SQL validation fails."""
20
+ pass
21
+
22
+
23
+ class SQLValidator:
24
+ """Validates SQL queries for safety before execution."""
25
+
26
+ FORBIDDEN_KEYWORDS = {
27
+ 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER',
28
+ 'TRUNCATE', 'GRANT', 'REVOKE', 'EXECUTE', 'EXEC',
29
+ 'INTO OUTFILE', 'INTO DUMPFILE', 'LOAD_FILE', 'LOAD DATA'
30
+ }
31
+
32
+ FORBIDDEN_PATTERNS = [
33
+ r'INTO\s+OUTFILE',
34
+ r'INTO\s+DUMPFILE',
35
+ r'LOAD_FILE\s*\(',
36
+ r'LOAD\s+DATA',
37
+ r';\s*(?:DROP|DELETE|UPDATE|INSERT)', # Multi-statement attacks
38
+ r'--', # SQL comments (potential injection)
39
+ r'/\*.*\*/', # Block comments
40
+ ]
41
+
42
+ def __init__(self, allowed_tables: Optional[Set[str]] = None, max_limit: int = 100):
43
+ self.allowed_tables = allowed_tables or set()
44
+ self.max_limit = max_limit
45
+ self._compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.FORBIDDEN_PATTERNS]
46
+
47
+ def set_allowed_tables(self, tables: List[str]):
48
+ """Set the whitelist of allowed tables."""
49
+ self.allowed_tables = set(tables)
50
+
51
+ def validate(self, sql: str) -> Tuple[bool, str, Optional[str]]:
52
+ """
53
+ Validate SQL query for safety.
54
+
55
+ Returns:
56
+ Tuple of (is_valid, message, sanitized_sql)
57
+ """
58
+ if not sql or not sql.strip():
59
+ return False, "Empty SQL query", None
60
+
61
+ sql = sql.strip()
62
+
63
+ # Check for forbidden patterns
64
+ for pattern in self._compiled_patterns:
65
+ if pattern.search(sql):
66
+ return False, f"Forbidden pattern detected in query", None
67
+
68
+ # Parse SQL
69
+ try:
70
+ parsed = sqlparse.parse(sql)
71
+ except Exception as e:
72
+ return False, f"Failed to parse SQL: {e}", None
73
+
74
+ if not parsed:
75
+ return False, "Failed to parse SQL query", None
76
+
77
+ # Only allow single statements
78
+ if len(parsed) > 1:
79
+ return False, "Multiple SQL statements not allowed", None
80
+
81
+ statement = parsed[0]
82
+
83
+ # Check statement type
84
+ stmt_type = statement.get_type()
85
+ if stmt_type != 'SELECT':
86
+ return False, f"Only SELECT statements allowed, got: {stmt_type}", None
87
+
88
+ # Check for forbidden keywords in tokens
89
+ sql_upper = sql.upper()
90
+ for keyword in self.FORBIDDEN_KEYWORDS:
91
+ if keyword in sql_upper:
92
+ return False, f"Forbidden keyword detected: {keyword}", None
93
+
94
+ # Extract and validate tables
95
+ tables = self._extract_tables(statement)
96
+ if self.allowed_tables:
97
+ invalid_tables = tables - self.allowed_tables
98
+ if invalid_tables:
99
+ return False, f"Access denied to tables: {invalid_tables}", None
100
+
101
+ # Ensure LIMIT clause exists
102
+ sanitized = self._ensure_limit(sql)
103
+
104
+ return True, "Query validated successfully", sanitized
105
+
106
+ def _extract_tables(self, statement: Statement) -> Set[str]:
107
+ """Extract table names from a SELECT statement using regex."""
108
+ tables = set()
109
+ sql = str(statement)
110
+
111
+ # Use regex to find tables after FROM and JOIN
112
+ # Pattern: FROM table_name or JOIN table_name
113
+ from_pattern = re.compile(
114
+ r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)',
115
+ re.IGNORECASE
116
+ )
117
+ join_pattern = re.compile(
118
+ r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)',
119
+ re.IGNORECASE
120
+ )
121
+
122
+ # Find all FROM tables
123
+ for match in from_pattern.finditer(sql):
124
+ tables.add(match.group(1))
125
+
126
+ # Find all JOIN tables
127
+ for match in join_pattern.finditer(sql):
128
+ tables.add(match.group(1))
129
+
130
+ return tables
131
+
132
+ def _ensure_limit(self, sql: str) -> str:
133
+ """Ensure the query has a LIMIT clause."""
134
+ sql_upper = sql.upper()
135
+
136
+ if 'LIMIT' in sql_upper:
137
+ # Check if limit is too high
138
+ limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
139
+ if limit_match:
140
+ current_limit = int(limit_match.group(1))
141
+ if current_limit > self.max_limit:
142
+ # Replace with max limit
143
+ sql = re.sub(
144
+ r'LIMIT\s+\d+',
145
+ f'LIMIT {self.max_limit}',
146
+ sql,
147
+ flags=re.IGNORECASE
148
+ )
149
+ return sql
150
+ else:
151
+ # Add LIMIT clause
152
+ sql = sql.rstrip(';').strip()
153
+ return f"{sql} LIMIT {self.max_limit}"
154
+
155
+
156
+ _validator: Optional[SQLValidator] = None
157
+
158
+
159
+ def get_sql_validator() -> SQLValidator:
160
+ global _validator
161
+ if _validator is None:
162
+ _validator = SQLValidator()
163
+ return _validator