File size: 5,068 Bytes
d7b3a74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import textwrap
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import pytest

from shinka.core import run_shinka_eval


def _write_program(tmp_path: Path, source: str) -> str:
    program_path = tmp_path / "program_eval.py"
    program_path.write_text(textwrap.dedent(source), encoding="utf-8")
    return str(program_path)


def _results_dir(tmp_path: Path, name: str) -> str:
    path = tmp_path / name
    path.mkdir(parents=True, exist_ok=True)
    return str(path)


def test_run_shinka_eval_parallel_matches_sequential(tmp_path: Path) -> None:
    program_path = _write_program(
        tmp_path,
        """
        import time

        def run_experiment(seed):
            # Deliberately finish out of order to verify ordered aggregation.
            time.sleep(0.01 * (6 - seed))
            return {"seed": seed}
        """,
    )

    def get_kwargs(run_idx: int) -> Dict[str, Any]:
        return {"seed": run_idx + 1}

    def aggregate_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
        ordered_seeds = [res["seed"] for res in results]
        return {
            "combined_score": float(sum(ordered_seeds)),
            "ordered_seeds": ordered_seeds,
        }

    def validate_result(result: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
        if result["seed"] % 2 == 0:
            return False, "even seed invalid"
        return True, None

    seq_metrics, seq_correct, seq_err = run_shinka_eval(
        program_path=program_path,
        results_dir=_results_dir(tmp_path, "seq"),
        experiment_fn_name="run_experiment",
        num_runs=5,
        get_experiment_kwargs=get_kwargs,
        aggregate_metrics_fn=aggregate_metrics,
        validate_fn=validate_result,
        run_workers=1,
    )
    par_metrics, par_correct, par_err = run_shinka_eval(
        program_path=program_path,
        results_dir=_results_dir(tmp_path, "par"),
        experiment_fn_name="run_experiment",
        num_runs=5,
        get_experiment_kwargs=get_kwargs,
        aggregate_metrics_fn=aggregate_metrics,
        validate_fn=validate_result,
        run_workers=3,
    )

    assert seq_metrics["ordered_seeds"] == [1, 2, 3, 4, 5]
    assert par_metrics["ordered_seeds"] == [1, 2, 3, 4, 5]
    assert par_metrics["combined_score"] == seq_metrics["combined_score"]
    assert par_metrics["num_valid_runs"] == seq_metrics["num_valid_runs"] == 3
    assert par_metrics["num_invalid_runs"] == seq_metrics["num_invalid_runs"] == 2
    assert par_metrics["all_validation_errors"] == seq_metrics[
        "all_validation_errors"
    ] == ["even seed invalid"]
    assert seq_correct is False
    assert par_correct is False
    assert seq_err == "Validation failed: even seed invalid"
    assert par_err == "Validation failed: even seed invalid"


def test_run_shinka_eval_parallel_worker_error_surfaces(tmp_path: Path) -> None:
    program_path = _write_program(
        tmp_path,
        """
        def run_experiment(seed):
            if seed == 3:
                raise RuntimeError("boom seed 3")
            return seed
        """,
    )

    metrics, correct, error_msg = run_shinka_eval(
        program_path=program_path,
        results_dir=_results_dir(tmp_path, "worker_error"),
        experiment_fn_name="run_experiment",
        num_runs=4,
        run_workers=2,
    )

    assert correct is False
    assert error_msg is not None
    assert "Run 3/4 failed in parallel evaluation" in error_msg
    assert "boom seed 3" in error_msg
    assert metrics["combined_score"] == 0.0


def test_parallel_mode_rejects_early_stop(tmp_path: Path) -> None:
    program_path = _write_program(
        tmp_path,
        """
        def run_experiment(seed):
            return float(seed)
        """,
    )

    with pytest.raises(
        ValueError, match="Early stopping is only supported in sequential mode"
    ):
        run_shinka_eval(
            program_path=program_path,
            results_dir=_results_dir(tmp_path, "early_stop_guard"),
            experiment_fn_name="run_experiment",
            num_runs=4,
            run_workers=2,
            early_stop_method="ci",
            early_stop_threshold=0.5,
        )


def test_invalid_worker_configuration_raises(tmp_path: Path) -> None:
    program_path = _write_program(
        tmp_path,
        """
        def run_experiment(seed):
            return seed
        """,
    )

    with pytest.raises(ValueError, match="run_workers must be >= 1"):
        run_shinka_eval(
            program_path=program_path,
            results_dir=_results_dir(tmp_path, "bad_workers"),
            experiment_fn_name="run_experiment",
            num_runs=1,
            run_workers=0,
        )

    with pytest.raises(ValueError, match="max_workers_cap must be >= 1"):
        run_shinka_eval(
            program_path=program_path,
            results_dir=_results_dir(tmp_path, "bad_cap"),
            experiment_fn_name="run_experiment",
            num_runs=1,
            run_workers=1,
            max_workers_cap=0,
        )