Siratish commited on
Commit
6d8b9a4
·
1 Parent(s): 7a98153

fix seed, clean code

Browse files
Files changed (1) hide show
  1. app.py +55 -86
app.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from importlib.resources import files
3
 
@@ -13,20 +25,6 @@ import sys
13
  sys.path.append('F5-TTS/src')
14
  sys.path.append('SmoothCache/SmoothCache')
15
 
16
- from f5_tts.infer.utils_infer import (
17
- cross_fade_duration,
18
- infer_process,
19
- load_model,
20
- load_vocoder,
21
- preprocess_ref_audio_text,
22
- speed
23
- )
24
- from smooth_cache_helper import SmoothCacheHelper
25
-
26
- import gradio as gr
27
- import numpy as np
28
- from functools import lru_cache
29
- from PIL import Image, ImageDraw
30
 
31
  try:
32
  import spaces
@@ -35,12 +33,14 @@ try:
35
  except ImportError:
36
  USING_SPACES = False
37
 
 
38
  def gpu_decorator(func):
39
  if USING_SPACES:
40
  return spaces.GPU(func)
41
  else:
42
  return func
43
 
 
44
  # Constants
45
  layer_names = ['attn', 'ff']
46
  colors_rgb = [(255, 103, 35), (0, 210, 106)] # orange, green
@@ -76,14 +76,19 @@ cache_schedule = {
76
  'ff': presets[default_preset]['ff'][:]
77
  }
78
 
