Selcan Yukcu commited on
Commit
0e5cf1e
·
1 Parent(s): a3f399c

feat: intent classification, intent base prompt choice

Browse files
Files changed (4) hide show
  1. conversation_memory.py +3 -2
  2. postgre_mcp_client.py +31 -14
  3. postgre_mcp_server.py +118 -5
  4. utils.py +23 -0
conversation_memory.py CHANGED
@@ -1,5 +1,5 @@
1
  import json
2
-
3
  class ConversationMemory:
4
  def __init__(self):
5
  self.history = [] # All parsed steps from all requests
@@ -36,7 +36,8 @@ class ConversationMemory:
36
  def get_all_user_messages(self):
37
  return list(set(self.user_messages))
38
 
39
- def reset(self):
 
40
  self.__init__() # Re-initialize the object
41
 
42
  def summary(self):
 
1
  import json
2
+ import os
3
  class ConversationMemory:
4
  def __init__(self):
5
  self.history = [] # All parsed steps from all requests
 
36
  def get_all_user_messages(self):
37
  return list(set(self.user_messages))
38
 
39
+ def reset(self, path = "memory.json"):
40
+ os.remove(path)
41
  self.__init__() # Re-initialize the object
42
 
43
  def summary(self):
postgre_mcp_client.py CHANGED
@@ -7,7 +7,7 @@ from langgraph.prebuilt import create_react_agent
7
  from langchain.chat_models import init_chat_model
8
  from conversation_memory import ConversationMemory
9
 
10
- from utils import parse_mcp_output
11
 
12
  llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",api_key ="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
13
 
@@ -24,6 +24,7 @@ The posts table represents content created by users, such as blog posts or messa
24
 
25
  request = "can you show me the result of the join of all tables?"
26
  request2 = "how many columns are there in this joined table?"
 
27
  async def main():
28
  async with stdio_client(server_params) as (read, write):
29
  async with ClientSession(read, write) as session:
@@ -43,22 +44,38 @@ async def main():
43
  past_results = memory.get_last_n_results()
44
  past_requests = memory.get_all_user_messages()
45
 
46
- uri = f"resource://base_prompt_table"
47
- resource = await session.read_resource(uri)
48
- base_prompt = resource.contents[0].text
49
 
50
- # Create a formatted string of tools
51
- tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
 
 
52
 
 
 
 
 
 
 
 
53
 
54
- prompt = base_prompt.format(
55
- user_requests=past_requests,
56
- past_tools=past_tools,
57
- last_queries=past_queries,
58
- last_results=past_results,
59
- new_request = request2,
60
- tools = tools_str
61
- )
 
 
 
 
 
 
 
 
 
62
 
63
 
64
 
 
7
  from langchain.chat_models import init_chat_model
8
  from conversation_memory import ConversationMemory
9
 
10
+ from utils import parse_mcp_output, classify_intent
11
 
12
  llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",api_key ="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
13
 
 
24
 
25
  request = "can you show me the result of the join of all tables?"
26
  request2 = "how many columns are there in this joined table?"
27
+ request3 = "send the last table"
28
  async def main():
29
  async with stdio_client(server_params) as (read, write):
30
  async with ClientSession(read, write) as session:
 
44
  past_results = memory.get_last_n_results()
45
  past_requests = memory.get_all_user_messages()
46
 
47
+ intent = classify_intent(request)
 
 
48
 
49
+ if intent == "superset_request":
50
+ uri = f"resource://last_prompt"
51
+ resource = await session.read_resource(uri)
52
+ base_prompt = resource.contents[0].text
53
 
54
+ prompt = base_prompt.format(
55
+ user_requests=past_requests,
56
+ past_tools=past_tools,
57
+ last_queries=past_queries,
58
+ last_results=past_results,
59
+ new_request=request3
60
+ )
61
 
62
+ else:
63
+ uri = f"resource://base_prompt"
64
+ resource = await session.read_resource(uri)
65
+ base_prompt = resource.contents[0].text
66
+
67
+ # Create a formatted string of tools
68
+ tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
69
+
70
+
71
+ prompt = base_prompt.format(
72
+ user_requests=past_requests,
73
+ past_tools=past_tools,
74
+ last_queries=past_queries,
75
+ last_results=past_results,
76
+ new_request = request2,
77
+ tools = tools_str
78
+ )
79
 
80
 
81
 
