YigitSekerci commited on
Commit
5a5e484
·
1 Parent(s): 23c0e5d

add llm controls

Browse files
Files changed (2) hide show
  1. src/agent.py +16 -2
  2. src/ui.py +79 -36
src/agent.py CHANGED
@@ -2,6 +2,7 @@ from langgraph.prebuilt import create_react_agent
2
  from pydantic import BaseModel, Field
3
  from dotenv import load_dotenv
4
  from langchain_mcp_adapters.client import MultiServerMCPClient
 
5
  import os
6
 
7
  class AgentOutput(BaseModel):
@@ -73,10 +74,14 @@ Output Audio Files: {output_audio_files}
73
  class AudioAgent:
74
  def __init__(
75
  self,
76
- model_name: str = "gpt-4.1-mini",
 
 
77
  ):
78
  load_dotenv()
79
  self.model_name = model_name
 
 
80
  self.server_url = os.getenv("MCP_SERVER")
81
  self.graph = None
82
 
@@ -87,10 +92,19 @@ class AudioAgent:
87
  self.agent = None
88
 
89
  async def build_agent(self):
 
 
 
90
  tools = await self._client.get_tools()
91
 
92
- agent = create_react_agent(
93
  model=self.model_name,
 
 
 
 
 
 
94
  tools=tools,
95
  prompt=system_prompt,
96
  response_format=AgentOutput,
 
2
  from pydantic import BaseModel, Field
3
  from dotenv import load_dotenv
4
  from langchain_mcp_adapters.client import MultiServerMCPClient
5
+ from langchain_openai import ChatOpenAI
6
  import os
7
 
8
  class AgentOutput(BaseModel):
 
74
  class AudioAgent:
75
  def __init__(
76
  self,
77
+ model_name: str = "gpt-4.1",
78
+ temperature: float = 0.3,
79
+ api_key: str = None,
80
  ):
81
  load_dotenv()
82
  self.model_name = model_name
83
+ self.temperature = temperature
84
+ self.api_key = api_key # or os.getenv("OPENAI_API_KEY")
85
  self.server_url = os.getenv("MCP_SERVER")
86
  self.graph = None
87
 
 
92
  self.agent = None
93
 
94
  async def build_agent(self):
95
+ if not self.api_key:
96
+ raise ValueError("OpenAI API key is required")
97
+
98
  tools = await self._client.get_tools()
99
 
100
+ llm = ChatOpenAI(
101
  model=self.model_name,
102
+ temperature=self.temperature,
103
+ api_key=self.api_key
104
+ )
105
+
106
+ agent = create_react_agent(
107
+ model=llm,
108
  tools=tools,
109
  prompt=system_prompt,
110
  response_format=AgentOutput,
src/ui.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from .agent import AudioAgent
4
 
5
  # Global agent instance
6
- agent = AudioAgent()
7
 
8
  # Global demo instance
9
  demo = None
@@ -14,10 +14,28 @@ def get_share_url(path):
14
  return path
15
  return f"{demo.share_url}/gradio_api/file={path}"
16
 
17
- def user_input(user_message, audio_files, history, custom_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  """
19
  Handle user input with text and audio files
20
  """
 
 
 
 
 
21
  if not user_message.strip() and not audio_files:
22
  return "", audio_files, history, custom_history
23
 
@@ -46,12 +64,15 @@ def user_input(user_message, audio_files, history, custom_history):
46
  "input_files": audio_file_urls
47
  })
48
 
49
- return "", audio_files, history, audio_file_urls, custom_history
50
 
51
  async def bot_response(history, audio_file_urls, custom_history):
52
  """
53
  Generate bot response using the agent
54
  """
 
 
 
55
  if not history or history[-1]["role"] != "user":
56
  return history, []
57
 
@@ -65,7 +86,7 @@ async def bot_response(history, audio_file_urls, custom_history):
65
 
66
  try:
67
  # Use the agent's run_agent method with history
68
- result = await agent.run_agent(user_message, input_files, custom_history)
69
 
70
  # Extract the final response and audio files from the result
71
  final_response = result.final_response
@@ -87,16 +108,7 @@ async def bot_response(history, audio_file_urls, custom_history):
87
  return history, output_audio_files
88
 
89
  except Exception as e:
90
- history.append({
91
- "role": "assistant",
92
- "content": f"❌ **Error**: {e}",
93
- })
94
- custom_history.append({
95
- "role": "assistant",
96
- "content": f"❌ **Error**: {e}",
97
- "output_files": []
98
- })
99
- return history, []
100
 
101
  def bot_response_sync(history, audio_file_urls, custom_history):
102
  """
@@ -122,7 +134,7 @@ def create_interface():
122
  # Hidden state to store audio file URLs and custom history
123
  audio_urls_state = gr.State([])
124
  custom_history_state = gr.State([])
125
-
126
  with gr.Row():
127
  with gr.Column(scale=2):
128
  chatbot = gr.Chatbot(
@@ -133,42 +145,73 @@ def create_interface():
133
  )
134
 
135
  with gr.Column(scale=1):
136
- audio_files = gr.File(
137
- file_count="multiple",
138
- file_types=["audio"],
139
- label="Upload Audio Files to Process",
140
- height=150
141
- )
142
- output_audio_files = gr.File(
143
- file_count="multiple",
144
- file_types=["audio"],
145
- label="Download Generated Audio",
146
- interactive=False,
147
- height=150
148
- )
149
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  with gr.Row(equal_height=True):
151
  msg = gr.Textbox(
152
  label="Describe what you want to do?",
153
  placeholder="e.g., 'Remove filler words and improve audio quality''",
154
  lines=3,
155
- scale=4
156
  )
157
  send_btn = gr.Button("Ask", variant="primary", scale=1, size="lg")
158
 
 
 
 
159
  # Handle user input and bot response
160
- def handle_submit(message, files, history, custom_history):
161
- new_msg, new_files, updated_history, audio_urls, updated_custom_history = user_input(message, files, history, custom_history)
162
- return new_msg, new_files, updated_history, audio_urls, updated_custom_history
 
 
163
 
164
  def handle_bot_response(history, audio_urls, custom_history):
165
  updated_history, output_files = bot_response_sync(history, audio_urls, custom_history)
166
- return updated_history, output_files, custom_history
 
167
 
168
  msg.submit(
169
  handle_submit,
 
170
  [msg, audio_files, chatbot, custom_history_state],
171
- [msg, audio_files, chatbot, audio_urls_state, custom_history_state],
172
  queue=False
173
  ).then(
174
  handle_bot_response,
@@ -178,8 +221,8 @@ def create_interface():
178
 
179
  send_btn.click(
180
  handle_submit,
 
181
  [msg, audio_files, chatbot, custom_history_state],
182
- [msg, audio_files, chatbot, audio_urls_state, custom_history_state],
183
  queue=False
184
  ).then(
185
  handle_bot_response,
 
3
  from .agent import AudioAgent
4
 
5
  # Global agent instance
6
+ agent = None
7
 
8
  # Global demo instance
9
  demo = None
 
14
  return path
15
  return f"{demo.share_url}/gradio_api/file={path}"
16
 
17
+ def update_agent(model_name, temperature, api_key):
18
+ """Update the agent with new configuration"""
19
+ global agent
20
+ try:
21
+ agent = AudioAgent(
22
+ model_name=model_name,
23
+ temperature=float(temperature),
24
+ api_key=api_key
25
+ )
26
+ return True, None
27
+ except Exception as e:
28
+ return False, str(e)
29
+
30
+ def user_input(user_message, audio_files, history, custom_history, model_name, temperature, api_key):
31
  """
32
  Handle user input with text and audio files
33
  """
34
+ # Try to update agent configuration
35
+ success, error = update_agent(model_name, temperature, api_key)
36
+ if not success:
37
+ raise gr.Error(error)
38
+
39
  if not user_message.strip() and not audio_files:
40
  return "", audio_files, history, custom_history
41
 
 
64
  "input_files": audio_file_urls
65
  })
66
 
67
+ return "", audio_files, history, custom_history
68
 
69
  async def bot_response(history, audio_file_urls, custom_history):
70
  """
71
  Generate bot response using the agent
72
  """
73
+ if not agent:
74
+ raise gr.Error("Please configure the agent first")
75
+
76
  if not history or history[-1]["role"] != "user":
77
  return history, []
78
 
 
86
 
87
  try:
88
  # Use the agent's run_agent method with history
89
+ result = await agent.run_agent(user_message, input_files, custom_history[:-1])
90
 
91
  # Extract the final response and audio files from the result
92
  final_response = result.final_response
 
108
  return history, output_audio_files
109
 
110
  except Exception as e:
111
+ raise gr.Error(str(e))
 
 
 
 
 
 
 
 
 
112
 
113
  def bot_response_sync(history, audio_file_urls, custom_history):
114
  """
 
134
  # Hidden state to store audio file URLs and custom history
135
  audio_urls_state = gr.State([])
136
  custom_history_state = gr.State([])
137
+
138
  with gr.Row():
139
  with gr.Column(scale=2):
140
  chatbot = gr.Chatbot(
 
145
  )
146
 
147
  with gr.Column(scale=1):
148
+ # Model Configuration
149
+ with gr.Group():
150
+ model_name = gr.Dropdown(
151
+ choices=["gpt-4.1", "gpt-4.1-mini", "gpt-4o", "o3"],
152
+ value="gpt-4.1",
153
+ label="Model",
154
+ info="Select the model to use"
155
+ )
156
+ temperature = gr.Slider(
157
+ minimum=0.0,
158
+ maximum=1.0,
159
+ value=0.3,
160
+ step=0.1,
161
+ label="Temperature",
162
+ info="Higher values make output more random"
163
+ )
164
+ api_key = gr.Textbox(
165
+ label="OpenAI API Key",
166
+ placeholder="sk-...",
167
+ type="password",
168
+ info="Your OpenAI API key"
169
+ )
170
+
171
+ with gr.Group():
172
+ audio_files = gr.File(
173
+ file_count="multiple",
174
+ file_types=["audio"],
175
+ label="Upload Audio Files to Process",
176
+ height=150
177
+ )
178
+ output_audio_files = gr.File(
179
+ file_count="multiple",
180
+ file_types=["audio"],
181
+ label="Download Generated Audio",
182
+ height=150,
183
+ interactive=False,
184
+ visible=False # Start hidden
185
+ )
186
+
187
  with gr.Row(equal_height=True):
188
  msg = gr.Textbox(
189
  label="Describe what you want to do?",
190
  placeholder="e.g., 'Remove filler words and improve audio quality''",
191
  lines=3,
192
+ scale=6
193
  )
194
  send_btn = gr.Button("Ask", variant="primary", scale=1, size="lg")
195
 
196
+ # Error message component
197
+ error_msg = gr.Textbox(label="Error", visible=False)
198
+
199
  # Handle user input and bot response
200
+ def handle_submit(message, files, history, custom_history, model, temp, key):
201
+ new_msg, new_files, updated_history, updated_custom_history = user_input(
202
+ message, files, history, custom_history, model, temp, key
203
+ )
204
+ return new_msg, new_files, updated_history, updated_custom_history
205
 
206
  def handle_bot_response(history, audio_urls, custom_history):
207
  updated_history, output_files = bot_response_sync(history, audio_urls, custom_history)
208
+ output_visible = bool(output_files) # True if there are files, else False
209
+ return updated_history, gr.update(value=output_files, visible=output_visible), custom_history
210
 
211
  msg.submit(
212
  handle_submit,
213
+ [msg, audio_files, chatbot, custom_history_state, model_name, temperature, api_key],
214
  [msg, audio_files, chatbot, custom_history_state],
 
215
  queue=False
216
  ).then(
217
  handle_bot_response,
 
221
 
222
  send_btn.click(
223
  handle_submit,
224
+ [msg, audio_files, chatbot, custom_history_state, model_name, temperature, api_key],
225
  [msg, audio_files, chatbot, custom_history_state],
 
226
  queue=False
227
  ).then(
228
  handle_bot_response,