YigitSekerci commited on
Commit
5e87361
·
1 Parent(s): f4438c7

try new agent style

Browse files
src/agent.py CHANGED
@@ -1,14 +1,62 @@
 
 
1
  from dotenv import load_dotenv
2
- from functools import partial
3
-
4
  from langchain_mcp_adapters.client import MultiServerMCPClient
5
- from langgraph.graph import StateGraph, END, START
6
 
7
- from .state import AgentState, InputState, OutputState
8
- from .nodes.chat import chat_node, chat_node_router
9
- from .nodes.planner import planner_node
10
- from .nodes.processor import processor_node
11
- from .nodes.validator import validator_node, validator_node_router
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  class AudioAgent:
14
  def __init__(
@@ -24,105 +72,30 @@ class AudioAgent:
24
  self._client = MultiServerMCPClient({
25
  "audio-tools": {"url": self.server_url, "transport": "sse"}
26
  })
 
 
27
 
28
- @property
29
- def is_initialized(self) -> bool:
30
- return self.graph is not None
31
-
32
- async def _build_graph(self) -> None:
33
- """Build the LangGraph workflow."""
34
-
35
- _graph = StateGraph(
36
- AgentState,
37
- input=InputState,
38
- output=OutputState
39
- )
40
 
41
- _graph.add_node("chat", chat_node)
42
- _graph.add_conditional_edges(
43
- "chat",
44
- chat_node_router,
45
- {
46
- "planner": "planner",
47
- "end": END
48
- }
49
  )
50
 
51
- _graph.add_node("planner", planner_node)
52
- _graph.add_edge("planner", "audio_processor")
53
-
54
- processor_node_with_tools = partial(processor_node, tools=self.tools)
55
- _graph.add_node("audio_processor", processor_node_with_tools)
56
- # TODO: add validator edge to here
57
- _graph.add_edge("audio_processor", "chat")
58
-
59
- _graph.add_node("validator", validator_node)
60
- _graph.add_conditional_edges(
61
- "validator",
62
- validator_node_router,
63
- {
64
- "chat": "chat",
65
- "planner": "planner"
66
- }
67
- )
68
-
69
- _graph.add_edge(START, "chat")
70
- _graph.add_edge("chat", END)
71
- self.graph = _graph.compile()
72
-
73
- async def initialize(self) -> None:
74
- """Initialize the LangGraph workflow with audio tools."""
75
- if self.is_initialized:
76
- return
77
-
78
- self.tools = await self._client.get_tools()
79
- if not self.tools:
80
- raise RuntimeError("No tools available from MCP server")
81
 
82
- await self._build_graph()
 
 
 
83
 
84
- def _extract_audio_paths(self, user_message: str) -> tuple[str, list[str]]:
85
- """Extract audio file paths from user message and return cleaned message."""
86
- audio_files = []
87
- lines = user_message.split('\n')
88
- clean_lines = []
89
-
90
- for line in lines:
91
- if line.strip().startswith('Audio file:'):
92
- # Extract the file path
93
- audio_path = line.replace('Audio file:', '').strip()
94
- audio_files.append(audio_path)
95
- else:
96
- clean_lines.append(line)
97
-
98
- clean_message = '\n'.join(clean_lines).strip()
99
- return clean_message, audio_files
100
-
101
- async def chat(self, user_message: str):
102
- """Stream chat responses with node information."""
103
- if not self.is_initialized:
104
- await self.initialize()
105
-
106
- # Extract audio file paths from the message
107
- clean_message, audio_files = self._extract_audio_paths(user_message)
108
-
109
- # Set up initial state
110
- initial_state = {
111
- "user_input": clean_message,
112
- "input_audio_files": audio_files,
113
- "steps_details": [],
114
- "plan": "",
115
- "final_response": "",
116
- "requires_processing": False,
117
- "validator_feedback": "",
118
- "output_audio_files": []
119
- }
120
-
121
- # Stream the graph execution
122
- return await self.graph.ainvoke(initial_state, stream_mode="values")
123
-
124
- def draw_graph(self) -> None:
125
- """Draw the graph to a file."""
126
- graph_image = self.graph.get_graph().draw_mermaid_png()
127
- with open("graph.png", "wb") as f:
128
- f.write(graph_image)
 
1
+ from langgraph.prebuilt import create_react_agent
2
+ from pydantic import BaseModel, Field
3
  from dotenv import load_dotenv
 
 
4
  from langchain_mcp_adapters.client import MultiServerMCPClient
 
5
 
6
+ class AgentOutput(BaseModel):
7
+ final_response: str = Field(description="The final response to the user.", default="")
8
+ output_audio_files: list[str] = Field(description="The output audio files.", default=[])
9
+
10
+ system_prompt = """You are an expert Audio Processing Assistant with specialized capabilities in audio manipulation, analysis, and editing. Your primary purpose is to help users with audio-related tasks and provide knowledgeable assistance in the audio domain.
11
+
12
+ ## Core Behavior Guidelines:
13
+
14
+ ### Conversation Scope:
15
+ - ONLY engage in conversations related to audio processing, audio editing, sound engineering, music production, audio analysis, audio formats, and related audio technologies
16
+ - If a user asks about topics outside the audio domain, politely decline and redirect them back to audio-related assistance
17
+ - Be conversational, friendly, and helpful when discussing audio topics
18
+ - Share your expertise about audio concepts, techniques, and best practices when relevant
19
+
20
+ ### Audio Processing Workflow:
21
+ When a user requests audio processing and provides input files, follow this structured approach:
22
+
23
+ 1. **ANALYSIS PHASE:**
24
+ - Analyze the user's request to understand their goals
25
+ - Examine the provided input audio files if available
26
+ - Identify what audio processing operations are needed
27
+
28
+ 2. **PLANNING PHASE:**
29
+ - Create a clear, step-by-step plan for the audio processing task
30
+ - Explain your plan to the user before execution
31
+ - Ensure the plan addresses their specific requirements
32
+
33
+ 3. **EXECUTION PHASE:**
34
+ - Use the available audio tools to implement your plan
35
+ - Process the audio files according to the planned steps
36
+ - Handle any errors or unexpected results gracefully
37
+
38
+ 4. **VALIDATION PHASE:**
39
+ - Verify that the processed audio meets the user's requirements
40
+ - Check the quality and correctness of the output
41
+ - Test that the processing achieved the desired results
42
+
43
+ 5. **RESPONSE PHASE:**
44
+ - Provide a clear summary of what was accomplished
45
+ - Include the output audio files in your response
46
+ - Offer additional suggestions or next steps if relevant
47
+
48
+ ## Available Context:
49
+ - You have access to input_audio_files when provided by the user
50
+ - You can generate output_audio_files through your audio processing tools
51
+ - Use your tools effectively to analyze, edit, convert, and manipulate audio
52
+
53
+ ## Response Format:
54
+ - Always provide helpful, accurate information about audio topics
55
+ - When processing audio, be transparent about your process and results
56
+ - Include relevant technical details when appropriate
57
+ - Maintain a professional yet approachable tone
58
+
59
+ Remember: Stay focused on audio-related assistance and use your specialized tools to help users achieve their audio processing goals efficiently and effectively."""
60
 
61
  class AudioAgent:
62
  def __init__(
 
72
  self._client = MultiServerMCPClient({
73
  "audio-tools": {"url": self.server_url, "transport": "sse"}
74
  })
75
+
76
+ self.agent = None
77
 
78
+ async def build_agent(self):
79
+ tools = await self._client.get_tools()
 
 
 
 
 
 
 
 
 
 
80
 
81
+ agent = create_react_agent(
82
+ model="gpt-4.1",
83
+ tools=tools,
84
+ prompt=system_prompt,
85
+ response_format=AgentOutput,
 
 
 
86
  )
87
 
88
+ return agent
89
+
90
+ async def run_agent(self, user_input: str, input_audio_files: list[str]):
91
+ if self.agent is None:
92
+ self.agent = await self.build_agent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ input_context = f"""
95
+ User Request: {user_input}
96
+ Input Audio Files: {', '.join(input_audio_files) if input_audio_files else 'None'}
97
+ """
98
 
99
+ return await self.agent.ainvoke(
100
+ {"messages": [{"role": "user", "content": input_context}]}
101
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/nodes/agent.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from functools import partial
3
+
4
+ from langchain_mcp_adapters.client import MultiServerMCPClient
5
+ from langgraph.graph import StateGraph, END, START
6
+
7
+ from .state import AgentState, InputState, OutputState
8
+ from .chat import chat_node, chat_node_router
9
+ from .planner import planner_node
10
+ from .processor import processor_node
11
+ from .validator import validator_node, validator_node_router
12
+
13
+ class AudioAgent:
14
+ def __init__(
15
+ self,
16
+ model_name: str = "gpt-4o",
17
+ server_url: str = "https://agents-mcp-hackathon-audioeditor.hf.space/gradio_api/mcp/sse",
18
+ ):
19
+ load_dotenv()
20
+ self.model_name = model_name
21
+ self.server_url = server_url
22
+ self.graph = None
23
+
24
+ self._client = MultiServerMCPClient({
25
+ "audio-tools": {"url": self.server_url, "transport": "sse"}
26
+ })
27
+
28
+ @property
29
+ def is_initialized(self) -> bool:
30
+ return self.graph is not None
31
+
32
+ async def _build_graph(self) -> None:
33
+ """Build the LangGraph workflow."""
34
+
35
+ _graph = StateGraph(
36
+ AgentState,
37
+ input=InputState,
38
+ output=OutputState
39
+ )
40
+
41
+ _graph.add_node("chat", chat_node)
42
+ _graph.add_conditional_edges(
43
+ "chat",
44
+ chat_node_router,
45
+ {
46
+ "planner": "planner",
47
+ "end": END
48
+ }
49
+ )
50
+
51
+ _graph.add_node("planner", planner_node)
52
+ _graph.add_edge("planner", "audio_processor")
53
+
54
+ processor_node_with_tools = partial(processor_node, tools=self.tools)
55
+ _graph.add_node("audio_processor", processor_node_with_tools)
56
+ # TODO: add validator edge to here
57
+ _graph.add_edge("audio_processor", "chat")
58
+
59
+ _graph.add_node("validator", validator_node)
60
+ _graph.add_conditional_edges(
61
+ "validator",
62
+ validator_node_router,
63
+ {
64
+ "chat": "chat",
65
+ "planner": "planner"
66
+ }
67
+ )
68
+
69
+ _graph.add_edge(START, "chat")
70
+ _graph.add_edge("chat", END)
71
+ self.graph = _graph.compile()
72
+
73
+ async def initialize(self) -> None:
74
+ """Initialize the LangGraph workflow with audio tools."""
75
+ if self.is_initialized:
76
+ return
77
+
78
+ self.tools = await self._client.get_tools()
79
+ if not self.tools:
80
+ raise RuntimeError("No tools available from MCP server")
81
+
82
+ await self._build_graph()
83
+
84
+ def _extract_audio_paths(self, user_message: str) -> tuple[str, list[str]]:
85
+ """Extract audio file paths from user message and return cleaned message."""
86
+ audio_files = []
87
+ lines = user_message.split('\n')
88
+ clean_lines = []
89
+
90
+ for line in lines:
91
+ if line.strip().startswith('Audio file:'):
92
+ # Extract the file path
93
+ audio_path = line.replace('Audio file:', '').strip()
94
+ audio_files.append(audio_path)
95
+ else:
96
+ clean_lines.append(line)
97
+
98
+ clean_message = '\n'.join(clean_lines).strip()
99
+ return clean_message, audio_files
100
+
101
+ async def chat(self, user_message: str):
102
+ """Stream chat responses with node information."""
103
+ if not self.is_initialized:
104
+ await self.initialize()
105
+
106
+ # Extract audio file paths from the message
107
+ clean_message, audio_files = self._extract_audio_paths(user_message)
108
+
109
+ # Set up initial state
110
+ initial_state = {
111
+ "user_input": clean_message,
112
+ "input_audio_files": audio_files,
113
+ "steps_details": [],
114
+ "plan": "",
115
+ "final_response": "",
116
+ "requires_processing": False,
117
+ "validator_feedback": "",
118
+ "output_audio_files": []
119
+ }
120
+
121
+ # Stream the graph execution
122
+ return await self.graph.ainvoke(initial_state, stream_mode="values")
123
+
124
+ def draw_graph(self) -> None:
125
+ """Draw the graph to a file."""
126
+ graph_image = self.graph.get_graph().draw_mermaid_png()
127
+ with open("graph.png", "wb") as f:
128
+ f.write(graph_image)
src/nodes/chat.py CHANGED
@@ -1,7 +1,7 @@
1
  from langchain_openai import ChatOpenAI
2
  from langchain_core.prompts import ChatPromptTemplate
3
  from langchain_core.runnables import RunnableParallel
4
- from src.state import AgentState, ChatInputState, ChatOutputState
5
  from operator import itemgetter
6
 
7
  def chat_node(state: ChatInputState) -> ChatOutputState:
 
1
  from langchain_openai import ChatOpenAI
2
  from langchain_core.prompts import ChatPromptTemplate
3
  from langchain_core.runnables import RunnableParallel
4
+ from nodes.state import AgentState, ChatInputState, ChatOutputState
5
  from operator import itemgetter
6
 
7
  def chat_node(state: ChatInputState) -> ChatOutputState:
src/nodes/planner.py CHANGED
@@ -1,7 +1,7 @@
1
  from langchain_openai import ChatOpenAI
2
  from langchain_core.prompts import ChatPromptTemplate
3
  from langchain_core.runnables import RunnableParallel
4
- from src.state import AgentState, PlannerInputState, PlannerOutputState
5
  from operator import itemgetter
6
 
7
  def planner_node(state: PlannerInputState) -> PlannerOutputState:
 
1
  from langchain_openai import ChatOpenAI
2
  from langchain_core.prompts import ChatPromptTemplate
3
  from langchain_core.runnables import RunnableParallel
4
+ from nodes.state import AgentState, PlannerInputState, PlannerOutputState
5
  from operator import itemgetter
6
 
7
  def planner_node(state: PlannerInputState) -> PlannerOutputState:
src/nodes/processor.py CHANGED
@@ -1,4 +1,4 @@
1
- from src.state import ProcessorInputState, ProcessorOutputState
2
  from langgraph.prebuilt import create_react_agent
3
  from pydantic import BaseModel, Field
4
 
 
1
+ from nodes.state import ProcessorInputState, ProcessorOutputState
2
  from langgraph.prebuilt import create_react_agent
3
  from pydantic import BaseModel, Field
4
 
src/{state.py → nodes/state.py} RENAMED
File without changes
src/nodes/validator.py CHANGED
@@ -1,6 +1,6 @@
1
  from langchain_openai import ChatOpenAI
2
  from langchain_core.prompts import ChatPromptTemplate
3
- from src.state import AgentState, ValidatorInputState, ValidatorOutputState
4
  from operator import itemgetter
5
  from langchain_core.runnables import RunnableParallel
6
 
 
1
  from langchain_openai import ChatOpenAI
2
  from langchain_core.prompts import ChatPromptTemplate
3
+ from nodes.state import AgentState, ValidatorInputState, ValidatorOutputState
4
  from operator import itemgetter
5
  from langchain_core.runnables import RunnableParallel
6
 
src/ui.py CHANGED
@@ -33,31 +33,42 @@ def user_input(user_message, audio_files, history):
33
 
34
  audio_file_urls.append(get_share_url(file_path))
35
 
 
36
  if audio_file_urls:
37
- audio_list = "\n".join([f"Audio file: {url}" for url in audio_file_urls])
38
- combined_message = f"{user_message}\n\n{audio_list}" if user_message.strip() else audio_list
39
  else:
40
  combined_message = user_message
41
 
42
  history.append({"role": "user", "content": combined_message})
43
- return "", [], history
44
 
45
- async def bot_response(history):
46
  """
47
- Generate bot response using the simple chat method
48
  """
49
  if not history or history[-1]["role"] != "user":
50
  return history
51
 
 
52
  user_message = history[-1]["content"]
53
 
54
- try:
55
- # Initialize agent if not already done
56
- if not agent.is_initialized:
57
- await agent.initialize()
 
 
 
 
58
 
59
- # Get the response from the agent
60
- result = await agent.chat(user_message)
 
 
 
 
 
61
 
62
  # Extract the final response and audio files from the result
63
  final_response = result.get("final_response", "")
@@ -80,14 +91,14 @@ async def bot_response(history):
80
 
81
  return history
82
 
83
- def bot_response_sync(history):
84
  """
85
  Synchronous wrapper for the async bot response
86
  """
87
  loop = asyncio.new_event_loop()
88
  asyncio.set_event_loop(loop)
89
  try:
90
- return loop.run_until_complete(bot_response(history))
91
  finally:
92
  loop.close()
93
 
@@ -110,6 +121,9 @@ def create_interface():
110
  **Supported formats**: MP3, WAV, M4A, FLAC, AAC, OGG
111
  """)
112
 
 
 
 
113
  with gr.Row():
114
  with gr.Column(scale=2):
115
  chatbot = gr.Chatbot(
@@ -184,35 +198,39 @@ def create_interface():
184
 
185
  # Handle user input and bot response
186
  def handle_submit(message, files, history):
187
- return user_input(message, files, history)
 
188
 
189
  msg.submit(
190
  handle_submit,
191
  [msg, audio_files, chatbot],
192
- [msg, audio_files, chatbot],
193
  queue=False
194
  ).then(
195
  bot_response_sync,
196
- chatbot,
197
  chatbot
198
  )
199
 
200
  send_btn.click(
201
  handle_submit,
202
  [msg, audio_files, chatbot],
203
- [msg, audio_files, chatbot],
204
  queue=False
205
  ).then(
206
  bot_response_sync,
207
- chatbot,
208
  chatbot
209
  )
210
 
211
  # Clear chat
 
 
 
212
  clear_btn.click(
213
- lambda: ([], []),
214
  None,
215
- [chatbot, audio_files],
216
  queue=False
217
  )
218
 
 
33
 
34
  audio_file_urls.append(get_share_url(file_path))
35
 
36
+ # For display purposes, show what audio files were uploaded
37
  if audio_file_urls:
38
+ audio_list = "\n".join([f"🎵 Uploaded: {url.split('/')[-1]}" for url in audio_file_urls])
39
+ combined_message = f"{user_message}\n\n{audio_list}" if user_message.strip() else f"Process uploaded audio files:\n{audio_list}"
40
  else:
41
  combined_message = user_message
42
 
43
  history.append({"role": "user", "content": combined_message})
44
+ return "", [], history, audio_file_urls
45
 
46
+ async def bot_response(history, audio_file_urls):
47
  """
48
+ Generate bot response using the test agent
49
  """
50
  if not history or history[-1]["role"] != "user":
51
  return history
52
 
53
+ # Get the actual user message (without the audio file display text)
54
  user_message = history[-1]["content"]
55
 
56
+ # Clean the user message by removing the uploaded file display text
57
+ if "🎵 Uploaded:" in user_message:
58
+ lines = user_message.split('\n')
59
+ clean_lines = []
60
+ for line in lines:
61
+ if not line.strip().startswith('🎵 Uploaded:'):
62
+ clean_lines.append(line)
63
+ user_message = '\n'.join(clean_lines).strip()
64
 
65
+ # If message is empty after cleaning, provide default message
66
+ if not user_message:
67
+ user_message = "Please process these audio files"
68
+
69
+ try:
70
+ # Use the test agent's run_agent method with separate parameters
71
+ result = await agent.run_agent(user_message, audio_file_urls or [])
72
 
73
  # Extract the final response and audio files from the result
74
  final_response = result.get("final_response", "")
 
91
 
92
  return history
93
 
94
+ def bot_response_sync(history, audio_file_urls):
95
  """
96
  Synchronous wrapper for the async bot response
97
  """
98
  loop = asyncio.new_event_loop()
99
  asyncio.set_event_loop(loop)
100
  try:
101
+ return loop.run_until_complete(bot_response(history, audio_file_urls))
102
  finally:
103
  loop.close()
104
 
 
121
  **Supported formats**: MP3, WAV, M4A, FLAC, AAC, OGG
122
  """)
123
 
124
+ # Hidden state to store audio file URLs
125
+ audio_urls_state = gr.State([])
126
+
127
  with gr.Row():
128
  with gr.Column(scale=2):
129
  chatbot = gr.Chatbot(
 
198
 
199
  # Handle user input and bot response
200
  def handle_submit(message, files, history):
201
+ new_msg, new_files, updated_history, audio_urls = user_input(message, files, history)
202
+ return new_msg, new_files, updated_history, audio_urls
203
 
204
  msg.submit(
205
  handle_submit,
206
  [msg, audio_files, chatbot],
207
+ [msg, audio_files, chatbot, audio_urls_state],
208
  queue=False
209
  ).then(
210
  bot_response_sync,
211
+ [chatbot, audio_urls_state],
212
  chatbot
213
  )
214
 
215
  send_btn.click(
216
  handle_submit,
217
  [msg, audio_files, chatbot],
218
+ [msg, audio_files, chatbot, audio_urls_state],
219
  queue=False
220
  ).then(
221
  bot_response_sync,
222
+ [chatbot, audio_urls_state],
223
  chatbot
224
  )
225
 
226
  # Clear chat
227
+ def clear_chat():
228
+ return [], [], []
229
+
230
  clear_btn.click(
231
+ clear_chat,
232
  None,
233
+ [chatbot, audio_files, audio_urls_state],
234
  queue=False
235
  )
236