Spaces:
Sleeping
Sleeping
Upload 34 files
Browse files- Dockerfile +60 -0
- README.md +111 -11
- app.py +336 -0
- chatbot.py +391 -0
- config.py +280 -0
- database/__init__.py +30 -0
- database/__pycache__/__init__.cpython-311.pyc +0 -0
- database/__pycache__/connection.cpython-311.pyc +0 -0
- database/__pycache__/schema_introspector.cpython-311.pyc +0 -0
- database/connection.py +231 -0
- database/schema_introspector.py +648 -0
- llm/__init__.py +17 -0
- llm/__pycache__/__init__.cpython-311.pyc +0 -0
- llm/__pycache__/client.cpython-311.pyc +0 -0
- llm/client.py +188 -0
- memory.py +760 -0
- rag/__init__.py +20 -0
- rag/__pycache__/__init__.cpython-311.pyc +0 -0
- rag/__pycache__/document_processor.cpython-311.pyc +0 -0
- rag/__pycache__/embeddings.cpython-311.pyc +0 -0
- rag/__pycache__/rag_engine.cpython-311.pyc +0 -0
- rag/__pycache__/vector_store.cpython-311.pyc +0 -0
- rag/document_processor.py +122 -0
- rag/embeddings.py +206 -0
- rag/rag_engine.py +120 -0
- rag/vector_store.py +173 -0
- requirements.txt +31 -0
- router.py +164 -0
- sql/__init__.py +9 -0
- sql/__pycache__/__init__.cpython-311.pyc +0 -0
- sql/__pycache__/generator.cpython-311.pyc +0 -0
- sql/__pycache__/validator.cpython-311.pyc +0 -0
- sql/generator.py +159 -0
- sql/validator.py +163 -0
Dockerfile
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces - Docker SDK
|
| 2 |
+
# Schema-Agnostic Database Chatbot with RAG
|
| 3 |
+
|
| 4 |
+
FROM python:3.11-slim
|
| 5 |
+
|
| 6 |
+
# Set working directory
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Set environment variables
|
| 10 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 11 |
+
PYTHONUNBUFFERED=1 \
|
| 12 |
+
PYTHONPATH=/app \
|
| 13 |
+
HF_HOME=/app/.cache \
|
| 14 |
+
TRANSFORMERS_CACHE=/app/.cache/transformers \
|
| 15 |
+
SENTENCE_TRANSFORMERS_HOME=/app/.cache/sentence_transformers
|
| 16 |
+
|
| 17 |
+
# Install system dependencies
|
| 18 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 19 |
+
build-essential \
|
| 20 |
+
curl \
|
| 21 |
+
git \
|
| 22 |
+
libpq-dev \
|
| 23 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 24 |
+
&& apt-get clean
|
| 25 |
+
|
| 26 |
+
# Create a non-root user for security
|
| 27 |
+
RUN useradd -m -u 1000 appuser
|
| 28 |
+
|
| 29 |
+
# Create cache directories with proper permissions
|
| 30 |
+
RUN mkdir -p /app/.cache/sentence_transformers /app/.cache/transformers /app/faiss_index \
|
| 31 |
+
&& chown -R appuser:appuser /app
|
| 32 |
+
|
| 33 |
+
# Copy requirements first for better caching
|
| 34 |
+
COPY --chown=appuser:appuser requirements.txt .
|
| 35 |
+
|
| 36 |
+
# Install Python dependencies
|
| 37 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 38 |
+
pip install --no-cache-dir -r requirements.txt
|
| 39 |
+
|
| 40 |
+
# Copy application code
|
| 41 |
+
COPY --chown=appuser:appuser . .
|
| 42 |
+
|
| 43 |
+
# Switch to non-root user
|
| 44 |
+
USER appuser
|
| 45 |
+
|
| 46 |
+
# Expose Streamlit port (HF Spaces expects 7860)
|
| 47 |
+
EXPOSE 7860
|
| 48 |
+
|
| 49 |
+
# Health check
|
| 50 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 51 |
+
CMD curl --fail http://localhost:7860/_stcore/health || exit 1
|
| 52 |
+
|
| 53 |
+
# Run Streamlit
|
| 54 |
+
CMD ["streamlit", "run", "app.py", \
|
| 55 |
+
"--server.port=7860", \
|
| 56 |
+
"--server.address=0.0.0.0", \
|
| 57 |
+
"--server.enableCORS=true", \
|
| 58 |
+
"--server.enableXsrfProtection=false", \
|
| 59 |
+
"--browser.gatherUsageStats=false", \
|
| 60 |
+
"--server.fileWatcherType=none"]
|
README.md
CHANGED
|
@@ -1,11 +1,111 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
license: mit
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Database Copilot
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# 🤖 Database Copilot
|
| 13 |
+
|
| 14 |
+
A production-grade, **schema-agnostic chatbot** that connects to **any** database (MySQL, PostgreSQL, or SQLite) and provides intelligent querying through **RAG** (Retrieval-Augmented Generation) and **Text-to-SQL**.
|
| 15 |
+
|
| 16 |
+
**🆓 Powered by Groq for FREE LLM inference!**
|
| 17 |
+
|
| 18 |
+
## 🌟 Features
|
| 19 |
+
|
| 20 |
+
- **Multi-Database Support**: Works with **MySQL**, **PostgreSQL**, and **SQLite**
|
| 21 |
+
- **Schema-Agnostic**: Works with ANY database schema - no hardcoding required
|
| 22 |
+
- **Dynamic Introspection**: Automatically discovers tables, columns, and relationships
|
| 23 |
+
- **Hybrid Query Routing**: Intelligently routes queries to RAG or SQL based on intent
|
| 24 |
+
- **Semantic Search (RAG)**: FAISS-based vector search for text content
|
| 25 |
+
- **Text-to-SQL**: LLM-powered SQL generation with dialect-specific syntax
|
| 26 |
+
- **Security First**: Read-only queries, SQL validation, table whitelisting
|
| 27 |
+
- **FREE LLM**: Uses Groq API (free tier) with Llama 3.3, Mixtral, and Gemma models
|
| 28 |
+
|
| 29 |
+
## 🚀 Getting Started
|
| 30 |
+
|
| 31 |
+
### 1. Configure Secrets
|
| 32 |
+
|
| 33 |
+
This Space requires the following secrets to be set in your Hugging Face Space settings:
|
| 34 |
+
|
| 35 |
+
**Required:**
|
| 36 |
+
| Secret Name | Description |
|
| 37 |
+
|------------|-------------|
|
| 38 |
+
| `GROQ_API_KEY` | Your Groq API key ([Get FREE key](https://console.groq.com)) |
|
| 39 |
+
|
| 40 |
+
**Database Configuration (choose one):**
|
| 41 |
+
|
| 42 |
+
#### For MySQL:
|
| 43 |
+
| Secret Name | Description |
|
| 44 |
+
|------------|-------------|
|
| 45 |
+
| `DB_TYPE` | Set to `mysql` |
|
| 46 |
+
| `DB_HOST` | MySQL server hostname |
|
| 47 |
+
| `DB_PORT` | MySQL port (default: 3306) |
|
| 48 |
+
| `DB_DATABASE` | Database name |
|
| 49 |
+
| `DB_USERNAME` | Database username |
|
| 50 |
+
| `DB_PASSWORD` | Database password |
|
| 51 |
+
|
| 52 |
+
#### For PostgreSQL:
|
| 53 |
+
| Secret Name | Description |
|
| 54 |
+
|------------|-------------|
|
| 55 |
+
| `DB_TYPE` | Set to `postgresql` |
|
| 56 |
+
| `DB_HOST` | PostgreSQL server hostname |
|
| 57 |
+
| `DB_PORT` | PostgreSQL port (default: 5432) |
|
| 58 |
+
| `DB_DATABASE` | Database name |
|
| 59 |
+
| `DB_USERNAME` | Database username |
|
| 60 |
+
| `DB_PASSWORD` | Database password |
|
| 61 |
+
|
| 62 |
+
#### For SQLite:
|
| 63 |
+
| Secret Name | Description |
|
| 64 |
+
|------------|-------------|
|
| 65 |
+
| `DB_TYPE` | Set to `sqlite` |
|
| 66 |
+
| `SQLITE_PATH` | Path to SQLite database file |
|
| 67 |
+
|
| 68 |
+
**Optional:**
|
| 69 |
+
| Secret Name | Description | Default |
|
| 70 |
+
|------------|-------------|---------|
|
| 71 |
+
| `GROQ_MODEL` | Groq model to use | `llama-3.3-70b-versatile` |
|
| 72 |
+
| `DB_SSL_CA` | Path to SSL CA certificate | None |
|
| 73 |
+
|
| 74 |
+
### 2. Connect & Use
|
| 75 |
+
|
| 76 |
+
1. Click **"Connect & Initialize"** in the sidebar
|
| 77 |
+
2. Click **"Index Text Data"** to enable semantic search
|
| 78 |
+
3. Start asking questions about your data!
|
| 79 |
+
|
| 80 |
+
## 💬 Example Queries
|
| 81 |
+
|
| 82 |
+
**Semantic Search (RAG):**
|
| 83 |
+
- "What products are related to electronics?"
|
| 84 |
+
- "Tell me about customer feedback on shipping"
|
| 85 |
+
|
| 86 |
+
**Structured Queries (SQL):**
|
| 87 |
+
- "How many orders were placed last month?"
|
| 88 |
+
- "Show me the top 10 customers by revenue"
|
| 89 |
+
|
| 90 |
+
**Hybrid:**
|
| 91 |
+
- "Find customers who complained about delivery and show their order count"
|
| 92 |
+
|
| 93 |
+
## 🔒 Security
|
| 94 |
+
|
| 95 |
+
- **Read-Only Transactions**: All queries run in read-only mode
|
| 96 |
+
- **SQL Validation**: Only SELECT statements allowed
|
| 97 |
+
- **Forbidden Keywords**: INSERT, UPDATE, DELETE, DROP, etc. are blocked
|
| 98 |
+
- **Table Whitelisting**: Only discovered tables are queryable
|
| 99 |
+
- **Automatic LIMIT**: All queries have LIMIT clauses enforced
|
| 100 |
+
|
| 101 |
+
## 🆓 Why Groq?
|
| 102 |
+
|
| 103 |
+
[Groq](https://console.groq.com) provides **FREE API access** with incredibly fast inference:
|
| 104 |
+
- **Llama 3.3 70B** - Best quality, state-of-the-art
|
| 105 |
+
- **Llama 3.1 8B Instant** - Fastest responses
|
| 106 |
+
- **Mixtral 8x7B** - Great for code and SQL
|
| 107 |
+
- **Gemma 2 9B** - Google's efficient model
|
| 108 |
+
|
| 109 |
+
## 📝 License
|
| 110 |
+
|
| 111 |
+
MIT License
|
app.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Schema-Agnostic Database Chatbot - Streamlit Application
|
| 3 |
+
|
| 4 |
+
A production-grade chatbot that connects to ANY MySQL database
|
| 5 |
+
and provides intelligent querying through RAG and Text-to-SQL.
|
| 6 |
+
|
| 7 |
+
Uses Groq for FREE LLM inference!
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Load .env FIRST before any other imports
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
load_dotenv(Path(__file__).parent / ".env")
|
| 16 |
+
|
| 17 |
+
import streamlit as st
|
| 18 |
+
import uuid
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
|
| 21 |
+
# Page config must be first
|
| 22 |
+
st.set_page_config(
|
| 23 |
+
page_title="Database Copilot",
|
| 24 |
+
page_icon="🤖",
|
| 25 |
+
layout="wide",
|
| 26 |
+
initial_sidebar_state="expanded"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Imports
|
| 30 |
+
from config import config
|
| 31 |
+
from database import get_db, get_schema, get_introspector
|
| 32 |
+
from llm import create_llm_client
|
| 33 |
+
from chatbot import create_chatbot, DatabaseChatbot
|
| 34 |
+
from memory import create_memory, create_enhanced_memory, EnhancedChatMemory
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Groq models (all FREE!)
|
| 38 |
+
GROQ_MODELS = [
|
| 39 |
+
"llama-3.3-70b-versatile",
|
| 40 |
+
"llama-3.1-8b-instant",
|
| 41 |
+
"mixtral-8x7b-32768",
|
| 42 |
+
"gemma2-9b-it"
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def init_session_state():
|
| 47 |
+
"""Initialize Streamlit session state."""
|
| 48 |
+
if "session_id" not in st.session_state:
|
| 49 |
+
st.session_state.session_id = str(uuid.uuid4())
|
| 50 |
+
|
| 51 |
+
if "messages" not in st.session_state:
|
| 52 |
+
st.session_state.messages = []
|
| 53 |
+
|
| 54 |
+
if "chatbot" not in st.session_state:
|
| 55 |
+
st.session_state.chatbot = None
|
| 56 |
+
|
| 57 |
+
if "initialized" not in st.session_state:
|
| 58 |
+
st.session_state.initialized = False
|
| 59 |
+
|
| 60 |
+
if "user_id" not in st.session_state:
|
| 61 |
+
st.session_state.user_id = "default"
|
| 62 |
+
|
| 63 |
+
if "enable_summarization" not in st.session_state:
|
| 64 |
+
st.session_state.enable_summarization = True
|
| 65 |
+
|
| 66 |
+
if "summary_threshold" not in st.session_state:
|
| 67 |
+
st.session_state.summary_threshold = 10
|
| 68 |
+
|
| 69 |
+
if "memory" not in st.session_state:
|
| 70 |
+
st.session_state.memory = create_enhanced_memory(
|
| 71 |
+
st.session_state.session_id,
|
| 72 |
+
user_id=st.session_state.user_id,
|
| 73 |
+
enable_summarization=st.session_state.enable_summarization,
|
| 74 |
+
summary_threshold=st.session_state.summary_threshold
|
| 75 |
+
)
|
| 76 |
+
# Clear temporary memory on fresh load/reload
|
| 77 |
+
st.session_state.memory.clear_user_history()
|
| 78 |
+
|
| 79 |
+
if "indexed" not in st.session_state:
|
| 80 |
+
st.session_state.indexed = False
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def render_sidebar():
|
| 84 |
+
"""Render the configuration sidebar."""
|
| 85 |
+
with st.sidebar:
|
| 86 |
+
st.title("⚙️ Settings")
|
| 87 |
+
|
| 88 |
+
# User Profile
|
| 89 |
+
st.subheader("👤 User Profile")
|
| 90 |
+
user_id = st.text_input(
|
| 91 |
+
"User ID / Name",
|
| 92 |
+
value=st.session_state.get("user_id", "default"),
|
| 93 |
+
key="user_id_input",
|
| 94 |
+
help="Your unique ID for private memory storage"
|
| 95 |
+
)
|
| 96 |
+
if user_id != st.session_state.get("user_id"):
|
| 97 |
+
# USER ID CHANGE - Same behavior as "New Chat":
|
| 98 |
+
# 1. Clear temporary memory (session history) for clean start
|
| 99 |
+
# 2. Permanent memory remains UNTOUCHED (per-user storage)
|
| 100 |
+
st.session_state.user_id = user_id
|
| 101 |
+
st.session_state.session_id = str(uuid.uuid4()) # New session
|
| 102 |
+
st.session_state.messages = [] # Clear UI chat history
|
| 103 |
+
|
| 104 |
+
# Create memory for new user and clear their temp history (fresh start)
|
| 105 |
+
st.session_state.memory = create_enhanced_memory(
|
| 106 |
+
st.session_state.session_id,
|
| 107 |
+
user_id=user_id,
|
| 108 |
+
enable_summarization=st.session_state.enable_summarization,
|
| 109 |
+
summary_threshold=st.session_state.summary_threshold
|
| 110 |
+
)
|
| 111 |
+
st.session_state.memory.clear_user_history() # Clears _chatbot_memory, NOT _chatbot_permanent_memory_v2
|
| 112 |
+
st.rerun()
|
| 113 |
+
|
| 114 |
+
st.divider()
|
| 115 |
+
|
| 116 |
+
# Initialize Button
|
| 117 |
+
if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"):
|
| 118 |
+
with st.spinner("Connecting to database..."):
|
| 119 |
+
success = initialize_chatbot()
|
| 120 |
+
if success:
|
| 121 |
+
st.success("✅ Connected!")
|
| 122 |
+
st.rerun()
|
| 123 |
+
|
| 124 |
+
# Index Button (after initialization)
|
| 125 |
+
if st.session_state.initialized:
|
| 126 |
+
if st.button("📚 Index Text Data", use_container_width=True):
|
| 127 |
+
with st.spinner("Indexing text data..."):
|
| 128 |
+
index_data()
|
| 129 |
+
st.success("✅ Indexed!")
|
| 130 |
+
st.rerun()
|
| 131 |
+
|
| 132 |
+
st.divider()
|
| 133 |
+
|
| 134 |
+
# Status
|
| 135 |
+
st.subheader("📊 Status")
|
| 136 |
+
if st.session_state.initialized:
|
| 137 |
+
st.success("Database: Connected")
|
| 138 |
+
schema = get_schema()
|
| 139 |
+
st.info(f"Tables: {len(schema.tables)}")
|
| 140 |
+
|
| 141 |
+
if st.session_state.indexed:
|
| 142 |
+
from rag import get_rag_engine
|
| 143 |
+
engine = get_rag_engine()
|
| 144 |
+
st.info(f"Indexed Docs: {engine.document_count}")
|
| 145 |
+
else:
|
| 146 |
+
st.warning("Not connected")
|
| 147 |
+
|
| 148 |
+
# New Chat (Context Switch)
|
| 149 |
+
# New Chat (Context Switch)
|
| 150 |
+
if st.button("➕ New Chat", use_container_width=True, type="secondary"):
|
| 151 |
+
# Clear previous session from DB
|
| 152 |
+
if "memory" in st.session_state and st.session_state.memory:
|
| 153 |
+
st.session_state.memory.clear()
|
| 154 |
+
|
| 155 |
+
st.session_state.messages = []
|
| 156 |
+
st.session_state.session_id = str(uuid.uuid4()) # Generate new session ID
|
| 157 |
+
|
| 158 |
+
# Preserve current user ID and memory settings
|
| 159 |
+
current_user = st.session_state.get("user_id", "default")
|
| 160 |
+
st.session_state.memory = create_enhanced_memory(
|
| 161 |
+
st.session_state.session_id,
|
| 162 |
+
user_id=current_user,
|
| 163 |
+
enable_summarization=st.session_state.enable_summarization,
|
| 164 |
+
summary_threshold=st.session_state.summary_threshold
|
| 165 |
+
)
|
| 166 |
+
# Set LLM client if available
|
| 167 |
+
if "llm" in st.session_state and st.session_state.llm:
|
| 168 |
+
st.session_state.memory.set_llm_client(st.session_state.llm)
|
| 169 |
+
st.rerun()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def initialize_chatbot() -> bool:
|
| 173 |
+
"""Initialize the chatbot using environment variables."""
|
| 174 |
+
try:
|
| 175 |
+
# Use Groq as default provider (from environment)
|
| 176 |
+
api_key = os.getenv("GROQ_API_KEY", "")
|
| 177 |
+
model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 178 |
+
|
| 179 |
+
if not api_key:
|
| 180 |
+
st.error("GROQ_API_KEY not configured. Please set it in your .env file.")
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
llm = create_llm_client("groq", api_key=api_key, model=model)
|
| 184 |
+
|
| 185 |
+
# Create and initialize chatbot
|
| 186 |
+
chatbot = create_chatbot(llm)
|
| 187 |
+
|
| 188 |
+
# Explicitly set LLM client (also configures router and sql_generator)
|
| 189 |
+
chatbot.set_llm_client(llm)
|
| 190 |
+
|
| 191 |
+
success, msg = chatbot.initialize()
|
| 192 |
+
|
| 193 |
+
if success:
|
| 194 |
+
st.session_state.chatbot = chatbot
|
| 195 |
+
st.session_state.llm = llm # Store LLM separately too
|
| 196 |
+
st.session_state.initialized = True
|
| 197 |
+
|
| 198 |
+
# Set LLM client on memory for summarization
|
| 199 |
+
if hasattr(st.session_state.memory, 'set_llm_client'):
|
| 200 |
+
st.session_state.memory.set_llm_client(llm)
|
| 201 |
+
|
| 202 |
+
return True
|
| 203 |
+
else:
|
| 204 |
+
st.error(f"Initialization failed: {msg}")
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
st.error(f"Error: {str(e)}")
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def index_data():
|
| 213 |
+
"""Index text data from the database."""
|
| 214 |
+
if st.session_state.chatbot:
|
| 215 |
+
progress = st.progress(0)
|
| 216 |
+
status = st.empty()
|
| 217 |
+
|
| 218 |
+
schema = get_schema()
|
| 219 |
+
total_tables = len(schema.tables)
|
| 220 |
+
indexed = 0
|
| 221 |
+
|
| 222 |
+
def progress_callback(table_name, docs):
|
| 223 |
+
nonlocal indexed
|
| 224 |
+
indexed += 1
|
| 225 |
+
progress.progress(indexed / total_tables)
|
| 226 |
+
status.text(f"Indexed {table_name}: {docs} documents")
|
| 227 |
+
|
| 228 |
+
total_docs = st.session_state.chatbot.index_text_data(progress_callback)
|
| 229 |
+
st.session_state.indexed = True
|
| 230 |
+
status.text(f"Total: {total_docs} documents indexed")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def render_schema_explorer():
|
| 234 |
+
"""Render schema explorer in an expander."""
|
| 235 |
+
if not st.session_state.initialized:
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
with st.expander("📋 Database Schema", expanded=False):
|
| 239 |
+
schema = get_schema()
|
| 240 |
+
|
| 241 |
+
for table_name, table_info in schema.tables.items():
|
| 242 |
+
with st.container():
|
| 243 |
+
st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)")
|
| 244 |
+
|
| 245 |
+
cols = []
|
| 246 |
+
for col in table_info.columns:
|
| 247 |
+
pk = "🔑" if col.is_primary_key else ""
|
| 248 |
+
txt = "📝" if col.is_text_type else ""
|
| 249 |
+
cols.append(f"`{col.name}` {col.data_type} {pk}{txt}")
|
| 250 |
+
|
| 251 |
+
st.caption(" | ".join(cols))
|
| 252 |
+
st.divider()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def render_chat_interface():
|
| 256 |
+
"""Render the main chat interface."""
|
| 257 |
+
st.title("🤖 Database Copilot")
|
| 258 |
+
st.caption("Schema-agnostic chatbot powered by Groq (FREE!)")
|
| 259 |
+
|
| 260 |
+
# Schema explorer
|
| 261 |
+
render_schema_explorer()
|
| 262 |
+
|
| 263 |
+
# Chat container
|
| 264 |
+
chat_container = st.container()
|
| 265 |
+
|
| 266 |
+
with chat_container:
|
| 267 |
+
# Display messages
|
| 268 |
+
for msg in st.session_state.messages:
|
| 269 |
+
with st.chat_message(msg["role"]):
|
| 270 |
+
st.markdown(msg["content"])
|
| 271 |
+
|
| 272 |
+
# Show metadata for assistant messages
|
| 273 |
+
if msg["role"] == "assistant" and "metadata" in msg:
|
| 274 |
+
meta = msg["metadata"]
|
| 275 |
+
if meta.get("query_type"):
|
| 276 |
+
st.caption(f"Query type: {meta['query_type']}")
|
| 277 |
+
if meta.get("sql_query"):
|
| 278 |
+
with st.expander("SQL Query"):
|
| 279 |
+
st.code(meta["sql_query"], language="sql")
|
| 280 |
+
|
| 281 |
+
# Chat input
|
| 282 |
+
if prompt := st.chat_input("Ask about your data..."):
|
| 283 |
+
if not st.session_state.initialized:
|
| 284 |
+
st.error("Please connect to a database first!")
|
| 285 |
+
return
|
| 286 |
+
|
| 287 |
+
# Add user message
|
| 288 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 289 |
+
st.session_state.memory.add_message("user", prompt)
|
| 290 |
+
|
| 291 |
+
with st.chat_message("user"):
|
| 292 |
+
st.markdown(prompt)
|
| 293 |
+
|
| 294 |
+
# Get response
|
| 295 |
+
with st.chat_message("assistant"):
|
| 296 |
+
with st.spinner("Thinking..."):
|
| 297 |
+
response = st.session_state.chatbot.chat(
|
| 298 |
+
prompt,
|
| 299 |
+
st.session_state.memory
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
st.markdown(response.answer)
|
| 303 |
+
|
| 304 |
+
# Show metadata
|
| 305 |
+
if response.query_type != "general":
|
| 306 |
+
st.caption(f"Query type: {response.query_type}")
|
| 307 |
+
|
| 308 |
+
if response.sql_query:
|
| 309 |
+
with st.expander("SQL Query"):
|
| 310 |
+
st.code(response.sql_query, language="sql")
|
| 311 |
+
|
| 312 |
+
if response.sql_results:
|
| 313 |
+
with st.expander("Results"):
|
| 314 |
+
st.dataframe(response.sql_results)
|
| 315 |
+
|
| 316 |
+
# Save to memory
|
| 317 |
+
st.session_state.messages.append({
|
| 318 |
+
"role": "assistant",
|
| 319 |
+
"content": response.answer,
|
| 320 |
+
"metadata": {
|
| 321 |
+
"query_type": response.query_type,
|
| 322 |
+
"sql_query": response.sql_query
|
| 323 |
+
}
|
| 324 |
+
})
|
| 325 |
+
st.session_state.memory.add_message("assistant", response.answer)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def main():
|
| 329 |
+
"""Main application entry point."""
|
| 330 |
+
init_session_state()
|
| 331 |
+
render_sidebar()
|
| 332 |
+
render_chat_interface()
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
if __name__ == "__main__":
|
| 336 |
+
main()
|
chatbot.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chatbot Core - Main orchestrator for the schema-agnostic database chatbot.
|
| 3 |
+
|
| 4 |
+
Combines all components:
|
| 5 |
+
- Schema introspection
|
| 6 |
+
- Query routing
|
| 7 |
+
- RAG retrieval
|
| 8 |
+
- SQL generation & execution
|
| 9 |
+
- Response generation
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
from database import get_db, get_schema, get_introspector
|
| 17 |
+
from rag import get_rag_engine
|
| 18 |
+
from sql import get_sql_generator, get_sql_validator
|
| 19 |
+
from llm import create_llm_client, LLMClient
|
| 20 |
+
from router import get_query_router, QueryType
|
| 21 |
+
from memory import ChatMemory, EnhancedChatMemory, create_memory
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class ChatResponse:
|
| 28 |
+
"""Response from the chatbot."""
|
| 29 |
+
answer: str
|
| 30 |
+
query_type: str
|
| 31 |
+
sources: List[Dict[str, Any]] = None
|
| 32 |
+
sql_query: Optional[str] = None
|
| 33 |
+
sql_results: Optional[List[Dict]] = None
|
| 34 |
+
error: Optional[str] = None
|
| 35 |
+
|
| 36 |
+
def __post_init__(self):
|
| 37 |
+
if self.sources is None:
|
| 38 |
+
self.sources = []
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DatabaseChatbot:
|
| 42 |
+
"""Main chatbot class orchestrating all components."""
|
| 43 |
+
|
| 44 |
+
RESPONSE_PROMPT = """You are a helpful database assistant. Answer the user's question based on the provided context.
|
| 45 |
+
|
| 46 |
+
IMPORTANT: Use the conversation history to understand follow-up questions. If the user refers to "it", "that", "the product", etc., look at the previous messages to understand what they're referring to.
|
| 47 |
+
|
| 48 |
+
{context}
|
| 49 |
+
|
| 50 |
+
USER QUESTION: {question}
|
| 51 |
+
|
| 52 |
+
INSTRUCTIONS:
|
| 53 |
+
- Answer ONLY based on the provided context AND conversation history
|
| 54 |
+
- Do NOT use outside knowledge, general assumptions, or hallucinate facts
|
| 55 |
+
- If the context doesn't contain the answer, explicitly state that the information is not available in the database
|
| 56 |
+
- Resolve pronouns using previous messages
|
| 57 |
+
- Be concise but complete
|
| 58 |
+
- Format data nicely
|
| 59 |
+
|
| 60 |
+
YOUR RESPONSE:"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, llm_client: Optional[LLMClient] = None):
|
| 63 |
+
self.db = get_db()
|
| 64 |
+
self.introspector = get_introspector()
|
| 65 |
+
self.rag_engine = get_rag_engine()
|
| 66 |
+
# Pass database type to SQL generator for dialect-specific SQL
|
| 67 |
+
db_type = self.db.db_type.value
|
| 68 |
+
self.sql_generator = get_sql_generator(db_type)
|
| 69 |
+
self.sql_validator = get_sql_validator()
|
| 70 |
+
self.router = get_query_router()
|
| 71 |
+
self.llm_client = llm_client
|
| 72 |
+
|
| 73 |
+
self._schema_initialized = False
|
| 74 |
+
self._rag_initialized = False
|
| 75 |
+
|
| 76 |
+
def set_llm_client(self, llm_client: LLMClient):
|
| 77 |
+
"""Configure the LLM client."""
|
| 78 |
+
self.llm_client = llm_client
|
| 79 |
+
self.sql_generator.set_llm_client(llm_client)
|
| 80 |
+
self.router.set_llm_client(llm_client)
|
| 81 |
+
|
| 82 |
+
def initialize(self) -> Tuple[bool, str]:
|
| 83 |
+
"""Initialize the chatbot by introspecting the database."""
|
| 84 |
+
try:
|
| 85 |
+
# Test connection
|
| 86 |
+
success, msg = self.db.test_connection()
|
| 87 |
+
if not success:
|
| 88 |
+
return False, f"Database connection failed: {msg}"
|
| 89 |
+
|
| 90 |
+
# Introspect schema
|
| 91 |
+
schema = self.introspector.introspect(force_refresh=True)
|
| 92 |
+
|
| 93 |
+
# Configure SQL validator with discovered tables
|
| 94 |
+
self.sql_validator.set_allowed_tables(schema.table_names)
|
| 95 |
+
|
| 96 |
+
self._schema_initialized = True
|
| 97 |
+
|
| 98 |
+
return True, f"Initialized with {len(schema.tables)} tables"
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Initialization failed: {e}")
|
| 102 |
+
return False, str(e)
|
| 103 |
+
|
| 104 |
+
def index_text_data(self, progress_callback=None) -> int:
|
| 105 |
+
"""Index all text data for RAG."""
|
| 106 |
+
if not self._schema_initialized:
|
| 107 |
+
raise RuntimeError("Chatbot not initialized. Call initialize() first.")
|
| 108 |
+
|
| 109 |
+
schema = get_schema()
|
| 110 |
+
total_docs = 0
|
| 111 |
+
|
| 112 |
+
for table_name, table_info in schema.tables.items():
|
| 113 |
+
text_cols = [c.name for c in table_info.text_columns]
|
| 114 |
+
if not text_cols:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
pk = table_info.primary_keys[0] if table_info.primary_keys else None
|
| 118 |
+
cols_to_select = text_cols + ([pk] if pk else [])
|
| 119 |
+
|
| 120 |
+
query = f"SELECT {', '.join(cols_to_select)} FROM {table_name} LIMIT 1000"
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
rows = self.db.execute_query(query)
|
| 124 |
+
docs = self.rag_engine.index_table(table_name, rows, text_cols, pk)
|
| 125 |
+
total_docs += docs
|
| 126 |
+
|
| 127 |
+
if progress_callback:
|
| 128 |
+
progress_callback(table_name, docs)
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.warning(f"Failed to index {table_name}: {e}")
|
| 132 |
+
|
| 133 |
+
self.rag_engine.save()
|
| 134 |
+
self._rag_initialized = True
|
| 135 |
+
|
| 136 |
+
return total_docs
|
| 137 |
+
|
| 138 |
+
def chat(self, query: str, memory: Optional[ChatMemory] = None) -> ChatResponse:
|
| 139 |
+
"""Process a user query and return a response."""
|
| 140 |
+
if not self._schema_initialized:
|
| 141 |
+
return ChatResponse(answer="Chatbot not initialized.", query_type="error",
|
| 142 |
+
error="Call initialize() first")
|
| 143 |
+
|
| 144 |
+
if not self.llm_client:
|
| 145 |
+
return ChatResponse(answer="LLM not configured.", query_type="error",
|
| 146 |
+
error="Configure LLM client first")
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
schema = get_schema()
|
| 150 |
+
schema_context = schema.to_context_string()
|
| 151 |
+
|
| 152 |
+
# Check for memory commands
|
| 153 |
+
# Check for memory commands
|
| 154 |
+
# Check for memory commands using regex for flexibility
|
| 155 |
+
import re
|
| 156 |
+
save_pattern = re.compile(r"(?:please\s+)?(?:save|remember|memorize)\s+(?:this|that)?\s*(?:to\s+(?:main\s+)?memory)?\s*(?:that)?\s*:?\s*(.*)", re.IGNORECASE)
|
| 157 |
+
match = save_pattern.match(query.strip())
|
| 158 |
+
|
| 159 |
+
# Check if it looks like a command (starts with command words)
|
| 160 |
+
is_command = bool(match) and (
|
| 161 |
+
query.lower().startswith(("save", "remember", "memorize")) or
|
| 162 |
+
"saved to" in query.lower() # specific user case "saved to main memory"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if is_command and memory:
|
| 166 |
+
content_to_save = match.group(1).strip() if match else ""
|
| 167 |
+
|
| 168 |
+
# If specific content is provided (e.g. "Remember that I like pizza")
|
| 169 |
+
if content_to_save:
|
| 170 |
+
# Save the explicit content
|
| 171 |
+
success = memory.save_permanent_context(content_to_save)
|
| 172 |
+
if success:
|
| 173 |
+
return ChatResponse(answer=f"💾 I've saved to your permanent memory: '{content_to_save}'", query_type="memory")
|
| 174 |
+
else:
|
| 175 |
+
return ChatResponse(answer="❌ Failed to save to permanent memory. Please try again.", query_type="memory")
|
| 176 |
+
|
| 177 |
+
# If no content (e.g. "Save this"), save the previous conversation turn
|
| 178 |
+
elif len(memory.messages) >= 2:
|
| 179 |
+
# [-1] is current command ("save to memory")
|
| 180 |
+
# [-2] is previous assistant response
|
| 181 |
+
# [-3] is previous user query (context for the response)
|
| 182 |
+
|
| 183 |
+
msgs_to_save = []
|
| 184 |
+
# We try to grab the last QA pair: User Prompt + AI Response
|
| 185 |
+
# memory.messages structure: [User, AI, User, AI, User(current)]
|
| 186 |
+
|
| 187 |
+
if len(memory.messages) >= 3:
|
| 188 |
+
msg_user = memory.messages[-3]
|
| 189 |
+
msg_ai = memory.messages[-2]
|
| 190 |
+
|
| 191 |
+
# Verify roles to ensure we are saving a Q&A pair
|
| 192 |
+
if msg_user.role == "user" and msg_ai.role == "assistant":
|
| 193 |
+
msgs_to_save = [msg_user, msg_ai]
|
| 194 |
+
|
| 195 |
+
if msgs_to_save:
|
| 196 |
+
# Format: "User: ... | Assistant: ..."
|
| 197 |
+
context_str = f"User: {msgs_to_save[0].content} | Assistant: {msgs_to_save[1].content}"
|
| 198 |
+
success = memory.save_permanent_context(context_str)
|
| 199 |
+
if success:
|
| 200 |
+
return ChatResponse(answer="💾 I've saved our last exchange to your permanent memory.", query_type="memory")
|
| 201 |
+
else:
|
| 202 |
+
return ChatResponse(answer="❌ Failed to save to permanent memory.", query_type="memory")
|
| 203 |
+
else:
|
| 204 |
+
return ChatResponse(answer="⚠️ I couldn't find a clear previous exchange to save. Try saying 'Remember that [fact]'.", query_type="memory")
|
| 205 |
+
else:
|
| 206 |
+
return ChatResponse(answer="⚠️ Nothing previous to save. Tell me something to remember first!", query_type="memory")
|
| 207 |
+
|
| 208 |
+
# Route the query
|
| 209 |
+
routing = self.router.route(query, schema_context)
|
| 210 |
+
|
| 211 |
+
# Get chat history for context
|
| 212 |
+
history = memory.get_context_messages(5) if memory else []
|
| 213 |
+
|
| 214 |
+
# Process based on route
|
| 215 |
+
if routing.query_type == QueryType.RAG:
|
| 216 |
+
return self._handle_rag(query, history)
|
| 217 |
+
elif routing.query_type == QueryType.SQL:
|
| 218 |
+
return self._handle_sql(query, schema_context, history)
|
| 219 |
+
elif routing.query_type == QueryType.HYBRID:
|
| 220 |
+
return self._handle_hybrid(query, schema_context, history)
|
| 221 |
+
else:
|
| 222 |
+
return self._handle_general(query, history)
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
logger.error(f"Chat error: {e}")
|
| 226 |
+
return ChatResponse(answer=f"Error: {str(e)}", query_type="error", error=str(e))
|
| 227 |
+
|
| 228 |
+
def _handle_rag(self, query: str, history: List[Dict]) -> ChatResponse:
|
| 229 |
+
"""Handle RAG-based query."""
|
| 230 |
+
context = self.rag_engine.get_context(query, top_k=5)
|
| 231 |
+
|
| 232 |
+
prompt = self.RESPONSE_PROMPT.format(context=f"RELEVANT DATA:\n{context}", question=query)
|
| 233 |
+
|
| 234 |
+
messages = self._construct_messages(
|
| 235 |
+
"You are a helpful database assistant.",
|
| 236 |
+
history,
|
| 237 |
+
prompt
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
answer = self.llm_client.chat(messages)
|
| 241 |
+
|
| 242 |
+
return ChatResponse(answer=answer, query_type="rag",
|
| 243 |
+
sources=[{"type": "semantic_search", "context": context[:500]}])
|
| 244 |
+
|
| 245 |
+
def _handle_sql(self, query: str, schema_context: str, history: List[Dict]) -> ChatResponse:
|
| 246 |
+
"""Handle SQL-based query."""
|
| 247 |
+
sql, explanation = self.sql_generator.generate(query, schema_context, history)
|
| 248 |
+
|
| 249 |
+
# Validate SQL
|
| 250 |
+
is_valid, msg, sanitized_sql = self.sql_validator.validate(sql)
|
| 251 |
+
if not is_valid:
|
| 252 |
+
return ChatResponse(answer=f"Could not generate safe query: {msg}",
|
| 253 |
+
query_type="sql", error=msg)
|
| 254 |
+
|
| 255 |
+
# Execute query
|
| 256 |
+
try:
|
| 257 |
+
results = self.db.execute_query(sanitized_sql)
|
| 258 |
+
except Exception as e:
|
| 259 |
+
return ChatResponse(answer=f"Query execution failed: {e}",
|
| 260 |
+
query_type="sql", sql_query=sanitized_sql, error=str(e))
|
| 261 |
+
|
| 262 |
+
# SMART FALLBACK: If SQL returns nothing, it might be a semantic issue (e.g. wrong column)
|
| 263 |
+
# We try RAG as a fallback if SQL found nothing
|
| 264 |
+
if not results:
|
| 265 |
+
logger.info(f"SQL returned no results for query: '{query}'. Falling back to RAG.")
|
| 266 |
+
rag_response = self._handle_rag(query, history)
|
| 267 |
+
|
| 268 |
+
# Combine the info: "I couldn't find an exact match in the rows, but here is what I found semantically:"
|
| 269 |
+
rag_response.answer = f"I couldn't find a direct match using a database query, but here is what I found in the product descriptions:\n\n{rag_response.answer}"
|
| 270 |
+
rag_response.query_type = "hybrid_fallback"
|
| 271 |
+
rag_response.sql_query = sanitized_sql
|
| 272 |
+
return rag_response
|
| 273 |
+
|
| 274 |
+
# Generate response
|
| 275 |
+
context = f"SQL QUERY:\n{sanitized_sql}\n\nRESULTS:\n{self._format_results(results)}"
|
| 276 |
+
prompt = self.RESPONSE_PROMPT.format(context=context, question=query)
|
| 277 |
+
|
| 278 |
+
messages = self._construct_messages(
|
| 279 |
+
"You are a helpful database assistant.",
|
| 280 |
+
history,
|
| 281 |
+
prompt
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
answer = self.llm_client.chat(messages)
|
| 285 |
+
|
| 286 |
+
return ChatResponse(answer=answer, query_type="sql",
|
| 287 |
+
sql_query=sanitized_sql, sql_results=results[:10])
|
| 288 |
+
|
| 289 |
+
def _handle_hybrid(self, query: str, schema_context: str, history: List[Dict]) -> ChatResponse:
|
| 290 |
+
"""Handle hybrid RAG + SQL query."""
|
| 291 |
+
# Get RAG context
|
| 292 |
+
rag_context = self.rag_engine.get_context(query, top_k=3)
|
| 293 |
+
|
| 294 |
+
# Try SQL as well
|
| 295 |
+
sql_context = ""
|
| 296 |
+
sql_query = None
|
| 297 |
+
try:
|
| 298 |
+
sql, _ = self.sql_generator.generate(query, schema_context, history)
|
| 299 |
+
is_valid, _, sanitized_sql = self.sql_validator.validate(sql)
|
| 300 |
+
if is_valid:
|
| 301 |
+
results = self.db.execute_query(sanitized_sql)
|
| 302 |
+
sql_context = f"\nSQL RESULTS:\n{self._format_results(results)}"
|
| 303 |
+
sql_query = sanitized_sql
|
| 304 |
+
except Exception as e:
|
| 305 |
+
logger.debug(f"SQL part of hybrid failed: {e}")
|
| 306 |
+
|
| 307 |
+
context = f"SEMANTIC SEARCH RESULTS:\n{rag_context}{sql_context}"
|
| 308 |
+
prompt = self.RESPONSE_PROMPT.format(context=context, question=query)
|
| 309 |
+
|
| 310 |
+
messages = self._construct_messages(
|
| 311 |
+
"You are a helpful database assistant.",
|
| 312 |
+
history,
|
| 313 |
+
prompt
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
answer = self.llm_client.chat(messages)
|
| 317 |
+
|
| 318 |
+
return ChatResponse(answer=answer, query_type="hybrid", sql_query=sql_query)
|
| 319 |
+
|
| 320 |
+
def _construct_messages(self, system_instruction: str, history: List[Dict], user_content: str) -> List[Dict]:
|
| 321 |
+
"""Construct message list, merging system messages from history."""
|
| 322 |
+
# Check if first history item is a system message (from memory)
|
| 323 |
+
additional_context = ""
|
| 324 |
+
filtered_history = []
|
| 325 |
+
|
| 326 |
+
for msg in history:
|
| 327 |
+
if msg.get("role") == "system":
|
| 328 |
+
additional_context += f"\n\n{msg.get('content')}"
|
| 329 |
+
else:
|
| 330 |
+
filtered_history.append(msg)
|
| 331 |
+
|
| 332 |
+
full_system_prompt = f"{system_instruction}{additional_context}"
|
| 333 |
+
|
| 334 |
+
messages = [{"role": "system", "content": full_system_prompt}]
|
| 335 |
+
messages.extend(filtered_history)
|
| 336 |
+
messages.append({"role": "user", "content": user_content})
|
| 337 |
+
|
| 338 |
+
return messages
|
| 339 |
+
|
| 340 |
+
def _handle_general(self, query: str, history: List[Dict]) -> ChatResponse:
|
| 341 |
+
"""Handle conversation."""
|
| 342 |
+
# Use a strict prompt for general conversation as well to prevent hallucinations
|
| 343 |
+
strict_system_prompt = (
|
| 344 |
+
"You are a helpful database assistant.\n"
|
| 345 |
+
"INSTRUCTIONS:\n"
|
| 346 |
+
"- Answer ONLY based on the conversation history and any context provided within it.\n"
|
| 347 |
+
"- Do NOT use outside knowledge, general assumptions, or hallucinate facts.\n"
|
| 348 |
+
"- If the answer is not in the history or context, state that you don't have that information.\n"
|
| 349 |
+
"- Be concise."
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
messages = self._construct_messages(
|
| 353 |
+
strict_system_prompt,
|
| 354 |
+
history,
|
| 355 |
+
query
|
| 356 |
+
)
|
| 357 |
+
answer = self.llm_client.chat(messages)
|
| 358 |
+
return ChatResponse(answer=answer, query_type="general")
|
| 359 |
+
|
| 360 |
+
def _format_results(self, results: List[Dict], max_rows: int = 10) -> str:
|
| 361 |
+
"""Format SQL results for display."""
|
| 362 |
+
if not results:
|
| 363 |
+
return "No results found."
|
| 364 |
+
|
| 365 |
+
rows = results[:max_rows]
|
| 366 |
+
lines = []
|
| 367 |
+
|
| 368 |
+
# Header
|
| 369 |
+
headers = list(rows[0].keys())
|
| 370 |
+
lines.append(" | ".join(headers))
|
| 371 |
+
lines.append("-" * len(lines[0]))
|
| 372 |
+
|
| 373 |
+
# Rows
|
| 374 |
+
for row in rows:
|
| 375 |
+
values = [str(v)[:50] for v in row.values()]
|
| 376 |
+
lines.append(" | ".join(values))
|
| 377 |
+
|
| 378 |
+
if len(results) > max_rows:
|
| 379 |
+
lines.append(f"... and {len(results) - max_rows} more rows")
|
| 380 |
+
|
| 381 |
+
return "\n".join(lines)
|
| 382 |
+
|
| 383 |
+
def get_schema_summary(self) -> str:
|
| 384 |
+
"""Get a summary of the database schema."""
|
| 385 |
+
if not self._schema_initialized:
|
| 386 |
+
return "Schema not loaded."
|
| 387 |
+
return get_schema().to_context_string()
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def create_chatbot(llm_client: Optional[LLMClient] = None) -> DatabaseChatbot:
|
| 391 |
+
return DatabaseChatbot(llm_client)
|
config.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration module for the Schema-Agnostic Database Chatbot.
|
| 3 |
+
|
| 4 |
+
This module handles all configuration including:
|
| 5 |
+
- Database connection settings (MySQL, PostgreSQL, SQLite)
|
| 6 |
+
- LLM provider settings (Groq / OpenAI / Local LLaMA)
|
| 7 |
+
- Embedding model configuration
|
| 8 |
+
- Security settings
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from typing import Optional, List
|
| 15 |
+
from enum import Enum
|
| 16 |
+
|
| 17 |
+
# Load .env file BEFORE any os.getenv calls
|
| 18 |
+
from dotenv import load_dotenv
|
| 19 |
+
env_path = Path(__file__).parent / ".env"
|
| 20 |
+
load_dotenv(env_path)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DatabaseType(Enum):
|
| 24 |
+
"""Supported database types."""
|
| 25 |
+
MYSQL = "mysql"
|
| 26 |
+
POSTGRESQL = "postgresql"
|
| 27 |
+
SQLITE = "sqlite"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LLMProvider(Enum):
|
| 31 |
+
"""Supported LLM providers."""
|
| 32 |
+
GROQ = "groq" # FREE!
|
| 33 |
+
OPENAI = "openai"
|
| 34 |
+
LOCAL_LLAMA = "local_llama"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class EmbeddingProvider(Enum):
|
| 38 |
+
"""Supported embedding providers."""
|
| 39 |
+
OPENAI = "openai"
|
| 40 |
+
SENTENCE_TRANSFORMERS = "sentence_transformers"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class DatabaseConfig:
|
| 45 |
+
"""
|
| 46 |
+
Database configuration supporting MySQL, PostgreSQL, and SQLite.
|
| 47 |
+
|
| 48 |
+
All sensitive values are loaded from environment variables.
|
| 49 |
+
"""
|
| 50 |
+
# Database type (mysql, postgresql, sqlite)
|
| 51 |
+
db_type: DatabaseType = field(
|
| 52 |
+
default_factory=lambda: DatabaseType(os.getenv("DB_TYPE", "mysql").lower())
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Common connection settings (for MySQL/PostgreSQL)
|
| 56 |
+
host: str = field(default_factory=lambda: os.getenv("DB_HOST", os.getenv("MYSQL_HOST", "")))
|
| 57 |
+
port: int = field(default_factory=lambda: int(os.getenv("DB_PORT", os.getenv("MYSQL_PORT", "3306"))))
|
| 58 |
+
database: str = field(default_factory=lambda: os.getenv("DB_DATABASE", os.getenv("MYSQL_DATABASE", "")))
|
| 59 |
+
username: str = field(default_factory=lambda: os.getenv("DB_USERNAME", os.getenv("MYSQL_USERNAME", "")))
|
| 60 |
+
password: str = field(default_factory=lambda: os.getenv("DB_PASSWORD", os.getenv("MYSQL_PASSWORD", "")))
|
| 61 |
+
|
| 62 |
+
# SSL configuration
|
| 63 |
+
ssl_ca: Optional[str] = field(default_factory=lambda: os.getenv("DB_SSL_CA", os.getenv("MYSQL_SSL_CA", None)))
|
| 64 |
+
|
| 65 |
+
# SQLite-specific: path to database file
|
| 66 |
+
sqlite_path: str = field(default_factory=lambda: os.getenv("SQLITE_PATH", "./chatbot.db"))
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def connection_string(self) -> str:
|
| 70 |
+
"""Generate SQLAlchemy connection string based on database type."""
|
| 71 |
+
if self.db_type == DatabaseType.SQLITE:
|
| 72 |
+
# SQLite uses file path
|
| 73 |
+
return f"sqlite:///{self.sqlite_path}"
|
| 74 |
+
|
| 75 |
+
elif self.db_type == DatabaseType.POSTGRESQL:
|
| 76 |
+
# PostgreSQL connection string
|
| 77 |
+
base_url = f"postgresql+psycopg2://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
|
| 78 |
+
if self.ssl_ca:
|
| 79 |
+
return f"{base_url}?sslmode=verify-full&sslrootcert={self.ssl_ca}"
|
| 80 |
+
return base_url
|
| 81 |
+
|
| 82 |
+
else: # MySQL (default)
|
| 83 |
+
# MySQL connection string
|
| 84 |
+
base_url = f"mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
|
| 85 |
+
if self.ssl_ca:
|
| 86 |
+
return f"{base_url}?ssl_ca={self.ssl_ca}"
|
| 87 |
+
return base_url
|
| 88 |
+
|
| 89 |
+
def is_configured(self) -> bool:
|
| 90 |
+
"""Check if all required database settings are configured."""
|
| 91 |
+
if self.db_type == DatabaseType.SQLITE:
|
| 92 |
+
# SQLite only needs a valid path
|
| 93 |
+
return bool(self.sqlite_path)
|
| 94 |
+
else:
|
| 95 |
+
# MySQL/PostgreSQL need host, database, username, password
|
| 96 |
+
return all([self.host, self.database, self.username, self.password])
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def is_mysql(self) -> bool:
|
| 100 |
+
"""Check if using MySQL."""
|
| 101 |
+
return self.db_type == DatabaseType.MYSQL
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def is_postgresql(self) -> bool:
|
| 105 |
+
"""Check if using PostgreSQL."""
|
| 106 |
+
return self.db_type == DatabaseType.POSTGRESQL
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def is_sqlite(self) -> bool:
|
| 110 |
+
"""Check if using SQLite."""
|
| 111 |
+
return self.db_type == DatabaseType.SQLITE
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@dataclass
|
| 115 |
+
class LLMConfig:
|
| 116 |
+
"""LLM configuration for query routing and response generation."""
|
| 117 |
+
provider: LLMProvider = field(
|
| 118 |
+
default_factory=lambda: LLMProvider(os.getenv("LLM_PROVIDER", "openai"))
|
| 119 |
+
)
|
| 120 |
+
openai_api_key: str = field(default_factory=lambda: os.getenv("OPENAI_API_KEY", ""))
|
| 121 |
+
openai_model: str = field(default_factory=lambda: os.getenv("OPENAI_MODEL", "gpt-4o-mini"))
|
| 122 |
+
|
| 123 |
+
# Local LLaMA settings
|
| 124 |
+
local_model_path: str = field(
|
| 125 |
+
default_factory=lambda: os.getenv("LOCAL_MODEL_PATH", "")
|
| 126 |
+
)
|
| 127 |
+
local_model_name: str = field(
|
| 128 |
+
default_factory=lambda: os.getenv("LOCAL_MODEL_NAME", "llama-2-7b-chat")
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Generation parameters
|
| 132 |
+
temperature: float = 0.1 # Low temperature for more deterministic outputs
|
| 133 |
+
max_tokens: int = 1024
|
| 134 |
+
|
| 135 |
+
def is_configured(self) -> bool:
|
| 136 |
+
"""Check if LLM is properly configured."""
|
| 137 |
+
if self.provider == LLMProvider.OPENAI:
|
| 138 |
+
return bool(self.openai_api_key)
|
| 139 |
+
return bool(self.local_model_path)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class EmbeddingConfig:
|
| 144 |
+
"""Embedding model configuration for RAG."""
|
| 145 |
+
provider: EmbeddingProvider = field(
|
| 146 |
+
default_factory=lambda: EmbeddingProvider(
|
| 147 |
+
os.getenv("EMBEDDING_PROVIDER", "sentence_transformers")
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# OpenAI embedding settings
|
| 152 |
+
openai_embedding_model: str = "text-embedding-3-small"
|
| 153 |
+
|
| 154 |
+
# Sentence Transformers settings
|
| 155 |
+
st_model_name: str = field(
|
| 156 |
+
default_factory=lambda: os.getenv(
|
| 157 |
+
"EMBEDDING_MODEL",
|
| 158 |
+
"sentence-transformers/all-MiniLM-L6-v2"
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Embedding dimensions (varies by model)
|
| 163 |
+
embedding_dim: int = 384 # Default for all-MiniLM-L6-v2
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclass
|
| 167 |
+
class SecurityConfig:
|
| 168 |
+
"""Security settings for SQL validation and execution."""
|
| 169 |
+
|
| 170 |
+
# SQL operations whitelist - ONLY SELECT allowed
|
| 171 |
+
allowed_operations: List[str] = field(default_factory=lambda: ["SELECT"])
|
| 172 |
+
|
| 173 |
+
# Dangerous keywords that should never appear in queries
|
| 174 |
+
forbidden_keywords: List[str] = field(default_factory=lambda: [
|
| 175 |
+
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER",
|
| 176 |
+
"TRUNCATE", "GRANT", "REVOKE", "EXECUTE", "EXEC",
|
| 177 |
+
"INTO OUTFILE", "INTO DUMPFILE", "LOAD_FILE",
|
| 178 |
+
"INFORMATION_SCHEMA.USER_PRIVILEGES"
|
| 179 |
+
])
|
| 180 |
+
|
| 181 |
+
# Maximum number of rows to return
|
| 182 |
+
max_result_rows: int = 100
|
| 183 |
+
|
| 184 |
+
# Default LIMIT clause if not specified
|
| 185 |
+
default_limit: int = 50
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@dataclass
|
| 189 |
+
class RAGConfig:
|
| 190 |
+
"""RAG (Retrieval-Augmented Generation) configuration."""
|
| 191 |
+
|
| 192 |
+
# FAISS index settings
|
| 193 |
+
faiss_index_path: str = "./faiss_index"
|
| 194 |
+
|
| 195 |
+
# Number of top results to retrieve
|
| 196 |
+
top_k: int = 5
|
| 197 |
+
|
| 198 |
+
# Minimum similarity score for relevance
|
| 199 |
+
similarity_threshold: float = 0.3
|
| 200 |
+
|
| 201 |
+
# Text columns to consider for RAG (common across database types)
|
| 202 |
+
text_column_types: List[str] = field(default_factory=lambda: [
|
| 203 |
+
# MySQL types
|
| 204 |
+
"TEXT", "MEDIUMTEXT", "LONGTEXT", "TINYTEXT", "VARCHAR", "CHAR",
|
| 205 |
+
# PostgreSQL types
|
| 206 |
+
"CHARACTER VARYING", "CHARACTER",
|
| 207 |
+
# SQLite types (SQLite is flexible but these are common)
|
| 208 |
+
"CLOB", "NVARCHAR", "NCHAR"
|
| 209 |
+
])
|
| 210 |
+
|
| 211 |
+
# Minimum character length to consider a column for RAG
|
| 212 |
+
min_text_length: int = 50
|
| 213 |
+
|
| 214 |
+
# Chunk size for long text documents
|
| 215 |
+
chunk_size: int = 500
|
| 216 |
+
chunk_overlap: int = 50
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@dataclass
|
| 220 |
+
class ChatConfig:
|
| 221 |
+
"""Chat and memory configuration."""
|
| 222 |
+
|
| 223 |
+
# Short-term memory (in session)
|
| 224 |
+
max_session_messages: int = 20
|
| 225 |
+
|
| 226 |
+
# Long-term memory table name (will be created if not exists)
|
| 227 |
+
memory_table_name: str = "_chatbot_memory"
|
| 228 |
+
|
| 229 |
+
# Number of recent messages to include in context
|
| 230 |
+
context_messages: int = 5
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class AppConfig:
|
| 234 |
+
"""
|
| 235 |
+
Main application configuration aggregator.
|
| 236 |
+
|
| 237 |
+
Combines all configuration sections and provides
|
| 238 |
+
validation methods.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
def __init__(self):
|
| 242 |
+
self.database = DatabaseConfig()
|
| 243 |
+
self.llm = LLMConfig()
|
| 244 |
+
self.embedding = EmbeddingConfig()
|
| 245 |
+
self.security = SecurityConfig()
|
| 246 |
+
self.rag = RAGConfig()
|
| 247 |
+
self.chat = ChatConfig()
|
| 248 |
+
|
| 249 |
+
def validate(self) -> tuple[bool, List[str]]:
|
| 250 |
+
"""
|
| 251 |
+
Validate all configuration settings.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
tuple: (is_valid, list of error messages)
|
| 255 |
+
"""
|
| 256 |
+
errors = []
|
| 257 |
+
|
| 258 |
+
if not self.database.is_configured():
|
| 259 |
+
db_type = self.database.db_type.value.upper()
|
| 260 |
+
if self.database.is_sqlite:
|
| 261 |
+
errors.append("SQLite configuration incomplete. Check SQLITE_PATH environment variable.")
|
| 262 |
+
else:
|
| 263 |
+
errors.append(f"{db_type} configuration incomplete. Check DB_* environment variables.")
|
| 264 |
+
|
| 265 |
+
if not self.llm.is_configured():
|
| 266 |
+
errors.append(
|
| 267 |
+
f"LLM configuration incomplete for provider: {self.llm.provider.value}. "
|
| 268 |
+
"Check API keys or model paths."
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return len(errors) == 0, errors
|
| 272 |
+
|
| 273 |
+
@classmethod
|
| 274 |
+
def from_env(cls) -> "AppConfig":
|
| 275 |
+
"""Create configuration from environment variables."""
|
| 276 |
+
return cls()
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# Global configuration instance
|
| 280 |
+
config = AppConfig.from_env()
|
database/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database module for the Schema-Agnostic Chatbot.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- Database connection management
|
| 6 |
+
- Dynamic schema introspection
|
| 7 |
+
- Safe query execution
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .connection import DatabaseConnection, get_db, db_connection
|
| 11 |
+
from .schema_introspector import (
|
| 12 |
+
SchemaIntrospector,
|
| 13 |
+
SchemaInfo,
|
| 14 |
+
TableInfo,
|
| 15 |
+
ColumnInfo,
|
| 16 |
+
get_introspector,
|
| 17 |
+
get_schema
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"DatabaseConnection",
|
| 22 |
+
"get_db",
|
| 23 |
+
"db_connection",
|
| 24 |
+
"SchemaIntrospector",
|
| 25 |
+
"SchemaInfo",
|
| 26 |
+
"TableInfo",
|
| 27 |
+
"ColumnInfo",
|
| 28 |
+
"get_introspector",
|
| 29 |
+
"get_schema"
|
| 30 |
+
]
|
database/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (786 Bytes). View file
|
|
|
database/__pycache__/connection.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
database/__pycache__/schema_introspector.cpython-311.pyc
ADDED
|
Binary file (31.2 kB). View file
|
|
|
database/connection.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database Connection Module - Multi-Database Support.
|
| 3 |
+
|
| 4 |
+
This module provides:
|
| 5 |
+
- SQLAlchemy engine and session management for MySQL, PostgreSQL, and SQLite
|
| 6 |
+
- Connection pooling (for MySQL/PostgreSQL)
|
| 7 |
+
- SSL/TLS support
|
| 8 |
+
- Connection health checking
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from contextlib import contextmanager
|
| 13 |
+
from typing import Optional, Generator
|
| 14 |
+
from sqlalchemy import create_engine, text, event
|
| 15 |
+
from sqlalchemy.engine import Engine
|
| 16 |
+
from sqlalchemy.orm import sessionmaker, Session
|
| 17 |
+
from sqlalchemy.pool import QueuePool, StaticPool
|
| 18 |
+
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
| 19 |
+
|
| 20 |
+
import sys
|
| 21 |
+
import os
|
| 22 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 23 |
+
from config import DatabaseConfig, DatabaseType, config
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DatabaseConnection:
|
| 29 |
+
"""
|
| 30 |
+
Manages database connections with connection pooling.
|
| 31 |
+
|
| 32 |
+
Supports MySQL, PostgreSQL, and SQLite.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, db_config: Optional[DatabaseConfig] = None):
|
| 36 |
+
"""
|
| 37 |
+
Initialize database connection manager.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
db_config: Database configuration. Uses global config if not provided.
|
| 41 |
+
"""
|
| 42 |
+
self.config = db_config or config.database
|
| 43 |
+
self._engine: Optional[Engine] = None
|
| 44 |
+
self._session_factory: Optional[sessionmaker] = None
|
| 45 |
+
|
| 46 |
+
def _create_engine(self) -> Engine:
|
| 47 |
+
"""
|
| 48 |
+
Create SQLAlchemy engine with appropriate settings for each database type.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Configured SQLAlchemy Engine instance
|
| 52 |
+
"""
|
| 53 |
+
connect_args = {}
|
| 54 |
+
|
| 55 |
+
if self.config.db_type == DatabaseType.SQLITE:
|
| 56 |
+
# SQLite-specific settings
|
| 57 |
+
# Use StaticPool for SQLite to handle multi-threading
|
| 58 |
+
connect_args["check_same_thread"] = False
|
| 59 |
+
|
| 60 |
+
engine = create_engine(
|
| 61 |
+
self.config.connection_string,
|
| 62 |
+
poolclass=StaticPool, # SQLite works best with StaticPool
|
| 63 |
+
connect_args=connect_args,
|
| 64 |
+
echo=False
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Enable foreign keys for SQLite
|
| 68 |
+
@event.listens_for(engine, "connect")
|
| 69 |
+
def set_sqlite_pragma(dbapi_connection, connection_record):
|
| 70 |
+
cursor = dbapi_connection.cursor()
|
| 71 |
+
cursor.execute("PRAGMA foreign_keys=ON")
|
| 72 |
+
cursor.close()
|
| 73 |
+
|
| 74 |
+
elif self.config.db_type == DatabaseType.POSTGRESQL:
|
| 75 |
+
# PostgreSQL-specific settings
|
| 76 |
+
if self.config.ssl_ca:
|
| 77 |
+
connect_args["sslmode"] = "verify-full"
|
| 78 |
+
connect_args["sslrootcert"] = self.config.ssl_ca
|
| 79 |
+
|
| 80 |
+
engine = create_engine(
|
| 81 |
+
self.config.connection_string,
|
| 82 |
+
poolclass=QueuePool,
|
| 83 |
+
pool_size=5,
|
| 84 |
+
max_overflow=10,
|
| 85 |
+
pool_timeout=30,
|
| 86 |
+
pool_recycle=1800,
|
| 87 |
+
pool_pre_ping=True,
|
| 88 |
+
connect_args=connect_args,
|
| 89 |
+
echo=False
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
else: # MySQL (default)
|
| 93 |
+
# MySQL-specific settings (SSL for Aiven)
|
| 94 |
+
if self.config.ssl_ca:
|
| 95 |
+
connect_args["ssl"] = {
|
| 96 |
+
"ca": self.config.ssl_ca,
|
| 97 |
+
"check_hostname": True,
|
| 98 |
+
"verify_mode": True
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
engine = create_engine(
|
| 102 |
+
self.config.connection_string,
|
| 103 |
+
poolclass=QueuePool,
|
| 104 |
+
pool_size=5,
|
| 105 |
+
max_overflow=10,
|
| 106 |
+
pool_timeout=30,
|
| 107 |
+
pool_recycle=1800,
|
| 108 |
+
pool_pre_ping=True,
|
| 109 |
+
connect_args=connect_args,
|
| 110 |
+
echo=False
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return engine
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def engine(self) -> Engine:
|
| 117 |
+
"""Get or create the SQLAlchemy engine."""
|
| 118 |
+
if self._engine is None:
|
| 119 |
+
self._engine = self._create_engine()
|
| 120 |
+
return self._engine
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def session_factory(self) -> sessionmaker:
|
| 124 |
+
"""Get or create the session factory."""
|
| 125 |
+
if self._session_factory is None:
|
| 126 |
+
self._session_factory = sessionmaker(
|
| 127 |
+
bind=self.engine,
|
| 128 |
+
autocommit=False,
|
| 129 |
+
autoflush=False
|
| 130 |
+
)
|
| 131 |
+
return self._session_factory
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def db_type(self) -> DatabaseType:
|
| 135 |
+
"""Get the current database type."""
|
| 136 |
+
return self.config.db_type
|
| 137 |
+
|
| 138 |
+
@contextmanager
|
| 139 |
+
def get_session(self) -> Generator[Session, None, None]:
|
| 140 |
+
"""
|
| 141 |
+
Context manager for database sessions.
|
| 142 |
+
|
| 143 |
+
Yields:
|
| 144 |
+
SQLAlchemy Session instance
|
| 145 |
+
|
| 146 |
+
Example:
|
| 147 |
+
with db.get_session() as session:
|
| 148 |
+
result = session.execute(text("SELECT * FROM users"))
|
| 149 |
+
"""
|
| 150 |
+
session = self.session_factory()
|
| 151 |
+
try:
|
| 152 |
+
yield session
|
| 153 |
+
session.commit()
|
| 154 |
+
except SQLAlchemyError as e:
|
| 155 |
+
session.rollback()
|
| 156 |
+
logger.error(f"Database session error: {e}")
|
| 157 |
+
raise
|
| 158 |
+
finally:
|
| 159 |
+
session.close()
|
| 160 |
+
|
| 161 |
+
def execute_query(self, query: str, params: Optional[dict] = None) -> list:
|
| 162 |
+
"""
|
| 163 |
+
Execute a read-only SQL query and return results.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
query: SQL query string (must be SELECT)
|
| 167 |
+
params: Optional query parameters for parameterized queries
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
List of result rows as dictionaries
|
| 171 |
+
"""
|
| 172 |
+
with self.get_session() as session:
|
| 173 |
+
result = session.execute(text(query), params or {})
|
| 174 |
+
# Convert rows to dictionaries for easier handling
|
| 175 |
+
columns = result.keys()
|
| 176 |
+
return [dict(zip(columns, row)) for row in result.fetchall()]
|
| 177 |
+
|
| 178 |
+
def execute_write(self, query: str, params: Optional[dict] = None) -> bool:
|
| 179 |
+
"""
|
| 180 |
+
Execute a write operation (INSERT, UPDATE, DELETE, CREATE).
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
query: SQL query string
|
| 184 |
+
params: Optional query parameters
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
bool: True if successful
|
| 188 |
+
"""
|
| 189 |
+
with self.get_session() as session:
|
| 190 |
+
session.execute(text(query), params or {})
|
| 191 |
+
session.commit()
|
| 192 |
+
return True
|
| 193 |
+
|
| 194 |
+
def test_connection(self) -> tuple[bool, str]:
|
| 195 |
+
"""
|
| 196 |
+
Test database connectivity.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
tuple: (success: bool, message: str)
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
with self.get_session() as session:
|
| 203 |
+
result = session.execute(text("SELECT 1 as health_check"))
|
| 204 |
+
row = result.fetchone()
|
| 205 |
+
if row and row[0] == 1:
|
| 206 |
+
db_type = self.config.db_type.value.upper()
|
| 207 |
+
return True, f"{db_type} connection successful"
|
| 208 |
+
return False, "Unexpected result from health check query"
|
| 209 |
+
except OperationalError as e:
|
| 210 |
+
logger.error(f"Database connection failed: {e}")
|
| 211 |
+
return False, f"Connection failed: {str(e)}"
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.error(f"Unexpected error during connection test: {e}")
|
| 214 |
+
return False, f"Unexpected error: {str(e)}"
|
| 215 |
+
|
| 216 |
+
def close(self):
|
| 217 |
+
"""Close all connections and dispose of the engine."""
|
| 218 |
+
if self._engine:
|
| 219 |
+
self._engine.dispose()
|
| 220 |
+
self._engine = None
|
| 221 |
+
self._session_factory = None
|
| 222 |
+
logger.info("Database connections closed")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# Create a global database connection instance
|
| 226 |
+
db_connection = DatabaseConnection()
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def get_db() -> DatabaseConnection:
|
| 230 |
+
"""Get the global database connection instance."""
|
| 231 |
+
return db_connection
|
database/schema_introspector.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dynamic Schema Introspection Module - Multi-Database Support.
|
| 3 |
+
|
| 4 |
+
This module is the CORE of the schema-agnostic design.
|
| 5 |
+
It dynamically discovers:
|
| 6 |
+
- All tables in the database
|
| 7 |
+
- All columns with their data types
|
| 8 |
+
- Primary keys and foreign keys
|
| 9 |
+
- Text-like columns for RAG indexing
|
| 10 |
+
- Relationships between tables
|
| 11 |
+
|
| 12 |
+
Supports MySQL, PostgreSQL, and SQLite.
|
| 13 |
+
NEVER hardcodes any table or column names.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import List, Dict, Optional, Any
|
| 19 |
+
from sqlalchemy import text, inspect
|
| 20 |
+
from sqlalchemy.engine import Engine
|
| 21 |
+
|
| 22 |
+
from .connection import get_db
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ColumnInfo:
|
| 29 |
+
"""Information about a single database column."""
|
| 30 |
+
name: str
|
| 31 |
+
data_type: str
|
| 32 |
+
is_nullable: bool
|
| 33 |
+
is_primary_key: bool
|
| 34 |
+
max_length: Optional[int] = None
|
| 35 |
+
default_value: Optional[str] = None
|
| 36 |
+
comment: Optional[str] = None
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def is_text_type(self) -> bool:
|
| 40 |
+
"""Check if this column contains text data suitable for RAG."""
|
| 41 |
+
text_types = [
|
| 42 |
+
# MySQL
|
| 43 |
+
'text', 'mediumtext', 'longtext', 'tinytext', 'varchar', 'char', 'json',
|
| 44 |
+
# PostgreSQL
|
| 45 |
+
'character varying', 'character', 'text', 'json', 'jsonb',
|
| 46 |
+
# SQLite (column affinity - TEXT)
|
| 47 |
+
'clob', 'nvarchar', 'nchar', 'ntext'
|
| 48 |
+
]
|
| 49 |
+
data_type_lower = self.data_type.lower().split('(')[0].strip()
|
| 50 |
+
return data_type_lower in text_types
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def is_numeric(self) -> bool:
|
| 54 |
+
"""Check if this column contains numeric data."""
|
| 55 |
+
numeric_types = [
|
| 56 |
+
# Common across databases
|
| 57 |
+
'int', 'integer', 'bigint', 'smallint', 'tinyint',
|
| 58 |
+
'decimal', 'numeric', 'float', 'double', 'real',
|
| 59 |
+
# PostgreSQL specific
|
| 60 |
+
'double precision', 'serial', 'bigserial', 'smallserial',
|
| 61 |
+
# SQLite (NUMERIC affinity)
|
| 62 |
+
'bool', 'boolean'
|
| 63 |
+
]
|
| 64 |
+
data_type_lower = self.data_type.lower().split('(')[0].strip()
|
| 65 |
+
return data_type_lower in numeric_types
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class TableInfo:
|
| 70 |
+
"""Complete information about a database table."""
|
| 71 |
+
name: str
|
| 72 |
+
columns: List[ColumnInfo] = field(default_factory=list)
|
| 73 |
+
primary_keys: List[str] = field(default_factory=list)
|
| 74 |
+
foreign_keys: Dict[str, str] = field(default_factory=dict) # column -> referenced_table.column
|
| 75 |
+
row_count: Optional[int] = None
|
| 76 |
+
comment: Optional[str] = None
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def text_columns(self) -> List[ColumnInfo]:
|
| 80 |
+
"""Get columns suitable for text/RAG indexing."""
|
| 81 |
+
return [col for col in self.columns if col.is_text_type]
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def column_names(self) -> List[str]:
|
| 85 |
+
"""Get list of all column names."""
|
| 86 |
+
return [col.name for col in self.columns]
|
| 87 |
+
|
| 88 |
+
def get_column(self, name: str) -> Optional[ColumnInfo]:
|
| 89 |
+
"""Get column info by name."""
|
| 90 |
+
for col in self.columns:
|
| 91 |
+
if col.name.lower() == name.lower():
|
| 92 |
+
return col
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass
|
| 97 |
+
class SchemaInfo:
|
| 98 |
+
"""Complete database schema information."""
|
| 99 |
+
database_name: str
|
| 100 |
+
tables: Dict[str, TableInfo] = field(default_factory=dict)
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def table_names(self) -> List[str]:
|
| 104 |
+
"""Get list of all table names."""
|
| 105 |
+
return list(self.tables.keys())
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def all_text_columns(self) -> List[tuple]:
|
| 109 |
+
"""Get all text columns across all tables as (table, column) tuples."""
|
| 110 |
+
result = []
|
| 111 |
+
for table_name, table_info in self.tables.items():
|
| 112 |
+
for col in table_info.text_columns:
|
| 113 |
+
result.append((table_name, col.name))
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
def to_context_string(self) -> str:
|
| 117 |
+
"""
|
| 118 |
+
Generate a natural language description of the schema.
|
| 119 |
+
This is used as context for the LLM.
|
| 120 |
+
"""
|
| 121 |
+
lines = [f"Database: {self.database_name}", ""]
|
| 122 |
+
lines.append("Available Tables:")
|
| 123 |
+
lines.append("-" * 40)
|
| 124 |
+
|
| 125 |
+
for table_name, table_info in self.tables.items():
|
| 126 |
+
lines.append(f"\nTable: {table_name}")
|
| 127 |
+
if table_info.comment:
|
| 128 |
+
lines.append(f" Description: {table_info.comment}")
|
| 129 |
+
if table_info.row_count is not None:
|
| 130 |
+
lines.append(f" Approximate rows: {table_info.row_count}")
|
| 131 |
+
|
| 132 |
+
lines.append(" Columns:")
|
| 133 |
+
for col in table_info.columns:
|
| 134 |
+
pk_marker = " [PRIMARY KEY]" if col.is_primary_key else ""
|
| 135 |
+
nullable = " (nullable)" if col.is_nullable else " (required)"
|
| 136 |
+
lines.append(f" - {col.name}: {col.data_type}{pk_marker}{nullable}")
|
| 137 |
+
if col.comment:
|
| 138 |
+
lines.append(f" Comment: {col.comment}")
|
| 139 |
+
|
| 140 |
+
if table_info.foreign_keys:
|
| 141 |
+
lines.append(" Foreign Keys:")
|
| 142 |
+
for col, ref in table_info.foreign_keys.items():
|
| 143 |
+
lines.append(f" - {col} -> {ref}")
|
| 144 |
+
|
| 145 |
+
return "\n".join(lines)
|
| 146 |
+
|
| 147 |
+
def to_sql_ddl(self) -> str:
|
| 148 |
+
"""
|
| 149 |
+
Generate SQL-like DDL representation of the schema.
|
| 150 |
+
Useful for SQL generation context.
|
| 151 |
+
"""
|
| 152 |
+
ddl_lines = []
|
| 153 |
+
|
| 154 |
+
for table_name, table_info in self.tables.items():
|
| 155 |
+
ddl_lines.append(f"CREATE TABLE {table_name} (")
|
| 156 |
+
|
| 157 |
+
col_defs = []
|
| 158 |
+
for col in table_info.columns:
|
| 159 |
+
col_def = f" {col.name} {col.data_type}"
|
| 160 |
+
if col.is_primary_key:
|
| 161 |
+
col_def += " PRIMARY KEY"
|
| 162 |
+
if not col.is_nullable:
|
| 163 |
+
col_def += " NOT NULL"
|
| 164 |
+
col_defs.append(col_def)
|
| 165 |
+
|
| 166 |
+
ddl_lines.append(",\n".join(col_defs))
|
| 167 |
+
ddl_lines.append(");\n")
|
| 168 |
+
|
| 169 |
+
return "\n".join(ddl_lines)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class SchemaIntrospector:
|
| 173 |
+
"""
|
| 174 |
+
Dynamically introspects database schema.
|
| 175 |
+
|
| 176 |
+
This is the key component that enables schema-agnostic operation.
|
| 177 |
+
It queries database system catalogs to discover the complete schema.
|
| 178 |
+
Supports MySQL, PostgreSQL, and SQLite.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
# System tables to exclude from introspection
|
| 182 |
+
SYSTEM_TABLES = {
|
| 183 |
+
'_chatbot_memory', # Our own chat history table
|
| 184 |
+
'_chatbot_permanent_memory_v2',
|
| 185 |
+
'_chatbot_user_summaries',
|
| 186 |
+
'schema_migrations',
|
| 187 |
+
'flyway_schema_history',
|
| 188 |
+
# SQLite internal tables
|
| 189 |
+
'sqlite_sequence',
|
| 190 |
+
'sqlite_stat1',
|
| 191 |
+
'sqlite_stat4'
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def __init__(self, engine: Optional[Engine] = None):
|
| 195 |
+
"""
|
| 196 |
+
Initialize the introspector.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
engine: SQLAlchemy engine. Uses global connection if not provided.
|
| 200 |
+
"""
|
| 201 |
+
self.db = get_db()
|
| 202 |
+
self._cached_schema: Optional[SchemaInfo] = None
|
| 203 |
+
|
| 204 |
+
def introspect(self, force_refresh: bool = False) -> SchemaInfo:
|
| 205 |
+
"""
|
| 206 |
+
Perform complete schema introspection.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
force_refresh: If True, bypass cache and re-introspect
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
SchemaInfo object with complete schema details
|
| 213 |
+
"""
|
| 214 |
+
if self._cached_schema is not None and not force_refresh:
|
| 215 |
+
return self._cached_schema
|
| 216 |
+
|
| 217 |
+
logger.info("Starting schema introspection...")
|
| 218 |
+
|
| 219 |
+
# Get database name
|
| 220 |
+
db_name = self._get_database_name()
|
| 221 |
+
|
| 222 |
+
# Get all user tables
|
| 223 |
+
tables = self._get_tables()
|
| 224 |
+
|
| 225 |
+
schema = SchemaInfo(database_name=db_name)
|
| 226 |
+
|
| 227 |
+
for table_name in tables:
|
| 228 |
+
if table_name in self.SYSTEM_TABLES:
|
| 229 |
+
continue
|
| 230 |
+
# Also skip tables that start with underscore (internal tables)
|
| 231 |
+
if table_name.startswith('_chatbot'):
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
table_info = self._introspect_table(table_name)
|
| 235 |
+
if table_info:
|
| 236 |
+
schema.tables[table_name] = table_info
|
| 237 |
+
|
| 238 |
+
self._cached_schema = schema
|
| 239 |
+
logger.info(f"Schema introspection complete. Found {len(schema.tables)} tables.")
|
| 240 |
+
|
| 241 |
+
return schema
|
| 242 |
+
|
| 243 |
+
def _get_database_name(self) -> str:
|
| 244 |
+
"""Get the current database name."""
|
| 245 |
+
db_type = self.db.db_type
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
if db_type.value == "sqlite":
|
| 249 |
+
# For SQLite, return the database file name
|
| 250 |
+
return self.db.config.sqlite_path.split('/')[-1]
|
| 251 |
+
elif db_type.value == "postgresql":
|
| 252 |
+
result = self.db.execute_query("SELECT current_database() as db_name")
|
| 253 |
+
return result[0]['db_name'] if result else "unknown"
|
| 254 |
+
else: # MySQL
|
| 255 |
+
result = self.db.execute_query("SELECT DATABASE() as db_name")
|
| 256 |
+
return result[0]['db_name'] if result else "unknown"
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Error getting database name: {e}")
|
| 259 |
+
return "unknown"
|
| 260 |
+
|
| 261 |
+
def _get_tables(self) -> List[str]:
|
| 262 |
+
"""
|
| 263 |
+
Get all user tables from the database.
|
| 264 |
+
Uses database-specific queries for comprehensive discovery.
|
| 265 |
+
"""
|
| 266 |
+
db_type = self.db.db_type
|
| 267 |
+
|
| 268 |
+
try:
|
| 269 |
+
if db_type.value == "sqlite":
|
| 270 |
+
query = """
|
| 271 |
+
SELECT name as table_name
|
| 272 |
+
FROM sqlite_master
|
| 273 |
+
WHERE type='table'
|
| 274 |
+
AND name NOT LIKE 'sqlite_%'
|
| 275 |
+
ORDER BY name
|
| 276 |
+
"""
|
| 277 |
+
result = self.db.execute_query(query)
|
| 278 |
+
return [row['table_name'] for row in result]
|
| 279 |
+
|
| 280 |
+
elif db_type.value == "postgresql":
|
| 281 |
+
query = """
|
| 282 |
+
SELECT table_name
|
| 283 |
+
FROM information_schema.tables
|
| 284 |
+
WHERE table_schema = 'public'
|
| 285 |
+
AND table_type = 'BASE TABLE'
|
| 286 |
+
ORDER BY table_name
|
| 287 |
+
"""
|
| 288 |
+
result = self.db.execute_query(query)
|
| 289 |
+
return [row['table_name'] for row in result]
|
| 290 |
+
|
| 291 |
+
else: # MySQL
|
| 292 |
+
query = """
|
| 293 |
+
SELECT TABLE_NAME
|
| 294 |
+
FROM INFORMATION_SCHEMA.TABLES
|
| 295 |
+
WHERE TABLE_SCHEMA = DATABASE()
|
| 296 |
+
AND TABLE_TYPE = 'BASE TABLE'
|
| 297 |
+
ORDER BY TABLE_NAME
|
| 298 |
+
"""
|
| 299 |
+
result = self.db.execute_query(query)
|
| 300 |
+
return [row['TABLE_NAME'] for row in result]
|
| 301 |
+
|
| 302 |
+
except Exception as e:
|
| 303 |
+
logger.error(f"Error getting tables: {e}")
|
| 304 |
+
return []
|
| 305 |
+
|
| 306 |
+
def _introspect_table(self, table_name: str) -> Optional[TableInfo]:
|
| 307 |
+
"""
|
| 308 |
+
Get complete information about a specific table.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
table_name: Name of the table to introspect
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
TableInfo object or None if table doesn't exist
|
| 315 |
+
"""
|
| 316 |
+
try:
|
| 317 |
+
# Get column information
|
| 318 |
+
columns = self._get_columns(table_name)
|
| 319 |
+
|
| 320 |
+
# Get primary keys
|
| 321 |
+
primary_keys = self._get_primary_keys(table_name)
|
| 322 |
+
|
| 323 |
+
# Get foreign keys
|
| 324 |
+
foreign_keys = self._get_foreign_keys(table_name)
|
| 325 |
+
|
| 326 |
+
# Get approximate row count (fast estimation)
|
| 327 |
+
row_count = self._get_row_count(table_name)
|
| 328 |
+
|
| 329 |
+
# Get table comment (not available in SQLite)
|
| 330 |
+
comment = self._get_table_comment(table_name)
|
| 331 |
+
|
| 332 |
+
# Mark primary key columns
|
| 333 |
+
for col in columns:
|
| 334 |
+
col.is_primary_key = col.name in primary_keys
|
| 335 |
+
|
| 336 |
+
return TableInfo(
|
| 337 |
+
name=table_name,
|
| 338 |
+
columns=columns,
|
| 339 |
+
primary_keys=primary_keys,
|
| 340 |
+
foreign_keys=foreign_keys,
|
| 341 |
+
row_count=row_count,
|
| 342 |
+
comment=comment
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logger.error(f"Error introspecting table {table_name}: {e}")
|
| 347 |
+
return None
|
| 348 |
+
|
| 349 |
+
def _get_columns(self, table_name: str) -> List[ColumnInfo]:
|
| 350 |
+
"""Get all columns for a table."""
|
| 351 |
+
db_type = self.db.db_type
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
if db_type.value == "sqlite":
|
| 355 |
+
query = f"PRAGMA table_info('{table_name}')"
|
| 356 |
+
result = self.db.execute_query(query)
|
| 357 |
+
|
| 358 |
+
columns = []
|
| 359 |
+
for row in result:
|
| 360 |
+
columns.append(ColumnInfo(
|
| 361 |
+
name=row['name'],
|
| 362 |
+
data_type=row['type'] or 'TEXT', # SQLite columns can have no type
|
| 363 |
+
is_nullable=row['notnull'] == 0,
|
| 364 |
+
is_primary_key=row['pk'] == 1,
|
| 365 |
+
max_length=None,
|
| 366 |
+
default_value=row['dflt_value'],
|
| 367 |
+
comment=None # SQLite doesn't support column comments
|
| 368 |
+
))
|
| 369 |
+
return columns
|
| 370 |
+
|
| 371 |
+
elif db_type.value == "postgresql":
|
| 372 |
+
query = """
|
| 373 |
+
SELECT
|
| 374 |
+
column_name,
|
| 375 |
+
data_type,
|
| 376 |
+
is_nullable,
|
| 377 |
+
column_default,
|
| 378 |
+
character_maximum_length,
|
| 379 |
+
col_description(
|
| 380 |
+
(SELECT oid FROM pg_class WHERE relname = :table_name),
|
| 381 |
+
ordinal_position
|
| 382 |
+
) as column_comment
|
| 383 |
+
FROM information_schema.columns
|
| 384 |
+
WHERE table_schema = 'public'
|
| 385 |
+
AND table_name = :table_name
|
| 386 |
+
ORDER BY ordinal_position
|
| 387 |
+
"""
|
| 388 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 389 |
+
|
| 390 |
+
columns = []
|
| 391 |
+
for row in result:
|
| 392 |
+
columns.append(ColumnInfo(
|
| 393 |
+
name=row['column_name'],
|
| 394 |
+
data_type=row['data_type'],
|
| 395 |
+
is_nullable=row['is_nullable'] == 'YES',
|
| 396 |
+
is_primary_key=False, # Will be set later
|
| 397 |
+
max_length=row['character_maximum_length'],
|
| 398 |
+
default_value=row['column_default'],
|
| 399 |
+
comment=row.get('column_comment')
|
| 400 |
+
))
|
| 401 |
+
return columns
|
| 402 |
+
|
| 403 |
+
else: # MySQL
|
| 404 |
+
query = """
|
| 405 |
+
SELECT
|
| 406 |
+
COLUMN_NAME,
|
| 407 |
+
COLUMN_TYPE,
|
| 408 |
+
IS_NULLABLE,
|
| 409 |
+
COLUMN_DEFAULT,
|
| 410 |
+
CHARACTER_MAXIMUM_LENGTH,
|
| 411 |
+
COLUMN_COMMENT
|
| 412 |
+
FROM INFORMATION_SCHEMA.COLUMNS
|
| 413 |
+
WHERE TABLE_SCHEMA = DATABASE()
|
| 414 |
+
AND TABLE_NAME = :table_name
|
| 415 |
+
ORDER BY ORDINAL_POSITION
|
| 416 |
+
"""
|
| 417 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 418 |
+
|
| 419 |
+
columns = []
|
| 420 |
+
for row in result:
|
| 421 |
+
columns.append(ColumnInfo(
|
| 422 |
+
name=row['COLUMN_NAME'],
|
| 423 |
+
data_type=row['COLUMN_TYPE'],
|
| 424 |
+
is_nullable=row['IS_NULLABLE'] == 'YES',
|
| 425 |
+
is_primary_key=False, # Will be set later
|
| 426 |
+
max_length=row['CHARACTER_MAXIMUM_LENGTH'],
|
| 427 |
+
default_value=row['COLUMN_DEFAULT'],
|
| 428 |
+
comment=row['COLUMN_COMMENT'] if row['COLUMN_COMMENT'] else None
|
| 429 |
+
))
|
| 430 |
+
return columns
|
| 431 |
+
|
| 432 |
+
except Exception as e:
|
| 433 |
+
logger.error(f"Error getting columns for {table_name}: {e}")
|
| 434 |
+
return []
|
| 435 |
+
|
| 436 |
+
def _get_primary_keys(self, table_name: str) -> List[str]:
|
| 437 |
+
"""Get primary key columns for a table."""
|
| 438 |
+
db_type = self.db.db_type
|
| 439 |
+
|
| 440 |
+
try:
|
| 441 |
+
if db_type.value == "sqlite":
|
| 442 |
+
query = f"PRAGMA table_info('{table_name}')"
|
| 443 |
+
result = self.db.execute_query(query)
|
| 444 |
+
return [row['name'] for row in result if row['pk'] > 0]
|
| 445 |
+
|
| 446 |
+
elif db_type.value == "postgresql":
|
| 447 |
+
query = """
|
| 448 |
+
SELECT a.attname as column_name
|
| 449 |
+
FROM pg_index i
|
| 450 |
+
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
| 451 |
+
WHERE i.indrelid = :table_name::regclass
|
| 452 |
+
AND i.indisprimary
|
| 453 |
+
"""
|
| 454 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 455 |
+
return [row['column_name'] for row in result]
|
| 456 |
+
|
| 457 |
+
else: # MySQL
|
| 458 |
+
query = """
|
| 459 |
+
SELECT COLUMN_NAME
|
| 460 |
+
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
| 461 |
+
WHERE TABLE_SCHEMA = DATABASE()
|
| 462 |
+
AND TABLE_NAME = :table_name
|
| 463 |
+
AND CONSTRAINT_NAME = 'PRIMARY'
|
| 464 |
+
ORDER BY ORDINAL_POSITION
|
| 465 |
+
"""
|
| 466 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 467 |
+
return [row['COLUMN_NAME'] for row in result]
|
| 468 |
+
|
| 469 |
+
except Exception as e:
|
| 470 |
+
logger.error(f"Error getting primary keys for {table_name}: {e}")
|
| 471 |
+
return []
|
| 472 |
+
|
| 473 |
+
def _get_foreign_keys(self, table_name: str) -> Dict[str, str]:
|
| 474 |
+
"""Get foreign key relationships for a table."""
|
| 475 |
+
db_type = self.db.db_type
|
| 476 |
+
|
| 477 |
+
try:
|
| 478 |
+
if db_type.value == "sqlite":
|
| 479 |
+
query = f"PRAGMA foreign_key_list('{table_name}')"
|
| 480 |
+
result = self.db.execute_query(query)
|
| 481 |
+
return {
|
| 482 |
+
row['from']: f"{row['table']}.{row['to']}"
|
| 483 |
+
for row in result
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
elif db_type.value == "postgresql":
|
| 487 |
+
query = """
|
| 488 |
+
SELECT
|
| 489 |
+
kcu.column_name,
|
| 490 |
+
ccu.table_name AS foreign_table_name,
|
| 491 |
+
ccu.column_name AS foreign_column_name
|
| 492 |
+
FROM information_schema.table_constraints AS tc
|
| 493 |
+
JOIN information_schema.key_column_usage AS kcu
|
| 494 |
+
ON tc.constraint_name = kcu.constraint_name
|
| 495 |
+
AND tc.table_schema = kcu.table_schema
|
| 496 |
+
JOIN information_schema.constraint_column_usage AS ccu
|
| 497 |
+
ON ccu.constraint_name = tc.constraint_name
|
| 498 |
+
AND ccu.table_schema = tc.table_schema
|
| 499 |
+
WHERE tc.constraint_type = 'FOREIGN KEY'
|
| 500 |
+
AND tc.table_name = :table_name
|
| 501 |
+
"""
|
| 502 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 503 |
+
return {
|
| 504 |
+
row['column_name']: f"{row['foreign_table_name']}.{row['foreign_column_name']}"
|
| 505 |
+
for row in result
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
else: # MySQL
|
| 509 |
+
query = """
|
| 510 |
+
SELECT
|
| 511 |
+
COLUMN_NAME,
|
| 512 |
+
REFERENCED_TABLE_NAME,
|
| 513 |
+
REFERENCED_COLUMN_NAME
|
| 514 |
+
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
| 515 |
+
WHERE TABLE_SCHEMA = DATABASE()
|
| 516 |
+
AND TABLE_NAME = :table_name
|
| 517 |
+
AND REFERENCED_TABLE_NAME IS NOT NULL
|
| 518 |
+
"""
|
| 519 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 520 |
+
return {
|
| 521 |
+
row['COLUMN_NAME']: f"{row['REFERENCED_TABLE_NAME']}.{row['REFERENCED_COLUMN_NAME']}"
|
| 522 |
+
for row in result
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
except Exception as e:
|
| 526 |
+
logger.error(f"Error getting foreign keys for {table_name}: {e}")
|
| 527 |
+
return {}
|
| 528 |
+
|
| 529 |
+
def _get_row_count(self, table_name: str) -> Optional[int]:
|
| 530 |
+
"""
|
| 531 |
+
Get approximate row count for a table.
|
| 532 |
+
Uses different strategies per database.
|
| 533 |
+
"""
|
| 534 |
+
db_type = self.db.db_type
|
| 535 |
+
|
| 536 |
+
try:
|
| 537 |
+
if db_type.value == "sqlite":
|
| 538 |
+
# SQLite doesn't have stats table, use max rowid for estimation
|
| 539 |
+
query = f"SELECT MAX(rowid) as row_count FROM \"{table_name}\""
|
| 540 |
+
result = self.db.execute_query(query)
|
| 541 |
+
return result[0]['row_count'] if result and result[0]['row_count'] else 0
|
| 542 |
+
|
| 543 |
+
elif db_type.value == "postgresql":
|
| 544 |
+
# Use pg_stat_user_tables for fast estimation
|
| 545 |
+
query = """
|
| 546 |
+
SELECT n_live_tup as row_count
|
| 547 |
+
FROM pg_stat_user_tables
|
| 548 |
+
WHERE relname = :table_name
|
| 549 |
+
"""
|
| 550 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 551 |
+
return result[0]['row_count'] if result else None
|
| 552 |
+
|
| 553 |
+
else: # MySQL
|
| 554 |
+
query = """
|
| 555 |
+
SELECT TABLE_ROWS
|
| 556 |
+
FROM INFORMATION_SCHEMA.TABLES
|
| 557 |
+
WHERE TABLE_SCHEMA = DATABASE()
|
| 558 |
+
AND TABLE_NAME = :table_name
|
| 559 |
+
"""
|
| 560 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 561 |
+
return result[0]['TABLE_ROWS'] if result else None
|
| 562 |
+
|
| 563 |
+
except Exception as e:
|
| 564 |
+
logger.error(f"Error getting row count for {table_name}: {e}")
|
| 565 |
+
return None
|
| 566 |
+
|
| 567 |
+
def _get_table_comment(self, table_name: str) -> Optional[str]:
|
| 568 |
+
"""Get table comment/description."""
|
| 569 |
+
db_type = self.db.db_type
|
| 570 |
+
|
| 571 |
+
try:
|
| 572 |
+
if db_type.value == "sqlite":
|
| 573 |
+
# SQLite doesn't support table comments
|
| 574 |
+
return None
|
| 575 |
+
|
| 576 |
+
elif db_type.value == "postgresql":
|
| 577 |
+
query = """
|
| 578 |
+
SELECT obj_description(:table_name::regclass, 'pg_class') as table_comment
|
| 579 |
+
"""
|
| 580 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 581 |
+
comment = result[0]['table_comment'] if result else None
|
| 582 |
+
return comment if comment else None
|
| 583 |
+
|
| 584 |
+
else: # MySQL
|
| 585 |
+
query = """
|
| 586 |
+
SELECT TABLE_COMMENT
|
| 587 |
+
FROM INFORMATION_SCHEMA.TABLES
|
| 588 |
+
WHERE TABLE_SCHEMA = DATABASE()
|
| 589 |
+
AND TABLE_NAME = :table_name
|
| 590 |
+
"""
|
| 591 |
+
result = self.db.execute_query(query, {"table_name": table_name})
|
| 592 |
+
comment = result[0]['TABLE_COMMENT'] if result else None
|
| 593 |
+
return comment if comment else None
|
| 594 |
+
|
| 595 |
+
except Exception as e:
|
| 596 |
+
logger.error(f"Error getting table comment for {table_name}: {e}")
|
| 597 |
+
return None
|
| 598 |
+
|
| 599 |
+
def get_text_columns_for_rag(self, min_length: int = 50) -> List[Dict[str, Any]]:
|
| 600 |
+
"""
|
| 601 |
+
Get all text columns suitable for RAG indexing.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
min_length: Minimum max_length for varchar columns to be considered
|
| 605 |
+
|
| 606 |
+
Returns:
|
| 607 |
+
List of dicts with table name, column name, and metadata
|
| 608 |
+
"""
|
| 609 |
+
schema = self.introspect()
|
| 610 |
+
text_columns = []
|
| 611 |
+
|
| 612 |
+
for table_name, table_info in schema.tables.items():
|
| 613 |
+
for col in table_info.columns:
|
| 614 |
+
if col.is_text_type:
|
| 615 |
+
# Skip very short varchar columns
|
| 616 |
+
if col.max_length and col.max_length < min_length:
|
| 617 |
+
continue
|
| 618 |
+
|
| 619 |
+
text_columns.append({
|
| 620 |
+
"table": table_name,
|
| 621 |
+
"column": col.name,
|
| 622 |
+
"data_type": col.data_type,
|
| 623 |
+
"primary_keys": table_info.primary_keys,
|
| 624 |
+
"max_length": col.max_length
|
| 625 |
+
})
|
| 626 |
+
|
| 627 |
+
return text_columns
|
| 628 |
+
|
| 629 |
+
def refresh_cache(self) -> SchemaInfo:
|
| 630 |
+
"""Force refresh the cached schema."""
|
| 631 |
+
return self.introspect(force_refresh=True)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# Global introspector instance
|
| 635 |
+
_introspector: Optional[SchemaIntrospector] = None
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def get_introspector() -> SchemaIntrospector:
|
| 639 |
+
"""Get or create the global schema introspector."""
|
| 640 |
+
global _introspector
|
| 641 |
+
if _introspector is None:
|
| 642 |
+
_introspector = SchemaIntrospector()
|
| 643 |
+
return _introspector
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def get_schema() -> SchemaInfo:
|
| 647 |
+
"""Convenience function to get the current schema."""
|
| 648 |
+
return get_introspector().introspect()
|
llm/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM module exports."""
|
| 2 |
+
|
| 3 |
+
from .client import (
|
| 4 |
+
LLMClient,
|
| 5 |
+
GroqClient,
|
| 6 |
+
OpenAIClient,
|
| 7 |
+
LocalLLaMAClient,
|
| 8 |
+
create_llm_client
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"LLMClient",
|
| 13 |
+
"GroqClient",
|
| 14 |
+
"OpenAIClient",
|
| 15 |
+
"LocalLLaMAClient",
|
| 16 |
+
"create_llm_client"
|
| 17 |
+
]
|
llm/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (433 Bytes). View file
|
|
|
llm/__pycache__/client.cpython-311.pyc
ADDED
|
Binary file (8.38 kB). View file
|
|
|
llm/client.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Client - Unified interface for Groq, OpenAI, and local models.
|
| 3 |
+
|
| 4 |
+
Groq is the DEFAULT provider (free tier available).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from typing import List, Dict, Optional
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LLMClient(ABC):
|
| 15 |
+
"""Abstract base class for LLM clients."""
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def chat(self, messages: List[Dict[str, str]]) -> str:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def is_available(self) -> bool:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class GroqClient(LLMClient):
|
| 27 |
+
"""
|
| 28 |
+
Groq API client - FREE and FAST inference.
|
| 29 |
+
|
| 30 |
+
Available models:
|
| 31 |
+
- llama-3.3-70b-versatile (recommended)
|
| 32 |
+
- llama-3.1-8b-instant (faster)
|
| 33 |
+
- mixtral-8x7b-32768
|
| 34 |
+
- gemma2-9b-it
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
AVAILABLE_MODELS = [
|
| 38 |
+
"llama-3.3-70b-versatile",
|
| 39 |
+
"llama-3.1-70b-versatile",
|
| 40 |
+
"llama-3.1-8b-instant",
|
| 41 |
+
"llama3-70b-8192",
|
| 42 |
+
"llama3-8b-8192",
|
| 43 |
+
"mixtral-8x7b-32768",
|
| 44 |
+
"gemma2-9b-it"
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
api_key: str,
|
| 50 |
+
model: str = "llama-3.3-70b-versatile",
|
| 51 |
+
temperature: float = 0.1,
|
| 52 |
+
max_tokens: int = 1024
|
| 53 |
+
):
|
| 54 |
+
self.api_key = api_key
|
| 55 |
+
self.model = model
|
| 56 |
+
self.temperature = temperature
|
| 57 |
+
self.max_tokens = max_tokens
|
| 58 |
+
self._client = None
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def client(self):
|
| 62 |
+
if self._client is None:
|
| 63 |
+
from groq import Groq
|
| 64 |
+
self._client = Groq(api_key=self.api_key)
|
| 65 |
+
return self._client
|
| 66 |
+
|
| 67 |
+
def chat(self, messages: List[Dict[str, str]]) -> str:
|
| 68 |
+
response = self.client.chat.completions.create(
|
| 69 |
+
model=self.model,
|
| 70 |
+
messages=messages,
|
| 71 |
+
temperature=self.temperature,
|
| 72 |
+
max_tokens=self.max_tokens
|
| 73 |
+
)
|
| 74 |
+
return response.choices[0].message.content
|
| 75 |
+
|
| 76 |
+
def is_available(self) -> bool:
|
| 77 |
+
try:
|
| 78 |
+
# Simple test call
|
| 79 |
+
self.client.models.list()
|
| 80 |
+
return True
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.warning(f"Groq availability check failed: {e}")
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class OpenAIClient(LLMClient):
|
| 87 |
+
"""OpenAI API client (paid)."""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
api_key: str,
|
| 92 |
+
model: str = "gpt-4o-mini",
|
| 93 |
+
temperature: float = 0.1,
|
| 94 |
+
max_tokens: int = 1024
|
| 95 |
+
):
|
| 96 |
+
self.api_key = api_key
|
| 97 |
+
self.model = model
|
| 98 |
+
self.temperature = temperature
|
| 99 |
+
self.max_tokens = max_tokens
|
| 100 |
+
self._client = None
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def client(self):
|
| 104 |
+
if self._client is None:
|
| 105 |
+
from openai import OpenAI
|
| 106 |
+
self._client = OpenAI(api_key=self.api_key)
|
| 107 |
+
return self._client
|
| 108 |
+
|
| 109 |
+
def chat(self, messages: List[Dict[str, str]]) -> str:
|
| 110 |
+
response = self.client.chat.completions.create(
|
| 111 |
+
model=self.model,
|
| 112 |
+
messages=messages,
|
| 113 |
+
temperature=self.temperature,
|
| 114 |
+
max_tokens=self.max_tokens
|
| 115 |
+
)
|
| 116 |
+
return response.choices[0].message.content
|
| 117 |
+
|
| 118 |
+
def is_available(self) -> bool:
|
| 119 |
+
try:
|
| 120 |
+
self.client.models.list()
|
| 121 |
+
return True
|
| 122 |
+
except Exception:
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class LocalLLaMAClient(LLMClient):
|
| 127 |
+
"""Local LLaMA/Phi model client via transformers."""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
model_name: str = "microsoft/Phi-3-mini-4k-instruct",
|
| 132 |
+
temperature: float = 0.1,
|
| 133 |
+
max_tokens: int = 1024
|
| 134 |
+
):
|
| 135 |
+
self.model_name = model_name
|
| 136 |
+
self.temperature = temperature
|
| 137 |
+
self.max_tokens = max_tokens
|
| 138 |
+
self._pipeline = None
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def pipeline(self):
|
| 142 |
+
if self._pipeline is None:
|
| 143 |
+
from transformers import pipeline
|
| 144 |
+
logger.info(f"Loading local model: {self.model_name}")
|
| 145 |
+
self._pipeline = pipeline(
|
| 146 |
+
"text-generation",
|
| 147 |
+
model=self.model_name,
|
| 148 |
+
torch_dtype="auto",
|
| 149 |
+
device_map="auto"
|
| 150 |
+
)
|
| 151 |
+
return self._pipeline
|
| 152 |
+
|
| 153 |
+
def chat(self, messages: List[Dict[str, str]]) -> str:
|
| 154 |
+
output = self.pipeline(
|
| 155 |
+
messages,
|
| 156 |
+
max_new_tokens=self.max_tokens,
|
| 157 |
+
temperature=self.temperature,
|
| 158 |
+
do_sample=True
|
| 159 |
+
)
|
| 160 |
+
return output[0]["generated_text"][-1]["content"]
|
| 161 |
+
|
| 162 |
+
def is_available(self) -> bool:
|
| 163 |
+
try:
|
| 164 |
+
_ = self.pipeline
|
| 165 |
+
return True
|
| 166 |
+
except Exception:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def create_llm_client(provider: str = "groq", **kwargs) -> LLMClient:
|
| 171 |
+
"""
|
| 172 |
+
Factory function to create LLM client.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
provider: "groq" (default, free), "openai", or "local"
|
| 176 |
+
**kwargs: Provider-specific arguments
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
Configured LLMClient instance
|
| 180 |
+
"""
|
| 181 |
+
if provider == "groq":
|
| 182 |
+
return GroqClient(**kwargs)
|
| 183 |
+
elif provider == "openai":
|
| 184 |
+
return OpenAIClient(**kwargs)
|
| 185 |
+
elif provider == "local":
|
| 186 |
+
return LocalLLaMAClient(**kwargs)
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError(f"Unknown provider: {provider}. Use 'groq', 'openai', or 'local'")
|
memory.py
ADDED
|
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat Memory - Short-term and long-term memory management.
|
| 3 |
+
|
| 4 |
+
Supports MySQL, PostgreSQL, and SQLite with dialect-specific DDL.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import json
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ChatMessage:
|
| 18 |
+
role: str # "user" or "assistant"
|
| 19 |
+
content: str
|
| 20 |
+
timestamp: datetime = None
|
| 21 |
+
metadata: Dict[str, Any] = None
|
| 22 |
+
|
| 23 |
+
def __post_init__(self):
|
| 24 |
+
if self.timestamp is None:
|
| 25 |
+
self.timestamp = datetime.now()
|
| 26 |
+
if self.metadata is None:
|
| 27 |
+
self.metadata = {}
|
| 28 |
+
|
| 29 |
+
def to_dict(self) -> Dict[str, str]:
|
| 30 |
+
return {"role": self.role, "content": self.content}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_memory_table_ddl(db_type: str) -> str:
|
| 34 |
+
"""Get the DDL for chat memory table based on database type."""
|
| 35 |
+
if db_type == "postgresql":
|
| 36 |
+
return """
|
| 37 |
+
CREATE TABLE IF NOT EXISTS _chatbot_memory (
|
| 38 |
+
id SERIAL PRIMARY KEY,
|
| 39 |
+
session_id VARCHAR(255) NOT NULL,
|
| 40 |
+
user_id VARCHAR(255) NOT NULL DEFAULT 'default',
|
| 41 |
+
role VARCHAR(50) NOT NULL,
|
| 42 |
+
content TEXT NOT NULL,
|
| 43 |
+
metadata JSONB,
|
| 44 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 45 |
+
)
|
| 46 |
+
"""
|
| 47 |
+
elif db_type == "sqlite":
|
| 48 |
+
return """
|
| 49 |
+
CREATE TABLE IF NOT EXISTS _chatbot_memory (
|
| 50 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 51 |
+
session_id TEXT NOT NULL,
|
| 52 |
+
user_id TEXT NOT NULL DEFAULT 'default',
|
| 53 |
+
role TEXT NOT NULL,
|
| 54 |
+
content TEXT NOT NULL,
|
| 55 |
+
metadata TEXT,
|
| 56 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 57 |
+
)
|
| 58 |
+
"""
|
| 59 |
+
else: # MySQL
|
| 60 |
+
return """
|
| 61 |
+
CREATE TABLE IF NOT EXISTS _chatbot_memory (
|
| 62 |
+
id INT AUTO_INCREMENT PRIMARY KEY,
|
| 63 |
+
session_id VARCHAR(255) NOT NULL,
|
| 64 |
+
user_id VARCHAR(255) NOT NULL DEFAULT 'default',
|
| 65 |
+
role VARCHAR(50) NOT NULL,
|
| 66 |
+
content TEXT NOT NULL,
|
| 67 |
+
metadata JSON,
|
| 68 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 69 |
+
INDEX idx_session (session_id),
|
| 70 |
+
INDEX idx_user (user_id),
|
| 71 |
+
INDEX idx_created (created_at)
|
| 72 |
+
)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_permanent_memory_ddl(db_type: str) -> str:
|
| 77 |
+
"""Get the DDL for permanent memory table based on database type."""
|
| 78 |
+
if db_type == "postgresql":
|
| 79 |
+
return """
|
| 80 |
+
CREATE TABLE IF NOT EXISTS _chatbot_permanent_memory_v2 (
|
| 81 |
+
id SERIAL PRIMARY KEY,
|
| 82 |
+
user_id VARCHAR(255) NOT NULL DEFAULT 'default',
|
| 83 |
+
content TEXT NOT NULL,
|
| 84 |
+
tags VARCHAR(255),
|
| 85 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 86 |
+
)
|
| 87 |
+
"""
|
| 88 |
+
elif db_type == "sqlite":
|
| 89 |
+
return """
|
| 90 |
+
CREATE TABLE IF NOT EXISTS _chatbot_permanent_memory_v2 (
|
| 91 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 92 |
+
user_id TEXT NOT NULL DEFAULT 'default',
|
| 93 |
+
content TEXT NOT NULL,
|
| 94 |
+
tags TEXT,
|
| 95 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 96 |
+
)
|
| 97 |
+
"""
|
| 98 |
+
else: # MySQL
|
| 99 |
+
return """
|
| 100 |
+
CREATE TABLE IF NOT EXISTS _chatbot_permanent_memory_v2 (
|
| 101 |
+
id INT AUTO_INCREMENT PRIMARY KEY,
|
| 102 |
+
user_id VARCHAR(255) NOT NULL DEFAULT 'default',
|
| 103 |
+
content TEXT NOT NULL,
|
| 104 |
+
tags VARCHAR(255),
|
| 105 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 106 |
+
INDEX idx_user (user_id)
|
| 107 |
+
)
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_summary_table_ddl(db_type: str) -> str:
|
| 112 |
+
"""Get the DDL for summary table based on database type."""
|
| 113 |
+
if db_type == "postgresql":
|
| 114 |
+
return """
|
| 115 |
+
CREATE TABLE IF NOT EXISTS _chatbot_user_summaries (
|
| 116 |
+
id SERIAL PRIMARY KEY,
|
| 117 |
+
user_id VARCHAR(255) NOT NULL UNIQUE,
|
| 118 |
+
summary TEXT NOT NULL,
|
| 119 |
+
message_count INT DEFAULT 0,
|
| 120 |
+
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 121 |
+
)
|
| 122 |
+
"""
|
| 123 |
+
elif db_type == "sqlite":
|
| 124 |
+
return """
|
| 125 |
+
CREATE TABLE IF NOT EXISTS _chatbot_user_summaries (
|
| 126 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 127 |
+
user_id TEXT NOT NULL UNIQUE,
|
| 128 |
+
summary TEXT NOT NULL,
|
| 129 |
+
message_count INTEGER DEFAULT 0,
|
| 130 |
+
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 131 |
+
)
|
| 132 |
+
"""
|
| 133 |
+
else: # MySQL
|
| 134 |
+
return """
|
| 135 |
+
CREATE TABLE IF NOT EXISTS _chatbot_user_summaries (
|
| 136 |
+
id INT AUTO_INCREMENT PRIMARY KEY,
|
| 137 |
+
user_id VARCHAR(255) NOT NULL,
|
| 138 |
+
summary TEXT NOT NULL,
|
| 139 |
+
message_count INT DEFAULT 0,
|
| 140 |
+
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
| 141 |
+
UNIQUE KEY idx_user (user_id)
|
| 142 |
+
)
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_upsert_summary_query(db_type: str) -> str:
|
| 147 |
+
"""Get the upsert query for summary based on database type."""
|
| 148 |
+
if db_type == "postgresql":
|
| 149 |
+
return """
|
| 150 |
+
INSERT INTO _chatbot_user_summaries
|
| 151 |
+
(user_id, summary, message_count, last_updated)
|
| 152 |
+
VALUES (:user_id, :summary, :message_count, CURRENT_TIMESTAMP)
|
| 153 |
+
ON CONFLICT (user_id)
|
| 154 |
+
DO UPDATE SET
|
| 155 |
+
summary = EXCLUDED.summary,
|
| 156 |
+
message_count = EXCLUDED.message_count,
|
| 157 |
+
last_updated = CURRENT_TIMESTAMP
|
| 158 |
+
"""
|
| 159 |
+
elif db_type == "sqlite":
|
| 160 |
+
return """
|
| 161 |
+
INSERT INTO _chatbot_user_summaries
|
| 162 |
+
(user_id, summary, message_count, last_updated)
|
| 163 |
+
VALUES (:user_id, :summary, :message_count, CURRENT_TIMESTAMP)
|
| 164 |
+
ON CONFLICT(user_id)
|
| 165 |
+
DO UPDATE SET
|
| 166 |
+
summary = excluded.summary,
|
| 167 |
+
message_count = excluded.message_count,
|
| 168 |
+
last_updated = CURRENT_TIMESTAMP
|
| 169 |
+
"""
|
| 170 |
+
else: # MySQL
|
| 171 |
+
return """
|
| 172 |
+
INSERT INTO _chatbot_user_summaries
|
| 173 |
+
(user_id, summary, message_count)
|
| 174 |
+
VALUES (:user_id, :summary, :message_count)
|
| 175 |
+
ON DUPLICATE KEY UPDATE
|
| 176 |
+
summary = :summary,
|
| 177 |
+
message_count = :message_count,
|
| 178 |
+
last_updated = CURRENT_TIMESTAMP
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ChatMemory:
|
| 183 |
+
"""Manages chat history with short-term and long-term storage."""
|
| 184 |
+
|
| 185 |
+
def __init__(self, session_id: str, user_id: str = "default", max_messages: int = 20, db_connection=None):
|
| 186 |
+
self.session_id = session_id
|
| 187 |
+
self.user_id = user_id
|
| 188 |
+
self.max_messages = max_messages
|
| 189 |
+
self.db = db_connection
|
| 190 |
+
self.messages: List[ChatMessage] = []
|
| 191 |
+
self._db_type = None
|
| 192 |
+
|
| 193 |
+
if self.db:
|
| 194 |
+
self._db_type = self.db.db_type.value
|
| 195 |
+
self._ensure_tables()
|
| 196 |
+
|
| 197 |
+
def _ensure_tables(self):
|
| 198 |
+
"""Create memory tables if they don't exist."""
|
| 199 |
+
try:
|
| 200 |
+
memory_ddl = get_memory_table_ddl(self._db_type)
|
| 201 |
+
permanent_ddl = get_permanent_memory_ddl(self._db_type)
|
| 202 |
+
|
| 203 |
+
self.db.execute_write(memory_ddl)
|
| 204 |
+
self.db.execute_write(permanent_ddl)
|
| 205 |
+
|
| 206 |
+
# Create indexes for SQLite and PostgreSQL (MySQL creates them inline)
|
| 207 |
+
if self._db_type in ("sqlite", "postgresql"):
|
| 208 |
+
self._create_indexes()
|
| 209 |
+
|
| 210 |
+
# Migration: Ensure user_id column exists (MySQL only for legacy support)
|
| 211 |
+
if self._db_type == "mysql":
|
| 212 |
+
self._migrate_mysql_user_id()
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.warning(f"Failed to create memory tables: {e}")
|
| 216 |
+
|
| 217 |
+
def _create_indexes(self):
|
| 218 |
+
"""Create indexes for SQLite and PostgreSQL."""
|
| 219 |
+
try:
|
| 220 |
+
if self._db_type == "sqlite":
|
| 221 |
+
self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_session ON _chatbot_memory(session_id)")
|
| 222 |
+
self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_user ON _chatbot_memory(user_id)")
|
| 223 |
+
self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_created ON _chatbot_memory(created_at)")
|
| 224 |
+
self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_permanent_user ON _chatbot_permanent_memory_v2(user_id)")
|
| 225 |
+
elif self._db_type == "postgresql":
|
| 226 |
+
self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_session ON _chatbot_memory(session_id)")
|
| 227 |
+
self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_user ON _chatbot_memory(user_id)")
|
| 228 |
+
self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_memory_created ON _chatbot_memory(created_at)")
|
| 229 |
+
self.db.execute_write("CREATE INDEX IF NOT EXISTS idx_permanent_user ON _chatbot_permanent_memory_v2(user_id)")
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.debug(f"Index creation (may already exist): {e}")
|
| 232 |
+
|
| 233 |
+
def _migrate_mysql_user_id(self):
|
| 234 |
+
"""Migrate MySQL table to include user_id column if missing."""
|
| 235 |
+
try:
|
| 236 |
+
check_query = """
|
| 237 |
+
SELECT COLUMN_NAME
|
| 238 |
+
FROM INFORMATION_SCHEMA.COLUMNS
|
| 239 |
+
WHERE TABLE_SCHEMA = :db_name
|
| 240 |
+
AND TABLE_NAME = '_chatbot_memory'
|
| 241 |
+
AND COLUMN_NAME = 'user_id'
|
| 242 |
+
"""
|
| 243 |
+
db_name = self.db.config.database
|
| 244 |
+
result = self.db.execute_query(check_query, {"db_name": db_name})
|
| 245 |
+
|
| 246 |
+
if not result:
|
| 247 |
+
self.db.execute_write("ALTER TABLE _chatbot_memory ADD COLUMN user_id VARCHAR(255) NOT NULL DEFAULT 'default' AFTER session_id")
|
| 248 |
+
self.db.execute_write("CREATE INDEX idx_user ON _chatbot_memory(user_id)")
|
| 249 |
+
logger.info("Migrated _chatbot_memory to include user_id")
|
| 250 |
+
except Exception as e:
|
| 251 |
+
logger.debug(f"Migration check failed: {e}")
|
| 252 |
+
|
| 253 |
+
def add_message(self, role: str, content: str, metadata: Dict = None):
|
| 254 |
+
"""Add a message to memory and optionally persist it."""
|
| 255 |
+
msg = ChatMessage(role=role, content=content, metadata=metadata)
|
| 256 |
+
self.messages.append(msg)
|
| 257 |
+
|
| 258 |
+
# Trim if exceeds max (short-term)
|
| 259 |
+
if len(self.messages) > self.max_messages:
|
| 260 |
+
self.messages = self.messages[-self.max_messages:]
|
| 261 |
+
|
| 262 |
+
# Persist to DB (session history)
|
| 263 |
+
if self.db:
|
| 264 |
+
try:
|
| 265 |
+
query = """
|
| 266 |
+
INSERT INTO _chatbot_memory (session_id, user_id, role, content, metadata)
|
| 267 |
+
VALUES (:session_id, :user_id, :role, :content, :metadata)
|
| 268 |
+
"""
|
| 269 |
+
self.db.execute_write(query, {
|
| 270 |
+
"session_id": self.session_id,
|
| 271 |
+
"user_id": self.user_id,
|
| 272 |
+
"role": role,
|
| 273 |
+
"content": content,
|
| 274 |
+
"metadata": json.dumps(metadata) if metadata else None
|
| 275 |
+
})
|
| 276 |
+
except Exception as e:
|
| 277 |
+
logger.warning(f"Failed to persist message: {e}")
|
| 278 |
+
|
| 279 |
+
def save_permanent_context(self, content: str, tags: str = "user_saved"):
|
| 280 |
+
"""Save specific context explicitly to permanent memory for this user."""
|
| 281 |
+
if not self.db:
|
| 282 |
+
return False, "No database connection"
|
| 283 |
+
|
| 284 |
+
try:
|
| 285 |
+
query = """
|
| 286 |
+
INSERT INTO _chatbot_permanent_memory_v2 (user_id, content, tags)
|
| 287 |
+
VALUES (:user_id, :content, :tags)
|
| 288 |
+
"""
|
| 289 |
+
self.db.execute_write(query, {
|
| 290 |
+
"user_id": self.user_id,
|
| 291 |
+
"content": content,
|
| 292 |
+
"tags": tags
|
| 293 |
+
})
|
| 294 |
+
return True, "Context saved to permanent memory"
|
| 295 |
+
except Exception as e:
|
| 296 |
+
logger.error(f"Failed to save permanent context: {e}")
|
| 297 |
+
return False, str(e)
|
| 298 |
+
|
| 299 |
+
def get_permanent_context(self, limit: int = 5) -> List[str]:
|
| 300 |
+
"""Retrieve recent permanent context for this user only."""
|
| 301 |
+
if not self.db:
|
| 302 |
+
return []
|
| 303 |
+
|
| 304 |
+
try:
|
| 305 |
+
# Use database-agnostic LIMIT syntax
|
| 306 |
+
query = """
|
| 307 |
+
SELECT content FROM _chatbot_permanent_memory_v2
|
| 308 |
+
WHERE user_id = :user_id
|
| 309 |
+
ORDER BY created_at DESC LIMIT :limit
|
| 310 |
+
"""
|
| 311 |
+
rows = self.db.execute_query(query, {
|
| 312 |
+
"user_id": self.user_id,
|
| 313 |
+
"limit": limit
|
| 314 |
+
})
|
| 315 |
+
return [row['content'] for row in rows]
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.warning(f"Failed to load permanent context: {e}")
|
| 318 |
+
return []
|
| 319 |
+
|
| 320 |
+
def get_messages(self, limit: Optional[int] = None) -> List[Dict[str, str]]:
|
| 321 |
+
"""Get messages for LLM context."""
|
| 322 |
+
msgs = self.messages if limit is None else self.messages[-limit:]
|
| 323 |
+
return [m.to_dict() for m in msgs]
|
| 324 |
+
|
| 325 |
+
def get_context_messages(self, count: int = 5) -> List[Dict[str, str]]:
|
| 326 |
+
"""Get recent messages plus permanent context for injection."""
|
| 327 |
+
# Get short-term session messages
|
| 328 |
+
context = self.get_messages(limit=count)
|
| 329 |
+
|
| 330 |
+
# Inject permanent memory if available
|
| 331 |
+
perm_docs = self.get_permanent_context(limit=3)
|
| 332 |
+
if perm_docs:
|
| 333 |
+
perm_context = f"IMPORTANT CONTEXT FOR USER '{self.user_id}':\n" + "\n".join(perm_docs)
|
| 334 |
+
# Add as a system note at the start
|
| 335 |
+
context.insert(0, {"role": "system", "content": perm_context})
|
| 336 |
+
|
| 337 |
+
return context
|
| 338 |
+
|
| 339 |
+
def clear(self):
|
| 340 |
+
"""Clear current session memory and remove from DB (temporary history)."""
|
| 341 |
+
self.messages = []
|
| 342 |
+
|
| 343 |
+
if self.db:
|
| 344 |
+
try:
|
| 345 |
+
# Delete temporary messages for this session
|
| 346 |
+
query = "DELETE FROM _chatbot_memory WHERE session_id = :session_id"
|
| 347 |
+
self.db.execute_write(query, {"session_id": self.session_id})
|
| 348 |
+
logger.info(f"Cleared session memory for {self.session_id}")
|
| 349 |
+
except Exception as e:
|
| 350 |
+
logger.warning(f"Failed to clear memory from DB: {e}")
|
| 351 |
+
|
| 352 |
+
def clear_user_history(self):
|
| 353 |
+
"""Clear ALL temporary history for this user (across all sessions)."""
|
| 354 |
+
self.messages = []
|
| 355 |
+
if self.db:
|
| 356 |
+
try:
|
| 357 |
+
query = "DELETE FROM _chatbot_memory WHERE user_id = :user_id"
|
| 358 |
+
self.db.execute_write(query, {"user_id": self.user_id})
|
| 359 |
+
logger.info(f"Cleared all temporary history for user: {self.user_id}")
|
| 360 |
+
except Exception as e:
|
| 361 |
+
logger.warning(f"Failed to clear user history from DB: {e}")
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class ConversationSummaryMemory:
|
| 365 |
+
"""
|
| 366 |
+
Per-user conversation summary memory using LLM for summarization.
|
| 367 |
+
|
| 368 |
+
This class maintains a running summary of the conversation, updating it
|
| 369 |
+
periodically (when message count exceeds threshold). This dramatically
|
| 370 |
+
reduces token usage while preserving context for long conversations.
|
| 371 |
+
|
| 372 |
+
Features:
|
| 373 |
+
- Automatic summarization when threshold is reached
|
| 374 |
+
- Per-user summary storage in database
|
| 375 |
+
- Combines summary + recent messages for optimal context
|
| 376 |
+
- Lazy summarization (only when needed)
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
SUMMARIZATION_PROMPT = """You are a conversation summarizer. Create a concise summary of the conversation below that captures:
|
| 380 |
+
1. Key topics discussed
|
| 381 |
+
2. Important facts or preferences mentioned by the user
|
| 382 |
+
3. Any decisions or conclusions reached
|
| 383 |
+
4. Context needed for follow-up questions
|
| 384 |
+
|
| 385 |
+
Keep the summary under 300 words but include all important details.
|
| 386 |
+
|
| 387 |
+
CONVERSATION:
|
| 388 |
+
{conversation}
|
| 389 |
+
|
| 390 |
+
SUMMARY:"""
|
| 391 |
+
|
| 392 |
+
INCREMENTAL_SUMMARY_PROMPT = """You are a conversation summarizer. Update the existing summary to incorporate new messages.
|
| 393 |
+
|
| 394 |
+
EXISTING SUMMARY:
|
| 395 |
+
{existing_summary}
|
| 396 |
+
|
| 397 |
+
NEW MESSAGES:
|
| 398 |
+
{new_messages}
|
| 399 |
+
|
| 400 |
+
Create an updated, comprehensive summary that:
|
| 401 |
+
1. Incorporates new information from the recent messages
|
| 402 |
+
2. Retains important context from the existing summary
|
| 403 |
+
3. Removes redundant or outdated information
|
| 404 |
+
4. Stays under 300 words
|
| 405 |
+
|
| 406 |
+
UPDATED SUMMARY:"""
|
| 407 |
+
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
user_id: str,
|
| 411 |
+
session_id: str,
|
| 412 |
+
db_connection=None,
|
| 413 |
+
llm_client=None,
|
| 414 |
+
summary_threshold: int = 10, # Summarize every N messages
|
| 415 |
+
recent_messages_count: int = 5 # Keep this many recent messages verbatim
|
| 416 |
+
):
|
| 417 |
+
self.user_id = user_id
|
| 418 |
+
self.session_id = session_id
|
| 419 |
+
self.db = db_connection
|
| 420 |
+
self.llm = llm_client
|
| 421 |
+
self.summary_threshold = summary_threshold
|
| 422 |
+
self.recent_messages_count = recent_messages_count
|
| 423 |
+
self._db_type = None
|
| 424 |
+
|
| 425 |
+
self._cached_summary: Optional[str] = None
|
| 426 |
+
self._messages_since_summary: int = 0
|
| 427 |
+
|
| 428 |
+
if self.db:
|
| 429 |
+
self._db_type = self.db.db_type.value
|
| 430 |
+
self._ensure_tables()
|
| 431 |
+
self._load_state()
|
| 432 |
+
|
| 433 |
+
def _ensure_tables(self):
|
| 434 |
+
"""Create summary table if it doesn't exist."""
|
| 435 |
+
try:
|
| 436 |
+
ddl = get_summary_table_ddl(self._db_type)
|
| 437 |
+
self.db.execute_write(ddl)
|
| 438 |
+
except Exception as e:
|
| 439 |
+
logger.warning(f"Failed to create summary table: {e}")
|
| 440 |
+
|
| 441 |
+
def _load_state(self):
|
| 442 |
+
"""Load existing summary state from database (per-user, not per-session)."""
|
| 443 |
+
try:
|
| 444 |
+
query = """
|
| 445 |
+
SELECT summary, message_count FROM _chatbot_user_summaries
|
| 446 |
+
WHERE user_id = :user_id
|
| 447 |
+
"""
|
| 448 |
+
rows = self.db.execute_query(query, {
|
| 449 |
+
"user_id": self.user_id
|
| 450 |
+
})
|
| 451 |
+
if rows:
|
| 452 |
+
self._cached_summary = rows[0].get('summary')
|
| 453 |
+
self._messages_since_summary = 0 # Reset since we loaded
|
| 454 |
+
logger.debug(f"Loaded summary for user {self.user_id}")
|
| 455 |
+
except Exception as e:
|
| 456 |
+
logger.warning(f"Failed to load summary state: {e}")
|
| 457 |
+
|
| 458 |
+
def set_llm_client(self, llm_client):
|
| 459 |
+
"""Set the LLM client for summarization."""
|
| 460 |
+
self.llm = llm_client
|
| 461 |
+
|
| 462 |
+
def on_message_added(self, message_count: int):
|
| 463 |
+
"""
|
| 464 |
+
Called after a message is added to track when to summarize.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
message_count: Current total number of messages in the conversation
|
| 468 |
+
"""
|
| 469 |
+
self._messages_since_summary += 1
|
| 470 |
+
|
| 471 |
+
# Check if we should summarize
|
| 472 |
+
if self._messages_since_summary >= self.summary_threshold:
|
| 473 |
+
self._trigger_summarization()
|
| 474 |
+
|
| 475 |
+
def _trigger_summarization(self):
|
| 476 |
+
"""Trigger summarization of the conversation."""
|
| 477 |
+
if not self.llm:
|
| 478 |
+
logger.warning("Cannot summarize: No LLM client configured")
|
| 479 |
+
return
|
| 480 |
+
|
| 481 |
+
if not self.db:
|
| 482 |
+
logger.warning("Cannot summarize: No database connection")
|
| 483 |
+
return
|
| 484 |
+
|
| 485 |
+
try:
|
| 486 |
+
# Get messages that need to be summarized
|
| 487 |
+
query = """
|
| 488 |
+
SELECT role, content FROM _chatbot_memory
|
| 489 |
+
WHERE user_id = :user_id AND session_id = :session_id
|
| 490 |
+
ORDER BY created_at ASC
|
| 491 |
+
"""
|
| 492 |
+
rows = self.db.execute_query(query, {
|
| 493 |
+
"user_id": self.user_id,
|
| 494 |
+
"session_id": self.session_id
|
| 495 |
+
})
|
| 496 |
+
|
| 497 |
+
if not rows:
|
| 498 |
+
return
|
| 499 |
+
|
| 500 |
+
# Format conversation for summarization
|
| 501 |
+
conversation_text = self._format_messages_for_summary(rows)
|
| 502 |
+
|
| 503 |
+
# Generate summary
|
| 504 |
+
if self._cached_summary:
|
| 505 |
+
# Incremental update
|
| 506 |
+
prompt = self.INCREMENTAL_SUMMARY_PROMPT.format(
|
| 507 |
+
existing_summary=self._cached_summary,
|
| 508 |
+
new_messages=conversation_text
|
| 509 |
+
)
|
| 510 |
+
else:
|
| 511 |
+
# Fresh summary
|
| 512 |
+
prompt = self.SUMMARIZATION_PROMPT.format(conversation=conversation_text)
|
| 513 |
+
|
| 514 |
+
messages = [
|
| 515 |
+
{"role": "system", "content": "You are a helpful assistant that creates concise conversation summaries."},
|
| 516 |
+
{"role": "user", "content": prompt}
|
| 517 |
+
]
|
| 518 |
+
|
| 519 |
+
summary = self.llm.chat(messages)
|
| 520 |
+
|
| 521 |
+
# Save to database
|
| 522 |
+
self._save_summary(summary, len(rows))
|
| 523 |
+
|
| 524 |
+
self._cached_summary = summary
|
| 525 |
+
self._messages_since_summary = 0
|
| 526 |
+
|
| 527 |
+
logger.info(f"Generated summary for user {self.user_id}")
|
| 528 |
+
|
| 529 |
+
except Exception as e:
|
| 530 |
+
logger.error(f"Summarization failed: {e}")
|
| 531 |
+
|
| 532 |
+
def _format_messages_for_summary(self, messages: List[Dict]) -> str:
|
| 533 |
+
"""Format messages as text for summarization."""
|
| 534 |
+
lines = []
|
| 535 |
+
for msg in messages:
|
| 536 |
+
role = msg.get('role', 'unknown').upper()
|
| 537 |
+
content = msg.get('content', '')
|
| 538 |
+
lines.append(f"{role}: {content}")
|
| 539 |
+
return "\n\n".join(lines)
|
| 540 |
+
|
| 541 |
+
def _save_summary(self, summary: str, message_count: int):
|
| 542 |
+
"""Save or update summary in database (per-user)."""
|
| 543 |
+
try:
|
| 544 |
+
query = get_upsert_summary_query(self._db_type)
|
| 545 |
+
self.db.execute_write(query, {
|
| 546 |
+
"user_id": self.user_id,
|
| 547 |
+
"summary": summary,
|
| 548 |
+
"message_count": message_count
|
| 549 |
+
})
|
| 550 |
+
except Exception as e:
|
| 551 |
+
logger.error(f"Failed to save summary: {e}")
|
| 552 |
+
|
| 553 |
+
def get_summary(self) -> Optional[str]:
|
| 554 |
+
"""Get the current conversation summary."""
|
| 555 |
+
return self._cached_summary
|
| 556 |
+
|
| 557 |
+
def get_context_for_llm(self, recent_messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
| 558 |
+
"""
|
| 559 |
+
Get optimized context for LLM calls.
|
| 560 |
+
|
| 561 |
+
Combines the summary (if available) with recent messages for optimal
|
| 562 |
+
token usage while maintaining context.
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
recent_messages: List of recent messages to include verbatim
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
List of messages with summary prepended as system context
|
| 569 |
+
"""
|
| 570 |
+
context_messages = []
|
| 571 |
+
|
| 572 |
+
# Add summary as system context if available
|
| 573 |
+
if self._cached_summary:
|
| 574 |
+
summary_context = f"""CONVERSATION SUMMARY (previous context):
|
| 575 |
+
{self._cached_summary}
|
| 576 |
+
|
| 577 |
+
Use this summary to understand the conversation history and context for follow-up questions."""
|
| 578 |
+
context_messages.append({
|
| 579 |
+
"role": "system",
|
| 580 |
+
"content": summary_context
|
| 581 |
+
})
|
| 582 |
+
|
| 583 |
+
# Add recent messages verbatim
|
| 584 |
+
context_messages.extend(recent_messages[-self.recent_messages_count:])
|
| 585 |
+
|
| 586 |
+
return context_messages
|
| 587 |
+
|
| 588 |
+
def force_summarize(self):
|
| 589 |
+
"""Force immediate summarization regardless of threshold."""
|
| 590 |
+
self._trigger_summarization()
|
| 591 |
+
|
| 592 |
+
def clear_summary(self):
|
| 593 |
+
"""Clear the summary for this user."""
|
| 594 |
+
self._cached_summary = None
|
| 595 |
+
self._messages_since_summary = 0
|
| 596 |
+
|
| 597 |
+
if self.db:
|
| 598 |
+
try:
|
| 599 |
+
query = "DELETE FROM _chatbot_user_summaries WHERE user_id = :user_id"
|
| 600 |
+
self.db.execute_write(query, {
|
| 601 |
+
"user_id": self.user_id
|
| 602 |
+
})
|
| 603 |
+
logger.info(f"Cleared summary for user: {self.user_id}")
|
| 604 |
+
except Exception as e:
|
| 605 |
+
logger.warning(f"Failed to clear summary: {e}")
|
| 606 |
+
|
| 607 |
+
def clear_all_user_summaries(self):
|
| 608 |
+
"""Clear all summaries for this user (alias for clear_summary since it's now per-user)."""
|
| 609 |
+
self.clear_summary()
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class EnhancedChatMemory(ChatMemory):
|
| 613 |
+
"""
|
| 614 |
+
Enhanced ChatMemory with integrated conversation summarization.
|
| 615 |
+
|
| 616 |
+
Combines the standard ChatMemory functionality with ConversationSummaryMemory
|
| 617 |
+
for automatic summarization and optimized context retrieval.
|
| 618 |
+
"""
|
| 619 |
+
|
| 620 |
+
def __init__(
|
| 621 |
+
self,
|
| 622 |
+
session_id: str,
|
| 623 |
+
user_id: str = "default",
|
| 624 |
+
max_messages: int = 20,
|
| 625 |
+
db_connection=None,
|
| 626 |
+
llm_client=None,
|
| 627 |
+
enable_summarization: bool = True,
|
| 628 |
+
summary_threshold: int = 10
|
| 629 |
+
):
|
| 630 |
+
super().__init__(session_id, user_id, max_messages, db_connection)
|
| 631 |
+
|
| 632 |
+
self.enable_summarization = enable_summarization
|
| 633 |
+
self.summary_memory: Optional[ConversationSummaryMemory] = None
|
| 634 |
+
|
| 635 |
+
if enable_summarization:
|
| 636 |
+
self.summary_memory = ConversationSummaryMemory(
|
| 637 |
+
user_id=user_id,
|
| 638 |
+
session_id=session_id,
|
| 639 |
+
db_connection=db_connection,
|
| 640 |
+
llm_client=llm_client,
|
| 641 |
+
summary_threshold=summary_threshold
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
def set_llm_client(self, llm_client):
|
| 645 |
+
"""Set the LLM client for summarization."""
|
| 646 |
+
if self.summary_memory:
|
| 647 |
+
self.summary_memory.set_llm_client(llm_client)
|
| 648 |
+
|
| 649 |
+
def add_message(self, role: str, content: str, metadata: Dict = None):
|
| 650 |
+
"""Add a message and trigger summarization check."""
|
| 651 |
+
super().add_message(role, content, metadata)
|
| 652 |
+
|
| 653 |
+
# Notify summary memory of new message
|
| 654 |
+
if self.summary_memory:
|
| 655 |
+
self.summary_memory.on_message_added(len(self.messages))
|
| 656 |
+
|
| 657 |
+
def get_context_messages(self, count: int = 5) -> List[Dict[str, str]]:
|
| 658 |
+
"""
|
| 659 |
+
Get context messages with summary integration.
|
| 660 |
+
|
| 661 |
+
If summarization is enabled and a summary exists, it will be
|
| 662 |
+
prepended to provide historical context while keeping recent
|
| 663 |
+
messages verbatim.
|
| 664 |
+
"""
|
| 665 |
+
# Get base context from parent
|
| 666 |
+
base_context = super().get_context_messages(count)
|
| 667 |
+
|
| 668 |
+
# If summarization is enabled, use summary-enhanced context
|
| 669 |
+
if self.summary_memory and self.summary_memory.get_summary():
|
| 670 |
+
# Filter out system messages from base context (we'll add summary separately)
|
| 671 |
+
filtered = [m for m in base_context if m.get("role") != "system"]
|
| 672 |
+
|
| 673 |
+
# Get summary-enhanced context
|
| 674 |
+
enhanced = self.summary_memory.get_context_for_llm(filtered)
|
| 675 |
+
|
| 676 |
+
# Re-add permanent memory context if it was present
|
| 677 |
+
for msg in base_context:
|
| 678 |
+
if msg.get("role") == "system" and "IMPORTANT CONTEXT" in msg.get("content", ""):
|
| 679 |
+
enhanced.insert(0, msg)
|
| 680 |
+
|
| 681 |
+
return enhanced
|
| 682 |
+
|
| 683 |
+
return base_context
|
| 684 |
+
|
| 685 |
+
def get_summary(self) -> Optional[str]:
|
| 686 |
+
"""Get the current conversation summary."""
|
| 687 |
+
if self.summary_memory:
|
| 688 |
+
return self.summary_memory.get_summary()
|
| 689 |
+
return None
|
| 690 |
+
|
| 691 |
+
def force_summarize(self):
|
| 692 |
+
"""Force immediate summarization."""
|
| 693 |
+
if self.summary_memory:
|
| 694 |
+
self.summary_memory.force_summarize()
|
| 695 |
+
|
| 696 |
+
def clear(self):
|
| 697 |
+
"""Clear session memory but KEEP the summary (long-term memory)."""
|
| 698 |
+
super().clear()
|
| 699 |
+
# NOTE: Summary is intentionally NOT cleared here
|
| 700 |
+
# Summary acts as long-term memory that persists across chat sessions
|
| 701 |
+
|
| 702 |
+
def clear_with_summary(self):
|
| 703 |
+
"""Clear session memory AND the summary (full reset)."""
|
| 704 |
+
super().clear()
|
| 705 |
+
if self.summary_memory:
|
| 706 |
+
self.summary_memory.clear_summary()
|
| 707 |
+
|
| 708 |
+
def clear_user_history(self):
|
| 709 |
+
"""Clear all user temp history but KEEP summaries."""
|
| 710 |
+
super().clear_user_history()
|
| 711 |
+
# NOTE: Summaries are intentionally NOT cleared
|
| 712 |
+
# They persist as long-term memory for the user
|
| 713 |
+
|
| 714 |
+
def clear_all_including_summaries(self):
|
| 715 |
+
"""Clear ALL user data including summaries (complete wipe)."""
|
| 716 |
+
super().clear_user_history()
|
| 717 |
+
if self.summary_memory:
|
| 718 |
+
self.summary_memory.clear_all_user_summaries()
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def create_memory(session_id: str, user_id: str = "default", max_messages: int = 20) -> ChatMemory:
|
| 722 |
+
"""Create a standard ChatMemory instance."""
|
| 723 |
+
from database import get_db
|
| 724 |
+
db = get_db()
|
| 725 |
+
return ChatMemory(session_id=session_id, user_id=user_id, max_messages=max_messages, db_connection=db)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def create_enhanced_memory(
|
| 729 |
+
session_id: str,
|
| 730 |
+
user_id: str = "default",
|
| 731 |
+
max_messages: int = 20,
|
| 732 |
+
llm_client=None,
|
| 733 |
+
enable_summarization: bool = True,
|
| 734 |
+
summary_threshold: int = 10
|
| 735 |
+
) -> EnhancedChatMemory:
|
| 736 |
+
"""
|
| 737 |
+
Create an EnhancedChatMemory with summarization support.
|
| 738 |
+
|
| 739 |
+
Args:
|
| 740 |
+
session_id: Unique session identifier
|
| 741 |
+
user_id: User identifier for per-user memory isolation
|
| 742 |
+
max_messages: Maximum messages to keep in short-term memory
|
| 743 |
+
llm_client: LLM client for summarization (can be set later)
|
| 744 |
+
enable_summarization: Whether to enable automatic summarization
|
| 745 |
+
summary_threshold: Summarize after this many messages
|
| 746 |
+
|
| 747 |
+
Returns:
|
| 748 |
+
EnhancedChatMemory instance with summarization capabilities
|
| 749 |
+
"""
|
| 750 |
+
from database import get_db
|
| 751 |
+
db = get_db()
|
| 752 |
+
return EnhancedChatMemory(
|
| 753 |
+
session_id=session_id,
|
| 754 |
+
user_id=user_id,
|
| 755 |
+
max_messages=max_messages,
|
| 756 |
+
db_connection=db,
|
| 757 |
+
llm_client=llm_client,
|
| 758 |
+
enable_summarization=enable_summarization,
|
| 759 |
+
summary_threshold=summary_threshold
|
| 760 |
+
)
|
rag/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAG module exports."""
|
| 2 |
+
|
| 3 |
+
from .embeddings import (
|
| 4 |
+
EmbeddingProvider,
|
| 5 |
+
SentenceTransformerEmbedding,
|
| 6 |
+
OpenAIEmbedding,
|
| 7 |
+
get_embedding_provider,
|
| 8 |
+
create_embedding_provider
|
| 9 |
+
)
|
| 10 |
+
from .document_processor import Document, DocumentProcessor, get_document_processor
|
| 11 |
+
from .vector_store import VectorStore, get_vector_store
|
| 12 |
+
from .rag_engine import RAGEngine, get_rag_engine
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"EmbeddingProvider", "SentenceTransformerEmbedding", "OpenAIEmbedding",
|
| 16 |
+
"get_embedding_provider", "create_embedding_provider",
|
| 17 |
+
"Document", "DocumentProcessor", "get_document_processor",
|
| 18 |
+
"VectorStore", "get_vector_store",
|
| 19 |
+
"RAGEngine", "get_rag_engine"
|
| 20 |
+
]
|
rag/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (873 Bytes). View file
|
|
|
rag/__pycache__/document_processor.cpython-311.pyc
ADDED
|
Binary file (6.98 kB). View file
|
|
|
rag/__pycache__/embeddings.cpython-311.pyc
ADDED
|
Binary file (9.78 kB). View file
|
|
|
rag/__pycache__/rag_engine.cpython-311.pyc
ADDED
|
Binary file (5.62 kB). View file
|
|
|
rag/__pycache__/vector_store.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
rag/document_processor.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document Processor for RAG.
|
| 3 |
+
|
| 4 |
+
Converts database rows into semantic documents for embedding.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import hashlib
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import List, Dict, Any, Optional, Generator
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Document:
|
| 18 |
+
"""Semantic document from the database."""
|
| 19 |
+
id: str
|
| 20 |
+
content: str
|
| 21 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 22 |
+
table_name: str = ""
|
| 23 |
+
column_name: str = ""
|
| 24 |
+
primary_key_value: Optional[str] = None
|
| 25 |
+
chunk_index: int = 0
|
| 26 |
+
total_chunks: int = 1
|
| 27 |
+
|
| 28 |
+
def __post_init__(self):
|
| 29 |
+
if not self.id:
|
| 30 |
+
hash_input = f"{self.table_name}:{self.column_name}:{self.primary_key_value}:{self.chunk_index}"
|
| 31 |
+
self.id = hashlib.md5(hash_input.encode()).hexdigest()
|
| 32 |
+
|
| 33 |
+
def to_context_string(self) -> str:
|
| 34 |
+
source = f"[Source: {self.table_name}.{self.column_name}"
|
| 35 |
+
if self.primary_key_value:
|
| 36 |
+
source += f" (id: {self.primary_key_value})"
|
| 37 |
+
source += "]"
|
| 38 |
+
return f"{source}\n{self.content}"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TextChunker:
|
| 42 |
+
"""Splits long text into overlapping chunks."""
|
| 43 |
+
|
| 44 |
+
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
|
| 45 |
+
self.chunk_size = chunk_size
|
| 46 |
+
self.chunk_overlap = chunk_overlap
|
| 47 |
+
self.sentence_pattern = re.compile(r'(?<=[.!?])\s+(?=[A-Z])')
|
| 48 |
+
|
| 49 |
+
def chunk_text(self, text: str) -> List[str]:
|
| 50 |
+
if not text or len(text) <= self.chunk_size:
|
| 51 |
+
return [text] if text else []
|
| 52 |
+
|
| 53 |
+
sentences = self.sentence_pattern.split(text)
|
| 54 |
+
chunks = []
|
| 55 |
+
current_chunk = []
|
| 56 |
+
current_length = 0
|
| 57 |
+
|
| 58 |
+
for sentence in sentences:
|
| 59 |
+
sentence = sentence.strip()
|
| 60 |
+
if not sentence:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
if current_length + len(sentence) + 1 > self.chunk_size:
|
| 64 |
+
if current_chunk:
|
| 65 |
+
chunks.append(' '.join(current_chunk))
|
| 66 |
+
current_chunk = [sentence]
|
| 67 |
+
current_length = len(sentence)
|
| 68 |
+
else:
|
| 69 |
+
current_chunk.append(sentence)
|
| 70 |
+
current_length += len(sentence) + 1
|
| 71 |
+
|
| 72 |
+
if current_chunk:
|
| 73 |
+
chunks.append(' '.join(current_chunk))
|
| 74 |
+
|
| 75 |
+
return chunks if chunks else [text]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DocumentProcessor:
|
| 79 |
+
"""Converts database rows into semantic documents."""
|
| 80 |
+
|
| 81 |
+
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
|
| 82 |
+
self.chunker = TextChunker(chunk_size, chunk_overlap)
|
| 83 |
+
|
| 84 |
+
def process_row(
|
| 85 |
+
self, row: Dict[str, Any], table_name: str,
|
| 86 |
+
text_columns: List[str], primary_key_column: Optional[str] = None
|
| 87 |
+
) -> List[Document]:
|
| 88 |
+
documents = []
|
| 89 |
+
pk_value = str(row.get(primary_key_column, "")) if primary_key_column else None
|
| 90 |
+
|
| 91 |
+
for column_name in text_columns:
|
| 92 |
+
text = row.get(column_name)
|
| 93 |
+
if not text or not isinstance(text, str):
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
text = text.strip()
|
| 97 |
+
if not text:
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
chunks = self.chunker.chunk_text(text)
|
| 101 |
+
for i, chunk in enumerate(chunks):
|
| 102 |
+
doc = Document(
|
| 103 |
+
id="", content=chunk, table_name=table_name,
|
| 104 |
+
column_name=column_name, primary_key_value=pk_value,
|
| 105 |
+
chunk_index=i, total_chunks=len(chunks),
|
| 106 |
+
metadata={"table": table_name, "column": column_name, "pk": pk_value}
|
| 107 |
+
)
|
| 108 |
+
documents.append(doc)
|
| 109 |
+
|
| 110 |
+
return documents
|
| 111 |
+
|
| 112 |
+
def process_rows(
|
| 113 |
+
self, rows: List[Dict[str, Any]], table_name: str,
|
| 114 |
+
text_columns: List[str], primary_key_column: Optional[str] = None
|
| 115 |
+
) -> Generator[Document, None, None]:
|
| 116 |
+
for row in rows:
|
| 117 |
+
for doc in self.process_row(row, table_name, text_columns, primary_key_column):
|
| 118 |
+
yield doc
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_document_processor(chunk_size: int = 500, chunk_overlap: int = 50) -> DocumentProcessor:
|
| 122 |
+
return DocumentProcessor(chunk_size, chunk_overlap)
|
rag/embeddings.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Embedding Generation Module.
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
- Sentence Transformers (local, free)
|
| 6 |
+
- OpenAI Embeddings (cloud, paid)
|
| 7 |
+
|
| 8 |
+
Configurable via environment variables.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from abc import ABC, abstractmethod
|
| 13 |
+
from typing import List, Optional
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EmbeddingProvider(ABC):
|
| 20 |
+
"""Abstract base class for embedding providers."""
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def embed_text(self, text: str) -> np.ndarray:
|
| 24 |
+
"""Generate embedding for a single text."""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
| 29 |
+
"""Generate embeddings for multiple texts."""
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def dimension(self) -> int:
|
| 35 |
+
"""Return the embedding dimension."""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SentenceTransformerEmbedding(EmbeddingProvider):
|
| 40 |
+
"""
|
| 41 |
+
Sentence Transformers embedding provider.
|
| 42 |
+
|
| 43 |
+
Uses local models, no API key required.
|
| 44 |
+
Default: all-MiniLM-L6-v2 (384 dimensions)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
| 48 |
+
"""
|
| 49 |
+
Initialize the Sentence Transformer model.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
model_name: HuggingFace model name
|
| 53 |
+
"""
|
| 54 |
+
self.model_name = model_name
|
| 55 |
+
self._model = None
|
| 56 |
+
self._dimension = None
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def model(self):
|
| 60 |
+
"""Lazy load the model."""
|
| 61 |
+
if self._model is None:
|
| 62 |
+
try:
|
| 63 |
+
from sentence_transformers import SentenceTransformer
|
| 64 |
+
logger.info(f"Loading embedding model: {self.model_name}")
|
| 65 |
+
self._model = SentenceTransformer(self.model_name)
|
| 66 |
+
self._dimension = self._model.get_sentence_embedding_dimension()
|
| 67 |
+
logger.info(f"Model loaded. Embedding dimension: {self._dimension}")
|
| 68 |
+
except ImportError:
|
| 69 |
+
raise ImportError(
|
| 70 |
+
"sentence-transformers is required. Install with: pip install sentence-transformers"
|
| 71 |
+
)
|
| 72 |
+
return self._model
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def dimension(self) -> int:
|
| 76 |
+
"""Get embedding dimension."""
|
| 77 |
+
if self._dimension is None:
|
| 78 |
+
_ = self.model # Force model load
|
| 79 |
+
return self._dimension
|
| 80 |
+
|
| 81 |
+
def embed_text(self, text: str) -> np.ndarray:
|
| 82 |
+
"""Generate embedding for a single text."""
|
| 83 |
+
return self.model.encode(text, convert_to_numpy=True)
|
| 84 |
+
|
| 85 |
+
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
| 86 |
+
"""Generate embeddings for multiple texts."""
|
| 87 |
+
return self.model.encode(texts, convert_to_numpy=True, show_progress_bar=len(texts) > 100)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class OpenAIEmbedding(EmbeddingProvider):
|
| 91 |
+
"""
|
| 92 |
+
OpenAI embedding provider.
|
| 93 |
+
|
| 94 |
+
Uses OpenAI API, requires API key.
|
| 95 |
+
Default: text-embedding-3-small (1536 dimensions)
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
DIMENSION_MAP = {
|
| 99 |
+
"text-embedding-3-small": 1536,
|
| 100 |
+
"text-embedding-3-large": 3072,
|
| 101 |
+
"text-embedding-ada-002": 1536
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
def __init__(self, api_key: str, model_name: str = "text-embedding-3-small"):
|
| 105 |
+
"""
|
| 106 |
+
Initialize OpenAI embedding client.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
api_key: OpenAI API key
|
| 110 |
+
model_name: OpenAI embedding model name
|
| 111 |
+
"""
|
| 112 |
+
self.api_key = api_key
|
| 113 |
+
self.model_name = model_name
|
| 114 |
+
self._client = None
|
| 115 |
+
self._dimension = self.DIMENSION_MAP.get(model_name, 1536)
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def client(self):
|
| 119 |
+
"""Lazy load the OpenAI client."""
|
| 120 |
+
if self._client is None:
|
| 121 |
+
try:
|
| 122 |
+
from openai import OpenAI
|
| 123 |
+
self._client = OpenAI(api_key=self.api_key)
|
| 124 |
+
except ImportError:
|
| 125 |
+
raise ImportError(
|
| 126 |
+
"openai is required. Install with: pip install openai"
|
| 127 |
+
)
|
| 128 |
+
return self._client
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def dimension(self) -> int:
|
| 132 |
+
"""Get embedding dimension."""
|
| 133 |
+
return self._dimension
|
| 134 |
+
|
| 135 |
+
def embed_text(self, text: str) -> np.ndarray:
|
| 136 |
+
"""Generate embedding for a single text."""
|
| 137 |
+
response = self.client.embeddings.create(
|
| 138 |
+
input=text,
|
| 139 |
+
model=self.model_name
|
| 140 |
+
)
|
| 141 |
+
return np.array(response.data[0].embedding, dtype=np.float32)
|
| 142 |
+
|
| 143 |
+
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
| 144 |
+
"""Generate embeddings for multiple texts (batch)."""
|
| 145 |
+
# OpenAI API supports batching up to 2048 inputs
|
| 146 |
+
batch_size = 100
|
| 147 |
+
all_embeddings = []
|
| 148 |
+
|
| 149 |
+
for i in range(0, len(texts), batch_size):
|
| 150 |
+
batch = texts[i:i + batch_size]
|
| 151 |
+
response = self.client.embeddings.create(
|
| 152 |
+
input=batch,
|
| 153 |
+
model=self.model_name
|
| 154 |
+
)
|
| 155 |
+
embeddings = [np.array(d.embedding, dtype=np.float32) for d in response.data]
|
| 156 |
+
all_embeddings.extend(embeddings)
|
| 157 |
+
|
| 158 |
+
return np.array(all_embeddings)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def create_embedding_provider(
|
| 162 |
+
provider_type: str = "sentence_transformers",
|
| 163 |
+
model_name: Optional[str] = None,
|
| 164 |
+
api_key: Optional[str] = None
|
| 165 |
+
) -> EmbeddingProvider:
|
| 166 |
+
"""
|
| 167 |
+
Factory function to create the appropriate embedding provider.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
provider_type: "sentence_transformers" or "openai"
|
| 171 |
+
model_name: Model name (optional, uses defaults)
|
| 172 |
+
api_key: API key for OpenAI (required if using OpenAI)
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Configured EmbeddingProvider instance
|
| 176 |
+
"""
|
| 177 |
+
if provider_type == "openai":
|
| 178 |
+
if not api_key:
|
| 179 |
+
raise ValueError("OpenAI API key is required for OpenAI embeddings")
|
| 180 |
+
return OpenAIEmbedding(
|
| 181 |
+
api_key=api_key,
|
| 182 |
+
model_name=model_name or "text-embedding-3-small"
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
return SentenceTransformerEmbedding(
|
| 186 |
+
model_name=model_name or "sentence-transformers/all-MiniLM-L6-v2"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Global embedding provider instance
|
| 191 |
+
_embedding_provider: Optional[EmbeddingProvider] = None
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_embedding_provider() -> EmbeddingProvider:
|
| 195 |
+
"""Get or create the global embedding provider."""
|
| 196 |
+
global _embedding_provider
|
| 197 |
+
if _embedding_provider is None:
|
| 198 |
+
# Default to sentence transformers (free, local)
|
| 199 |
+
_embedding_provider = SentenceTransformerEmbedding()
|
| 200 |
+
return _embedding_provider
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def set_embedding_provider(provider: EmbeddingProvider):
|
| 204 |
+
"""Set the global embedding provider."""
|
| 205 |
+
global _embedding_provider
|
| 206 |
+
_embedding_provider = provider
|
rag/rag_engine.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG Engine - Orchestrates the retrieval-augmented generation pipeline.
|
| 3 |
+
|
| 4 |
+
Handles:
|
| 5 |
+
- Automatic indexing of text columns from the database
|
| 6 |
+
- Semantic retrieval using FAISS
|
| 7 |
+
- Context building for the LLM
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
from .document_processor import Document, get_document_processor
|
| 14 |
+
from .vector_store import VectorStore, get_vector_store
|
| 15 |
+
from .embeddings import get_embedding_provider
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RAGEngine:
|
| 21 |
+
"""Main RAG engine for semantic retrieval from database text."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, vector_store: Optional[VectorStore] = None):
|
| 24 |
+
self.vector_store = vector_store or get_vector_store()
|
| 25 |
+
self.doc_processor = get_document_processor()
|
| 26 |
+
self.indexed_tables: Dict[str, bool] = {}
|
| 27 |
+
|
| 28 |
+
def index_table(
|
| 29 |
+
self,
|
| 30 |
+
table_name: str,
|
| 31 |
+
rows: List[Dict[str, Any]],
|
| 32 |
+
text_columns: List[str],
|
| 33 |
+
primary_key_column: Optional[str] = None
|
| 34 |
+
) -> int:
|
| 35 |
+
"""
|
| 36 |
+
Index text data from a table.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Number of documents indexed
|
| 40 |
+
"""
|
| 41 |
+
documents = list(self.doc_processor.process_rows(
|
| 42 |
+
rows, table_name, text_columns, primary_key_column
|
| 43 |
+
))
|
| 44 |
+
|
| 45 |
+
if documents:
|
| 46 |
+
self.vector_store.add_documents(documents)
|
| 47 |
+
self.indexed_tables[table_name] = True
|
| 48 |
+
logger.info(f"Indexed {len(documents)} documents from {table_name}")
|
| 49 |
+
|
| 50 |
+
return len(documents)
|
| 51 |
+
|
| 52 |
+
def search(
|
| 53 |
+
self,
|
| 54 |
+
query: str,
|
| 55 |
+
top_k: int = 5,
|
| 56 |
+
table_filter: Optional[List[str]] = None
|
| 57 |
+
) -> List[Tuple[Document, float]]:
|
| 58 |
+
"""
|
| 59 |
+
Search for relevant documents.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
query: Search query
|
| 63 |
+
top_k: Number of results
|
| 64 |
+
table_filter: Optional list of tables to search in
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
List of (document, score) tuples
|
| 68 |
+
"""
|
| 69 |
+
results = self.vector_store.search(query, top_k=top_k * 2)
|
| 70 |
+
|
| 71 |
+
if table_filter:
|
| 72 |
+
results = [
|
| 73 |
+
(doc, score) for doc, score in results
|
| 74 |
+
if doc.table_name in table_filter
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
return results[:top_k]
|
| 78 |
+
|
| 79 |
+
def get_context(
|
| 80 |
+
self,
|
| 81 |
+
query: str,
|
| 82 |
+
top_k: int = 5,
|
| 83 |
+
table_filter: Optional[List[str]] = None
|
| 84 |
+
) -> str:
|
| 85 |
+
"""
|
| 86 |
+
Get formatted context for LLM from search results.
|
| 87 |
+
"""
|
| 88 |
+
results = self.search(query, top_k, table_filter)
|
| 89 |
+
|
| 90 |
+
if not results:
|
| 91 |
+
return "No relevant information found in the database."
|
| 92 |
+
|
| 93 |
+
context_parts = []
|
| 94 |
+
for doc, score in results:
|
| 95 |
+
context_parts.append(doc.to_context_string())
|
| 96 |
+
|
| 97 |
+
return "\n\n---\n\n".join(context_parts)
|
| 98 |
+
|
| 99 |
+
def clear_index(self):
|
| 100 |
+
"""Clear the entire index."""
|
| 101 |
+
self.vector_store.clear()
|
| 102 |
+
self.indexed_tables = {}
|
| 103 |
+
|
| 104 |
+
def save(self):
|
| 105 |
+
"""Save the index to disk."""
|
| 106 |
+
self.vector_store.save()
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def document_count(self) -> int:
|
| 110 |
+
return len(self.vector_store)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
_rag_engine: Optional[RAGEngine] = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_rag_engine() -> RAGEngine:
|
| 117 |
+
global _rag_engine
|
| 118 |
+
if _rag_engine is None:
|
| 119 |
+
_rag_engine = RAGEngine()
|
| 120 |
+
return _rag_engine
|
rag/vector_store.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FAISS Vector Store for RAG.
|
| 3 |
+
|
| 4 |
+
Manages the FAISS index for semantic search over database text content.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import pickle
|
| 9 |
+
import os
|
| 10 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import faiss
|
| 15 |
+
except ImportError:
|
| 16 |
+
faiss = None
|
| 17 |
+
|
| 18 |
+
from .document_processor import Document
|
| 19 |
+
from .embeddings import get_embedding_provider, EmbeddingProvider
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class VectorStore:
|
| 25 |
+
"""FAISS-based vector store for semantic search."""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
embedding_provider: Optional[EmbeddingProvider] = None,
|
| 30 |
+
index_path: str = "./faiss_index"
|
| 31 |
+
):
|
| 32 |
+
if faiss is None:
|
| 33 |
+
raise ImportError("faiss-cpu is required. Install with: pip install faiss-cpu")
|
| 34 |
+
|
| 35 |
+
self.embedding_provider = embedding_provider or get_embedding_provider()
|
| 36 |
+
self.index_path = index_path
|
| 37 |
+
self.dimension = self.embedding_provider.dimension
|
| 38 |
+
|
| 39 |
+
self.index: Optional[faiss.IndexFlatIP] = None
|
| 40 |
+
self.documents: List[Document] = []
|
| 41 |
+
self.id_to_idx: Dict[str, int] = {}
|
| 42 |
+
|
| 43 |
+
self._initialize_index()
|
| 44 |
+
|
| 45 |
+
def _initialize_index(self):
|
| 46 |
+
"""Initialize or load the FAISS index."""
|
| 47 |
+
index_file = os.path.join(self.index_path, "index.faiss")
|
| 48 |
+
docs_file = os.path.join(self.index_path, "documents.pkl")
|
| 49 |
+
|
| 50 |
+
if os.path.exists(index_file) and os.path.exists(docs_file):
|
| 51 |
+
try:
|
| 52 |
+
# Check file size - if 0 something is wrong
|
| 53 |
+
if os.path.getsize(index_file) > 0:
|
| 54 |
+
self.index = faiss.read_index(index_file)
|
| 55 |
+
with open(docs_file, 'rb') as f:
|
| 56 |
+
self.documents = pickle.load(f)
|
| 57 |
+
self.id_to_idx = {doc.id: i for i, doc in enumerate(self.documents)}
|
| 58 |
+
|
| 59 |
+
# Verify index dimension matches expected
|
| 60 |
+
if self.index.d != self.dimension:
|
| 61 |
+
logger.warning(f"Index dimension mismatch: {self.index.d} != {self.dimension}. Resetting.")
|
| 62 |
+
raise ValueError("Dimension mismatch")
|
| 63 |
+
|
| 64 |
+
logger.info(f"Loaded index with {len(self.documents)} documents")
|
| 65 |
+
return
|
| 66 |
+
except (Exception, RuntimeError) as e:
|
| 67 |
+
logger.warning(f"Failed to load index (might be corrupted or memory error): {e}")
|
| 68 |
+
# If loading fails, we should probably backup the broken files or just overwrite
|
| 69 |
+
if os.path.exists(index_file):
|
| 70 |
+
try:
|
| 71 |
+
os.rename(index_file, index_file + ".bak")
|
| 72 |
+
os.rename(docs_file, docs_file + ".bak")
|
| 73 |
+
except:
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
# Create new index (Inner Product for cosine similarity with normalized vectors)
|
| 77 |
+
self.index = faiss.IndexFlatIP(self.dimension)
|
| 78 |
+
self.documents = []
|
| 79 |
+
self.id_to_idx = {}
|
| 80 |
+
logger.info(f"Created new FAISS index with dimension {self.dimension}")
|
| 81 |
+
|
| 82 |
+
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
| 83 |
+
"""Add documents to the vector store."""
|
| 84 |
+
if not documents:
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
new_docs = [doc for doc in documents if doc.id not in self.id_to_idx]
|
| 88 |
+
if not new_docs:
|
| 89 |
+
logger.info("No new documents to add")
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
logger.info(f"Adding {len(new_docs)} documents to index")
|
| 93 |
+
|
| 94 |
+
for i in range(0, len(new_docs), batch_size):
|
| 95 |
+
batch = new_docs[i:i + batch_size]
|
| 96 |
+
texts = [doc.content for doc in batch]
|
| 97 |
+
|
| 98 |
+
embeddings = self.embedding_provider.embed_texts(texts)
|
| 99 |
+
|
| 100 |
+
# Normalize for cosine similarity
|
| 101 |
+
faiss.normalize_L2(embeddings)
|
| 102 |
+
|
| 103 |
+
start_idx = len(self.documents)
|
| 104 |
+
self.index.add(embeddings)
|
| 105 |
+
|
| 106 |
+
for j, doc in enumerate(batch):
|
| 107 |
+
self.documents.append(doc)
|
| 108 |
+
self.id_to_idx[doc.id] = start_idx + j
|
| 109 |
+
|
| 110 |
+
logger.info(f"Index now contains {len(self.documents)} documents")
|
| 111 |
+
|
| 112 |
+
def search(
|
| 113 |
+
self, query: str, top_k: int = 5, threshold: float = 0.0
|
| 114 |
+
) -> List[Tuple[Document, float]]:
|
| 115 |
+
"""Search for similar documents."""
|
| 116 |
+
if not self.documents:
|
| 117 |
+
return []
|
| 118 |
+
|
| 119 |
+
query_embedding = self.embedding_provider.embed_text(query)
|
| 120 |
+
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
|
| 121 |
+
faiss.normalize_L2(query_embedding)
|
| 122 |
+
|
| 123 |
+
k = min(top_k, len(self.documents))
|
| 124 |
+
scores, indices = self.index.search(query_embedding, k)
|
| 125 |
+
|
| 126 |
+
results = []
|
| 127 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 128 |
+
if idx >= 0 and score >= threshold:
|
| 129 |
+
results.append((self.documents[idx], float(score)))
|
| 130 |
+
|
| 131 |
+
return results
|
| 132 |
+
|
| 133 |
+
def save(self):
|
| 134 |
+
"""Save the index to disk."""
|
| 135 |
+
os.makedirs(self.index_path, exist_ok=True)
|
| 136 |
+
|
| 137 |
+
index_file = os.path.join(self.index_path, "index.faiss")
|
| 138 |
+
docs_file = os.path.join(self.index_path, "documents.pkl")
|
| 139 |
+
|
| 140 |
+
faiss.write_index(self.index, index_file)
|
| 141 |
+
with open(docs_file, 'wb') as f:
|
| 142 |
+
pickle.dump(self.documents, f)
|
| 143 |
+
|
| 144 |
+
logger.info(f"Saved index with {len(self.documents)} documents")
|
| 145 |
+
|
| 146 |
+
def clear(self):
|
| 147 |
+
"""Clear the index."""
|
| 148 |
+
self.index = faiss.IndexFlatIP(self.dimension)
|
| 149 |
+
self.documents = []
|
| 150 |
+
self.id_to_idx = {}
|
| 151 |
+
|
| 152 |
+
# Delete files
|
| 153 |
+
index_file = os.path.join(self.index_path, "index.faiss")
|
| 154 |
+
docs_file = os.path.join(self.index_path, "documents.pkl")
|
| 155 |
+
|
| 156 |
+
for f in [index_file, docs_file]:
|
| 157 |
+
if os.path.exists(f):
|
| 158 |
+
os.remove(f)
|
| 159 |
+
|
| 160 |
+
logger.info("Index cleared")
|
| 161 |
+
|
| 162 |
+
def __len__(self) -> int:
|
| 163 |
+
return len(self.documents)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
_vector_store: Optional[VectorStore] = None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_vector_store() -> VectorStore:
|
| 170 |
+
global _vector_store
|
| 171 |
+
if _vector_store is None:
|
| 172 |
+
_vector_store = VectorStore()
|
| 173 |
+
return _vector_store
|
requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Schema-Agnostic Database Chatbot
|
| 2 |
+
# Multi-Database Support: MySQL, PostgreSQL, SQLite
|
| 3 |
+
|
| 4 |
+
# Core dependencies
|
| 5 |
+
streamlit>=1.30.0
|
| 6 |
+
sqlalchemy>=2.0.0
|
| 7 |
+
|
| 8 |
+
# Database drivers
|
| 9 |
+
pymysql>=1.1.0 # MySQL driver
|
| 10 |
+
psycopg2-binary>=2.9.9 # PostgreSQL driver
|
| 11 |
+
# SQLite is built into Python - no driver needed
|
| 12 |
+
|
| 13 |
+
# RAG dependencies
|
| 14 |
+
faiss-cpu>=1.7.4
|
| 15 |
+
sentence-transformers>=2.2.2
|
| 16 |
+
|
| 17 |
+
# LLM dependencies
|
| 18 |
+
groq>=0.4.0 # FREE API!
|
| 19 |
+
openai>=1.0.0 # Optional, for OpenAI provider
|
| 20 |
+
|
| 21 |
+
# For local models (optional)
|
| 22 |
+
# transformers>=4.36.0
|
| 23 |
+
# torch>=2.0.0
|
| 24 |
+
|
| 25 |
+
# SQL parsing and validation
|
| 26 |
+
sqlparse>=0.4.4
|
| 27 |
+
|
| 28 |
+
# Utilities
|
| 29 |
+
python-dotenv>=1.0.0
|
| 30 |
+
numpy>=1.24.0
|
| 31 |
+
pandas>=2.0.0
|
router.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Query Router - Decides between RAG, SQL, or hybrid approach.
|
| 3 |
+
|
| 4 |
+
Analyzes user intent and routes to the appropriate handler.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from typing import Dict, Any, Optional, Tuple, List
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class QueryType(Enum):
|
| 16 |
+
RAG = "rag" # Semantic search in text
|
| 17 |
+
SQL = "sql" # Structured query
|
| 18 |
+
HYBRID = "hybrid" # Both RAG and SQL
|
| 19 |
+
GENERAL = "general" # General conversation
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class RoutingDecision:
|
| 24 |
+
query_type: QueryType
|
| 25 |
+
confidence: float
|
| 26 |
+
reasoning: str
|
| 27 |
+
suggested_tables: List[str] = None
|
| 28 |
+
|
| 29 |
+
def __post_init__(self):
|
| 30 |
+
if self.suggested_tables is None:
|
| 31 |
+
self.suggested_tables = []
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class QueryRouter:
|
| 35 |
+
"""Routes queries to appropriate handlers based on intent analysis."""
|
| 36 |
+
|
| 37 |
+
ROUTING_PROMPT = """Analyze this user query and determine the best approach to answer it.
|
| 38 |
+
|
| 39 |
+
DATABASE SCHEMA:
|
| 40 |
+
{schema}
|
| 41 |
+
|
| 42 |
+
USER QUERY: {query}
|
| 43 |
+
|
| 44 |
+
Determine if this query needs:
|
| 45 |
+
1. RAG - Semantic search through text content (searching for meanings, concepts, descriptions)
|
| 46 |
+
2. SQL - Structured database query (counting, filtering, aggregating, specific lookups)
|
| 47 |
+
3. HYBRID - Both semantic search and structured query
|
| 48 |
+
4. GENERAL - General conversation not requiring database access
|
| 49 |
+
|
| 50 |
+
Respond in this exact format:
|
| 51 |
+
TYPE: [RAG|SQL|HYBRID|GENERAL]
|
| 52 |
+
CONFIDENCE: [0.0-1.0]
|
| 53 |
+
TABLES: [comma-separated list of relevant tables, or NONE]
|
| 54 |
+
REASONING: [brief explanation]"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, llm_client=None):
|
| 57 |
+
self.llm_client = llm_client
|
| 58 |
+
|
| 59 |
+
def set_llm_client(self, llm_client):
|
| 60 |
+
self.llm_client = llm_client
|
| 61 |
+
|
| 62 |
+
def route(self, query: str, schema_context: str) -> RoutingDecision:
|
| 63 |
+
"""Analyze query and determine routing."""
|
| 64 |
+
if not self.llm_client:
|
| 65 |
+
# Fallback to simple heuristics
|
| 66 |
+
return self._heuristic_route(query)
|
| 67 |
+
|
| 68 |
+
prompt = self.ROUTING_PROMPT.format(schema=schema_context, query=query)
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
response = self.llm_client.chat([
|
| 72 |
+
{"role": "system", "content": "You are a query routing assistant."},
|
| 73 |
+
{"role": "user", "content": prompt}
|
| 74 |
+
])
|
| 75 |
+
return self._parse_routing_response(response)
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warning(f"LLM routing failed: {e}, using heuristics")
|
| 78 |
+
return self._heuristic_route(query)
|
| 79 |
+
|
| 80 |
+
def _parse_routing_response(self, response: str) -> RoutingDecision:
|
| 81 |
+
"""Parse LLM routing response."""
|
| 82 |
+
lines = response.strip().split('\n')
|
| 83 |
+
|
| 84 |
+
query_type = QueryType.GENERAL
|
| 85 |
+
confidence = 0.5
|
| 86 |
+
tables = []
|
| 87 |
+
reasoning = ""
|
| 88 |
+
|
| 89 |
+
for line in lines:
|
| 90 |
+
line = line.strip()
|
| 91 |
+
if line.startswith("TYPE:"):
|
| 92 |
+
type_str = line.replace("TYPE:", "").strip().upper()
|
| 93 |
+
query_type = QueryType[type_str] if type_str in QueryType.__members__ else QueryType.GENERAL
|
| 94 |
+
elif line.startswith("CONFIDENCE:"):
|
| 95 |
+
try:
|
| 96 |
+
confidence = float(line.replace("CONFIDENCE:", "").strip())
|
| 97 |
+
except ValueError:
|
| 98 |
+
confidence = 0.5
|
| 99 |
+
elif line.startswith("TABLES:"):
|
| 100 |
+
tables_str = line.replace("TABLES:", "").strip()
|
| 101 |
+
if tables_str.upper() != "NONE":
|
| 102 |
+
tables = [t.strip() for t in tables_str.split(",")]
|
| 103 |
+
elif line.startswith("REASONING:"):
|
| 104 |
+
reasoning = line.replace("REASONING:", "").strip()
|
| 105 |
+
|
| 106 |
+
return RoutingDecision(query_type, confidence, reasoning, tables)
|
| 107 |
+
|
| 108 |
+
def _heuristic_route(self, query: str) -> RoutingDecision:
|
| 109 |
+
"""Simple heuristic-based routing when LLM is unavailable."""
|
| 110 |
+
query_lower = query.lower()
|
| 111 |
+
|
| 112 |
+
# SQL keywords - for structured data retrieval
|
| 113 |
+
sql_keywords = [
|
| 114 |
+
'how many', 'count', 'total', 'average', 'sum', 'max', 'min',
|
| 115 |
+
'list all', 'show all', 'find all', 'get all', 'between',
|
| 116 |
+
'greater than', 'less than', 'equal to', 'top', 'bottom',
|
| 117 |
+
# Data listing patterns
|
| 118 |
+
'what products', 'what customers', 'what orders', 'what items',
|
| 119 |
+
'show me', 'list', 'display', 'give me', 'get me',
|
| 120 |
+
'all products', 'all customers', 'all orders',
|
| 121 |
+
'products do you have', 'customers do you have',
|
| 122 |
+
'from new york', 'from chicago', 'from los angeles',
|
| 123 |
+
# Specific lookups
|
| 124 |
+
'price of', 'cost of', 'stock of', 'quantity',
|
| 125 |
+
'where', 'which', 'who'
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
# RAG keywords - for semantic/conceptual questions
|
| 129 |
+
rag_keywords = [
|
| 130 |
+
'what is the policy', 'explain', 'describe', 'tell me about',
|
| 131 |
+
'meaning of', 'definition', 'why', 'how does', 'what does',
|
| 132 |
+
'similar to', 'return policy', 'shipping policy', 'warranty',
|
| 133 |
+
'support', 'help with', 'information about', 'details about'
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
sql_score = sum(1 for kw in sql_keywords if kw in query_lower)
|
| 137 |
+
rag_score = sum(1 for kw in rag_keywords if kw in query_lower)
|
| 138 |
+
|
| 139 |
+
# Boost SQL score for common listing patterns
|
| 140 |
+
if any(word in query_lower for word in ['products', 'customers', 'orders', 'items']):
|
| 141 |
+
if any(word in query_lower for word in ['what', 'show', 'list', 'all', 'have']):
|
| 142 |
+
sql_score += 2
|
| 143 |
+
|
| 144 |
+
if sql_score > rag_score:
|
| 145 |
+
return RoutingDecision(QueryType.SQL, 0.8, "SQL query for data retrieval")
|
| 146 |
+
elif rag_score > sql_score:
|
| 147 |
+
return RoutingDecision(QueryType.RAG, 0.8, "Semantic search for concepts")
|
| 148 |
+
elif sql_score > 0 and rag_score > 0:
|
| 149 |
+
return RoutingDecision(QueryType.HYBRID, 0.6, "Mixed query type")
|
| 150 |
+
else:
|
| 151 |
+
# Default to SQL for simple questions about data
|
| 152 |
+
if any(word in query_lower for word in ['products', 'customers', 'orders']):
|
| 153 |
+
return RoutingDecision(QueryType.SQL, 0.6, "Default to SQL for data tables")
|
| 154 |
+
return RoutingDecision(QueryType.RAG, 0.5, "Default to semantic search")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
_router: Optional[QueryRouter] = None
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_query_router() -> QueryRouter:
|
| 161 |
+
global _router
|
| 162 |
+
if _router is None:
|
| 163 |
+
_router = QueryRouter()
|
| 164 |
+
return _router
|
sql/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQL module exports."""
|
| 2 |
+
|
| 3 |
+
from .validator import SQLValidator, SQLValidationError, get_sql_validator
|
| 4 |
+
from .generator import SQLGenerator, get_sql_generator
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"SQLValidator", "SQLValidationError", "get_sql_validator",
|
| 8 |
+
"SQLGenerator", "get_sql_generator"
|
| 9 |
+
]
|
sql/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (480 Bytes). View file
|
|
|
sql/__pycache__/generator.cpython-311.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
sql/__pycache__/validator.cpython-311.pyc
ADDED
|
Binary file (7.14 kB). View file
|
|
|
sql/generator.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text-to-SQL Generator - Multi-Database Support.
|
| 3 |
+
|
| 4 |
+
Uses LLM to generate SQL queries from natural language,
|
| 5 |
+
with dynamic schema context. Supports MySQL, PostgreSQL, and SQLite.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Optional, Dict, Any, List, Tuple
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_sql_dialect(db_type: str) -> str:
|
| 16 |
+
"""Get the SQL dialect name for the given database type."""
|
| 17 |
+
dialects = {
|
| 18 |
+
"mysql": "MySQL",
|
| 19 |
+
"postgresql": "PostgreSQL",
|
| 20 |
+
"sqlite": "SQLite"
|
| 21 |
+
}
|
| 22 |
+
return dialects.get(db_type, "SQL")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_dialect_specific_hints(db_type: str) -> str:
|
| 26 |
+
"""Get database-specific hints for SQL generation."""
|
| 27 |
+
if db_type == "postgresql":
|
| 28 |
+
return """
|
| 29 |
+
PostgreSQL-SPECIFIC NOTES:
|
| 30 |
+
- Use ILIKE for case-insensitive pattern matching (instead of LIKE)
|
| 31 |
+
- String concatenation uses || operator
|
| 32 |
+
- Use LIMIT at the end of queries
|
| 33 |
+
- Boolean values are TRUE/FALSE (not 1/0)
|
| 34 |
+
- Use double quotes for identifiers with special chars, single quotes for strings
|
| 35 |
+
"""
|
| 36 |
+
elif db_type == "sqlite":
|
| 37 |
+
return """
|
| 38 |
+
SQLite-SPECIFIC NOTES:
|
| 39 |
+
- LIKE is case-insensitive for ASCII characters by default
|
| 40 |
+
- Use || for string concatenation
|
| 41 |
+
- No ILIKE - use LIKE (case-insensitive) or GLOB (case-sensitive)
|
| 42 |
+
- Use LIMIT at the end of queries
|
| 43 |
+
- Boolean values are 1/0
|
| 44 |
+
- Uses strftime() for date functions instead of DATE_FORMAT
|
| 45 |
+
"""
|
| 46 |
+
else: # MySQL
|
| 47 |
+
return """
|
| 48 |
+
MySQL-SPECIFIC NOTES:
|
| 49 |
+
- LIKE is case-insensitive for non-binary strings
|
| 50 |
+
- Use CONCAT() for string concatenation
|
| 51 |
+
- Use LIMIT at the end of queries
|
| 52 |
+
- Boolean values are 1/0
|
| 53 |
+
- Use backticks for identifiers with special chars, single quotes for strings
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SQLGenerator:
|
| 58 |
+
"""Generates SQL queries from natural language using LLM."""
|
| 59 |
+
|
| 60 |
+
SYSTEM_PROMPT_TEMPLATE = """You are a SQL expert. Generate {dialect} SELECT queries based on user questions.
|
| 61 |
+
|
| 62 |
+
RULES:
|
| 63 |
+
1. ONLY generate SELECT statements.
|
| 64 |
+
2. NEVER use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, or TRUNCATE.
|
| 65 |
+
3. Always include a LIMIT clause (max 50 rows unless specified).
|
| 66 |
+
4. Use table and column names EXACTLY as shown in the schema.
|
| 67 |
+
5. AMBIGUITY: If the user asks for a category, type, or specific value, and you are unsure which column it belongs to:
|
| 68 |
+
- Check multiple likely columns (e.g., `category`, `sub_category`, `type`, `description`).
|
| 69 |
+
- Use pattern matching for flexibility.
|
| 70 |
+
- Use `OR` to combine multiple column checks.
|
| 71 |
+
6. DATA AWARENESS: In footwear databases, specific types like 'Formal', 'Casual', or 'Sports' often appear in `sub_category` OR `category`. Check both if available.
|
| 72 |
+
7. Return ONLY the SQL query, no explanations.
|
| 73 |
+
|
| 74 |
+
{dialect_hints}
|
| 75 |
+
|
| 76 |
+
DATABASE SCHEMA:
|
| 77 |
+
{schema}
|
| 78 |
+
|
| 79 |
+
Generate a single {dialect} SELECT query to answer the user's question."""
|
| 80 |
+
|
| 81 |
+
def __init__(self, llm_client=None, db_type: str = "mysql"):
|
| 82 |
+
self.llm_client = llm_client
|
| 83 |
+
self.db_type = db_type
|
| 84 |
+
|
| 85 |
+
def set_llm_client(self, llm_client):
|
| 86 |
+
self.llm_client = llm_client
|
| 87 |
+
|
| 88 |
+
def set_db_type(self, db_type: str):
|
| 89 |
+
"""Set the database type for SQL generation."""
|
| 90 |
+
self.db_type = db_type
|
| 91 |
+
|
| 92 |
+
def generate(
|
| 93 |
+
self,
|
| 94 |
+
question: str,
|
| 95 |
+
schema_context: str,
|
| 96 |
+
chat_history: Optional[List[Dict[str, str]]] = None
|
| 97 |
+
) -> Tuple[str, str]:
|
| 98 |
+
"""
|
| 99 |
+
Generate SQL from natural language.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Tuple of (sql_query, explanation)
|
| 103 |
+
"""
|
| 104 |
+
if not self.llm_client:
|
| 105 |
+
raise ValueError("LLM client not configured")
|
| 106 |
+
|
| 107 |
+
dialect = get_sql_dialect(self.db_type)
|
| 108 |
+
dialect_hints = get_dialect_specific_hints(self.db_type)
|
| 109 |
+
|
| 110 |
+
system_prompt = self.SYSTEM_PROMPT_TEMPLATE.format(
|
| 111 |
+
dialect=dialect,
|
| 112 |
+
dialect_hints=dialect_hints,
|
| 113 |
+
schema=schema_context
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
messages = [{"role": "system", "content": system_prompt}]
|
| 117 |
+
|
| 118 |
+
if chat_history:
|
| 119 |
+
for msg in chat_history[-3:]: # Last 3 exchanges for context
|
| 120 |
+
messages.append(msg)
|
| 121 |
+
|
| 122 |
+
messages.append({"role": "user", "content": question})
|
| 123 |
+
|
| 124 |
+
response = self.llm_client.chat(messages)
|
| 125 |
+
|
| 126 |
+
# Extract SQL from response
|
| 127 |
+
sql = self._extract_sql(response)
|
| 128 |
+
|
| 129 |
+
return sql, response
|
| 130 |
+
|
| 131 |
+
def _extract_sql(self, response: str) -> str:
|
| 132 |
+
"""Extract SQL query from LLM response."""
|
| 133 |
+
# Look for SQL in code blocks
|
| 134 |
+
code_block = re.search(r'```(?:sql)?\\s*(.*?)```', response, re.DOTALL | re.IGNORECASE)
|
| 135 |
+
if code_block:
|
| 136 |
+
return code_block.group(1).strip()
|
| 137 |
+
|
| 138 |
+
# Look for SELECT statement
|
| 139 |
+
select_match = re.search(
|
| 140 |
+
r'(SELECT\\s+.+?(?:;|$))',
|
| 141 |
+
response,
|
| 142 |
+
re.DOTALL | re.IGNORECASE
|
| 143 |
+
)
|
| 144 |
+
if select_match:
|
| 145 |
+
return select_match.group(1).strip().rstrip(';')
|
| 146 |
+
|
| 147 |
+
return response.strip()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
_generator: Optional[SQLGenerator] = None
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_sql_generator(db_type: str = "mysql") -> SQLGenerator:
|
| 154 |
+
global _generator
|
| 155 |
+
if _generator is None:
|
| 156 |
+
_generator = SQLGenerator(db_type=db_type)
|
| 157 |
+
else:
|
| 158 |
+
_generator.set_db_type(db_type)
|
| 159 |
+
return _generator
|
sql/validator.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL Validator - Security layer for SQL queries.
|
| 3 |
+
|
| 4 |
+
Ensures ONLY safe SELECT queries are executed.
|
| 5 |
+
Validates against whitelist and blocks dangerous operations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import re
|
| 10 |
+
from typing import List, Tuple, Optional, Set
|
| 11 |
+
import sqlparse
|
| 12 |
+
from sqlparse.sql import Statement, Token, Identifier, IdentifierList
|
| 13 |
+
from sqlparse.tokens import Keyword, DML
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SQLValidationError(Exception):
|
| 19 |
+
"""Raised when SQL validation fails."""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SQLValidator:
|
| 24 |
+
"""Validates SQL queries for safety before execution."""
|
| 25 |
+
|
| 26 |
+
FORBIDDEN_KEYWORDS = {
|
| 27 |
+
'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER',
|
| 28 |
+
'TRUNCATE', 'GRANT', 'REVOKE', 'EXECUTE', 'EXEC',
|
| 29 |
+
'INTO OUTFILE', 'INTO DUMPFILE', 'LOAD_FILE', 'LOAD DATA'
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
FORBIDDEN_PATTERNS = [
|
| 33 |
+
r'INTO\s+OUTFILE',
|
| 34 |
+
r'INTO\s+DUMPFILE',
|
| 35 |
+
r'LOAD_FILE\s*\(',
|
| 36 |
+
r'LOAD\s+DATA',
|
| 37 |
+
r';\s*(?:DROP|DELETE|UPDATE|INSERT)', # Multi-statement attacks
|
| 38 |
+
r'--', # SQL comments (potential injection)
|
| 39 |
+
r'/\*.*\*/', # Block comments
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
def __init__(self, allowed_tables: Optional[Set[str]] = None, max_limit: int = 100):
|
| 43 |
+
self.allowed_tables = allowed_tables or set()
|
| 44 |
+
self.max_limit = max_limit
|
| 45 |
+
self._compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.FORBIDDEN_PATTERNS]
|
| 46 |
+
|
| 47 |
+
def set_allowed_tables(self, tables: List[str]):
|
| 48 |
+
"""Set the whitelist of allowed tables."""
|
| 49 |
+
self.allowed_tables = set(tables)
|
| 50 |
+
|
| 51 |
+
def validate(self, sql: str) -> Tuple[bool, str, Optional[str]]:
|
| 52 |
+
"""
|
| 53 |
+
Validate SQL query for safety.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Tuple of (is_valid, message, sanitized_sql)
|
| 57 |
+
"""
|
| 58 |
+
if not sql or not sql.strip():
|
| 59 |
+
return False, "Empty SQL query", None
|
| 60 |
+
|
| 61 |
+
sql = sql.strip()
|
| 62 |
+
|
| 63 |
+
# Check for forbidden patterns
|
| 64 |
+
for pattern in self._compiled_patterns:
|
| 65 |
+
if pattern.search(sql):
|
| 66 |
+
return False, f"Forbidden pattern detected in query", None
|
| 67 |
+
|
| 68 |
+
# Parse SQL
|
| 69 |
+
try:
|
| 70 |
+
parsed = sqlparse.parse(sql)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
return False, f"Failed to parse SQL: {e}", None
|
| 73 |
+
|
| 74 |
+
if not parsed:
|
| 75 |
+
return False, "Failed to parse SQL query", None
|
| 76 |
+
|
| 77 |
+
# Only allow single statements
|
| 78 |
+
if len(parsed) > 1:
|
| 79 |
+
return False, "Multiple SQL statements not allowed", None
|
| 80 |
+
|
| 81 |
+
statement = parsed[0]
|
| 82 |
+
|
| 83 |
+
# Check statement type
|
| 84 |
+
stmt_type = statement.get_type()
|
| 85 |
+
if stmt_type != 'SELECT':
|
| 86 |
+
return False, f"Only SELECT statements allowed, got: {stmt_type}", None
|
| 87 |
+
|
| 88 |
+
# Check for forbidden keywords in tokens
|
| 89 |
+
sql_upper = sql.upper()
|
| 90 |
+
for keyword in self.FORBIDDEN_KEYWORDS:
|
| 91 |
+
if keyword in sql_upper:
|
| 92 |
+
return False, f"Forbidden keyword detected: {keyword}", None
|
| 93 |
+
|
| 94 |
+
# Extract and validate tables
|
| 95 |
+
tables = self._extract_tables(statement)
|
| 96 |
+
if self.allowed_tables:
|
| 97 |
+
invalid_tables = tables - self.allowed_tables
|
| 98 |
+
if invalid_tables:
|
| 99 |
+
return False, f"Access denied to tables: {invalid_tables}", None
|
| 100 |
+
|
| 101 |
+
# Ensure LIMIT clause exists
|
| 102 |
+
sanitized = self._ensure_limit(sql)
|
| 103 |
+
|
| 104 |
+
return True, "Query validated successfully", sanitized
|
| 105 |
+
|
| 106 |
+
def _extract_tables(self, statement: Statement) -> Set[str]:
|
| 107 |
+
"""Extract table names from a SELECT statement using regex."""
|
| 108 |
+
tables = set()
|
| 109 |
+
sql = str(statement)
|
| 110 |
+
|
| 111 |
+
# Use regex to find tables after FROM and JOIN
|
| 112 |
+
# Pattern: FROM table_name or JOIN table_name
|
| 113 |
+
from_pattern = re.compile(
|
| 114 |
+
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)',
|
| 115 |
+
re.IGNORECASE
|
| 116 |
+
)
|
| 117 |
+
join_pattern = re.compile(
|
| 118 |
+
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)',
|
| 119 |
+
re.IGNORECASE
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Find all FROM tables
|
| 123 |
+
for match in from_pattern.finditer(sql):
|
| 124 |
+
tables.add(match.group(1))
|
| 125 |
+
|
| 126 |
+
# Find all JOIN tables
|
| 127 |
+
for match in join_pattern.finditer(sql):
|
| 128 |
+
tables.add(match.group(1))
|
| 129 |
+
|
| 130 |
+
return tables
|
| 131 |
+
|
| 132 |
+
def _ensure_limit(self, sql: str) -> str:
|
| 133 |
+
"""Ensure the query has a LIMIT clause."""
|
| 134 |
+
sql_upper = sql.upper()
|
| 135 |
+
|
| 136 |
+
if 'LIMIT' in sql_upper:
|
| 137 |
+
# Check if limit is too high
|
| 138 |
+
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
|
| 139 |
+
if limit_match:
|
| 140 |
+
current_limit = int(limit_match.group(1))
|
| 141 |
+
if current_limit > self.max_limit:
|
| 142 |
+
# Replace with max limit
|
| 143 |
+
sql = re.sub(
|
| 144 |
+
r'LIMIT\s+\d+',
|
| 145 |
+
f'LIMIT {self.max_limit}',
|
| 146 |
+
sql,
|
| 147 |
+
flags=re.IGNORECASE
|
| 148 |
+
)
|
| 149 |
+
return sql
|
| 150 |
+
else:
|
| 151 |
+
# Add LIMIT clause
|
| 152 |
+
sql = sql.rstrip(';').strip()
|
| 153 |
+
return f"{sql} LIMIT {self.max_limit}"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
_validator: Optional[SQLValidator] = None
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_sql_validator() -> SQLValidator:
|
| 160 |
+
global _validator
|
| 161 |
+
if _validator is None:
|
| 162 |
+
_validator = SQLValidator()
|
| 163 |
+
return _validator
|