Selcan Yukcu commited on
Commit
a3f399c
·
1 Parent(s): 440696e

refactor: add parse_mcp_output into utils, give prompt_temp as a resource of server, improve prompt

Browse files
Files changed (3) hide show
  1. postgre_mcp_client.py +14 -84
  2. postgre_mcp_server.py +155 -0
  3. utils.py +74 -0
postgre_mcp_client.py CHANGED
@@ -7,11 +7,12 @@ from langgraph.prebuilt import create_react_agent
7
  from langchain.chat_models import init_chat_model
8
  from conversation_memory import ConversationMemory
9
 
 
 
10
  llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",api_key ="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
11
 
12
  server_params = StdioServerParameters(
13
  command="python",
14
- # buraya full path konulmalı
15
  args=[r"C:\Users\yukcus\Desktop\MCPTest\postgre_mcp_server.py"],
16
  )
17
 
@@ -20,81 +21,6 @@ The users table stores information about the individuals who use the application
20
 
21
  The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
22
  """
23
- prompt_temp = ""
24
-
25
- def parse_mcp_output(output_dict):
26
- result = []
27
- messages = output_dict.get("messages", [])
28
-
29
- for msg in messages:
30
- role_name = msg.__class__.__name__ # Example: HumanMessage, AIMessage, ToolMessage
31
- content = getattr(msg, "content", "")
32
-
33
- # AIMessage with tool call
34
- if role_name == "AIMessage":
35
- function_call = getattr(msg, "additional_kwargs", {}).get("function_call")
36
- if function_call:
37
- tool_name = function_call.get("name")
38
- arguments = function_call.get("arguments")
39
-
40
- # Check if arguments is a JSON string or a dict
41
- if isinstance(arguments, str):
42
- import json
43
- try:
44
- arguments_dict = json.loads(arguments)
45
- except json.JSONDecodeError:
46
- arguments_dict = {}
47
- else:
48
- arguments_dict = arguments or {}
49
-
50
- # Check for presence of "query" key
51
- if "query" in arguments_dict:
52
- print("query detected!!!")
53
- print(f"ai said:{content[0]}")
54
- print(arguments_dict["query"])
55
-
56
- result.append({
57
- "type": "ai_function_call",
58
- "ai_said": content,
59
- "tool": tool_name,
60
- "args": arguments
61
- })
62
- else:
63
- print(f"ai said:{content}")
64
- result.append({
65
- "type": "ai_function_call",
66
- "ai_said": content,
67
- "tool": tool_name,
68
- "args": arguments
69
- })
70
-
71
- else:
72
- print(f"ai final answer:{content}")
73
- result.append({
74
- "type": "ai_final_answer",
75
- "ai_said": content
76
- })
77
-
78
- # ToolMessage
79
- elif role_name == "ToolMessage":
80
- tool_name = getattr(msg, "name", None)
81
- print(f"tool response:{content}")
82
- result.append({
83
- "type": "tool_response",
84
- "tool": tool_name,
85
- "response": content
86
- })
87
-
88
- elif role_name == "HumanMessage":
89
- result.append({
90
- "type": "user_message",
91
- "content": content
92
- })
93
-
94
-
95
- return result
96
-
97
-
98
 
99
  request = "can you show me the result of the join of all tables?"
100
  request2 = "how many columns are there in this joined table?"
@@ -105,11 +31,6 @@ async def main():
105
  await session.initialize()
106
 
107
  memory = ConversationMemory()
108
- prompt = ""
109
- with open(r"C:\Users\yukcus\Desktop\MCPTest\prompt_temp.txt", 'r', encoding='utf-8') as file:
110
- prompt_temp = file.read()
111
-
112
- prompt += prompt_temp
113
 
114
  # Get tools
115
  tools = await load_mcp_tools(session)
@@ -122,12 +43,21 @@ async def main():
122
  past_results = memory.get_last_n_results()
123
  past_requests = memory.get_all_user_messages()
124
 
125
- prompt = prompt.format(
 
 
 
 
 
 
 
 
126
  user_requests=past_requests,
127
  past_tools=past_tools,
128
  last_queries=past_queries,
129
  last_results=past_results,
130
- new_request = request2
 
131
  )
132
 
133
 
@@ -137,7 +67,7 @@ async def main():
137
  agent_response = await agent.ainvoke({"messages": prompt})
138
 
139
 
140
- parsed_steps = parse_mcp_output(agent_response)
141
 
142
 
143
  memory.update_from_parsed(parsed_steps)
 
7
  from langchain.chat_models import init_chat_model
8
  from conversation_memory import ConversationMemory
9
 
10
+ from utils import parse_mcp_output
11
+
12
  llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",api_key ="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
13
 
14
  server_params = StdioServerParameters(
15
  command="python",
 
16
  args=[r"C:\Users\yukcus\Desktop\MCPTest\postgre_mcp_server.py"],
17
  )
18
 
 
21
 
22
  The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
23
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  request = "can you show me the result of the join of all tables?"
26
  request2 = "how many columns are there in this joined table?"
 
31
  await session.initialize()
32
 
33
  memory = ConversationMemory()
 
 
 
 
 
34
 
35
  # Get tools
36
  tools = await load_mcp_tools(session)
 
43
  past_results = memory.get_last_n_results()
44
  past_requests = memory.get_all_user_messages()
45
 
46
+ uri = f"resource://base_prompt_table"
47
+ resource = await session.read_resource(uri)
48
+ base_prompt = resource.contents[0].text
49
+
50
+ # Create a formatted string of tools
51
+ tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
52
+
53
+
54
+ prompt = base_prompt.format(
55
  user_requests=past_requests,
56
  past_tools=past_tools,
57
  last_queries=past_queries,
58
  last_results=past_results,
59
+ new_request = request2,
60
+ tools = tools_str
61
  )
62
 
63
 
 
67
  agent_response = await agent.ainvoke({"messages": prompt})
68
 
69
 
70
+ parsed_steps, query_store = parse_mcp_output(agent_response)
71
 
72
 
73
  memory.update_from_parsed(parsed_steps)
postgre_mcp_server.py CHANGED
@@ -49,6 +49,161 @@ mcp = FastMCP(
49
  )
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
53
  async def test_connection(ctx: Context) -> str:
54
  """Test database connection"""
 
49
  )
50
 
51
 
52
+ @mcp.resource(
53
+ uri="resource://base_prompt_table",
54
+ name="base_prompt_table",
55
+ description="A base prompt to generate description of a table"
56
+ )
57
+ async def base_prompt_table() -> str:
58
+ """Returns a base prompt to generate description of a table"""
59
+
60
+ base_prompt = """
61
+
62
+ ==========================
63
+ # Your Role
64
+ ==========================
65
+
66
+ You are an expert in generating SQL queries and interacting with a PostgreSQL database using **FastMCP tools**. These tools allow you to:
67
+
68
+ - List available tables
69
+ - Retrieve schema details
70
+ - Execute SQL queries
71
+
72
+ Each tool may also return previews or summaries of table contents to help you better understand the data structure.
73
+
74
+ You also have access to **short-term memory**, which stores relevant context from earlier queries. If memory contains the needed information, you **must use it** instead of repeating tool calls with the same input. Avoid redundant tool usage unless:
75
+ - The memory is empty, or
76
+ - A tool's output is outdated or missing
77
+
78
+ ---
79
+
80
+ ==========================
81
+ # Your Objective
82
+ ==========================
83
+
84
+ When a user submits a request, follow these steps:
85
+
86
+ 1. Analyze the request to determine the desired data or action.
87
+ 2. Use tools to gather any necessary information (e.g., list tables, get schema).
88
+ 3. Generate a valid SQL query (such as **SELECT**, **COUNT**, or other read-only operations) and clearly display the full query.
89
+ 4. Execute the query and return the result.
90
+ 5. Chain tools logically to build toward the answer.
91
+ 6. Explain your reasoning at every step for clarity and transparency.
92
+
93
+ ---
94
+
95
+ ==========================
96
+ # Critical Rules
97
+ ==========================
98
+
99
+ - Only use **read-only** SQL queries such as **SELECT**, **COUNT**, or queries with **GROUP BY**, **ORDER BY**, etc.
100
+ - **Never** use destructive operations like **DELETE**, **UPDATE**, **INSERT**, or **DROP**.
101
+ - Always show the SQL query you generate along with the execution result.
102
+ - Validate SQL syntax before execution.
103
+ - Never assume table or column names. Use tools to confirm structure.
104
+ - Use memory efficiently. Don’t rerun a tool unless necessary.
105
+
106
+ ---
107
+
108
+ ==========================
109
+ # Short-Term Memory
110
+ ==========================
111
+
112
+ You have access to the following memory from this conversation. Use it if applicable for the current request.
113
+
114
+ - Previous user requests: {user_requests}
115
+ - Tools used so far: {past_tools}
116
+ - Last SQL queries: {last_queries}
117
+ - Last result preview: {last_results}
118
+
119
+ ---
120
+
121
+ ==========================
122
+ # Tools
123
+ ==========================
124
+
125
+ You can use the following FastMCP tools. These allow you to create **read-only** queries, such as `SELECT`, `COUNT`, or queries with `GROUP BY`, `ORDER BY`, and similar clauses. You may chain tools together to gather the necessary information before generating your SQL query.
126
+
127
+ {tools}
128
+
129
+ ---
130
+
131
+ ==========================
132
+ # Tool Usage Examples
133
+ ==========================
134
+
135
+ ### Example 1 — Unknown Table Name:
136
+ **User Request:** "Get the total sales for each product."
137
+ **Steps:**
138
+ 1. List Tables → Identify a table like `sales_data`.
139
+ 2. Get Schema for `sales_data` → Confirm columns like `product_name`, `total_sales`.
140
+ 3. Generate and execute query:
141
+ ```sql
142
+ SELECT product_name, SUM(total_sales) FROM sales_data GROUP BY product_name;
143
+
144
+ ### Example 2 — Schema Uncertainty:
145
+ **User Request:** "Show customer emails from the database."
146
+
147
+ **Steps:**
148
+ 1. Use memory to check if we already retrieved schema.
149
+ 2. If not, List Tables → Identify a table like `customers`.
150
+ 3. Get Schema for `customers` → Confirm column `email`.
151
+ 4. Query:
152
+ ```sql
153
+ SELECT email FROM customers;
154
+
155
+ ### Example 3 — Memory Usage:
156
+ **User Request:** "Get top 5 most expensive products."
157
+
158
+ **Steps:**
159
+ 1. Check memory for schema of `products` table.
160
+ 2. If column `price` exists in memory, directly generate:
161
+ ```sql
162
+ SELECT * FROM products ORDER BY price DESC LIMIT 5;
163
+
164
+ ### Example 4 — COUNT Query:
165
+ **User Request:** "How many orders have been placed?"
166
+
167
+ **Steps:**
168
+ 1. List Tables → Identify a table like `orders`.
169
+ 2. Get Schema for `orders` → Confirm it's the right table.
170
+ 3. Query:
171
+ ```sql
172
+ SELECT COUNT(*) FROM orders;
173
+
174
+ ### Example 5 — WHERE Clause (Filtering):
175
+ **User Request:** "Get the names of customers from Germany."
176
+
177
+ **Steps:**
178
+ 1. Use memory or List Tables → Identify `customers` table.
179
+ 2. Get Schema for `customers` → Confirm columns like `country`, `name`.
180
+ 3. Generate and execute query:
181
+ ```sql
182
+ SELECT name FROM customers WHERE country = 'Germany';
183
+
184
+ ### Invalid Example — DELETE Operation (Not Allowed):
185
+ **User Request:** "Delete all customers from Germany."
186
+
187
+ **Response Guidance:**
188
+ - **Do not generate or execute** destructive queries such as `DELETE`.
189
+ - Instead, respond with a message like:
190
+ > Destructive operations such as `DELETE` are not permitted. I can help you retrieve the customers from Germany using a `SELECT` query instead:
191
+ > ```sql
192
+ > SELECT * FROM customers WHERE country = 'Germany';
193
+ > ```
194
+
195
+ =========================
196
+ # New User Request
197
+ =========================
198
+
199
+ Please fulfill the following request based on the above context:
200
+
201
+ {new_request}
202
+ """
203
+
204
+ return base_prompt
205
+
206
+
207
  @mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
208
  async def test_connection(ctx: Context) -> str:
209
  """Test database connection"""
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def parse_mcp_output(output_dict):
2
+ result = []
3
+ messages = output_dict.get("messages", [])
4
+
5
+ query_store = []
6
+
7
+ for msg in messages:
8
+ role_name = msg.__class__.__name__ # Example: HumanMessage, AIMessage, ToolMessage
9
+ content = getattr(msg, "content", "")
10
+
11
+ # AIMessage with tool call
12
+ if role_name == "AIMessage":
13
+ function_call = getattr(msg, "additional_kwargs", {}).get("function_call")
14
+ if function_call:
15
+ tool_name = function_call.get("name")
16
+ arguments = function_call.get("arguments")
17
+
18
+ # Check if arguments is a JSON string or a dict
19
+ if isinstance(arguments, str):
20
+ import json
21
+ try:
22
+ arguments_dict = json.loads(arguments)
23
+ except json.JSONDecodeError:
24
+ arguments_dict = {}
25
+ else:
26
+ arguments_dict = arguments or {}
27
+
28
+ # Check for presence of "query" key
29
+ if "query" in arguments_dict:
30
+ print("query detected!!!")
31
+ print(f"ai said:{content[0]}")
32
+ print(arguments_dict["query"])
33
+ query_store.append(arguments_dict["query"])
34
+
35
+ result.append({
36
+ "type": "ai_function_call",
37
+ "ai_said": content,
38
+ "tool": tool_name,
39
+ "args": arguments
40
+ })
41
+ else:
42
+ print(f"ai said:{content}")
43
+ result.append({
44
+ "type": "ai_function_call",
45
+ "ai_said": content,
46
+ "tool": tool_name,
47
+ "args": arguments
48
+ })
49
+
50
+ else:
51
+ print(f"ai final answer:{content}")
52
+ result.append({
53
+ "type": "ai_final_answer",
54
+ "ai_said": content
55
+ })
56
+
57
+ # ToolMessage
58
+ elif role_name == "ToolMessage":
59
+ tool_name = getattr(msg, "name", None)
60
+ print(f"tool response:{content}")
61
+ result.append({
62
+ "type": "tool_response",
63
+ "tool": tool_name,
64
+ "response": content
65
+ })
66
+
67
+ elif role_name == "HumanMessage":
68
+ result.append({
69
+ "type": "user_message",
70
+ "content": content
71
+ })
72
+
73
+
74
+ return result, query_store