akseljoonas HF Staff commited on
Commit
0c252e4
·
1 Parent(s): 0d2b9f7

intermediary commit

Browse files
agent/config.py CHANGED
@@ -7,7 +7,7 @@ class MCPServerConfig(BaseModel):
7
 
8
  name: str
9
  command: str
10
- args: list[str]
11
  env: dict[str, str] | None = None
12
 
13
 
@@ -15,8 +15,8 @@ class Config(BaseModel):
15
  """Configuration manager"""
16
 
17
  model_name: str
18
- tools: list[Tool]
19
- system_prompt_path: str
20
  mcp_servers: list[MCPServerConfig] = []
21
 
22
 
 
7
 
8
  name: str
9
  command: str
10
+ args: list[str] = []
11
  env: dict[str, str] | None = None
12
 
13
 
 
15
  """Configuration manager"""
16
 
17
  model_name: str
18
+ tools: list[Tool] = []
19
+ system_prompt_path: str = ""
20
  mcp_servers: list[MCPServerConfig] = []
21
 
22
 
agent/context_manager/manager.py CHANGED
@@ -10,9 +10,7 @@ class ContextManager:
10
 
11
  def __init__(self):
12
  self.system_prompt = self._load_system_prompt()
13
- self.items: list[Message] = [
14
- Message(role="system", content=self.system_prompt)
15
- ]
16
 
17
  def _load_system_prompt(self):
18
  """Load the system prompt"""
@@ -35,8 +33,10 @@ class ContextManager:
35
  return
36
 
37
  # Always keep system prompt
38
- system_msg = self.items[0] if self.items and self.items[0].role == "system" else None
39
- messages_to_keep = self.items[-(target_size - 1):]
 
 
40
 
41
  if system_msg:
42
  self.items = [system_msg] + messages_to_keep
 
10
 
11
  def __init__(self):
12
  self.system_prompt = self._load_system_prompt()
13
+ self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
 
 
14
 
15
  def _load_system_prompt(self):
16
  """Load the system prompt"""
 
33
  return
34
 
35
  # Always keep system prompt
36
+ system_msg = (
37
+ self.items[0] if self.items and self.items[0].role == "system" else None
38
+ )
39
+ messages_to_keep = self.items[-(target_size - 1) :]
40
 
41
  if system_msg:
42
  self.items = [system_msg] + messages_to_keep
agent/core/__init__.py CHANGED
@@ -3,7 +3,13 @@ Core agent implementation
3
  Contains the main agent logic, decision-making, and orchestration
4
  """
5
 
6
- from agent.core.executor import ToolExecutor
7
- from agent.core.mcp_client import MCPClient, MCPServerConfig
8
 
9
- __all__ = ["ToolExecutor", "MCPClient", "MCPServerConfig"]
 
 
 
 
 
 
 
3
  Contains the main agent logic, decision-making, and orchestration
4
  """
5
 
6
+ from agent.core.mcp_client import McpClient, McpConnectionManager
7
+ from agent.core.tools import ToolRouter, ToolSpec, create_builtin_tools
8
 
9
+ __all__ = [
10
+ "McpClient",
11
+ "McpConnectionManager",
12
+ "ToolRouter",
13
+ "ToolSpec",
14
+ "create_builtin_tools",
15
+ ]
agent/core/agent_loop.py CHANGED
@@ -1,15 +1,11 @@
1
  """
2
- Main agent implementation
3
  """
4
 
5
  import asyncio
 
6
 
7
- from litellm import (
8
- ChatCompletionMessageToolCall,
9
- Message,
10
- ModelResponse,
11
- acompletion,
12
- )
13
 
14
  from agent.config import Config
15
  from agent.core.session import Event, OpType, Session
@@ -36,14 +32,16 @@ class Handlers:
36
  iteration = 0
37
  while iteration < max_iterations:
38
  messages = session.context_manager.get_messages()
39
- print(f"Messages: {messages}")
40
 
41
  try:
42
  response: ModelResponse = await acompletion(
43
  model=session.config.model_name,
44
  messages=messages,
45
- tools=session.config.tools,
 
46
  )
 
47
  message = response.choices[0].message
48
 
49
  # Extract content and tool calls
@@ -56,46 +54,58 @@ class Handlers:
56
  session.context_manager.add_message(assistant_msg)
57
 
58
  await session.send_event(
59
- Event(
60
- event_type="assistant_message",
61
- data={"message": assistant_msg},
62
- )
63
  )
