Selcan Yukcu commited on
Commit
7f3ee7b
·
1 Parent(s): b112622

refactor: separate helper functions and main file from client function. refactor output prints

Browse files
Files changed (4) hide show
  1. main.py +32 -0
  2. postgre_mcp_client.py +101 -93
  3. table_summary.txt +3 -0
  4. utils.py +36 -8
main.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from pathlib import Path
3
+ from typing import List
4
+ import asyncio
5
+ from postgre_mcp_client import pg_mcp_exec
6
+ import logging
7
+
8
+ #logger = logging.getLogger(__name__)
9
+ # TODO add config
10
+ def load_db_configs():
11
+ """Load database configurations from databases.yaml"""
12
+ configs_path = Path("configs.yaml")
13
+
14
+ if not configs_path.exists():
15
+ raise FileNotFoundError("configs.yaml not found")
16
+
17
+ with open(configs_path) as f:
18
+ configs = yaml.safe_load(f)
19
+
20
+ return configs["db_configs"]
21
+
22
+
23
+ async def main():
24
+ #configs = load_db_configs()
25
+
26
+ request = "Show me the table of join posts and users tables."
27
+ await pg_mcp_exec(request)
28
+
29
+
30
+ if __name__ == "__main__":
31
+
32
+ asyncio.run(main())
postgre_mcp_client.py CHANGED
@@ -1,118 +1,126 @@
1
- import asyncio
2
  import os.path
3
-
4
  from mcp import ClientSession, StdioServerParameters
5
  from mcp.client.stdio import stdio_client
6
-
7
  from langchain_mcp_adapters.tools import load_mcp_tools
8
  from langgraph.prebuilt import create_react_agent
9
  from langchain.chat_models import init_chat_model
10
  from conversation_memory import ConversationMemory
11
-
12
  from utils import parse_mcp_output, classify_intent
 
13
 
14
- llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",api_key ="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
15
 
16
- server_params = StdioServerParameters(
17
- command="python",
18
- args=[r"C:\Users\yukcus\Desktop\MCPTest\postgre_mcp_server.py"],
19
- )
20
 
21
- table_summary = """
22
- The users table stores information about the individuals who use the application. Each user is assigned a unique, auto-incrementing id that serves as the primary key. The username field holds the user's chosen display name and cannot be null, while the email field stores the user’s unique email address, also required and constrained to be unique to avoid duplicates. To track when a user was added to the system, the created_at column records the timestamp of their creation, with a default value set to the current time.
 
23
 
24
- The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
25
- """
 
 
 
 
 
 
 
 
26
 
27
- #request = "can you show me the result of the join of posts and users tables?"
28
- #request = "But you did not execute the query. Can you tell me why?"
29
- #request = "May ı see the table?"
30
- #request = "stop"
31
- #request = "how many columns are there in this joined table?"
32
- request = "send the table"
33
- async def main():
34
  async with stdio_client(server_params) as (read, write):
35
  async with ClientSession(read, write) as session:
36
- # Initialize the connection
37
  await session.initialize()
 
38
 
39
- memory = ConversationMemory()
40
-
41
- # Get tools
42
- tools = await load_mcp_tools(session)
43
- for tool in tools:
44
- tool.description += f" {table_summary}"
45
-
46
- if os.path.exists("memory.json"):
47
- memory = memory.load_memory()
48
- past_tools = memory.get_all_tools_used()
49
- past_queries = memory.get_last_n_queries()
50
- past_results = memory.get_last_n_results()
51
- past_requests = memory.get_all_user_messages()
52
-
53
- else:
54
- past_tools = "No tools found"
55
- past_queries ="No queries found"
56
- past_results = "No results found"
57
- past_requests = "No requests found"
58
-
59
 
60
  intent = classify_intent(request)
 
61
 
62
- if intent == "superset_request":
63
- uri = f"resource://last_prompt"
64
- resource = await session.read_resource(uri)
65
- base_prompt = resource.contents[0].text
66
-
67
- prompt = base_prompt.format(
68
- user_requests=past_requests,
69
- past_tools=past_tools,
70
- last_queries=past_queries,
71
- last_results=past_results,
72
- new_request=request
73
-
74
- )
75
-
76
- else:
77
- uri = f"resource://base_prompt"
78
- resource = await session.read_resource(uri)
79
- base_prompt = resource.contents[0].text
80
-
81
- # Create a formatted string of tools
82
- tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
83
-
84
-
85
- prompt = base_prompt.format(
86
- user_requests=past_requests,
87
- past_tools=past_tools,
88
- last_queries=past_queries,
89
- last_results=past_results,
90
- new_request = request,
91
- tools = tools_str
92
- )
93
-
94
-
95
-
96
- # Create and run the agent
97
  agent = create_react_agent(llm, tools)
98
  agent_response = await agent.ainvoke({"messages": prompt})
99
 
