ashishkblink commited on
Commit
591aecb
·
verified ·
1 Parent(s): e53ebe7

Upload f5_tts/infer/infer_cli.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/infer/infer_cli.py +226 -0
f5_tts/infer/infer_cli.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import codecs
3
+ import os
4
+ import re
5
+ from importlib.resources import files
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import tomli
11
+ from cached_path import cached_path
12
+
13
+ from f5_tts.infer.utils_infer import (
14
+ infer_process,
15
+ load_model,
16
+ load_vocoder,
17
+ preprocess_ref_audio_text,
18
+ remove_silence_for_generated_wav,
19
+ )
20
+ from f5_tts.model import DiT, UNetT
21
+
22
+ parser = argparse.ArgumentParser(
23
+ prog="python3 infer-cli.py",
24
+ description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
25
+ epilog="Specify options above to override one or more settings from config.",
26
+ )
27
+ parser.add_argument(
28
+ "-c",
29
+ "--config",
30
+ help="Configuration file. Default=infer/examples/basic/basic.toml",
31
+ default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
32
+ )
33
+ parser.add_argument(
34
+ "-m",
35
+ "--model",
36
+ help="F5-TTS | E2-TTS",
37
+ )
38
+ parser.add_argument(
39
+ "-p",
40
+ "--ckpt_file",
41
+ help="The Checkpoint .pt",
42
+ )
43
+ parser.add_argument(
44
+ "-v",
45
+ "--vocab_file",
46
+ help="The vocab .txt",
47
+ )
48
+ parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
49
+ parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
50
+ parser.add_argument(
51
+ "-t",
52
+ "--gen_text",
53
+ type=str,
54
+ help="Text to generate.",
55
+ )
56
+ parser.add_argument(
57
+ "-f",
58
+ "--gen_file",
59
+ type=str,
60
+ help="File with text to generate. Ignores --gen_text",
61
+ )
62
+ parser.add_argument(
63
+ "-o",
64
+ "--output_dir",
65
+ type=str,
66
+ help="Path to output folder..",
67
+ )
68
+ parser.add_argument(
69
+ "-w",
70
+ "--output_file",
71
+ type=str,
72
+ help="Filename of output file..",
73
+ )
74
+ parser.add_argument(
75
+ "--remove_silence",
76
+ help="Remove silence.",
77
+ )
78
+ parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
79
+ parser.add_argument(
80
+ "--load_vocoder_from_local",
81
+ action="store_true",
82
+ help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
83
+ )
84
+ parser.add_argument(
85
+ "--speed",
86
+ type=float,
87
+ default=1.0,
88
+ help="Adjust the speed of the audio generation (default: 1.0)",
89
+ )
90
+ args = parser.parse_args()
91
+
92
+ config = tomli.load(open(args.config, "rb"))
93
+
94
+ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
95
+ ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
96
+ gen_text = args.gen_text if args.gen_text else config["gen_text"]
97
+ gen_file = args.gen_file if args.gen_file else config["gen_file"]
98
+
99
+ # patches for pip pkg user
100
+ if "infer/examples/" in ref_audio:
101
+ ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
102
+ if "infer/examples/" in gen_file:
103
+ gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
104
+ if "voices" in config:
105
+ for voice in config["voices"]:
106
+ voice_ref_audio = config["voices"][voice]["ref_audio"]
107
+ if "infer/examples/" in voice_ref_audio:
108
+ config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
109
+
110
+ if gen_file:
111
+ gen_text = codecs.open(gen_file, "r", "utf-8").read()
112
+ output_dir = args.output_dir if args.output_dir else config["output_dir"]
113
+ output_file = args.output_file if args.output_file else config["output_file"]
114
+ model = args.model if args.model else config["model"]
115
+ ckpt_file = args.ckpt_file if args.ckpt_file else ""
116
+ vocab_file = args.vocab_file if args.vocab_file else ""
117
+ remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
118
+ speed = args.speed
119
+
120
+ wave_path = Path(output_dir) / output_file
121
+ # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
122
+
123
+ vocoder_name = args.vocoder_name
124
+ mel_spec_type = args.vocoder_name
125
+ if vocoder_name == "vocos":
126
+ vocoder_local_path = "../checkpoints/vocos-mel-24khz"
127
+ elif vocoder_name == "bigvgan":
128
+ vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
129
+
130
+ vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
131
+
132
+
133
+ # load models
134
+ if model == "F5-TTS":
135
+ model_cls = DiT
136
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
137
+ if ckpt_file == "":
138
+ if vocoder_name == "vocos":
139
+ repo_name = "F5-TTS"
140
+ exp_name = "F5TTS_Base"
141
+ ckpt_step = 1200000
142
+ ckpt_file = "/home/tts/ttsteam/repos/en_f5/F5-TTS/ckpts/expresso/model_356000.pt"
143
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
144
+ elif vocoder_name == "bigvgan":
145
+ repo_name = "F5-TTS"
146
+ exp_name = "F5TTS_Base_bigvgan"
147
+ ckpt_step = 1250000
148
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
149
+
150
+ elif model == "E2-TTS":
151
+ assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
152
+ model_cls = UNetT
153
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
154
+ if ckpt_file == "":
155
+ repo_name = "E2-TTS"
156
+ exp_name = "E2TTS_Base"
157
+ ckpt_step = 1200000
158
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
159
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
160
+
161
+
162
+ print(f"Using {model}...")
163
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
164
+
165
+
166
+ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
167
+ main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
168
+ if "voices" not in config:
169
+ voices = {"main": main_voice}
170
+ else:
171
+ voices = config["voices"]
172
+ voices["main"] = main_voice
173
+ for voice in voices:
174
+ voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
175
+ voices[voice]["ref_audio"], voices[voice]["ref_text"]
176
+ )
177
+ print("Voice:", voice)
178
+ print("Ref_audio:", voices[voice]["ref_audio"])
179
+ print("Ref_text:", voices[voice]["ref_text"])
180
+
181
+ generated_audio_segments = []
182
+ reg1 = r"(?=\[\w+\])"
183
+ chunks = re.split(reg1, text_gen)
184
+ reg2 = r"\[(\w+)\]"
185
+ for text in chunks:
186
+ if not text.strip():
187
+ continue
188
+ match = re.match(reg2, text)
189
+ if match:
190
+ voice = match[1]
191
+ else:
192
+ print("No voice tag found, using main.")
193
+ voice = "main"
194
+ if voice not in voices:
195
+ print(f"Voice {voice} not found, using main.")
196
+ voice = "main"
197
+ text = re.sub(reg2, "", text)
198
+ gen_text = text.strip()
199
+ ref_audio = voices[voice]["ref_audio"]
200
+ ref_text = voices[voice]["ref_text"]
201
+ print(f"Voice: {voice}")
202
+ audio, final_sample_rate, spectragram = infer_process(
203
+ ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
204
+ )
205
+ generated_audio_segments.append(audio)
206
+
207
+ if generated_audio_segments:
208
+ final_wave = np.concatenate(generated_audio_segments)
209
+
210
+ if not os.path.exists(output_dir):
211
+ os.makedirs(output_dir)
212
+
213
+ with open(wave_path, "wb") as f:
214
+ sf.write(f.name, final_wave, final_sample_rate)
215
+ # Remove silence
216
+ if remove_silence:
217
+ remove_silence_for_generated_wav(f.name)
218
+ print(f.name)
219
+
220
+
221
+ def main():
222
+ main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()