Spaces:
Sleeping
Sleeping
improve performance
Browse files
app.py
CHANGED
|
@@ -75,6 +75,13 @@ cache_schedule = {
|
|
| 75 |
'ff': presets[default_preset]['ff'][:]
|
| 76 |
}
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
seed = np.random.randint(0, 2**31 - 1)
|
| 79 |
torch.manual_seed(seed)
|
| 80 |
|
|
@@ -170,7 +177,6 @@ def load_default():
|
|
| 170 |
return render_grid(cache_schedule), default_preset
|
| 171 |
|
| 172 |
|
| 173 |
-
@lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
|
| 174 |
@gpu_decorator
|
| 175 |
def infer(
|
| 176 |
ref_audio_orig,
|
|
@@ -178,32 +184,36 @@ def infer(
|
|
| 178 |
gen_text,
|
| 179 |
nfe_step=32,
|
| 180 |
):
|
| 181 |
-
global cache_schedule
|
| 182 |
show_info = gr.Info
|
| 183 |
if not ref_audio_orig:
|
| 184 |
gr.Warning("Please provide reference audio.")
|
| 185 |
-
return gr.update(), gr.update(), ref_text
|
| 186 |
|
| 187 |
if not gen_text.strip():
|
| 188 |
-
gr.Warning("Please enter text to generate
|
| 189 |
-
return gr.update(), gr.update(), ref_text
|
| 190 |
|
| 191 |
ref_audio, ref_text = preprocess_ref_audio_text(
|
| 192 |
ref_audio_orig, ref_text, show_info=show_info)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
cache_helper = SmoothCacheHelper(
|
| 208 |
model=ema_model.transformer,
|
| 209 |
block_classes=get_class("f5_tts.model.modules.DiTBlock"),
|
|
@@ -227,13 +237,21 @@ def infer(
|
|
| 227 |
process_time_cache = time.time() - start_time
|
| 228 |
cache_helper.disable()
|
| 229 |
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
with gr.Blocks() as demo:
|
| 234 |
gr.Markdown("## F5-TTS + SmoothCache")
|
| 235 |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
| 236 |
-
ref_text_input = gr.Textbox(label="Reference Text")
|
| 237 |
gen_text_input = gr.Textbox(label="Text to Generate")
|
| 238 |
with gr.Row():
|
| 239 |
with gr.Column(scale=0):
|
|
@@ -260,12 +278,12 @@ with gr.Blocks() as demo:
|
|
| 260 |
preset_dropdown.change(
|
| 261 |
fn=apply_preset, inputs=preset_dropdown, outputs=[image, nfe_slider])
|
| 262 |
image.select(fn=toggle_cell, outputs=[image, preset_dropdown])
|
| 263 |
-
nfe_slider.
|
| 264 |
outputs=[image, preset_dropdown])
|
| 265 |
generate_btn.click(
|
| 266 |
infer,
|
| 267 |
inputs=[ref_audio_input, ref_text_input, gen_text_input, nfe_slider],
|
| 268 |
-
outputs=[audio_output, audio_output_cache,
|
| 269 |
process_time, process_time_cache],
|
| 270 |
)
|
| 271 |
demo.load(fn=load_default, outputs=[image, preset_dropdown])
|
|
|
|
| 75 |
'ff': presets[default_preset]['ff'][:]
|
| 76 |
}
|
| 77 |
|
| 78 |
+
recent_input = {
|
| 79 |
+
"ref_audio": None,
|
| 80 |
+
"ref_text": None,
|
| 81 |
+
"gen_text": None,
|
| 82 |
+
"nfe_step": None
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
seed = np.random.randint(0, 2**31 - 1)
|
| 86 |
torch.manual_seed(seed)
|
| 87 |
|
|
|
|
| 177 |
return render_grid(cache_schedule), default_preset
|
| 178 |
|
| 179 |
|
|
|
|
| 180 |
@gpu_decorator
|
| 181 |
def infer(
|
| 182 |
ref_audio_orig,
|
|
|
|
| 184 |
gen_text,
|
| 185 |
nfe_step=32,
|
| 186 |
):
|
| 187 |
+
global cache_schedule, recent_input
|
| 188 |
show_info = gr.Info
|
| 189 |
if not ref_audio_orig:
|
| 190 |
gr.Warning("Please provide reference audio.")
|
| 191 |
+
return gr.update(), gr.update(), ref_text, gr.update(), gr.update()
|
| 192 |
|
| 193 |
if not gen_text.strip():
|
| 194 |
+
gr.Warning("Please enter text to generate.")
|
| 195 |
+
return gr.update(), gr.update(), ref_text, gr.update(), gr.update()
|
| 196 |
|
| 197 |
ref_audio, ref_text = preprocess_ref_audio_text(
|
| 198 |
ref_audio_orig, ref_text, show_info=show_info)
|
| 199 |
+
skip_no_cache = False
|
| 200 |
+
if recent_input["ref_audio"] == ref_audio_orig and recent_input["ref_text"] == ref_text and recent_input["gen_text"] == gen_text and recent_input["nfe_step"] == nfe_step:
|
| 201 |
+
skip_no_cache = True
|
| 202 |
+
if not skip_no_cache:
|
| 203 |
+
start_time = time.time()
|
| 204 |
+
final_wave, final_sample_rate, _ = infer_process(
|
| 205 |
+
ref_audio,
|
| 206 |
+
ref_text,
|
| 207 |
+
gen_text,
|
| 208 |
+
ema_model,
|
| 209 |
+
vocoder,
|
| 210 |
+
cross_fade_duration=cross_fade_duration,
|
| 211 |
+
nfe_step=nfe_step,
|
| 212 |
+
speed=speed,
|
| 213 |
+
show_info=show_info,
|
| 214 |
+
progress=gr.Progress(),
|
| 215 |
+
)
|
| 216 |
+
process_time = time.time() - start_time
|
| 217 |
cache_helper = SmoothCacheHelper(
|
| 218 |
model=ema_model.transformer,
|
| 219 |
block_classes=get_class("f5_tts.model.modules.DiTBlock"),
|
|
|
|
| 237 |
process_time_cache = time.time() - start_time
|
| 238 |
cache_helper.disable()
|
| 239 |
|
| 240 |
+
recent_input["ref_audio"] = ref_audio_orig
|
| 241 |
+
recent_input["ref_text"] = ref_text
|
| 242 |
+
recent_input["gen_text"] = gen_text
|
| 243 |
+
recent_input["nfe_step"] = nfe_step
|
| 244 |
+
|
| 245 |
+
if skip_no_cache:
|
| 246 |
+
print("skip")
|
| 247 |
+
return gr.update(), (final_sample_rate_cache, final_wave_cache), ref_text, gr.update(), process_time_cache
|
| 248 |
+
return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), ref_text, process_time, process_time_cache
|
| 249 |
|
| 250 |
|
| 251 |
with gr.Blocks() as demo:
|
| 252 |
gr.Markdown("## F5-TTS + SmoothCache")
|
| 253 |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
| 254 |
+
ref_text_input = gr.Textbox(label="Reference Text (Optional)")
|
| 255 |
gen_text_input = gr.Textbox(label="Text to Generate")
|
| 256 |
with gr.Row():
|
| 257 |
with gr.Column(scale=0):
|
|
|
|
| 278 |
preset_dropdown.change(
|
| 279 |
fn=apply_preset, inputs=preset_dropdown, outputs=[image, nfe_slider])
|
| 280 |
image.select(fn=toggle_cell, outputs=[image, preset_dropdown])
|
| 281 |
+
nfe_slider.release(fn=update_nfe, inputs=nfe_slider,
|
| 282 |
outputs=[image, preset_dropdown])
|
| 283 |
generate_btn.click(
|
| 284 |
infer,
|
| 285 |
inputs=[ref_audio_input, ref_text_input, gen_text_input, nfe_slider],
|
| 286 |
+
outputs=[audio_output, audio_output_cache, ref_text_input,
|
| 287 |
process_time, process_time_cache],
|
| 288 |
)
|
| 289 |
demo.load(fn=load_default, outputs=[image, preset_dropdown])
|