YigitSekerci commited on
Commit
eff95ca
·
1 Parent(s): 026baed

simplify agent

Browse files
Files changed (1) hide show
  1. src/agent.py +59 -339
src/agent.py CHANGED
@@ -1,370 +1,90 @@
1
  import asyncio
2
- import json
3
- import logging
4
- from typing import List, Dict, Any, Optional, Tuple, Union
5
- from langchain_mcp_adapters.client import MultiServerMCPClient
6
- from langgraph.prebuilt import create_react_agent
7
- from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
8
- from langchain_core.output_parsers import StrOutputParser
9
- from langchain_core.exceptions import OutputParserException
10
  from dotenv import load_dotenv
11
- from langchain_openai import ChatOpenAI
12
-
13
- # Configure logging
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
16
-
17
- load_dotenv()
18
-
19
- class AudioAgentError(Exception):
20
- """Custom exception for AudioAgent errors"""
21
- pass
22
-
23
-
24
- class AudioAgentInitializationError(AudioAgentError):
25
- """Raised when agent initialization fails"""
26
- pass
27
-
28
-
29
- class AudioAgentChatError(AudioAgentError):
30
- """Raised when chat processing fails"""
31
- pass
32
 
 
 
 
33
 
34
  class AudioAgent:
35
  """
36
- A class to manage an audio-focused AI agent with MCP tools integration.
37
-
38
- This agent connects to audio tools via MCP and provides a conversational interface
39
- using LangChain's robust message handling and output parsing.
40
  """
41
-
42
- def __init__(self, model_name: str = "gpt-4o", server_url: str = "http://127.0.0.1:7860/gradio_api/mcp/sse"):
43
- """
44
- Initialize the AudioAgent.
45
-
46
- Args:
47
- model_name: The language model to use for the agent
48
- server_url: The URL of the MCP server providing audio tools
49
- """
50
  self.model_name = model_name
51
  self.server_url = server_url
52
- self._agent = None
53
- self._tools = None
54
- self._llm = None
55
- self._is_initialized = False
56
- self._output_parser = StrOutputParser()
57
-
58
- # Initialize MCP client
59
  self._client = MultiServerMCPClient({
60
- "audio-tools": {
61
- "url": server_url,
62
- "transport": "sse",
63
- }
64
  })
65
-
 
 
66
  @property
67
  def is_initialized(self) -> bool:
68
- """Check if the agent is initialized and ready to use."""
69
- return self._is_initialized
70
-
71
  async def initialize(self) -> None:
72
- """
73
- Initialize the agent with tools from the MCP client.
74
-
75
- Raises:
76
- AudioAgentInitializationError: If initialization fails
77
- """
78
- if self._is_initialized:
79
- logger.info("Agent already initialized")
80
  return
81
-
82
- try:
83
- logger.info("Initializing AudioAgent...")
84
-
85
- logger.info(f"Initializing LLM: {self.model_name}")
86
- self._llm = ChatOpenAI(model=self.model_name, temperature=0, streaming=True)
87
 
