Spaces:
Sleeping
Sleeping
Commit ·
5e87361
1
Parent(s): f4438c7
try new agent style
Browse files- src/agent.py +77 -104
- src/nodes/agent.py +128 -0
- src/nodes/chat.py +1 -1
- src/nodes/planner.py +1 -1
- src/nodes/processor.py +1 -1
- src/{state.py → nodes/state.py} +0 -0
- src/nodes/validator.py +1 -1
- src/ui.py +38 -20
src/agent.py
CHANGED
|
@@ -1,14 +1,62 @@
|
|
|
|
|
|
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
-
from functools import partial
|
| 3 |
-
|
| 4 |
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 5 |
-
from langgraph.graph import StateGraph, END, START
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class AudioAgent:
|
| 14 |
def __init__(
|
|
@@ -24,105 +72,30 @@ class AudioAgent:
|
|
| 24 |
self._client = MultiServerMCPClient({
|
| 25 |
"audio-tools": {"url": self.server_url, "transport": "sse"}
|
| 26 |
})
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
return self.graph is not None
|
| 31 |
-
|
| 32 |
-
async def _build_graph(self) -> None:
|
| 33 |
-
"""Build the LangGraph workflow."""
|
| 34 |
-
|
| 35 |
-
_graph = StateGraph(
|
| 36 |
-
AgentState,
|
| 37 |
-
input=InputState,
|
| 38 |
-
output=OutputState
|
| 39 |
-
)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
"planner": "planner",
|
| 47 |
-
"end": END
|
| 48 |
-
}
|
| 49 |
)
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
# TODO: add validator edge to here
|
| 57 |
-
_graph.add_edge("audio_processor", "chat")
|
| 58 |
-
|
| 59 |
-
_graph.add_node("validator", validator_node)
|
| 60 |
-
_graph.add_conditional_edges(
|
| 61 |
-
"validator",
|
| 62 |
-
validator_node_router,
|
| 63 |
-
{
|
| 64 |
-
"chat": "chat",
|
| 65 |
-
"planner": "planner"
|
| 66 |
-
}
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
_graph.add_edge(START, "chat")
|
| 70 |
-
_graph.add_edge("chat", END)
|
| 71 |
-
self.graph = _graph.compile()
|
| 72 |
-
|
| 73 |
-
async def initialize(self) -> None:
|
| 74 |
-
"""Initialize the LangGraph workflow with audio tools."""
|
| 75 |
-
if self.is_initialized:
|
| 76 |
-
return
|
| 77 |
-
|
| 78 |
-
self.tools = await self._client.get_tools()
|
| 79 |
-
if not self.tools:
|
| 80 |
-
raise RuntimeError("No tools available from MCP server")
|
| 81 |
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
lines = user_message.split('\n')
|
| 88 |
-
clean_lines = []
|
| 89 |
-
|
| 90 |
-
for line in lines:
|
| 91 |
-
if line.strip().startswith('Audio file:'):
|
| 92 |
-
# Extract the file path
|
| 93 |
-
audio_path = line.replace('Audio file:', '').strip()
|
| 94 |
-
audio_files.append(audio_path)
|
| 95 |
-
else:
|
| 96 |
-
clean_lines.append(line)
|
| 97 |
-
|
| 98 |
-
clean_message = '\n'.join(clean_lines).strip()
|
| 99 |
-
return clean_message, audio_files
|
| 100 |
-
|
| 101 |
-
async def chat(self, user_message: str):
|
| 102 |
-
"""Stream chat responses with node information."""
|
| 103 |
-
if not self.is_initialized:
|
| 104 |
-
await self.initialize()
|
| 105 |
-
|
| 106 |
-
# Extract audio file paths from the message
|
| 107 |
-
clean_message, audio_files = self._extract_audio_paths(user_message)
|
| 108 |
-
|
| 109 |
-
# Set up initial state
|
| 110 |
-
initial_state = {
|
| 111 |
-
"user_input": clean_message,
|
| 112 |
-
"input_audio_files": audio_files,
|
| 113 |
-
"steps_details": [],
|
| 114 |
-
"plan": "",
|
| 115 |
-
"final_response": "",
|
| 116 |
-
"requires_processing": False,
|
| 117 |
-
"validator_feedback": "",
|
| 118 |
-
"output_audio_files": []
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
# Stream the graph execution
|
| 122 |
-
return await self.graph.ainvoke(initial_state, stream_mode="values")
|
| 123 |
-
|
| 124 |
-
def draw_graph(self) -> None:
|
| 125 |
-
"""Draw the graph to a file."""
|
| 126 |
-
graph_image = self.graph.get_graph().draw_mermaid_png()
|
| 127 |
-
with open("graph.png", "wb") as f:
|
| 128 |
-
f.write(graph_image)
|
|
|
|
| 1 |
+
from langgraph.prebuilt import create_react_agent
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
| 4 |
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
|
|
| 5 |
|
| 6 |
+
class AgentOutput(BaseModel):
|
| 7 |
+
final_response: str = Field(description="The final response to the user.", default="")
|
| 8 |
+
output_audio_files: list[str] = Field(description="The output audio files.", default=[])
|
| 9 |
+
|
| 10 |
+
system_prompt = """You are an expert Audio Processing Assistant with specialized capabilities in audio manipulation, analysis, and editing. Your primary purpose is to help users with audio-related tasks and provide knowledgeable assistance in the audio domain.
|
| 11 |
+
|
| 12 |
+
## Core Behavior Guidelines:
|
| 13 |
+
|
| 14 |
+
### Conversation Scope:
|
| 15 |
+
- ONLY engage in conversations related to audio processing, audio editing, sound engineering, music production, audio analysis, audio formats, and related audio technologies
|
| 16 |
+
- If a user asks about topics outside the audio domain, politely decline and redirect them back to audio-related assistance
|
| 17 |
+
- Be conversational, friendly, and helpful when discussing audio topics
|
| 18 |
+
- Share your expertise about audio concepts, techniques, and best practices when relevant
|
| 19 |
+
|
| 20 |
+
### Audio Processing Workflow:
|
| 21 |
+
When a user requests audio processing and provides input files, follow this structured approach:
|
| 22 |
+
|
| 23 |
+
1. **ANALYSIS PHASE:**
|
| 24 |
+
- Analyze the user's request to understand their goals
|
| 25 |
+
- Examine the provided input audio files if available
|
| 26 |
+
- Identify what audio processing operations are needed
|
| 27 |
+
|
| 28 |
+
2. **PLANNING PHASE:**
|
| 29 |
+
- Create a clear, step-by-step plan for the audio processing task
|
| 30 |
+
- Explain your plan to the user before execution
|
| 31 |
+
- Ensure the plan addresses their specific requirements
|
| 32 |
+
|
| 33 |
+
3. **EXECUTION PHASE:**
|
| 34 |
+
- Use the available audio tools to implement your plan
|
| 35 |
+
- Process the audio files according to the planned steps
|
| 36 |
+
- Handle any errors or unexpected results gracefully
|
| 37 |
+
|
| 38 |
+
4. **VALIDATION PHASE:**
|
| 39 |
+
- Verify that the processed audio meets the user's requirements
|
| 40 |
+
- Check the quality and correctness of the output
|
| 41 |
+
- Test that the processing achieved the desired results
|
| 42 |
+
|
| 43 |
+
5. **RESPONSE PHASE:**
|
| 44 |
+
- Provide a clear summary of what was accomplished
|
| 45 |
+
- Include the output audio files in your response
|
| 46 |
+
- Offer additional suggestions or next steps if relevant
|
| 47 |
+
|
| 48 |
+
## Available Context:
|
| 49 |
+
- You have access to input_audio_files when provided by the user
|
| 50 |
+
- You can generate output_audio_files through your audio processing tools
|
| 51 |
+
- Use your tools effectively to analyze, edit, convert, and manipulate audio
|
| 52 |
+
|
| 53 |
+
## Response Format:
|
| 54 |
+
- Always provide helpful, accurate information about audio topics
|
| 55 |
+
- When processing audio, be transparent about your process and results
|
| 56 |
+
- Include relevant technical details when appropriate
|
| 57 |
+
- Maintain a professional yet approachable tone
|
| 58 |
+
|
| 59 |
+
Remember: Stay focused on audio-related assistance and use your specialized tools to help users achieve their audio processing goals efficiently and effectively."""
|
| 60 |
|
| 61 |
class AudioAgent:
|
| 62 |
def __init__(
|
|
|
|
| 72 |
self._client = MultiServerMCPClient({
|
| 73 |
"audio-tools": {"url": self.server_url, "transport": "sse"}
|
| 74 |
})
|
| 75 |
+
|
| 76 |
+
self.agent = None
|
| 77 |
|
| 78 |
+
async def build_agent(self):
|
| 79 |
+
tools = await self._client.get_tools()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
agent = create_react_agent(
|
| 82 |
+
model="gpt-4.1",
|
| 83 |
+
tools=tools,
|
| 84 |
+
prompt=system_prompt,
|
| 85 |
+
response_format=AgentOutput,
|
|
|
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
+
return agent
|
| 89 |
+
|
| 90 |
+
async def run_agent(self, user_input: str, input_audio_files: list[str]):
|
| 91 |
+
if self.agent is None:
|
| 92 |
+
self.agent = await self.build_agent()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
input_context = f"""
|
| 95 |
+
User Request: {user_input}
|
| 96 |
+
Input Audio Files: {', '.join(input_audio_files) if input_audio_files else 'None'}
|
| 97 |
+
"""
|
| 98 |
|
| 99 |
+
return await self.agent.ainvoke(
|
| 100 |
+
{"messages": [{"role": "user", "content": input_context}]}
|
| 101 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/nodes/agent.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 5 |
+
from langgraph.graph import StateGraph, END, START
|
| 6 |
+
|
| 7 |
+
from .state import AgentState, InputState, OutputState
|
| 8 |
+
from .chat import chat_node, chat_node_router
|
| 9 |
+
from .planner import planner_node
|
| 10 |
+
from .processor import processor_node
|
| 11 |
+
from .validator import validator_node, validator_node_router
|
| 12 |
+
|
| 13 |
+
class AudioAgent:
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
model_name: str = "gpt-4o",
|
| 17 |
+
server_url: str = "https://agents-mcp-hackathon-audioeditor.hf.space/gradio_api/mcp/sse",
|
| 18 |
+
):
|
| 19 |
+
load_dotenv()
|
| 20 |
+
self.model_name = model_name
|
| 21 |
+
self.server_url = server_url
|
| 22 |
+
self.graph = None
|
| 23 |
+
|
| 24 |
+
self._client = MultiServerMCPClient({
|
| 25 |
+
"audio-tools": {"url": self.server_url, "transport": "sse"}
|
| 26 |
+
})
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def is_initialized(self) -> bool:
|
| 30 |
+
return self.graph is not None
|
| 31 |
+
|
| 32 |
+
async def _build_graph(self) -> None:
|
| 33 |
+
"""Build the LangGraph workflow."""
|
| 34 |
+
|
| 35 |
+
_graph = StateGraph(
|
| 36 |
+
AgentState,
|
| 37 |
+
input=InputState,
|
| 38 |
+
output=OutputState
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
_graph.add_node("chat", chat_node)
|
| 42 |
+
_graph.add_conditional_edges(
|
| 43 |
+
"chat",
|
| 44 |
+
chat_node_router,
|
| 45 |
+
{
|
| 46 |
+
"planner": "planner",
|
| 47 |
+
"end": END
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
_graph.add_node("planner", planner_node)
|
| 52 |
+
_graph.add_edge("planner", "audio_processor")
|
| 53 |
+
|
| 54 |
+
processor_node_with_tools = partial(processor_node, tools=self.tools)
|
| 55 |
+
_graph.add_node("audio_processor", processor_node_with_tools)
|
| 56 |
+
# TODO: add validator edge to here
|
| 57 |
+
_graph.add_edge("audio_processor", "chat")
|
| 58 |
+
|
| 59 |
+
_graph.add_node("validator", validator_node)
|
| 60 |
+
_graph.add_conditional_edges(
|
| 61 |
+
"validator",
|
| 62 |
+
validator_node_router,
|
| 63 |
+
{
|
| 64 |
+
"chat": "chat",
|
| 65 |
+
"planner": "planner"
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
_graph.add_edge(START, "chat")
|
| 70 |
+
_graph.add_edge("chat", END)
|
| 71 |
+
self.graph = _graph.compile()
|
| 72 |
+
|
| 73 |
+
async def initialize(self) -> None:
|
| 74 |
+
"""Initialize the LangGraph workflow with audio tools."""
|
| 75 |
+
if self.is_initialized:
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
self.tools = await self._client.get_tools()
|
| 79 |
+
if not self.tools:
|
| 80 |
+
raise RuntimeError("No tools available from MCP server")
|
| 81 |
+
|
| 82 |
+
await self._build_graph()
|
| 83 |
+
|
| 84 |
+
def _extract_audio_paths(self, user_message: str) -> tuple[str, list[str]]:
|
| 85 |
+
"""Extract audio file paths from user message and return cleaned message."""
|
| 86 |
+
audio_files = []
|
| 87 |
+
lines = user_message.split('\n')
|
| 88 |
+
clean_lines = []
|
| 89 |
+
|
| 90 |
+
for line in lines:
|
| 91 |
+
if line.strip().startswith('Audio file:'):
|
| 92 |
+
# Extract the file path
|
| 93 |
+
audio_path = line.replace('Audio file:', '').strip()
|
| 94 |
+
audio_files.append(audio_path)
|
| 95 |
+
else:
|
| 96 |
+
clean_lines.append(line)
|
| 97 |
+
|
| 98 |
+
clean_message = '\n'.join(clean_lines).strip()
|
| 99 |
+
return clean_message, audio_files
|
| 100 |
+
|
| 101 |
+
async def chat(self, user_message: str):
|
| 102 |
+
"""Stream chat responses with node information."""
|
| 103 |
+
if not self.is_initialized:
|
| 104 |
+
await self.initialize()
|
| 105 |
+
|
| 106 |
+
# Extract audio file paths from the message
|
| 107 |
+
clean_message, audio_files = self._extract_audio_paths(user_message)
|
| 108 |
+
|
| 109 |
+
# Set up initial state
|
| 110 |
+
initial_state = {
|
| 111 |
+
"user_input": clean_message,
|
| 112 |
+
"input_audio_files": audio_files,
|
| 113 |
+
"steps_details": [],
|
| 114 |
+
"plan": "",
|
| 115 |
+
"final_response": "",
|
| 116 |
+
"requires_processing": False,
|
| 117 |
+
"validator_feedback": "",
|
| 118 |
+
"output_audio_files": []
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# Stream the graph execution
|
| 122 |
+
return await self.graph.ainvoke(initial_state, stream_mode="values")
|
| 123 |
+
|
| 124 |
+
def draw_graph(self) -> None:
|
| 125 |
+
"""Draw the graph to a file."""
|
| 126 |
+
graph_image = self.graph.get_graph().draw_mermaid_png()
|
| 127 |
+
with open("graph.png", "wb") as f:
|
| 128 |
+
f.write(graph_image)
|
src/nodes/chat.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from langchain_openai import ChatOpenAI
|
| 2 |
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
from langchain_core.runnables import RunnableParallel
|
| 4 |
-
from
|
| 5 |
from operator import itemgetter
|
| 6 |
|
| 7 |
def chat_node(state: ChatInputState) -> ChatOutputState:
|
|
|
|
| 1 |
from langchain_openai import ChatOpenAI
|
| 2 |
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
from langchain_core.runnables import RunnableParallel
|
| 4 |
+
from nodes.state import AgentState, ChatInputState, ChatOutputState
|
| 5 |
from operator import itemgetter
|
| 6 |
|
| 7 |
def chat_node(state: ChatInputState) -> ChatOutputState:
|
src/nodes/planner.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from langchain_openai import ChatOpenAI
|
| 2 |
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
from langchain_core.runnables import RunnableParallel
|
| 4 |
-
from
|
| 5 |
from operator import itemgetter
|
| 6 |
|
| 7 |
def planner_node(state: PlannerInputState) -> PlannerOutputState:
|
|
|
|
| 1 |
from langchain_openai import ChatOpenAI
|
| 2 |
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
from langchain_core.runnables import RunnableParallel
|
| 4 |
+
from nodes.state import AgentState, PlannerInputState, PlannerOutputState
|
| 5 |
from operator import itemgetter
|
| 6 |
|
| 7 |
def planner_node(state: PlannerInputState) -> PlannerOutputState:
|
src/nodes/processor.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from
|
| 2 |
from langgraph.prebuilt import create_react_agent
|
| 3 |
from pydantic import BaseModel, Field
|
| 4 |
|
|
|
|
| 1 |
+
from nodes.state import ProcessorInputState, ProcessorOutputState
|
| 2 |
from langgraph.prebuilt import create_react_agent
|
| 3 |
from pydantic import BaseModel, Field
|
| 4 |
|
src/{state.py → nodes/state.py}
RENAMED
|
File without changes
|
src/nodes/validator.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from langchain_openai import ChatOpenAI
|
| 2 |
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
-
from
|
| 4 |
from operator import itemgetter
|
| 5 |
from langchain_core.runnables import RunnableParallel
|
| 6 |
|
|
|
|
| 1 |
from langchain_openai import ChatOpenAI
|
| 2 |
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
+
from nodes.state import AgentState, ValidatorInputState, ValidatorOutputState
|
| 4 |
from operator import itemgetter
|
| 5 |
from langchain_core.runnables import RunnableParallel
|
| 6 |
|
src/ui.py
CHANGED
|
@@ -33,31 +33,42 @@ def user_input(user_message, audio_files, history):
|
|
| 33 |
|
| 34 |
audio_file_urls.append(get_share_url(file_path))
|
| 35 |
|
|
|
|
| 36 |
if audio_file_urls:
|
| 37 |
-
audio_list = "\n".join([f"
|
| 38 |
-
combined_message = f"{user_message}\n\n{audio_list}" if user_message.strip() else audio_list
|
| 39 |
else:
|
| 40 |
combined_message = user_message
|
| 41 |
|
| 42 |
history.append({"role": "user", "content": combined_message})
|
| 43 |
-
return "", [], history
|
| 44 |
|
| 45 |
-
async def bot_response(history):
|
| 46 |
"""
|
| 47 |
-
Generate bot response using the
|
| 48 |
"""
|
| 49 |
if not history or history[-1]["role"] != "user":
|
| 50 |
return history
|
| 51 |
|
|
|
|
| 52 |
user_message = history[-1]["content"]
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
#
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Extract the final response and audio files from the result
|
| 63 |
final_response = result.get("final_response", "")
|
|
@@ -80,14 +91,14 @@ async def bot_response(history):
|
|
| 80 |
|
| 81 |
return history
|
| 82 |
|
| 83 |
-
def bot_response_sync(history):
|
| 84 |
"""
|
| 85 |
Synchronous wrapper for the async bot response
|
| 86 |
"""
|
| 87 |
loop = asyncio.new_event_loop()
|
| 88 |
asyncio.set_event_loop(loop)
|
| 89 |
try:
|
| 90 |
-
return loop.run_until_complete(bot_response(history))
|
| 91 |
finally:
|
| 92 |
loop.close()
|
| 93 |
|
|
@@ -110,6 +121,9 @@ def create_interface():
|
|
| 110 |
**Supported formats**: MP3, WAV, M4A, FLAC, AAC, OGG
|
| 111 |
""")
|
| 112 |
|
|
|
|
|
|
|
|
|
|
| 113 |
with gr.Row():
|
| 114 |
with gr.Column(scale=2):
|
| 115 |
chatbot = gr.Chatbot(
|
|
@@ -184,35 +198,39 @@ def create_interface():
|
|
| 184 |
|
| 185 |
# Handle user input and bot response
|
| 186 |
def handle_submit(message, files, history):
|
| 187 |
-
|
|
|
|
| 188 |
|
| 189 |
msg.submit(
|
| 190 |
handle_submit,
|
| 191 |
[msg, audio_files, chatbot],
|
| 192 |
-
[msg, audio_files, chatbot],
|
| 193 |
queue=False
|
| 194 |
).then(
|
| 195 |
bot_response_sync,
|
| 196 |
-
chatbot,
|
| 197 |
chatbot
|
| 198 |
)
|
| 199 |
|
| 200 |
send_btn.click(
|
| 201 |
handle_submit,
|
| 202 |
[msg, audio_files, chatbot],
|
| 203 |
-
[msg, audio_files, chatbot],
|
| 204 |
queue=False
|
| 205 |
).then(
|
| 206 |
bot_response_sync,
|
| 207 |
-
chatbot,
|
| 208 |
chatbot
|
| 209 |
)
|
| 210 |
|
| 211 |
# Clear chat
|
|
|
|
|
|
|
|
|
|
| 212 |
clear_btn.click(
|
| 213 |
-
|
| 214 |
None,
|
| 215 |
-
[chatbot, audio_files],
|
| 216 |
queue=False
|
| 217 |
)
|
| 218 |
|
|
|
|
| 33 |
|
| 34 |
audio_file_urls.append(get_share_url(file_path))
|
| 35 |
|
| 36 |
+
# For display purposes, show what audio files were uploaded
|
| 37 |
if audio_file_urls:
|
| 38 |
+
audio_list = "\n".join([f"🎵 Uploaded: {url.split('/')[-1]}" for url in audio_file_urls])
|
| 39 |
+
combined_message = f"{user_message}\n\n{audio_list}" if user_message.strip() else f"Process uploaded audio files:\n{audio_list}"
|
| 40 |
else:
|
| 41 |
combined_message = user_message
|
| 42 |
|
| 43 |
history.append({"role": "user", "content": combined_message})
|
| 44 |
+
return "", [], history, audio_file_urls
|
| 45 |
|
| 46 |
+
async def bot_response(history, audio_file_urls):
|
| 47 |
"""
|
| 48 |
+
Generate bot response using the test agent
|
| 49 |
"""
|
| 50 |
if not history or history[-1]["role"] != "user":
|
| 51 |
return history
|
| 52 |
|
| 53 |
+
# Get the actual user message (without the audio file display text)
|
| 54 |
user_message = history[-1]["content"]
|
| 55 |
|
| 56 |
+
# Clean the user message by removing the uploaded file display text
|
| 57 |
+
if "🎵 Uploaded:" in user_message:
|
| 58 |
+
lines = user_message.split('\n')
|
| 59 |
+
clean_lines = []
|
| 60 |
+
for line in lines:
|
| 61 |
+
if not line.strip().startswith('🎵 Uploaded:'):
|
| 62 |
+
clean_lines.append(line)
|
| 63 |
+
user_message = '\n'.join(clean_lines).strip()
|
| 64 |
|
| 65 |
+
# If message is empty after cleaning, provide default message
|
| 66 |
+
if not user_message:
|
| 67 |
+
user_message = "Please process these audio files"
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
# Use the test agent's run_agent method with separate parameters
|
| 71 |
+
result = await agent.run_agent(user_message, audio_file_urls or [])
|
| 72 |
|
| 73 |
# Extract the final response and audio files from the result
|
| 74 |
final_response = result.get("final_response", "")
|
|
|
|
| 91 |
|
| 92 |
return history
|
| 93 |
|
| 94 |
+
def bot_response_sync(history, audio_file_urls):
|
| 95 |
"""
|
| 96 |
Synchronous wrapper for the async bot response
|
| 97 |
"""
|
| 98 |
loop = asyncio.new_event_loop()
|
| 99 |
asyncio.set_event_loop(loop)
|
| 100 |
try:
|
| 101 |
+
return loop.run_until_complete(bot_response(history, audio_file_urls))
|
| 102 |
finally:
|
| 103 |
loop.close()
|
| 104 |
|
|
|
|
| 121 |
**Supported formats**: MP3, WAV, M4A, FLAC, AAC, OGG
|
| 122 |
""")
|
| 123 |
|
| 124 |
+
# Hidden state to store audio file URLs
|
| 125 |
+
audio_urls_state = gr.State([])
|
| 126 |
+
|
| 127 |
with gr.Row():
|
| 128 |
with gr.Column(scale=2):
|
| 129 |
chatbot = gr.Chatbot(
|
|
|
|
| 198 |
|
| 199 |
# Handle user input and bot response
|
| 200 |
def handle_submit(message, files, history):
|
| 201 |
+
new_msg, new_files, updated_history, audio_urls = user_input(message, files, history)
|
| 202 |
+
return new_msg, new_files, updated_history, audio_urls
|
| 203 |
|
| 204 |
msg.submit(
|
| 205 |
handle_submit,
|
| 206 |
[msg, audio_files, chatbot],
|
| 207 |
+
[msg, audio_files, chatbot, audio_urls_state],
|
| 208 |
queue=False
|
| 209 |
).then(
|
| 210 |
bot_response_sync,
|
| 211 |
+
[chatbot, audio_urls_state],
|
| 212 |
chatbot
|
| 213 |
)
|
| 214 |
|
| 215 |
send_btn.click(
|
| 216 |
handle_submit,
|
| 217 |
[msg, audio_files, chatbot],
|
| 218 |
+
[msg, audio_files, chatbot, audio_urls_state],
|
| 219 |
queue=False
|
| 220 |
).then(
|
| 221 |
bot_response_sync,
|
| 222 |
+
[chatbot, audio_urls_state],
|
| 223 |
chatbot
|
| 224 |
)
|
| 225 |
|
| 226 |
# Clear chat
|
| 227 |
+
def clear_chat():
|
| 228 |
+
return [], [], []
|
| 229 |
+
|
| 230 |
clear_btn.click(
|
| 231 |
+
clear_chat,
|
| 232 |
None,
|
| 233 |
+
[chatbot, audio_files, audio_urls_state],
|
| 234 |
queue=False
|
| 235 |
)
|
| 236 |
|