jtowarek commited on
Commit
4550628
·
verified ·
1 Parent(s): 141582a

Delete train/splits.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train/splits.py +0 -77
train/splits.py DELETED
@@ -1,77 +0,0 @@
1
- """Deterministic stratified train/eval game split."""
2
-
3
- from __future__ import annotations
4
-
5
- import random
6
- from typing import Dict, FrozenSet, List, Set, Tuple
7
-
8
- from common.games_meta.game_tags import GAME_TAGS
9
- from constant_definitions.batch4.tag_constants import CATEGORIES
10
- from constant_definitions.game_constants import EVAL_ZERO, EVAL_ONE
11
- from constant_definitions.train.split_constants import (
12
- MIN_EVAL_TAG_FRACTION_DENOMINATOR,
13
- MIN_EVAL_TAG_FRACTION_NUMERATOR,
14
- SPLIT_SEED,
15
- TRAIN_FRACTION_DENOMINATOR,
16
- TRAIN_FRACTION_NUMERATOR,
17
- )
18
-
19
- # Domain tags are used for stratification
20
- _DOMAIN_TAGS: List[str] = CATEGORIES["domain"]
21
-
22
-
23
- def get_train_eval_split(
24
- seed: int = SPLIT_SEED,
25
- ) -> Tuple[FrozenSet[str], FrozenSet[str]]:
26
- """Return (train_games, eval_games) as frozen sets of game keys.
27
-
28
- The split is deterministic for a given seed and stratified so that
29
- every domain tag has at least ``MIN_EVAL_TAG_FRACTION`` representation
30
- in the eval set.
31
- """
32
- all_games = sorted(GAME_TAGS.keys())
33
- rng = random.Random(seed)
34
-
35
- # Build domain -> games index
36
- domain_to_games: Dict[str, List[str]] = {tag: [] for tag in _DOMAIN_TAGS}
37
- for game_key in all_games:
38
- tags = GAME_TAGS[game_key]
39
- for dtag in _DOMAIN_TAGS:
40
- if dtag in tags:
41
- domain_to_games[dtag].append(game_key)
42
-
43
- # Guarantee minimum eval representation per domain
44
- eval_set: Set[str] = set()
45
- for dtag in _DOMAIN_TAGS:
46
- games_with_tag = domain_to_games[dtag]
47
- if not games_with_tag:
48
- continue
49
- min_eval = _min_eval_count(len(games_with_tag))
50
- already_in_eval = [g for g in games_with_tag if g in eval_set]
51
- needed = min_eval - len(already_in_eval)
52
- if needed > EVAL_ZERO:
53
- candidates = [g for g in games_with_tag if g not in eval_set]
54
- rng.shuffle(candidates)
55
- for g in candidates[:needed]:
56
- eval_set.add(g)
57
-
58
- # Fill remaining eval slots up to target size
59
- total = len(all_games)
60
- target_train = (total * TRAIN_FRACTION_NUMERATOR) // TRAIN_FRACTION_DENOMINATOR
61
- target_eval = total - target_train
62
- remaining = [g for g in all_games if g not in eval_set]
63
- rng.shuffle(remaining)
64
- slots_to_fill = target_eval - len(eval_set)
65
- if slots_to_fill > EVAL_ZERO:
66
- for g in remaining[:slots_to_fill]:
67
- eval_set.add(g)
68
-
69
- train_set = frozenset(g for g in all_games if g not in eval_set)
70
- return train_set, frozenset(eval_set)
71
-
72
-
73
- def _min_eval_count(tag_total: int) -> int:
74
- """Minimum number of games with a given tag that must be in eval."""
75
- _numer = tag_total * MIN_EVAL_TAG_FRACTION_NUMERATOR
76
- result = (_numer + MIN_EVAL_TAG_FRACTION_DENOMINATOR - EVAL_ONE) // MIN_EVAL_TAG_FRACTION_DENOMINATOR
77
- return max(result, EVAL_ONE)