Selcan Yukcu commited on
Commit
7afc0de
·
1 Parent(s): 7f3ee7b

feat: smolagent option

Browse files
Files changed (1) hide show
  1. postgre_smolagent_clinet.py +128 -0
postgre_smolagent_clinet.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from mcp import ClientSession, StdioServerParameters
3
+ from mcp.client.stdio import stdio_client
4
+ from langchain_mcp_adapters.tools import load_mcp_tools
5
+ from langgraph.prebuilt import create_react_agent
6
+ from langchain.chat_models import init_chat_model
7
+ from conversation_memory import ConversationMemory
8
+ from utils import parse_mcp_output, classify_intent
9
+ import logging
10
+ from smolagents import LiteLLMModel, ToolCollection, CodeAgent
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ async def pg_mcp_smolagent_exec(request: str) -> str:
15
+ """
16
+ Execute the full PostgreSQL MCP pipeline: load summary, connect session,
17
+ load memory and tools, build prompt, run agent, update memory.
18
+
19
+ Args:
20
+ request (str): User's request input.
21
+ llm (Any): Language model for reasoning agent.
22
+
23
+ Returns:
24
+ str: Agent response message.
25
+ """
26
+ # TODO: give summary file path from config
27
+ table_summary = load_table_summary("table_summary.txt")
28
+ server_params = get_server_params()
29
+
30
+ # TODO: give key from env
31
+ llm = LiteLLMModel(model_id="gemini/gemini-2.0-flash-lite", api_key="AIzaSyAuxYmci0DVU5l5L_YcxLlxHzR5MLn70js")
32
+
33
+ async with stdio_client(server_params) as (read, write):
34
+ async with ClientSession(read, write) as session:
35
+ await session.initialize()
36
+ memory = await load_or_create_memory()
37
+
38
+ intent = classify_intent(request)
39
+ with ToolCollection.from_mcp(server_params, trust_remote_code=True) as tool_collection:
40
+ # Create a Code agent using the LLM and tools
41
+ agent = CodeAgent(model=llm, tools=[*tool_collection.tools], add_base_tools=True)
42
+ tools = await load_and_enrich_tools(session, table_summary)
43
+ past_data = get_memory_snapshot(memory)
44
+
45
+ prompt = await build_prompt(session, intent, request, tools, past_data)
46
+ agent_response = agent.run(task=prompt, stream=False)
47
+
48
+
49
+ parsed_steps, _ = parse_mcp_output(agent_response)
50
+ memory.update_from_parsed(parsed_steps, request)
51
+
52
+ await handle_memory_save_or_reset(memory, request)
53
+
54
+ return agent_response
55
+
56
+
57
+ # ---------------- Helper Functions ---------------- #
58
+
59
+ def load_table_summary(path: str) -> str:
60
+ with open(path, 'r') as file:
61
+ return file.read()
62
+
63
+ def get_server_params() -> StdioServerParameters:
64
+ # TODO: give server params from config
65
+ return StdioServerParameters(
66
+ command="python",
67
+ args=[r"C:\Users\yukcus\Desktop\MCPTest\postgre_mcp_server.py"],
68
+ )
69
+
70
+ async def load_or_create_memory() -> ConversationMemory:
71
+ memory = ConversationMemory()
72
+ if os.path.exists("memory.json"):
73
+ return memory.load_memory()
74
+ return memory
75
+
76
+ async def load_and_enrich_tools(session: ClientSession, summary: str):
77
+ tools = await load_mcp_tools(session)
78
+ for tool in tools:
79
+ tool.description += f" {summary}"
80
+ return tools
81
+
82
+ def get_memory_snapshot(memory: ConversationMemory) -> dict:
83
+ if os.path.exists("memory.json"):
84
+ return {
85
+ "past_tools": memory.get_all_tools_used(),
86
+ "past_queries": memory.get_last_n_queries(),
87
+ "past_results": memory.get_last_n_results(),
88
+ "past_requests": memory.get_all_user_messages()
89
+ }
90
+ return {
91
+ "past_tools": "No tools found",
92
+ "past_queries": "No queries found",
93
+ "past_results": "No results found",
94
+ "past_requests": "No requests found"
95
+ }
96
+
97
+ async def build_prompt(session, intent, request, tools, past_data):
98
+ superset_prompt = await session.read_resource("resource://last_prompt")
99
+ conversation_prompt = await session.read_resource("resource://base_prompt")
100
+ # TODO: add uri's from config
101
+ if intent == "superset_request":
102
+ template = superset_prompt.contents[0].text
103
+ return template.format(
104
+ user_requests=past_data["past_requests"],
105
+ past_tools=past_data["past_tools"],
106
+ last_queries=past_data["past_queries"],
107
+ last_results=past_data["past_results"],
108
+ new_request=request
109
+ )
110
+ else:
111
+ template = conversation_prompt.contents[0].text
112
+ tools_str = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
113
+ return template.format(
114
+ user_requests=past_data["past_requests"],
115
+ past_tools=past_data["past_tools"],
116
+ last_queries=past_data["past_queries"],
117
+ last_results=past_data["past_results"],
118
+ new_request=request,
119
+ tools=tools_str
120
+ )
121
+
122
+ async def handle_memory_save_or_reset(memory: ConversationMemory, request: str):
123
+ if request.strip().lower() == "stop":
124
+ memory.reset()
125
+ logger.info("Conversation memory reset.")
126
+ else:
127
+ memory.save_memory()
128
+ logger.info("Conversation memory saved.")