Selcan Yukcu commited on
Commit
e27d11b
·
1 Parent(s): 0e5cf1e

fix: memory store issue

Browse files
Files changed (4) hide show
  1. conversation_memory.py +10 -10
  2. postgre_mcp_client.py +26 -13
  3. postgre_mcp_server.py +19 -2
  4. utils.py +0 -7
conversation_memory.py CHANGED
@@ -2,27 +2,27 @@ import json
2
  import os
3
  class ConversationMemory:
4
  def __init__(self):
5
- self.history = [] # All parsed steps from all requests
6
  self.tools_used = []
7
  self.all_queries = []
8
  self.query_results = []
9
- self.user_messages = []
10
 
11
- def update_from_parsed(self, parsed_steps):
12
  tools = []
 
13
  for step in parsed_steps:
14
- if step['type'] == 'user_message':
15
- self.user_messages.append(step['content'])
16
- elif step['type'] == 'ai_function_call':
17
  tools.append(step['tool'])
18
  if 'query' in step:
19
  self.all_queries.append(step['query'])
 
 
 
20
  elif step['type'] == 'tool_response':
21
  if step['tool'] == 'execute_query':
22
  self.query_results.append(step['response'])
23
 
24
  self.tools_used.extend(tools)
25
- self.history.append(parsed_steps)
26
 
27
  def get_last_n_queries(self):
28
  return list(set(self.all_queries))
@@ -34,7 +34,7 @@ class ConversationMemory:
34
  return list(set(self.tools_used))
35
 
36
  def get_all_user_messages(self):
37
- return list(set(self.user_messages))
38
 
39
  def reset(self, path = "memory.json"):
40
  os.remove(path)
@@ -42,9 +42,9 @@ class ConversationMemory:
42
 
43
  def summary(self):
44
  return {
45
- "total_requests": len(self.user_messages),
46
  "tools_used": self.get_all_tools_used(),
47
- "last_request": self.user_messages[-1] if self.user_messages else None,
48
  "last_query": self.all_queries[-1] if self.all_queries else None,
49
  "last_result": self.query_results[-1] if self.query_results else None,
50
  }
 
2
  import os
3
  class ConversationMemory:
4
  def __init__(self):
 
5
  self.tools_used = []
6
  self.all_queries = []
7
  self.query_results = []
8
+ self.request = []
9
 
10
+ def update_from_parsed(self, parsed_steps, request):
11
  tools = []
12
+ self.request.append(request)
13
  for step in parsed_steps:
14
+ if step['type'] == 'ai_function_call':
 
 
15
  tools.append(step['tool'])
16
  if 'query' in step:
17
  self.all_queries.append(step['query'])
18
+ elif step['type'] == 'ai_final_answer':
19
+ if 'query' in step['ai_said'][0]:
20
+ self.all_queries.append(step['ai_said'][1])
21
  elif step['type'] == 'tool_response':
22
  if step['tool'] == 'execute_query':
23
  self.query_results.append(step['response'])
24
 
25
  self.tools_used.extend(tools)
 
26
 
27
  def get_last_n_queries(self):
28
  return list(set(self.all_queries))
 
34
  return list(set(self.tools_used))
35
 
36
  def get_all_user_messages(self):
37
+ return list(set(self.request))
38
 
39
  def reset(self, path = "memory.json"):
40
  os.remove(path)
 
42
 
43
  def summary(self):
44
  return {
45
+ "total_requests": len(self.request),
46
  "tools_used": self.get_all_tools_used(),
47
+ "last_request": self.request[-1] if self.request else None,
48
  "last_query": self.all_queries[-1] if self.all_queries else None,
49
  "last_result": self.query_results[-1] if self.query_results else None,
50
  }
postgre_mcp_client.py CHANGED
@@ -1,4 +1,6 @@
1
  import asyncio
 
 
2
  from mcp import ClientSession, StdioServerParameters
3
  from mcp.client.stdio import stdio_client
4
 
@@ -22,9 +24,11 @@ The users table stores information about the individuals who use the application
22
  The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
