File size: 3,911 Bytes
dd75f48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python3
"""
=============================================================
Dataset Validator for F5-TTS Training
=============================================================
Scans an F5-TTS dataset (raw.arrow + audio files) and validates every audio
file can be read by soundfile. This catches corrupted/malformed WAVs before
training starts.

Usage:
    python scripts/validate_dataset.py --dataset data/sinhala_tts_batch03_custom
    python scripts/validate_dataset.py --dataset data/sinhala_tts_custom

If corrupted files are found, the script prints their paths and exits with
code 1. You can then remove or re-encode them.
=============================================================
"""

import argparse
import sys
from pathlib import Path


def validate_dataset(dataset_dir: str):
    import pyarrow as pa
    import pyarrow.ipc as ipc
    import soundfile as sf
    import numpy as np

    dataset_path = Path(dataset_dir)
    arrow_path = dataset_path / "raw.arrow"
    duration_path = dataset_path / "duration.json"

    if not arrow_path.exists():
        print(f"ERROR: {arrow_path} not found")
        sys.exit(1)

    print(f"Loading dataset from {dataset_path} ...")

    with pa.memory_map(str(arrow_path), "r") as f:
        try:
            reader = ipc.RecordBatchFileReader(f)
            table = reader.read_all()
        except pa.ArrowInvalid:
            f.seek(0)
            reader = ipc.RecordBatchStreamReader(f)
            table = reader.read_all()

    paths = [p.as_py() for p in table["audio_path"]]
    texts = [t.as_py() for t in table["text"]]
    durations = [d.as_py() for d in table["duration"]]

    print(f"Total utterances: {len(paths)}")

    corrupted = []
    warnings = []

    for i, (p, text, dur) in enumerate(zip(paths, texts, durations)):
        path_obj = Path(p)
        if not path_obj.exists():
            corrupted.append((i, p, "file not found"))
            print(f"  [{i}] MISSING: {p}")
            continue

        try:
            data, sr = sf.read(p, dtype="float32")
        except Exception as e:
            corrupted.append((i, p, str(e)))
            print(f"  [{i}] CORRUPTED: {p}")
            print(f"       Error: {e}")
            continue

        if len(data) == 0:
            corrupted.append((i, p, "empty audio"))
            print(f"  [{i}] EMPTY: {p}")
            continue

        if np.any(np.isnan(data)) or np.any(np.isinf(data)):
            corrupted.append((i, p, "NaN/Inf in audio"))
            print(f"  [{i}] NAN/INF: {p}")
            continue

        actual_dur = len(data) / sr
        if abs(actual_dur - dur) > 1.0:
            warnings.append((i, p, f"duration mismatch: metadata={dur:.2f}s, actual={actual_dur:.2f}s"))

        if (i + 1) % 500 == 0:
            print(f"  ... validated {i + 1}/{len(paths)}")

    print(f"\n{'=' * 60}")
    print(f"Validation complete.")
    print(f"  Corrupted/missing: {len(corrupted)}")
    print(f"  Warnings: {len(warnings)}")

    if warnings:
        print("\nWarnings (non-fatal):")
        for i, p, reason in warnings[:10]:
            print(f"  [{i}] {reason}: {p}")
        if len(warnings) > 10:
            print(f"  ... and {len(warnings) - 10} more")

    if corrupted:
        print("\nCorrupted files:")
        for i, p, reason in corrupted:
            print(f"  [{i}] {reason}: {p}")
        print("\nFix: remove or re-encode these files, then rebuild the dataset.")
        sys.exit(1)

    print("All audio files are valid. Training should proceed without soundfile crashes.\n")
    sys.exit(0)


def main():
    parser = argparse.ArgumentParser(description="Validate F5-TTS dataset audio files")
    parser.add_argument("--dataset", default="data/sinhala_tts_batch03_custom", help="Path to dataset dir containing raw.arrow")
    args = parser.parse_args()
    validate_dataset(args.dataset)


if __name__ == "__main__":
    main()