Selcan Yukcu
commited on
Commit
·
7f3ee7b
1
Parent(s):
b112622
refactor: separate helper functions and main file from client function. refactor output prints
Browse files- main.py +32 -0
- postgre_mcp_client.py +101 -93
- table_summary.txt +3 -0
- utils.py +36 -8
main.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List
|
| 4 |
+
import asyncio
|
| 5 |
+
from postgre_mcp_client import pg_mcp_exec
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
#logger = logging.getLogger(__name__)
|
| 9 |
+
# TODO add config
|
| 10 |
+
def load_db_configs():
|
| 11 |
+
"""Load database configurations from databases.yaml"""
|
| 12 |
+
configs_path = Path("configs.yaml")
|
| 13 |
+
|
| 14 |
+
if not configs_path.exists():
|
| 15 |
+
raise FileNotFoundError("configs.yaml not found")
|
| 16 |
+
|
| 17 |
+
with open(configs_path) as f:
|
| 18 |
+
configs = yaml.safe_load(f)
|
| 19 |
+
|
| 20 |
+
return configs["db_configs"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
async def main():
|
| 24 |
+
#configs = load_db_configs()
|
| 25 |
+
|
| 26 |
+
request = "Show me the table of join posts and users tables."
|
| 27 |
+
await pg_mcp_exec(request)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
|
| 32 |
+
asyncio.run(main())
|
postgre_mcp_client.py
CHANGED
|
@@ -1,118 +1,126 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
import os.path
|
| 3 |
-
|
| 4 |
from mcp import ClientSession, StdioServerParameters
|
| 5 |
from mcp.client.stdio import stdio_client
|
| 6 |
-
|
| 7 |
from langchain_mcp_adapters.tools import load_mcp_tools
|
| 8 |
from langgraph.prebuilt import create_react_agent
|
| 9 |
from langchain.chat_models import init_chat_model
|
| 10 |
from conversation_memory import ConversationMemory
|
| 11 |
-
|
| 12 |
from utils import parse_mcp_output, classify_intent
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
#request = "can you show me the result of the join of posts and users tables?"
|
| 28 |
-
#request = "But you did not execute the query. Can you tell me why?"
|
| 29 |
-
#request = "May ı see the table?"
|
| 30 |
-
#request = "stop"
|
| 31 |
-
#request = "how many columns are there in this joined table?"
|
| 32 |
-
request = "send the table"
|
| 33 |
-
async def main():
|
| 34 |
async with stdio_client(server_params) as (read, write):
|
| 35 |
async with ClientSession(read, write) as session:
|
| 36 |
-
# Initialize the connection
|
| 37 |
await session.initialize()
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# Get tools
|
| 42 |
-
tools = await load_mcp_tools(session)
|
| 43 |
-
for tool in tools:
|
| 44 |
-
tool.description += f" {table_summary}"
|
| 45 |
-
|
| 46 |
-
if os.path.exists("memory.json"):
|
| 47 |
-
memory = memory.load_memory()
|
| 48 |
-
past_tools = memory.get_all_tools_used()
|
| 49 |
-
past_queries = memory.get_last_n_queries()
|
| 50 |
-
past_results = memory.get_last_n_results()
|
| 51 |
-
past_requests = memory.get_all_user_messages()
|
| 52 |
-
|
| 53 |
-
else:
|
| 54 |
-
past_tools = "No tools found"
|
| 55 |
-
past_queries ="No queries found"
|
| 56 |
-
past_results = "No results found"
|
| 57 |
-
past_requests = "No requests found"
|
| 58 |
-
|
| 59 |
|
| 60 |
intent = classify_intent(request)
|
|
|
|
| 61 |
|
| 62 |
-
if intent == "superset_request":
|
| 63 |
-
uri = f"resource://last_prompt"
|
| 64 |
-
resource = await session.read_resource(uri)
|
| 65 |
-
base_prompt = resource.contents[0].text
|
| 66 |
-
|
| 67 |
-
prompt = base_prompt.format(
|
| 68 |
-
user_requests=past_requests,
|
| 69 |
-
past_tools=past_tools,
|
| 70 |
-
last_queries=past_queries,
|
| 71 |
-
last_results=past_results,
|
| 72 |
-
new_request=request
|
| 73 |
-
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
else:
|
| 77 |
-
uri = f"resource://base_prompt"
|
| 78 |
-
resource = await session.read_resource(uri)
|
| 79 |
-
base_prompt = resource.contents[0].text
|
| 80 |
-
|
| 81 |
-
# Create a formatted string of tools
|
| 82 |
-
tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
prompt = base_prompt.format(
|
| 86 |
-
user_requests=past_requests,
|
| 87 |
-
past_tools=past_tools,
|
| 88 |
-
last_queries=past_queries,
|
| 89 |
-
last_results=past_results,
|
| 90 |
-
new_request = request,
|
| 91 |
-
tools = tools_str
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
# Create and run the agent
|
| 97 |
agent = create_react_agent(llm, tools)
|
| 98 |
agent_response = await agent.ainvoke({"messages": prompt})
|
| 99 |
|
| 100 |
-
|
| 101 |
-
parsed_steps, query_store = parse_mcp_output(agent_response)
|
| 102 |
-
print("************")
|
| 103 |
-
print(parsed_steps)
|
| 104 |
memory.update_from_parsed(parsed_steps, request)
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os.path
|
|
|
|
| 2 |
from mcp import ClientSession, StdioServerParameters
|
| 3 |
from mcp.client.stdio import stdio_client
|
|
|
|
| 4 |
from langchain_mcp_adapters.tools import load_mcp_tools
|
| 5 |
from langgraph.prebuilt import create_react_agent
|
| 6 |
from langchain.chat_models import init_chat_model
|
| 7 |
from conversation_memory import ConversationMemory
|
|
|
|
| 8 |
from utils import parse_mcp_output, classify_intent
|
| 9 |
+
import logging
|
| 10 |
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
async def pg_mcp_exec(request: str) -> str:
|
| 14 |
+
"""
|
| 15 |
+
Execute the full PostgreSQL MCP pipeline: load summary, connect session,
|
| 16 |
+
load memory and tools, build prompt, run agent, update memory.
|
| 17 |
|
| 18 |
+
Args:
|
| 19 |
+
request (str): User's request input.
|
| 20 |
+
llm (Any): Language model for reasoning agent.
|
| 21 |
|
| 22 |
+
Returns:
|
| 23 |
+
str: Agent response message.
|
| 24 |
+
"""
|
| 25 |
+
# TODO: give summary file path from config
|
| 26 |
+
table_summary = load_table_summary("table_summary.txt")
|
| 27 |
+
server_params = get_server_params()
|
| 28 |
+
|
| 29 |
+
# TODO: give key from env
|
| 30 |
+
llm = init_chat_model(model="gemini-2.0-flash-lite", model_provider="google_genai",
|
| 31 |
+
api_key="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
async with stdio_client(server_params) as (read, write):
|
| 34 |
async with ClientSession(read, write) as session:
|
|
|
|
| 35 |
await session.initialize()
|
| 36 |
+
memory = await load_or_create_memory()
|
| 37 |
|
| 38 |
+
tools = await load_and_enrich_tools(session, table_summary)
|
| 39 |
+
past_data = get_memory_snapshot(memory)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
intent = classify_intent(request)
|
| 42 |
+
prompt = await build_prompt(session, intent, request, tools, past_data)
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
agent = create_react_agent(llm, tools)
|
| 45 |
agent_response = await agent.ainvoke({"messages": prompt})
|
| 46 |
|
| 47 |
+
parsed_steps, _ = parse_mcp_output(agent_response)
|
|
|
|
|
|
|
|
|
|
| 48 |
memory.update_from_parsed(parsed_steps, request)
|
| 49 |
|
| 50 |
+
await handle_memory_save_or_reset(memory, request)
|
| 51 |
+
|
| 52 |
+
return agent_response
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ---------------- Helper Functions ---------------- #
|
| 56 |
+
|
| 57 |
+
def load_table_summary(path: str) -> str:
|
| 58 |
+
with open(path, 'r') as file:
|
| 59 |
+
return file.read()
|
| 60 |
+
|
| 61 |
+
def get_server_params() -> StdioServerParameters:
|
| 62 |
+
# TODO: give server params from config
|
| 63 |
+
return StdioServerParameters(
|
| 64 |
+
command="python",
|
| 65 |
+
args=[r"C:\Users\yukcus\Desktop\MCPTest\postgre_mcp_server.py"],
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
async def load_or_create_memory() -> ConversationMemory:
|
| 69 |
+
memory = ConversationMemory()
|
| 70 |
+
if os.path.exists("memory.json"):
|
| 71 |
+
return memory.load_memory()
|
| 72 |
+
return memory
|
| 73 |
+
|
| 74 |
+
async def load_and_enrich_tools(session: ClientSession, summary: str):
|
| 75 |
+
tools = await load_mcp_tools(session)
|
| 76 |
+
for tool in tools:
|
| 77 |
+
tool.description += f" {summary}"
|
| 78 |
+
return tools
|
| 79 |
+
|
| 80 |
+
def get_memory_snapshot(memory: ConversationMemory) -> dict:
|
| 81 |
+
if os.path.exists("memory.json"):
|
| 82 |
+
return {
|
| 83 |
+
"past_tools": memory.get_all_tools_used(),
|
| 84 |
+
"past_queries": memory.get_last_n_queries(),
|
| 85 |
+
"past_results": memory.get_last_n_results(),
|
| 86 |
+
"past_requests": memory.get_all_user_messages()
|
| 87 |
+
}
|
| 88 |
+
return {
|
| 89 |
+
"past_tools": "No tools found",
|
| 90 |
+
"past_queries": "No queries found",
|
| 91 |
+
"past_results": "No results found",
|
| 92 |
+
"past_requests": "No requests found"
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
async def build_prompt(session, intent, request, tools, past_data):
|
| 96 |
+
superset_prompt = await session.read_resource("resource://last_prompt")
|
| 97 |
+
conversation_prompt = await session.read_resource("resource://base_prompt")
|
| 98 |
+
# TODO: add uri's from config
|
| 99 |
+
if intent == "superset_request":
|
| 100 |
+
template = superset_prompt.contents[0].text
|
| 101 |
+
return template.format(
|
| 102 |
+
user_requests=past_data["past_requests"],
|
| 103 |
+
past_tools=past_data["past_tools"],
|
| 104 |
+
last_queries=past_data["past_queries"],
|
| 105 |
+
last_results=past_data["past_results"],
|
| 106 |
+
new_request=request
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
template = conversation_prompt.contents[0].text
|
| 110 |
+
tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
|
| 111 |
+
return template.format(
|
| 112 |
+
user_requests=past_data["past_requests"],
|
| 113 |
+
past_tools=past_data["past_tools"],
|
| 114 |
+
last_queries=past_data["past_queries"],
|
| 115 |
+
last_results=past_data["past_results"],
|
| 116 |
+
new_request=request,
|
| 117 |
+
tools=tools_str
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
async def handle_memory_save_or_reset(memory: ConversationMemory, request: str):
|
| 121 |
+
if request.strip().lower() == "stop":
|
| 122 |
+
memory.reset()
|
| 123 |
+
logger.info("Conversation memory reset.")
|
| 124 |
+
else:
|
| 125 |
+
memory.save_memory()
|
| 126 |
+
logger.info("Conversation memory saved.")
|
table_summary.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The users table stores information about the individuals who use the application. Each user is assigned a unique, auto-incrementing id that serves as the primary key. The username field holds the user's chosen display name and cannot be null, while the email field stores the user’s unique email address, also required and constrained to be unique to avoid duplicates. To track when a user was added to the system, the created_at column records the timestamp of their creation, with a default value set to the current time.
|
| 2 |
+
|
| 3 |
+
The posts table represents content created by users, such as blog posts or messages. Like the users table, each entry has a unique, auto-incrementing id as the primary key. The user_id field links each post to its author by referencing the id field in the users table, establishing a one-to-many relationship between users and posts. The title column holds a brief summary or headline of the post, while the content field contains the full text. A created_at timestamp is also included to record when each post was created, with a default value of the current time.
|
utils.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
import re
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
def parse_mcp_output(output_dict):
|
| 5 |
result = []
|
| 6 |
messages = output_dict.get("messages", [])
|
| 7 |
-
|
| 8 |
query_store = []
|
| 9 |
|
| 10 |
for msg in messages:
|
|
@@ -30,9 +32,20 @@ def parse_mcp_output(output_dict):
|
|
| 30 |
|
| 31 |
# Check for presence of "query" key
|
| 32 |
if "query" in arguments_dict:
|
| 33 |
-
print("query detected!!!")
|
| 34 |
-
print(f"
|
| 35 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
query_store.append(arguments_dict["query"])
|
| 37 |
|
| 38 |
result.append({
|
|
@@ -42,7 +55,16 @@ def parse_mcp_output(output_dict):
|
|
| 42 |
"args": arguments
|
| 43 |
})
|
| 44 |
else:
|
| 45 |
-
print(f"ai said:{content}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
result.append({
|
| 47 |
"type": "ai_function_call",
|
| 48 |
"ai_said": content,
|
|
@@ -51,7 +73,10 @@ def parse_mcp_output(output_dict):
|
|
| 51 |
})
|
| 52 |
|
| 53 |
else:
|
| 54 |
-
print(f"ai final answer:{content}")
|
|
|
|
|
|
|
|
|
|
| 55 |
result.append({
|
| 56 |
"type": "ai_final_answer",
|
| 57 |
"ai_said": content
|
|
@@ -60,7 +85,9 @@ def parse_mcp_output(output_dict):
|
|
| 60 |
# ToolMessage
|
| 61 |
elif role_name == "ToolMessage":
|
| 62 |
tool_name = getattr(msg, "name", None)
|
| 63 |
-
print(
|
|
|
|
|
|
|
| 64 |
result.append({
|
| 65 |
"type": "tool_response",
|
| 66 |
"tool": tool_name,
|
|
@@ -88,3 +115,4 @@ def classify_intent(user_input: str) -> str:
|
|
| 88 |
|
| 89 |
# Fallback
|
| 90 |
return "sql_request"
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
+
import os
|
| 3 |
+
from conversation_memory import ConversationMemory
|
| 4 |
+
import logging
|
| 5 |
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
def parse_mcp_output(output_dict):
|
| 8 |
result = []
|
| 9 |
messages = output_dict.get("messages", [])
|
|
|
|
| 10 |
query_store = []
|
| 11 |
|
| 12 |
for msg in messages:
|
|
|
|
| 32 |
|
| 33 |
# Check for presence of "query" key
|
| 34 |
if "query" in arguments_dict:
|
| 35 |
+
#print("query detected!!!")
|
| 36 |
+
print(f"=============== AI Reasoning Step ===============")
|
| 37 |
+
print(content[0])
|
| 38 |
+
print()
|
| 39 |
+
print("=============== AI used the following tools ===============")
|
| 40 |
+
print(tool_name)
|
| 41 |
+
print()
|
| 42 |
+
print("=============== AI generated the following query ===============")
|
| 43 |
+
print(arguments_dict['query'])
|
| 44 |
+
|
| 45 |
+
logger.info(f"ai said:{content[0]}")
|
| 46 |
+
logger.info(f"ai used:{tool_name}")
|
| 47 |
+
logger.info(f"generated query:{arguments_dict['query']}")
|
| 48 |
+
#print(arguments_dict["query"])
|
| 49 |
query_store.append(arguments_dict["query"])
|
| 50 |
|
| 51 |
result.append({
|
|
|
|
| 55 |
"args": arguments
|
| 56 |
})
|
| 57 |
else:
|
| 58 |
+
#print(f"ai said:{content}")
|
| 59 |
+
logger.info(f"ai said:{content}")
|
| 60 |
+
logger.info(f"ai used:{tool_name}")
|
| 61 |
+
print(f"=============== AI Reasoning Step ===============")
|
| 62 |
+
print(content)
|
| 63 |
+
print()
|
| 64 |
+
print("=============== AI used the following tools ===============")
|
| 65 |
+
print(tool_name)
|
| 66 |
+
print()
|
| 67 |
+
|
| 68 |
result.append({
|
| 69 |
"type": "ai_function_call",
|
| 70 |
"ai_said": content,
|
|
|
|
| 73 |
})
|
| 74 |
|
| 75 |
else:
|
| 76 |
+
#print(f"ai final answer:{content}")
|
| 77 |
+
logger.info(f"ai final answer:{content}")
|
| 78 |
+
print("=============== AI's final answer ===============")
|
| 79 |
+
print(content)
|
| 80 |
result.append({
|
| 81 |
"type": "ai_final_answer",
|
| 82 |
"ai_said": content
|
|
|
|
| 85 |
# ToolMessage
|
| 86 |
elif role_name == "ToolMessage":
|
| 87 |
tool_name = getattr(msg, "name", None)
|
| 88 |
+
print("=============== The tool returned the following response ===============")
|
| 89 |
+
print(content)
|
| 90 |
+
logger.info(f"tool response:{content}")
|
| 91 |
result.append({
|
| 92 |
"type": "tool_response",
|
| 93 |
"tool": tool_name,
|
|
|
|
| 115 |
|
| 116 |
# Fallback
|
| 117 |
return "sql_request"
|
| 118 |
+
|