Selcan Yukcu commited on
Commit ·
0e5cf1e
1
Parent(s): a3f399c
feat: intent classification, intent base prompt choice
Browse files- conversation_memory.py +3 -2
- postgre_mcp_client.py +31 -14
- postgre_mcp_server.py +118 -5
- utils.py +23 -0
conversation_memory.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import json
|
| 2 |
-
|
| 3 |
class ConversationMemory:
|
| 4 |
def __init__(self):
|
| 5 |
self.history = [] # All parsed steps from all requests
|
|
@@ -36,7 +36,8 @@ class ConversationMemory:
|
|
| 36 |
def get_all_user_messages(self):
|
| 37 |
return list(set(self.user_messages))
|
| 38 |
|
| 39 |
-
def reset(self):
|
|
|
|
| 40 |
self.__init__() # Re-initialize the object
|
| 41 |
|
| 42 |
def summary(self):
|
|
|
|
| 1 |
import json
|
| 2 |
+
import os
|
| 3 |
class ConversationMemory:
|
| 4 |
def __init__(self):
|
| 5 |
self.history = [] # All parsed steps from all requests
|
|
|
|
| 36 |
def get_all_user_messages(self):
|
| 37 |
return list(set(self.user_messages))
|
| 38 |
|
| 39 |
+
def reset(self, path = "memory.json"):
|
| 40 |
+
os.remove(path)
|
| 41 |
self.__init__() # Re-initialize the object
|
| 42 |
|
| 43 |
def summary(self):
|
postgre_mcp_client.py
CHANGED
|
@@ -7,7 +7,7 @@ from langgraph.prebuilt import create_react_agent
|
|
| 7 |
from langchain.chat_models import init_chat_model
|
| 8 |
from conversation_memory import ConversationMemory
|
| 9 |
|
| 10 |
-
from utils import parse_mcp_output
|
| 11 |
|
| 12 |
llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",api_key ="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
|
| 13 |
|
|
@@ -24,6 +24,7 @@ The posts table represents content created by users, such as blog posts or messa
|
|
| 24 |
|
| 25 |
request = "can you show me the result of the join of all tables?"
|
| 26 |
request2 = "how many columns are there in this joined table?"
|
|
|
|
| 27 |
async def main():
|
| 28 |
async with stdio_client(server_params) as (read, write):
|
| 29 |
async with ClientSession(read, write) as session:
|
|
@@ -43,22 +44,38 @@ async def main():
|
|
| 43 |
past_results = memory.get_last_n_results()
|
| 44 |
past_requests = memory.get_all_user_messages()
|
| 45 |
|
| 46 |
-
|
| 47 |
-
resource = await session.read_resource(uri)
|
| 48 |
-
base_prompt = resource.contents[0].text
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
|
|
|
|
| 7 |
from langchain.chat_models import init_chat_model
|
| 8 |
from conversation_memory import ConversationMemory
|
| 9 |
|
| 10 |
+
from utils import parse_mcp_output, classify_intent
|
| 11 |
|
| 12 |
llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",api_key ="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
|
| 13 |
|
|
|
|
| 24 |
|
| 25 |
request = "can you show me the result of the join of all tables?"
|
| 26 |
request2 = "how many columns are there in this joined table?"
|
| 27 |
+
request3 = "send the last table"
|
| 28 |
async def main():
|
| 29 |
async with stdio_client(server_params) as (read, write):
|
| 30 |
async with ClientSession(read, write) as session:
|
|
|
|
| 44 |
past_results = memory.get_last_n_results()
|
| 45 |
past_requests = memory.get_all_user_messages()
|
| 46 |
|
| 47 |
+
intent = classify_intent(request)
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
if intent == "superset_request":
|
| 50 |
+
uri = f"resource://last_prompt"
|
| 51 |
+
resource = await session.read_resource(uri)
|
| 52 |
+
base_prompt = resource.contents[0].text
|
| 53 |
|
| 54 |
+
prompt = base_prompt.format(
|
| 55 |
+
user_requests=past_requests,
|
| 56 |
+
past_tools=past_tools,
|
| 57 |
+
last_queries=past_queries,
|
| 58 |
+
last_results=past_results,
|
| 59 |
+
new_request=request3
|
| 60 |
+
)
|
| 61 |
|
| 62 |
+
else:
|
| 63 |
+
uri = f"resource://base_prompt"
|
| 64 |
+
resource = await session.read_resource(uri)
|
| 65 |
+
base_prompt = resource.contents[0].text
|
| 66 |
+
|
| 67 |
+
# Create a formatted string of tools
|
| 68 |
+
tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
prompt = base_prompt.format(
|
| 72 |
+
user_requests=past_requests,
|
| 73 |
+
past_tools=past_tools,
|
| 74 |
+
last_queries=past_queries,
|
| 75 |
+
last_results=past_results,
|
| 76 |
+
new_request = request2,
|
| 77 |
+
tools = tools_str
|
| 78 |
+
)
|
| 79 |
|
| 80 |
|
| 81 |
|
postgre_mcp_server.py
CHANGED
|
@@ -50,12 +50,12 @@ mcp = FastMCP(
|
|
| 50 |
|
| 51 |
|
| 52 |
@mcp.resource(
|
| 53 |
-
uri="resource://
|
| 54 |
-
name="
|
| 55 |
-
description="A base prompt to generate
|
| 56 |
)
|
| 57 |
-
async def
|
| 58 |
-
"""Returns a base prompt to generate
|
| 59 |
|
| 60 |
base_prompt = """
|
| 61 |
|
|
@@ -181,6 +181,15 @@ async def base_prompt_table() -> str:
|
|
| 181 |
```sql
|
| 182 |
SELECT name FROM customers WHERE country = 'Germany';
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
### Invalid Example — DELETE Operation (Not Allowed):
|
| 185 |
**User Request:** "Delete all customers from Germany."
|
| 186 |
|
|
@@ -204,6 +213,110 @@ async def base_prompt_table() -> str:
|
|
| 204 |
return base_prompt
|
| 205 |
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
@mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
|
| 208 |
async def test_connection(ctx: Context) -> str:
|
| 209 |
"""Test database connection"""
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
@mcp.resource(
|
| 53 |
+
uri="resource://base_prompt",
|
| 54 |
+
name="base_prompt",
|
| 55 |
+
description="A base prompt to generate SQL queries and answer questions"
|
| 56 |
)
|
| 57 |
+
async def base_prompt_query() -> str:
|
| 58 |
+
"""Returns a base prompt to generate sql queries and answer questions"""
|
| 59 |
|
| 60 |
base_prompt = """
|
| 61 |
|
|
|
|
| 181 |
```sql
|
| 182 |
SELECT name FROM customers WHERE country = 'Germany';
|
| 183 |
|
| 184 |
+
## Example 6 — Basic Aggregation**
|
| 185 |
+
**User Request:** "Get total sales for each product"
|
| 186 |
+
|
| 187 |
+
**Steps:**
|
| 188 |
+
1. Use memory or List Tables → Get schema for `sales_data`
|
| 189 |
+
2. Generate and execute query:
|
| 190 |
+
```sql
|
| 191 |
+
SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
|
| 192 |
+
|
| 193 |
### Invalid Example — DELETE Operation (Not Allowed):
|
| 194 |
**User Request:** "Delete all customers from Germany."
|
| 195 |
|
|
|
|
| 213 |
return base_prompt
|
| 214 |
|
| 215 |
|
| 216 |
+
@mcp.resource(
|
| 217 |
+
uri="resource://last_prompt",
|
| 218 |
+
name="last_prompt",
|
| 219 |
+
description="A prompt that identifies the most recent SQL query related to the user's request and reformats it into ANSI SQL syntax for use in Superset."
|
| 220 |
+
)
|
| 221 |
+
async def last_prompt() -> str:
|
| 222 |
+
"""A prompt that identifies the most recent SQL query related to the user's request and reformats it into ANSI SQL syntax for use in Superset."""
|
| 223 |
+
|
| 224 |
+
base_prompt = """
|
| 225 |
+
|
| 226 |
+
==========================
|
| 227 |
+
# Your Role
|
| 228 |
+
==========================
|
| 229 |
+
|
| 230 |
+
You are an expert at reading and understanding SQL queries.
|
| 231 |
+
Your task is to retrieve the **exact SQL query** that produced a previously seen result, convert the query to the **ANSI SQL query** and return **only the ANSI SQL query** — no explanation, reasoning, or commentary.
|
| 232 |
+
|
| 233 |
+
You have access to a **short-term memory**, which stores relevant context from earlier interactions in the current conversation.
|
| 234 |
+
|
| 235 |
+
---
|
| 236 |
+
|
| 237 |
+
==========================
|
| 238 |
+
# Your Objective
|
| 239 |
+
==========================
|
| 240 |
+
|
| 241 |
+
When a user submits a request (e.g., *"send me that table"*, *"send the last query"*, etc.), follow these steps:
|
| 242 |
+
|
| 243 |
+
1. Identify which previous result the user is referring to, using your short-term memory.
|
| 244 |
+
2. Retrieve the corresponding SQL query that produced that result.
|
| 245 |
+
3. Convert the SQL query to the ANSI SQL query
|
| 246 |
+
3. Return **only** that ANSI SQL query.
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
==========================
|
| 251 |
+
# Critical Rules
|
| 252 |
+
==========================
|
| 253 |
+
|
| 254 |
+
- Do **not** ask questions or request clarification.
|
| 255 |
+
- Do **not** explain anything to the user.
|
| 256 |
+
- Only use the **memory** to determine which query is relevant.
|
| 257 |
+
- Respond with the **exact ANSI SQL query only**, formatted cleanly.
|
| 258 |
+
- Do **not** guess — only retrieve queries that actually exist in memory.
|
| 259 |
+
- If no query fits, respond with: "Query not found."
|
| 260 |
+
|
| 261 |
+
---
|
| 262 |
+
|
| 263 |
+
==========================
|
| 264 |
+
# Short-Term Memory
|
| 265 |
+
==========================
|
| 266 |
+
|
| 267 |
+
You have access to the following memory from this conversation:
|
| 268 |
+
|
| 269 |
+
- **Previous user requests**:
|
| 270 |
+
`{user_requests}`
|
| 271 |
+
|
| 272 |
+
- **Tools used so far**:
|
| 273 |
+
`{past_tools}`
|
| 274 |
+
|
| 275 |
+
- **Recent SQL queries**:
|
| 276 |
+
`{last_queries}`
|
| 277 |
+
|
| 278 |
+
- **Result preview from last query**:
|
| 279 |
+
`{last_results}`
|
| 280 |
+
|
| 281 |
+
Use this memory to resolve any references in the user's latest request.
|
| 282 |
+
|
| 283 |
+
---
|
| 284 |
+
|
| 285 |
+
==========================
|
| 286 |
+
# Examples
|
| 287 |
+
==========================
|
| 288 |
+
|
| 289 |
+
### Example 1 — Referring last query (Check the memory and find the most recent query that generates a table) :
|
| 290 |
+
**User Request:** "send the last table"
|
| 291 |
+
**You return:**
|
| 292 |
+
```sql
|
| 293 |
+
SELECT * FROM posts INNER JOIN users ON posts.user_id = users.id;
|
| 294 |
+
|
| 295 |
+
### Example 2 — Referring to a specific query (check the memory and find the query that returns the count of users):
|
| 296 |
+
**User Request:** "send the query that gave us the count of users"
|
| 297 |
+
**You return:**
|
| 298 |
+
```sql
|
| 299 |
+
SELECT COUNT(*) FROM users;
|
| 300 |
+
|
| 301 |
+
### Example 3 — Referring latest known query (Check the memory and find the most recent query.)
|
| 302 |
+
**User Request:** "send"
|
| 303 |
+
**You return:**
|
| 304 |
+
(latest known query):
|
| 305 |
+
```sql
|
| 306 |
+
SELECT * FROM posts WHERE user_id = 1;
|
| 307 |
+
|
| 308 |
+
Remember: Only respond with valid SQL from memory converted to the ANSI SQL. No assumptions. No explanations.
|
| 309 |
+
|
| 310 |
+
=========================
|
| 311 |
+
# New User Request
|
| 312 |
+
=========================
|
| 313 |
+
{new_request}
|
| 314 |
+
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
return base_prompt
|
| 318 |
+
|
| 319 |
+
|
| 320 |
@mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
|
| 321 |
async def test_connection(ctx: Context) -> str:
|
| 322 |
"""Test database connection"""
|
utils.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def parse_mcp_output(output_dict):
|
| 2 |
result = []
|
| 3 |
messages = output_dict.get("messages", [])
|
|
@@ -72,3 +75,23 @@ def parse_mcp_output(output_dict):
|
|
| 72 |
|
| 73 |
|
| 74 |
return result, query_store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
def parse_mcp_output(output_dict):
|
| 5 |
result = []
|
| 6 |
messages = output_dict.get("messages", [])
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
return result, query_store
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def classify_intent(user_input: str) -> str:
|
| 84 |
+
user_input = user_input.lower().strip()
|
| 85 |
+
|
| 86 |
+
superset_keywords = [
|
| 87 |
+
"send to superset", "chart", "visualize", "visualise",
|
| 88 |
+
"plot", "graph", "send this", "send that", "create a chart",
|
| 89 |
+
"push to superset", "make a chart", "show chart", "dashboard", "send"
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
# Check for superset intent
|
| 93 |
+
if any(kw in user_input for kw in superset_keywords):
|
| 94 |
+
return "superset_request"
|
| 95 |
+
|
| 96 |
+
# Fallback
|
| 97 |
+
return "sql_request"
|