ashishkblink commited on
Commit
5ee4d87
·
verified ·
1 Parent(s): 64f9887

Upload f5_tts/train/finetune_gradio.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/train/finetune_gradio.py +1846 -0
f5_tts/train/finetune_gradio.py ADDED
@@ -0,0 +1,1846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import queue
3
+ import re
4
+
5
+ import gc
6
+ import json
7
+ import os
8
+ import platform
9
+ import psutil
10
+ import random
11
+ import signal
12
+ import shutil
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+ import time
17
+ from glob import glob
18
+
19
+ import click
20
+ import gradio as gr
21
+ import librosa
22
+ import numpy as np
23
+ import torch
24
+ import torchaudio
25
+ from datasets import Dataset as Dataset_
26
+ from datasets.arrow_writer import ArrowWriter
27
+ from safetensors.torch import save_file
28
+ from scipy.io import wavfile
29
+ from cached_path import cached_path
30
+ from f5_tts.api import F5TTS
31
+ from f5_tts.model.utils import convert_char_to_pinyin
32
+ from f5_tts.infer.utils_infer import transcribe
33
+ from importlib.resources import files
34
+
35
+
36
+ training_process = None
37
+ system = platform.system()
38
+ python_executable = sys.executable or "python"
39
+ tts_api = None
40
+ last_checkpoint = ""
41
+ last_device = ""
42
+ last_ema = None
43
+
44
+
45
+ path_data = str(files("f5_tts").joinpath("../../data"))
46
+ path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
47
+ file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
48
+
49
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
50
+
51
+
52
+ # Save settings from a JSON file
53
+ def save_settings(
54
+ project_name,
55
+ exp_name,
56
+ learning_rate,
57
+ batch_size_per_gpu,
58
+ batch_size_type,
59
+ max_samples,
60
+ grad_accumulation_steps,
61
+ max_grad_norm,
62
+ epochs,
63
+ num_warmup_updates,
64
+ save_per_updates,
65
+ last_per_steps,
66
+ finetune,
67
+ file_checkpoint_train,
68
+ tokenizer_type,
69
+ tokenizer_file,
70
+ mixed_precision,
71
+ logger,
72
+ ch_8bit_adam,
73
+ ):
74
+ path_project = os.path.join(path_project_ckpts, project_name)
75
+ os.makedirs(path_project, exist_ok=True)
76
+ file_setting = os.path.join(path_project, "setting.json")
77
+
78
+ settings = {
79
+ "exp_name": exp_name,
80
+ "learning_rate": learning_rate,
81
+ "batch_size_per_gpu": batch_size_per_gpu,
82
+ "batch_size_type": batch_size_type,
83
+ "max_samples": max_samples,
84
+ "grad_accumulation_steps": grad_accumulation_steps,
85
+ "max_grad_norm": max_grad_norm,
86
+ "epochs": epochs,
87
+ "num_warmup_updates": num_warmup_updates,
88
+ "save_per_updates": save_per_updates,
89
+ "last_per_steps": last_per_steps,
90
+ "finetune": finetune,
91
+ "file_checkpoint_train": file_checkpoint_train,
92
+ "tokenizer_type": tokenizer_type,
93
+ "tokenizer_file": tokenizer_file,
94
+ "mixed_precision": mixed_precision,
95
+ "logger": logger,
96
+ "bnb_optimizer": ch_8bit_adam,
97
+ }
98
+ with open(file_setting, "w") as f:
99
+ json.dump(settings, f, indent=4)
100
+ return "Settings saved!"
101
+
102
+
103
+ # Load settings from a JSON file
104
+ def load_settings(project_name):
105
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
106
+ path_project = os.path.join(path_project_ckpts, project_name)
107
+ file_setting = os.path.join(path_project, "setting.json")
108
+
109
+ if not os.path.isfile(file_setting):
110
+ settings = {
111
+ "exp_name": "F5TTS_Base",
112
+ "learning_rate": 1e-05,
113
+ "batch_size_per_gpu": 1000,
114
+ "batch_size_type": "frame",
115
+ "max_samples": 64,
116
+ "grad_accumulation_steps": 1,
117
+ "max_grad_norm": 1,
118
+ "epochs": 100,
119
+ "num_warmup_updates": 2,
120
+ "save_per_updates": 300,
121
+ "last_per_steps": 100,
122
+ "finetune": True,
123
+ "file_checkpoint_train": "",
124
+ "tokenizer_type": "pinyin",
125
+ "tokenizer_file": "",
126
+ "mixed_precision": "none",
127
+ "logger": "wandb",
128
+ "bnb_optimizer": False,
129
+ }
130
+ return (
131
+ settings["exp_name"],
132
+ settings["learning_rate"],
133
+ settings["batch_size_per_gpu"],
134
+ settings["batch_size_type"],
135
+ settings["max_samples"],
136
+ settings["grad_accumulation_steps"],
137
+ settings["max_grad_norm"],
138
+ settings["epochs"],
139
+ settings["num_warmup_updates"],
140
+ settings["save_per_updates"],
141
+ settings["last_per_steps"],
142
+ settings["finetune"],
143
+ settings["file_checkpoint_train"],
144
+ settings["tokenizer_type"],
145
+ settings["tokenizer_file"],
146
+ settings["mixed_precision"],
147
+ settings["logger"],
148
+ settings["bnb_optimizer"],
149
+ )
150
+
151
+ with open(file_setting, "r") as f:
152
+ settings = json.load(f)
153
+ if "logger" not in settings:
154
+ settings["logger"] = "wandb"
155
+ if "bnb_optimizer" not in settings:
156
+ settings["bnb_optimizer"] = False
157
+ return (
158
+ settings["exp_name"],
159
+ settings["learning_rate"],
160
+ settings["batch_size_per_gpu"],
161
+ settings["batch_size_type"],
162
+ settings["max_samples"],
163
+ settings["grad_accumulation_steps"],
164
+ settings["max_grad_norm"],
165
+ settings["epochs"],
166
+ settings["num_warmup_updates"],
167
+ settings["save_per_updates"],
168
+ settings["last_per_steps"],
169
+ settings["finetune"],
170
+ settings["file_checkpoint_train"],
171
+ settings["tokenizer_type"],
172
+ settings["tokenizer_file"],
173
+ settings["mixed_precision"],
174
+ settings["logger"],
175
+ settings["bnb_optimizer"],
176
+ )
177
+
178
+
179
+ # Load metadata
180
+ def get_audio_duration(audio_path):
181
+ """Calculate the duration mono of an audio file."""
182
+ audio, sample_rate = torchaudio.load(audio_path)
183
+ return audio.shape[1] / sample_rate
184
+
185
+
186
+ def clear_text(text):
187
+ """Clean and prepare text by lowering the case and stripping whitespace."""
188
+ return text.lower().strip()
189
+
190
+
191
+ def get_rms(
192
+ y,
193
+ frame_length=2048,
194
+ hop_length=512,
195
+ pad_mode="constant",
196
+ ): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
197
+ padding = (int(frame_length // 2), int(frame_length // 2))
198
+ y = np.pad(y, padding, mode=pad_mode)
199
+
200
+ axis = -1
201
+ # put our new within-frame axis at the end for now
202
+ out_strides = y.strides + tuple([y.strides[axis]])
203
+ # Reduce the shape on the framing axis
204
+ x_shape_trimmed = list(y.shape)
205
+ x_shape_trimmed[axis] -= frame_length - 1
206
+ out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
207
+ xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
208
+ if axis < 0:
209
+ target_axis = axis - 1
210
+ else:
211
+ target_axis = axis + 1
212
+ xw = np.moveaxis(xw, -1, target_axis)
213
+ # Downsample along the target axis
214
+ slices = [slice(None)] * xw.ndim
215
+ slices[axis] = slice(0, None, hop_length)
216
+ x = xw[tuple(slices)]
217
+
218
+ # Calculate power
219
+ power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
220
+
221
+ return np.sqrt(power)
222
+
223
+
224
+ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
225
+ def __init__(
226
+ self,
227
+ sr: int,
228
+ threshold: float = -40.0,
229
+ min_length: int = 2000,
230
+ min_interval: int = 300,
231
+ hop_size: int = 20,
232
+ max_sil_kept: int = 2000,
233
+ ):
234
+ if not min_length >= min_interval >= hop_size:
235
+ raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
236
+ if not max_sil_kept >= hop_size:
237
+ raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
238
+ min_interval = sr * min_interval / 1000
239
+ self.threshold = 10 ** (threshold / 20.0)
240
+ self.hop_size = round(sr * hop_size / 1000)
241
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
242
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
243
+ self.min_interval = round(min_interval / self.hop_size)
244
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
245
+
246
+ def _apply_slice(self, waveform, begin, end):
247
+ if len(waveform.shape) > 1:
248
+ return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
249
+ else:
250
+ return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
251
+
252
+ # @timeit
253
+ def slice(self, waveform):
254
+ if len(waveform.shape) > 1:
255
+ samples = waveform.mean(axis=0)
256
+ else:
257
+ samples = waveform
258
+ if samples.shape[0] <= self.min_length:
259
+ return [waveform]
260
+ rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
261
+ sil_tags = []
262
+ silence_start = None
263
+ clip_start = 0
264
+ for i, rms in enumerate(rms_list):
265
+ # Keep looping while frame is silent.
266
+ if rms < self.threshold:
267
+ # Record start of silent frames.
268
+ if silence_start is None:
269
+ silence_start = i
270
+ continue
271
+ # Keep looping while frame is not silent and silence start has not been recorded.
272
+ if silence_start is None:
273
+ continue
274
+ # Clear recorded silence start if interval is not enough or clip is too short
275
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
276
+ need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
277
+ if not is_leading_silence and not need_slice_middle:
278
+ silence_start = None
279
+ continue
280
+ # Need slicing. Record the range of silent frames to be removed.
281
+ if i - silence_start <= self.max_sil_kept:
282
+ pos = rms_list[silence_start : i + 1].argmin() + silence_start
283
+ if silence_start == 0:
284
+ sil_tags.append((0, pos))
285
+ else:
286
+ sil_tags.append((pos, pos))
287
+ clip_start = pos
288
+ elif i - silence_start <= self.max_sil_kept * 2:
289
+ pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
290
+ pos += i - self.max_sil_kept
291
+ pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
292
+ pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
293
+ if silence_start == 0:
294
+ sil_tags.append((0, pos_r))
295
+ clip_start = pos_r
296
+ else:
297
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
298
+ clip_start = max(pos_r, pos)
299
+ else:
300
+ pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
301
+ pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
302
+ if silence_start == 0:
303
+ sil_tags.append((0, pos_r))
304
+ else:
305
+ sil_tags.append((pos_l, pos_r))
306
+ clip_start = pos_r
307
+ silence_start = None
308
+ # Deal with trailing silence.
309
+ total_frames = rms_list.shape[0]
310
+ if silence_start is not None and total_frames - silence_start >= self.min_interval:
311
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
312
+ pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
313
+ sil_tags.append((pos, total_frames + 1))
314
+ # Apply and return slices.
315
+ ####音频+起始时间+终止时间
316
+ if len(sil_tags) == 0:
317
+ return [[waveform, 0, int(total_frames * self.hop_size)]]
318
+ else:
319
+ chunks = []
320
+ if sil_tags[0][0] > 0:
321
+ chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
322
+ for i in range(len(sil_tags) - 1):
323
+ chunks.append(
324
+ [
325
+ self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
326
+ int(sil_tags[i][1] * self.hop_size),
327
+ int(sil_tags[i + 1][0] * self.hop_size),
328
+ ]
329
+ )
330
+ if sil_tags[-1][1] < total_frames:
331
+ chunks.append(
332
+ [
333
+ self._apply_slice(waveform, sil_tags[-1][1], total_frames),
334
+ int(sil_tags[-1][1] * self.hop_size),
335
+ int(total_frames * self.hop_size),
336
+ ]
337
+ )
338
+ return chunks
339
+
340
+
341
+ # terminal
342
+ def terminate_process_tree(pid, including_parent=True):
343
+ try:
344
+ parent = psutil.Process(pid)
345
+ except psutil.NoSuchProcess:
346
+ # Process already terminated
347
+ return
348
+
349
+ children = parent.children(recursive=True)
350
+ for child in children:
351
+ try:
352
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
353
+ except OSError:
354
+ pass
355
+ if including_parent:
356
+ try:
357
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
358
+ except OSError:
359
+ pass
360
+
361
+
362
+ def terminate_process(pid):
363
+ if system == "Windows":
364
+ cmd = f"taskkill /t /f /pid {pid}"
365
+ os.system(cmd)
366
+ else:
367
+ terminate_process_tree(pid)
368
+
369
+
370
+ def start_training(
371
+ dataset_name="",
372
+ exp_name="F5TTS_Base",
373
+ learning_rate=1e-4,
374
+ batch_size_per_gpu=400,
375
+ batch_size_type="frame",
376
+ max_samples=64,
377
+ grad_accumulation_steps=1,
378
+ max_grad_norm=1.0,
379
+ epochs=11,
380
+ num_warmup_updates=200,
381
+ save_per_updates=400,
382
+ last_per_steps=800,
383
+ finetune=True,
384
+ file_checkpoint_train="",
385
+ tokenizer_type="pinyin",
386
+ tokenizer_file="",
387
+ mixed_precision="fp16",
388
+ stream=False,
389
+ logger="wandb",
390
+ ch_8bit_adam=False,
391
+ ):
392
+ global training_process, tts_api, stop_signal
393
+
394
+ if tts_api is not None:
395
+ if tts_api is not None:
396
+ del tts_api
397
+
398
+ gc.collect()
399
+ torch.cuda.empty_cache()
400
+ tts_api = None
401
+
402
+ path_project = os.path.join(path_data, dataset_name)
403
+
404
+ if not os.path.isdir(path_project):
405
+ yield (
406
+ f"There is not project with name {dataset_name}",
407
+ gr.update(interactive=True),
408
+ gr.update(interactive=False),
409
+ )
410
+ return
411
+
412
+ file_raw = os.path.join(path_project, "raw.arrow")
413
+ if not os.path.isfile(file_raw):
414
+ yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
415
+ return
416
+
417
+ # Check if a training process is already running
418
+ if training_process is not None:
419
+ return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
420
+
421
+ yield "start train", gr.update(interactive=False), gr.update(interactive=False)
422
+
423
+ # Command to run the training script with the specified arguments
424
+
425
+ if tokenizer_file == "":
426
+ if dataset_name.endswith("_pinyin"):
427
+ tokenizer_type = "pinyin"
428
+ elif dataset_name.endswith("_char"):
429
+ tokenizer_type = "char"
430
+ else:
431
+ tokenizer_type = "custom"
432
+
433
+ dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
434
+
435
+ if mixed_precision != "none":
436
+ fp16 = f"--mixed_precision={mixed_precision}"
437
+ else:
438
+ fp16 = ""
439
+
440
+ cmd = (
441
+ f"accelerate launch {fp16} {file_train} --exp_name {exp_name} "
442
+ f"--learning_rate {learning_rate} "
443
+ f"--batch_size_per_gpu {batch_size_per_gpu} "
444
+ f"--batch_size_type {batch_size_type} "
445
+ f"--max_samples {max_samples} "
446
+ f"--grad_accumulation_steps {grad_accumulation_steps} "
447
+ f"--max_grad_norm {max_grad_norm} "
448
+ f"--epochs {epochs} "
449
+ f"--num_warmup_updates {num_warmup_updates} "
450
+ f"--save_per_updates {save_per_updates} "
451
+ f"--last_per_steps {last_per_steps} "
452
+ f"--dataset_name {dataset_name}"
453
+ )
454
+
455
+ cmd += f" --finetune {finetune}"
456
+
457
+ if file_checkpoint_train != "":
458
+ cmd += f" --pretrain {file_checkpoint_train}"
459
+
460
+ if tokenizer_file != "":
461
+ cmd += f" --tokenizer_path {tokenizer_file}"
462
+
463
+ cmd += f" --tokenizer {tokenizer_type} "
464
+
465
+ cmd += f" --log_samples True --logger {logger} "
466
+
467
+ if ch_8bit_adam:
468
+ cmd += " --bnb_optimizer True "
469
+
470
+ print("run command : \n" + cmd + "\n")
471
+
472
+ save_settings(
473
+ dataset_name,
474
+ exp_name,
475
+ learning_rate,
476
+ batch_size_per_gpu,
477
+ batch_size_type,
478
+ max_samples,
479
+ grad_accumulation_steps,
480
+ max_grad_norm,
481
+ epochs,
482
+ num_warmup_updates,
483
+ save_per_updates,
484
+ last_per_steps,
485
+ finetune,
486
+ file_checkpoint_train,
487
+ tokenizer_type,
488
+ tokenizer_file,
489
+ mixed_precision,
490
+ logger,
491
+ ch_8bit_adam,
492
+ )
493
+
494
+ try:
495
+ if not stream:
496
+ # Start the training process
497
+ training_process = subprocess.Popen(cmd, shell=True)
498
+
499
+ time.sleep(5)
500
+ yield "train start", gr.update(interactive=False), gr.update(interactive=True)
501
+
502
+ # Wait for the training process to finish
503
+ training_process.wait()
504
+ else:
505
+
506
+ def stream_output(pipe, output_queue):
507
+ try:
508
+ for line in iter(pipe.readline, ""):
509
+ output_queue.put(line)
510
+ except Exception as e:
511
+ output_queue.put(f"Error reading pipe: {str(e)}")
512
+ finally:
513
+ pipe.close()
514
+
515
+ env = os.environ.copy()
516
+ env["PYTHONUNBUFFERED"] = "1"
517
+
518
+ training_process = subprocess.Popen(
519
+ cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
520
+ )
521
+ yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)
522
+
523
+ stdout_queue = queue.Queue()
524
+ stderr_queue = queue.Queue()
525
+
526
+ stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue))
527
+ stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue))
528
+ stdout_thread.daemon = True
529
+ stderr_thread.daemon = True
530
+ stdout_thread.start()
531
+ stderr_thread.start()
532
+ stop_signal = False
533
+ while True:
534
+ if stop_signal:
535
+ training_process.terminate()
536
+ time.sleep(0.5)
537
+ if training_process.poll() is None:
538
+ training_process.kill()
539
+ yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False)
540
+ break
541
+
542
+ process_status = training_process.poll()
543
+
544
+ # Handle stdout
545
+ try:
546
+ while True:
547
+ output = stdout_queue.get_nowait()
548
+ print(output, end="")
549
+ match = re.search(
550
+ r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), step=(\d+)", output
551
+ )
552
+ if match:
553
+ current_epoch = match.group(1)
554
+ total_epochs = match.group(2)
555
+ percent_complete = match.group(3)
556
+ elapsed_time = match.group(4)
557
+ loss = match.group(5)
558
+ current_step = match.group(6)
559
+ message = (
560
+ f"Epoch: {current_epoch}/{total_epochs}, "
561
+ f"Progress: {percent_complete}%, "
562
+ f"Elapsed Time: {elapsed_time}, "
563
+ f"Loss: {loss}, "
564
+ f"Step: {current_step}"
565
+ )
566
+ yield message, gr.update(interactive=False), gr.update(interactive=True)
567
+ elif output.strip():
568
+ yield output, gr.update(interactive=False), gr.update(interactive=True)
569
+ except queue.Empty:
570
+ pass
571
+
572
+ # Handle stderr
573
+ try:
574
+ while True:
575
+ error_output = stderr_queue.get_nowait()
576
+ print(error_output, end="")
577
+ if error_output.strip():
578
+ yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True)
579
+ except queue.Empty:
580
+ pass
581
+
582
+ if process_status is not None and stdout_queue.empty() and stderr_queue.empty():
583
+ if process_status != 0:
584
+ yield (
585
+ f"Process crashed with exit code {process_status}!",
586
+ gr.update(interactive=False),
587
+ gr.update(interactive=True),
588
+ )
589
+ else:
590
+ yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
591
+ break
592
+
593
+ # Small sleep to prevent CPU thrashing
594
+ time.sleep(0.1)
595
+
596
+ # Clean up
597
+ training_process.stdout.close()
598
+ training_process.stderr.close()
599
+ training_process.wait()
600
+
601
+ time.sleep(1)
602
+
603
+ if training_process is None:
604
+ text_info = "train stop"
605
+ else:
606
+ text_info = "train complete !"
607
+
608
+ except Exception as e: # Catch all exceptions
609
+ # Ensure that we reset the training process variable in case of an error
610
+ text_info = f"An error occurred: {str(e)}"
611
+
612
+ training_process = None
613
+
614
+ yield text_info, gr.update(interactive=True), gr.update(interactive=False)
615
+
616
+
617
+ def stop_training():
618
+ global training_process, stop_signal
619
+
620
+ if training_process is None:
621
+ return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
622
+ terminate_process_tree(training_process.pid)
623
+ # training_process = None
624
+ stop_signal = True
625
+ return "train stop", gr.update(interactive=True), gr.update(interactive=False)
626
+
627
+
628
+ def get_list_projects():
629
+ project_list = []
630
+ for folder in os.listdir(path_data):
631
+ path_folder = os.path.join(path_data, folder)
632
+ if not os.path.isdir(path_folder):
633
+ continue
634
+ folder = folder.lower()
635
+ if folder == "emilia_zh_en_pinyin":
636
+ continue
637
+ project_list.append(folder)
638
+
639
+ projects_selelect = None if not project_list else project_list[-1]
640
+
641
+ return project_list, projects_selelect
642
+
643
+
644
+ def create_data_project(name, tokenizer_type):
645
+ name += "_" + tokenizer_type
646
+ os.makedirs(os.path.join(path_data, name), exist_ok=True)
647
+ os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
648
+ project_list, projects_selelect = get_list_projects()
649
+ return gr.update(choices=project_list, value=name)
650
+
651
+
652
+ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
653
+ path_project = os.path.join(path_data, name_project)
654
+ path_dataset = os.path.join(path_project, "dataset")
655
+ path_project_wavs = os.path.join(path_project, "wavs")
656
+ file_metadata = os.path.join(path_project, "metadata.csv")
657
+
658
+ if not user:
659
+ if audio_files is None:
660
+ return "You need to load an audio file."
661
+
662
+ if os.path.isdir(path_project_wavs):
663
+ shutil.rmtree(path_project_wavs)
664
+
665
+ if os.path.isfile(file_metadata):
666
+ os.remove(file_metadata)
667
+
668
+ os.makedirs(path_project_wavs, exist_ok=True)
669
+
670
+ if user:
671
+ file_audios = [
672
+ file
673
+ for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
674
+ for file in glob(os.path.join(path_dataset, format))
675
+ ]
676
+ if file_audios == []:
677
+ return "No audio file was found in the dataset."
678
+ else:
679
+ file_audios = audio_files
680
+
681
+ alpha = 0.5
682
+ _max = 1.0
683
+ slicer = Slicer(24000)
684
+
685
+ num = 0
686
+ error_num = 0
687
+ data = ""
688
+ for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
689
+ audio, _ = librosa.load(file_audio, sr=24000, mono=True)
690
+
691
+ list_slicer = slicer.slice(audio)
692
+ for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
693
+ name_segment = os.path.join(f"segment_{num}")
694
+ file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
695
+
696
+ tmp_max = np.abs(chunk).max()
697
+ if tmp_max > 1:
698
+ chunk /= tmp_max
699
+ chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
700
+ wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
701
+
702
+ try:
703
+ text = transcribe(file_segment, language)
704
+ text = text.lower().strip().replace('"', "")
705
+
706
+ data += f"{name_segment}|{text}\n"
707
+
708
+ num += 1
709
+ except: # noqa: E722
710
+ error_num += 1
711
+
712
+ with open(file_metadata, "w", encoding="utf-8-sig") as f:
713
+ f.write(data)
714
+
715
+ if error_num != []:
716
+ error_text = f"\nerror files : {error_num}"
717
+ else:
718
+ error_text = ""
719
+
720
+ return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
721
+
722
+
723
+ def format_seconds_to_hms(seconds):
724
+ hours = int(seconds / 3600)
725
+ minutes = int((seconds % 3600) / 60)
726
+ seconds = seconds % 60
727
+ return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
728
+
729
+
730
+ def get_correct_audio_path(
731
+ audio_input,
732
+ base_path="wavs",
733
+ supported_formats=("wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr"),
734
+ ):
735
+ file_audio = None
736
+
737
+ # Helper function to check if file has a supported extension
738
+ def has_supported_extension(file_name):
739
+ return any(file_name.endswith(f".{ext}") for ext in supported_formats)
740
+
741
+ # Case 1: If it's a full path with a valid extension, use it directly
742
+ if os.path.isabs(audio_input) and has_supported_extension(audio_input):
743
+ file_audio = audio_input
744
+
745
+ # Case 2: If it has a supported extension but is not a full path
746
+ elif has_supported_extension(audio_input) and not os.path.isabs(audio_input):
747
+ file_audio = os.path.join(base_path, audio_input)
748
+
749
+ # Case 3: If only the name is given (no extension and not a full path)
750
+ elif not has_supported_extension(audio_input) and not os.path.isabs(audio_input):
751
+ for ext in supported_formats:
752
+ potential_file = os.path.join(base_path, f"{audio_input}.{ext}")
753
+ if os.path.exists(potential_file):
754
+ file_audio = potential_file
755
+ break
756
+ else:
757
+ file_audio = os.path.join(base_path, f"{audio_input}.{supported_formats[0]}")
758
+ return file_audio
759
+
760
+
761
+ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
762
+ path_project = os.path.join(path_data, name_project)
763
+ path_project_wavs = os.path.join(path_project, "wavs")
764
+ file_metadata = os.path.join(path_project, "metadata.csv")
765
+ file_raw = os.path.join(path_project, "raw.arrow")
766
+ file_duration = os.path.join(path_project, "duration.json")
767
+ file_vocab = os.path.join(path_project, "vocab.txt")
768
+
769
+ if not os.path.isfile(file_metadata):
770
+ return "The file was not found in " + file_metadata, ""
771
+
772
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
773
+ data = f.read()
774
+
775
+ audio_path_list = []
776
+ text_list = []
777
+ duration_list = []
778
+
779
+ count = data.split("\n")
780
+ lenght = 0
781
+ result = []
782
+ error_files = []
783
+ text_vocab_set = set()
784
+ for line in progress.tqdm(data.split("\n"), total=count):
785
+ sp_line = line.split("|")
786
+ if len(sp_line) != 2:
787
+ continue
788
+ name_audio, text = sp_line[:2]
789
+
790
+ file_audio = get_correct_audio_path(name_audio, path_project_wavs)
791
+
792
+ if not os.path.isfile(file_audio):
793
+ error_files.append([file_audio, "error path"])
794
+ continue
795
+
796
+ try:
797
+ duration = get_audio_duration(file_audio)
798
+ except Exception as e:
799
+ error_files.append([file_audio, "duration"])
800
+ print(f"Error processing {file_audio}: {e}")
801
+ continue
802
+
803
+ if duration < 1 or duration > 25:
804
+ if duration > 25:
805
+ error_files.append([file_audio, "duration > 25 sec"])
806
+ if duration < 1:
807
+ error_files.append([file_audio, "duration < 1 sec "])
808
+ continue
809
+ if len(text) < 3:
810
+ error_files.append([file_audio, "very small text len 3"])
811
+ continue
812
+
813
+ text = clear_text(text)
814
+ text = convert_char_to_pinyin([text], polyphone=True)[0]
815
+
816
+ audio_path_list.append(file_audio)
817
+ duration_list.append(duration)
818
+ text_list.append(text)
819
+
820
+ result.append({"audio_path": file_audio, "text": text, "duration": duration})
821
+ if ch_tokenizer:
822
+ text_vocab_set.update(list(text))
823
+
824
+ lenght += duration
825
+
826
+ if duration_list == []:
827
+ return f"Error: No audio files found in the specified path : {path_project_wavs}", ""
828
+
829
+ min_second = round(min(duration_list), 2)
830
+ max_second = round(max(duration_list), 2)
831
+
832
+ with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
833
+ for line in progress.tqdm(result, total=len(result), desc="prepare data"):
834
+ writer.write(line)
835
+
836
+ with open(file_duration, "w") as f:
837
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
838
+
839
+ new_vocal = ""
840
+ if not ch_tokenizer:
841
+ if not os.path.isfile(file_vocab):
842
+ file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
843
+ if not os.path.isfile(file_vocab_finetune):
844
+ return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", ""
845
+ shutil.copy2(file_vocab_finetune, file_vocab)
846
+
847
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
848
+ vocab_char_map = {}
849
+ for i, char in enumerate(f):
850
+ vocab_char_map[char[:-1]] = i
851
+ vocab_size = len(vocab_char_map)
852
+
853
+ else:
854
+ with open(file_vocab, "w", encoding="utf-8-sig") as f:
855
+ for vocab in sorted(text_vocab_set):
856
+ f.write(vocab + "\n")
857
+ new_vocal += vocab + "\n"
858
+ vocab_size = len(text_vocab_set)
859
+
860
+ if error_files != []:
861
+ error_text = "\n".join([" = ".join(item) for item in error_files])
862
+ else:
863
+ error_text = ""
864
+
865
+ return (
866
+ f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\nvocab : {vocab_size}\n{error_text}",
867
+ new_vocal,
868
+ )
869
+
870
+
871
+ def check_user(value):
872
+ return gr.update(visible=not value), gr.update(visible=value)
873
+
874
+
875
+ def calculate_train(
876
+ name_project,
877
+ batch_size_type,
878
+ max_samples,
879
+ learning_rate,
880
+ num_warmup_updates,
881
+ save_per_updates,
882
+ last_per_steps,
883
+ finetune,
884
+ ):
885
+ path_project = os.path.join(path_data, name_project)
886
+ file_duraction = os.path.join(path_project, "duration.json")
887
+
888
+ if not os.path.isfile(file_duraction):
889
+ return (
890
+ 1000,
891
+ max_samples,
892
+ num_warmup_updates,
893
+ save_per_updates,
894
+ last_per_steps,
895
+ "project not found !",
896
+ learning_rate,
897
+ )
898
+
899
+ with open(file_duraction, "r") as file:
900
+ data = json.load(file)
901
+
902
+ duration_list = data["duration"]
903
+ samples = len(duration_list)
904
+ hours = sum(duration_list) / 3600
905
+
906
+ # if torch.cuda.is_available():
907
+ # gpu_properties = torch.cuda.get_device_properties(0)
908
+ # total_memory = gpu_properties.total_memory / (1024**3)
909
+ # elif torch.backends.mps.is_available():
910
+ # total_memory = psutil.virtual_memory().available / (1024**3)
911
+
912
+ if torch.cuda.is_available():
913
+ gpu_count = torch.cuda.device_count()
914
+ total_memory = 0
915
+ for i in range(gpu_count):
916
+ gpu_properties = torch.cuda.get_device_properties(i)
917
+ total_memory += gpu_properties.total_memory / (1024**3) # in GB
918
+
919
+ elif torch.backends.mps.is_available():
920
+ gpu_count = 1
921
+ total_memory = psutil.virtual_memory().available / (1024**3)
922
+
923
+ if batch_size_type == "frame":
924
+ batch = int(total_memory * 0.5)
925
+ batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
926
+ batch_size_per_gpu = int(38400 / batch)
927
+ else:
928
+ batch_size_per_gpu = int(total_memory / 8)
929
+ batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
930
+ batch = batch_size_per_gpu
931
+
932
+ if batch_size_per_gpu <= 0:
933
+ batch_size_per_gpu = 1
934
+
935
+ if samples < 64:
936
+ max_samples = int(samples * 0.25)
937
+ else:
938
+ max_samples = 64
939
+
940
+ num_warmup_updates = int(samples * 0.05)
941
+ save_per_updates = int(samples * 0.10)
942
+ last_per_steps = int(save_per_updates * 0.25)
943
+
944
+ max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
945
+ num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
946
+ save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
947
+ last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
948
+ if last_per_steps <= 0:
949
+ last_per_steps = 2
950
+
951
+ total_hours = hours
952
+ mel_hop_length = 256
953
+ mel_sampling_rate = 24000
954
+
955
+ # target
956
+ wanted_max_updates = 1000000
957
+
958
+ # train params
959
+ gpus = gpu_count
960
+ frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
961
+ grad_accum = 1
962
+
963
+ # intermediate
964
+ mini_batch_frames = frames_per_gpu * grad_accum * gpus
965
+ mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
966
+ updates_per_epoch = total_hours / mini_batch_hours
967
+ # steps_per_epoch = updates_per_epoch * grad_accum
968
+ epochs = wanted_max_updates / updates_per_epoch
969
+
970
+ if finetune:
971
+ learning_rate = 1e-5
972
+ else:
973
+ learning_rate = 7.5e-5
974
+
975
+ return (
976
+ batch_size_per_gpu,
977
+ max_samples,
978
+ num_warmup_updates,
979
+ save_per_updates,
980
+ last_per_steps,
981
+ samples,
982
+ learning_rate,
983
+ int(epochs),
984
+ )
985
+
986
+
987
+ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
988
+ try:
989
+ checkpoint = torch.load(checkpoint_path)
990
+ print("Original Checkpoint Keys:", checkpoint.keys())
991
+
992
+ ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
993
+ if ema_model_state_dict is None:
994
+ return "No 'ema_model_state_dict' found in the checkpoint."
995
+
996
+ if safetensors:
997
+ new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
998
+ save_file(ema_model_state_dict, new_checkpoint_path)
999
+ else:
1000
+ new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
1001
+ new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
1002
+ torch.save(new_checkpoint, new_checkpoint_path)
1003
+
1004
+ return f"New checkpoint saved at: {new_checkpoint_path}"
1005
+
1006
+ except Exception as e:
1007
+ return f"An error occurred: {e}"
1008
+
1009
+
1010
+ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
1011
+ seed = 666
1012
+ random.seed(seed)
1013
+ os.environ["PYTHONHASHSEED"] = str(seed)
1014
+ torch.manual_seed(seed)
1015
+ torch.cuda.manual_seed(seed)
1016
+ torch.cuda.manual_seed_all(seed)
1017
+ torch.backends.cudnn.deterministic = True
1018
+ torch.backends.cudnn.benchmark = False
1019
+
1020
+ ckpt = torch.load(ckpt_path, map_location="cpu")
1021
+
1022
+ ema_sd = ckpt.get("ema_model_state_dict", {})
1023
+ embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
1024
+ old_embed_ema = ema_sd[embed_key_ema]
1025
+
1026
+ vocab_old = old_embed_ema.size(0)
1027
+ embed_dim = old_embed_ema.size(1)
1028
+ vocab_new = vocab_old + num_new_tokens
1029
+
1030
+ def expand_embeddings(old_embeddings):
1031
+ new_embeddings = torch.zeros((vocab_new, embed_dim))
1032
+ new_embeddings[:vocab_old] = old_embeddings
1033
+ new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim))
1034
+ return new_embeddings
1035
+
1036
+ ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
1037
+
1038
+ torch.save(ckpt, new_ckpt_path)
1039
+
1040
+ return vocab_new
1041
+
1042
+
1043
+ def vocab_count(text):
1044
+ return str(len(text.split(",")))
1045
+
1046
+
1047
+ def vocab_extend(project_name, symbols, model_type):
1048
+ if symbols == "":
1049
+ return "Symbols empty!"
1050
+
1051
+ name_project = project_name
1052
+ path_project = os.path.join(path_data, name_project)
1053
+ file_vocab_project = os.path.join(path_project, "vocab.txt")
1054
+
1055
+ file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
1056
+ if not os.path.isfile(file_vocab):
1057
+ return f"the file {file_vocab} not found !"
1058
+
1059
+ symbols = symbols.split(",")
1060
+ if symbols == []:
1061
+ return "Symbols to extend not found."
1062
+
1063
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
1064
+ data = f.read()
1065
+ vocab = data.split("\n")
1066
+ vocab_check = set(vocab)
1067
+
1068
+ miss_symbols = []
1069
+ for item in symbols:
1070
+ item = item.replace(" ", "")
1071
+ if item in vocab_check:
1072
+ continue
1073
+ miss_symbols.append(item)
1074
+
1075
+ if miss_symbols == []:
1076
+ return "Symbols are okay no need to extend."
1077
+
1078
+ size_vocab = len(vocab)
1079
+ vocab.pop()
1080
+ for item in miss_symbols:
1081
+ vocab.append(item)
1082
+
1083
+ vocab.append("")
1084
+
1085
+ with open(file_vocab_project, "w", encoding="utf-8") as f:
1086
+ f.write("\n".join(vocab))
1087
+
1088
+ if model_type == "F5-TTS":
1089
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
1090
+ else:
1091
+ ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
1092
+
1093
+ vocab_size_new = len(miss_symbols)
1094
+
1095
+ dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
1096
+ new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
1097
+ os.makedirs(new_ckpt_path, exist_ok=True)
1098
+ new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
1099
+
1100
+ size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
1101
+
1102
+ vocab_new = "\n".join(miss_symbols)
1103
+ return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
1104
+
1105
+
1106
+ def vocab_check(project_name):
1107
+ name_project = project_name
1108
+ path_project = os.path.join(path_data, name_project)
1109
+
1110
+ file_metadata = os.path.join(path_project, "metadata.csv")
1111
+
1112
+ file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
1113
+ if not os.path.isfile(file_vocab):
1114
+ return f"the file {file_vocab} not found !", ""
1115
+
1116
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
1117
+ data = f.read()
1118
+ vocab = data.split("\n")
1119
+ vocab = set(vocab)
1120
+
1121
+ if not os.path.isfile(file_metadata):
1122
+ return f"the file {file_metadata} not found !", ""
1123
+
1124
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
1125
+ data = f.read()
1126
+
1127
+ miss_symbols = []
1128
+ miss_symbols_keep = {}
1129
+ for item in data.split("\n"):
1130
+ sp = item.split("|")
1131
+ if len(sp) != 2:
1132
+ continue
1133
+
1134
+ text = sp[1].lower().strip()
1135
+
1136
+ for t in text:
1137
+ if t not in vocab and t not in miss_symbols_keep:
1138
+ miss_symbols.append(t)
1139
+ miss_symbols_keep[t] = t
1140
+
1141
+ if miss_symbols == []:
1142
+ vocab_miss = ""
1143
+ info = "You can train using your language !"
1144
+ else:
1145
+ vocab_miss = ",".join(miss_symbols)
1146
+ info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
1147
+
1148
+ return info, vocab_miss
1149
+
1150
+
1151
+ def get_random_sample_prepare(project_name):
1152
+ name_project = project_name
1153
+ path_project = os.path.join(path_data, name_project)
1154
+ file_arrow = os.path.join(path_project, "raw.arrow")
1155
+ if not os.path.isfile(file_arrow):
1156
+ return "", None
1157
+ dataset = Dataset_.from_file(file_arrow)
1158
+ random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
1159
+ text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
1160
+ audio_path = random_sample["audio_path"][0]
1161
+ return text, audio_path
1162
+
1163
+
1164
+ def get_random_sample_transcribe(project_name):
1165
+ name_project = project_name
1166
+ path_project = os.path.join(path_data, name_project)
1167
+ file_metadata = os.path.join(path_project, "metadata.csv")
1168
+ if not os.path.isfile(file_metadata):
1169
+ return "", None
1170
+
1171
+ data = ""
1172
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
1173
+ data = f.read()
1174
+
1175
+ list_data = []
1176
+ for item in data.split("\n"):
1177
+ sp = item.split("|")
1178
+ if len(sp) != 2:
1179
+ continue
1180
+
1181
+ # fixed audio when it is absolute
1182
+ file_audio = get_correct_audio_path(sp[0], os.path.join(path_project, "wavs"))
1183
+ list_data.append([file_audio, sp[1]])
1184
+
1185
+ if list_data == []:
1186
+ return "", None
1187
+
1188
+ random_item = random.choice(list_data)
1189
+
1190
+ return random_item[1], random_item[0]
1191
+
1192
+
1193
+ def get_random_sample_infer(project_name):
1194
+ text, audio = get_random_sample_transcribe(project_name)
1195
+ return (
1196
+ text,
1197
+ text,
1198
+ audio,
1199
+ )
1200
+
1201
+
1202
+ def infer(
1203
+ project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence
1204
+ ):
1205
+ global last_checkpoint, last_device, tts_api, last_ema
1206
+
1207
+ if not os.path.isfile(file_checkpoint):
1208
+ return None, "checkpoint not found!"
1209
+
1210
+ if training_process is not None:
1211
+ device_test = "cpu"
1212
+ else:
1213
+ device_test = None
1214
+
1215
+ if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema or tts_api is None:
1216
+ if last_checkpoint != file_checkpoint:
1217
+ last_checkpoint = file_checkpoint
1218
+
1219
+ if last_device != device_test:
1220
+ last_device = device_test
1221
+
1222
+ if last_ema != use_ema:
1223
+ last_ema = use_ema
1224
+
1225
+ vocab_file = os.path.join(path_data, project, "vocab.txt")
1226
+
1227
+ tts_api = F5TTS(
1228
+ model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
1229
+ )
1230
+
1231
+ print("update >> ", device_test, file_checkpoint, use_ema)
1232
+
1233
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1234
+ tts_api.infer(
1235
+ gen_text=gen_text.lower().strip(),
1236
+ ref_text=ref_text.lower().strip(),
1237
+ ref_file=ref_audio,
1238
+ nfe_step=nfe_step,
1239
+ file_wave=f.name,
1240
+ speed=speed,
1241
+ seed=seed,
1242
+ remove_silence=remove_silence,
1243
+ )
1244
+ return f.name, tts_api.device, str(tts_api.seed)
1245
+
1246
+
1247
+ def check_finetune(finetune):
1248
+ return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune)
1249
+
1250
+
1251
+ def get_checkpoints_project(project_name, is_gradio=True):
1252
+ if project_name is None:
1253
+ return [], ""
1254
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
1255
+
1256
+ if os.path.isdir(path_project_ckpts):
1257
+ files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
1258
+ files_checkpoints = sorted(
1259
+ files_checkpoints,
1260
+ key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
1261
+ if os.path.basename(x) != "model_last.pt"
1262
+ else float("inf"),
1263
+ )
1264
+ else:
1265
+ files_checkpoints = []
1266
+
1267
+ selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0]
1268
+
1269
+ if is_gradio:
1270
+ return gr.update(choices=files_checkpoints, value=selelect_checkpoint)
1271
+
1272
+ return files_checkpoints, selelect_checkpoint
1273
+
1274
+
1275
+ def get_audio_project(project_name, is_gradio=True):
1276
+ if project_name is None:
1277
+ return [], ""
1278
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
1279
+
1280
+ if os.path.isdir(path_project_ckpts):
1281
+ files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav"))
1282
+ files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))
1283
+
1284
+ files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")]
1285
+ else:
1286
+ files_audios = []
1287
+
1288
+ selelect_checkpoint = None if not files_audios else files_audios[0]
1289
+
1290
+ if is_gradio:
1291
+ return gr.update(choices=files_audios, value=selelect_checkpoint)
1292
+
1293
+ return files_audios, selelect_checkpoint
1294
+
1295
+
1296
+ def get_gpu_stats():
1297
+ gpu_stats = ""
1298
+
1299
+ if torch.cuda.is_available():
1300
+ gpu_count = torch.cuda.device_count()
1301
+ for i in range(gpu_count):
1302
+ gpu_name = torch.cuda.get_device_name(i)
1303
+ gpu_properties = torch.cuda.get_device_properties(i)
1304
+ total_memory = gpu_properties.total_memory / (1024**3) # in GB
1305
+ allocated_memory = torch.cuda.memory_allocated(i) / (1024**2) # in MB
1306
+ reserved_memory = torch.cuda.memory_reserved(i) / (1024**2) # in MB
1307
+
1308
+ gpu_stats += (
1309
+ f"GPU {i} Name: {gpu_name}\n"
1310
+ f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
1311
+ f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
1312
+ f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
1313
+ )
1314
+
1315
+ elif torch.backends.mps.is_available():
1316
+ gpu_count = 1
1317
+ gpu_stats += "MPS GPU\n"
1318
+ total_memory = psutil.virtual_memory().total / (
1319
+ 1024**3
1320
+ ) # Total system memory (MPS doesn't have its own memory)
1321
+ allocated_memory = 0
1322
+ reserved_memory = 0
1323
+
1324
+ gpu_stats += (
1325
+ f"Total system memory: {total_memory:.2f} GB\n"
1326
+ f"Allocated GPU memory (MPS): {allocated_memory:.2f} MB\n"
1327
+ f"Reserved GPU memory (MPS): {reserved_memory:.2f} MB\n"
1328
+ )
1329
+
1330
+ else:
1331
+ gpu_stats = "No GPU available"
1332
+
1333
+ return gpu_stats
1334
+
1335
+
1336
+ def get_cpu_stats():
1337
+ cpu_usage = psutil.cpu_percent(interval=1)
1338
+ memory_info = psutil.virtual_memory()
1339
+ memory_used = memory_info.used / (1024**2)
1340
+ memory_total = memory_info.total / (1024**2)
1341
+ memory_percent = memory_info.percent
1342
+
1343
+ pid = os.getpid()
1344
+ process = psutil.Process(pid)
1345
+ nice_value = process.nice()
1346
+
1347
+ cpu_stats = (
1348
+ f"CPU Usage: {cpu_usage:.2f}%\n"
1349
+ f"System Memory: {memory_used:.2f} MB used / {memory_total:.2f} MB total ({memory_percent}% used)\n"
1350
+ f"Process Priority (Nice value): {nice_value}"
1351
+ )
1352
+
1353
+ return cpu_stats
1354
+
1355
+
1356
+ def get_combined_stats():
1357
+ gpu_stats = get_gpu_stats()
1358
+ cpu_stats = get_cpu_stats()
1359
+ combined_stats = f"### GPU Stats\n{gpu_stats}\n\n### CPU Stats\n{cpu_stats}"
1360
+ return combined_stats
1361
+
1362
+
1363
+ def get_audio_select(file_sample):
1364
+ select_audio_ref = file_sample
1365
+ select_audio_gen = file_sample
1366
+
1367
+ if file_sample is not None:
1368
+ select_audio_ref += "_ref.wav"
1369
+ select_audio_gen += "_gen.wav"
1370
+
1371
+ return select_audio_ref, select_audio_gen
1372
+
1373
+
1374
+ with gr.Blocks() as app:
1375
+ gr.Markdown(
1376
+ """
1377
+ # E2/F5 TTS Automatic Finetune
1378
+
1379
+ This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
1380
+
1381
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
1382
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
1383
+
1384
+ The checkpoints support English and Chinese.
1385
+
1386
+ For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
1387
+ """
1388
+ )
1389
+
1390
+ with gr.Row():
1391
+ projects, projects_selelect = get_list_projects()
1392
+ tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char", "custom"], value="pinyin")
1393
+ project_name = gr.Textbox(label="Project Name", value="my_speak")
1394
+ bt_create = gr.Button("Create a New Project")
1395
+
1396
+ with gr.Row():
1397
+ cm_project = gr.Dropdown(
1398
+ choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6
1399
+ )
1400
+ ch_refresh_project = gr.Button("Refresh", scale=1)
1401
+
1402
+ bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
1403
+
1404
+ with gr.Tabs():
1405
+ with gr.TabItem("Transcribe Data"):
1406
+ gr.Markdown("""```plaintext
1407
+ Skip this step if you have your dataset, metadata.csv, and a folder wavs with all the audio files.
1408
+ ```""")
1409
+
1410
+ ch_manual = gr.Checkbox(label="Audio from Path", value=False)
1411
+
1412
+ mark_info_transcribe = gr.Markdown(
1413
+ """```plaintext
1414
+ Place your 'wavs' folder and 'metadata.csv' file in the '{your_project_name}' directory.
1415
+
1416
+ my_speak/
1417
+
1418
+ └── dataset/
1419
+ ├── audio1.wav
1420
+ └── audio2.wav
1421
+ ...
1422
+ ```""",
1423
+ visible=False,
1424
+ )
1425
+
1426
+ audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
1427
+ txt_lang = gr.Text(label="Language", value="English")
1428
+ bt_transcribe = bt_create = gr.Button("Transcribe")
1429
+ txt_info_transcribe = gr.Text(label="Info", value="")
1430
+ bt_transcribe.click(
1431
+ fn=transcribe_all,
1432
+ inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
1433
+ outputs=[txt_info_transcribe],
1434
+ )
1435
+ ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
1436
+
1437
+ random_sample_transcribe = gr.Button("Random Sample")
1438
+
1439
+ with gr.Row():
1440
+ random_text_transcribe = gr.Text(label="Text")
1441
+ random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
1442
+
1443
+ random_sample_transcribe.click(
1444
+ fn=get_random_sample_transcribe,
1445
+ inputs=[cm_project],
1446
+ outputs=[random_text_transcribe, random_audio_transcribe],
1447
+ )
1448
+
1449
+ with gr.TabItem("Vocab Check"):
1450
+ gr.Markdown("""```plaintext
1451
+ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. For fine-tuning a new language.
1452
+ ```""")
1453
+
1454
+ check_button = gr.Button("Check Vocab")
1455
+ txt_info_check = gr.Text(label="Info", value="")
1456
+
1457
+ gr.Markdown("""```plaintext
1458
+ Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
1459
+ ```""")
1460
+
1461
+ exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1462
+
1463
+ with gr.Row():
1464
+ txt_extend = gr.Textbox(
1465
+ label="Symbols",
1466
+ value="",
1467
+ placeholder="To add new symbols, make sure to use ',' for each symbol",
1468
+ scale=6,
1469
+ )
1470
+ txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
1471
+
1472
+ extend_button = gr.Button("Extend")
1473
+ txt_info_extend = gr.Text(label="Info", value="")
1474
+
1475
+ txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
1476
+ check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
1477
+ extend_button.click(
1478
+ fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
1479
+ )
1480
+
1481
+ with gr.TabItem("Prepare Data"):
1482
+ gr.Markdown("""```plaintext
1483
+ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
1484
+ ```""")
1485
+
1486
+ gr.Markdown(
1487
+ """```plaintext
1488
+ Place all your "wavs" folder and your "metadata.csv" file in your project name directory.
1489
+
1490
+ Supported audio formats: "wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr"
1491
+
1492
+ Example wav format:
1493
+ my_speak/
1494
+
1495
+ ├── wavs/
1496
+ │ ├── audio1.wav
1497
+ │ └── audio2.wav
1498
+ | ...
1499
+
1500
+ └── metadata.csv
1501
+
1502
+ File format metadata.csv:
1503
+
1504
+ audio1|text1 or audio1.wav|text1 or your_path/audio1.wav|text1
1505
+ audio2|text1 or audio2.wav|text1 or your_path/audio2.wav|text1
1506
+ ...
1507
+
1508
+ ```"""
1509
+ )
1510
+ ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
1511
+
1512
+ bt_prepare = bt_create = gr.Button("Prepare")
1513
+ txt_info_prepare = gr.Text(label="Info", value="")
1514
+ txt_vocab_prepare = gr.Text(label="Vocab", value="")
1515
+
1516
+ bt_prepare.click(
1517
+ fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
1518
+ )
1519
+
1520
+ random_sample_prepare = gr.Button("Random Sample")
1521
+
1522
+ with gr.Row():
1523
+ random_text_prepare = gr.Text(label="Tokenizer")
1524
+ random_audio_prepare = gr.Audio(label="Audio", type="filepath")
1525
+
1526
+ random_sample_prepare.click(
1527
+ fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
1528
+ )
1529
+
1530
+ with gr.TabItem("Train Data"):
1531
+ gr.Markdown("""```plaintext
1532
+ The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per steps are set correctly, or change them manually as needed.
1533
+ If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
1534
+ ```""")
1535
+ with gr.Row():
1536
+ bt_calculate = bt_create = gr.Button("Auto Settings")
1537
+ lb_samples = gr.Label(label="Samples")
1538
+ batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
1539
+
1540
+ with gr.Row():
1541
+ ch_finetune = bt_create = gr.Checkbox(label="Finetune", value=True)
1542
+ tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
1543
+ file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
1544
+
1545
+ with gr.Row():
1546
+ exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
1547
+ learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
1548
+
1549
+ with gr.Row():
1550
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
1551
+ max_samples = gr.Number(label="Max Samples", value=64)
1552
+
1553
+ with gr.Row():
1554
+ grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
1555
+ max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
1556
+
1557
+ with gr.Row():
1558
+ epochs = gr.Number(label="Epochs", value=10)
1559
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=2)
1560
+
1561
+ with gr.Row():
1562
+ save_per_updates = gr.Number(label="Save per Updates", value=300)
1563
+ last_per_steps = gr.Number(label="Last per Steps", value=100)
1564
+
1565
+ with gr.Row():
1566
+ ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
1567
+ mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
1568
+ cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1569
+ start_button = gr.Button("Start Training")
1570
+ stop_button = gr.Button("Stop Training", interactive=False)
1571
+
1572
+ if projects_selelect is not None:
1573
+ (
1574
+ exp_namev,
1575
+ learning_ratev,
1576
+ batch_size_per_gpuv,
1577
+ batch_size_typev,
1578
+ max_samplesv,
1579
+ grad_accumulation_stepsv,
1580
+ max_grad_normv,
1581
+ epochsv,
1582
+ num_warmupv_updatesv,
1583
+ save_per_updatesv,
1584
+ last_per_stepsv,
1585
+ finetunev,
1586
+ file_checkpoint_trainv,
1587
+ tokenizer_typev,
1588
+ tokenizer_filev,
1589
+ mixed_precisionv,
1590
+ cd_loggerv,
1591
+ ch_8bit_adamv,
1592
+ ) = load_settings(projects_selelect)
1593
+ exp_name.value = exp_namev
1594
+ learning_rate.value = learning_ratev
1595
+ batch_size_per_gpu.value = batch_size_per_gpuv
1596
+ batch_size_type.value = batch_size_typev
1597
+ max_samples.value = max_samplesv
1598
+ grad_accumulation_steps.value = grad_accumulation_stepsv
1599
+ max_grad_norm.value = max_grad_normv
1600
+ epochs.value = epochsv
1601
+ num_warmup_updates.value = num_warmupv_updatesv
1602
+ save_per_updates.value = save_per_updatesv
1603
+ last_per_steps.value = last_per_stepsv
1604
+ ch_finetune.value = finetunev
1605
+ file_checkpoint_train.value = file_checkpoint_trainv
1606
+ tokenizer_type.value = tokenizer_typev
1607
+ tokenizer_file.value = tokenizer_filev
1608
+ mixed_precision.value = mixed_precisionv
1609
+ cd_logger.value = cd_loggerv
1610
+ ch_8bit_adam.value = ch_8bit_adamv
1611
+
1612
+ ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1613
+ txt_info_train = gr.Text(label="Info", value="")
1614
+
1615
+ list_audios, select_audio = get_audio_project(projects_selelect, False)
1616
+
1617
+ select_audio_ref = select_audio
1618
+ select_audio_gen = select_audio
1619
+
1620
+ if select_audio is not None:
1621
+ select_audio_ref += "_ref.wav"
1622
+ select_audio_gen += "_gen.wav"
1623
+
1624
+ with gr.Row():
1625
+ ch_list_audio = gr.Dropdown(
1626
+ choices=list_audios,
1627
+ value=select_audio,
1628
+ label="Audios",
1629
+ allow_custom_value=True,
1630
+ scale=6,
1631
+ interactive=True,
1632
+ )
1633
+ bt_stream_audio = gr.Button("Refresh", scale=1)
1634
+ bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1635
+ cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1636
+
1637
+ with gr.Row():
1638
+ audio_ref_stream = gr.Audio(label="Original", type="filepath", value=select_audio_ref)
1639
+ audio_gen_stream = gr.Audio(label="Generate", type="filepath", value=select_audio_gen)
1640
+
1641
+ ch_list_audio.change(
1642
+ fn=get_audio_select,
1643
+ inputs=[ch_list_audio],
1644
+ outputs=[audio_ref_stream, audio_gen_stream],
1645
+ )
1646
+
1647
+ start_button.click(
1648
+ fn=start_training,
1649
+ inputs=[
1650
+ cm_project,
1651
+ exp_name,
1652
+ learning_rate,
1653
+ batch_size_per_gpu,
1654
+ batch_size_type,
1655
+ max_samples,
1656
+ grad_accumulation_steps,
1657
+ max_grad_norm,
1658
+ epochs,
1659
+ num_warmup_updates,
1660
+ save_per_updates,
1661
+ last_per_steps,
1662
+ ch_finetune,
1663
+ file_checkpoint_train,
1664
+ tokenizer_type,
1665
+ tokenizer_file,
1666
+ mixed_precision,
1667
+ ch_stream,
1668
+ cd_logger,
1669
+ ch_8bit_adam,
1670
+ ],
1671
+ outputs=[txt_info_train, start_button, stop_button],
1672
+ )
1673
+ stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
1674
+
1675
+ bt_calculate.click(
1676
+ fn=calculate_train,
1677
+ inputs=[
1678
+ cm_project,
1679
+ batch_size_type,
1680
+ max_samples,
1681
+ learning_rate,
1682
+ num_warmup_updates,
1683
+ save_per_updates,
1684
+ last_per_steps,
1685
+ ch_finetune,
1686
+ ],
1687
+ outputs=[
1688
+ batch_size_per_gpu,
1689
+ max_samples,
1690
+ num_warmup_updates,
1691
+ save_per_updates,
1692
+ last_per_steps,
1693
+ lb_samples,
1694
+ learning_rate,
1695
+ epochs,
1696
+ ],
1697
+ )
1698
+
1699
+ ch_finetune.change(
1700
+ check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
1701
+ )
1702
+
1703
+ def setup_load_settings():
1704
+ output_components = [
1705
+ exp_name,
1706
+ learning_rate,
1707
+ batch_size_per_gpu,
1708
+ batch_size_type,
1709
+ max_samples,
1710
+ grad_accumulation_steps,
1711
+ max_grad_norm,
1712
+ epochs,
1713
+ num_warmup_updates,
1714
+ save_per_updates,
1715
+ last_per_steps,
1716
+ ch_finetune,
1717
+ file_checkpoint_train,
1718
+ tokenizer_type,
1719
+ tokenizer_file,
1720
+ mixed_precision,
1721
+ cd_logger,
1722
+ ]
1723
+
1724
+ return output_components
1725
+
1726
+ outputs = setup_load_settings()
1727
+
1728
+ cm_project.change(
1729
+ fn=load_settings,
1730
+ inputs=[cm_project],
1731
+ outputs=outputs,
1732
+ )
1733
+
1734
+ ch_refresh_project.click(
1735
+ fn=load_settings,
1736
+ inputs=[cm_project],
1737
+ outputs=outputs,
1738
+ )
1739
+
1740
+ with gr.TabItem("Test Model"):
1741
+ gr.Markdown("""```plaintext
1742
+ SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
1743
+ ```""")
1744
+ exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1745
+ list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1746
+
1747
+ with gr.Row():
1748
+ nfe_step = gr.Number(label="NFE Step", value=32)
1749
+ speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
1750
+ seed = gr.Number(label="Seed", value=-1, minimum=-1)
1751
+ remove_silence = gr.Checkbox(label="Remove Silence")
1752
+
1753
+ ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
1754
+ with gr.Row():
1755
+ cm_checkpoint = gr.Dropdown(
1756
+ choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
1757
+ )
1758
+ bt_checkpoint_refresh = gr.Button("Refresh")
1759
+
1760
+ random_sample_infer = gr.Button("Random Sample")
1761
+
1762
+ ref_text = gr.Textbox(label="Ref Text")
1763
+ ref_audio = gr.Audio(label="Audio Ref", type="filepath")
1764
+ gen_text = gr.Textbox(label="Gen Text")
1765
+
1766
+ random_sample_infer.click(
1767
+ fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
1768
+ )
1769
+
1770
+ with gr.Row():
1771
+ txt_info_gpu = gr.Textbox("", label="Device")
1772
+ seed_info = gr.Text(label="Seed :")
1773
+ check_button_infer = gr.Button("Infer")
1774
+
1775
+ gen_audio = gr.Audio(label="Audio Gen", type="filepath")
1776
+
1777
+ check_button_infer.click(
1778
+ fn=infer,
1779
+ inputs=[
1780
+ cm_project,
1781
+ cm_checkpoint,
1782
+ exp_name,
1783
+ ref_text,
1784
+ ref_audio,
1785
+ gen_text,
1786
+ nfe_step,
1787
+ ch_use_ema,
1788
+ speed,
1789
+ seed,
1790
+ remove_silence,
1791
+ ],
1792
+ outputs=[gen_audio, txt_info_gpu, seed_info],
1793
+ )
1794
+
1795
+ bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1796
+ cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1797
+
1798
+ with gr.TabItem("Reduce Checkpoint"):
1799
+ gr.Markdown("""```plaintext
1800
+ Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or fine-tuning afterward, but it cannot be used to continue training.
1801
+ ```""")
1802
+ txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
1803
+ txt_path_checkpoint_small = gr.Text(label="Path to Output:")
1804
+ ch_safetensors = gr.Checkbox(label="Safetensors", value="")
1805
+ txt_info_reduse = gr.Text(label="Info", value="")
1806
+ reduse_button = gr.Button("Reduce")
1807
+ reduse_button.click(
1808
+ fn=extract_and_save_ema_model,
1809
+ inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
1810
+ outputs=[txt_info_reduse],
1811
+ )
1812
+
1813
+ with gr.TabItem("System Info"):
1814
+ output_box = gr.Textbox(label="GPU and CPU Information", lines=20)
1815
+
1816
+ def update_stats():
1817
+ return get_combined_stats()
1818
+
1819
+ update_button = gr.Button("Update Stats")
1820
+ update_button.click(fn=update_stats, outputs=output_box)
1821
+
1822
+ def auto_update():
1823
+ yield gr.update(value=update_stats())
1824
+
1825
+ gr.update(fn=auto_update, inputs=[], outputs=output_box)
1826
+
1827
+
1828
+ @click.command()
1829
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
1830
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
1831
+ @click.option(
1832
+ "--share",
1833
+ "-s",
1834
+ default=False,
1835
+ is_flag=True,
1836
+ help="Share the app via Gradio share link",
1837
+ )
1838
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
1839
+ def main(port, host, share, api):
1840
+ global app
1841
+ print("Starting app...")
1842
+ app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
1843
+
1844
+
1845
+ if __name__ == "__main__":
1846
+ main()