Spaces:
Paused
Paused
Delete train/splits.py with huggingface_hub
Browse files- train/splits.py +0 -77
train/splits.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
"""Deterministic stratified train/eval game split."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import random
|
| 6 |
-
from typing import Dict, FrozenSet, List, Set, Tuple
|
| 7 |
-
|
| 8 |
-
from common.games_meta.game_tags import GAME_TAGS
|
| 9 |
-
from constant_definitions.batch4.tag_constants import CATEGORIES
|
| 10 |
-
from constant_definitions.game_constants import EVAL_ZERO, EVAL_ONE
|
| 11 |
-
from constant_definitions.train.split_constants import (
|
| 12 |
-
MIN_EVAL_TAG_FRACTION_DENOMINATOR,
|
| 13 |
-
MIN_EVAL_TAG_FRACTION_NUMERATOR,
|
| 14 |
-
SPLIT_SEED,
|
| 15 |
-
TRAIN_FRACTION_DENOMINATOR,
|
| 16 |
-
TRAIN_FRACTION_NUMERATOR,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
# Domain tags are used for stratification
|
| 20 |
-
_DOMAIN_TAGS: List[str] = CATEGORIES["domain"]
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def get_train_eval_split(
|
| 24 |
-
seed: int = SPLIT_SEED,
|
| 25 |
-
) -> Tuple[FrozenSet[str], FrozenSet[str]]:
|
| 26 |
-
"""Return (train_games, eval_games) as frozen sets of game keys.
|
| 27 |
-
|
| 28 |
-
The split is deterministic for a given seed and stratified so that
|
| 29 |
-
every domain tag has at least ``MIN_EVAL_TAG_FRACTION`` representation
|
| 30 |
-
in the eval set.
|
| 31 |
-
"""
|
| 32 |
-
all_games = sorted(GAME_TAGS.keys())
|
| 33 |
-
rng = random.Random(seed)
|
| 34 |
-
|
| 35 |
-
# Build domain -> games index
|
| 36 |
-
domain_to_games: Dict[str, List[str]] = {tag: [] for tag in _DOMAIN_TAGS}
|
| 37 |
-
for game_key in all_games:
|
| 38 |
-
tags = GAME_TAGS[game_key]
|
| 39 |
-
for dtag in _DOMAIN_TAGS:
|
| 40 |
-
if dtag in tags:
|
| 41 |
-
domain_to_games[dtag].append(game_key)
|
| 42 |
-
|
| 43 |
-
# Guarantee minimum eval representation per domain
|
| 44 |
-
eval_set: Set[str] = set()
|
| 45 |
-
for dtag in _DOMAIN_TAGS:
|
| 46 |
-
games_with_tag = domain_to_games[dtag]
|
| 47 |
-
if not games_with_tag:
|
| 48 |
-
continue
|
| 49 |
-
min_eval = _min_eval_count(len(games_with_tag))
|
| 50 |
-
already_in_eval = [g for g in games_with_tag if g in eval_set]
|
| 51 |
-
needed = min_eval - len(already_in_eval)
|
| 52 |
-
if needed > EVAL_ZERO:
|
| 53 |
-
candidates = [g for g in games_with_tag if g not in eval_set]
|
| 54 |
-
rng.shuffle(candidates)
|
| 55 |
-
for g in candidates[:needed]:
|
| 56 |
-
eval_set.add(g)
|
| 57 |
-
|
| 58 |
-
# Fill remaining eval slots up to target size
|
| 59 |
-
total = len(all_games)
|
| 60 |
-
target_train = (total * TRAIN_FRACTION_NUMERATOR) // TRAIN_FRACTION_DENOMINATOR
|
| 61 |
-
target_eval = total - target_train
|
| 62 |
-
remaining = [g for g in all_games if g not in eval_set]
|
| 63 |
-
rng.shuffle(remaining)
|
| 64 |
-
slots_to_fill = target_eval - len(eval_set)
|
| 65 |
-
if slots_to_fill > EVAL_ZERO:
|
| 66 |
-
for g in remaining[:slots_to_fill]:
|
| 67 |
-
eval_set.add(g)
|
| 68 |
-
|
| 69 |
-
train_set = frozenset(g for g in all_games if g not in eval_set)
|
| 70 |
-
return train_set, frozenset(eval_set)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _min_eval_count(tag_total: int) -> int:
|
| 74 |
-
"""Minimum number of games with a given tag that must be in eval."""
|
| 75 |
-
_numer = tag_total * MIN_EVAL_TAG_FRACTION_NUMERATOR
|
| 76 |
-
result = (_numer + MIN_EVAL_TAG_FRACTION_DENOMINATOR - EVAL_ONE) // MIN_EVAL_TAG_FRACTION_DENOMINATOR
|
| 77 |
-
return max(result, EVAL_ONE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|