100
-
101
- parsed_steps, query_store = parse_mcp_output(agent_response)
102
- print("************")
103
- print(parsed_steps)
104
  memory.update_from_parsed(parsed_steps, request)
105
 
106
- if request.strip().lower() == "stop":
107
- memory.reset()
108
- print("Conversation memory reset.")
109
-
110
- else:
111
- memory.save_memory()
112
- #pprint(parsed_steps)
113
-
114
-
115
-
116
- asyncio.run(main())
117
-
118
- # open up a new cmd shell and run awith python mcp_client.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os.path
 
2
  from mcp import ClientSession, StdioServerParameters
3
  from mcp.client.stdio import stdio_client
 
4
  from langchain_mcp_adapters.tools import load_mcp_tools
5
  from langgraph.prebuilt import create_react_agent
6
  from langchain.chat_models import init_chat_model
7
  from conversation_memory import ConversationMemory
 
8
  from utils import parse_mcp_output, classify_intent
9
+ import logging
10
 
11
+ logger = logging.getLogger(__name__)
12
 
13
+ async def pg_mcp_exec(request: str) -> str:
14
+ """
15
+ Execute the full PostgreSQL MCP pipeline: load summary, connect session,
16
+ load memory and tools, build prompt, run agent, update memory.
17
 
18
+ Args:
19
+ request (str): User's request input.
20
+ llm (Any): Language model for reasoning agent.
21
 
22
+ Returns:
23
+ str: Agent response message.
24
+ """
25
+ # TODO: give summary file path from config
26
+ table_summary = load_table_summary("table_summary.txt")
27
+ server_params = get_server_params()
28
+
29
+ # TODO: give key from env
30
+ llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",
31
+ api_key="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
32
 
 
 
 
 
 
 
 
33
  async with stdio_client(server_params) as (read, write):
34
  async with ClientSession(read, write) as session:
 
35
  await session.initialize()
36
+ memory = await load_or_create_memory()
37
 
38
+ tools = await load_and_enrich_tools(session, table_summary)
39
+ past_data = get_memory_snapshot(memory)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  intent = classify_intent(request)
42
+ prompt = await build_prompt(session, intent, request, tools, past_data)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  agent = create_react_agent(llm, tools)
45
  agent_response = await agent.ainvoke({"messages": prompt})
46
 
47
+ parsed_steps, _ = parse_mcp_output(agent_response)
 
 
 
48
  memory.update_from_parsed(parsed_steps, request)
49
 
50
+ await handle_memory_save_or_reset(memory, request)
51
+
52
+ return agent_response
53
+
54
+
55
+ # ---------------- Helper Functions ---------------- #
56
+
57
+ def load_table_summary(path: str) -> str:
58
+ with open(path, 'r') as file:
59
+ return file.read()
60
+
61
+ def get_server_params() -> StdioServerParameters:
62
+ # TODO: give server params from config
63
+ return StdioServerParameters(
64
+ command="python",
65
+ args=[r"C:\Users\yukcus\Desktop\MCPTest\postgre_mcp_server.py"],
66
+ )
67
+
68
+ async def load_or_create_memory() -> ConversationMemory:
69
+ memory = ConversationMemory()
70
+ if os.path.exists("memory.json"):
71
+ return memory.load_memory()
72
+ return memory
73
+
74
+ async def load_and_enrich_tools(session: ClientSession, summary: str):
75
+ tools = await load_mcp_tools(session)
76
+ for tool in tools:
77
+ tool.description += f" {summary}"
78
+ return tools
79
+
80
+ def get_memory_snapshot(memory: ConversationMemory) -> dict:
81
+ if os.path.exists("memory.json"):
82
+ return {
83
+ "past_tools": memory.get_all_tools_used(),
84
+ "past_queries": memory.get_last_n_queries(),
85
+ "past_results": memory.get_last_n_results(),
86
+ "past_requests": memory.get_all_user_messages()
87
+ }
88
+ return {
89
+ "past_tools": "No tools found",
90
+ "past_queries": "No queries found",
91
+ "past_results": "No results found",
92
+ "past_requests": "No requests found"
93
+ }
94
+
95
+ async def build_prompt(session, intent, request, tools, past_data):
96
+ superset_prompt = await session.read_resource("resource://last_prompt")
97
+ conversation_prompt = await session.read_resource("resource://base_prompt")
98
+ # TODO: add uri's from config
99
+ if intent == "superset_request":
100
+ template = superset_prompt.contents[0].text
101
+ return template.format(
102
+ user_requests=past_data["past_requests"],
103
+ past_tools=past_data["past_tools"],
104
+ last_queries=past_data["past_queries"],
105
+ last_results=past_data["past_results"],
106
+ new_request=request
107
+ )
108
+ else:
109
+ template = conversation_prompt.contents[0].text
110
+ tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
111
+ return template.format(
112
+ user_requests=past_data["past_requests"],
113
+ past_tools=past_data["past_tools"],
114
+ last_queries=past_data["past_queries"],
115
+ last_results=past_data["past_results"],
116
+ new_request=request,
117
+ tools=tools_str
118
+ )
119
+
120
+ async def handle_memory_save_or_reset(memory: ConversationMemory, request: str):
121
+ if request.strip().lower() == "stop":
122
+ memory.reset()
123
+ logger.info("Conversation memory reset.")
124
+ else:
125
+ memory.save_memory()
126
+ logger.info("Conversation memory saved.")
table_summary.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ The users table stores information about the individuals who use the application. Each user is assigned a unique, auto-incrementing id that serves as the primary key. The username field holds the user's chosen display name and cannot be null, while the email field stores the user’s unique email address, also required and constrained to be unique to avoid duplicates. To track when a user was added to the system, the created_at column records the timestamp of their creation, with a default value set to the current time.
2
+
3
+ The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
utils.py CHANGED
@@ -1,10 +1,12 @@
1
  import re
 
 
 