88
- # Get tools from MCP client
89
- self._tools = await self._client.get_tools()
90
- if not self._tools:
91
- raise AudioAgentInitializationError("No tools available from MCP client")
92
-
93
- logger.info(f"Loaded {len(self._tools)} tools: {[tool.name for tool in self._tools]}")
94
-
95
- # Create the agent
96
- self._agent = create_react_agent(
97
- self._llm,
98
- self._tools,
99
- )
100
-
101
- self._is_initialized = True
102
- logger.info("AudioAgent initialized successfully")
103
-
104
- except Exception as e:
105
- error_msg = f"Failed to initialize AudioAgent: {str(e)}"
106
- logger.error(error_msg)
107
- raise AudioAgentInitializationError(error_msg) from e
108
-
109
- def _convert_to_langchain_messages(self, history: List[Tuple[str, Optional[str]]]) -> List[BaseMessage]:
110
- """
111
- Convert chat history to LangChain message objects.
112
-
113
- Args:
114
- history: List of (human_message, ai_response) tuples
115
-
116
- Returns:
117
- List of LangChain BaseMessage objects
118
- """
119
- messages = []
120
- for human_msg, ai_msg in history:
121
- if human_msg and human_msg.strip():
122
- messages.append(HumanMessage(content=human_msg.strip()))
123
- if ai_msg and ai_msg.strip():
124
- messages.append(AIMessage(content=ai_msg.strip()))
125
- return messages
126
-
127
- async def _extract_response_content(self, response: Dict[str, Any]) -> str:
128
- """
129
- Extract the content from the agent's response using LangChain output parser.
130
-
131
- Args:
132
- response: The response from the agent
133
-
134
- Returns:
135
- The extracted content as a string
136
-
137
- Raises:
138
- AudioAgentChatError: If response parsing fails
139
- """
140
- try:
141
- if not response:
142
- raise OutputParserException("Received empty response from agent")
143
-
144
- if "messages" not in response or not response["messages"]:
145
- raise OutputParserException("No messages found in agent response")
146
-
147
- last_message = response["messages"][-1]
148
-
149
- # Handle different message formats
150
- if hasattr(last_message, 'content'):
151
- content = last_message.content
152
- elif isinstance(last_message, dict) and 'content' in last_message:
153
- content = last_message['content']
154
- else:
155
- content = str(last_message)
156
-
157
- # Use LangChain's output parser for robust string processing
158
- parsed_content = await self._output_parser.aparse(content)
159
- return parsed_content if parsed_content else "I couldn't generate a response."
160
-
161
- except OutputParserException as e:
162
- logger.warning(f"Output parsing failed: {e}")
163
- raise AudioAgentChatError(f"Failed to parse agent response: {str(e)}") from e
164
- except Exception as e:
165
- logger.error(f"Unexpected error in response extraction: {e}")
166
- raise AudioAgentChatError(f"Error extracting response content: {str(e)}") from e
167
-
168
- def _validate_message(self, message: str) -> str:
169
  """
170
- Validate and sanitize the input message.
171
-
172
- Args:
173
- message: The user's message
174
-
175
- Returns:
176
- The validated and sanitized message
177
-
178
- Raises:
179
- AudioAgentChatError: If message is invalid
180
  """
181
- if not message:
182
- raise AudioAgentChatError("Message cannot be None")
183
-
184
- cleaned_message = message.strip()
185
- if not cleaned_message:
186
- raise AudioAgentChatError("Message cannot be empty or only whitespace")
187
-
188
- if len(cleaned_message) > 10000:
189
- raise AudioAgentChatError("Message is too long (max 10,000 characters)")
190
-
191
- return cleaned_message
192
 
193
- async def chat(self, message: str, history: Optional[List[Tuple[str, Optional[str]]]] = None) -> str:
194
  """
195
- Process a chat message with the agent using LangChain's robust message handling.
196
-
197
- Args:
198
- message: The user's message
199
- history: Previous chat history as list of (human, ai) tuples
200
-
201
- Returns:
202
- The agent's response
203
-
204
- Raises:
205
- AudioAgentChatError: If chat processing fails
206
- AudioAgentInitializationError: If agent is not initialized
207
  """
208
- # Validate input
209
- validated_message = self._validate_message(message)
210
-
211
- # Ensure agent is initialized
212
- if not self._is_initialized:
213
  await self.initialize()
