File size: 4,813 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Build a Hugging Face DatasetDict from processed audio segments.

Expected input: a list of records produced by segment_audio.process_pair()
or download_mgb3.py.

Output: DatasetDict with "train", "eval", and "test" splits saved to disk.
Split ratios default to 80 / 10 / 10.
"""

from __future__ import annotations

import json
import logging
import random
from pathlib import Path
from typing import List

import datasets
from datasets import Audio, Dataset, DatasetDict

logger = logging.getLogger(__name__)


def records_to_dataset(records: List[dict]) -> Dataset:
    """Convert a flat list of segment records into a Dataset."""
    data = {
        "audio":        [r["audio_path"] for r in records],
        "sentence":     [r["sentence"]   for r in records],
        "duration":     [r["duration"]   for r in records],
        "source_audio": [r["source_audio"] for r in records],
    }
    ds = Dataset.from_dict(data)
    # Cast the audio column so HF loads + resamples automatically
    ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
    return ds


def split_dataset(
    records: List[dict],
    eval_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = 42,
) -> DatasetDict:
    """
    Shuffle and split records into train / eval / test at the SOURCE AUDIO level.

    All segments that came from the same original recording are kept together
    in the same split, preventing data leakage across splits.

    Split order (applied to shuffled source groups):
      1. test  — first `test_ratio` fraction of sources (held out)
      2. eval  — next  `eval_ratio` fraction of sources (validation)
      3. train — remainder                              (typically ~80 %)

    Both `eval_ratio` and `test_ratio` are expressed as fractions of the total.
    """
    # Group segments by their source audio file
    from collections import defaultdict
    groups: dict[str, List[dict]] = defaultdict(list)
    for r in records:
        groups[r["source_audio"]].append(r)

    sources = list(groups.keys())
    rng = random.Random(seed)
    rng.shuffle(sources)

    n_total = len(sources)
    n_test  = max(1, int(n_total * test_ratio))
    n_eval  = max(1, int(n_total * eval_ratio))

    test_sources  = sources[:n_test]
    eval_sources  = sources[n_test : n_test + n_eval]
    train_sources = sources[n_test + n_eval :]

    def flatten(src_list):
        result = []
        for s in src_list:
            result.extend(groups[s])
        return result

    train_records = flatten(train_sources)
    eval_records  = flatten(eval_sources)
    test_records  = flatten(test_sources)

    logger.info(
        "Dataset split (by source) — "
        "train: %d segments (%d sources)  "
        "eval: %d segments (%d sources)  "
        "test: %d segments (%d sources)",
        len(train_records), len(train_sources),
        len(eval_records),  len(eval_sources),
        len(test_records),  len(test_sources),
    )

    return DatasetDict({
        "train": records_to_dataset(train_records),
        "eval":  records_to_dataset(eval_records),
        "test":  records_to_dataset(test_records),
    })


def save_manifest(records: List[dict], path: Path | str) -> None:
    """Save records as JSON for reproducibility / inspection."""
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as fh:
        json.dump(records, fh, ensure_ascii=False, indent=2)
    logger.info("Manifest saved to %s", path)


def load_manifest(path: Path | str) -> List[dict]:
    with Path(path).open(encoding="utf-8") as fh:
        return json.load(fh)


def build_and_save(
    records: List[dict],
    dataset_dir: Path | str,
    eval_ratio: float = 0.1,
    test_ratio: float = 0.1,
) -> DatasetDict:
    """Build, split (train/eval/test), and save the DatasetDict to disk."""
    dataset_dir = Path(dataset_dir)
    logger.info("Building dataset from %d records (eval=%.0f%%, test=%.0f%%) ...",
                len(records), eval_ratio * 100, test_ratio * 100)
    dd = split_dataset(records, eval_ratio=eval_ratio, test_ratio=test_ratio)
    hf_path = dataset_dir / "hf_dataset"
    logger.info("Saving HuggingFace DatasetDict to %s ...", hf_path)
    dd.save_to_disk(str(hf_path))
    save_manifest(records, dataset_dir / "manifest.json")
    logger.info("Dataset build complete — saved to %s", dataset_dir)
    return dd


def load_saved_dataset(dataset_dir: Path | str) -> DatasetDict:
    hf_path = Path(dataset_dir) / "hf_dataset"
    logger.info("Loading dataset from %s ...", hf_path)
    dd = datasets.load_from_disk(str(hf_path))
    for split_name, split_ds in dd.items():
        logger.info("  %-8s: %d samples", split_name, len(split_ds))
    logger.info("Dataset loaded successfully")
    return dd