Spaces:
Sleeping
Sleeping
Commit ·
52adb86
0
Parent(s):
Project Completion Commit
Browse files- .dockerignore +6 -0
- .gitignore +7 -0
- .python-version +1 -0
- Dockerfile +41 -0
- README.md +0 -0
- app.py +114 -0
- pyproject.toml +24 -0
- requirements.txt +20 -0
- src/embedding.py +71 -0
- src/graph.py +211 -0
- src/main.py +73 -0
- src/retrieval.py +70 -0
- src/scheme.py +35 -0
- start.sh +7 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git/
|
| 2 |
+
.gitignore
|
| 3 |
+
.env
|
| 4 |
+
.venv/
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
.venv/
|
| 5 |
+
.env/
|
| 6 |
+
venv/.env
|
| 7 |
+
.env
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
Dockerfile
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. Base Image
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# 2. Environment Variables for Hugging Face compatibility
|
| 5 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 6 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 7 |
+
HOME=/app \
|
| 8 |
+
PATH=/app/.local/bin:$PATH
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
# 3. Install System Dependencies
|
| 13 |
+
# libpq-dev is for PostgreSQL, curl is for Streamlit health checks
|
| 14 |
+
RUN apt-get update && apt-get install -y \
|
| 15 |
+
build-essential \
|
| 16 |
+
libpq-dev \
|
| 17 |
+
curl \
|
| 18 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
# 4. Install uv (The blazing fast package manager)
|
| 21 |
+
RUN pip install uv
|
| 22 |
+
|
| 23 |
+
# 5. Cache & Install Python Dependencies
|
| 24 |
+
COPY pyproject.toml uv.lock ./
|
| 25 |
+
RUN uv pip install --system -r pyproject.toml
|
| 26 |
+
|
| 27 |
+
# 6. Copy your application code
|
| 28 |
+
COPY . .
|
| 29 |
+
|
| 30 |
+
# 7. Permissions: Make the script executable
|
| 31 |
+
RUN chmod +x start.sh
|
| 32 |
+
|
| 33 |
+
# 8. Permissions: Hugging Face runs as user 1000, not root!
|
| 34 |
+
RUN chown -R 1000:1000 /app
|
| 35 |
+
USER 1000
|
| 36 |
+
|
| 37 |
+
# 9. Expose Ports (7860 for UI, 8000 for internal API)
|
| 38 |
+
EXPOSE 7860 8000
|
| 39 |
+
|
| 40 |
+
# 10. Start the application
|
| 41 |
+
CMD ["./start.sh"]
|
README.md
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import requests
|
| 3 |
+
import uuid
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
# --- CONFIGURATION ---
|
| 7 |
+
API_URL = "http://localhost:8000" # Your FastAPI server URL
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="Text@SQL Agent", page_icon="🤖", layout="centered")
|
| 10 |
+
|
| 11 |
+
# --- SESSION STATE INITIALIZATION ---
|
| 12 |
+
# This ensures variables survive when Streamlit re-renders the page
|
| 13 |
+
if "thread_id" not in st.session_state:
|
| 14 |
+
st.session_state.thread_id = str(uuid.uuid4()) # Unique session ID for LangGraph memory
|
| 15 |
+
if "user_id" not in st.session_state:
|
| 16 |
+
st.session_state.user_id = "tenant_" + str(uuid.uuid4())[:8]
|
| 17 |
+
if "is_db_connected" not in st.session_state:
|
| 18 |
+
st.session_state.is_db_connected = False
|
| 19 |
+
if "connection_url" not in st.session_state:
|
| 20 |
+
st.session_state.connection_url = ""
|
| 21 |
+
if "chat_history" not in st.session_state:
|
| 22 |
+
st.session_state.chat_history = []
|
| 23 |
+
|
| 24 |
+
# --- SIDEBAR: DATABASE CONNECTION ---
|
| 25 |
+
with st.sidebar:
|
| 26 |
+
st.header("⚙️ Database Setup")
|
| 27 |
+
|
| 28 |
+
# If already connected, disable the input to enforce ONE database connection
|
| 29 |
+
db_input = st.text_input(
|
| 30 |
+
"Enter Database URL:",
|
| 31 |
+
disabled=st.session_state.is_db_connected
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if not st.session_state.is_db_connected:
|
| 35 |
+
if st.button("Connect & Initialize", type="primary", use_container_width=True):
|
| 36 |
+
if not db_input:
|
| 37 |
+
st.error("Please enter a valid URL.")
|
| 38 |
+
else:
|
| 39 |
+
with st.spinner("Building embeddings and initializing agent..."):
|
| 40 |
+
try:
|
| 41 |
+
# 1. Hit your FastAPI upload endpoint
|
| 42 |
+
payload = {"connection_url": db_input, "user_id": st.session_state.user_id}
|
| 43 |
+
response = requests.post(f"{API_URL}/upload_url", json=payload)
|
| 44 |
+
|
| 45 |
+
if response.status_code == 200:
|
| 46 |
+
# 2. Lock the connection and unlock the chat
|
| 47 |
+
st.session_state.is_db_connected = True
|
| 48 |
+
st.session_state.connection_url = db_input
|
| 49 |
+
|
| 50 |
+
# Because your FastAPI upload uses BackgroundTasks, it returns instantly.
|
| 51 |
+
# We add a 2-second UI buffer here so the Qdrant embeddings have time to finish
|
| 52 |
+
# before the user fires off their first chat question.
|
| 53 |
+
time.sleep(15)
|
| 54 |
+
|
| 55 |
+
st.success("Database connected securely!")
|
| 56 |
+
st.rerun() # Refresh UI to unlock the chat window
|
| 57 |
+
else:
|
| 58 |
+
st.error(f"Failed to connect: {response.text}")
|
| 59 |
+
except requests.exceptions.ConnectionError:
|
| 60 |
+
st.error("🚨 Cannot connect to backend. Is FastAPI running?")
|
| 61 |
+
else:
|
| 62 |
+
st.success("✅ Connected to Database")
|
| 63 |
+
st.caption(f"URL: {st.session_state.connection_url}")
|
| 64 |
+
|
| 65 |
+
# Add a reset button just in case they want to start completely over
|
| 66 |
+
if st.button("Disconnect & Reset", use_container_width=True):
|
| 67 |
+
st.session_state.clear()
|
| 68 |
+
st.rerun()
|
| 69 |
+
|
| 70 |
+
# --- MAIN CHAT INTERFACE ---
|
| 71 |
+
st.title("🗣️ Text2SQL Agent")
|
| 72 |
+
|
| 73 |
+
# The Lock: Do not render the chat if DB is not connected
|
| 74 |
+
if not st.session_state.is_db_connected:
|
| 75 |
+
st.info("👈 Please connect your database in the sidebar to begin analyzing data.")
|
| 76 |
+
else:
|
| 77 |
+
# 1. Display previous chat messages from session state
|
| 78 |
+
for msg in st.session_state.chat_history:
|
| 79 |
+
with st.chat_message(msg["role"]):
|
| 80 |
+
st.markdown(msg["content"])
|
| 81 |
+
|
| 82 |
+
# 2. The Chat Input box
|
| 83 |
+
if user_query := st.chat_input("Ask a question about your data..."):
|
| 84 |
+
|
| 85 |
+
# Immediately display the user's question in the UI
|
| 86 |
+
st.session_state.chat_history.append({"role": "user", "content": user_query})
|
| 87 |
+
with st.chat_message("user"):
|
| 88 |
+
st.markdown(user_query)
|
| 89 |
+
|
| 90 |
+
# 3. Call the LangGraph Backend
|
| 91 |
+
with st.chat_message("assistant"):
|
| 92 |
+
with st.spinner("Analyzing schema and generating SQL..."):
|
| 93 |
+
try:
|
| 94 |
+
payload = {
|
| 95 |
+
"message": user_query,
|
| 96 |
+
"thread_id": st.session_state.thread_id,
|
| 97 |
+
"user_id": st.session_state.user_id,
|
| 98 |
+
"connection_url": st.session_state.connection_url
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
response = requests.post(f"{API_URL}/chat", json=payload)
|
| 102 |
+
|
| 103 |
+
if response.status_code == 200:
|
| 104 |
+
# Extract the final_result from your FastAPI JSON response
|
| 105 |
+
answer = response.json().get("response", "No response found.")
|
| 106 |
+
st.markdown(answer)
|
| 107 |
+
|
| 108 |
+
# Save the assistant's answer to the UI history
|
| 109 |
+
st.session_state.chat_history.append({"role": "assistant", "content": answer})
|
| 110 |
+
else:
|
| 111 |
+
st.error(f"Agent Error: {response.text}")
|
| 112 |
+
|
| 113 |
+
except requests.exceptions.ConnectionError:
|
| 114 |
+
st.error("🚨 Connection dropped. Ensure FastAPI is running.")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "text-to-sql-agent"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"faker>=40.21.0",
|
| 9 |
+
"fastapi>=0.136.3",
|
| 10 |
+
"fastembed>=0.8.0",
|
| 11 |
+
"langchain>=1.3.4",
|
| 12 |
+
"langchain-community>=0.4.2",
|
| 13 |
+
"langchain-core>=1.4.0",
|
| 14 |
+
"langchain-openai>=1.2.2",
|
| 15 |
+
"langgraph>=1.2.4",
|
| 16 |
+
"langgraph-checkpoint-postgres>=3.1.0",
|
| 17 |
+
"langsmith>=0.8.8",
|
| 18 |
+
"psycopg-binary>=3.3.4",
|
| 19 |
+
"python-dotenv>=1.2.2",
|
| 20 |
+
"qdrant-client>=1.18.0",
|
| 21 |
+
"sqlalchemy>=2.0.50",
|
| 22 |
+
"streamlit>=1.58.0",
|
| 23 |
+
"pymysql",
|
| 24 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sqlalchemy
|
| 2 |
+
langchain-core
|
| 3 |
+
qdrant-client
|
| 4 |
+
fastembed
|
| 5 |
+
python-dotenv
|
| 6 |
+
langchain
|
| 7 |
+
langchain-classic
|
| 8 |
+
langchain-community
|
| 9 |
+
langgraph
|
| 10 |
+
langchain-openai
|
| 11 |
+
pydantic
|
| 12 |
+
fastapi
|
| 13 |
+
langgraph-checkpoint-postgres
|
| 14 |
+
uvicorn
|
| 15 |
+
python-multipart
|
| 16 |
+
streamlit
|
| 17 |
+
requests
|
| 18 |
+
psycopg-pool
|
| 19 |
+
langsmith
|
| 20 |
+
pymysql
|
src/embedding.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from qdrant_client import QdrantClient
|
| 2 |
+
from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
|
| 3 |
+
from fastembed import TextEmbedding, SparseTextEmbedding
|
| 4 |
+
import uuid
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import os
|
| 7 |
+
from src.scheme import create_scheme
|
| 8 |
+
|
| 9 |
+
COLLECTION_NAME = "Text2SQL"
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
qdrant_api = os.getenv("QDRANT_API_KEY")
|
| 14 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
| 15 |
+
|
| 16 |
+
def create_embeddings(connection_url : str , user_id : str) :
|
| 17 |
+
client = QdrantClient(api_key=qdrant_api , url=qdrant_url)
|
| 18 |
+
|
| 19 |
+
dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 20 |
+
sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
|
| 21 |
+
|
| 22 |
+
if not client.collection_exists(COLLECTION_NAME) :
|
| 23 |
+
client.create_collection(collection_name=COLLECTION_NAME,
|
| 24 |
+
vectors_config={"dense": VectorParams(size=384, distance=Distance.COSINE)},
|
| 25 |
+
sparse_vectors_config={"sparse": SparseVectorParams()})
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
client.create_payload_index(
|
| 29 |
+
collection_name=COLLECTION_NAME,
|
| 30 |
+
field_name="user_id",
|
| 31 |
+
field_schema="keyword",
|
| 32 |
+
)
|
| 33 |
+
except Exception:
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
docs = create_scheme(connection_url)
|
| 37 |
+
text = [doc.page_content for doc in docs]
|
| 38 |
+
|
| 39 |
+
dense_vectors = list(dense_model.embed(text))
|
| 40 |
+
sparse_vectors = list(sparse_model.embed(text))
|
| 41 |
+
|
| 42 |
+
points = []
|
| 43 |
+
|
| 44 |
+
for i , doc in enumerate(docs) :
|
| 45 |
+
dense_vector = dense_vectors[i].tolist()
|
| 46 |
+
|
| 47 |
+
sparse_embeddings = sparse_vectors[i]
|
| 48 |
+
|
| 49 |
+
sparse_vector = {
|
| 50 |
+
'indices' : sparse_embeddings.indices.tolist(),
|
| 51 |
+
'values' : sparse_embeddings.values.tolist()
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
table_id = str(uuid.uuid4())
|
| 55 |
+
|
| 56 |
+
point = PointStruct(
|
| 57 |
+
id = table_id ,
|
| 58 |
+
vector = {
|
| 59 |
+
"dense" : dense_vector ,
|
| 60 |
+
"sparse" : sparse_vector
|
| 61 |
+
},
|
| 62 |
+
payload = {
|
| 63 |
+
'user_id' : user_id,
|
| 64 |
+
'text' : doc.page_content,
|
| 65 |
+
'table_name' : doc.metadata.get("table_name")
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
points.append(point)
|
| 70 |
+
|
| 71 |
+
client.upsert(collection_name=COLLECTION_NAME, points=points)
|
src/graph.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict , Annotated , List , Optional
|
| 2 |
+
from langgraph.graph.message import add_messages
|
| 3 |
+
from langchain_core.messages import SystemMessage , HumanMessage
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
from src.retrieval import retrieve
|
| 6 |
+
import os
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from langgraph.graph import StateGraph, START ,END
|
| 9 |
+
from pydantic import BaseModel , Field
|
| 10 |
+
import datetime
|
| 11 |
+
from langchain_community.utilities import SQLDatabase
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
class State(TypedDict) :
|
| 16 |
+
connection_url : str
|
| 17 |
+
user_id : str
|
| 18 |
+
messages : Annotated[List , add_messages]
|
| 19 |
+
scheme : str
|
| 20 |
+
sql_query : str
|
| 21 |
+
query_result : str
|
| 22 |
+
error : Optional[str]
|
| 23 |
+
retry : int
|
| 24 |
+
final_result : str
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
llm = ChatOpenAI(
|
| 28 |
+
model="openai/gpt-4o-mini",
|
| 29 |
+
openai_api_key=os.getenv("OPENROUTER_API_KEY"),
|
| 30 |
+
openai_api_base="https://openrouter.ai/api/v1",
|
| 31 |
+
temperature=0
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
class sql_query(BaseModel) :
|
| 35 |
+
generated_sql_query : str = Field(...,description="The raw, valid executable SQL query text. Contain absolutely NO markdown wrapping, code blocks, or conversational formatting.")
|
| 36 |
+
|
| 37 |
+
def retrieve_node(state : State) :
|
| 38 |
+
messages = state.get("messages")
|
| 39 |
+
db_url = state.get("connection_url")
|
| 40 |
+
user_id = state.get("user_id")
|
| 41 |
+
|
| 42 |
+
query = messages[-1].content
|
| 43 |
+
|
| 44 |
+
scheme = retrieve(user_id , query , db_url)
|
| 45 |
+
|
| 46 |
+
return {'scheme' : scheme}
|
| 47 |
+
|
| 48 |
+
def generate_node(state : State) :
|
| 49 |
+
messages = state.get("messages")
|
| 50 |
+
scheme = state.get("scheme")
|
| 51 |
+
error = state.get("error")
|
| 52 |
+
wrong_query = state.get('sql_query')
|
| 53 |
+
|
| 54 |
+
llm_with_structured_output = llm.with_structured_output(sql_query)
|
| 55 |
+
|
| 56 |
+
history_messages = messages[:-1]
|
| 57 |
+
current_query_string = messages[-1].content
|
| 58 |
+
|
| 59 |
+
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
| 60 |
+
|
| 61 |
+
if history_messages:
|
| 62 |
+
history_text = "\n".join([
|
| 63 |
+
f"{msg.type.capitalize()}: {msg.content}"
|
| 64 |
+
for msg in history_messages
|
| 65 |
+
])
|
| 66 |
+
else:
|
| 67 |
+
history_text = "This is the first user request. No history exists."
|
| 68 |
+
|
| 69 |
+
if error and wrong_query :
|
| 70 |
+
error_context = f"""
|
| 71 |
+
=== 🚨 ERROR CORRECTION MODE 🚨 ===
|
| 72 |
+
Your previous attempt to answer this request failed.
|
| 73 |
+
[PREVIOUS BROKEN QUERY]:
|
| 74 |
+
{wrong_query}
|
| 75 |
+
|
| 76 |
+
[DATABASE ERROR MESSAGE]:
|
| 77 |
+
{error}
|
| 78 |
+
|
| 79 |
+
INSTRUCTION: Analyze the error message and the schema carefully. Fix the syntax, column names, or logic, and generate a CORRECTED query.
|
| 80 |
+
"""
|
| 81 |
+
else :
|
| 82 |
+
error_context = ""
|
| 83 |
+
|
| 84 |
+
system_prompt = SystemMessage(content=f"""You are an expert Data Analyst and Database Engineer.
|
| 85 |
+
Your job is to write highly optimized, perfectly accurate database queries based on user requests.
|
| 86 |
+
|
| 87 |
+
=== DATABASE SCHEMA & DIALECT ===
|
| 88 |
+
Look at the metadata below to identify the targeted database engine dialect and table layout:
|
| 89 |
+
{scheme}
|
| 90 |
+
|
| 91 |
+
=== CONVERSATION HISTORY ===
|
| 92 |
+
Use this previous context to resolve ambiguous terms (e.g., if the user says "filter those by...", look here to see what "those" refers to):
|
| 93 |
+
{history_text}
|
| 94 |
+
{error_context}
|
| 95 |
+
|
| 96 |
+
=== CRITICAL RULES ===
|
| 97 |
+
1. ALIGNMENT: Only use the tables and columns provided in the schema above. Do not hallucinate column names.
|
| 98 |
+
2. DIALECT MATCHING: Look at the 'Dialect:' specified above and write strict queries matching that exact syntax.
|
| 99 |
+
3. JOINS: Pay close attention to the FOREIGN KEY constraints provided in the schema to perform accurate JOINs.
|
| 100 |
+
4. CURRENT DATE: Today's date is {current_date}. Use this exact date for any relative time filters (e.g., "last month", "this year").
|
| 101 |
+
5. CASE SENSITIVITY: When filtering by strings, use case-insensitive comparisons (e.g., LOWER(column) = LOWER('value')) unless instructed otherwise.
|
| 102 |
+
6. SECURITY: NEVER generate DML queries (INSERT, UPDATE, DELETE, DROP). Only generate SELECT statements.
|
| 103 |
+
|
| 104 |
+
=== OUTPUT SELECTION RULES ===
|
| 105 |
+
1. If the user asks WHO / WHICH / WHAT IS THE NAME / identify a person, customer, user, product, company, or entity, return the human-readable name field, not just the ID.
|
| 106 |
+
2. If the schema has both an ID column and a name column, prefer selecting the name column in the final output.
|
| 107 |
+
3. If the name is in another table, use the required JOIN to fetch it.
|
| 108 |
+
4. Only return an ID alone when the user explicitly asks for the ID, or when no name-like field exists in the schema.
|
| 109 |
+
5. For count/number questions, return an aggregate numeric result, not a list of rows.
|
| 110 |
+
6. For "who/which" questions, do not answer with only identifiers if a readable label exists in the schema.
|
| 111 |
+
|
| 112 |
+
=== INSTRUCTIONS ===
|
| 113 |
+
First, think through the necessary tables, filters, joins, and the exact type of answer expected.
|
| 114 |
+
Then, provide the final executable SQL query specifically for the LATEST USER REQUEST.""")
|
| 115 |
+
|
| 116 |
+
final_msg = [
|
| 117 |
+
system_prompt,
|
| 118 |
+
HumanMessage(content=f"LATEST USER REQUEST:\n{current_query_string}")
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
response = llm_with_structured_output.invoke(final_msg)
|
| 122 |
+
|
| 123 |
+
return {'sql_query' : response.generated_sql_query , "error" : None}
|
| 124 |
+
|
| 125 |
+
def execute_node(state : State) :
|
| 126 |
+
url = state.get("connection_url")
|
| 127 |
+
sql_query = state.get("sql_query")
|
| 128 |
+
retry = state.get("retry" , 0)
|
| 129 |
+
|
| 130 |
+
try :
|
| 131 |
+
db = SQLDatabase.from_uri(url)
|
| 132 |
+
|
| 133 |
+
result = db.run(sql_query)
|
| 134 |
+
|
| 135 |
+
return {"query_result" : result , "error" : None , "retry" : 0}
|
| 136 |
+
|
| 137 |
+
except Exception as e :
|
| 138 |
+
return {'error' : str(e) , "retry" : retry+1}
|
| 139 |
+
|
| 140 |
+
def routing(state : State) :
|
| 141 |
+
error = state.get("error")
|
| 142 |
+
retry = state.get('retry' , 0)
|
| 143 |
+
|
| 144 |
+
if error and retry<3 :
|
| 145 |
+
return "generate_node"
|
| 146 |
+
else :
|
| 147 |
+
return "answer_node"
|
| 148 |
+
|
| 149 |
+
def answer_node(state : State) :
|
| 150 |
+
messages = state.get("messages")
|
| 151 |
+
query_result = state.get("query_result" , "No records found.")
|
| 152 |
+
error = state.get("error")
|
| 153 |
+
|
| 154 |
+
history_messages = messages[:-1]
|
| 155 |
+
user_query = messages[-1].content
|
| 156 |
+
|
| 157 |
+
if history_messages:
|
| 158 |
+
history_text = "\n".join([
|
| 159 |
+
f"{msg.type.capitalize()}: {msg.content}"
|
| 160 |
+
for msg in history_messages
|
| 161 |
+
])
|
| 162 |
+
else:
|
| 163 |
+
history_text = "This is the first user request. No history exists."
|
| 164 |
+
|
| 165 |
+
system_prompt = f"""You are a helpful Data Analyst communicating directly with a user.
|
| 166 |
+
|
| 167 |
+
=== CONVERSATION HISTORY ===
|
| 168 |
+
Use this to maintain the context and tone of the conversation:
|
| 169 |
+
{history_text}
|
| 170 |
+
|
| 171 |
+
=== EXECUTION CONTEXT ===\n"""
|
| 172 |
+
|
| 173 |
+
if error:
|
| 174 |
+
system_prompt += f"""Unfortunately, the database returned an error and the data could not be retrieved.
|
| 175 |
+
Error details: {error}
|
| 176 |
+
INSTRUCTION: Politely apologize to the user and briefly explain that you encountered a technical issue retrieving their specific request."""
|
| 177 |
+
else:
|
| 178 |
+
system_prompt += f"""The database returned this raw data: {query_result}
|
| 179 |
+
|
| 180 |
+
INSTRUCTIONS:
|
| 181 |
+
1. Answer using ONLY the returned data.
|
| 182 |
+
2. Never invent a name, value, or entity that is not present in the result.
|
| 183 |
+
3. If the result contains both an ID and a name, use the name in the final answer and mention the ID only if helpful.
|
| 184 |
+
4. If the result contains only an ID and the user asked for a name/person/entity, say that the returned data only contains an identifier and no readable name.
|
| 185 |
+
5. Do not substitute or guess a name from a customer_id or any other identifier.
|
| 186 |
+
6. Do not mention SQL, the database, schemas, or how you got the data.
|
| 187 |
+
7. Give a clean, professional, and conversational response."""
|
| 188 |
+
|
| 189 |
+
final_msg = [
|
| 190 |
+
SystemMessage(content=system_prompt),
|
| 191 |
+
HumanMessage(content=f"LATEST USER REQUEST:\n{user_query}")
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
response = llm.invoke(final_msg)
|
| 195 |
+
|
| 196 |
+
return {"messages": [response], "final_result": response.content}
|
| 197 |
+
|
| 198 |
+
workflow = StateGraph(State)
|
| 199 |
+
|
| 200 |
+
workflow.add_node("retrieve_node" , retrieve_node)
|
| 201 |
+
workflow.add_node("generate_node" , generate_node)
|
| 202 |
+
workflow.add_node("execute_node" , execute_node)
|
| 203 |
+
workflow.add_node("answer_node" , answer_node)
|
| 204 |
+
|
| 205 |
+
workflow.add_edge(START , "retrieve_node")
|
| 206 |
+
workflow.add_edge("retrieve_node" , "generate_node")
|
| 207 |
+
workflow.add_edge("generate_node" , "execute_node")
|
| 208 |
+
workflow.add_conditional_edges("execute_node" , routing , {
|
| 209 |
+
"answer_node" : "answer_node" , "generate_node" : "generate_node"
|
| 210 |
+
})
|
| 211 |
+
workflow.add_edge("answer_node" , END)
|
src/main.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI , HTTPException , BackgroundTasks
|
| 2 |
+
from src.embedding import create_embeddings
|
| 3 |
+
from src.graph import workflow
|
| 4 |
+
from pydantic import BaseModel , Field
|
| 5 |
+
from langgraph.checkpoint.postgres import PostgresSaver
|
| 6 |
+
from langchain_core.messages import HumanMessage
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
app = FastAPI(
|
| 10 |
+
title="Text2SQL Agent API",
|
| 11 |
+
description="A production-grade backend powering LangGraph agent.",
|
| 12 |
+
version="1.0.0"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
class UploadRequest(BaseModel):
|
| 16 |
+
connection_url: str = Field(..., description="Database URL")
|
| 17 |
+
user_id: str = Field(..., description="The unique identifier for the tenant context.")
|
| 18 |
+
|
| 19 |
+
class ChatRequest(BaseModel) :
|
| 20 |
+
message : str = Field(...,description="Input message by the user.")
|
| 21 |
+
thread_id : str = Field(...,description="Unique session ID to maintain short term memory.")
|
| 22 |
+
user_id : str = Field(...,description="The unique identifier for the tenant context.")
|
| 23 |
+
connection_url : str = Field(...,description="Database URL")
|
| 24 |
+
|
| 25 |
+
@app.post("/upload_url" , summary="Recieve database URL and invoke ingestion pipeline.")
|
| 26 |
+
def upload(request : UploadRequest , background_tasks : BackgroundTasks) :
|
| 27 |
+
background_tasks.add_task(create_embeddings , request.connection_url , request.user_id)
|
| 28 |
+
|
| 29 |
+
return {
|
| 30 |
+
"status" : "success",
|
| 31 |
+
"message" : "Ingestion Pipeline started !"
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
@app.post("/chat" , summary="Return the response generated by the agent for the given user query.")
|
| 35 |
+
def chat_endpoint(request : ChatRequest) :
|
| 36 |
+
db_uri = os.getenv("DATABASE_URI")
|
| 37 |
+
|
| 38 |
+
with PostgresSaver.from_conn_string(db_uri) as checkpointer:
|
| 39 |
+
checkpointer.setup()
|
| 40 |
+
|
| 41 |
+
agent = workflow.compile(
|
| 42 |
+
checkpointer=checkpointer
|
| 43 |
+
)
|
| 44 |
+
config = {
|
| 45 |
+
"configurable" : {
|
| 46 |
+
'thread_id' : request.thread_id
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
initial_state = {
|
| 51 |
+
'connection_url' : request.connection_url ,
|
| 52 |
+
'user_id' : request.user_id ,
|
| 53 |
+
'messages' : [HumanMessage(content=request.message)],
|
| 54 |
+
'retry' : 0
|
| 55 |
+
}
|
| 56 |
+
try :
|
| 57 |
+
result = agent.invoke(initial_state , config=config)
|
| 58 |
+
|
| 59 |
+
final_result = result.get("final_result")
|
| 60 |
+
|
| 61 |
+
print("*"*50 , flush=True)
|
| 62 |
+
print(f"\n\n Scheme : {result['scheme']}\n\n" , flush=True)
|
| 63 |
+
print(f"\n\nSql Query : {result['sql_query']}\n\n" , flush=True)
|
| 64 |
+
print(f"\n\nQuery Result : {result['query_result']}\n\n" , flush=True)
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
"status": "success",
|
| 68 |
+
"thread_id": request.thread_id,
|
| 69 |
+
"response": final_result
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
except Exception as e :
|
| 73 |
+
raise HTTPException(status_code=500 , detail=f"Error : {str(e)}")
|
src/retrieval.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from qdrant_client import QdrantClient
|
| 4 |
+
from qdrant_client import models
|
| 5 |
+
from fastembed import TextEmbedding, SparseTextEmbedding
|
| 6 |
+
from langchain_community.utilities import SQLDatabase
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
qdrant_api = os.getenv("QDRANT_API_KEY")
|
| 11 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
| 12 |
+
|
| 13 |
+
COLLECTION_NAME = "Text2SQL"
|
| 14 |
+
|
| 15 |
+
def retrieve(user_id : str , query : str , connection_url: str) :
|
| 16 |
+
|
| 17 |
+
client = QdrantClient(api_key=qdrant_api , url=qdrant_url)
|
| 18 |
+
|
| 19 |
+
dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 20 |
+
sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
|
| 21 |
+
|
| 22 |
+
dense_query_vector = list(dense_model.embed([query]))[0]
|
| 23 |
+
|
| 24 |
+
sparse_query = list(sparse_model.embed([query]))[0]
|
| 25 |
+
|
| 26 |
+
sparse_query_vector = models.SparseVector(indices=sparse_query.indices,
|
| 27 |
+
values=sparse_query.values)
|
| 28 |
+
|
| 29 |
+
user_filter = models.Filter(
|
| 30 |
+
must=[
|
| 31 |
+
models.FieldCondition(
|
| 32 |
+
key="user_id",
|
| 33 |
+
match=models.MatchValue(value=user_id)
|
| 34 |
+
)
|
| 35 |
+
]
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
results = client.query_points(
|
| 39 |
+
collection_name=COLLECTION_NAME,
|
| 40 |
+
prefetch=[
|
| 41 |
+
models.Prefetch(
|
| 42 |
+
query=dense_query_vector,
|
| 43 |
+
limit=10,
|
| 44 |
+
using="dense",
|
| 45 |
+
filter=user_filter
|
| 46 |
+
),
|
| 47 |
+
models.Prefetch(
|
| 48 |
+
query=sparse_query_vector,
|
| 49 |
+
using="sparse",
|
| 50 |
+
limit=10,
|
| 51 |
+
filter=user_filter
|
| 52 |
+
)
|
| 53 |
+
],
|
| 54 |
+
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
| 55 |
+
limit=10
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
tables = []
|
| 59 |
+
for point in results.points :
|
| 60 |
+
table = point.payload['table_name']
|
| 61 |
+
if table not in tables :
|
| 62 |
+
tables.append(table)
|
| 63 |
+
|
| 64 |
+
db = SQLDatabase.from_uri(connection_url , sample_rows_in_table_info=0)
|
| 65 |
+
|
| 66 |
+
dialect = db.dialect
|
| 67 |
+
|
| 68 |
+
final_schemes = f"Dialect : {dialect}\n {db.get_table_info(table_names=tables)}"
|
| 69 |
+
|
| 70 |
+
return final_schemes
|
src/scheme.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import create_engine , inspect
|
| 2 |
+
from langchain_core.documents import Document
|
| 3 |
+
|
| 4 |
+
def create_scheme(database_url : str) -> list[dict] :
|
| 5 |
+
|
| 6 |
+
scheme_docs = []
|
| 7 |
+
|
| 8 |
+
engine = create_engine(database_url)
|
| 9 |
+
inspector = inspect(engine)
|
| 10 |
+
|
| 11 |
+
tables = inspector.get_table_names()
|
| 12 |
+
|
| 13 |
+
for table in tables :
|
| 14 |
+
clean_table = table.replace("_" , " ")
|
| 15 |
+
|
| 16 |
+
columns = inspector.get_columns(table)
|
| 17 |
+
|
| 18 |
+
clean_columns = []
|
| 19 |
+
|
| 20 |
+
for col in columns :
|
| 21 |
+
clean_col = col['name'].replace("_" , " ")
|
| 22 |
+
clean_columns.append(clean_col)
|
| 23 |
+
|
| 24 |
+
doc = f"Table: {clean_table}.\nColumns: {', '.join(clean_columns)}"
|
| 25 |
+
|
| 26 |
+
scheme_docs.append(
|
| 27 |
+
Document(
|
| 28 |
+
page_content=doc,
|
| 29 |
+
metadata={
|
| 30 |
+
"table_name" : table
|
| 31 |
+
}
|
| 32 |
+
)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
return scheme_docs
|
start.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
echo "Starting FastAPI Backend..."
|
| 2 |
+
uv run uvicorn src.main:app --host 0.0.0.0 --port 8000 &
|
| 3 |
+
|
| 4 |
+
sleep 3
|
| 5 |
+
|
| 6 |
+
echo "Starting Streamlit Frontend..."
|
| 7 |
+
uv run streamlit run app.py --server.port=7860 --server.address=0.0.0.0
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|