214
-
215
- try:
216
- # Convert history to LangChain messages
217
- langchain_messages = self._convert_to_langchain_messages(history or [])
218
-
219
- # Add current message
220
- langchain_messages.append(HumanMessage(content=validated_message))
221
-
222
- # Prepare input for the agent
223
- input_data = {"messages": langchain_messages}
224
-
225
- logger.info(f"Processing message: {validated_message[:50]}{'...' if len(validated_message) > 50 else ''}")
226
-
227
- # Get response from agent
228
- response = await self._agent.ainvoke(input_data)
229
-
230
- # Extract and return content using output parser
231
- content = await self._extract_response_content(response)
232
- logger.info("Message processed successfully")
233
- return content
234
-
235
- except AudioAgentChatError:
236
- # Re-raise our custom errors
237
- raise
238
- except Exception as e:
239
- error_msg = f"Failed to process chat message: {str(e)}"
240
- logger.error(error_msg)
241
- raise AudioAgentChatError(error_msg) from e
242
-
243
- def chat_sync(self, message: str, history: Optional[List[Tuple[str, Optional[str]]]] = None) -> str:
244
- """
245
- Synchronous wrapper for the async chat method.
246
-
247
- Args:
248
- message: The user's message
249
- history: Previous chat history as list of (human, ai) tuples
250
-
251
- Returns:
252
- The agent's response
253
- """
254
- try:
255
- return asyncio.run(self.chat(message, history))
256
- except Exception as e:
257
- logger.error(f"Error in synchronous chat: {e}")
258
- raise
259
-
260
- async def get_available_tools(self) -> List[str]:
261
- """
262
- Get the list of available tool names.
263
-
264
- Returns:
265
- List of tool names
266
-
267
- Raises:
268
- AudioAgentInitializationError: If initialization fails
269
- """
270
- try:
271
- if not self._is_initialized:
272
- await self.initialize()
273
- return [tool.name for tool in self._tools] if self._tools else []
274
- except Exception as e:
275
- error_msg = f"Failed to get available tools: {str(e)}"
276
- logger.error(error_msg)
277
- raise AudioAgentInitializationError(error_msg) from e
278
-
279
- async def stream_chat(self, message: str, history: Optional[List[Tuple[str, Optional[str]]]] = None):
280
  """
281
- Stream a chat response with intermediate steps.
282
-
283
- Args:
284
- message: The user's message
285
- history: Previous chat history as list of (human, ai) tuples
286
-
287
- Yields:
288
- Formatted strings for thought process and final response.
289
- The string is prefixed with 'thought:', 'response_chunk:', or 'error:'.
290
-
291
- Raises:
292
- AudioAgentChatError: If streaming fails
293
  """
294
- # Validate input
295
- validated_message = self._validate_message(message)
296
-
297
- # Ensure agent is initialized
298
- if not self._is_initialized:
299
  await self.initialize()
300
-
301
- try:
302
- # Convert history to LangChain messages
303
- langchain_messages = self._convert_to_langchain_messages(history or [])
304
-
305
- # Add current message
306
- langchain_messages.append(HumanMessage(content=validated_message))
307
-
308
- # Prepare input for the agent
309
- input_data = {"messages": langchain_messages}
310
-
311
- logger.info(f"Streaming message: {validated_message[:50]}{'...' if len(validated_message) > 50 else ''}")
312
-
313
- final_response = ""
314
- # Use astream_events to get intermediate steps
315
- async for event in self._agent.astream_events(input_data, version="v1"):
316
- kind = event["event"]
317
- if kind == "on_chat_model_stream":
318
- content = event["data"]["chunk"].content
319
- if content:
320
- final_response += content
321
- yield f"response_chunk:{content}"
322
-
323
- elif kind == "on_tool_start":
324
- yield f"thought:Calling tool `{event['name']}` with input:\n```json\n{json.dumps(event['data'].get('input'), indent=2)}\n```"
325
- elif kind == "on_tool_end":
326
- yield f"thought:Tool `{event['name']}` finished. Output:\n```\n{event['data'].get('output')}\n```"
327
 
328
- if not final_response:
329
- logger.warning("Streaming finished but no final response was generated.")
330
- yield "response_chunk:I couldn't generate a response."
331
-
332
- except Exception as e:
333
- error_msg = f"Failed to stream chat message: {str(e)}"
334
- logger.error(error_msg, exc_info=True)
335
- yield f"error:{error_msg}"
336
- # Re-raising the exception might be too much if the error is already yielded
337
- # raise AudioAgentChatError(error_msg) from e
338
 
339
  async def main():
