File size: 6,180 Bytes
76d096a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | #!/usr/bin/env python3
"""
Prepare Sinhala dataset for F5-TTS fine-tuning.
This bypasses F5-TTS's built-in pinyin-based prep and builds
the Arrow dataset directly for Sinhala character-level tokenization.
Usage:
python scripts/prepare_f5_data.py \
--dataset_dir /path/to/cc_v1_tenvideo_baseline \
--output_dir data/sinhala_tts_custom \
--vocab_path data/sinhala_vocab/vocab.txt
The output goes to data/<output_name>/ which must match:
--dataset_name <output_name> --tokenizer custom
in the finetune CLI.
"""
import argparse
import json
import os
import sys
from pathlib import Path
import soundfile as sf
from tqdm import tqdm
def build_arrow_dataset(dataset_dir: str, output_dir: str, vocab_path: str,
min_dur: float = 0.3, max_dur: float = 30.0):
"""Build Arrow dataset + duration.json + copy vocab.txt."""
from datasets.arrow_writer import ArrowWriter
dataset_path = Path(dataset_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# --- Find metadata CSV ---
meta_csv = dataset_path / "metadata.csv"
if not meta_csv.exists():
print(f"ERROR: {meta_csv} not found")
sys.exit(1)
# --- Find wavs directory ---
wavs_dir = dataset_path / "wavs"
if not wavs_dir.exists():
print(f"ERROR: {wavs_dir} not found")
sys.exit(1)
# --- Parse metadata (LJSpeech format: filename|text|normalized_text) ---
records = []
skipped_missing = 0
skipped_duration = 0
print(f"Reading metadata from {meta_csv}...")
with open(meta_csv, encoding="utf-8") as f:
for line_num, line in enumerate(f):
line = line.strip()
if not line:
continue
parts = line.split("|")
if len(parts) < 2:
continue
uttr_id = parts[0].strip()
text = parts[1].strip()
norm_text = parts[2].strip() if len(parts) > 2 else text
# Resolve wav path — handle both flat and subdirectory layouts
wav_path = wavs_dir / f"{uttr_id}.wav"
if not wav_path.exists():
# Try with subdirectory structure (e.g., wavs/00/si_000042.wav)
try:
num = int(uttr_id.split("_")[1])
subdir = f"{num // 5000:02d}"
wav_path = wavs_dir / subdir / f"{uttr_id}.wav"
except (IndexError, ValueError):
pass
if not wav_path.exists():
skipped_missing += 1
continue
# Get duration
try:
info = sf.info(str(wav_path))
dur = info.duration
except Exception as e:
print(f" WARNING: Can't read {wav_path}: {e}")
skipped_missing += 1
continue
if dur < min_dur or dur > max_dur:
skipped_duration += 1
continue
records.append({
"audio_path": str(wav_path.resolve()),
"text": norm_text,
"duration": dur,
})
print(f"\nParsed {len(records)} valid utterances")
print(f" Skipped (missing wav): {skipped_missing}")
print(f" Skipped (duration): {skipped_duration}")
total_hours = sum(r["duration"] for r in records) / 3600
print(f" Total audio: {total_hours:.2f} hours")
if len(records) == 0:
print("ERROR: No valid records found!")
sys.exit(1)
# --- Build vocab from data (or use provided) ---
if vocab_path and Path(vocab_path).exists():
print(f"\nUsing provided vocab: {vocab_path}")
import shutil
shutil.copy2(vocab_path, output_path / "vocab.txt")
else:
print("\nBuilding vocab from data...")
vocab_set = set()
for r in records:
vocab_set.update(list(r["text"]))
# Space must be index 0
vocab_sorted = [" "] + sorted(v for v in vocab_set if v != " ")
with open(output_path / "vocab.txt", "w", encoding="utf-8") as f:
for ch in vocab_sorted:
f.write(ch + "\n")
print(f" Vocab size: {len(vocab_sorted)}")
# --- Write duration.json ---
durations = [r["duration"] for r in records]
with open(output_path / "duration.json", "w") as f:
json.dump({"duration": durations}, f)
print(f"Wrote duration.json ({len(durations)} entries)")
# --- Write Arrow dataset ---
print("Building Arrow dataset...")
arrow_path = output_path / "raw.arrow"
with ArrowWriter(path=str(arrow_path)) as writer:
for record in tqdm(records, desc="Writing Arrow"):
writer.write(record)
writer.finalize()
print(f"\nDone! Output directory: {output_path}")
print(f" raw.arrow : {arrow_path.stat().st_size / 1024:.1f} KB")
print(f" duration.json : {(output_path / 'duration.json').stat().st_size / 1024:.1f} KB")
print(f" vocab.txt : {(output_path / 'vocab.txt').stat().st_size / 1024:.1f} KB")
print(f"\nTotal: {len(records)} utterances, {total_hours:.2f} hours")
def main():
parser = argparse.ArgumentParser(description="Prepare Sinhala dataset for F5-TTS")
parser.add_argument("--dataset_dir", required=True,
help="Path to dataset directory (contains metadata.csv + wavs/)")
parser.add_argument("--output_dir", default="data/sinhala_tts_custom",
help="Output directory for Arrow dataset")
parser.add_argument("--vocab_path", default="data/sinhala_vocab/vocab.txt",
help="Path to vocab.txt (Sinhala character set)")
parser.add_argument("--min_dur", type=float, default=0.3,
help="Min utterance duration in seconds")
parser.add_argument("--max_dur", type=float, default=30.0,
help="Max utterance duration in seconds")
args = parser.parse_args()
build_arrow_dataset(args.dataset_dir, args.output_dir, args.vocab_path,
args.min_dur, args.max_dur)
if __name__ == "__main__":
main()
|