ashishkblink commited on
Commit
f916b29
·
verified ·
1 Parent(s): 9ed490b

Upload f5_tts/train/datasets/prepare_optimized.py with huggingface_hub

Browse files
f5_tts/train/datasets/prepare_optimized.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import argparse
5
+ from pathlib import Path
6
+ from multiprocessing import Pool
7
+ from datasets.arrow_writer import ArrowWriter
8
+ from f5_tts.model.utils import convert_char_to_pinyin
9
+ from tqdm import tqdm
10
+
11
+ sys.path.append(os.getcwd())
12
+
13
+ # Increase CSV field size limit
14
+ import csv
15
+ csv.field_size_limit(sys.maxsize)
16
+
17
+
18
+ # def get_audio_duration(audio_path):
19
+ # """Use SoX for instant audio duration retrieval"""
20
+ # result = os.popen(f"soxi -D {audio_path}").read().strip()
21
+ # return float(result) if result else 0
22
+
23
+ import subprocess
24
+
25
+ def get_audio_duration(audio_path):
26
+ """Use ffprobe for accurate duration retrieval without header issues."""
27
+ try:
28
+ result = subprocess.run(
29
+ ["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of",
30
+ "default=noprint_wrappers=1:nokey=1", audio_path],
31
+ stdout=subprocess.PIPE,
32
+ stderr=subprocess.PIPE,
33
+ text=True
34
+ )
35
+ return float(result.stdout.strip()) if result.stdout.strip() else 0
36
+ except Exception as e:
37
+ print(f"Error processing {audio_path}: {e}")
38
+ return 0
39
+
40
+
41
+
42
+ def read_audio_text_pairs(csv_file_path):
43
+ """Use AWK to quickly process CSV"""
44
+ awk_cmd = f"awk -F '|' 'NR > 1 {{ print $1, $2 }}' {csv_file_path}"
45
+ output = os.popen(awk_cmd).read().strip().split("\n")
46
+
47
+ parent = Path(csv_file_path).parent
48
+ return [(str(parent / line.split(" ")[0]), " ".join(line.split(" ")[1:])) for line in output if len(line.split(" ")) >= 2]
49
+
50
+
51
+ def process_audio(audio_path_text):
52
+ """Processes an audio file: checks existence, computes duration, and converts text to Pinyin"""
53
+ audio_path, text = audio_path_text
54
+ if not Path(audio_path).exists():
55
+ return None
56
+
57
+ duration = get_audio_duration(audio_path)
58
+ if duration < 0.1 or duration > 30:
59
+ return None
60
+
61
+ text = convert_char_to_pinyin([text], polyphone=True)[0]
62
+ return {"audio_path": audio_path, "text": text, "duration": duration}, duration
63
+
64
+
65
+ def prepare_csv_wavs_dir(input_dir, num_processes=32):
66
+ """Parallelized processing of audio-text pairs using multiprocessing"""
67
+ input_dir = Path(input_dir)
68
+ metadata_path = input_dir / "metadata.csv"
69
+ audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
70
+
71
+ with Pool(num_processes) as pool:
72
+ results = list(tqdm(pool.imap(process_audio, audio_path_text_pairs), total=len(audio_path_text_pairs), desc="Processing audio files"))
73
+
74
+ sub_result, durations, vocab_set = [], [], set()
75
+ for result in results:
76
+ if result:
77
+ sub_result.append(result[0])
78
+ durations.append(result[1])
79
+ vocab_set.update(list(result[0]['text']))
80
+
81
+ return sub_result, durations, vocab_set
82
+
83
+
84
+ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set):
85
+ """Writes the processed dataset to disk efficiently"""
86
+ out_dir = Path(out_dir)
87
+ out_dir.mkdir(exist_ok=True, parents=True)
88
+ print(f"\nSaving to {out_dir} ...")
89
+
90
+ raw_arrow_path = out_dir / "raw.arrow"
91
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
92
+ for line in tqdm(result, desc="Writing to raw.arrow"):
93
+ writer.write(line) # Stream data directly to Arrow file
94
+
95
+ dur_json_path = out_dir / "duration.json"
96
+ with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
97
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
98
+
99
+ voca_out_path = out_dir / "new_vocab.txt"
100
+ with open(voca_out_path.as_posix(), "w") as f:
101
+ f.writelines(f"{vocab}\n" for vocab in sorted(text_vocab_set))
102
+
103
+ dataset_name = out_dir.stem
104
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
105
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
106
+
107
+
108
+ def prepare_and_save_set(inp_dir, out_dir):
109
+ """Runs the dataset preparation pipeline"""
110
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
111
+ save_prepped_dataset(out_dir, sub_result, durations, vocab_set)
112
+
113
+
114
+ def cli():
115
+ """Command-line interface for the script"""
116
+ parser = argparse.ArgumentParser(description="Prepare and save dataset.")
117
+ parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
118
+ parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
119
+
120
+ args = parser.parse_args()
121
+ prepare_and_save_set(args.inp_dir, args.out_dir)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ cli()