File size: 5,735 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Data preparation pipeline.

Usage:
    python scripts/prepare_data.py
    python scripts/prepare_data.py --config config/training_config.yaml

Expected input layout (data/raw/):
    data/raw/audio/           ← audio files (.wav, .mp3, .m4a, .flac, .ogg)
    data/raw/transcripts/     ← matching .json transcript files (same stem)

MGB-3 data: run  python scripts/download_mgb3.py  first β€” it saves MGB-3
audio and transcripts directly into data/raw/, so they are picked up here
automatically alongside any other local files.

Steps performed:
    1. Scan data/raw/audio/ for all audio files.
    2. For each audio file, find a matching .json in data/raw/transcripts/.
    3. Parse the JSON transcript β†’ List[TranscriptEntry].
    4. Group entries into ≀25-second segments β†’ List[TranscriptSegment].
    5. Slice the audio into aligned WAV chunks β†’ data/processed/audio_segments/.
    6. Build a HuggingFace DatasetDict (train / eval / test split, grouped by
       source audio to prevent leakage) β†’ data/processed/.
"""

from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path

import yaml

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.data_preparation.build_dataset import build_and_save
from src.data_preparation.parse_transcripts import build_segments, parse_transcript_file
from src.data_preparation.segment_audio import process_pair

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)

AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg"}


def main(config_path: str = "config/training_config.yaml") -> None:
    root = Path(__file__).parent.parent

    with (root / config_path).open() as fh:
        cfg = yaml.safe_load(fh)

    audio_dir      = root / cfg["data"]["raw_audio_dir"]
    transcript_dir = root / cfg["data"]["raw_transcripts_dir"]
    processed_dir  = root / cfg["data"]["processed_dir"]
    wav_dir        = processed_dir / "audio_segments"

    max_dur: float  = cfg["data"]["max_segment_duration"]
    min_dur: float  = cfg["data"]["min_segment_duration"]
    sr: int         = cfg["data"]["sample_rate"]
    eval_ratio: float = cfg["data"]["eval_split_ratio"]
    test_ratio: float = cfg["data"].get("test_split_ratio", 0.1)
    min_amplitude: float = cfg["data"].get("min_audio_amplitude", 0.001)

    # ------------------------------------------------------------------ #
    # 1. Process local audio + transcript pairs
    # ------------------------------------------------------------------ #
    audio_files = [
        p for p in audio_dir.iterdir()
        if p.suffix.lower() in AUDIO_EXTENSIONS
    ]

    local_records: list[dict] = []

    if not audio_files:
        logger.warning("No audio files found in %s", audio_dir)
    else:
        logger.info("Found %d local audio file(s)", len(audio_files))
        skipped = 0

        for audio_path in sorted(audio_files):
            transcript_path = transcript_dir / (audio_path.stem + ".json")

            if not transcript_path.exists():
                logger.warning(
                    "No JSON transcript for '%s' (expected '%s') β€” skipping",
                    audio_path.name, transcript_path.name,
                )
                skipped += 1
                continue

            entries = parse_transcript_file(transcript_path)
            if not entries:
                logger.warning(
                    "Transcript '%s' produced no entries β€” skipping",
                    transcript_path.name,
                )
                skipped += 1
                continue

            segments = build_segments(
                entries,
                source_audio=audio_path.stem,
                max_duration=max_dur,
                min_duration=min_dur,
            )

            if not segments:
                logger.warning(
                    "No segments produced for '%s' β€” check min_segment_duration",
                    audio_path.name,
                )
                skipped += 1
                continue

            records = process_pair(
                audio_path, segments, wav_dir,
                sample_rate=sr,
                min_amplitude=min_amplitude,
            )
            local_records.extend(records)

        logger.info(
            "Local segments: %d  |  Skipped files: %d",
            len(local_records), skipped,
        )

    all_records = local_records

    if not all_records:
        logger.error(
            "No segments produced.\n"
            "  β€’ Check data/raw/audio/ for audio files with matching JSON transcripts.\n"
            "  β€’ Run  python scripts/download_mgb3.py  to add MGB-3 data."
        )
        sys.exit(1)

    total_hours = sum(r["duration"] for r in all_records) / 3600.0
    logger.info("Total dataset: %d segments (%.1f h)", len(all_records), total_hours)

    # ------------------------------------------------------------------ #
    # 4. Build and save HuggingFace DatasetDict (train / eval / test)
    # ------------------------------------------------------------------ #
    build_and_save(
        all_records,
        processed_dir,
        eval_ratio=eval_ratio,
        test_ratio=test_ratio,
    )
    logger.info("Data preparation complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Prepare combined dataset for Whisper Arabic fine-tuning"
    )
    parser.add_argument(
        "--config",
        default="config/training_config.yaml",
        help="Path to training_config.yaml (relative to project root)",
    )
    args = parser.parse_args()
    main(args.config)