ashishkblink commited on
Commit
1dc4d12
·
verified ·
1 Parent(s): 6261b7f

Upload f5_tts/infer/infer_batch_parallel.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/infer/infer_batch_parallel.py +171 -0
f5_tts/infer/infer_batch_parallel.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import codecs
3
+ import os
4
+ import re
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import tomli
10
+ from cached_path import cached_path
11
+ import pandas as pd
12
+ from tqdm import tqdm
13
+
14
+ from f5_tts.infer.utils_infer import (
15
+ infer_process,
16
+ load_model,
17
+ load_vocoder,
18
+ preprocess_ref_audio_text,
19
+ remove_silence_for_generated_wav,
20
+ )
21
+ from f5_tts.model import DiT, UNetT
22
+
23
+
24
+ def run_batch_inference(prompt_paths, prompt_texts, texts, languages, categories, model_obj, vocoder, mel_spec_type, remove_silence, speed, output_dir):
25
+ count = 0
26
+ for ref_audio in prompt_paths:
27
+ if not isinstance(ref_audio, str) or not os.path.isfile(ref_audio):
28
+ print(f"Invalid ref_audio: {ref_audio}")
29
+ count += 1
30
+ print(count)
31
+ # raise ValueError(f"Invalid ref_audio: {ref_audio}")
32
+
33
+ for idx, (ref_audio, ref_text, text_gen, language, category) in tqdm(enumerate(zip(prompt_paths, prompt_texts, texts, languages, categories))):
34
+ voices = {"main": {"ref_audio": ref_audio, "ref_text": ref_text}}
35
+ for voice in voices:
36
+ voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
37
+ voices[voice]["ref_audio"], voices[voice]["ref_text"]
38
+ )
39
+ print("Voice:", voice)
40
+ print("Ref_audio:", voices[voice]["ref_audio"])
41
+ print("Ref_text:", voices[voice]["ref_text"])
42
+
43
+ generated_audio_segments = []
44
+ reg1 = r"(?=\[\w+\])"
45
+ chunks = re.split(reg1, text_gen)
46
+ reg2 = r"\[(\w+)\]"
47
+ for text in chunks:
48
+ if not text.strip():
49
+ continue
50
+ match = re.match(reg2, text)
51
+ if match:
52
+ voice = match[1]
53
+ else:
54
+ print("No voice tag found, using main.")
55
+ voice = "main"
56
+ if voice not in voices:
57
+ print(f"Voice {voice} not found, using main.")
58
+ voice = "main"
59
+ text = re.sub(reg2, "", text)
60
+ gen_text = text.strip()
61
+ ref_audio = voices[voice]["ref_audio"]
62
+ ref_text = voices[voice]["ref_text"]
63
+ print(f"Voice: {voice}")
64
+ audio, final_sample_rate, spectragram = infer_process(
65
+ ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
66
+ )
67
+ generated_audio_segments.append(audio)
68
+
69
+ if generated_audio_segments:
70
+ final_wave = np.concatenate(generated_audio_segments)
71
+ filename = f"{language.upper()}_{category.upper()}_{idx}.wav"
72
+ outfile_dir = os.path.join(output_dir, language)
73
+ os.makedirs(outfile_dir, exist_ok=True)
74
+ wave_path = Path(outfile_dir) / filename
75
+ with open(wave_path, "wb") as f:
76
+ sf.write(f.name, final_wave, final_sample_rate)
77
+ if remove_silence:
78
+ remove_silence_for_generated_wav(f.name)
79
+ print(f"Generated audio saved to: {f.name}")
80
+
81
+
82
+ def main():
83
+ parser = argparse.ArgumentParser(
84
+ prog="python3 infer-cli.py",
85
+ description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
86
+ epilog="Specify options above to override one or more settings from config.",
87
+ )
88
+
89
+ parser.add_argument(
90
+ "-m",
91
+ "--model",
92
+ help="F5-TTS | E2-TTS",
93
+ )
94
+ parser.add_argument(
95
+ "-p",
96
+ "--ckpt_file",
97
+ help="The Checkpoint .pt",
98
+ )
99
+ parser.add_argument(
100
+ "-v",
101
+ "--vocab_file",
102
+ help="The vocab .txt",
103
+ )
104
+
105
+ parser.add_argument(
106
+ "-f",
107
+ "--generate_csv",
108
+ type=str,
109
+ )
110
+ parser.add_argument(
111
+ "-o",
112
+ "--output_dir",
113
+ type=str,
114
+ help="Path to output folder..",
115
+ )
116
+ parser.add_argument(
117
+ "--remove_silence",
118
+ help="Remove silence.",
119
+ )
120
+ parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
121
+ parser.add_argument(
122
+ "--load_vocoder_from_local",
123
+ action="store_true",
124
+ help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
125
+ )
126
+ parser.add_argument(
127
+ "--speed",
128
+ type=float,
129
+ default=1.0,
130
+ help="Adjust the speed of the audio generation (default: 1.0)",
131
+ )
132
+ args = parser.parse_args()
133
+
134
+ # Read texts and prompts to generate
135
+ filepath = args.generate_csv
136
+ df = pd.read_csv(filepath)
137
+ prompt_paths = df['prompt_path'].tolist()
138
+ prompt_texts = df['prompt_text'].tolist()
139
+ texts = df['text'].tolist()
140
+ languages = df['language'].tolist()
141
+ categories = df['category'].tolist()
142
+
143
+ # Model config
144
+ model = args.model
145
+ ckpt_file = args.ckpt_file
146
+ vocab_file = args.vocab_file
147
+ remove_silence = args.remove_silence
148
+ speed = args.speed
149
+ vocoder_name = args.vocoder_name
150
+ mel_spec_type = args.vocoder_name
151
+ if vocoder_name == "vocos":
152
+ vocoder_local_path = "../checkpoints/vocos-mel-24khz"
153
+ elif vocoder_name == "bigvgan":
154
+ vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
155
+ vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
156
+
157
+ # load models
158
+ model_cls = DiT
159
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
160
+ print(f"Using {model}...")
161
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
162
+
163
+ # Batch inference
164
+ output_dir = args.output_dir
165
+ if not os.path.exists(output_dir):
166
+ os.makedirs(output_dir)
167
+ run_batch_inference(prompt_paths, prompt_texts, texts, languages, categories, ema_model, vocoder, mel_spec_type, remove_silence, speed, output_dir)
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()