ashishkblink commited on
Commit
441f260
·
verified ·
1 Parent(s): c47c9cd

Upload f5_tts/model/dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/model/dataset.py +331 -0
f5_tts/model/dataset.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from importlib.resources import files
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ from datasets import Dataset as Dataset_
9
+ from datasets import load_from_disk
10
+ from torch import nn
11
+ from torch.utils.data import Dataset, Sampler
12
+ from tqdm import tqdm
13
+
14
+ from f5_tts.model.modules import MelSpec
15
+ from f5_tts.model.utils import default
16
+
17
+
18
+ class HFDataset(Dataset):
19
+ def __init__(
20
+ self,
21
+ hf_dataset: Dataset,
22
+ target_sample_rate=24_000,
23
+ n_mel_channels=100,
24
+ hop_length=256,
25
+ n_fft=1024,
26
+ win_length=1024,
27
+ mel_spec_type="vocos",
28
+ ):
29
+ self.data = hf_dataset
30
+ self.target_sample_rate = target_sample_rate
31
+ self.hop_length = hop_length
32
+
33
+ self.mel_spectrogram = MelSpec(
34
+ n_fft=n_fft,
35
+ hop_length=hop_length,
36
+ win_length=win_length,
37
+ n_mel_channels=n_mel_channels,
38
+ target_sample_rate=target_sample_rate,
39
+ mel_spec_type=mel_spec_type,
40
+ )
41
+
42
+ def get_frame_len(self, index):
43
+ row = self.data[index]
44
+ audio = row["audio"]["array"]
45
+ sample_rate = row["audio"]["sampling_rate"]
46
+ return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
47
+
48
+ def __len__(self):
49
+ return len(self.data)
50
+
51
+ def __getitem__(self, index):
52
+ row = self.data[index]
53
+ audio = row["audio"]["array"]
54
+
55
+ # logger.info(f"Audio shape: {audio.shape}")
56
+
57
+ sample_rate = row["audio"]["sampling_rate"]
58
+ duration = audio.shape[-1] / sample_rate
59
+
60
+ if duration > 30 or duration < 0.3:
61
+ return self.__getitem__((index + 1) % len(self.data))
62
+
63
+ audio_tensor = torch.from_numpy(audio).float()
64
+
65
+ if sample_rate != self.target_sample_rate:
66
+ resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
67
+ audio_tensor = resampler(audio_tensor)
68
+
69
+ audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
70
+
71
+ mel_spec = self.mel_spectrogram(audio_tensor)
72
+
73
+ mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
74
+
75
+ text = row["text"]
76
+
77
+ return dict(
78
+ mel_spec=mel_spec,
79
+ text=text,
80
+ )
81
+
82
+
83
+ class CustomDataset(Dataset):
84
+ def __init__(
85
+ self,
86
+ custom_dataset: Dataset,
87
+ durations=None,
88
+ target_sample_rate=24_000,
89
+ hop_length=256,
90
+ n_mel_channels=100,
91
+ n_fft=1024,
92
+ win_length=1024,
93
+ mel_spec_type="vocos",
94
+ preprocessed_mel=False,
95
+ mel_spec_module: nn.Module | None = None,
96
+ ):
97
+ self.data = custom_dataset
98
+ self.durations = durations
99
+ self.target_sample_rate = target_sample_rate
100
+ self.hop_length = hop_length
101
+ self.n_fft = n_fft
102
+ self.win_length = win_length
103
+ self.mel_spec_type = mel_spec_type
104
+ self.preprocessed_mel = preprocessed_mel
105
+
106
+ if not preprocessed_mel:
107
+ self.mel_spectrogram = default(
108
+ mel_spec_module,
109
+ MelSpec(
110
+ n_fft=n_fft,
111
+ hop_length=hop_length,
112
+ win_length=win_length,
113
+ n_mel_channels=n_mel_channels,
114
+ target_sample_rate=target_sample_rate,
115
+ mel_spec_type=mel_spec_type,
116
+ ),
117
+ )
118
+
119
+ def get_frame_len(self, index):
120
+ if (
121
+ self.durations is not None
122
+ ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
123
+ return self.durations[index] * self.target_sample_rate / self.hop_length
124
+ return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
125
+
126
+ def __len__(self):
127
+ return len(self.data)
128
+
129
+ def __getitem__(self, index):
130
+ while True:
131
+ row = self.data[index]
132
+ audio_path = row["audio_path"]
133
+ # YOTTA Specific path fixes. Please don't ever do this, and fix the dataset arrow instead!
134
+ audio_path = audio_path.replace('/home/tts/ttsteam/datasets', '/projects/data/ttsteam/datasets/')
135
+
136
+ if 'limmits' in audio_path:
137
+ lang_spk = audio_path.split('limmits/')[1].split('/')[0]
138
+ lang, spk = lang_spk.split('_')
139
+ audio_path = audio_path.replace(f'limmits/{lang_spk}', f'limmits/processed_datasets/{lang}/{spk}')
140
+ audio_path = audio_path.replace('processed/datasets', '')
141
+ if 'indictts' in audio_path:
142
+ audio_path = audio_path.replace('/wavs-24k/', '/wavs-22k/')
143
+
144
+ text = row["text"]
145
+ duration = row["duration"]
146
+
147
+ # filter by given length
148
+ if 0.3 <= duration <= 30:
149
+ break # valid
150
+
151
+ index = (index + 1) % len(self.data)
152
+
153
+ if self.preprocessed_mel:
154
+ mel_spec = torch.tensor(row["mel_spec"])
155
+ else:
156
+ audio, source_sample_rate = torchaudio.load(audio_path)
157
+
158
+ # make sure mono input
159
+ if audio.shape[0] > 1:
160
+ audio = torch.mean(audio, dim=0, keepdim=True)
161
+
162
+ # resample if necessary
163
+ if source_sample_rate != self.target_sample_rate:
164
+ resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
165
+ audio = resampler(audio)
166
+
167
+ # to mel spectrogram
168
+ mel_spec = self.mel_spectrogram(audio)
169
+ mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
170
+
171
+ return {
172
+ "mel_spec": mel_spec,
173
+ "text": text,
174
+ }
175
+
176
+
177
+ # Dynamic Batch Sampler
178
+ class DynamicBatchSampler(Sampler[list[int]]):
179
+ """Extension of Sampler that will do the following:
180
+ 1. Change the batch size (essentially number of sequences)
181
+ in a batch to ensure that the total number of frames are less
182
+ than a certain threshold.
183
+ 2. Make sure the padding efficiency in the batch is high.
184
+ """
185
+
186
+ def __init__(
187
+ self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
188
+ ):
189
+ self.sampler = sampler
190
+ self.frames_threshold = frames_threshold
191
+ self.max_samples = max_samples
192
+
193
+ indices, batches = [], []
194
+ data_source = self.sampler.data_source
195
+
196
+ for idx in tqdm(
197
+ self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
198
+ ):
199
+ indices.append((idx, data_source.get_frame_len(idx)))
200
+ indices.sort(key=lambda elem: elem[1])
201
+
202
+ batch = []
203
+ batch_frames = 0
204
+ for idx, frame_len in tqdm(
205
+ indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
206
+ ):
207
+ if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
208
+ batch.append(idx)
209
+ batch_frames += frame_len
210
+ else:
211
+ if len(batch) > 0:
212
+ batches.append(batch)
213
+ if frame_len <= self.frames_threshold:
214
+ batch = [idx]
215
+ batch_frames = frame_len
216
+ else:
217
+ batch = []
218
+ batch_frames = 0
219
+
220
+ if not drop_last and len(batch) > 0:
221
+ batches.append(batch)
222
+
223
+ del indices
224
+
225
+ # if want to have different batches between epochs, may just set a seed and log it in ckpt
226
+ # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
227
+ # e.g. for epoch n, use (random_seed + n)
228
+ random.seed(random_seed)
229
+ random.shuffle(batches)
230
+
231
+ self.batches = batches
232
+
233
+ def __iter__(self):
234
+ return iter(self.batches)
235
+
236
+ def __len__(self):
237
+ return len(self.batches)
238
+
239
+
240
+ # Load dataset
241
+
242
+
243
+ def load_dataset(
244
+ dataset_name: str,
245
+ tokenizer: str = "pinyin",
246
+ dataset_type: str = "CustomDatasetPath",
247
+ audio_type: str = "raw",
248
+ mel_spec_module: nn.Module | None = None,
249
+ mel_spec_kwargs: dict = dict(),
250
+ data_dir: str = None,
251
+ ) -> CustomDataset | HFDataset:
252
+ """
253
+ dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
254
+ - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
255
+ """
256
+
257
+ print("Loading dataset ...")
258
+
259
+ if dataset_type == "CustomDataset":
260
+ rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
261
+ if audio_type == "raw":
262
+ try:
263
+ train_dataset = load_from_disk(f"{rel_data_path}/raw")
264
+ except: # noqa: E722
265
+ train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
266
+ preprocessed_mel = False
267
+ elif audio_type == "mel":
268
+ train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
269
+ preprocessed_mel = True
270
+ with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
271
+ data_dict = json.load(f)
272
+ durations = data_dict["duration"]
273
+ train_dataset = CustomDataset(
274
+ train_dataset,
275
+ durations=durations,
276
+ preprocessed_mel=preprocessed_mel,
277
+ mel_spec_module=mel_spec_module,
278
+ **mel_spec_kwargs,
279
+ )
280
+
281
+ elif dataset_type == "CustomDatasetPath":
282
+ try:
283
+ train_dataset = load_from_disk(f"{data_dir}/raw")
284
+ except: # noqa: E722
285
+ train_dataset = Dataset_.from_file(f"{data_dir}/raw.arrow")
286
+ preprocessed_mel = False
287
+ with open(f"{data_dir}/duration.json", "r", encoding="utf-8") as f:
288
+ data_dict = json.load(f)
289
+ durations = data_dict["duration"]
290
+ train_dataset = CustomDataset(
291
+ train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
292
+ )
293
+
294
+ elif dataset_type == "HFDataset":
295
+ print(
296
+ "Should manually modify the path of huggingface dataset to your need.\n"
297
+ + "May also the corresponding script cuz different dataset may have different format."
298
+ )
299
+ pre, post = dataset_name.split("_")
300
+ train_dataset = HFDataset(
301
+ load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
302
+ )
303
+
304
+ return train_dataset
305
+
306
+
307
+ # collation
308
+
309
+
310
+ def collate_fn(batch):
311
+ mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
312
+ mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
313
+ max_mel_length = mel_lengths.amax()
314
+
315
+ padded_mel_specs = []
316
+ for spec in mel_specs: # TODO. maybe records mask for attention here
317
+ padding = (0, max_mel_length - spec.size(-1))
318
+ padded_spec = F.pad(spec, padding, value=0)
319
+ padded_mel_specs.append(padded_spec)
320
+
321
+ mel_specs = torch.stack(padded_mel_specs)
322
+
323
+ text = [item["text"] for item in batch]
324
+ text_lengths = torch.LongTensor([len(item) for item in text])
325
+
326
+ return dict(
327
+ mel=mel_specs,
328
+ mel_lengths=mel_lengths,
329
+ text=text,
330
+ text_lengths=text_lengths,
331
+ )