""" Build a Hugging Face DatasetDict from processed audio segments. Expected input: a list of records produced by segment_audio.process_pair() or download_mgb3.py. Output: DatasetDict with "train", "eval", and "test" splits saved to disk. Split ratios default to 80 / 10 / 10. """ from __future__ import annotations import json import logging import random from pathlib import Path from typing import List import datasets from datasets import Audio, Dataset, DatasetDict logger = logging.getLogger(__name__) def records_to_dataset(records: List[dict]) -> Dataset: """Convert a flat list of segment records into a Dataset.""" data = { "audio": [r["audio_path"] for r in records], "sentence": [r["sentence"] for r in records], "duration": [r["duration"] for r in records], "source_audio": [r["source_audio"] for r in records], } ds = Dataset.from_dict(data) # Cast the audio column so HF loads + resamples automatically ds = ds.cast_column("audio", Audio(sampling_rate=16_000)) return ds def split_dataset( records: List[dict], eval_ratio: float = 0.1, test_ratio: float = 0.1, seed: int = 42, ) -> DatasetDict: """ Shuffle and split records into train / eval / test at the SOURCE AUDIO level. All segments that came from the same original recording are kept together in the same split, preventing data leakage across splits. Split order (applied to shuffled source groups): 1. test — first `test_ratio` fraction of sources (held out) 2. eval — next `eval_ratio` fraction of sources (validation) 3. train — remainder (typically ~80 %) Both `eval_ratio` and `test_ratio` are expressed as fractions of the total. """ # Group segments by their source audio file from collections import defaultdict groups: dict[str, List[dict]] = defaultdict(list) for r in records: groups[r["source_audio"]].append(r) sources = list(groups.keys()) rng = random.Random(seed) rng.shuffle(sources) n_total = len(sources) n_test = max(1, int(n_total * test_ratio)) n_eval = max(1, int(n_total * eval_ratio)) test_sources = sources[:n_test] eval_sources = sources[n_test : n_test + n_eval] train_sources = sources[n_test + n_eval :] def flatten(src_list): result = [] for s in src_list: result.extend(groups[s]) return result train_records = flatten(train_sources) eval_records = flatten(eval_sources) test_records = flatten(test_sources) logger.info( "Dataset split (by source) — " "train: %d segments (%d sources) " "eval: %d segments (%d sources) " "test: %d segments (%d sources)", len(train_records), len(train_sources), len(eval_records), len(eval_sources), len(test_records), len(test_sources), ) return DatasetDict({ "train": records_to_dataset(train_records), "eval": records_to_dataset(eval_records), "test": records_to_dataset(test_records), }) def save_manifest(records: List[dict], path: Path | str) -> None: """Save records as JSON for reproducibility / inspection.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as fh: json.dump(records, fh, ensure_ascii=False, indent=2) logger.info("Manifest saved to %s", path) def load_manifest(path: Path | str) -> List[dict]: with Path(path).open(encoding="utf-8") as fh: return json.load(fh) def build_and_save( records: List[dict], dataset_dir: Path | str, eval_ratio: float = 0.1, test_ratio: float = 0.1, ) -> DatasetDict: """Build, split (train/eval/test), and save the DatasetDict to disk.""" dataset_dir = Path(dataset_dir) logger.info("Building dataset from %d records (eval=%.0f%%, test=%.0f%%) ...", len(records), eval_ratio * 100, test_ratio * 100) dd = split_dataset(records, eval_ratio=eval_ratio, test_ratio=test_ratio) hf_path = dataset_dir / "hf_dataset" logger.info("Saving HuggingFace DatasetDict to %s ...", hf_path) dd.save_to_disk(str(hf_path)) save_manifest(records, dataset_dir / "manifest.json") logger.info("Dataset build complete — saved to %s", dataset_dir) return dd def load_saved_dataset(dataset_dir: Path | str) -> DatasetDict: hf_path = Path(dataset_dir) / "hf_dataset" logger.info("Loading dataset from %s ...", hf_path) dd = datasets.load_from_disk(str(hf_path)) for split_name, split_ds in dd.items(): logger.info(" %-8s: %d samples", split_name, len(split_ds)) logger.info("Dataset loaded successfully") return dd