| |
| """ |
| ============================================================= |
| Dataset Validator for F5-TTS Training |
| ============================================================= |
| Scans an F5-TTS dataset (raw.arrow + audio files) and validates every audio |
| file can be read by soundfile. This catches corrupted/malformed WAVs before |
| training starts. |
| |
| Usage: |
| python scripts/validate_dataset.py --dataset data/sinhala_tts_batch03_custom |
| python scripts/validate_dataset.py --dataset data/sinhala_tts_custom |
| |
| If corrupted files are found, the script prints their paths and exits with |
| code 1. You can then remove or re-encode them. |
| ============================================================= |
| """ |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
|
|
| def validate_dataset(dataset_dir: str): |
| import pyarrow as pa |
| import pyarrow.ipc as ipc |
| import soundfile as sf |
| import numpy as np |
|
|
| dataset_path = Path(dataset_dir) |
| arrow_path = dataset_path / "raw.arrow" |
| duration_path = dataset_path / "duration.json" |
|
|
| if not arrow_path.exists(): |
| print(f"ERROR: {arrow_path} not found") |
| sys.exit(1) |
|
|
| print(f"Loading dataset from {dataset_path} ...") |
|
|
| with pa.memory_map(str(arrow_path), "r") as f: |
| try: |
| reader = ipc.RecordBatchFileReader(f) |
| table = reader.read_all() |
| except pa.ArrowInvalid: |
| f.seek(0) |
| reader = ipc.RecordBatchStreamReader(f) |
| table = reader.read_all() |
|
|
| paths = [p.as_py() for p in table["audio_path"]] |
| texts = [t.as_py() for t in table["text"]] |
| durations = [d.as_py() for d in table["duration"]] |
|
|
| print(f"Total utterances: {len(paths)}") |
|
|
| corrupted = [] |
| warnings = [] |
|
|
| for i, (p, text, dur) in enumerate(zip(paths, texts, durations)): |
| path_obj = Path(p) |
| if not path_obj.exists(): |
| corrupted.append((i, p, "file not found")) |
| print(f" [{i}] MISSING: {p}") |
| continue |
|
|
| try: |
| data, sr = sf.read(p, dtype="float32") |
| except Exception as e: |
| corrupted.append((i, p, str(e))) |
| print(f" [{i}] CORRUPTED: {p}") |
| print(f" Error: {e}") |
| continue |
|
|
| if len(data) == 0: |
| corrupted.append((i, p, "empty audio")) |
| print(f" [{i}] EMPTY: {p}") |
| continue |
|
|
| if np.any(np.isnan(data)) or np.any(np.isinf(data)): |
| corrupted.append((i, p, "NaN/Inf in audio")) |
| print(f" [{i}] NAN/INF: {p}") |
| continue |
|
|
| actual_dur = len(data) / sr |
| if abs(actual_dur - dur) > 1.0: |
| warnings.append((i, p, f"duration mismatch: metadata={dur:.2f}s, actual={actual_dur:.2f}s")) |
|
|
| if (i + 1) % 500 == 0: |
| print(f" ... validated {i + 1}/{len(paths)}") |
|
|
| print(f"\n{'=' * 60}") |
| print(f"Validation complete.") |
| print(f" Corrupted/missing: {len(corrupted)}") |
| print(f" Warnings: {len(warnings)}") |
|
|
| if warnings: |
| print("\nWarnings (non-fatal):") |
| for i, p, reason in warnings[:10]: |
| print(f" [{i}] {reason}: {p}") |
| if len(warnings) > 10: |
| print(f" ... and {len(warnings) - 10} more") |
|
|
| if corrupted: |
| print("\nCorrupted files:") |
| for i, p, reason in corrupted: |
| print(f" [{i}] {reason}: {p}") |
| print("\nFix: remove or re-encode these files, then rebuild the dataset.") |
| sys.exit(1) |
|
|
| print("All audio files are valid. Training should proceed without soundfile crashes.\n") |
| sys.exit(0) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Validate F5-TTS dataset audio files") |
| parser.add_argument("--dataset", default="data/sinhala_tts_batch03_custom", help="Path to dataset dir containing raw.arrow") |
| args = parser.parse_args() |
| validate_dataset(args.dataset) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|