| |
| """ |
| Deterministic train/validation/test split assignment. |
| |
| The policy groups clips before splitting to avoid leakage. Speaker clusters |
| take priority when present; otherwise video_id is used as the grouping key. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import hashlib |
| from collections import Counter |
| from typing import Any, Dict, List |
|
|
|
|
| def _stable_fraction(value: str) -> float: |
| digest = hashlib.sha256(value.encode("utf-8")).hexdigest() |
| return int(digest[:16], 16) / float(0xFFFFFFFFFFFFFFFF) |
|
|
|
|
| def _split_for_group(group_key: str, seed: str, train_ratio: float, val_ratio: float) -> str: |
| x = _stable_fraction(f"{seed}:{group_key}") |
| if x < train_ratio: |
| return "train" |
| if x < train_ratio + val_ratio: |
| return "val" |
| return "test" |
|
|
|
|
| def _group_key(row: Dict[str, Any]) -> str: |
| speaker = str(row.get("speaker_cluster_id", "") or "").strip() |
| if speaker: |
| return f"speaker:{speaker}" |
| video_id = str(row.get("video_id", "") or "unknown").strip() |
| return f"video:{video_id}" |
|
|
|
|
| def assign_splits(rows: List[Dict[str, Any]], config: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """Mutate and return rows with deterministic split metadata.""" |
| if not config.get("enable_split_assignment", True): |
| for row in rows: |
| row["split"] = row.get("split", "unsplit") |
| row["split_policy"] = "disabled" |
| return rows |
|
|
| train_ratio = float(config.get("split_train_ratio", 0.90)) |
| val_ratio = float(config.get("split_val_ratio", 0.05)) |
| if train_ratio <= 0 or val_ratio < 0 or train_ratio + val_ratio >= 1.0: |
| raise ValueError("Invalid split ratios: require train_ratio > 0 and train_ratio + val_ratio < 1.0") |
|
|
| seed = str(config.get("split_seed", "sinhala_tts_llm_cc_v1")) |
| assignments: Dict[str, str] = {} |
| for row in rows: |
| key = _group_key(row) |
| if key not in assignments: |
| assignments[key] = _split_for_group(key, seed, train_ratio, val_ratio) |
| row["split"] = assignments[key] |
| row["split_group_key"] = key |
| row["split_policy"] = "speaker_cluster_else_video_hash_v1" |
|
|
| return rows |
|
|
|
|
| def split_distribution(rows: List[Dict[str, Any]]) -> Dict[str, int]: |
| return dict(Counter(str(row.get("split", "unknown") or "unknown") for row in rows)) |
|
|