sinhala-tts / scripts /validate_dataset.py
outlawmold's picture
Fix critical issues, migrate to IndicF5 fine-tuning, update pipeline
dd75f48
#!/usr/bin/env python3
"""
=============================================================
Dataset Validator for F5-TTS Training
=============================================================
Scans an F5-TTS dataset (raw.arrow + audio files) and validates every audio
file can be read by soundfile. This catches corrupted/malformed WAVs before
training starts.
Usage:
python scripts/validate_dataset.py --dataset data/sinhala_tts_batch03_custom
python scripts/validate_dataset.py --dataset data/sinhala_tts_custom
If corrupted files are found, the script prints their paths and exits with
code 1. You can then remove or re-encode them.
=============================================================
"""
import argparse
import sys
from pathlib import Path
def validate_dataset(dataset_dir: str):
import pyarrow as pa
import pyarrow.ipc as ipc
import soundfile as sf
import numpy as np
dataset_path = Path(dataset_dir)
arrow_path = dataset_path / "raw.arrow"
duration_path = dataset_path / "duration.json"
if not arrow_path.exists():
print(f"ERROR: {arrow_path} not found")
sys.exit(1)
print(f"Loading dataset from {dataset_path} ...")
with pa.memory_map(str(arrow_path), "r") as f:
try:
reader = ipc.RecordBatchFileReader(f)
table = reader.read_all()
except pa.ArrowInvalid:
f.seek(0)
reader = ipc.RecordBatchStreamReader(f)
table = reader.read_all()
paths = [p.as_py() for p in table["audio_path"]]
texts = [t.as_py() for t in table["text"]]
durations = [d.as_py() for d in table["duration"]]
print(f"Total utterances: {len(paths)}")
corrupted = []
warnings = []
for i, (p, text, dur) in enumerate(zip(paths, texts, durations)):
path_obj = Path(p)
if not path_obj.exists():
corrupted.append((i, p, "file not found"))
print(f" [{i}] MISSING: {p}")
continue
try:
data, sr = sf.read(p, dtype="float32")
except Exception as e:
corrupted.append((i, p, str(e)))
print(f" [{i}] CORRUPTED: {p}")
print(f" Error: {e}")
continue
if len(data) == 0:
corrupted.append((i, p, "empty audio"))
print(f" [{i}] EMPTY: {p}")
continue
if np.any(np.isnan(data)) or np.any(np.isinf(data)):
corrupted.append((i, p, "NaN/Inf in audio"))
print(f" [{i}] NAN/INF: {p}")
continue
actual_dur = len(data) / sr
if abs(actual_dur - dur) > 1.0:
warnings.append((i, p, f"duration mismatch: metadata={dur:.2f}s, actual={actual_dur:.2f}s"))
if (i + 1) % 500 == 0:
print(f" ... validated {i + 1}/{len(paths)}")
print(f"\n{'=' * 60}")
print(f"Validation complete.")
print(f" Corrupted/missing: {len(corrupted)}")
print(f" Warnings: {len(warnings)}")
if warnings:
print("\nWarnings (non-fatal):")
for i, p, reason in warnings[:10]:
print(f" [{i}] {reason}: {p}")
if len(warnings) > 10:
print(f" ... and {len(warnings) - 10} more")
if corrupted:
print("\nCorrupted files:")
for i, p, reason in corrupted:
print(f" [{i}] {reason}: {p}")
print("\nFix: remove or re-encode these files, then rebuild the dataset.")
sys.exit(1)
print("All audio files are valid. Training should proceed without soundfile crashes.\n")
sys.exit(0)
def main():
parser = argparse.ArgumentParser(description="Validate F5-TTS dataset audio files")
parser.add_argument("--dataset", default="data/sinhala_tts_batch03_custom", help="Path to dataset dir containing raw.arrow")
args = parser.parse_args()
validate_dataset(args.dataset)
if __name__ == "__main__":
main()