ashishkblink commited on
Commit
303ceb6
·
verified ·
1 Parent(s): b56816d

Upload f5_tts/infer/infer_cli_batch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/infer/infer_cli_batch.py +245 -0
f5_tts/infer/infer_cli_batch.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import codecs
3
+ import os
4
+ import re
5
+ from importlib.resources import files
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import tomli
11
+ from cached_path import cached_path
12
+ import pandas as pd
13
+
14
+ from f5_tts.infer.utils_infer import (
15
+ infer_process,
16
+ load_model,
17
+ load_vocoder,
18
+ preprocess_ref_audio_text,
19
+ remove_silence_for_generated_wav,
20
+ )
21
+ from f5_tts.model import DiT, UNetT
22
+
23
+ parser = argparse.ArgumentParser(
24
+ prog="python3 infer-cli.py",
25
+ description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
26
+ epilog="Specify options above to override one or more settings from config.",
27
+ )
28
+ parser.add_argument(
29
+ "-c",
30
+ "--config",
31
+ help="Configuration file. Default=infer/examples/basic/basic.toml",
32
+ default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
33
+ )
34
+ parser.add_argument(
35
+ "-m",
36
+ "--model",
37
+ help="F5-TTS | E2-TTS",
38
+ )
39
+ parser.add_argument(
40
+ "-p",
41
+ "--ckpt_file",
42
+ help="The Checkpoint .pt",
43
+ )
44
+ parser.add_argument(
45
+ "-v",
46
+ "--vocab_file",
47
+ help="The vocab .txt",
48
+ )
49
+ parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
50
+ parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
51
+ parser.add_argument(
52
+ "-t",
53
+ "--gen_text",
54
+ type=str,
55
+ help="Text to generate.",
56
+ )
57
+ parser.add_argument(
58
+ "-f",
59
+ "--gen_file",
60
+ type=str,
61
+ help="File with text to generate. Ignores --gen_text",
62
+ )
63
+ parser.add_argument(
64
+ "-o",
65
+ "--output_dir",
66
+ type=str,
67
+ help="Path to output folder..",
68
+ )
69
+ parser.add_argument(
70
+ "-w",
71
+ "--output_file",
72
+ type=str,
73
+ help="Filename of output file..",
74
+ )
75
+ parser.add_argument(
76
+ "--remove_silence",
77
+ help="Remove silence.",
78
+ )
79
+ parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
80
+ parser.add_argument(
81
+ "--load_vocoder_from_local",
82
+ action="store_true",
83
+ help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
84
+ )
85
+ parser.add_argument(
86
+ "--speed",
87
+ type=float,
88
+ default=1.0,
89
+ help="Adjust the speed of the audio generation (default: 1.0)",
90
+ )
91
+ args = parser.parse_args()
92
+
93
+ config = tomli.load(open(args.config, "rb"))
94
+
95
+ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
96
+ ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
97
+ gen_file = args.gen_file if args.gen_file else config["gen_file"]
98
+
99
+
100
+ if gen_file:
101
+ # Read texts from CSV file
102
+ df = pd.read_csv(gen_file)
103
+ text_list = df['text'].tolist()
104
+ else:
105
+ # If no file provided, use single text
106
+ gen_text = args.gen_text if args.gen_text else config["gen_text"]
107
+ text_list = [gen_text]
108
+ output_dir = args.output_dir if args.output_dir else config["output_dir"]
109
+ output_file = args.output_file if args.output_file else config["output_file"]
110
+ model = args.model if args.model else config["model"]
111
+ ckpt_file = args.ckpt_file if args.ckpt_file else ""
112
+ vocab_file = args.vocab_file if args.vocab_file else ""
113
+ remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
114
+ speed = args.speed
115
+
116
+ wave_path = Path(output_dir) / output_file
117
+ # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
118
+
119
+
120
+ # patches for pip pkg user
121
+ if "infer/examples/" in ref_audio:
122
+ ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
123
+ if "infer/examples/" in gen_file:
124
+ gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
125
+ if "voices" in config:
126
+ for voice in config["voices"]:
127
+ voice_ref_audio = config["voices"][voice]["ref_audio"]
128
+ if "infer/examples/" in voice_ref_audio:
129
+ config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
130
+
131
+ vocoder_name = args.vocoder_name
132
+ mel_spec_type = args.vocoder_name
133
+ if vocoder_name == "vocos":
134
+ vocoder_local_path = "../checkpoints/vocos-mel-24khz"
135
+ elif vocoder_name == "bigvgan":
136
+ vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
137
+
138
+ vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
139
+
140
+
141
+ # load models
142
+ if model == "F5-TTS":
143
+ model_cls = DiT
144
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
145
+ if ckpt_file == "":
146
+ if vocoder_name == "vocos":
147
+ repo_name = "F5-TTS"
148
+ exp_name = "F5TTS_Base"
149
+ ckpt_step = 1200000
150
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
151
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
152
+ elif vocoder_name == "bigvgan":
153
+ repo_name = "F5-TTS"
154
+ exp_name = "F5TTS_Base_bigvgan"
155
+ ckpt_step = 1250000
156
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
157
+
158
+ elif model == "E2-TTS":
159
+ assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
160
+ model_cls = UNetT
161
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
162
+ if ckpt_file == "":
163
+ repo_name = "E2-TTS"
164
+ exp_name = "E2TTS_Base"
165
+ ckpt_step = 1200000
166
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
167
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
168
+
169
+
170
+ print(f"Using {model}...")
171
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
172
+
173
+
174
+ def main_process(ref_audio, ref_text, text_list, model_obj, mel_spec_type, remove_silence, speed):
175
+ main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
176
+ if "voices" not in config:
177
+ voices = {"main": main_voice}
178
+ else:
179
+ voices = config["voices"]
180
+ voices["main"] = main_voice
181
+ for voice in voices:
182
+ voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
183
+ voices[voice]["ref_audio"], voices[voice]["ref_text"]
184
+ )
185
+ print("Voice:", voice)
186
+ print("Ref_audio:", voices[voice]["ref_audio"])
187
+ print("Ref_text:", voices[voice]["ref_text"])
188
+
189
+ # Process each text in the list
190
+ for idx, text_gen in enumerate(text_list):
191
+ generated_audio_segments = []
192
+ reg1 = r"(?=\[\w+\])"
193
+ chunks = re.split(reg1, text_gen)
194
+ reg2 = r"\[(\w+)\]"
195
+ for text in chunks:
196
+ if not text.strip():
197
+ continue
198
+ match = re.match(reg2, text)
199
+ if match:
200
+ voice = match[1]
201
+ else:
202
+ print("No voice tag found, using main.")
203
+ voice = "main"
204
+ if voice not in voices:
205
+ print(f"Voice {voice} not found, using main.")
206
+ voice = "main"
207
+ text = re.sub(reg2, "", text)
208
+ gen_text = text.strip()
209
+ ref_audio = voices[voice]["ref_audio"]
210
+ ref_text = voices[voice]["ref_text"]
211
+ print(f"Voice: {voice}")
212
+ audio, final_sample_rate, spectragram = infer_process(
213
+ ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
214
+ )
215
+ generated_audio_segments.append(audio)
216
+
217
+ if generated_audio_segments:
218
+ final_wave = np.concatenate(generated_audio_segments)
219
+
220
+ if not os.path.exists(output_dir):
221
+ os.makedirs(output_dir)
222
+
223
+ # Get first 3 words from the text
224
+ first_three_words = '_'.join(text_gen.split()[:3])
225
+ # Remove any special characters that might cause issues in filenames
226
+ first_three_words = re.sub(r'[^\w\s-]', '', first_three_words)
227
+ # Create filename with index and first 3 words
228
+ filename = f"{Path(output_file).stem}__sentence{(idx+1):03d}_{first_three_words}{Path(output_file).suffix}"
229
+
230
+ wave_path = Path(output_dir) / filename
231
+
232
+ with open(wave_path, "wb") as f:
233
+ sf.write(f.name, final_wave, final_sample_rate)
234
+ # Remove silence
235
+ if remove_silence:
236
+ remove_silence_for_generated_wav(f.name)
237
+ print(f"Generated audio saved to: {f.name}")
238
+
239
+
240
+ def main():
241
+ main_process(ref_audio, ref_text, text_list, ema_model, mel_spec_type, remove_silence, speed)
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()