Spaces:
Sleeping
Sleeping
| """ | |
| Download the MGB-3 Egyptian Arabic dataset from HuggingFace and convert it | |
| into the same audio + JSON transcript format used by the local data pipeline. | |
| Dataset: MightyStudent/Egyptian-ASR-MGB-3 | |
| 16 hours of Egyptian Arabic speech from YouTube (comedy, cooking, drama, | |
| sports, TEDx, etc.). Audio is already 16 kHz mono in most configs. | |
| What this script does: | |
| 1. Downloads all splits of the dataset via the HuggingFace datasets library. | |
| 2. Saves each audio segment as a 16 kHz mono WAV to data/raw/audio/. | |
| 3. Saves a matching JSON transcript file to data/raw/transcripts/ with the | |
| same stem, in the format expected by parse_transcripts.py: | |
| { | |
| "video_id": "mgb3_train_000000", | |
| "title": "MGB-3 Egyptian ASR - train", | |
| "transcript": [ | |
| {"start": 0.0, "duration": <seconds>, "text": "<raw text>"} | |
| ] | |
| } | |
| 4. Writes data/mgb3/records_index.json as a summary index (optional, for | |
| reference only — prepare_data.py reads the raw/ folder directly). | |
| After running this script, run: | |
| python scripts/prepare_data.py | |
| which will pick up both local and MGB-3 data from data/raw/ automatically. | |
| Usage: | |
| python scripts/download_mgb3.py | |
| python scripts/download_mgb3.py --output_audio data/raw/audio --output_transcripts data/raw/transcripts | |
| python scripts/download_mgb3.py --max_samples 500 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import io | |
| import numpy as np | |
| import soundfile as sf | |
| from tqdm import tqdm | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.data_preparation.parse_transcripts import normalize_arabic | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| DATASET_ID = "MightyStudent/Egyptian-ASR-MGB-3" | |
| TARGET_SR = 16_000 | |
| MIN_DURATION = 1.0 # seconds — drop segments shorter than this | |
| MAX_DURATION = 30.0 # seconds — drop segments longer than Whisper's window | |
| def _find_text_column(column_names: list[str]) -> str: | |
| """Return the name of the transcript column (varies by dataset version).""" | |
| for candidate in ("sentence", "text", "transcription", "transcript"): | |
| if candidate in column_names: | |
| return candidate | |
| raise ValueError( | |
| f"Cannot find a text column in dataset columns: {column_names}\n" | |
| "Update _find_text_column() with the correct column name." | |
| ) | |
| def _to_float32_mono(array: np.ndarray) -> np.ndarray: | |
| if array.ndim > 1: | |
| array = array.mean(axis=1) | |
| return array.astype(np.float32) | |
| def _save_wav(array: np.ndarray, sr: int, path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| sf.write(str(path), array, sr, subtype="PCM_16") | |
| def _save_transcript_json(stem: str, split_name: str, duration: float, raw_text: str, path: Path) -> None: | |
| """ | |
| Save a single-entry JSON transcript file matching the format expected by | |
| parse_transcripts.parse_transcript_file(). | |
| Each HF example is already one audio segment, so the transcript array | |
| has a single entry spanning the full duration (start=0.0). | |
| """ | |
| data = { | |
| "video_id": stem, | |
| "title": f"MGB-3 Egyptian ASR - {split_name}", | |
| "transcript": [ | |
| { | |
| "start": 0.0, | |
| "duration": round(duration, 6), | |
| "text": raw_text, | |
| } | |
| ], | |
| } | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8") as fh: | |
| json.dump(data, fh, ensure_ascii=False, indent=2) | |
| def download_and_convert( | |
| audio_dir: Path, | |
| transcript_dir: Path, | |
| max_samples: int | None = None, | |
| ) -> list[dict]: | |
| """ | |
| Download MGB-3, save each example as a WAV + JSON transcript pair. | |
| Returns an index list of all saved pairs (written to records_index.json). | |
| """ | |
| from datasets import load_dataset # type: ignore | |
| logger.info("Downloading %s from HuggingFace ...", DATASET_ID) | |
| try: | |
| raw_ds = load_dataset(DATASET_ID, trust_remote_code=False) | |
| except Exception as exc: | |
| logger.error("Failed to load dataset: %s", exc) | |
| sys.exit(1) | |
| audio_dir.mkdir(parents=True, exist_ok=True) | |
| transcript_dir.mkdir(parents=True, exist_ok=True) | |
| index: list[dict] = [] | |
| seg_id = 0 | |
| skipped_no_text = 0 | |
| skipped_duration = 0 | |
| skipped_silent = 0 | |
| for split_name, split_ds in raw_ds.items(): | |
| logger.info("Processing split '%s' (%d examples) ...", split_name, len(split_ds)) | |
| text_col = _find_text_column(split_ds.column_names) | |
| logger.info("Using column '%s' as transcript", text_col) | |
| # Decode audio as raw bytes to avoid torchcodec/FFmpeg dependency | |
| from datasets import Audio as HFAudio # type: ignore | |
| split_ds = split_ds.cast_column("audio", HFAudio(decode=False)) | |
| for example in tqdm(split_ds, desc=split_name, unit="seg"): | |
| if max_samples is not None and len(index) >= max_samples: | |
| break | |
| # --- Transcript --- | |
| raw_text = example.get(text_col, "") or "" | |
| raw_text = raw_text.replace("\n", " ").strip() | |
| normalized = normalize_arabic(raw_text) | |
| if not normalized: | |
| skipped_no_text += 1 | |
| continue | |
| # --- Audio (decode with soundfile, bypassing torchcodec) --- | |
| audio_obj = example["audio"] | |
| audio_bytes = audio_obj.get("bytes") | |
| audio_path = audio_obj.get("path") | |
| try: | |
| if audio_bytes: | |
| array, sr = sf.read(io.BytesIO(audio_bytes)) | |
| elif audio_path: | |
| array, sr = sf.read(audio_path) | |
| else: | |
| skipped_no_text += 1 | |
| continue | |
| except Exception as exc: | |
| logger.warning("Could not decode audio for example %d: %s", seg_id, exc) | |
| skipped_no_text += 1 | |
| continue | |
| array: np.ndarray = np.array(array, dtype=np.float32) | |
| sr: int = int(sr) | |
| array = _to_float32_mono(array) | |
| if sr != TARGET_SR: | |
| import torch | |
| import torchaudio.functional as F_audio # type: ignore | |
| waveform = torch.from_numpy(array).unsqueeze(0) | |
| resampled = F_audio.resample(waveform, sr, TARGET_SR) | |
| array = resampled.squeeze(0).numpy().astype(np.float32) | |
| sr = TARGET_SR | |
| duration = len(array) / sr | |
| if duration < MIN_DURATION or duration > MAX_DURATION: | |
| skipped_duration += 1 | |
| continue | |
| rms = float(np.sqrt(np.mean(array ** 2))) | |
| if rms < 0.001: | |
| skipped_silent += 1 | |
| continue | |
| # --- Save WAV + JSON pair --- | |
| stem = f"mgb3_{split_name}_{seg_id:06d}" | |
| wav_path = audio_dir / f"{stem}.wav" | |
| json_path = transcript_dir / f"{stem}.json" | |
| _save_wav(array, sr, wav_path) | |
| _save_transcript_json(stem, split_name, duration, raw_text, json_path) | |
| index.append({ | |
| "stem": stem, | |
| "audio_path": str(wav_path), | |
| "json_path": str(json_path), | |
| "duration": duration, | |
| "split": split_name, | |
| }) | |
| seg_id += 1 | |
| logger.info( | |
| "Done — saved %d pairs (skipped: %d no-text, %d duration, %d silent)", | |
| len(index), skipped_no_text, skipped_duration, skipped_silent, | |
| ) | |
| return index | |
| def main( | |
| audio_dir: str, | |
| transcript_dir: str, | |
| max_samples: int | None, | |
| ) -> None: | |
| a_dir = Path(audio_dir) | |
| t_dir = Path(transcript_dir) | |
| index = download_and_convert(a_dir, t_dir, max_samples=max_samples) | |
| if not index: | |
| logger.error("No pairs produced — check the dataset or your internet connection.") | |
| sys.exit(1) | |
| # Write optional summary index next to the audio folder | |
| index_path = a_dir.parent / "mgb3" / "records_index.json" | |
| index_path.parent.mkdir(parents=True, exist_ok=True) | |
| with index_path.open("w", encoding="utf-8") as fh: | |
| json.dump(index, fh, ensure_ascii=False, indent=2) | |
| total_hours = sum(r["duration"] for r in index) / 3600.0 | |
| logger.info( | |
| "Saved %d WAV + JSON pairs (%.1f h)\n" | |
| " Audio → %s\n" | |
| " Transcripts → %s\n" | |
| " Index → %s\n" | |
| "Next step: python scripts/prepare_data.py", | |
| len(index), total_hours, a_dir, t_dir, index_path, | |
| ) | |
| if __name__ == "__main__": | |
| root = Path(__file__).parent.parent | |
| parser = argparse.ArgumentParser( | |
| description="Download MGB-3 and save as WAV + JSON transcript pairs" | |
| ) | |
| parser.add_argument( | |
| "--output_audio", | |
| default=str(root / "data" / "raw" / "audio"), | |
| help="Directory to save WAV files (default: data/raw/audio)", | |
| ) | |
| parser.add_argument( | |
| "--output_transcripts", | |
| default=str(root / "data" / "raw" / "transcripts"), | |
| help="Directory to save JSON transcript files (default: data/raw/transcripts)", | |
| ) | |
| parser.add_argument( | |
| "--max_samples", | |
| type=int, | |
| default=None, | |
| help="Cap on number of segments to download (default: all)", | |
| ) | |
| args = parser.parse_args() | |
| main(args.output_audio, args.output_transcripts, args.max_samples) | |