23
  """
24
 
25
- request = "can you show me the result of the join of all tables?"
26
- request2 = "how many columns are there in this joined table?"
27
- request3 = "send the last table"
 
 
28
  async def main():
29
  async with stdio_client(server_params) as (read, write):
30
  async with ClientSession(read, write) as session:
@@ -38,11 +42,19 @@ async def main():
38
  for tool in tools:
39
  tool.description += f" {table_summary}"
40
 
41
- memory = memory.load_memory()
42
- past_tools = memory.get_all_tools_used()
43
- past_queries = memory.get_last_n_queries()
44
- past_results = memory.get_last_n_results()
45
- past_requests = memory.get_all_user_messages()
 
 
 
 
 
 
 
 
46
 
47
  intent = classify_intent(request)
48
 
@@ -56,7 +68,8 @@ async def main():
56
  past_tools=past_tools,
57
  last_queries=past_queries,
58
  last_results=past_results,
59
- new_request=request3
 
60
  )
61
 
62
  else:
@@ -73,7 +86,7 @@ async def main():
73
  past_tools=past_tools,
74
  last_queries=past_queries,
75
  last_results=past_results,
76
- new_request = request2,
77
  tools = tools_str
78
  )
79
 
@@ -85,9 +98,9 @@ async def main():
85
 
86
 
87
  parsed_steps, query_store = parse_mcp_output(agent_response)
88
-
89
-
90
- memory.update_from_parsed(parsed_steps)
91
 
92
  if request.strip().lower() == "stop":
93
  memory.reset()
 
1
  import asyncio
2
+ import os.path
3
+
4
  from mcp import ClientSession, StdioServerParameters
5
  from mcp.client.stdio import stdio_client
6
 
 
24
  The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
25
  """
26
 
27
+ #request = "can you show me the result of the join of postes and users tables?"
28
+ #request = "May ı see the table?"
29
+ #request = "stop"
30
+ #request = "how many columns are there in this joined table?"
31
+ request = "send the table"
32
  async def main():
33
  async with stdio_client(server_params) as (read, write):
34
  async with ClientSession(read, write) as session:
 
42
  for tool in tools:
43
  tool.description += f" {table_summary}"
44
 
45
+ if os.path.exists("memory.json"):
46
+ memory = memory.load_memory()
47
+ past_tools = memory.get_all_tools_used()
48
+ past_queries = memory.get_last_n_queries()
49
+ past_results = memory.get_last_n_results()
50
+ past_requests = memory.get_all_user_messages()
51
+
52
+ else:
53
+ past_tools = "No tools found"
54
+ past_queries ="No queries found"
55
+ past_results = "No results found"
56
+ past_requests = "No requests found"
57
+
58
 
59
  intent = classify_intent(request)
60
 
 
68
  past_tools=past_tools,
69
  last_queries=past_queries,
70
  last_results=past_results,
71
+ new_request=request
72
+
73
  )
74
 
75
  else:
 
86
  past_tools=past_tools,
87
  last_queries=past_queries,
88
  last_results=past_results,
89
+ new_request = request,
90
  tools = tools_str
91
  )
92
 
 
98
 
99
 
100
  parsed_steps, query_store = parse_mcp_output(agent_response)
101
+ print("************")
102
+ print(parsed_steps)
103
+ memory.update_from_parsed(parsed_steps, request)
104
 
105
  if request.strip().lower() == "stop":
106
  memory.reset()
postgre_mcp_server.py CHANGED
@@ -69,6 +69,7 @@ async def base_prompt_query() -> str:
69
  - Retrieve schema details
70
  - Execute SQL queries
71
 
 
72
  Each tool may also return previews or summaries of table contents to help you better understand the data structure.
73
 
74
  You also have access to **short-term memory**, which stores relevant context from earlier queries. If memory contains the needed information, you **must use it** instead of repeating tool calls with the same input. Avoid redundant tool usage unless:
@@ -86,9 +87,10 @@ async def base_prompt_query() -> str:
86
  1. Analyze the request to determine the desired data or action.
87
  2. Use tools to gather any necessary information (e.g., list tables, get schema).
88
  3. Generate a valid SQL query (such as **SELECT**, **COUNT**, or other read-only operations) and clearly display the full query.
89
- 4. Execute the query and return the result.
90
- 5. Chain tools logically to build toward the answer.
91
  6. Explain your reasoning at every step for clarity and transparency.
 
92
 
93
  ---
94
 
@@ -102,6 +104,8 @@ async def base_prompt_query() -> str:
102
  - Validate SQL syntax before execution.
103
  - Never assume table or column names. Use tools to confirm structure.
104
  - Use memory efficiently. Don’t rerun a tool unless necessary.
 
 
105
 
106
  ---
107
 
@@ -141,6 +145,14 @@ async def base_prompt_query() -> str:
141
  ```sql
142
  SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
143
 
 
 
 
 
 
 
 
 
144
  ### Example 2 — Schema Uncertainty:
145
  **User Request:** "Show customer emails from the database."
146
 
@@ -152,6 +164,7 @@ async def base_prompt_query() -> str:
152
  ```sql
153
  SELECT email FROM customers;
