File size: 6,180 Bytes
76d096a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
"""
Prepare Sinhala dataset for F5-TTS fine-tuning.

This bypasses F5-TTS's built-in pinyin-based prep and builds
the Arrow dataset directly for Sinhala character-level tokenization.

Usage:
    python scripts/prepare_f5_data.py \
        --dataset_dir /path/to/cc_v1_tenvideo_baseline \
        --output_dir data/sinhala_tts_custom \
        --vocab_path data/sinhala_vocab/vocab.txt

The output goes to data/<output_name>/ which must match:
    --dataset_name <output_name> --tokenizer custom
in the finetune CLI.
"""

import argparse
import json
import os
import sys
from pathlib import Path

import soundfile as sf
from tqdm import tqdm


def build_arrow_dataset(dataset_dir: str, output_dir: str, vocab_path: str,
                        min_dur: float = 0.3, max_dur: float = 30.0):
    """Build Arrow dataset + duration.json + copy vocab.txt."""
    from datasets.arrow_writer import ArrowWriter

    dataset_path = Path(dataset_dir)
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # --- Find metadata CSV ---
    meta_csv = dataset_path / "metadata.csv"
    if not meta_csv.exists():
        print(f"ERROR: {meta_csv} not found")
        sys.exit(1)

    # --- Find wavs directory ---
    wavs_dir = dataset_path / "wavs"
    if not wavs_dir.exists():
        print(f"ERROR: {wavs_dir} not found")
        sys.exit(1)

    # --- Parse metadata (LJSpeech format: filename|text|normalized_text) ---
    records = []
    skipped_missing = 0
    skipped_duration = 0

    print(f"Reading metadata from {meta_csv}...")
    with open(meta_csv, encoding="utf-8") as f:
        for line_num, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            parts = line.split("|")
            if len(parts) < 2:
                continue

            uttr_id = parts[0].strip()
            text = parts[1].strip()
            norm_text = parts[2].strip() if len(parts) > 2 else text

            # Resolve wav path — handle both flat and subdirectory layouts
            wav_path = wavs_dir / f"{uttr_id}.wav"
            if not wav_path.exists():
                # Try with subdirectory structure (e.g., wavs/00/si_000042.wav)
                try:
                    num = int(uttr_id.split("_")[1])
                    subdir = f"{num // 5000:02d}"
                    wav_path = wavs_dir / subdir / f"{uttr_id}.wav"
                except (IndexError, ValueError):
                    pass

            if not wav_path.exists():
                skipped_missing += 1
                continue

            # Get duration
            try:
                info = sf.info(str(wav_path))
                dur = info.duration
            except Exception as e:
                print(f"  WARNING: Can't read {wav_path}: {e}")
                skipped_missing += 1
                continue

            if dur < min_dur or dur > max_dur:
                skipped_duration += 1
                continue

            records.append({
                "audio_path": str(wav_path.resolve()),
                "text": norm_text,
                "duration": dur,
            })

    print(f"\nParsed {len(records)} valid utterances")
    print(f"  Skipped (missing wav): {skipped_missing}")
    print(f"  Skipped (duration): {skipped_duration}")
    total_hours = sum(r["duration"] for r in records) / 3600
    print(f"  Total audio: {total_hours:.2f} hours")

    if len(records) == 0:
        print("ERROR: No valid records found!")
        sys.exit(1)

    # --- Build vocab from data (or use provided) ---
    if vocab_path and Path(vocab_path).exists():
        print(f"\nUsing provided vocab: {vocab_path}")
        import shutil
        shutil.copy2(vocab_path, output_path / "vocab.txt")
    else:
        print("\nBuilding vocab from data...")
        vocab_set = set()
        for r in records:
            vocab_set.update(list(r["text"]))
        # Space must be index 0
        vocab_sorted = [" "] + sorted(v for v in vocab_set if v != " ")
        with open(output_path / "vocab.txt", "w", encoding="utf-8") as f:
            for ch in vocab_sorted:
                f.write(ch + "\n")
        print(f"  Vocab size: {len(vocab_sorted)}")

    # --- Write duration.json ---
    durations = [r["duration"] for r in records]
    with open(output_path / "duration.json", "w") as f:
        json.dump({"duration": durations}, f)
    print(f"Wrote duration.json ({len(durations)} entries)")

    # --- Write Arrow dataset ---
    print("Building Arrow dataset...")
    arrow_path = output_path / "raw.arrow"

    with ArrowWriter(path=str(arrow_path)) as writer:
        for record in tqdm(records, desc="Writing Arrow"):
            writer.write(record)
        writer.finalize()

    print(f"\nDone! Output directory: {output_path}")
    print(f"  raw.arrow     : {arrow_path.stat().st_size / 1024:.1f} KB")
    print(f"  duration.json : {(output_path / 'duration.json').stat().st_size / 1024:.1f} KB")
    print(f"  vocab.txt     : {(output_path / 'vocab.txt').stat().st_size / 1024:.1f} KB")
    print(f"\nTotal: {len(records)} utterances, {total_hours:.2f} hours")


def main():
    parser = argparse.ArgumentParser(description="Prepare Sinhala dataset for F5-TTS")
    parser.add_argument("--dataset_dir", required=True,
                        help="Path to dataset directory (contains metadata.csv + wavs/)")
    parser.add_argument("--output_dir", default="data/sinhala_tts_custom",
                        help="Output directory for Arrow dataset")
    parser.add_argument("--vocab_path", default="data/sinhala_vocab/vocab.txt",
                        help="Path to vocab.txt (Sinhala character set)")
    parser.add_argument("--min_dur", type=float, default=0.3,
                        help="Min utterance duration in seconds")
    parser.add_argument("--max_dur", type=float, default=30.0,
                        help="Max utterance duration in seconds")

    args = parser.parse_args()
    build_arrow_dataset(args.dataset_dir, args.output_dir, args.vocab_path,
                        args.min_dur, args.max_dur)


if __name__ == "__main__":
    main()