sinhala-tts / scripts /prepare_f5_data.py
outlawmold's picture
Add F5-TTS data prep script (Sinhala Arrow builder, bypasses pinyin)
76d096a verified
#!/usr/bin/env python3
"""
Prepare Sinhala dataset for F5-TTS fine-tuning.
This bypasses F5-TTS's built-in pinyin-based prep and builds
the Arrow dataset directly for Sinhala character-level tokenization.
Usage:
python scripts/prepare_f5_data.py \
--dataset_dir /path/to/cc_v1_tenvideo_baseline \
--output_dir data/sinhala_tts_custom \
--vocab_path data/sinhala_vocab/vocab.txt
The output goes to data/<output_name>/ which must match:
--dataset_name <output_name> --tokenizer custom
in the finetune CLI.
"""
import argparse
import json
import os
import sys
from pathlib import Path
import soundfile as sf
from tqdm import tqdm
def build_arrow_dataset(dataset_dir: str, output_dir: str, vocab_path: str,
min_dur: float = 0.3, max_dur: float = 30.0):
"""Build Arrow dataset + duration.json + copy vocab.txt."""
from datasets.arrow_writer import ArrowWriter
dataset_path = Path(dataset_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# --- Find metadata CSV ---
meta_csv = dataset_path / "metadata.csv"
if not meta_csv.exists():
print(f"ERROR: {meta_csv} not found")
sys.exit(1)
# --- Find wavs directory ---
wavs_dir = dataset_path / "wavs"
if not wavs_dir.exists():
print(f"ERROR: {wavs_dir} not found")
sys.exit(1)
# --- Parse metadata (LJSpeech format: filename|text|normalized_text) ---
records = []
skipped_missing = 0
skipped_duration = 0
print(f"Reading metadata from {meta_csv}...")
with open(meta_csv, encoding="utf-8") as f:
for line_num, line in enumerate(f):
line = line.strip()
if not line:
continue
parts = line.split("|")
if len(parts) < 2:
continue
uttr_id = parts[0].strip()
text = parts[1].strip()
norm_text = parts[2].strip() if len(parts) > 2 else text
# Resolve wav path — handle both flat and subdirectory layouts
wav_path = wavs_dir / f"{uttr_id}.wav"
if not wav_path.exists():
# Try with subdirectory structure (e.g., wavs/00/si_000042.wav)
try:
num = int(uttr_id.split("_")[1])
subdir = f"{num // 5000:02d}"
wav_path = wavs_dir / subdir / f"{uttr_id}.wav"
except (IndexError, ValueError):
pass
if not wav_path.exists():
skipped_missing += 1
continue
# Get duration
try:
info = sf.info(str(wav_path))
dur = info.duration
except Exception as e:
print(f" WARNING: Can't read {wav_path}: {e}")
skipped_missing += 1
continue
if dur < min_dur or dur > max_dur:
skipped_duration += 1
continue
records.append({
"audio_path": str(wav_path.resolve()),
"text": norm_text,
"duration": dur,
})
print(f"\nParsed {len(records)} valid utterances")
print(f" Skipped (missing wav): {skipped_missing}")
print(f" Skipped (duration): {skipped_duration}")
total_hours = sum(r["duration"] for r in records) / 3600
print(f" Total audio: {total_hours:.2f} hours")
if len(records) == 0:
print("ERROR: No valid records found!")
sys.exit(1)
# --- Build vocab from data (or use provided) ---
if vocab_path and Path(vocab_path).exists():
print(f"\nUsing provided vocab: {vocab_path}")
import shutil
shutil.copy2(vocab_path, output_path / "vocab.txt")
else:
print("\nBuilding vocab from data...")
vocab_set = set()
for r in records:
vocab_set.update(list(r["text"]))
# Space must be index 0
vocab_sorted = [" "] + sorted(v for v in vocab_set if v != " ")
with open(output_path / "vocab.txt", "w", encoding="utf-8") as f:
for ch in vocab_sorted:
f.write(ch + "\n")
print(f" Vocab size: {len(vocab_sorted)}")
# --- Write duration.json ---
durations = [r["duration"] for r in records]
with open(output_path / "duration.json", "w") as f:
json.dump({"duration": durations}, f)
print(f"Wrote duration.json ({len(durations)} entries)")
# --- Write Arrow dataset ---
print("Building Arrow dataset...")
arrow_path = output_path / "raw.arrow"
with ArrowWriter(path=str(arrow_path)) as writer:
for record in tqdm(records, desc="Writing Arrow"):
writer.write(record)
writer.finalize()
print(f"\nDone! Output directory: {output_path}")
print(f" raw.arrow : {arrow_path.stat().st_size / 1024:.1f} KB")
print(f" duration.json : {(output_path / 'duration.json').stat().st_size / 1024:.1f} KB")
print(f" vocab.txt : {(output_path / 'vocab.txt').stat().st_size / 1024:.1f} KB")
print(f"\nTotal: {len(records)} utterances, {total_hours:.2f} hours")
def main():
parser = argparse.ArgumentParser(description="Prepare Sinhala dataset for F5-TTS")
parser.add_argument("--dataset_dir", required=True,
help="Path to dataset directory (contains metadata.csv + wavs/)")
parser.add_argument("--output_dir", default="data/sinhala_tts_custom",
help="Output directory for Arrow dataset")
parser.add_argument("--vocab_path", default="data/sinhala_vocab/vocab.txt",
help="Path to vocab.txt (Sinhala character set)")
parser.add_argument("--min_dur", type=float, default=0.3,
help="Min utterance duration in seconds")
parser.add_argument("--max_dur", type=float, default=30.0,
help="Max utterance duration in seconds")
args = parser.parse_args()
build_arrow_dataset(args.dataset_dir, args.output_dir, args.vocab_path,
args.min_dur, args.max_dur)
if __name__ == "__main__":
main()