Selcan Yukcu commited on
Commit ·
e27d11b
1
Parent(s): 0e5cf1e
fix: memory store issue
Browse files- conversation_memory.py +10 -10
- postgre_mcp_client.py +26 -13
- postgre_mcp_server.py +19 -2
- utils.py +0 -7
conversation_memory.py
CHANGED
|
@@ -2,27 +2,27 @@ import json
|
|
| 2 |
import os
|
| 3 |
class ConversationMemory:
|
| 4 |
def __init__(self):
|
| 5 |
-
self.history = [] # All parsed steps from all requests
|
| 6 |
self.tools_used = []
|
| 7 |
self.all_queries = []
|
| 8 |
self.query_results = []
|
| 9 |
-
self.
|
| 10 |
|
| 11 |
-
def update_from_parsed(self, parsed_steps):
|
| 12 |
tools = []
|
|
|
|
| 13 |
for step in parsed_steps:
|
| 14 |
-
if step['type'] == '
|
| 15 |
-
self.user_messages.append(step['content'])
|
| 16 |
-
elif step['type'] == 'ai_function_call':
|
| 17 |
tools.append(step['tool'])
|
| 18 |
if 'query' in step:
|
| 19 |
self.all_queries.append(step['query'])
|
|
|
|
|
|
|
|
|
|
| 20 |
elif step['type'] == 'tool_response':
|
| 21 |
if step['tool'] == 'execute_query':
|
| 22 |
self.query_results.append(step['response'])
|
| 23 |
|
| 24 |
self.tools_used.extend(tools)
|
| 25 |
-
self.history.append(parsed_steps)
|
| 26 |
|
| 27 |
def get_last_n_queries(self):
|
| 28 |
return list(set(self.all_queries))
|
|
@@ -34,7 +34,7 @@ class ConversationMemory:
|
|
| 34 |
return list(set(self.tools_used))
|
| 35 |
|
| 36 |
def get_all_user_messages(self):
|
| 37 |
-
return list(set(self.
|
| 38 |
|
| 39 |
def reset(self, path = "memory.json"):
|
| 40 |
os.remove(path)
|
|
@@ -42,9 +42,9 @@ class ConversationMemory:
|
|
| 42 |
|
| 43 |
def summary(self):
|
| 44 |
return {
|
| 45 |
-
"total_requests": len(self.
|
| 46 |
"tools_used": self.get_all_tools_used(),
|
| 47 |
-
"last_request": self.
|
| 48 |
"last_query": self.all_queries[-1] if self.all_queries else None,
|
| 49 |
"last_result": self.query_results[-1] if self.query_results else None,
|
| 50 |
}
|
|
|
|
| 2 |
import os
|
| 3 |
class ConversationMemory:
|
| 4 |
def __init__(self):
|
|
|
|
| 5 |
self.tools_used = []
|
| 6 |
self.all_queries = []
|
| 7 |
self.query_results = []
|
| 8 |
+
self.request = []
|
| 9 |
|
| 10 |
+
def update_from_parsed(self, parsed_steps, request):
|
| 11 |
tools = []
|
| 12 |
+
self.request.append(request)
|
| 13 |
for step in parsed_steps:
|
| 14 |
+
if step['type'] == 'ai_function_call':
|
|
|
|
|
|
|
| 15 |
tools.append(step['tool'])
|
| 16 |
if 'query' in step:
|
| 17 |
self.all_queries.append(step['query'])
|
| 18 |
+
elif step['type'] == 'ai_final_answer':
|
| 19 |
+
if 'query' in step['ai_said'][0]:
|
| 20 |
+
self.all_queries.append(step['ai_said'][1])
|
| 21 |
elif step['type'] == 'tool_response':
|
| 22 |
if step['tool'] == 'execute_query':
|
| 23 |
self.query_results.append(step['response'])
|
| 24 |
|
| 25 |
self.tools_used.extend(tools)
|
|
|
|
| 26 |
|
| 27 |
def get_last_n_queries(self):
|
| 28 |
return list(set(self.all_queries))
|
|
|
|
| 34 |
return list(set(self.tools_used))
|
| 35 |
|
| 36 |
def get_all_user_messages(self):
|
| 37 |
+
return list(set(self.request))
|
| 38 |
|
| 39 |
def reset(self, path = "memory.json"):
|
| 40 |
os.remove(path)
|
|
|
|
| 42 |
|
| 43 |
def summary(self):
|
| 44 |
return {
|
| 45 |
+
"total_requests": len(self.request),
|
| 46 |
"tools_used": self.get_all_tools_used(),
|
| 47 |
+
"last_request": self.request[-1] if self.request else None,
|
| 48 |
"last_query": self.all_queries[-1] if self.all_queries else None,
|
| 49 |
"last_result": self.query_results[-1] if self.query_results else None,
|
| 50 |
}
|
postgre_mcp_client.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
import asyncio
|
|
|
|
|
|
|
| 2 |
from mcp import ClientSession, StdioServerParameters
|
| 3 |
from mcp.client.stdio import stdio_client
|
| 4 |
|
|
@@ -22,9 +24,11 @@ The users table stores information about the individuals who use the application
|
|
| 22 |
The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
|
| 23 |
"""
|
| 24 |
|
| 25 |
-
request = "can you show me the result of the join of
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
async def main():
|
| 29 |
async with stdio_client(server_params) as (read, write):
|
| 30 |
async with ClientSession(read, write) as session:
|
|
@@ -38,11 +42,19 @@ async def main():
|
|
| 38 |
for tool in tools:
|
| 39 |
tool.description += f" {table_summary}"
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
intent = classify_intent(request)
|
| 48 |
|
|
@@ -56,7 +68,8 @@ async def main():
|
|
| 56 |
past_tools=past_tools,
|
| 57 |
last_queries=past_queries,
|
| 58 |
last_results=past_results,
|
| 59 |
-
new_request=
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
else:
|
|
@@ -73,7 +86,7 @@ async def main():
|
|
| 73 |
past_tools=past_tools,
|
| 74 |
last_queries=past_queries,
|
| 75 |
last_results=past_results,
|
| 76 |
-
new_request =
|
| 77 |
tools = tools_str
|
| 78 |
)
|
| 79 |
|
|
@@ -85,9 +98,9 @@ async def main():
|
|
| 85 |
|
| 86 |
|
| 87 |
parsed_steps, query_store = parse_mcp_output(agent_response)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
memory.update_from_parsed(parsed_steps)
|
| 91 |
|
| 92 |
if request.strip().lower() == "stop":
|
| 93 |
memory.reset()
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
import os.path
|
| 3 |
+
|
| 4 |
from mcp import ClientSession, StdioServerParameters
|
| 5 |
from mcp.client.stdio import stdio_client
|
| 6 |
|
|
|
|
| 24 |
The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
|
| 25 |
"""
|
| 26 |
|
| 27 |
+
#request = "can you show me the result of the join of postes and users tables?"
|
| 28 |
+
#request = "May ı see the table?"
|
| 29 |
+
#request = "stop"
|
| 30 |
+
#request = "how many columns are there in this joined table?"
|
| 31 |
+
request = "send the table"
|
| 32 |
async def main():
|
| 33 |
async with stdio_client(server_params) as (read, write):
|
| 34 |
async with ClientSession(read, write) as session:
|
|
|
|
| 42 |
for tool in tools:
|
| 43 |
tool.description += f" {table_summary}"
|
| 44 |
|
| 45 |
+
if os.path.exists("memory.json"):
|
| 46 |
+
memory = memory.load_memory()
|
| 47 |
+
past_tools = memory.get_all_tools_used()
|
| 48 |
+
past_queries = memory.get_last_n_queries()
|
| 49 |
+
past_results = memory.get_last_n_results()
|
| 50 |
+
past_requests = memory.get_all_user_messages()
|
| 51 |
+
|
| 52 |
+
else:
|
| 53 |
+
past_tools = "No tools found"
|
| 54 |
+
past_queries ="No queries found"
|
| 55 |
+
past_results = "No results found"
|
| 56 |
+
past_requests = "No requests found"
|
| 57 |
+
|
| 58 |
|
| 59 |
intent = classify_intent(request)
|
| 60 |
|
|
|
|
| 68 |
past_tools=past_tools,
|
| 69 |
last_queries=past_queries,
|
| 70 |
last_results=past_results,
|
| 71 |
+
new_request=request
|
| 72 |
+
|
| 73 |
)
|
| 74 |
|
| 75 |
else:
|
|
|
|
| 86 |
past_tools=past_tools,
|
| 87 |
last_queries=past_queries,
|
| 88 |
last_results=past_results,
|
| 89 |
+
new_request = request,
|
| 90 |
tools = tools_str
|
| 91 |
)
|
| 92 |
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
parsed_steps, query_store = parse_mcp_output(agent_response)
|
| 101 |
+
print("************")
|
| 102 |
+
print(parsed_steps)
|
| 103 |
+
memory.update_from_parsed(parsed_steps, request)
|
| 104 |
|
| 105 |
if request.strip().lower() == "stop":
|
| 106 |
memory.reset()
|
postgre_mcp_server.py
CHANGED
|
@@ -69,6 +69,7 @@ async def base_prompt_query() -> str:
|
|
| 69 |
- Retrieve schema details
|
| 70 |
- Execute SQL queries
|
| 71 |
|
|
|
|
| 72 |
Each tool may also return previews or summaries of table contents to help you better understand the data structure.
|
| 73 |
|
| 74 |
You also have access to **short-term memory**, which stores relevant context from earlier queries. If memory contains the needed information, you **must use it** instead of repeating tool calls with the same input. Avoid redundant tool usage unless:
|
|
@@ -86,9 +87,10 @@ async def base_prompt_query() -> str:
|
|
| 86 |
1. Analyze the request to determine the desired data or action.
|
| 87 |
2. Use tools to gather any necessary information (e.g., list tables, get schema).
|
| 88 |
3. Generate a valid SQL query (such as **SELECT**, **COUNT**, or other read-only operations) and clearly display the full query.
|
| 89 |
-
4. Execute the query and return the result.
|
| 90 |
-
5. Chain tools logically to build toward the answer.
|
| 91 |
6. Explain your reasoning at every step for clarity and transparency.
|
|
|
|
| 92 |
|
| 93 |
---
|
| 94 |
|
|
@@ -102,6 +104,8 @@ async def base_prompt_query() -> str:
|
|
| 102 |
- Validate SQL syntax before execution.
|
| 103 |
- Never assume table or column names. Use tools to confirm structure.
|
| 104 |
- Use memory efficiently. Don’t rerun a tool unless necessary.
|
|
|
|
|
|
|
| 105 |
|
| 106 |
---
|
| 107 |
|
|
@@ -141,6 +145,14 @@ async def base_prompt_query() -> str:
|
|
| 141 |
```sql
|
| 142 |
SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
### Example 2 — Schema Uncertainty:
|
| 145 |
**User Request:** "Show customer emails from the database."
|
| 146 |
|
|
@@ -152,6 +164,7 @@ async def base_prompt_query() -> str:
|
|
| 152 |
```sql
|
| 153 |
SELECT email FROM customers;
|
| 154 |
|
|
|
|
| 155 |
### Example 3 — Memory Usage:
|
| 156 |
**User Request:** "Get top 5 most expensive products."
|
| 157 |
|
|
@@ -161,6 +174,7 @@ async def base_prompt_query() -> str:
|
|
| 161 |
```sql
|
| 162 |
SELECT * FROM products ORDER BY price DESC LIMIT 5;
|
| 163 |
|
|
|
|
| 164 |
### Example 4 — COUNT Query:
|
| 165 |
**User Request:** "How many orders have been placed?"
|
| 166 |
|
|
@@ -171,6 +185,7 @@ async def base_prompt_query() -> str:
|
|
| 171 |
```sql
|
| 172 |
SELECT COUNT(*) FROM orders;
|
| 173 |
|
|
|
|
| 174 |
### Example 5 — WHERE Clause (Filtering):
|
| 175 |
**User Request:** "Get the names of customers from Germany."
|
| 176 |
|
|
@@ -181,6 +196,7 @@ async def base_prompt_query() -> str:
|
|
| 181 |
```sql
|
| 182 |
SELECT name FROM customers WHERE country = 'Germany';
|
| 183 |
|
|
|
|
| 184 |
## Example 6 — Basic Aggregation**
|
| 185 |
**User Request:** "Get total sales for each product"
|
| 186 |
|
|
@@ -190,6 +206,7 @@ async def base_prompt_query() -> str:
|
|
| 190 |
```sql
|
| 191 |
SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
|
| 192 |
|
|
|
|
| 193 |
### Invalid Example — DELETE Operation (Not Allowed):
|
| 194 |
**User Request:** "Delete all customers from Germany."
|
| 195 |
|
|
|
|
| 69 |
- Retrieve schema details
|
| 70 |
- Execute SQL queries
|
| 71 |
|
| 72 |
+
|
| 73 |
Each tool may also return previews or summaries of table contents to help you better understand the data structure.
|
| 74 |
|
| 75 |
You also have access to **short-term memory**, which stores relevant context from earlier queries. If memory contains the needed information, you **must use it** instead of repeating tool calls with the same input. Avoid redundant tool usage unless:
|
|
|
|
| 87 |
1. Analyze the request to determine the desired data or action.
|
| 88 |
2. Use tools to gather any necessary information (e.g., list tables, get schema).
|
| 89 |
3. Generate a valid SQL query (such as **SELECT**, **COUNT**, or other read-only operations) and clearly display the full query.
|
| 90 |
+
4. Execute the query and **return the result**.
|
| 91 |
+
5. **Chain tools logically to build toward the answer.**
|
| 92 |
6. Explain your reasoning at every step for clarity and transparency.
|
| 93 |
+
7. Show the result of the **execute_query** in your final answer.
|
| 94 |
|
| 95 |
---
|
| 96 |
|
|
|
|
| 104 |
- Validate SQL syntax before execution.
|
| 105 |
- Never assume table or column names. Use tools to confirm structure.
|
| 106 |
- Use memory efficiently. Don’t rerun a tool unless necessary.
|
| 107 |
+
- If you generate a SQL query, you **have to** use **execute_query** tool at the end.
|
| 108 |
+
|
| 109 |
|
| 110 |
---
|
| 111 |
|
|
|
|
| 145 |
```sql
|
| 146 |
SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
|
| 147 |
|
| 148 |
+
### Example 1 — Repeating the same request:
|
| 149 |
+
**User Request:** "Get the total sales for each product."
|
| 150 |
+
**Steps:**
|
| 151 |
+
1. Use memory to check if we already retrieved schema.
|
| 152 |
+
2. Retrieve the SQL query that gives the total sales for each product.
|
| 153 |
+
3. Only execute the query and show the user.
|
| 154 |
+
|
| 155 |
+
|
| 156 |
### Example 2 — Schema Uncertainty:
|
| 157 |
**User Request:** "Show customer emails from the database."
|
| 158 |
|
|
|
|
| 164 |
```sql
|
| 165 |
SELECT email FROM customers;
|
| 166 |
|
| 167 |
+
|
| 168 |
### Example 3 — Memory Usage:
|
| 169 |
**User Request:** "Get top 5 most expensive products."
|
| 170 |
|
|
|
|
| 174 |
```sql
|
| 175 |
SELECT * FROM products ORDER BY price DESC LIMIT 5;
|
| 176 |
|
| 177 |
+
|
| 178 |
### Example 4 — COUNT Query:
|
| 179 |
**User Request:** "How many orders have been placed?"
|
| 180 |
|
|
|
|
| 185 |
```sql
|
| 186 |
SELECT COUNT(*) FROM orders;
|
| 187 |
|
| 188 |
+
|
| 189 |
### Example 5 — WHERE Clause (Filtering):
|
| 190 |
**User Request:** "Get the names of customers from Germany."
|
| 191 |
|
|
|
|
| 196 |
```sql
|
| 197 |
SELECT name FROM customers WHERE country = 'Germany';
|
| 198 |
|
| 199 |
+
|
| 200 |
## Example 6 — Basic Aggregation**
|
| 201 |
**User Request:** "Get total sales for each product"
|
| 202 |
|
|
|
|
| 206 |
```sql
|
| 207 |
SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
|
| 208 |
|
| 209 |
+
|
| 210 |
### Invalid Example — DELETE Operation (Not Allowed):
|
| 211 |
**User Request:** "Delete all customers from Germany."
|
| 212 |
|
utils.py
CHANGED
|
@@ -67,13 +67,6 @@ def parse_mcp_output(output_dict):
|
|
| 67 |
"response": content
|
| 68 |
})
|
| 69 |
|
| 70 |
-
elif role_name == "HumanMessage":
|
| 71 |
-
result.append({
|
| 72 |
-
"type": "user_message",
|
| 73 |
-
"content": content
|
| 74 |
-
})
|
| 75 |
-
|
| 76 |
-
|
| 77 |
return result, query_store
|
| 78 |
|
| 79 |
|
|
|
|
| 67 |
"response": content
|
| 68 |
})
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
return result, query_store
|
| 71 |
|
| 72 |
|