| |
| """ |
| Prepare Sinhala dataset for F5-TTS fine-tuning. |
| |
| This bypasses F5-TTS's built-in pinyin-based prep and builds |
| the Arrow dataset directly for Sinhala character-level tokenization. |
| |
| Usage: |
| python scripts/prepare_f5_data.py \ |
| --dataset_dir /path/to/cc_v1_tenvideo_baseline \ |
| --output_dir data/sinhala_tts_custom \ |
| --vocab_path data/sinhala_vocab/vocab.txt |
| |
| The output goes to data/<output_name>/ which must match: |
| --dataset_name <output_name> --tokenizer custom |
| in the finetune CLI. |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import soundfile as sf |
| from tqdm import tqdm |
|
|
|
|
| def build_arrow_dataset(dataset_dir: str, output_dir: str, vocab_path: str, |
| min_dur: float = 0.3, max_dur: float = 30.0): |
| """Build Arrow dataset + duration.json + copy vocab.txt.""" |
| from datasets.arrow_writer import ArrowWriter |
|
|
| dataset_path = Path(dataset_dir) |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| meta_csv = dataset_path / "metadata.csv" |
| if not meta_csv.exists(): |
| print(f"ERROR: {meta_csv} not found") |
| sys.exit(1) |
|
|
| |
| wavs_dir = dataset_path / "wavs" |
| if not wavs_dir.exists(): |
| print(f"ERROR: {wavs_dir} not found") |
| sys.exit(1) |
|
|
| |
| records = [] |
| skipped_missing = 0 |
| skipped_duration = 0 |
|
|
| print(f"Reading metadata from {meta_csv}...") |
| with open(meta_csv, encoding="utf-8") as f: |
| for line_num, line in enumerate(f): |
| line = line.strip() |
| if not line: |
| continue |
| parts = line.split("|") |
| if len(parts) < 2: |
| continue |
|
|
| uttr_id = parts[0].strip() |
| text = parts[1].strip() |
| norm_text = parts[2].strip() if len(parts) > 2 else text |
|
|
| |
| wav_path = wavs_dir / f"{uttr_id}.wav" |
| if not wav_path.exists(): |
| |
| try: |
| num = int(uttr_id.split("_")[1]) |
| subdir = f"{num // 5000:02d}" |
| wav_path = wavs_dir / subdir / f"{uttr_id}.wav" |
| except (IndexError, ValueError): |
| pass |
|
|
| if not wav_path.exists(): |
| skipped_missing += 1 |
| continue |
|
|
| |
| try: |
| info = sf.info(str(wav_path)) |
| dur = info.duration |
| except Exception as e: |
| print(f" WARNING: Can't read {wav_path}: {e}") |
| skipped_missing += 1 |
| continue |
|
|
| if dur < min_dur or dur > max_dur: |
| skipped_duration += 1 |
| continue |
|
|
| records.append({ |
| "audio_path": str(wav_path.resolve()), |
| "text": norm_text, |
| "duration": dur, |
| }) |
|
|
| print(f"\nParsed {len(records)} valid utterances") |
| print(f" Skipped (missing wav): {skipped_missing}") |
| print(f" Skipped (duration): {skipped_duration}") |
| total_hours = sum(r["duration"] for r in records) / 3600 |
| print(f" Total audio: {total_hours:.2f} hours") |
|
|
| if len(records) == 0: |
| print("ERROR: No valid records found!") |
| sys.exit(1) |
|
|
| |
| if vocab_path and Path(vocab_path).exists(): |
| print(f"\nUsing provided vocab: {vocab_path}") |
| import shutil |
| shutil.copy2(vocab_path, output_path / "vocab.txt") |
| else: |
| print("\nBuilding vocab from data...") |
| vocab_set = set() |
| for r in records: |
| vocab_set.update(list(r["text"])) |
| |
| vocab_sorted = [" "] + sorted(v for v in vocab_set if v != " ") |
| with open(output_path / "vocab.txt", "w", encoding="utf-8") as f: |
| for ch in vocab_sorted: |
| f.write(ch + "\n") |
| print(f" Vocab size: {len(vocab_sorted)}") |
|
|
| |
| durations = [r["duration"] for r in records] |
| with open(output_path / "duration.json", "w") as f: |
| json.dump({"duration": durations}, f) |
| print(f"Wrote duration.json ({len(durations)} entries)") |
|
|
| |
| print("Building Arrow dataset...") |
| arrow_path = output_path / "raw.arrow" |
|
|
| with ArrowWriter(path=str(arrow_path)) as writer: |
| for record in tqdm(records, desc="Writing Arrow"): |
| writer.write(record) |
| writer.finalize() |
|
|
| print(f"\nDone! Output directory: {output_path}") |
| print(f" raw.arrow : {arrow_path.stat().st_size / 1024:.1f} KB") |
| print(f" duration.json : {(output_path / 'duration.json').stat().st_size / 1024:.1f} KB") |
| print(f" vocab.txt : {(output_path / 'vocab.txt').stat().st_size / 1024:.1f} KB") |
| print(f"\nTotal: {len(records)} utterances, {total_hours:.2f} hours") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Prepare Sinhala dataset for F5-TTS") |
| parser.add_argument("--dataset_dir", required=True, |
| help="Path to dataset directory (contains metadata.csv + wavs/)") |
| parser.add_argument("--output_dir", default="data/sinhala_tts_custom", |
| help="Output directory for Arrow dataset") |
| parser.add_argument("--vocab_path", default="data/sinhala_vocab/vocab.txt", |
| help="Path to vocab.txt (Sinhala character set)") |
| parser.add_argument("--min_dur", type=float, default=0.3, |
| help="Min utterance duration in seconds") |
| parser.add_argument("--max_dur", type=float, default=30.0, |
| help="Max utterance duration in seconds") |
|
|
| args = parser.parse_args() |
| build_arrow_dataset(args.dataset_dir, args.output_dir, args.vocab_path, |
| args.min_dur, args.max_dur) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|