jhansss commited on
Commit
4067f95
·
1 Parent(s): ff8bce5

Refactor run_pipeline and update_metrics methods to support inference on HF ZeroGPU

Browse files
Files changed (1) hide show
  1. interface.py +32 -28
interface.py CHANGED
@@ -1,5 +1,6 @@
1
  import time
2
  import uuid
 
3
 
4
  import gradio as gr
5
  import spaces
@@ -9,6 +10,34 @@ from characters import CHARACTERS
9
  from pipeline import SingingDialoguePipeline
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class GradioInterface:
13
  def __init__(self, options_config: str, default_config: str):
14
  self.options = self.load_config(options_config)
@@ -148,12 +177,12 @@ class GradioInterface:
148
  fn=self.update_voice, inputs=voice_radio, outputs=voice_radio
149
  )
150
  mic_input.change(
151
- fn=self.run_pipeline,
152
  inputs=mic_input,
153
  outputs=[interaction_log, audio_output],
154
  )
155
  metrics_button.click(
156
- fn=self.update_metrics,
157
  inputs=audio_output,
158
  outputs=[metrics_output],
159
  )
@@ -161,6 +190,7 @@ class GradioInterface:
161
  return demo
162
  except Exception as e:
163
  import traceback
 
164
  print(traceback.format_exc())
165
  return gr.Blocks()
166
 
@@ -212,29 +242,3 @@ class GradioInterface:
212
  def update_voice(self, voice):
213
  self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
214
  return gr.update(value=voice)
215
-
216
- @spaces.GPU
217
- def run_pipeline(self, audio_path):
218
- if not audio_path:
219
- return gr.update(value=None), gr.update(value=None)
220
- tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
221
- self.results = self.pipeline.run(
222
- audio_path,
223
- self.svs_model_map[self.current_svs_model]["lang"],
224
- self.character_info[self.current_character].prompt,
225
- self.current_voice,
226
- output_audio_path=tmp_file,
227
- )
228
- formatted_logs = f"ASR: {self.results['asr_text']}\nLLM: {self.results['llm_text']}"
229
- return gr.update(value=formatted_logs), gr.update(
230
- value=self.results["output_audio_path"]
231
- )
232
-
233
- @spaces.GPU
234
- def update_metrics(self, audio_path):
235
- if not audio_path or not self.results:
236
- return gr.update(value="")
237
- results = self.pipeline.evaluate(audio_path, **self.results)
238
- results.update(self.results.get("metrics", {}))
239
- formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
240
- return gr.update(value=formatted_metrics)
 
1
  import time
2
  import uuid
3
+ from functools import partial
4
 
5
  import gradio as gr
6
  import spaces
 
10
  from pipeline import SingingDialoguePipeline
11
 
12
 
13
+ @spaces.GPU(duration=120)
14
+ def run_pipeline(audio_path, interface):
15
+ if not audio_path:
16
+ return gr.update(value=None), gr.update(value=None)
17
+ tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
18
+ results = interface.pipeline.run(
19
+ audio_path,
20
+ interface.svs_model_map[interface.current_svs_model]["lang"],
21
+ interface.character_info[interface.current_character].prompt,
22
+ interface.current_voice,
23
+ output_audio_path=tmp_file,
24
+ )
25
+ formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
26
+ return gr.update(value=formatted_logs), gr.update(
27
+ value=results["output_audio_path"]
28
+ )
29
+
30
+
31
+ @spaces.GPU(duration=120)
32
+ def update_metrics(audio_path, interface):
33
+ if not audio_path or not interface.results:
34
+ return gr.update(value="")
35
+ results = interface.pipeline.evaluate(audio_path, **interface.results)
36
+ results.update(interface.results.get("metrics", {}))
37
+ formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
38
+ return gr.update(value=formatted_metrics)
39
+
40
+
41
  class GradioInterface:
42
  def __init__(self, options_config: str, default_config: str):
43
  self.options = self.load_config(options_config)
 
177
  fn=self.update_voice, inputs=voice_radio, outputs=voice_radio
178
  )
179
  mic_input.change(
180
+ fn=partial(run_pipeline, interface=self),
181
  inputs=mic_input,
182
  outputs=[interaction_log, audio_output],
183
  )
184
  metrics_button.click(
185
+ fn=partial(update_metrics, interface=self),
186
  inputs=audio_output,
187
  outputs=[metrics_output],
188
  )
 
190
  return demo
191
  except Exception as e:
192
  import traceback
193
+
194
  print(traceback.format_exc())
195
  return gr.Blocks()
196
 
 
242
  def update_voice(self, voice):
243
  self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
244
  return gr.update(value=voice)