postgre_mcp_server.py CHANGED
@@ -50,12 +50,12 @@ mcp = FastMCP(
50
 
51
 
52
  @mcp.resource(
53
- uri="resource://base_prompt_table",
54
- name="base_prompt_table",
55
- description="A base prompt to generate description of a table"
56
  )
57
- async def base_prompt_table() -> str:
58
- """Returns a base prompt to generate description of a table"""
59
 
60
  base_prompt = """
61
 
@@ -181,6 +181,15 @@ async def base_prompt_table() -> str:
181
  ```sql
182
  SELECT name FROM customers WHERE country = 'Germany';
183
 
 
 
 
 
 
 
 
 
 
184
  ### Invalid Example — DELETE Operation (Not Allowed):
185
  **User Request:** "Delete all customers from Germany."
186
 
@@ -204,6 +213,110 @@ async def base_prompt_table() -> str:
204
  return base_prompt
205
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  @mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
208
  async def test_connection(ctx: Context) -> str:
209
  """Test database connection"""
 
50
 
51
 
52
  @mcp.resource(
53
+ uri="resource://base_prompt",
54
+ name="base_prompt",
55
+ description="A base prompt to generate SQL queries and answer questions"
56
  )
57
+ async def base_prompt_query() -> str:
58
+ """Returns a base prompt to generate sql queries and answer questions"""
59
 
60
  base_prompt = """
61
 
 
181
  ```sql
182
  SELECT name FROM customers WHERE country = 'Germany';
183
 
184
+ ## Example 6 — Basic Aggregation**
185
+ **User Request:** "Get total sales for each product"
186
+
187
+ **Steps:**
188
+ 1. Use memory or List Tables → Get schema for `sales_data`
189
+ 2. Generate and execute query:
190
+ ```sql
191
+ SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
192
+
193
  ### Invalid Example — DELETE Operation (Not Allowed):
194
  **User Request:** "Delete all customers from Germany."
195
 
 
213
  return base_prompt
214
 
215
 
216
+ @mcp.resource(
217
+ uri="resource://last_prompt",
218
+ name="last_prompt",
219
+ description="A prompt that identifies the most recent SQL query related to the user's request and reformats it into ANSI SQL syntax for use in Superset."
220
+ )
221
+ async def last_prompt() -> str:
222
+ """A prompt that identifies the most recent SQL query related to the user's request and reformats it into ANSI SQL syntax for use in Superset."""
223
+
224
+ base_prompt = """
225
+
226
+ ==========================
227
+ # Your Role
228
+ ==========================
229
+
230
+ You are an expert at reading and understanding SQL queries.
231
+ Your task is to retrieve the **exact SQL query** that produced a previously seen result, convert the query to the **ANSI SQL query** and return **only the ANSI SQL query** — no explanation, reasoning, or commentary.
232
+
233
+ You have access to a **short-term memory**, which stores relevant context from earlier interactions in the current conversation.
234
+
235
+ ---
236
+
237
+ ==========================
238
+ # Your Objective
239
+ ==========================
240
+
241
+ When a user submits a request (e.g., *"send me that table"*, *"send the last query"*, etc.), follow these steps:
242
+
243
+ 1. Identify which previous result the user is referring to, using your short-term memory.
244
+ 2. Retrieve the corresponding SQL query that produced that result.
245
+ 3. Convert the SQL query to the ANSI SQL query
246
+ 3. Return **only** that ANSI SQL query.
247
+
248
+ ---
249
+
250
+ ==========================
251
+ # Critical Rules
252
+ ==========================
253
+
254
+ - Do **not** ask questions or request clarification.
255
+ - Do **not** explain anything to the user.
256
+ - Only use the **memory** to determine which query is relevant.
257
+ - Respond with the **exact ANSI SQL query only**, formatted cleanly.
258
+ - Do **not** guess — only retrieve queries that actually exist in memory.
259
+ - If no query fits, respond with: "Query not found."
260
+
261
+ ---
262
+
263
+ ==========================
264
+ # Short-Term Memory
265
+ ==========================
266
+
267
+ You have access to the following memory from this conversation:
268
+
269
+ - **Previous user requests**:
270
+ `{user_requests}`
271
+
272
+ - **Tools used so far**:
273
+ `{past_tools}`
274
+
275
+ - **Recent SQL queries**:
276
+ `{last_queries}`
277
+
278
+ - **Result preview from last query**:
279
+ `{last_results}`
280
+
281
+ Use this memory to resolve any references in the user's latest request.
282
+
283
+ ---
284
+
285
+ ==========================
286
+ # Examples
287
+ ==========================
288
+
289
+ ### Example 1 — Referring last query (Check the memory and find the most recent query that generates a table) :
290
+ **User Request:** "send the last table"
291
+ **You return:**
292
+ ```sql
293
+ SELECT * FROM posts INNER JOIN users ON posts.user_id = users.id;
294
+
295
+ ### Example 2 — Referring to a specific query (check the memory and find the query that returns the count of users):
296
+ **User Request:** "send the query that gave us the count of users"
297
+ **You return:**
298
+ ```sql
299
+ SELECT COUNT(*) FROM users;
300
+
301
+ ### Example 3 — Referring latest known query (Check the memory and find the most recent query.)
302
+ **User Request:** "send"
303
+ **You return:**
304
+ (latest known query):
305
+ ```sql
306
+ SELECT * FROM posts WHERE user_id = 1;
307
+
308
+ Remember: Only respond with valid SQL from memory converted to the ANSI SQL. No assumptions. No explanations.
309
+
310
+ =========================
311
+ # New User Request
312
+ =========================
313
+ {new_request}
314
+
315
+ """
316
+
317
+ return base_prompt
318
+
319
+
320
  @mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
321
  async def test_connection(ctx: Context) -> str:
322
  """Test database connection"""
utils.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  def parse_mcp_output(output_dict):
2
  result = []
3
  messages = output_dict.get("messages", [])
@@ -72,3 +75,23 @@ def parse_mcp_output(output_dict):
72
 
73
 
74
  return result, query_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
  def parse_mcp_output(output_dict):
5
  result = []
6
  messages = output_dict.get("messages", [])
 
75
 
76
 
77
  return result, query_store
78
+
79
+
80
+
81
+
82
+
83
+ def classify_intent(user_input: str) -> str:
84
+ user_input = user_input.lower().strip()
85
+
86
+ superset_keywords = [
87
+ "send to superset", "chart", "visualize", "visualise",
88
+ "plot", "graph", "send this", "send that", "create a chart",
89
+ "push to superset", "make a chart", "show chart", "dashboard", "send"
90
+ ]
91
+
92
+ # Check for superset intent
93
+ if any(kw in user_input for kw in superset_keywords):
94
+ return "superset_request"
95
+
96
+ # Fallback
97
+ return "sql_request"