Speach-To-Text / scripts /download_mgb3.py
MIP-Tech's picture
Deploy to HF Spaces
0db822c
"""
Download the MGB-3 Egyptian Arabic dataset from HuggingFace and convert it
into the same audio + JSON transcript format used by the local data pipeline.
Dataset: MightyStudent/Egyptian-ASR-MGB-3
16 hours of Egyptian Arabic speech from YouTube (comedy, cooking, drama,
sports, TEDx, etc.). Audio is already 16 kHz mono in most configs.
What this script does:
1. Downloads all splits of the dataset via the HuggingFace datasets library.
2. Saves each audio segment as a 16 kHz mono WAV to data/raw/audio/.
3. Saves a matching JSON transcript file to data/raw/transcripts/ with the
same stem, in the format expected by parse_transcripts.py:
{
"video_id": "mgb3_train_000000",
"title": "MGB-3 Egyptian ASR - train",
"transcript": [
{"start": 0.0, "duration": <seconds>, "text": "<raw text>"}
]
}
4. Writes data/mgb3/records_index.json as a summary index (optional, for
reference only — prepare_data.py reads the raw/ folder directly).
After running this script, run:
python scripts/prepare_data.py
which will pick up both local and MGB-3 data from data/raw/ automatically.
Usage:
python scripts/download_mgb3.py
python scripts/download_mgb3.py --output_audio data/raw/audio --output_transcripts data/raw/transcripts
python scripts/download_mgb3.py --max_samples 500
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
import io
import numpy as np
import soundfile as sf
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.data_preparation.parse_transcripts import normalize_arabic
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
DATASET_ID = "MightyStudent/Egyptian-ASR-MGB-3"
TARGET_SR = 16_000
MIN_DURATION = 1.0 # seconds — drop segments shorter than this
MAX_DURATION = 30.0 # seconds — drop segments longer than Whisper's window
def _find_text_column(column_names: list[str]) -> str:
"""Return the name of the transcript column (varies by dataset version)."""
for candidate in ("sentence", "text", "transcription", "transcript"):
if candidate in column_names:
return candidate
raise ValueError(
f"Cannot find a text column in dataset columns: {column_names}\n"
"Update _find_text_column() with the correct column name."
)
def _to_float32_mono(array: np.ndarray) -> np.ndarray:
if array.ndim > 1:
array = array.mean(axis=1)
return array.astype(np.float32)
def _save_wav(array: np.ndarray, sr: int, path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(path), array, sr, subtype="PCM_16")
def _save_transcript_json(stem: str, split_name: str, duration: float, raw_text: str, path: Path) -> None:
"""
Save a single-entry JSON transcript file matching the format expected by
parse_transcripts.parse_transcript_file().
Each HF example is already one audio segment, so the transcript array
has a single entry spanning the full duration (start=0.0).
"""
data = {
"video_id": stem,
"title": f"MGB-3 Egyptian ASR - {split_name}",
"transcript": [
{
"start": 0.0,
"duration": round(duration, 6),
"text": raw_text,
}
],
}
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh:
json.dump(data, fh, ensure_ascii=False, indent=2)
def download_and_convert(
audio_dir: Path,
transcript_dir: Path,
max_samples: int | None = None,
) -> list[dict]:
"""
Download MGB-3, save each example as a WAV + JSON transcript pair.
Returns an index list of all saved pairs (written to records_index.json).
"""
from datasets import load_dataset # type: ignore
logger.info("Downloading %s from HuggingFace ...", DATASET_ID)
try:
raw_ds = load_dataset(DATASET_ID, trust_remote_code=False)
except Exception as exc:
logger.error("Failed to load dataset: %s", exc)
sys.exit(1)
audio_dir.mkdir(parents=True, exist_ok=True)
transcript_dir.mkdir(parents=True, exist_ok=True)
index: list[dict] = []
seg_id = 0
skipped_no_text = 0
skipped_duration = 0
skipped_silent = 0
for split_name, split_ds in raw_ds.items():
logger.info("Processing split '%s' (%d examples) ...", split_name, len(split_ds))
text_col = _find_text_column(split_ds.column_names)
logger.info("Using column '%s' as transcript", text_col)
# Decode audio as raw bytes to avoid torchcodec/FFmpeg dependency
from datasets import Audio as HFAudio # type: ignore
split_ds = split_ds.cast_column("audio", HFAudio(decode=False))
for example in tqdm(split_ds, desc=split_name, unit="seg"):
if max_samples is not None and len(index) >= max_samples:
break
# --- Transcript ---
raw_text = example.get(text_col, "") or ""
raw_text = raw_text.replace("\n", " ").strip()
normalized = normalize_arabic(raw_text)
if not normalized:
skipped_no_text += 1
continue
# --- Audio (decode with soundfile, bypassing torchcodec) ---
audio_obj = example["audio"]
audio_bytes = audio_obj.get("bytes")
audio_path = audio_obj.get("path")
try:
if audio_bytes:
array, sr = sf.read(io.BytesIO(audio_bytes))
elif audio_path:
array, sr = sf.read(audio_path)
else:
skipped_no_text += 1
continue
except Exception as exc:
logger.warning("Could not decode audio for example %d: %s", seg_id, exc)
skipped_no_text += 1
continue
array: np.ndarray = np.array(array, dtype=np.float32)
sr: int = int(sr)
array = _to_float32_mono(array)
if sr != TARGET_SR:
import torch
import torchaudio.functional as F_audio # type: ignore
waveform = torch.from_numpy(array).unsqueeze(0)
resampled = F_audio.resample(waveform, sr, TARGET_SR)
array = resampled.squeeze(0).numpy().astype(np.float32)
sr = TARGET_SR
duration = len(array) / sr
if duration < MIN_DURATION or duration > MAX_DURATION:
skipped_duration += 1
continue
rms = float(np.sqrt(np.mean(array ** 2)))
if rms < 0.001:
skipped_silent += 1
continue
# --- Save WAV + JSON pair ---
stem = f"mgb3_{split_name}_{seg_id:06d}"
wav_path = audio_dir / f"{stem}.wav"
json_path = transcript_dir / f"{stem}.json"
_save_wav(array, sr, wav_path)
_save_transcript_json(stem, split_name, duration, raw_text, json_path)
index.append({
"stem": stem,
"audio_path": str(wav_path),
"json_path": str(json_path),
"duration": duration,
"split": split_name,
})
seg_id += 1
logger.info(
"Done — saved %d pairs (skipped: %d no-text, %d duration, %d silent)",
len(index), skipped_no_text, skipped_duration, skipped_silent,
)
return index
def main(
audio_dir: str,
transcript_dir: str,
max_samples: int | None,
) -> None:
a_dir = Path(audio_dir)
t_dir = Path(transcript_dir)
index = download_and_convert(a_dir, t_dir, max_samples=max_samples)
if not index:
logger.error("No pairs produced — check the dataset or your internet connection.")
sys.exit(1)
# Write optional summary index next to the audio folder
index_path = a_dir.parent / "mgb3" / "records_index.json"
index_path.parent.mkdir(parents=True, exist_ok=True)
with index_path.open("w", encoding="utf-8") as fh:
json.dump(index, fh, ensure_ascii=False, indent=2)
total_hours = sum(r["duration"] for r in index) / 3600.0
logger.info(
"Saved %d WAV + JSON pairs (%.1f h)\n"
" Audio → %s\n"
" Transcripts → %s\n"
" Index → %s\n"
"Next step: python scripts/prepare_data.py",
len(index), total_hours, a_dir, t_dir, index_path,
)
if __name__ == "__main__":
root = Path(__file__).parent.parent
parser = argparse.ArgumentParser(
description="Download MGB-3 and save as WAV + JSON transcript pairs"
)
parser.add_argument(
"--output_audio",
default=str(root / "data" / "raw" / "audio"),
help="Directory to save WAV files (default: data/raw/audio)",
)
parser.add_argument(
"--output_transcripts",
default=str(root / "data" / "raw" / "transcripts"),
help="Directory to save JSON transcript files (default: data/raw/transcripts)",
)
parser.add_argument(
"--max_samples",
type=int,
default=None,
help="Cap on number of segments to download (default: all)",
)
args = parser.parse_args()
main(args.output_audio, args.output_transcripts, args.max_samples)