#!/usr/bin/env python3 """ ============================================================= Dataset Validator for F5-TTS Training ============================================================= Scans an F5-TTS dataset (raw.arrow + audio files) and validates every audio file can be read by soundfile. This catches corrupted/malformed WAVs before training starts. Usage: python scripts/validate_dataset.py --dataset data/sinhala_tts_batch03_custom python scripts/validate_dataset.py --dataset data/sinhala_tts_custom If corrupted files are found, the script prints their paths and exits with code 1. You can then remove or re-encode them. ============================================================= """ import argparse import sys from pathlib import Path def validate_dataset(dataset_dir: str): import pyarrow as pa import pyarrow.ipc as ipc import soundfile as sf import numpy as np dataset_path = Path(dataset_dir) arrow_path = dataset_path / "raw.arrow" duration_path = dataset_path / "duration.json" if not arrow_path.exists(): print(f"ERROR: {arrow_path} not found") sys.exit(1) print(f"Loading dataset from {dataset_path} ...") with pa.memory_map(str(arrow_path), "r") as f: try: reader = ipc.RecordBatchFileReader(f) table = reader.read_all() except pa.ArrowInvalid: f.seek(0) reader = ipc.RecordBatchStreamReader(f) table = reader.read_all() paths = [p.as_py() for p in table["audio_path"]] texts = [t.as_py() for t in table["text"]] durations = [d.as_py() for d in table["duration"]] print(f"Total utterances: {len(paths)}") corrupted = [] warnings = [] for i, (p, text, dur) in enumerate(zip(paths, texts, durations)): path_obj = Path(p) if not path_obj.exists(): corrupted.append((i, p, "file not found")) print(f" [{i}] MISSING: {p}") continue try: data, sr = sf.read(p, dtype="float32") except Exception as e: corrupted.append((i, p, str(e))) print(f" [{i}] CORRUPTED: {p}") print(f" Error: {e}") continue if len(data) == 0: corrupted.append((i, p, "empty audio")) print(f" [{i}] EMPTY: {p}") continue if np.any(np.isnan(data)) or np.any(np.isinf(data)): corrupted.append((i, p, "NaN/Inf in audio")) print(f" [{i}] NAN/INF: {p}") continue actual_dur = len(data) / sr if abs(actual_dur - dur) > 1.0: warnings.append((i, p, f"duration mismatch: metadata={dur:.2f}s, actual={actual_dur:.2f}s")) if (i + 1) % 500 == 0: print(f" ... validated {i + 1}/{len(paths)}") print(f"\n{'=' * 60}") print(f"Validation complete.") print(f" Corrupted/missing: {len(corrupted)}") print(f" Warnings: {len(warnings)}") if warnings: print("\nWarnings (non-fatal):") for i, p, reason in warnings[:10]: print(f" [{i}] {reason}: {p}") if len(warnings) > 10: print(f" ... and {len(warnings) - 10} more") if corrupted: print("\nCorrupted files:") for i, p, reason in corrupted: print(f" [{i}] {reason}: {p}") print("\nFix: remove or re-encode these files, then rebuild the dataset.") sys.exit(1) print("All audio files are valid. Training should proceed without soundfile crashes.\n") sys.exit(0) def main(): parser = argparse.ArgumentParser(description="Validate F5-TTS dataset audio files") parser.add_argument("--dataset", default="data/sinhala_tts_batch03_custom", help="Path to dataset dir containing raw.arrow") args = parser.parse_args() validate_dataset(args.dataset) if __name__ == "__main__": main()