2
 
3
-
4
  def parse_mcp_output(output_dict):
5
  result = []
6
  messages = output_dict.get("messages", [])
7
-
8
  query_store = []
9
 
10
  for msg in messages:
@@ -30,9 +32,20 @@ def parse_mcp_output(output_dict):
30
 
31
  # Check for presence of "query" key
32
  if "query" in arguments_dict:
33
- print("query detected!!!")
34
- print(f"ai said:{content[0]}")
35
- print(arguments_dict["query"])
 
 
 
 
 
 
 
 
 
 
 
36
  query_store.append(arguments_dict["query"])
37
 
38
  result.append({
@@ -42,7 +55,16 @@ def parse_mcp_output(output_dict):
42
  "args": arguments
43
  })
44
  else:
45
- print(f"ai said:{content}")
 
 
 
 
 
 
 
 
 
46
  result.append({
47
  "type": "ai_function_call",
48
  "ai_said": content,
@@ -51,7 +73,10 @@ def parse_mcp_output(output_dict):
51
  })
52
 
53
  else:
54
- print(f"ai final answer:{content}")
 
 
 
55
  result.append({
56
  "type": "ai_final_answer",
57
  "ai_said": content
@@ -60,7 +85,9 @@ def parse_mcp_output(output_dict):
60
  # ToolMessage
61
  elif role_name == "ToolMessage":
62
  tool_name = getattr(msg, "name", None)
63
- print(f"tool response:{content}")
 
 
64
  result.append({
65
  "type": "tool_response",
66
  "tool": tool_name,
@@ -88,3 +115,4 @@ def classify_intent(user_input: str) -> str:
88
 
89
  # Fallback
90
  return "sql_request"
 
 
1
  import re
2
+ import os
3
+ from conversation_memory import ConversationMemory
4
+ import logging
5
 
6
+ logger = logging.getLogger(__name__)
7
  def parse_mcp_output(output_dict):
8
  result = []
9
  messages = output_dict.get("messages", [])
 
10
  query_store = []
11
 
12
  for msg in messages:
 
32
 
33
  # Check for presence of "query" key
34
  if "query" in arguments_dict:
35
+ #print("query detected!!!")
36
+ print(f"=============== AI Reasoning Step ===============")
37
+ print(content[0])
38
+ print()
39
+ print("=============== AI used the following tools ===============")
40
+ print(tool_name)
41
+ print()
42
+ print("=============== AI generated the following query ===============")
43
+ print(arguments_dict['query'])
44
+
45
+ logger.info(f"ai said:{content[0]}")
46
+ logger.info(f"ai used:{tool_name}")
47
+ logger.info(f"generated query:{arguments_dict['query']}")
48
+ #print(arguments_dict["query"])
49
  query_store.append(arguments_dict["query"])
50
 
51
  result.append({
 
55
  "args": arguments
56
  })
57
  else:
58
+ #print(f"ai said:{content}")
59
+ logger.info(f"ai said:{content}")
60
+ logger.info(f"ai used:{tool_name}")
61
+ print(f"=============== AI Reasoning Step ===============")
62
+ print(content)
63
+ print()
64
+ print("=============== AI used the following tools ===============")
65
+ print(tool_name)
66
+ print()
67
+
68
  result.append({
69
  "type": "ai_function_call",
70
  "ai_said": content,
 
73
  })
74
 
75
  else:
76
+ #print(f"ai final answer:{content}")
77
+ logger.info(f"ai final answer:{content}")
78
+ print("=============== AI's final answer ===============")
79
+ print(content)
80
  result.append({
81
  "type": "ai_final_answer",
82
  "ai_said": content
 
85
  # ToolMessage
86
  elif role_name == "ToolMessage":
87
  tool_name = getattr(msg, "name", None)
88
+ print("=============== The tool returned the following response ===============")
89
+ print(content)
90
+ logger.info(f"tool response:{content}")
91
  result.append({
92
  "type": "tool_response",
93
  "tool": tool_name,
 
115
 
116
  # Fallback
117
  return "sql_request"
118
+