Spaces:
Sleeping
Sleeping
File size: 5,735 Bytes
0db822c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
Data preparation pipeline.
Usage:
python scripts/prepare_data.py
python scripts/prepare_data.py --config config/training_config.yaml
Expected input layout (data/raw/):
data/raw/audio/ β audio files (.wav, .mp3, .m4a, .flac, .ogg)
data/raw/transcripts/ β matching .json transcript files (same stem)
MGB-3 data: run python scripts/download_mgb3.py first β it saves MGB-3
audio and transcripts directly into data/raw/, so they are picked up here
automatically alongside any other local files.
Steps performed:
1. Scan data/raw/audio/ for all audio files.
2. For each audio file, find a matching .json in data/raw/transcripts/.
3. Parse the JSON transcript β List[TranscriptEntry].
4. Group entries into β€25-second segments β List[TranscriptSegment].
5. Slice the audio into aligned WAV chunks β data/processed/audio_segments/.
6. Build a HuggingFace DatasetDict (train / eval / test split, grouped by
source audio to prevent leakage) β data/processed/.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import yaml
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.data_preparation.build_dataset import build_and_save
from src.data_preparation.parse_transcripts import build_segments, parse_transcript_file
from src.data_preparation.segment_audio import process_pair
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg"}
def main(config_path: str = "config/training_config.yaml") -> None:
root = Path(__file__).parent.parent
with (root / config_path).open() as fh:
cfg = yaml.safe_load(fh)
audio_dir = root / cfg["data"]["raw_audio_dir"]
transcript_dir = root / cfg["data"]["raw_transcripts_dir"]
processed_dir = root / cfg["data"]["processed_dir"]
wav_dir = processed_dir / "audio_segments"
max_dur: float = cfg["data"]["max_segment_duration"]
min_dur: float = cfg["data"]["min_segment_duration"]
sr: int = cfg["data"]["sample_rate"]
eval_ratio: float = cfg["data"]["eval_split_ratio"]
test_ratio: float = cfg["data"].get("test_split_ratio", 0.1)
min_amplitude: float = cfg["data"].get("min_audio_amplitude", 0.001)
# ------------------------------------------------------------------ #
# 1. Process local audio + transcript pairs
# ------------------------------------------------------------------ #
audio_files = [
p for p in audio_dir.iterdir()
if p.suffix.lower() in AUDIO_EXTENSIONS
]
local_records: list[dict] = []
if not audio_files:
logger.warning("No audio files found in %s", audio_dir)
else:
logger.info("Found %d local audio file(s)", len(audio_files))
skipped = 0
for audio_path in sorted(audio_files):
transcript_path = transcript_dir / (audio_path.stem + ".json")
if not transcript_path.exists():
logger.warning(
"No JSON transcript for '%s' (expected '%s') β skipping",
audio_path.name, transcript_path.name,
)
skipped += 1
continue
entries = parse_transcript_file(transcript_path)
if not entries:
logger.warning(
"Transcript '%s' produced no entries β skipping",
transcript_path.name,
)
skipped += 1
continue
segments = build_segments(
entries,
source_audio=audio_path.stem,
max_duration=max_dur,
min_duration=min_dur,
)
if not segments:
logger.warning(
"No segments produced for '%s' β check min_segment_duration",
audio_path.name,
)
skipped += 1
continue
records = process_pair(
audio_path, segments, wav_dir,
sample_rate=sr,
min_amplitude=min_amplitude,
)
local_records.extend(records)
logger.info(
"Local segments: %d | Skipped files: %d",
len(local_records), skipped,
)
all_records = local_records
if not all_records:
logger.error(
"No segments produced.\n"
" β’ Check data/raw/audio/ for audio files with matching JSON transcripts.\n"
" β’ Run python scripts/download_mgb3.py to add MGB-3 data."
)
sys.exit(1)
total_hours = sum(r["duration"] for r in all_records) / 3600.0
logger.info("Total dataset: %d segments (%.1f h)", len(all_records), total_hours)
# ------------------------------------------------------------------ #
# 4. Build and save HuggingFace DatasetDict (train / eval / test)
# ------------------------------------------------------------------ #
build_and_save(
all_records,
processed_dir,
eval_ratio=eval_ratio,
test_ratio=test_ratio,
)
logger.info("Data preparation complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Prepare combined dataset for Whisper Arabic fine-tuning"
)
parser.add_argument(
"--config",
default="config/training_config.yaml",
help="Path to training_config.yaml (relative to project root)",
)
args = parser.parse_args()
main(args.config)
|