ashishkblink commited on
Commit
0b8fc6a
·
verified ·
1 Parent(s): 1219672

Upload f5_tts/eval/eval_infer_batch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/eval/eval_infer_batch.py +207 -0
f5_tts/eval/eval_infer_batch.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ import argparse
7
+ import time
8
+ from importlib.resources import files
9
+
10
+ import torch
11
+ import torchaudio
12
+ from accelerate import Accelerator
13
+ from tqdm import tqdm
14
+
15
+ from f5_tts.eval.utils_eval import (
16
+ get_inference_prompt,
17
+ get_librispeech_test_clean_metainfo,
18
+ get_seedtts_testset_metainfo,
19
+ )
20
+ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
21
+ from f5_tts.model import CFM, DiT, UNetT
22
+ from f5_tts.model.utils import get_tokenizer
23
+
24
+ accelerator = Accelerator()
25
+ device = f"cuda:{accelerator.process_index}"
26
+
27
+
28
+ # --------------------- Dataset Settings -------------------- #
29
+
30
+ target_sample_rate = 24000
31
+ n_mel_channels = 100
32
+ hop_length = 256
33
+ win_length = 1024
34
+ n_fft = 1024
35
+ target_rms = 0.1
36
+
37
+ rel_path = str(files("f5_tts").joinpath("../../"))
38
+
39
+
40
+ def main():
41
+ # ---------------------- infer setting ---------------------- #
42
+
43
+ parser = argparse.ArgumentParser(description="batch inference")
44
+
45
+ parser.add_argument("-s", "--seed", default=None, type=int)
46
+ parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
47
+ parser.add_argument("-n", "--expname", required=True)
48
+ parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
49
+ parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
50
+ parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
51
+
52
+ parser.add_argument("-nfe", "--nfestep", default=32, type=int)
53
+ parser.add_argument("-o", "--odemethod", default="euler")
54
+ parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
55
+
56
+ parser.add_argument("-t", "--testset", required=True)
57
+
58
+ args = parser.parse_args()
59
+
60
+ seed = args.seed
61
+ dataset_name = args.dataset
62
+ exp_name = args.expname
63
+ ckpt_step = args.ckptstep
64
+ ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
65
+ mel_spec_type = args.mel_spec_type
66
+ tokenizer = args.tokenizer
67
+
68
+ nfe_step = args.nfestep
69
+ ode_method = args.odemethod
70
+ sway_sampling_coef = args.swaysampling
71
+
72
+ testset = args.testset
73
+
74
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
75
+ cfg_strength = 2.0
76
+ speed = 1.0
77
+ use_truth_duration = False
78
+ no_ref_audio = False
79
+
80
+ if exp_name == "F5TTS_Base":
81
+ model_cls = DiT
82
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
83
+
84
+ elif exp_name == "E2TTS_Base":
85
+ model_cls = UNetT
86
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
87
+
88
+ if testset == "ls_pc_test_clean":
89
+ metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
90
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
91
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
92
+
93
+ elif testset == "seedtts_test_zh":
94
+ metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
95
+ metainfo = get_seedtts_testset_metainfo(metalst)
96
+
97
+ elif testset == "seedtts_test_en":
98
+ metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
99
+ metainfo = get_seedtts_testset_metainfo(metalst)
100
+
101
+ # path to save genereted wavs
102
+ output_dir = (
103
+ f"{rel_path}/"
104
+ f"results/{exp_name}_{ckpt_step}/{testset}/"
105
+ f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
106
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
107
+ f"_cfg{cfg_strength}_speed{speed}"
108
+ f"{'_gt-dur' if use_truth_duration else ''}"
109
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
110
+ )
111
+
112
+ # -------------------------------------------------#
113
+
114
+ use_ema = True
115
+
116
+ prompts_all = get_inference_prompt(
117
+ metainfo,
118
+ speed=speed,
119
+ tokenizer=tokenizer,
120
+ target_sample_rate=target_sample_rate,
121
+ n_mel_channels=n_mel_channels,
122
+ hop_length=hop_length,
123
+ mel_spec_type=mel_spec_type,
124
+ target_rms=target_rms,
125
+ use_truth_duration=use_truth_duration,
126
+ infer_batch_size=infer_batch_size,
127
+ )
128
+
129
+ # Vocoder model
130
+ local = False
131
+ if mel_spec_type == "vocos":
132
+ vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
133
+ elif mel_spec_type == "bigvgan":
134
+ vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
135
+ vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
136
+
137
+ # Tokenizer
138
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
139
+
140
+ # Model
141
+ model = CFM(
142
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
143
+ mel_spec_kwargs=dict(
144
+ n_fft=n_fft,
145
+ hop_length=hop_length,
146
+ win_length=win_length,
147
+ n_mel_channels=n_mel_channels,
148
+ target_sample_rate=target_sample_rate,
149
+ mel_spec_type=mel_spec_type,
150
+ ),
151
+ odeint_kwargs=dict(
152
+ method=ode_method,
153
+ ),
154
+ vocab_char_map=vocab_char_map,
155
+ ).to(device)
156
+
157
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
158
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
159
+
160
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
161
+ os.makedirs(output_dir)
162
+
163
+ # start batch inference
164
+ accelerator.wait_for_everyone()
165
+ start = time.time()
166
+
167
+ with accelerator.split_between_processes(prompts_all) as prompts:
168
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
169
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
170
+ ref_mels = ref_mels.to(device)
171
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
172
+ total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
173
+
174
+ # Inference
175
+ with torch.inference_mode():
176
+ generated, _ = model.sample(
177
+ cond=ref_mels,
178
+ text=final_text_list,
179
+ duration=total_mel_lens,
180
+ lens=ref_mel_lens,
181
+ steps=nfe_step,
182
+ cfg_strength=cfg_strength,
183
+ sway_sampling_coef=sway_sampling_coef,
184
+ no_ref_audio=no_ref_audio,
185
+ seed=seed,
186
+ )
187
+ # Final result
188
+ for i, gen in enumerate(generated):
189
+ gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
190
+ gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
191
+ if mel_spec_type == "vocos":
192
+ generated_wave = vocoder.decode(gen_mel_spec).cpu()
193
+ elif mel_spec_type == "bigvgan":
194
+ generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
195
+
196
+ if ref_rms_list[i] < target_rms:
197
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
198
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
199
+
200
+ accelerator.wait_for_everyone()
201
+ if accelerator.is_main_process:
202
+ timediff = time.time() - start
203
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
204
+
205
+
206
+ if __name__ == "__main__":
207
+ main()