File size: 2,185 Bytes
fad16c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations

import json
import random
from pathlib import Path
from typing import Iterable


def load_jsonl(path: Path) -> list[dict]:
    if not path.exists():
        raise FileNotFoundError(f"JSONL file not found: {path}")
    rows: list[dict] = []
    with path.open(encoding="utf-8") as fp:
        for line in fp:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    if not rows:
        raise ValueError(f"No rows found in: {path}")
    return rows


def split_scenarios(
    scenarios: list[dict],
    train_fraction: float = 0.8,
    seed: int = 42,
) -> tuple[list[dict], list[dict]]:
    if not 0.1 <= train_fraction <= 0.95:
        raise ValueError("train_fraction must be in [0.1, 0.95]")
    if len(scenarios) < 4:
        raise ValueError("Need at least 4 scenarios to split train/eval")

    shuffled = list(scenarios)
    random.Random(seed).shuffle(shuffled)
    cut = max(1, min(len(shuffled) - 1, int(round(len(shuffled) * train_fraction))))
    train_rows = shuffled[:cut]
    eval_rows = shuffled[cut:]

    train_ids = {row["id"] for row in train_rows}
    eval_ids = {row["id"] for row in eval_rows}
    overlap = train_ids & eval_ids
    if overlap:
        raise ValueError(f"Scenario leakage between train/eval split: {sorted(overlap)}")
    return train_rows, eval_rows


def write_json(path: Path, payload: dict) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2), encoding="utf-8")


def mean(values: Iterable[float]) -> float:
    values = list(values)
    if not values:
        return 0.0
    return sum(values) / len(values)


def ensure_metric_keys(metrics: dict, required_keys: list[str]) -> None:
    missing = [key for key in required_keys if key not in metrics]
    if missing:
        raise ValueError(f"Missing metric keys: {missing}")


def collect_rewards_with_autosubmit(environments) -> list[float]:
    rewards: list[float] = []
    for env in environments:
        if not getattr(env, "done", False):
            env.submit()
        rewards.append(float(getattr(env, "reward", 0.0)))
    return rewards