ashishkblink commited on
Commit
f36d2e1
·
verified ·
1 Parent(s): f916b29

Upload f5_tts/train/datasets/prepare_in22_en_10k.py with huggingface_hub

Browse files
f5_tts/train/datasets/prepare_in22_en_10k.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ import argparse
7
+ import csv
8
+ import json
9
+ import shutil
10
+ from importlib.resources import files
11
+ from pathlib import Path
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+
14
+ import torchaudio
15
+ from tqdm import tqdm
16
+ from datasets.arrow_writer import ArrowWriter
17
+
18
+ from f5_tts.model.utils import (
19
+ convert_char_to_pinyin,
20
+ )
21
+
22
+
23
+ # Increase the field size limit
24
+ csv.field_size_limit(sys.maxsize)
25
+
26
+ # PRETRAINED_VOCAB_PATH = Path("/projects/data/ttsteam/repos/f5/data/in22_5k/vocab.txt")
27
+
28
+
29
+ def is_csv_wavs_format(input_dataset_dir):
30
+ fpath = Path(input_dataset_dir)
31
+ metadata = fpath / "metadata.csv"
32
+ wavs = fpath / "wavs"
33
+ return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
34
+
35
+
36
+ def prepare_csv_wavs_dir(input_dir, num_threads=32): # Added num_threads parameter
37
+ print("Inside prepare csv wavs dir!")
38
+ input_dir = Path(input_dir)
39
+ metadata_path = input_dir / "metadata.csv"
40
+ audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
41
+
42
+ sub_result, durations = [], []
43
+ vocab_set = set()
44
+ polyphone = True
45
+
46
+ def process_audio(audio_path_text):
47
+ audio_path, text = audio_path_text
48
+ if not Path(audio_path).exists():
49
+ print(f"audio {audio_path} not found, skipping")
50
+ return None
51
+ audio_duration = get_audio_duration(audio_path)
52
+ # print('before', text)
53
+ text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
54
+ # print('after', text)
55
+ return {"audio_path": audio_path, "text": text, "duration": audio_duration}, audio_duration
56
+
57
+ with ThreadPoolExecutor(max_workers=num_threads) as executor: # Set max_workers
58
+ futures = {executor.submit(process_audio, pair): pair for pair in tqdm(audio_path_text_pairs, desc='submit')}
59
+
60
+ # Use tqdm to track progress
61
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Processing audio files"):
62
+ result = future.result()
63
+ if result is not None:
64
+ # print("result is: ", result)
65
+ aud_dur = result[1]
66
+ if aud_dur < 0.1 or aud_dur > 30:
67
+ continue
68
+ sub_result.append(result[0])
69
+ durations.append(result[1])
70
+ vocab_set.update(list(result[0]['text']))
71
+ else:
72
+ print("Result not found: ", futures[future])
73
+
74
+ return sub_result, durations, vocab_set
75
+
76
+
77
+ def get_audio_duration(audio_path):
78
+ audio, sample_rate = torchaudio.load(audio_path)
79
+ return audio.shape[1] / sample_rate
80
+
81
+
82
+ def read_audio_text_pairs(csv_file_path):
83
+ audio_text_pairs = []
84
+
85
+ parent = Path(csv_file_path).parent
86
+ with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
87
+ reader = csv.reader(csvfile, delimiter="|")
88
+ next(reader) # Skip the header row
89
+ for row in tqdm(reader):
90
+ if len(row) == 2: # Only if len == 2, else skip the row as could be noisy. IN22 texts could use '|' as a delimiter
91
+ audio_file = row[0].strip() # First column: audio file path
92
+ text = row[1].strip() # Second column: text
93
+ # audio_file_path = parent / audio_file
94
+ audio_file_path = audio_file
95
+ audio_text_pairs.append((Path(audio_file_path).as_posix(), text))
96
+ else:
97
+ print("skipped", row)
98
+ return audio_text_pairs
99
+
100
+
101
+ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
102
+ out_dir = Path(out_dir)
103
+ # save preprocessed dataset to disk
104
+ out_dir.mkdir(exist_ok=True, parents=True)
105
+ print(f"\nSaving to {out_dir} ...")
106
+
107
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
108
+ # dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
109
+ raw_arrow_path = out_dir / "raw.arrow"
110
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
111
+ for line in tqdm(result, desc="Writing to raw.arrow ..."):
112
+ writer.write(line)
113
+
114
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
115
+ dur_json_path = out_dir / "duration.json"
116
+ with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
117
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
118
+
119
+ # vocab map, i.e. tokenizer
120
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
121
+ # if tokenizer == "pinyin":
122
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
123
+ voca_out_path = out_dir / "new_vocab.txt"
124
+ with open(voca_out_path.as_posix(), "w") as f:
125
+ for vocab in sorted(text_vocab_set):
126
+ f.write(vocab + "\n")
127
+
128
+ # voca_out_path = out_dir / "new_vocab.txt"
129
+ # with open(voca_out_path.as_posix(), "w") as f:
130
+ # for vocab in sorted(text_vocab_set):
131
+ # f.write(vocab + "\n")
132
+
133
+ # voca_out_path = out_dir / "vocab.txt"
134
+ # if is_finetune:
135
+ # file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
136
+ # shutil.copy2(file_vocab_finetune, voca_out_path)
137
+ # else:
138
+ # with open(voca_out_path, "w") as f:
139
+ # for vocab in sorted(text_vocab_set):
140
+ # f.write(vocab + "\n")
141
+
142
+ dataset_name = out_dir.stem
143
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
144
+ # print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
145
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
146
+
147
+
148
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
149
+ if is_finetune:
150
+ print("Inside finetuning ...")
151
+ # assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
152
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
153
+ save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
154
+
155
+
156
+ def cli():
157
+ # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
158
+ # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
159
+ parser = argparse.ArgumentParser(description="Prepare and save dataset.")
160
+ parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
161
+ parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
162
+ parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
163
+
164
+ args = parser.parse_args()
165
+
166
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
167
+
168
+
169
+ if __name__ == "__main__":
170
+ cli()