Speach-To-Text / src /data_preparation /build_dataset.py
MIP-Tech's picture
Deploy to HF Spaces
0db822c
"""
Build a Hugging Face DatasetDict from processed audio segments.
Expected input: a list of records produced by segment_audio.process_pair()
or download_mgb3.py.
Output: DatasetDict with "train", "eval", and "test" splits saved to disk.
Split ratios default to 80 / 10 / 10.
"""
from __future__ import annotations
import json
import logging
import random
from pathlib import Path
from typing import List
import datasets
from datasets import Audio, Dataset, DatasetDict
logger = logging.getLogger(__name__)
def records_to_dataset(records: List[dict]) -> Dataset:
"""Convert a flat list of segment records into a Dataset."""
data = {
"audio": [r["audio_path"] for r in records],
"sentence": [r["sentence"] for r in records],
"duration": [r["duration"] for r in records],
"source_audio": [r["source_audio"] for r in records],
}
ds = Dataset.from_dict(data)
# Cast the audio column so HF loads + resamples automatically
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
return ds
def split_dataset(
records: List[dict],
eval_ratio: float = 0.1,
test_ratio: float = 0.1,
seed: int = 42,
) -> DatasetDict:
"""
Shuffle and split records into train / eval / test at the SOURCE AUDIO level.
All segments that came from the same original recording are kept together
in the same split, preventing data leakage across splits.
Split order (applied to shuffled source groups):
1. test — first `test_ratio` fraction of sources (held out)
2. eval — next `eval_ratio` fraction of sources (validation)
3. train — remainder (typically ~80 %)
Both `eval_ratio` and `test_ratio` are expressed as fractions of the total.
"""
# Group segments by their source audio file
from collections import defaultdict
groups: dict[str, List[dict]] = defaultdict(list)
for r in records:
groups[r["source_audio"]].append(r)
sources = list(groups.keys())
rng = random.Random(seed)
rng.shuffle(sources)
n_total = len(sources)
n_test = max(1, int(n_total * test_ratio))
n_eval = max(1, int(n_total * eval_ratio))
test_sources = sources[:n_test]
eval_sources = sources[n_test : n_test + n_eval]
train_sources = sources[n_test + n_eval :]
def flatten(src_list):
result = []
for s in src_list:
result.extend(groups[s])
return result
train_records = flatten(train_sources)
eval_records = flatten(eval_sources)
test_records = flatten(test_sources)
logger.info(
"Dataset split (by source) — "
"train: %d segments (%d sources) "
"eval: %d segments (%d sources) "
"test: %d segments (%d sources)",
len(train_records), len(train_sources),
len(eval_records), len(eval_sources),
len(test_records), len(test_sources),
)
return DatasetDict({
"train": records_to_dataset(train_records),
"eval": records_to_dataset(eval_records),
"test": records_to_dataset(test_records),
})
def save_manifest(records: List[dict], path: Path | str) -> None:
"""Save records as JSON for reproducibility / inspection."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh:
json.dump(records, fh, ensure_ascii=False, indent=2)
logger.info("Manifest saved to %s", path)
def load_manifest(path: Path | str) -> List[dict]:
with Path(path).open(encoding="utf-8") as fh:
return json.load(fh)
def build_and_save(
records: List[dict],
dataset_dir: Path | str,
eval_ratio: float = 0.1,
test_ratio: float = 0.1,
) -> DatasetDict:
"""Build, split (train/eval/test), and save the DatasetDict to disk."""
dataset_dir = Path(dataset_dir)
logger.info("Building dataset from %d records (eval=%.0f%%, test=%.0f%%) ...",
len(records), eval_ratio * 100, test_ratio * 100)
dd = split_dataset(records, eval_ratio=eval_ratio, test_ratio=test_ratio)
hf_path = dataset_dir / "hf_dataset"
logger.info("Saving HuggingFace DatasetDict to %s ...", hf_path)
dd.save_to_disk(str(hf_path))
save_manifest(records, dataset_dir / "manifest.json")
logger.info("Dataset build complete — saved to %s", dataset_dir)
return dd
def load_saved_dataset(dataset_dir: Path | str) -> DatasetDict:
hf_path = Path(dataset_dir) / "hf_dataset"
logger.info("Loading dataset from %s ...", hf_path)
dd = datasets.load_from_disk(str(hf_path))
for split_name, split_ds in dd.items():
logger.info(" %-8s: %d samples", split_name, len(split_ds))
logger.info("Dataset loaded successfully")
return dd