JRealValdes commited on
Commit
b6e93af
·
1 Parent(s): 9994c78

MCP Memory Agent

Browse files

MCP Client Refactor

Added file to README Structure

README.md CHANGED
@@ -61,6 +61,7 @@ jarvis/
61
  ├── enums/
62
  │ └── core_enums.py
63
  ├── mcp/
 
64
  │ └── servers/
65
  │ └── math_server.py
66
  ├── media/
 
61
  ├── enums/
62
  │ └── core_enums.py
63
  ├── mcp/
64
+ │ ├── server_config.json
65
  │ └── servers/
66
  │ └── math_server.py
67
  ├── media/
agents/factory.py CHANGED
@@ -1,5 +1,6 @@
1
  from enums.core_enums import ModelEnum
2
  from agents.jarvis_memory_agent import JarvisMemoryAgent
 
3
  from agents.jarvis_basic_agent import JarvisBasicAgent
4
 
5
  models_with_memory = [ModelEnum.GPT_3_5]
@@ -8,6 +9,7 @@ def build_agent(model_used: ModelEnum):
8
  if model_used in [ModelEnum.ZEPHYR, ModelEnum.MISTRAL]:
9
  return JarvisBasicAgent(model_used)
10
  elif model_used == ModelEnum.GPT_3_5:
11
- return JarvisMemoryAgent(model_used)
 
12
  else:
13
  raise ValueError("Modelo no soportado.")
 
1
  from enums.core_enums import ModelEnum
2
  from agents.jarvis_memory_agent import JarvisMemoryAgent
3
+ from agents.jarvis_mcp_memory_agent import JarvisMcpMemoryAgent
4
  from agents.jarvis_basic_agent import JarvisBasicAgent
5
 
6
  models_with_memory = [ModelEnum.GPT_3_5]
 
9
  if model_used in [ModelEnum.ZEPHYR, ModelEnum.MISTRAL]:
10
  return JarvisBasicAgent(model_used)
11
  elif model_used == ModelEnum.GPT_3_5:
12
+ # return JarvisMemoryAgent(model_used)
13
+ return JarvisMcpMemoryAgent(model_used)
14
  else:
15
  raise ValueError("Modelo no soportado.")
agents/jarvis_mcp_memory_agent.py CHANGED
@@ -1,17 +1,114 @@
1
  import os
 
 
2
  from contextlib import AsyncExitStack
3
  from typing import Annotated
4
  from typing_extensions import TypedDict
5
  from enums.core_enums import ModelEnum
6
  from langchain_openai import ChatOpenAI
7
- from langchain_ollama import ChatOllama
8
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
9
- from langgraph.graph import StateGraph
10
  from langgraph.checkpoint.memory import MemorySaver
11
- from langgraph.prebuilt import create_react_agent, ToolNode, tools_condition
12
  from langgraph.graph.message import add_messages
 
13
  from tools.calc import calculate_tool
14
  from tools.speech_to_text import speech_to_text_tool
 
 
15
 
16
  local_tools = [calculate_tool, speech_to_text_tool]
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import asyncio
4
  from contextlib import AsyncExitStack
5
  from typing import Annotated
6
  from typing_extensions import TypedDict
7
  from enums.core_enums import ModelEnum
8
  from langchain_openai import ChatOpenAI
9
+ from langchain_mcp_adapters.tools import load_mcp_tools
 
 
10
  from langgraph.checkpoint.memory import MemorySaver
11
+ from langgraph.graph import StateGraph
12
  from langgraph.graph.message import add_messages
13
+ from langgraph.prebuilt import ToolNode, tools_condition
14
  from tools.calc import calculate_tool
15
  from tools.speech_to_text import speech_to_text_tool
16
+ from mcp import ClientSession, StdioServerParameters
17
+ from mcp.client.stdio import stdio_client
18
 
19
  local_tools = [calculate_tool, speech_to_text_tool]
20
 
21
+ class State(TypedDict):
22
+ messages: Annotated[list, add_messages]
23
+
24
+ server_config_path = "mcp\server_config.json"
25
+
26
+ class JarvisMcpMemoryAgent:
27
+ def __init__(self, model_enum: ModelEnum):
28
+ self.model_enum = model_enum
29
+ self.exit_stack = None
30
+ self.tools = None
31
+ self.graph = None
32
+ self.memory = None
33
+ self._is_connected = False
34
+ self.sessions_to_tools = {}
35
+
36
+ def _create_langgraph_agent(self, model_enum: ModelEnum, tools, memory=None):
37
+ if model_enum == ModelEnum.GPT_3_5:
38
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
39
+ else:
40
+ raise ValueError(f"Unsupported model: {model_enum}")
41
+
42
+ graph_builder = StateGraph(State)
43
+ llm_with_tools = llm.bind_tools(tools)
44
+
45
+ def chatbot(state: State):
46
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
47
+
48
+ graph_builder.add_node("chatbot", chatbot)
49
+ graph_builder.add_node("tools", ToolNode(tools=tools))
50
+ graph_builder.add_conditional_edges("chatbot", tools_condition)
51
+ graph_builder.add_edge("tools", "chatbot")
52
+ graph_builder.set_entry_point("chatbot")
53
+
54
+ if memory is None:
55
+ memory = MemorySaver()
56
+
57
+ self.graph = graph_builder.compile(checkpointer=memory)
58
+ self.memory = memory
59
+
60
+
61
+ async def _connect_to_server(self, server_name, server_config):
62
+ server_params = StdioServerParameters(**server_config)
63
+ read, write = await self.exit_stack.enter_async_context(stdio_client(server_params))
64
+ session = await self.exit_stack.enter_async_context(ClientSession(read, write))
65
+ await session.initialize()
66
+ mcp_tools = await load_mcp_tools(session)
67
+ self.sessions_to_tools[session] = mcp_tools
68
+ self.tools.extend(mcp_tools)
69
+
70
+
71
+ async def initialize_mcp_connection(self):
72
+ if self._is_connected:
73
+ print("[INFO] Initialize MCP Connection called, but agent MCP services are already initialized")
74
+ return
75
+
76
+ self.exit_stack = AsyncExitStack()
77
+ await self.exit_stack.__aenter__() # TODO: Check if neccesary
78
+
79
+ with open(server_config_path, "r") as file:
80
+ data = json.load(file)
81
+ servers = data.get("mcpServers", {})
82
+
83
+ for server_name, server_config in servers.items():
84
+ await self._connect_to_server(server_name, server_config)
85
+
86
+ self._is_connected = True
87
+
88
+ async def setup_mcp(self):
89
+ self.tools = local_tools
90
+
91
+ await self.initialize_mcp_connection()
92
+
93
+ self._create_langgraph_agent(self.model_enum, self.tools, memory=self.memory)
94
+
95
+ async def ainvoke(self, **kwargs):
96
+ if not self._is_connected:
97
+ await self.setup_mcp()
98
+
99
+ result = await self.graph.ainvoke(**kwargs)
100
+ return result
101
+
102
+ async def aclose(self):
103
+ if self.exit_stack:
104
+ await self.exit_stack.aclose()
105
+ self.exit_stack = None
106
+ self._is_connected = False
107
+
108
+ def invoke(self, **kwargs):
109
+ async def _wrapped():
110
+ try:
111
+ return await self.ainvoke(**kwargs)
112
+ finally:
113
+ await self.aclose()
114
+ return asyncio.run(_wrapped())
mcp/server_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mcpServers": {
3
+ "math": {
4
+ "command": "python",
5
+ "args": [
6
+ "mcp/servers/math_server.py"
7
+ ]
8
+ }
9
+ }
10
+ }
11
+