Siratish commited on
Commit
46224e4
·
1 Parent(s): 91dde72

improve performance

Browse files
Files changed (1) hide show
  1. app.py +41 -23
app.py CHANGED
@@ -75,6 +75,13 @@ cache_schedule = {
75
  'ff': presets[default_preset]['ff'][:]
76
  }
77
 
 
 
 
 
 
 
 
78
  seed = np.random.randint(0, 2**31 - 1)
79
  torch.manual_seed(seed)
80
 
@@ -170,7 +177,6 @@ def load_default():
170
  return render_grid(cache_schedule), default_preset
171
 
172
 
173
- @lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
174
  @gpu_decorator
175
  def infer(
176
  ref_audio_orig,
@@ -178,32 +184,36 @@ def infer(
178
  gen_text,
179
  nfe_step=32,
180
  ):
181
- global cache_schedule
182
  show_info = gr.Info
183
  if not ref_audio_orig:
184
  gr.Warning("Please provide reference audio.")
185
- return gr.update(), gr.update(), ref_text
186
 
187
  if not gen_text.strip():
188
- gr.Warning("Please enter text to generate or upload a text file.")
189
- return gr.update(), gr.update(), ref_text
190
 
191
  ref_audio, ref_text = preprocess_ref_audio_text(
192
  ref_audio_orig, ref_text, show_info=show_info)
193
- start_time = time.time()
194
- final_wave, final_sample_rate, _ = infer_process(
195
- ref_audio,
196
- ref_text,
197
- gen_text,
198
- ema_model,
199
- vocoder,
200
- cross_fade_duration=cross_fade_duration,
201
- nfe_step=nfe_step,
202
- speed=speed,
203
- show_info=show_info,
204
- progress=gr.Progress(),
205
- )
206
- process_time = time.time() - start_time
 
 
 
 
207
  cache_helper = SmoothCacheHelper(
208
  model=ema_model.transformer,
209
  block_classes=get_class("f5_tts.model.modules.DiTBlock"),
@@ -227,13 +237,21 @@ def infer(
227
  process_time_cache = time.time() - start_time
228
  cache_helper.disable()
229
 
230
- return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), process_time, process_time_cache
 
 
 
 
 
 
 
 
231
 
232
 
233
  with gr.Blocks() as demo:
234
  gr.Markdown("## F5-TTS + SmoothCache")
235
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
236
- ref_text_input = gr.Textbox(label="Reference Text")
237
  gen_text_input = gr.Textbox(label="Text to Generate")
238
  with gr.Row():
239
  with gr.Column(scale=0):
@@ -260,12 +278,12 @@ with gr.Blocks() as demo:
260
  preset_dropdown.change(
261
  fn=apply_preset, inputs=preset_dropdown, outputs=[image, nfe_slider])
262
  image.select(fn=toggle_cell, outputs=[image, preset_dropdown])
263
- nfe_slider.input(fn=update_nfe, inputs=nfe_slider,
264
  outputs=[image, preset_dropdown])
265
  generate_btn.click(
266
  infer,
267
  inputs=[ref_audio_input, ref_text_input, gen_text_input, nfe_slider],
268
- outputs=[audio_output, audio_output_cache,
269
  process_time, process_time_cache],
270
  )
271
  demo.load(fn=load_default, outputs=[image, preset_dropdown])
 
75
  'ff': presets[default_preset]['ff'][:]
76
  }
77
 
78
+ recent_input = {
79
+ "ref_audio": None,
80
+ "ref_text": None,
81
+ "gen_text": None,
82
+ "nfe_step": None
83
+ }
84
+
85
  seed = np.random.randint(0, 2**31 - 1)
86
  torch.manual_seed(seed)
87
 
 
177
  return render_grid(cache_schedule), default_preset
178
 
179
 
 
180
  @gpu_decorator
181
  def infer(
182
  ref_audio_orig,
 
184
  gen_text,
185
  nfe_step=32,
186
  ):
187
+ global cache_schedule, recent_input
188
  show_info = gr.Info
189
  if not ref_audio_orig:
190
  gr.Warning("Please provide reference audio.")
191
+ return gr.update(), gr.update(), ref_text, gr.update(), gr.update()
192
 
193
  if not gen_text.strip():
194
+ gr.Warning("Please enter text to generate.")
195
+ return gr.update(), gr.update(), ref_text, gr.update(), gr.update()
196
 
197
  ref_audio, ref_text = preprocess_ref_audio_text(
198
  ref_audio_orig, ref_text, show_info=show_info)
199
+ skip_no_cache = False
200
+ if recent_input["ref_audio"] == ref_audio_orig and recent_input["ref_text"] == ref_text and recent_input["gen_text"] == gen_text and recent_input["nfe_step"] == nfe_step:
201
+ skip_no_cache = True
202
+ if not skip_no_cache:
203
+ start_time = time.time()
204
+ final_wave, final_sample_rate, _ = infer_process(
205
+ ref_audio,
206
+ ref_text,
207
+ gen_text,
208
+ ema_model,
209
+ vocoder,
210
+ cross_fade_duration=cross_fade_duration,
211
+ nfe_step=nfe_step,
212
+ speed=speed,
213
+ show_info=show_info,
214
+ progress=gr.Progress(),
215
+ )
216
+ process_time = time.time() - start_time
217
  cache_helper = SmoothCacheHelper(
218
  model=ema_model.transformer,
219
  block_classes=get_class("f5_tts.model.modules.DiTBlock"),
 
237
  process_time_cache = time.time() - start_time
238
  cache_helper.disable()
239
 
240
+ recent_input["ref_audio"] = ref_audio_orig
241
+ recent_input["ref_text"] = ref_text
242
+ recent_input["gen_text"] = gen_text
243
+ recent_input["nfe_step"] = nfe_step
244
+
245
+ if skip_no_cache:
246
+ print("skip")
247
+ return gr.update(), (final_sample_rate_cache, final_wave_cache), ref_text, gr.update(), process_time_cache
248
+ return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), ref_text, process_time, process_time_cache
249
 
250
 
251
  with gr.Blocks() as demo:
252
  gr.Markdown("## F5-TTS + SmoothCache")
253
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
254
+ ref_text_input = gr.Textbox(label="Reference Text (Optional)")
255
  gen_text_input = gr.Textbox(label="Text to Generate")
256
  with gr.Row():
257
  with gr.Column(scale=0):
 
278
  preset_dropdown.change(
279
  fn=apply_preset, inputs=preset_dropdown, outputs=[image, nfe_slider])
280
  image.select(fn=toggle_cell, outputs=[image, preset_dropdown])
281
+ nfe_slider.release(fn=update_nfe, inputs=nfe_slider,
282
  outputs=[image, preset_dropdown])
283
  generate_btn.click(
284
  infer,
285
  inputs=[ref_audio_input, ref_text_input, gen_text_input, nfe_slider],
286
+ outputs=[audio_output, audio_output_cache, ref_text_input,
287
  process_time, process_time_cache],
288
  )
289
  demo.load(fn=load_default, outputs=[image, preset_dropdown])