LightRT commited on
Commit
52adb86
·
0 Parent(s):

Project Completion Commit

Browse files
Files changed (15) hide show
  1. .dockerignore +6 -0
  2. .gitignore +7 -0
  3. .python-version +1 -0
  4. Dockerfile +41 -0
  5. README.md +0 -0
  6. app.py +114 -0
  7. pyproject.toml +24 -0
  8. requirements.txt +20 -0
  9. src/embedding.py +71 -0
  10. src/graph.py +211 -0
  11. src/main.py +73 -0
  12. src/retrieval.py +70 -0
  13. src/scheme.py +35 -0
  14. start.sh +7 -0
  15. uv.lock +0 -0
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .git/
2
+ .gitignore
3
+ .env
4
+ .venv/
5
+ __pycache__/
6
+ *.pyc
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ .venv/
5
+ .env/
6
+ venv/.env
7
+ .env
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. Base Image
2
+ FROM python:3.11-slim
3
+
4
+ # 2. Environment Variables for Hugging Face compatibility
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1 \
7
+ HOME=/app \
8
+ PATH=/app/.local/bin:$PATH
9
+
10
+ WORKDIR /app
11
+
12
+ # 3. Install System Dependencies
13
+ # libpq-dev is for PostgreSQL, curl is for Streamlit health checks
14
+ RUN apt-get update && apt-get install -y \
15
+ build-essential \
16
+ libpq-dev \
17
+ curl \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ # 4. Install uv (The blazing fast package manager)
21
+ RUN pip install uv
22
+
23
+ # 5. Cache & Install Python Dependencies
24
+ COPY pyproject.toml uv.lock ./
25
+ RUN uv pip install --system -r pyproject.toml
26
+
27
+ # 6. Copy your application code
28
+ COPY . .
29
+
30
+ # 7. Permissions: Make the script executable
31
+ RUN chmod +x start.sh
32
+
33
+ # 8. Permissions: Hugging Face runs as user 1000, not root!
34
+ RUN chown -R 1000:1000 /app
35
+ USER 1000
36
+
37
+ # 9. Expose Ports (7860 for UI, 8000 for internal API)
38
+ EXPOSE 7860 8000
39
+
40
+ # 10. Start the application
41
+ CMD ["./start.sh"]
README.md ADDED
File without changes
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import uuid
4
+ import time
5
+
6
+ # --- CONFIGURATION ---
7
+ API_URL = "http://localhost:8000" # Your FastAPI server URL
8
+
9
+ st.set_page_config(page_title="Text@SQL Agent", page_icon="🤖", layout="centered")
10
+
11
+ # --- SESSION STATE INITIALIZATION ---
12
+ # This ensures variables survive when Streamlit re-renders the page
13
+ if "thread_id" not in st.session_state:
14
+ st.session_state.thread_id = str(uuid.uuid4()) # Unique session ID for LangGraph memory
15
+ if "user_id" not in st.session_state:
16
+ st.session_state.user_id = "tenant_" + str(uuid.uuid4())[:8]
17
+ if "is_db_connected" not in st.session_state:
18
+ st.session_state.is_db_connected = False
19
+ if "connection_url" not in st.session_state:
20
+ st.session_state.connection_url = ""
21
+ if "chat_history" not in st.session_state:
22
+ st.session_state.chat_history = []
23
+
24
+ # --- SIDEBAR: DATABASE CONNECTION ---
25
+ with st.sidebar:
26
+ st.header("⚙️ Database Setup")
27
+
28
+ # If already connected, disable the input to enforce ONE database connection
29
+ db_input = st.text_input(
30
+ "Enter Database URL:",
31
+ disabled=st.session_state.is_db_connected
32
+ )
33
+
34
+ if not st.session_state.is_db_connected:
35
+ if st.button("Connect & Initialize", type="primary", use_container_width=True):
36
+ if not db_input:
37
+ st.error("Please enter a valid URL.")
38
+ else:
39
+ with st.spinner("Building embeddings and initializing agent..."):
40
+ try:
41
+ # 1. Hit your FastAPI upload endpoint
42
+ payload = {"connection_url": db_input, "user_id": st.session_state.user_id}
43
+ response = requests.post(f"{API_URL}/upload_url", json=payload)
44
+
45
+ if response.status_code == 200:
46
+ # 2. Lock the connection and unlock the chat
47
+ st.session_state.is_db_connected = True
48
+ st.session_state.connection_url = db_input
49
+
50
+ # Because your FastAPI upload uses BackgroundTasks, it returns instantly.
51
+ # We add a 2-second UI buffer here so the Qdrant embeddings have time to finish
52
+ # before the user fires off their first chat question.
53
+ time.sleep(15)
54
+
55
+ st.success("Database connected securely!")
56
+ st.rerun() # Refresh UI to unlock the chat window
57
+ else:
58
+ st.error(f"Failed to connect: {response.text}")
59
+ except requests.exceptions.ConnectionError:
60
+ st.error("🚨 Cannot connect to backend. Is FastAPI running?")
61
+ else:
62
+ st.success("✅ Connected to Database")
63
+ st.caption(f"URL: {st.session_state.connection_url}")
64
+
65
+ # Add a reset button just in case they want to start completely over
66
+ if st.button("Disconnect & Reset", use_container_width=True):
67
+ st.session_state.clear()
68
+ st.rerun()
69
+
70
+ # --- MAIN CHAT INTERFACE ---
71
+ st.title("🗣️ Text2SQL Agent")
72
+
73
+ # The Lock: Do not render the chat if DB is not connected
74
+ if not st.session_state.is_db_connected:
75
+ st.info("👈 Please connect your database in the sidebar to begin analyzing data.")
76
+ else:
77
+ # 1. Display previous chat messages from session state
78
+ for msg in st.session_state.chat_history:
79
+ with st.chat_message(msg["role"]):
80
+ st.markdown(msg["content"])
81
+
82
+ # 2. The Chat Input box
83
+ if user_query := st.chat_input("Ask a question about your data..."):
84
+
85
+ # Immediately display the user's question in the UI
86
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
87
+ with st.chat_message("user"):
88
+ st.markdown(user_query)
89
+
90
+ # 3. Call the LangGraph Backend
91
+ with st.chat_message("assistant"):
92
+ with st.spinner("Analyzing schema and generating SQL..."):
93
+ try:
94
+ payload = {
95
+ "message": user_query,
96
+ "thread_id": st.session_state.thread_id,
97
+ "user_id": st.session_state.user_id,
98
+ "connection_url": st.session_state.connection_url
99
+ }
100
+
101
+ response = requests.post(f"{API_URL}/chat", json=payload)
102
+
103
+ if response.status_code == 200:
104
+ # Extract the final_result from your FastAPI JSON response
105
+ answer = response.json().get("response", "No response found.")
106
+ st.markdown(answer)
107
+
108
+ # Save the assistant's answer to the UI history
109
+ st.session_state.chat_history.append({"role": "assistant", "content": answer})
110
+ else:
111
+ st.error(f"Agent Error: {response.text}")
112
+
113
+ except requests.exceptions.ConnectionError:
114
+ st.error("🚨 Connection dropped. Ensure FastAPI is running.")
pyproject.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "text-to-sql-agent"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "faker>=40.21.0",
9
+ "fastapi>=0.136.3",
10
+ "fastembed>=0.8.0",
11
+ "langchain>=1.3.4",
12
+ "langchain-community>=0.4.2",
13
+ "langchain-core>=1.4.0",
14
+ "langchain-openai>=1.2.2",
15
+ "langgraph>=1.2.4",
16
+ "langgraph-checkpoint-postgres>=3.1.0",
17
+ "langsmith>=0.8.8",
18
+ "psycopg-binary>=3.3.4",
19
+ "python-dotenv>=1.2.2",
20
+ "qdrant-client>=1.18.0",
21
+ "sqlalchemy>=2.0.50",
22
+ "streamlit>=1.58.0",
23
+ "pymysql",
24
+ ]
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sqlalchemy
2
+ langchain-core
3
+ qdrant-client
4
+ fastembed
5
+ python-dotenv
6
+ langchain
7
+ langchain-classic
8
+ langchain-community
9
+ langgraph
10
+ langchain-openai
11
+ pydantic
12
+ fastapi
13
+ langgraph-checkpoint-postgres
14
+ uvicorn
15
+ python-multipart
16
+ streamlit
17
+ requests
18
+ psycopg-pool
19
+ langsmith
20
+ pymysql
src/embedding.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
3
+ from fastembed import TextEmbedding, SparseTextEmbedding
4
+ import uuid
5
+ from dotenv import load_dotenv
6
+ import os
7
+ from src.scheme import create_scheme
8
+
9
+ COLLECTION_NAME = "Text2SQL"
10
+
11
+ load_dotenv()
12
+
13
+ qdrant_api = os.getenv("QDRANT_API_KEY")
14
+ qdrant_url = os.getenv("QDRANT_URL")
15
+
16
+ def create_embeddings(connection_url : str , user_id : str) :
17
+ client = QdrantClient(api_key=qdrant_api , url=qdrant_url)
18
+
19
+ dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
20
+ sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
21
+
22
+ if not client.collection_exists(COLLECTION_NAME) :
23
+ client.create_collection(collection_name=COLLECTION_NAME,
24
+ vectors_config={"dense": VectorParams(size=384, distance=Distance.COSINE)},
25
+ sparse_vectors_config={"sparse": SparseVectorParams()})
26
+
27
+ try:
28
+ client.create_payload_index(
29
+ collection_name=COLLECTION_NAME,
30
+ field_name="user_id",
31
+ field_schema="keyword",
32
+ )
33
+ except Exception:
34
+ pass
35
+
36
+ docs = create_scheme(connection_url)
37
+ text = [doc.page_content for doc in docs]
38
+
39
+ dense_vectors = list(dense_model.embed(text))
40
+ sparse_vectors = list(sparse_model.embed(text))
41
+
42
+ points = []
43
+
44
+ for i , doc in enumerate(docs) :
45
+ dense_vector = dense_vectors[i].tolist()
46
+
47
+ sparse_embeddings = sparse_vectors[i]
48
+
49
+ sparse_vector = {
50
+ 'indices' : sparse_embeddings.indices.tolist(),
51
+ 'values' : sparse_embeddings.values.tolist()
52
+ }
53
+
54
+ table_id = str(uuid.uuid4())
55
+
56
+ point = PointStruct(
57
+ id = table_id ,
58
+ vector = {
59
+ "dense" : dense_vector ,
60
+ "sparse" : sparse_vector
61
+ },
62
+ payload = {
63
+ 'user_id' : user_id,
64
+ 'text' : doc.page_content,
65
+ 'table_name' : doc.metadata.get("table_name")
66
+ }
67
+ )
68
+
69
+ points.append(point)
70
+
71
+ client.upsert(collection_name=COLLECTION_NAME, points=points)
src/graph.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict , Annotated , List , Optional
2
+ from langgraph.graph.message import add_messages
3
+ from langchain_core.messages import SystemMessage , HumanMessage
4
+ from langchain_openai import ChatOpenAI
5
+ from src.retrieval import retrieve
6
+ import os
7
+ from dotenv import load_dotenv
8
+ from langgraph.graph import StateGraph, START ,END
9
+ from pydantic import BaseModel , Field
10
+ import datetime
11
+ from langchain_community.utilities import SQLDatabase
12
+
13
+ load_dotenv()
14
+
15
+ class State(TypedDict) :
16
+ connection_url : str
17
+ user_id : str
18
+ messages : Annotated[List , add_messages]
19
+ scheme : str
20
+ sql_query : str
21
+ query_result : str
22
+ error : Optional[str]
23
+ retry : int
24
+ final_result : str
25
+
26
+
27
+ llm = ChatOpenAI(
28
+ model="openai/gpt-4o-mini",
29
+ openai_api_key=os.getenv("OPENROUTER_API_KEY"),
30
+ openai_api_base="https://openrouter.ai/api/v1",
31
+ temperature=0
32
+ )
33
+
34
+ class sql_query(BaseModel) :
35
+ generated_sql_query : str = Field(...,description="The raw, valid executable SQL query text. Contain absolutely NO markdown wrapping, code blocks, or conversational formatting.")
36
+
37
+ def retrieve_node(state : State) :
38
+ messages = state.get("messages")
39
+ db_url = state.get("connection_url")
40
+ user_id = state.get("user_id")
41
+
42
+ query = messages[-1].content
43
+
44
+ scheme = retrieve(user_id , query , db_url)
45
+
46
+ return {'scheme' : scheme}
47
+
48
+ def generate_node(state : State) :
49
+ messages = state.get("messages")
50
+ scheme = state.get("scheme")
51
+ error = state.get("error")
52
+ wrong_query = state.get('sql_query')
53
+
54
+ llm_with_structured_output = llm.with_structured_output(sql_query)
55
+
56
+ history_messages = messages[:-1]
57
+ current_query_string = messages[-1].content
58
+
59
+ current_date = datetime.datetime.now().strftime("%Y-%m-%d")
60
+
61
+ if history_messages:
62
+ history_text = "\n".join([
63
+ f"{msg.type.capitalize()}: {msg.content}"
64
+ for msg in history_messages
65
+ ])
66
+ else:
67
+ history_text = "This is the first user request. No history exists."
68
+
69
+ if error and wrong_query :
70
+ error_context = f"""
71
+ === 🚨 ERROR CORRECTION MODE 🚨 ===
72
+ Your previous attempt to answer this request failed.
73
+ [PREVIOUS BROKEN QUERY]:
74
+ {wrong_query}
75
+
76
+ [DATABASE ERROR MESSAGE]:
77
+ {error}
78
+
79
+ INSTRUCTION: Analyze the error message and the schema carefully. Fix the syntax, column names, or logic, and generate a CORRECTED query.
80
+ """
81
+ else :
82
+ error_context = ""
83
+
84
+ system_prompt = SystemMessage(content=f"""You are an expert Data Analyst and Database Engineer.
85
+ Your job is to write highly optimized, perfectly accurate database queries based on user requests.
86
+
87
+ === DATABASE SCHEMA & DIALECT ===
88
+ Look at the metadata below to identify the targeted database engine dialect and table layout:
89
+ {scheme}
90
+
91
+ === CONVERSATION HISTORY ===
92
+ Use this previous context to resolve ambiguous terms (e.g., if the user says "filter those by...", look here to see what "those" refers to):
93
+ {history_text}
94
+ {error_context}
95
+
96
+ === CRITICAL RULES ===
97
+ 1. ALIGNMENT: Only use the tables and columns provided in the schema above. Do not hallucinate column names.
98
+ 2. DIALECT MATCHING: Look at the 'Dialect:' specified above and write strict queries matching that exact syntax.
99
+ 3. JOINS: Pay close attention to the FOREIGN KEY constraints provided in the schema to perform accurate JOINs.
100
+ 4. CURRENT DATE: Today's date is {current_date}. Use this exact date for any relative time filters (e.g., "last month", "this year").
101
+ 5. CASE SENSITIVITY: When filtering by strings, use case-insensitive comparisons (e.g., LOWER(column) = LOWER('value')) unless instructed otherwise.
102
+ 6. SECURITY: NEVER generate DML queries (INSERT, UPDATE, DELETE, DROP). Only generate SELECT statements.
103
+
104
+ === OUTPUT SELECTION RULES ===
105
+ 1. If the user asks WHO / WHICH / WHAT IS THE NAME / identify a person, customer, user, product, company, or entity, return the human-readable name field, not just the ID.
106
+ 2. If the schema has both an ID column and a name column, prefer selecting the name column in the final output.
107
+ 3. If the name is in another table, use the required JOIN to fetch it.
108
+ 4. Only return an ID alone when the user explicitly asks for the ID, or when no name-like field exists in the schema.
109
+ 5. For count/number questions, return an aggregate numeric result, not a list of rows.
110
+ 6. For "who/which" questions, do not answer with only identifiers if a readable label exists in the schema.
111
+
112
+ === INSTRUCTIONS ===
113
+ First, think through the necessary tables, filters, joins, and the exact type of answer expected.
114
+ Then, provide the final executable SQL query specifically for the LATEST USER REQUEST.""")
115
+
116
+ final_msg = [
117
+ system_prompt,
118
+ HumanMessage(content=f"LATEST USER REQUEST:\n{current_query_string}")
119
+ ]
120
+
121
+ response = llm_with_structured_output.invoke(final_msg)
122
+
123
+ return {'sql_query' : response.generated_sql_query , "error" : None}
124
+
125
+ def execute_node(state : State) :
126
+ url = state.get("connection_url")
127
+ sql_query = state.get("sql_query")
128
+ retry = state.get("retry" , 0)
129
+
130
+ try :
131
+ db = SQLDatabase.from_uri(url)
132
+
133
+ result = db.run(sql_query)
134
+
135
+ return {"query_result" : result , "error" : None , "retry" : 0}
136
+
137
+ except Exception as e :
138
+ return {'error' : str(e) , "retry" : retry+1}
139
+
140
+ def routing(state : State) :
141
+ error = state.get("error")
142
+ retry = state.get('retry' , 0)
143
+
144
+ if error and retry<3 :
145
+ return "generate_node"
146
+ else :
147
+ return "answer_node"
148
+
149
+ def answer_node(state : State) :
150
+ messages = state.get("messages")
151
+ query_result = state.get("query_result" , "No records found.")
152
+ error = state.get("error")
153
+
154
+ history_messages = messages[:-1]
155
+ user_query = messages[-1].content
156
+
157
+ if history_messages:
158
+ history_text = "\n".join([
159
+ f"{msg.type.capitalize()}: {msg.content}"
160
+ for msg in history_messages
161
+ ])
162
+ else:
163
+ history_text = "This is the first user request. No history exists."
164
+
165
+ system_prompt = f"""You are a helpful Data Analyst communicating directly with a user.
166
+
167
+ === CONVERSATION HISTORY ===
168
+ Use this to maintain the context and tone of the conversation:
169
+ {history_text}
170
+
171
+ === EXECUTION CONTEXT ===\n"""
172
+
173
+ if error:
174
+ system_prompt += f"""Unfortunately, the database returned an error and the data could not be retrieved.
175
+ Error details: {error}
176
+ INSTRUCTION: Politely apologize to the user and briefly explain that you encountered a technical issue retrieving their specific request."""
177
+ else:
178
+ system_prompt += f"""The database returned this raw data: {query_result}
179
+
180
+ INSTRUCTIONS:
181
+ 1. Answer using ONLY the returned data.
182
+ 2. Never invent a name, value, or entity that is not present in the result.
183
+ 3. If the result contains both an ID and a name, use the name in the final answer and mention the ID only if helpful.
184
+ 4. If the result contains only an ID and the user asked for a name/person/entity, say that the returned data only contains an identifier and no readable name.
185
+ 5. Do not substitute or guess a name from a customer_id or any other identifier.
186
+ 6. Do not mention SQL, the database, schemas, or how you got the data.
187
+ 7. Give a clean, professional, and conversational response."""
188
+
189
+ final_msg = [
190
+ SystemMessage(content=system_prompt),
191
+ HumanMessage(content=f"LATEST USER REQUEST:\n{user_query}")
192
+ ]
193
+
194
+ response = llm.invoke(final_msg)
195
+
196
+ return {"messages": [response], "final_result": response.content}
197
+
198
+ workflow = StateGraph(State)
199
+
200
+ workflow.add_node("retrieve_node" , retrieve_node)
201
+ workflow.add_node("generate_node" , generate_node)
202
+ workflow.add_node("execute_node" , execute_node)
203
+ workflow.add_node("answer_node" , answer_node)
204
+
205
+ workflow.add_edge(START , "retrieve_node")
206
+ workflow.add_edge("retrieve_node" , "generate_node")
207
+ workflow.add_edge("generate_node" , "execute_node")
208
+ workflow.add_conditional_edges("execute_node" , routing , {
209
+ "answer_node" : "answer_node" , "generate_node" : "generate_node"
210
+ })
211
+ workflow.add_edge("answer_node" , END)
src/main.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI , HTTPException , BackgroundTasks
2
+ from src.embedding import create_embeddings
3
+ from src.graph import workflow
4
+ from pydantic import BaseModel , Field
5
+ from langgraph.checkpoint.postgres import PostgresSaver
6
+ from langchain_core.messages import HumanMessage
7
+ import os
8
+
9
+ app = FastAPI(
10
+ title="Text2SQL Agent API",
11
+ description="A production-grade backend powering LangGraph agent.",
12
+ version="1.0.0"
13
+ )
14
+
15
+ class UploadRequest(BaseModel):
16
+ connection_url: str = Field(..., description="Database URL")
17
+ user_id: str = Field(..., description="The unique identifier for the tenant context.")
18
+
19
+ class ChatRequest(BaseModel) :
20
+ message : str = Field(...,description="Input message by the user.")
21
+ thread_id : str = Field(...,description="Unique session ID to maintain short term memory.")
22
+ user_id : str = Field(...,description="The unique identifier for the tenant context.")
23
+ connection_url : str = Field(...,description="Database URL")
24
+
25
+ @app.post("/upload_url" , summary="Recieve database URL and invoke ingestion pipeline.")
26
+ def upload(request : UploadRequest , background_tasks : BackgroundTasks) :
27
+ background_tasks.add_task(create_embeddings , request.connection_url , request.user_id)
28
+
29
+ return {
30
+ "status" : "success",
31
+ "message" : "Ingestion Pipeline started !"
32
+ }
33
+
34
+ @app.post("/chat" , summary="Return the response generated by the agent for the given user query.")
35
+ def chat_endpoint(request : ChatRequest) :
36
+ db_uri = os.getenv("DATABASE_URI")
37
+
38
+ with PostgresSaver.from_conn_string(db_uri) as checkpointer:
39
+ checkpointer.setup()
40
+
41
+ agent = workflow.compile(
42
+ checkpointer=checkpointer
43
+ )
44
+ config = {
45
+ "configurable" : {
46
+ 'thread_id' : request.thread_id
47
+ }
48
+ }
49
+
50
+ initial_state = {
51
+ 'connection_url' : request.connection_url ,
52
+ 'user_id' : request.user_id ,
53
+ 'messages' : [HumanMessage(content=request.message)],
54
+ 'retry' : 0
55
+ }
56
+ try :
57
+ result = agent.invoke(initial_state , config=config)
58
+
59
+ final_result = result.get("final_result")
60
+
61
+ print("*"*50 , flush=True)
62
+ print(f"\n\n Scheme : {result['scheme']}\n\n" , flush=True)
63
+ print(f"\n\nSql Query : {result['sql_query']}\n\n" , flush=True)
64
+ print(f"\n\nQuery Result : {result['query_result']}\n\n" , flush=True)
65
+
66
+ return {
67
+ "status": "success",
68
+ "thread_id": request.thread_id,
69
+ "response": final_result
70
+ }
71
+
72
+ except Exception as e :
73
+ raise HTTPException(status_code=500 , detail=f"Error : {str(e)}")
src/retrieval.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from qdrant_client import QdrantClient
4
+ from qdrant_client import models
5
+ from fastembed import TextEmbedding, SparseTextEmbedding
6
+ from langchain_community.utilities import SQLDatabase
7
+
8
+ load_dotenv()
9
+
10
+ qdrant_api = os.getenv("QDRANT_API_KEY")
11
+ qdrant_url = os.getenv("QDRANT_URL")
12
+
13
+ COLLECTION_NAME = "Text2SQL"
14
+
15
+ def retrieve(user_id : str , query : str , connection_url: str) :
16
+
17
+ client = QdrantClient(api_key=qdrant_api , url=qdrant_url)
18
+
19
+ dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
20
+ sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
21
+
22
+ dense_query_vector = list(dense_model.embed([query]))[0]
23
+
24
+ sparse_query = list(sparse_model.embed([query]))[0]
25
+
26
+ sparse_query_vector = models.SparseVector(indices=sparse_query.indices,
27
+ values=sparse_query.values)
28
+
29
+ user_filter = models.Filter(
30
+ must=[
31
+ models.FieldCondition(
32
+ key="user_id",
33
+ match=models.MatchValue(value=user_id)
34
+ )
35
+ ]
36
+ )
37
+
38
+ results = client.query_points(
39
+ collection_name=COLLECTION_NAME,
40
+ prefetch=[
41
+ models.Prefetch(
42
+ query=dense_query_vector,
43
+ limit=10,
44
+ using="dense",
45
+ filter=user_filter
46
+ ),
47
+ models.Prefetch(
48
+ query=sparse_query_vector,
49
+ using="sparse",
50
+ limit=10,
51
+ filter=user_filter
52
+ )
53
+ ],
54
+ query=models.FusionQuery(fusion=models.Fusion.RRF),
55
+ limit=10
56
+ )
57
+
58
+ tables = []
59
+ for point in results.points :
60
+ table = point.payload['table_name']
61
+ if table not in tables :
62
+ tables.append(table)
63
+
64
+ db = SQLDatabase.from_uri(connection_url , sample_rows_in_table_info=0)
65
+
66
+ dialect = db.dialect
67
+
68
+ final_schemes = f"Dialect : {dialect}\n {db.get_table_info(table_names=tables)}"
69
+
70
+ return final_schemes
src/scheme.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine , inspect
2
+ from langchain_core.documents import Document
3
+
4
+ def create_scheme(database_url : str) -> list[dict] :
5
+
6
+ scheme_docs = []
7
+
8
+ engine = create_engine(database_url)
9
+ inspector = inspect(engine)
10
+
11
+ tables = inspector.get_table_names()
12
+
13
+ for table in tables :
14
+ clean_table = table.replace("_" , " ")
15
+
16
+ columns = inspector.get_columns(table)
17
+
18
+ clean_columns = []
19
+
20
+ for col in columns :
21
+ clean_col = col['name'].replace("_" , " ")
22
+ clean_columns.append(clean_col)
23
+
24
+ doc = f"Table: {clean_table}.\nColumns: {', '.join(clean_columns)}"
25
+
26
+ scheme_docs.append(
27
+ Document(
28
+ page_content=doc,
29
+ metadata={
30
+ "table_name" : table
31
+ }
32
+ )
33
+ )
34
+
35
+ return scheme_docs
start.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ echo "Starting FastAPI Backend..."
2
+ uv run uvicorn src.main:app --host 0.0.0.0 --port 8000 &
3
+
4
+ sleep 3
5
+
6
+ echo "Starting Streamlit Frontend..."
7
+ uv run streamlit run app.py --server.port=7860 --server.address=0.0.0.0
uv.lock ADDED
The diff for this file is too large to render. See raw diff