154
 
 
155
  ### Example 3 — Memory Usage:
156
  **User Request:** "Get top 5 most expensive products."
157
 
@@ -161,6 +174,7 @@ async def base_prompt_query() -> str:
161
  ```sql
162
  SELECT * FROM products ORDER BY price DESC LIMIT 5;
163
 
 
164
  ### Example 4 — COUNT Query:
165
  **User Request:** "How many orders have been placed?"
166
 
@@ -171,6 +185,7 @@ async def base_prompt_query() -> str:
171
  ```sql
172
  SELECT COUNT(*) FROM orders;
173
 
 
174
  ### Example 5 — WHERE Clause (Filtering):
175
  **User Request:** "Get the names of customers from Germany."
176
 
@@ -181,6 +196,7 @@ async def base_prompt_query() -> str:
181
  ```sql
182
  SELECT name FROM customers WHERE country = 'Germany';
183
 
 
184
  ## Example 6 — Basic Aggregation**
185
  **User Request:** "Get total sales for each product"
186
 
@@ -190,6 +206,7 @@ async def base_prompt_query() -> str:
190
  ```sql
191
  SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
192
 
 
193
  ### Invalid Example — DELETE Operation (Not Allowed):
194
  **User Request:** "Delete all customers from Germany."
195
 
 
69
  - Retrieve schema details
70
  - Execute SQL queries
71
 
72
+
73
  Each tool may also return previews or summaries of table contents to help you better understand the data structure.
74
 
75
  You also have access to **short-term memory**, which stores relevant context from earlier queries. If memory contains the needed information, you **must use it** instead of repeating tool calls with the same input. Avoid redundant tool usage unless:
 
87
  1. Analyze the request to determine the desired data or action.
88
  2. Use tools to gather any necessary information (e.g., list tables, get schema).
89
  3. Generate a valid SQL query (such as **SELECT**, **COUNT**, or other read-only operations) and clearly display the full query.
90
+ 4. Execute the query and **return the result**.
91
+ 5. **Chain tools logically to build toward the answer.**
92
  6. Explain your reasoning at every step for clarity and transparency.
93
+ 7. Show the result of the **execute_query** in your final answer.
94
 
95
  ---
96
 
 
104
  - Validate SQL syntax before execution.
105
  - Never assume table or column names. Use tools to confirm structure.
106
  - Use memory efficiently. Don’t rerun a tool unless necessary.
107
+ - If you generate a SQL query, you **have to** use **execute_query** tool at the end.
108
+
109
 
110
  ---
111
 
 
145
  ```sql
146
  SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
147
 
148
+ ### Example 1 — Repeating the same request:
149
+ **User Request:** "Get the total sales for each product."
150
+ **Steps:**
151
+ 1. Use memory to check if we already retrieved schema.
152
+ 2. Retrieve the SQL query that gives the total sales for each product.
153
+ 3. Only execute the query and show the user.
154
+
155
+
156
  ### Example 2 — Schema Uncertainty:
157
  **User Request:** "Show customer emails from the database."
158
 
 
164
  ```sql
165
  SELECT email FROM customers;
166
 
167
+
168
  ### Example 3 — Memory Usage:
169
  **User Request:** "Get top 5 most expensive products."
170
 
 
174
  ```sql
175
  SELECT * FROM products ORDER BY price DESC LIMIT 5;
176
 
177
+
178
  ### Example 4 — COUNT Query:
179
  **User Request:** "How many orders have been placed?"
180
 
 
185
  ```sql
186
  SELECT COUNT(*) FROM orders;
187
 
188
+
189
  ### Example 5 — WHERE Clause (Filtering):
190
  **User Request:** "Get the names of customers from Germany."
191
 
 
196
  ```sql
197
  SELECT name FROM customers WHERE country = 'Germany';
198
 
199
+
200
  ## Example 6 — Basic Aggregation**
201
  **User Request:** "Get total sales for each product"
202
 
 
206
  ```sql
207
  SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
208
 
209
+
210
  ### Invalid Example — DELETE Operation (Not Allowed):
211
  **User Request:** "Delete all customers from Germany."
212
 
utils.py CHANGED
@@ -67,13 +67,6 @@ def parse_mcp_output(output_dict):
67
  "response": content
68
  })
69
 
70
- elif role_name == "HumanMessage":
71
- result.append({
72
- "type": "user_message",
73
- "content": content
74
- })
75
-
76
-
77
  return result, query_store
78
 
79
 
 
67
  "response": content
68
  })
69
 
 
 
 
 
 
 
 
70
  return result, query_store
71
 
72