64
 
65
  # If no tool calls, we're done
66
  if not tool_calls:
67
  break
68
 
69
- for tool_call in tool_calls:
70
- print(f"Executing tool: {tool_call.function.name}")
71
- result = await session.tool_executor.execute_tool(tool_call)
72
- print(result)
73
- tool_output = Message(
74
- role="tool", content=result.output, success=result.success
 
 
 
 
 
 
 
 
75
  )
76
- session.context_manager.add_message(tool_output)
 
 
 
 
 
 
 
 
77
 
78
  await session.send_event(
79
  Event(
80
  event_type="tool_output",
81
- data={"message": tool_output},
 
 
 
 
 
82
  )
83
  )
84
 
85
  iteration += 1
86
 
87
  except Exception as e:
88
- import traceback
89
-
90
  await session.send_event(
91
- Event(
92
- event_type="error",
93
- data={"error": traceback.print_exc() + str(e)},
94
- )
95
  )
96
  break
97
 
98
- # Send completion event
99
  await session.send_event(
100
  Event(
101
  event_type="turn_complete",
@@ -153,7 +163,7 @@ async def process_submission(session: Session, submission) -> bool:
153
 
154
  if op.op_type == OpType.USER_INPUT:
155
  text = op.data.get("text", "") if op.data else ""
156
- await Handlers.run_agent(session, text, max_iterations=10)
157
  return True
158
 
159
  if op.op_type == OpType.INTERRUPT:
@@ -178,23 +188,56 @@ async def process_submission(session: Session, submission) -> bool:
178
  async def submission_loop(
179
  submission_queue: asyncio.Queue,
180
  event_queue: asyncio.Queue,
 
181
  config: Config | None = None,
182
  ) -> None:
183
  """
184
  Main agent loop - processes submissions and dispatches to handlers.
185
  This is the core of the agent (like submission_loop in codex.rs:1259-1340)
186
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  session = Session(event_queue, config=config)
 
188
  print("🤖 Agent loop started")
189
 
190
- # Initialize MCP connections
191
- if session.config.mcp_servers:
192
- try:
193
- print(f"Initializing MCP connections for {session.config.mcp_servers}")
194
- await session.initialize_mcp()
195
- except Exception as e:
196
- print(f"⚠️ Error initializing MCP: {e}")
197
-
198
  try:
199
  # Main processing loop
200
  while session.is_running:
@@ -208,8 +251,12 @@ async def submission_loop(
208
  break
209
  except Exception as e:
210
  print(f"❌ Error in agent loop: {e}")
211
- await session.send_event(Event(event_type="error", data={"error": str(e)}))
 
 
212
  finally:
213
- await session.cleanup()
 
 
214
 
215
  print("🛑 Agent loop exited")
 
1
  """
2
+ Main agent implementation with integrated tool system and MCP support
3
  """
4
 
5
  import asyncio
6
+ import json
7
 
8
+ from litellm import ChatCompletionMessageToolCall, Message, ModelResponse, acompletion
 
 
 
 
 
9
 
10
  from agent.config import Config
11
  from agent.core.session import Event, OpType, Session
 
32
  iteration = 0
33
  while iteration < max_iterations:
34
  messages = session.context_manager.get_messages()
35
+ tools = session.tool_router.get_tool_specs_for_llm()
36
 
37
  try:
38
  response: ModelResponse = await acompletion(
39
  model=session.config.model_name,
40
  messages=messages,
41
+ tools=tools,
42
+ tool_choice="auto",
43
  )
44
+
45
  message = response.choices[0].message
46
 
47
  # Extract content and tool calls
 
54
  session.context_manager.add_message(assistant_msg)
55
 
56
  await session.send_event(
57
+ Event(event_type="assistant_message", data={"content": content})
 
 
 
58
  )
59
 
60
  # If no tool calls, we're done
61
  if not tool_calls:
62
  break
63
 
64
+ # Execute tools
65
+ for tc in tool_calls:
66
+ tool_name = tc.function.name
67
+ tool_args = json.loads(tc.function.arguments)
68
+
69
+ await session.send_event(
70
+ Event(
71
+ event_type="tool_call",
72
+ data={"tool": tool_name, "arguments": tool_args},
73
+ )
74
+ )
75
+
76
+ output, success = await session.tool_router.execute_tool(
77
+ tool_name, tool_args
78
  )
79
+
80
+ # Add tool result to history
81
+ tool_msg = Message(
82
+ role="tool",
83
+ content=output,
84
+ tool_call_id=tc.id,
85
+ name=tool_name,
86
+ )
87
+ session.context_manager.add_message(tool_msg)
88
 
89
  await session.send_event(
90
  Event(
91
  event_type="tool_output",
92
+ data={
93
+ "tool": tool_name,
94
+ "output": output[:200]
95
+ + ("..." if len(output) > 200 else ""),
96
+ "success": success,
97
+ },
98
  )
99
  )
100
 
101
  iteration += 1
102
 
103
  except Exception as e:
 
 
104
  await session.send_event(
105
+ Event(event_type="error", data={"error": str(e)})
 
 
 
106
  )
107
  break
108
 
 
109
  await session.send_event(
110
  Event(
111
  event_type="turn_complete",
 
163
 
164
  if op.op_type == OpType.USER_INPUT:
165
  text = op.data.get("text", "") if op.data else ""
166
+ await Handlers.user_input(session, text)
167
  return True
168
 
169
  if op.op_type == OpType.INTERRUPT:
 
188
  async def submission_loop(
189
  submission_queue: asyncio.Queue,
190
  event_queue: asyncio.Queue,
191
+ tool_router=None,
192
  config: Config | None = None,
193
  ) -> None:
194
  """
195
  Main agent loop - processes submissions and dispatches to handlers.
196
  This is the core of the agent (like submission_loop in codex.rs:1259-1340)
197
  """
198
+ # Import here to avoid circular imports
199
+ from agent.core.mcp_client import McpConnectionManager
200
+ from agent.core.tools import ToolRouter, create_builtin_tools
201
+
202
+ # Initialize MCP and tools
203
+ if tool_router is None:
204
+ mcp_manager = McpConnectionManager()
205
+
206
+ # Add MCP servers from config
207
+ if config and config.mcp_servers:
208
+ print("🔌 Initializing MCP connections...")
209
+ for server_config in config.mcp_servers:
210
+ try:
211
+ await mcp_manager.add_server(
212
+ server_name=server_config.name,
213
+ command=server_config.command,
214
+ args=server_config.args,
215
+ env=server_config.env,
216
+ )
217
+ except Exception as e:
218
+ print(
219
+ f"⚠️ Failed to connect to MCP server {server_config.name}: {e}"
220
+ )
221
+
222
+ # Create tool router
223
+ tool_router = ToolRouter(mcp_manager)
224
+
225
+ # Register built-in tools
226
+ for tool in create_builtin_tools():
227
+ tool_router.register_tool(tool)
228
+
229
+ # Register MCP tools
230
+ tool_router.register_mcp_tools()
231
+
232
+ print(f"📦 Registered {len(tool_router.tools)} tools:")
233
+ for tool_name in tool_router.tools.keys():
234
+ print(f" - {tool_name}")
235
+
236
+ # Create session and assign tool router
237
  session = Session(event_queue, config=config)
238
+ session.tool_router = tool_router
239
  print("🤖 Agent loop started")
240
 
 
 
 
 
 
 
 
 
241
  try:
242
  # Main processing loop
243
  while session.is_running:
 
251
  break
252
  except Exception as e:
253
  print(f"❌ Error in agent loop: {e}")
254
+ await session.send_event(
255
+ Event(event_type="error", data={"error": str(e)})
256
+ )
257
  finally:
258
+ # Cleanup MCP connections
259
+ if hasattr(tool_router, "mcp_manager") and tool_router.mcp_manager:
260
+ await tool_router.mcp_manager.shutdown_all()
261
 
262
  print("🛑 Agent loop exited")
agent/core/executor.py DELETED
@@ -1,48 +0,0 @@
1
- """
2
- Task execution engine
3
- """
4
-
5
- import json
6
- from typing import Any, List
7
-
8
- from litellm import ChatCompletionMessageToolCall
9
- from pydantic import BaseModel
10
-
11
- ToolCall = ChatCompletionMessageToolCall
12
-
13
-
14
- class ToolResult(BaseModel):
15
- output: str
16
- success: bool
17
-
18
-
19
- class ToolExecutor:
20
- """Executes planned tasks using available tools"""
21
-
22
- def __init__(self, tools: List[Any] = None, mcp_client=None):
23
- self.tools = tools or []
24
- self.mcp_client = mcp_client
25
-
26
- async def execute_tool(self, tool_call: ToolCall) -> ToolResult:
27
- """Execute a single step in the plan"""
28
- tool_name = tool_call.function.name
29
-
30
- # Parse arguments
31
- try:
32
- if isinstance(tool_call.function.arguments, str):
33
- tool_args = json.loads(tool_call.function.arguments)
34
- else:
35
- tool_args = tool_call.function.arguments
36
- except json.JSONDecodeError as e:
37
- return ToolResult(
38
- output=f"Error parsing tool arguments: {str(e)}", success=False
39
- )
40
-
41
- # Check if this is an MCP tool (prefixed with server name)
42
- if self.mcp_client and "__" in tool_name:
43
- success, result = await self.mcp_client.call_tool(tool_name, tool_args)
44
- return ToolResult(output=result, success=success)
45
-
46
- # If not an MCP tool, try local tools
47
- # TODO: Implement local tool execution
48
- return ToolResult(output=f"Tool {tool_name} not found", success=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/core/mcp_client.py CHANGED
@@ -1,149 +1,164 @@
1
  """
2
  MCP (Model Context Protocol) client integration for the agent
 
3
  """
4
 
 
5
  from contextlib import AsyncExitStack
6
- from typing import Optional
7
 
8
- from mcp import ClientSession, StdioServerParameters
9
  from mcp.client.stdio import stdio_client
10
 
11
 
12
- class MCPServerConfig:
13
- """Configuration for an MCP server"""
 
 
 
14
 
15
  def __init__(
16
  self,
17
- name: str,
18
  command: str,
19
- args: list[str],
20
- env: Optional[dict[str, str]] = None,
21
  ):
22
- self.name = name
23
  self.command = command
24
- self.args = args
25
- self.env = env
26
-
27
-
28
- class MCPClient:
29
- """
30
- Manages connections to MCP servers and provides tool access
31
- """
32
-
33
- def __init__(self):
34
- self.sessions: dict[str, ClientSession] = {}
35
  self.exit_stack = AsyncExitStack()
36
- self._tools_cache: Optional[list[dict]] = None
37
 
38
- async def connect_to_server(self, server_config: MCPServerConfig) -> None:
39
- """
40
- Connect to an MCP server
 
41
 
42
- Args:
43
- server_config: Configuration for the MCP server
44
- """
45
  server_params = StdioServerParameters(
46
- command=server_config.command,
47
- args=server_config.args,
48
- env=server_config.env,
49
  )
50
 
51
- stdio_transport = await self.exit_stack.enter_async_context(
 
52
  stdio_client(server_params)
53
  )
54
- stdio, write = stdio_transport
55
- session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
56
-
57
- await session.initialize()
58
-
59
- # Store the session
60
- self.sessions[server_config.name] = session
61
-
62
- # Invalidate tools cache
63
- self._tools_cache = None
64
-
65
- print(f"✅ Connected to MCP server: {server_config.name}")
66
-
67
- async def list_tools(self) -> list[dict]:
68
- """
69
- Get all available tools from all connected servers
70
-
71
- Returns:
72
- List of tool definitions compatible with LiteLLM format
73
- """
74
- if self._tools_cache is not None:
75
- return self._tools_cache
76
-
77
- all_tools = []
78
-
79
- for server_name, session in self.sessions.items():
80
- try:
81
- response = await session.list_tools()
82
- for tool in response.tools:
83
- # Convert MCP tool format to LiteLLM tool format
84
- tool_def = {
85
- "type": "function",
86
- "function": {
87
- "name": f"{server_name}__{tool.name}", # Prefix with server name
88
- "description": tool.description or "",
89
- "parameters": tool.inputSchema,
90
- },
91
- }
92
- all_tools.append(tool_def)
93
- except Exception as e:
94
- print(f"⚠️ Error listing tools from {server_name}: {e}")
95
-
96
- self._tools_cache = all_tools
97
- return all_tools
98
 
99
- async def call_tool(self, tool_name: str, tool_args: dict) -> tuple[bool, str]:
100
- """
101
- Call a tool on the appropriate MCP server
 
102
 
103
- Args:
104
- tool_name: Name of the tool (format: "server_name__tool_name")
105
- tool_args: Arguments to pass to the tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- Returns:
108
- Tuple of (success, result_content)
109
- """
110
- # Parse server name from tool name
111
- if "__" not in tool_name:
112
- return False, f"Invalid tool name format: {tool_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- server_name, actual_tool_name = tool_name.split("__", 1)
 
 
115
 
116
- if server_name not in self.sessions:
117
- return False, f"Server not found: {server_name}"
 
118
 
119
- session = self.sessions[server_name]
120
 
121
- try:
122
- result = await session.call_tool(actual_tool_name, tool_args)
123
-
124
- # Extract content from result
125
- if hasattr(result, "content"):
126
- if isinstance(result.content, list):
127
- # Handle list of content items
128
- content_parts = []
129
- for item in result.content:
130
- if hasattr(item, "text"):
131
- content_parts.append(item.text)
132
- else:
133
- content_parts.append(str(item))
134
- content = "\n".join(content_parts)
135
- else:
136
- content = str(result.content)
137
- else:
138
- content = str(result)
139
-
140
- return True, content
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  except Exception as e:
143
- return False, f"Error calling tool {tool_name}: {str(e)}"
 
 
 
 
 
 
 
 
144
 
145
- async def cleanup(self) -> None:
146
- """Clean up all MCP connections"""
147
- await self.exit_stack.aclose()
148
- self.sessions.clear()
149
- self._tools_cache = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  MCP (Model Context Protocol) client integration for the agent
3
+ Based on the official MCP SDK implementation
4
  """
5
 
6
+ import os
7
  from contextlib import AsyncExitStack
8
+ from typing import Any, Optional
9
 
10
+ from mcp import ClientSession, StdioServerParameters, types
11
  from mcp.client.stdio import stdio_client
12
 
13
 
14
+ class McpClient:
15
+ """
16
+ Client for connecting to MCP servers using the official MCP SDK.
17
+ Based on codex-rs/core/src/mcp_connection_manager.rs
18
+ """
19
 
20
  def __init__(
21
  self,
22
+ server_name: str,
23
  command: str,
24
+ args: list[str] | None = None,
25
+ env: dict[str, str] | None = None,
26
  ):
27
+ self.server_name = server_name
28
  self.command = command
29
+ self.args = args or []
30
+ self.env = env or {}
31
+ self.session: Optional[ClientSession] = None
32
+ self.tools: dict[str, dict[str, Any]] = {}
 
 
 
 
 
 
 
33
  self.exit_stack = AsyncExitStack()
 
34
 
35
+ async def start(self) -> None:
36
+ """Start the MCP server connection using official SDK"""
37
+ # Merge environment variables
38
+ full_env = {**dict(os.environ), **self.env} if self.env else None
39
 
40
+ # Create server parameters
 
 
41
  server_params = StdioServerParameters(
42
+ command=self.command,
43
+ args=self.args,
44
+ env=full_env,
45
  )
46
 
47
+ # Connect using stdio_client
48
+ read, write = await self.exit_stack.enter_async_context(
49
  stdio_client(server_params)
50
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Create session
53
+ self.session = await self.exit_stack.enter_async_context(
54
+ ClientSession(read, write)
55
+ )
56
 
57
+ # Initialize
58
+ await self.session.initialize()
59
+
60
+ # List available tools
61
+ tools_result = await self.session.list_tools()
62
+ for tool in tools_result.tools:
63
+ qualified_name = f"mcp__{self.server_name}__{tool.name}"
64
+ self.tools[qualified_name] = {
65
+ "name": tool.name,
66
+ "description": tool.description or "",
67
+ "inputSchema": tool.inputSchema,
68
+ }
69
+
70
+ async def call_tool(
71
+ self, tool_name: str, arguments: dict[str, Any]
72
+ ) -> tuple[str, bool]:
73
+ """Execute a tool on the MCP server"""
74
+ if not self.session:
75
+ return "Client not connected", False
76
+
77
+ # Strip the mcp__servername__ prefix to get the actual tool name
78
+ actual_tool_name = tool_name.split("__")[-1]
79
 
80
+ try:
81
+ result = await self.session.call_tool(actual_tool_name, arguments)
82
+
83
+ # Extract text from content
84
+ text_parts = []
85
+ for content in result.content:
86
+ if isinstance(content, types.TextContent):
87
+ text_parts.append(content.text)
88
+ elif isinstance(content, types.ImageContent):
89
+ text_parts.append(f"[Image: {content.mimeType}]")
90
+ elif isinstance(content, types.EmbeddedResource):
91
+ text_parts.append(f"[Resource: {content.resource}]")
92
+
93
+ output = "\n".join(text_parts) if text_parts else str(result.content)
94
+ success = not result.isError
95
+
96
+ return output, success
97
+ except Exception as e:
98
+ return f"Tool call failed: {str(e)}", False
99
 
100
+ def get_tools(self) -> dict[str, dict[str, Any]]:
101
+ """Get all available tools from this server"""
102
+ return self.tools.copy()
103
 
104
+ async def shutdown(self) -> None:
105
+ """Shutdown the MCP server connection"""
106
+ await self.exit_stack.aclose()
107
 
 
108
 
109
+ class McpConnectionManager:
110
+ """
111
+ Manages multiple MCP server connections.
112
+ Based on codex-rs/core/src/mcp_connection_manager.rs
113
+ """
114
+
115
+ def __init__(self):
116
+ self.clients: dict[str, McpClient] = {}
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ async def add_server(
119
+ self,
120
+ server_name: str,
121
+ command: str,
122
+ args: list[str] | None = None,
123
+ env: dict[str, str] | None = None,
124
+ ) -> bool:
125
+ """Add and start an MCP server"""
126
+ try:
127
+ client = McpClient(server_name, command, args, env)
128
+ await client.start()
129
+ self.clients[server_name] = client
130
+ print(
131
+ f"✅ MCP server '{server_name}' connected with {len(client.tools)} tools"
132
+ )
133
+ return True
134
  except Exception as e:
135
+ print(f" Failed to start MCP server '{server_name}': {e}")
136
+ return False
137
+
138
+ def list_all_tools(self) -> dict[str, dict[str, Any]]:
139
+ """Aggregate tools from all connected servers"""
140
+ all_tools = {}
141
+ for client in self.clients.values():
142
+ all_tools.update(client.get_tools())
143
+ return all_tools
144
 
145
+ async def call_tool(
146
+ self, tool_name: str, arguments: dict[str, Any]
147
+ ) -> tuple[str, bool]:
148
+ """Route tool call to the appropriate MCP server"""
149
+ # Extract server name from qualified tool name: mcp__servername__toolname
150
+ if tool_name.startswith("mcp__"):
151
+ parts = tool_name.split("__")
152
+ if len(parts) >= 3:
153
+ server_name = parts[1]
154
+ if server_name in self.clients:
155
+ return await self.clients[server_name].call_tool(
156
+ tool_name, arguments
157
+ )
158
+
159
+ return "Unknown MCP tool", False
160
+
161
+ async def shutdown_all(self) -> None:
162
+ """Shutdown all MCP servers"""
163
+ for client in self.clients.values():
164
+ await client.shutdown()
agent/core/session.py CHANGED
@@ -1,13 +1,10 @@
1
  import asyncio
 
2
  from enum import Enum
3
- from typing import Any, Literal
4
-
5
- from pydantic import BaseModel
6
 
7
  from agent.config import Config
8
  from agent.context_manager.manager import ContextManager
9
- from agent.core import ToolExecutor
10
- from agent.core.mcp_client import MCPClient, MCPServerConfig as MCPServerConfigClass
11
 
12
 
13
  class OpType(Enum):
@@ -19,19 +16,10 @@ class OpType(Enum):
19
  SHUTDOWN = "shutdown"
20
 
21
 
22
- class Event(BaseModel):
23
- event_type: Literal[
24
- "processing",
25
- "assistant_message",
26
- "tool_output",
27
- "turn_complete",
28
- "compacted",
29
- "undo_complete",
30
- "shutdown",
31
- "error",
32
- "interrupted",
33
- ]
34
- data: dict[str, Any] | None = None
35
 
36
 
37
  class Session:
@@ -40,7 +28,11 @@ class Session:
40
  Similar to Session in codex-rs/core/src/codex.rs
41
  """
42
 
43
- def __init__(self, event_queue: asyncio.Queue, config: Config | None = None):
 
 
 
 
44
  self.context_manager = ContextManager()
45
  self.event_queue = event_queue
46
  self.config = config or Config(
@@ -48,42 +40,9 @@ class Session:
48
  tools=[],
49
  system_prompt_path="",
50
  )
51
-
52
- # Initialize MCP client
53
- self.mcp_client = MCPClient()
54
- self.tool_executor = ToolExecutor(mcp_client=self.mcp_client)
55
-
56
  self.is_running = True
57
  self.current_task: asyncio.Task | None = None
58
- self._mcp_initialized = False
59
-
60
- async def initialize_mcp(self) -> None:
61
- """Initialize MCP server connections"""
62
- if self._mcp_initialized:
63
- return
64
-
65
- for server_config in self.config.mcp_servers:
66
- try:
67
- mcp_server_config = MCPServerConfigClass(
68
- name=server_config.name,
69
- command=server_config.command,
70
- args=server_config.args,
71
- env=server_config.env,
72
- )
73
- await self.mcp_client.connect_to_server(mcp_server_config)
74
- except Exception as e:
75
- print(f"⚠️ Failed to connect to MCP server {server_config.name}: {e}")
76
-
77
- # Get MCP tools and merge with config tools
78
- try:
79
- mcp_tools = await self.mcp_client.list_tools()
80
- # Merge with existing tools
81
- self.config.tools = list(self.config.tools) + mcp_tools
82
- print(f"📦 Loaded {len(mcp_tools)} tools from MCP servers")
83
- except Exception as e:
84
- print(f"⚠️ Error loading MCP tools: {e}")
85
-
86
- self._mcp_initialized = True
87
 
88
  async def send_event(self, event: Event) -> None:
89
  """Send event back to client"""
@@ -93,7 +52,3 @@ class Session:
93
  """Interrupt current running task"""
94
  if self.current_task and not self.current_task.done():
95
  self.current_task.cancel()
96
-
97
- async def cleanup(self) -> None:
98
- """Cleanup session resources"""
99
- await self.mcp_client.cleanup()
 
1
  import asyncio
2
+ from dataclasses import dataclass
3
  from enum import Enum
4
+ from typing import Any, Optional
 
 
5
 
6
  from agent.config import Config
7
  from agent.context_manager.manager import ContextManager
 
 
8
 
9
 
10
  class OpType(Enum):
 
16
  SHUTDOWN = "shutdown"
17
 
18
 
19
+ @dataclass
20
+ class Event:
21
+ event_type: str
22
+ data: Optional[dict[str, Any]] = None
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class Session:
 
28
  Similar to Session in codex-rs/core/src/codex.rs
29
  """
30
 
31
+ def __init__(
32
+ self,
33
+ event_queue: asyncio.Queue,
34
+ config: Config | None = None,
35
+ ):
36
  self.context_manager = ContextManager()
37
  self.event_queue = event_queue
38
  self.config = config or Config(
 
40
  tools=[],
41
  system_prompt_path="",
42
  )
 
 
 
 
 
43
  self.is_running = True
44
  self.current_task: asyncio.Task | None = None
45
+ self.tool_router = None # Set by submission_loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  async def send_event(self, event: Event) -> None:
48
  """Send event back to client"""
 
52
  """Interrupt current running task"""
53
  if self.current_task and not self.current_task.done():
54
  self.current_task.cancel()
 
 
 
 
agent/core/tools.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tool system for the agent
3
+ Provides ToolSpec and ToolRouter for managing both built-in and MCP tools
4
+ """
5
+
6
+ import subprocess
7
+ from dataclasses import dataclass
8
+ from typing import Any, Awaitable, Callable, Optional
9
+
10
+ from agent.core.mcp_client import McpConnectionManager
11
+
12
+
13
+ @dataclass
14
+ class ToolSpec:
15
+ """Tool specification for LLM"""
16
+
17
+ name: str
18
+ description: str
19
+ parameters: dict[str, Any]
20
+ handler: Optional[Callable[[dict[str, Any]], Awaitable[tuple[str, bool]]]] = None
21
+
22
+
23
+ class ToolRouter:
24
+ """
25
+ Routes tool calls to appropriate handlers.
26
+ Based on codex-rs/core/src/tools/router.rs
27
+ """
28
+
29
+ def __init__(self, mcp_manager: Optional[McpConnectionManager] = None):
30
+ self.tools: dict[str, ToolSpec] = {}
31
+ self.mcp_manager = mcp_manager
32
+
33
+ def register_tool(self, spec: ToolSpec) -> None:
34
+ """Register a tool with its handler"""
35
+ self.tools[spec.name] = spec
36
+
37
+ def register_mcp_tools(self) -> None:
38
+ """Register all MCP tools from the connection manager"""
39
+ if not self.mcp_manager:
40
+ return
41
+
42
+ mcp_tools = self.mcp_manager.list_all_tools()
43
+ for tool_name, tool_def in mcp_tools.items():
44
+ spec = ToolSpec(
45
+ name=tool_name,
46
+ description=tool_def.get("description", ""),
47
+ parameters=tool_def.get(
48
+ "inputSchema", {"type": "object", "properties": {}}
49
+ ),
50
+ handler=None, # MCP tools use the manager
51
+ )
52
+ self.tools[tool_name] = spec
53
+
54
+ def get_tool_specs_for_llm(self) -> list[dict[str, Any]]:
55
+ """Get tool specifications in OpenAI format"""
56
+ specs = []
57
+ for tool in self.tools.values():
58
+ specs.append(
59
+ {
60
+ "type": "function",
61
+ "function": {
62
+ "name": tool.name,
63
+ "description": tool.description,
64
+ "parameters": tool.parameters,
65
+ },
66
+ }
67
+ )
68
+ return specs
69
+
70
+ async def execute_tool(
71
+ self, tool_name: str, arguments: dict[str, Any]
72
+ ) -> tuple[str, bool]:
73
+ """Execute a tool by name"""
74
+ if tool_name not in self.tools:
75
+ return f"Unknown tool: {tool_name}", False
76
+
77
+ tool = self.tools[tool_name]
78
+
79
+ # MCP tool
80
+ if tool_name.startswith("mcp__") and self.mcp_manager:
81
+ return await self.mcp_manager.call_tool(tool_name, arguments)
82
+
83
+ # Built-in tool with handler
84
+ if tool.handler:
85
+ return await tool.handler(arguments)
86
+
87
+ return "Tool has no handler", False
88
+
89
+
90
+ # ============================================================================
91
+ # BUILT-IN TOOL HANDLERS
92
+ # ============================================================================
93
+
94
+
95
+ async def bash_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
96
+ """Execute bash command"""
97
+ try:
98
+ command = arguments.get("command", "")
99
+ result = subprocess.run(
100
+ command, shell=True, capture_output=True, text=True, timeout=30
101
+ )
102
+ output = result.stdout if result.returncode == 0 else result.stderr
103
+ success = result.returncode == 0
104
+ return output, success
105
+ except Exception as e:
106
+ return f"Error: {str(e)}", False
107
+
108
+
109
+ async def read_file_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
110
+ """Read file contents"""
111
+ try:
112
+ path = arguments.get("path", "")
113
+ with open(path, "r") as f:
114
+ content = f.read()
115
+ return content, True
116
+ except Exception as e:
117
+ return f"Error reading file: {str(e)}", False
118
+
119
+
120
+ async def write_file_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
121
+ """Write to file"""
122
+ try:
123
+ path = arguments.get("path", "")
124
+ content = arguments.get("content", "")
125
+ with open(path, "w") as f:
126
+ f.write(content)
127
+ return f"Successfully wrote to {path}", True
128
+ except Exception as e:
129
+ return f"Error writing file: {str(e)}", False
130
+
131
+
132
+ def create_builtin_tools() -> list[ToolSpec]:
133
+ """Create built-in tool specifications"""
134
+ return [
135
+ ToolSpec(
136
+ name="bash",
137
+ description="Execute a bash command and return its output",
138
+ parameters={
139
+ "type": "object",
140
+ "properties": {
141
+ "command": {
142
+ "type": "string",
143
+ "description": "The bash command to execute",
144
+ }
145
+ },
146
+ "required": ["command"],
147
+ },
148
+ handler=bash_handler,
149
+ ),
150
+ ToolSpec(
151
+ name="read_file",
152
+ description="Read the contents of a file",
153
+ parameters={
154
+ "type": "object",
155
+ "properties": {
156
+ "path": {
157
+ "type": "string",
158
+ "description": "Path to the file to read",
159
+ }
160
+ },
161
+ "required": ["path"],
162
+ },
163
+ handler=read_file_handler,
164
+ ),
165
+ ToolSpec(
166
+ name="write_file",
167
+ description="Write content to a file",
168
+ parameters={
169
+ "type": "object",
170
+ "properties": {
171
+ "path": {
172
+ "type": "string",
173
+ "description": "Path to the file to write",
174
+ },
175
+ "content": {
176
+ "type": "string",
177
+ "description": "Content to write to the file",
178
+ },
179
+ },
180
+ "required": ["path", "content"],
181
+ },
182
+ handler=write_file_handler,
183
+ ),
184
+ ]