79
- config = tomli.load(open(os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"), "rb"))
 
 
 
 
80
 
81
  model = config.get("model", "F5TTS_v1_Base")
82
  ckpt_file = config.get("ckpt_file", "")
83
  vocab_file = config.get("vocab_file", "")
84
 
85
  model_cfg = OmegaConf.load(
86
- config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
 
87
  )
88
  model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
89
  model_arc = model_cfg.model.arch
@@ -91,7 +96,8 @@ model_arc = model_cfg.model.arch
91
  repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
92
 
93
  if not ckpt_file:
94
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
 
95
 
96
  if not vocab_file:
97
  vocab_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/vocab.txt"))
@@ -103,10 +109,12 @@ ema_model = load_model(
103
 
104
  vocoder = load_vocoder()
105
 
 
106
  @gpu_decorator
107
  def render_grid(schedule: dict) -> np.ndarray:
108
  n_steps = len(schedule['attn'])
109
- img = Image.new("RGB", (n_steps * (cell_size + spacing), n_layers * (cell_size + spacing)), "white")
 
110
  draw = ImageDraw.Draw(img)
111
 
112
  for row in range(n_layers):
@@ -121,6 +129,7 @@ def render_grid(schedule: dict) -> np.ndarray:
121
 
122
  return np.array(img)
123
 
 
124
  @gpu_decorator
125
  def apply_preset(preset_name):
126
  global cache_schedule
@@ -130,6 +139,7 @@ def apply_preset(preset_name):
130
  cache_schedule['ff'] = schedule['ff'][:]
131
  return render_grid(cache_schedule), len(cache_schedule['attn'])
132
 
 
133
  @gpu_decorator
134
  def toggle_cell(evt: gr.SelectData):
135
  global cache_schedule
@@ -140,6 +150,7 @@ def toggle_cell(evt: gr.SelectData):
140
  cache_schedule[layer][col] ^= 1
141
  return render_grid(cache_schedule), "Custom"
142
 
 
143
  @gpu_decorator
144
  def reset_schedule(n_steps):
145
  global cache_schedule
@@ -149,46 +160,37 @@ def reset_schedule(n_steps):
149
  }
150
  return render_grid(cache_schedule), "Custom"
151
 
 
152
  @gpu_decorator
153
  def update_nfe(nfe_value):
154
  return reset_schedule(nfe_value)
155
 
 
156
  @gpu_decorator
157
  def load_default():
158
  return render_grid(cache_schedule), default_preset
159
 
 
160
  @lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
161
  @gpu_decorator
162
  def infer(
163
  ref_audio_orig,
164
  ref_text,
165
  gen_text,
166
- #model,
167
- #remove_silence,
168
- #seed,
169
- #cross_fade_duration=0.15,
170
  nfe_step=32,
171
- #speed=1,
172
- #show_info=gr.Info,
173
  ):
174
  global cache_schedule
175
- show_info=gr.Info
176
  if not ref_audio_orig:
177
  gr.Warning("Please provide reference audio.")
178
  return gr.update(), gr.update(), ref_text
179
 
180
- # Set inference seed
181
- # if seed < 0 or seed > 2**31 - 1:
182
- # gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
183
- seed = np.random.randint(0, 2**31 - 1)
184
- torch.manual_seed(seed)
185
- used_seed = seed
186
-
187
  if not gen_text.strip():
188
  gr.Warning("Please enter text to generate or upload a text file.")
189
  return gr.update(), gr.update(), ref_text
190
 
191
- ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
 
192
  start_time = time.time()
193
  final_wave, final_sample_rate, _ = infer_process(
194
  ref_audio,
@@ -206,7 +208,7 @@ def infer(
206
  cache_helper = SmoothCacheHelper(
207
  model=ema_model.transformer,
208
  block_classes=get_class("f5_tts.model.modules.DiTBlock"),
209
- components_to_wrap=['attn','ff'],
210
  schedule=cache_schedule
211
  )
212
  cache_helper.enable()
@@ -226,79 +228,46 @@ def infer(
226
  process_time_cache = time.time() - start_time
227
  cache_helper.disable()
228
 
229
- # Remove silence
230
- # if remove_silence:
231
- # with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
232
- # temp_path = f.name
233
- # try:
234
- # sf.write(temp_path, final_wave, final_sample_rate)
235
- # remove_silence_for_generated_wav(f.name)
236
- # final_wave, _ = torchaudio.load(f.name)
237
- # finally:
238
- # os.unlink(temp_path)
239
- # final_wave = final_wave.squeeze().cpu().numpy()
240
-
241
- # Save the spectrogram
242
- # with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
243
- # spectrogram_path = tmp_spectrogram.name
244
- # save_spectrogram(combined_spectrogram, spectrogram_path)
245
-
246
  return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), process_time, process_time_cache
247
 
248
 
249
  with gr.Blocks() as demo:
250
  gr.Markdown("## F5-TTS + SmoothCache")
251
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
252
- ref_text_input = gr.Textbox(
253
- label="Reference Text",
254
- #info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.",
255
- # lines=2,
256
- # scale=4,
257
- )
258
- gen_text_input = gr.Textbox(
259
- label="Text to Generate",
260
- # lines=10,
261
- # max_lines=40,
262
- # scale=4,
263
- )
264
  with gr.Row():
265
  with gr.Column(scale=0):
266
- preset_dropdown = gr.Dropdown(choices=list(presets.keys()) + ["Custom"], label="Choose Preset", value=default_preset)
267
- nfe_slider = gr.Slider(4, 64, value=len(cache_schedule['attn']), step=1, label="Number of Steps (NFE)")
 
 
268
  with gr.Column(scale=1):
269
- gr.Markdown("Click Grid to Customize Cache Schedule<br>🟧 = Compute Attn Layer / 🟩 = Compute FFN Layer / ⬜ = Cached Layer")
 
270
  image = gr.Image(type="numpy", label="", interactive=True, scale=1)
271
- #reset_btn = gr.Button("Reset to All Cached")
272
- #current_label = gr.Textbox(label="Current Preset", interactive=False)
273
  generate_btn = gr.Button("Synthesize", variant="primary")
274
  with gr.Row():
275
  with gr.Group():
276
  audio_output = gr.Audio(label="Synthesized Audio (No Cache)")
277
- process_time = gr.Textbox(label="⏱ Process Time", interactive=False)
 
278
  with gr.Group():
279
  audio_output_cache = gr.Audio(label="Synthesized Audio (Cache)")
280
- process_time_cache = gr.Textbox(label="⏱ Process Time", interactive=False)
 
281
 
282
  # Wire up logic
283
- preset_dropdown.change(fn=apply_preset, inputs=preset_dropdown, outputs=[image, nfe_slider])
284
- #preset_dropdown.change(fn=lambda x: x, inputs=preset_dropdown, outputs=current_label)
285
  image.select(fn=toggle_cell, outputs=[image, preset_dropdown])
286
- #reset_btn.click(fn=reset_schedule, inputs=nfe_slider, outputs=[image, preset_dropdown])
287
- nfe_slider.input(fn=update_nfe, inputs=nfe_slider, outputs=[image, preset_dropdown])
288
  generate_btn.click(
289
  infer,
290
- inputs=[
291
- ref_audio_input,
292
- ref_text_input,
293
- gen_text_input,
294
- #remove_silence,
295
- #randomize_seed,
296
- #np.random.randint(0, 2**31 - 1),
297
- #cross_fade_duration_slider,
298
- nfe_slider,
299
- #speed_slider,
300
- ],
301
- outputs=[audio_output, audio_output_cache, process_time, process_time_cache],
302
  )
303
  demo.load(fn=load_default, outputs=[image, preset_dropdown])
304
 
 
1
+ from PIL import Image, ImageDraw
2
+ from functools import lru_cache
3
+ import gradio as gr
4
+ from smooth_cache_helper import SmoothCacheHelper
5
+ from f5_tts.infer.utils_infer import (
6
+ cross_fade_duration,
7
+ infer_process,
8
+ load_model,
9
+ load_vocoder,
10
+ preprocess_ref_audio_text,
11
+ speed
12
+ )
13
  import os
14
  from importlib.resources import files
15
 
 
25
  sys.path.append('F5-TTS/src')
26
  sys.path.append('SmoothCache/SmoothCache')
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  try:
30
  import spaces
 
33
  except ImportError:
34
  USING_SPACES = False
35
 
36
+
37
  def gpu_decorator(func):
38
  if USING_SPACES:
39
  return spaces.GPU(func)
40
  else:
41
  return func
42
 
43
+
44
  # Constants
45
  layer_names = ['attn', 'ff']
46
  colors_rgb = [(255, 103, 35), (0, 210, 106)] # orange, green
 
76
  'ff': presets[default_preset]['ff'][:]
77
  }
78
 
79
+ seed = np.random.randint(0, 2**31 - 1)
80
+ torch.manual_seed(seed)
81
+
82
+ config = tomli.load(open(os.path.join(files("f5_tts").joinpath(
83
+ "infer/examples/basic"), "basic.toml"), "rb"))
84
 
85
  model = config.get("model", "F5TTS_v1_Base")
86
  ckpt_file = config.get("ckpt_file", "")
87
  vocab_file = config.get("vocab_file", "")
88
 
89
  model_cfg = OmegaConf.load(
90
+ config.get("model_cfg", str(
91
+ files("f5_tts").joinpath(f"configs/{model}.yaml")))
92
  )
93
  model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
94
  model_arc = model_cfg.model.arch
 
96
  repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
97
 
98
  if not ckpt_file:
99
+ ckpt_file = str(cached_path(
100
+ f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
101
 
102
  if not vocab_file:
103
  vocab_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/vocab.txt"))
 
109
 
110
  vocoder = load_vocoder()
111
 
112
+
113
  @gpu_decorator
114
  def render_grid(schedule: dict) -> np.ndarray:
115
  n_steps = len(schedule['attn'])
116
+ img = Image.new("RGB", (n_steps * (cell_size + spacing),
117
+ n_layers * (cell_size + spacing)), "white")
118
  draw = ImageDraw.Draw(img)
119
 
120
  for row in range(n_layers):
 
129
 
130
  return np.array(img)
131
 
132
+
133
  @gpu_decorator
134
  def apply_preset(preset_name):
135
  global cache_schedule
 
139
  cache_schedule['ff'] = schedule['ff'][:]
140
  return render_grid(cache_schedule), len(cache_schedule['attn'])
141
 
142
+
143
  @gpu_decorator
144
  def toggle_cell(evt: gr.SelectData):
145
  global cache_schedule
 
150
  cache_schedule[layer][col] ^= 1
151
  return render_grid(cache_schedule), "Custom"
152
 
153
+
154
  @gpu_decorator
155
  def reset_schedule(n_steps):
156
  global cache_schedule
 
160
  }
161
  return render_grid(cache_schedule), "Custom"
162
 
163
+
164
  @gpu_decorator
165
  def update_nfe(nfe_value):
166
  return reset_schedule(nfe_value)
167
 
168
+
169
  @gpu_decorator
170
  def load_default():
171
  return render_grid(cache_schedule), default_preset
172
 
173
+
174
  @lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
175
  @gpu_decorator
176
  def infer(
177
  ref_audio_orig,
178
  ref_text,
179
  gen_text,
 
 
 
 
180
  nfe_step=32,
 
 
181
  ):
182
  global cache_schedule
183
+ show_info = gr.Info
184
  if not ref_audio_orig:
185
  gr.Warning("Please provide reference audio.")
186
  return gr.update(), gr.update(), ref_text
187
 
 
 
 
 
 
 
 
188
  if not gen_text.strip():
189
  gr.Warning("Please enter text to generate or upload a text file.")
190
  return gr.update(), gr.update(), ref_text
191
 
192
+ ref_audio, ref_text = preprocess_ref_audio_text(
193
+ ref_audio_orig, ref_text, show_info=show_info)
194
  start_time = time.time()
195
  final_wave, final_sample_rate, _ = infer_process(
196
  ref_audio,
 
208
  cache_helper = SmoothCacheHelper(
209
  model=ema_model.transformer,
210
  block_classes=get_class("f5_tts.model.modules.DiTBlock"),
211
+ components_to_wrap=['attn', 'ff'],
212
  schedule=cache_schedule
213
  )
214
  cache_helper.enable()
 
228
  process_time_cache = time.time() - start_time
229
  cache_helper.disable()
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), process_time, process_time_cache
232
 
233
 
234
  with gr.Blocks() as demo:
235
  gr.Markdown("## F5-TTS + SmoothCache")
236
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
237
+ ref_text_input = gr.Textbox(label="Reference Text")
238
+ gen_text_input = gr.Textbox(label="Text to Generate")
 
 
 
 
 
 
 
 
 
 
239
  with gr.Row():
240
  with gr.Column(scale=0):
241
+ preset_dropdown = gr.Dropdown(choices=list(
242
+ presets.keys()) + ["Custom"], label="Choose Preset", value=default_preset)
243
+ nfe_slider = gr.Slider(4, 64, value=len(
244
+ cache_schedule['attn']), step=1, label="Number of Steps (NFE)")
245
  with gr.Column(scale=1):
246
+ gr.Markdown(
247
+ "Click Grid to Customize Cache Schedule<br>🟧 = Compute Attn Layer / 🟩 = Compute FFN Layer / ⬜ = Cached Layer")
248
  image = gr.Image(type="numpy", label="", interactive=True, scale=1)
 
 
249
  generate_btn = gr.Button("Synthesize", variant="primary")
250
  with gr.Row():
251
  with gr.Group():
252
  audio_output = gr.Audio(label="Synthesized Audio (No Cache)")
253
+ process_time = gr.Textbox(
254
+ label="⏱ Process Time", interactive=False)
255
  with gr.Group():
256
  audio_output_cache = gr.Audio(label="Synthesized Audio (Cache)")
257
+ process_time_cache = gr.Textbox(
258
+ label="⏱ Process Time", interactive=False)
259
 
260
  # Wire up logic
261
+ preset_dropdown.change(
262
+ fn=apply_preset, inputs=preset_dropdown, outputs=[image, nfe_slider])
263
  image.select(fn=toggle_cell, outputs=[image, preset_dropdown])
264
+ nfe_slider.input(fn=update_nfe, inputs=nfe_slider,
265
+ outputs=[image, preset_dropdown])
266
  generate_btn.click(
267
  infer,
268
+ inputs=[ref_audio_input, ref_text_input, gen_text_input, nfe_slider],
269
+ outputs=[audio_output, audio_output_cache,
270
+ process_time, process_time_cache],
 
 
 
 
 
 
 
 
 
271
  )
272
  demo.load(fn=load_default, outputs=[image, preset_dropdown])
273