Speach-To-Text / scripts /prepare_data.py
MIP-Tech's picture
Deploy to HF Spaces
0db822c
"""
Data preparation pipeline.
Usage:
python scripts/prepare_data.py
python scripts/prepare_data.py --config config/training_config.yaml
Expected input layout (data/raw/):
data/raw/audio/ ← audio files (.wav, .mp3, .m4a, .flac, .ogg)
data/raw/transcripts/ ← matching .json transcript files (same stem)
MGB-3 data: run python scripts/download_mgb3.py first — it saves MGB-3
audio and transcripts directly into data/raw/, so they are picked up here
automatically alongside any other local files.
Steps performed:
1. Scan data/raw/audio/ for all audio files.
2. For each audio file, find a matching .json in data/raw/transcripts/.
3. Parse the JSON transcript → List[TranscriptEntry].
4. Group entries into ≤25-second segments → List[TranscriptSegment].
5. Slice the audio into aligned WAV chunks → data/processed/audio_segments/.
6. Build a HuggingFace DatasetDict (train / eval / test split, grouped by
source audio to prevent leakage) → data/processed/.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import yaml
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.data_preparation.build_dataset import build_and_save
from src.data_preparation.parse_transcripts import build_segments, parse_transcript_file
from src.data_preparation.segment_audio import process_pair
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg"}
def main(config_path: str = "config/training_config.yaml") -> None:
root = Path(__file__).parent.parent
with (root / config_path).open() as fh:
cfg = yaml.safe_load(fh)
audio_dir = root / cfg["data"]["raw_audio_dir"]
transcript_dir = root / cfg["data"]["raw_transcripts_dir"]
processed_dir = root / cfg["data"]["processed_dir"]
wav_dir = processed_dir / "audio_segments"
max_dur: float = cfg["data"]["max_segment_duration"]
min_dur: float = cfg["data"]["min_segment_duration"]
sr: int = cfg["data"]["sample_rate"]
eval_ratio: float = cfg["data"]["eval_split_ratio"]
test_ratio: float = cfg["data"].get("test_split_ratio", 0.1)
min_amplitude: float = cfg["data"].get("min_audio_amplitude", 0.001)
# ------------------------------------------------------------------ #
# 1. Process local audio + transcript pairs
# ------------------------------------------------------------------ #
audio_files = [
p for p in audio_dir.iterdir()
if p.suffix.lower() in AUDIO_EXTENSIONS
]
local_records: list[dict] = []
if not audio_files:
logger.warning("No audio files found in %s", audio_dir)
else:
logger.info("Found %d local audio file(s)", len(audio_files))
skipped = 0
for audio_path in sorted(audio_files):
transcript_path = transcript_dir / (audio_path.stem + ".json")
if not transcript_path.exists():
logger.warning(
"No JSON transcript for '%s' (expected '%s') — skipping",
audio_path.name, transcript_path.name,
)
skipped += 1
continue
entries = parse_transcript_file(transcript_path)
if not entries:
logger.warning(
"Transcript '%s' produced no entries — skipping",
transcript_path.name,
)
skipped += 1
continue
segments = build_segments(
entries,
source_audio=audio_path.stem,
max_duration=max_dur,
min_duration=min_dur,
)
if not segments:
logger.warning(
"No segments produced for '%s' — check min_segment_duration",
audio_path.name,
)
skipped += 1
continue
records = process_pair(
audio_path, segments, wav_dir,
sample_rate=sr,
min_amplitude=min_amplitude,
)
local_records.extend(records)
logger.info(
"Local segments: %d | Skipped files: %d",
len(local_records), skipped,
)
all_records = local_records
if not all_records:
logger.error(
"No segments produced.\n"
" • Check data/raw/audio/ for audio files with matching JSON transcripts.\n"
" • Run python scripts/download_mgb3.py to add MGB-3 data."
)
sys.exit(1)
total_hours = sum(r["duration"] for r in all_records) / 3600.0
logger.info("Total dataset: %d segments (%.1f h)", len(all_records), total_hours)
# ------------------------------------------------------------------ #
# 4. Build and save HuggingFace DatasetDict (train / eval / test)
# ------------------------------------------------------------------ #
build_and_save(
all_records,
processed_dir,
eval_ratio=eval_ratio,
test_ratio=test_ratio,
)
logger.info("Data preparation complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Prepare combined dataset for Whisper Arabic fine-tuning"
)
parser.add_argument(
"--config",
default="config/training_config.yaml",
help="Path to training_config.yaml (relative to project root)",
)
args = parser.parse_args()
main(args.config)