340
- """Example usage and testing"""
341
- try:
342
- # Create and initialize agent
343
- agent = AudioAgent()
344
- await agent.initialize()
345
-
346
- # Show available tools
347
- tools = await agent.get_available_tools()
348
- print(f"Available tools: {tools}")
349
-
350
- # Test chat
351
- #response = await agent.chat("What tools do you have?")
352
- #print(f"Agent response: {response}")
353
-
354
- # Test streaming (if supported)
355
- print("\nTesting streaming:")
356
- full_response = ""
357
- async for chunk in agent.stream_chat("Tell me about audio processing"):
358
- if chunk.startswith("response_chunk:"):
359
- full_response += chunk[len("response_chunk:"):]
360
- else:
361
- print(chunk)
362
- print(f"Final response: {full_response}")
363
-
364
- except AudioAgentError as e:
365
- logger.error(f"AudioAgent error: {e}")
366
- except Exception as e:
367
- logger.error(f"Unexpected error: {e}")
368
 
369
  if __name__ == "__main__":
370
  asyncio.run(main())
 
1
  import asyncio
 
 
 
 
 
 
 
 
2
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from langchain_mcp_adapters.client import MultiServerMCPClient
5
+ from langgraph.prebuilt import create_react_agent
6
+ from langgraph.graph.graph import CompiledGraph
7
 
8
  class AudioAgent:
9
  """
10
+ Wraps a LangGraph REACT agent over your MCP audio-tools,
11
+ exposing both one-shot and streaming chat methods.
 
 
12
  """
13
+
14
+ def __init__(
15
+ self,
16
+ model_name: str = "gpt-4o",
17
+ server_url: str = "http://127.0.0.1:7860/gradio_api/mcp/sse",
18
+ ):
19
+ load_dotenv()
 
 
20
  self.model_name = model_name
21
  self.server_url = server_url
22
+
23
+ # SSE client for your audio tools
 
 
 
 
 
24
  self._client = MultiServerMCPClient({
25
+ "audio-tools": {"url": self.server_url, "transport": "sse"}
 
 
 
26
  })
27
+
28
+ self._agent = None
29
+
30
  @property
31
  def is_initialized(self) -> bool:
32
+ return self._agent is not None
33
+
 
34
  async def initialize(self) -> None:
35
+ """Fetch tools from MCP and build a streaming-capable LangGraph REACT agent."""
36
+ if self.is_initialized:
 
 
 
 
 
 
37
  return
 
 
 
 
 
 
38
 
39
+ tools = await self._client.get_tools()
40
+ if not tools:
41
+ raise RuntimeError("No tools available from MCP server")
42
+
43
+ self._agent: CompiledGraph = create_react_agent(
44
+ model=self.model_name,
45
+ tools=tools,
46
+ prompt="""
47
+ You are a helpful assistant that can use the following tools to help the user.
48
+ """
49
+ )
50
+
51
+ def process_user_input(self, user_input: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  """
53
+ Process user input and return a prompt for the agent.
 
 
 
 
 
 
 
 
 
54
  """
55
+ return {"messages": [{"role": "user", "content": user_input}]}
 
 
 
 
 
 
 
 
 
 
56
 
57
+ async def chat(self, prompt: str) -> str:
58
  """
59
+ One-shot chat: returns the full LLM + tool-augmented reply.
 
 
 
 
 
 
 
 
 
 
 
60
  """
61
+ if not self.is_initialized:
 
 
 
 
62
  await self.initialize()
63
+ return await self._agent.ainvoke(self.process_user_input(prompt))
64
+
65
+ async def stream_chat(self, prompt: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  """
67
+ Streaming chat: prints tokens live and returns the full reply at the end.
 
 
 
 
 
 
 
 
 
 
 
68
  """
69
+ if not self.is_initialized:
 
 
 
 
70
  await self.initialize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ async for msg, metadata in self._agent.astream(
73
+ self.process_user_input(prompt),
74
+ stream_mode="messages"
75
+ ):
76
+ if msg.content:
77
+ yield msg.content
 
 
 
 
78
 
79
  async def main():
80
+ agent = AudioAgent()
81
+ # one-shot example
82
+ reply = await agent.chat("Hi! What audio tools are available?")
83
+ print("→", reply)
84
+
85
+ # streaming example
86
+ async for msg in agent.stream_chat("Explain how audio normalization works."):
87
+ print(msg, end="", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  if __name__ == "__main__":
90
  asyncio.run(main())