#!/usr/bin/env python3 """ Prepare Sinhala dataset for F5-TTS fine-tuning. This bypasses F5-TTS's built-in pinyin-based prep and builds the Arrow dataset directly for Sinhala character-level tokenization. Usage: python scripts/prepare_f5_data.py \ --dataset_dir /path/to/cc_v1_tenvideo_baseline \ --output_dir data/sinhala_tts_custom \ --vocab_path data/sinhala_vocab/vocab.txt The output goes to data// which must match: --dataset_name --tokenizer custom in the finetune CLI. """ import argparse import json import os import sys from pathlib import Path import soundfile as sf from tqdm import tqdm def build_arrow_dataset(dataset_dir: str, output_dir: str, vocab_path: str, min_dur: float = 0.3, max_dur: float = 30.0): """Build Arrow dataset + duration.json + copy vocab.txt.""" from datasets.arrow_writer import ArrowWriter dataset_path = Path(dataset_dir) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # --- Find metadata CSV --- meta_csv = dataset_path / "metadata.csv" if not meta_csv.exists(): print(f"ERROR: {meta_csv} not found") sys.exit(1) # --- Find wavs directory --- wavs_dir = dataset_path / "wavs" if not wavs_dir.exists(): print(f"ERROR: {wavs_dir} not found") sys.exit(1) # --- Parse metadata (LJSpeech format: filename|text|normalized_text) --- records = [] skipped_missing = 0 skipped_duration = 0 print(f"Reading metadata from {meta_csv}...") with open(meta_csv, encoding="utf-8") as f: for line_num, line in enumerate(f): line = line.strip() if not line: continue parts = line.split("|") if len(parts) < 2: continue uttr_id = parts[0].strip() text = parts[1].strip() norm_text = parts[2].strip() if len(parts) > 2 else text # Resolve wav path — handle both flat and subdirectory layouts wav_path = wavs_dir / f"{uttr_id}.wav" if not wav_path.exists(): # Try with subdirectory structure (e.g., wavs/00/si_000042.wav) try: num = int(uttr_id.split("_")[1]) subdir = f"{num // 5000:02d}" wav_path = wavs_dir / subdir / f"{uttr_id}.wav" except (IndexError, ValueError): pass if not wav_path.exists(): skipped_missing += 1 continue # Get duration try: info = sf.info(str(wav_path)) dur = info.duration except Exception as e: print(f" WARNING: Can't read {wav_path}: {e}") skipped_missing += 1 continue if dur < min_dur or dur > max_dur: skipped_duration += 1 continue records.append({ "audio_path": str(wav_path.resolve()), "text": norm_text, "duration": dur, }) print(f"\nParsed {len(records)} valid utterances") print(f" Skipped (missing wav): {skipped_missing}") print(f" Skipped (duration): {skipped_duration}") total_hours = sum(r["duration"] for r in records) / 3600 print(f" Total audio: {total_hours:.2f} hours") if len(records) == 0: print("ERROR: No valid records found!") sys.exit(1) # --- Build vocab from data (or use provided) --- if vocab_path and Path(vocab_path).exists(): print(f"\nUsing provided vocab: {vocab_path}") import shutil shutil.copy2(vocab_path, output_path / "vocab.txt") else: print("\nBuilding vocab from data...") vocab_set = set() for r in records: vocab_set.update(list(r["text"])) # Space must be index 0 vocab_sorted = [" "] + sorted(v for v in vocab_set if v != " ") with open(output_path / "vocab.txt", "w", encoding="utf-8") as f: for ch in vocab_sorted: f.write(ch + "\n") print(f" Vocab size: {len(vocab_sorted)}") # --- Write duration.json --- durations = [r["duration"] for r in records] with open(output_path / "duration.json", "w") as f: json.dump({"duration": durations}, f) print(f"Wrote duration.json ({len(durations)} entries)") # --- Write Arrow dataset --- print("Building Arrow dataset...") arrow_path = output_path / "raw.arrow" with ArrowWriter(path=str(arrow_path)) as writer: for record in tqdm(records, desc="Writing Arrow"): writer.write(record) writer.finalize() print(f"\nDone! Output directory: {output_path}") print(f" raw.arrow : {arrow_path.stat().st_size / 1024:.1f} KB") print(f" duration.json : {(output_path / 'duration.json').stat().st_size / 1024:.1f} KB") print(f" vocab.txt : {(output_path / 'vocab.txt').stat().st_size / 1024:.1f} KB") print(f"\nTotal: {len(records)} utterances, {total_hours:.2f} hours") def main(): parser = argparse.ArgumentParser(description="Prepare Sinhala dataset for F5-TTS") parser.add_argument("--dataset_dir", required=True, help="Path to dataset directory (contains metadata.csv + wavs/)") parser.add_argument("--output_dir", default="data/sinhala_tts_custom", help="Output directory for Arrow dataset") parser.add_argument("--vocab_path", default="data/sinhala_vocab/vocab.txt", help="Path to vocab.txt (Sinhala character set)") parser.add_argument("--min_dur", type=float, default=0.3, help="Min utterance duration in seconds") parser.add_argument("--max_dur", type=float, default=30.0, help="Max utterance duration in seconds") args = parser.parse_args() build_arrow_dataset(args.dataset_dir, args.output_dir, args.vocab_path, args.min_dur, args.max_dur) if __name__ == "__main__": main()