| import textwrap |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import pytest |
|
|
| from shinka.core import run_shinka_eval |
|
|
|
|
| def _write_program(tmp_path: Path, source: str) -> str: |
| program_path = tmp_path / "program_eval.py" |
| program_path.write_text(textwrap.dedent(source), encoding="utf-8") |
| return str(program_path) |
|
|
|
|
| def _results_dir(tmp_path: Path, name: str) -> str: |
| path = tmp_path / name |
| path.mkdir(parents=True, exist_ok=True) |
| return str(path) |
|
|
|
|
| def test_run_shinka_eval_parallel_matches_sequential(tmp_path: Path) -> None: |
| program_path = _write_program( |
| tmp_path, |
| """ |
| import time |
| |
| def run_experiment(seed): |
| # Deliberately finish out of order to verify ordered aggregation. |
| time.sleep(0.01 * (6 - seed)) |
| return {"seed": seed} |
| """, |
| ) |
|
|
| def get_kwargs(run_idx: int) -> Dict[str, Any]: |
| return {"seed": run_idx + 1} |
|
|
| def aggregate_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: |
| ordered_seeds = [res["seed"] for res in results] |
| return { |
| "combined_score": float(sum(ordered_seeds)), |
| "ordered_seeds": ordered_seeds, |
| } |
|
|
| def validate_result(result: Dict[str, Any]) -> Tuple[bool, Optional[str]]: |
| if result["seed"] % 2 == 0: |
| return False, "even seed invalid" |
| return True, None |
|
|
| seq_metrics, seq_correct, seq_err = run_shinka_eval( |
| program_path=program_path, |
| results_dir=_results_dir(tmp_path, "seq"), |
| experiment_fn_name="run_experiment", |
| num_runs=5, |
| get_experiment_kwargs=get_kwargs, |
| aggregate_metrics_fn=aggregate_metrics, |
| validate_fn=validate_result, |
| run_workers=1, |
| ) |
| par_metrics, par_correct, par_err = run_shinka_eval( |
| program_path=program_path, |
| results_dir=_results_dir(tmp_path, "par"), |
| experiment_fn_name="run_experiment", |
| num_runs=5, |
| get_experiment_kwargs=get_kwargs, |
| aggregate_metrics_fn=aggregate_metrics, |
| validate_fn=validate_result, |
| run_workers=3, |
| ) |
|
|
| assert seq_metrics["ordered_seeds"] == [1, 2, 3, 4, 5] |
| assert par_metrics["ordered_seeds"] == [1, 2, 3, 4, 5] |
| assert par_metrics["combined_score"] == seq_metrics["combined_score"] |
| assert par_metrics["num_valid_runs"] == seq_metrics["num_valid_runs"] == 3 |
| assert par_metrics["num_invalid_runs"] == seq_metrics["num_invalid_runs"] == 2 |
| assert par_metrics["all_validation_errors"] == seq_metrics[ |
| "all_validation_errors" |
| ] == ["even seed invalid"] |
| assert seq_correct is False |
| assert par_correct is False |
| assert seq_err == "Validation failed: even seed invalid" |
| assert par_err == "Validation failed: even seed invalid" |
|
|
|
|
| def test_run_shinka_eval_parallel_worker_error_surfaces(tmp_path: Path) -> None: |
| program_path = _write_program( |
| tmp_path, |
| """ |
| def run_experiment(seed): |
| if seed == 3: |
| raise RuntimeError("boom seed 3") |
| return seed |
| """, |
| ) |
|
|
| metrics, correct, error_msg = run_shinka_eval( |
| program_path=program_path, |
| results_dir=_results_dir(tmp_path, "worker_error"), |
| experiment_fn_name="run_experiment", |
| num_runs=4, |
| run_workers=2, |
| ) |
|
|
| assert correct is False |
| assert error_msg is not None |
| assert "Run 3/4 failed in parallel evaluation" in error_msg |
| assert "boom seed 3" in error_msg |
| assert metrics["combined_score"] == 0.0 |
|
|
|
|
| def test_parallel_mode_rejects_early_stop(tmp_path: Path) -> None: |
| program_path = _write_program( |
| tmp_path, |
| """ |
| def run_experiment(seed): |
| return float(seed) |
| """, |
| ) |
|
|
| with pytest.raises( |
| ValueError, match="Early stopping is only supported in sequential mode" |
| ): |
| run_shinka_eval( |
| program_path=program_path, |
| results_dir=_results_dir(tmp_path, "early_stop_guard"), |
| experiment_fn_name="run_experiment", |
| num_runs=4, |
| run_workers=2, |
| early_stop_method="ci", |
| early_stop_threshold=0.5, |
| ) |
|
|
|
|
| def test_invalid_worker_configuration_raises(tmp_path: Path) -> None: |
| program_path = _write_program( |
| tmp_path, |
| """ |
| def run_experiment(seed): |
| return seed |
| """, |
| ) |
|
|
| with pytest.raises(ValueError, match="run_workers must be >= 1"): |
| run_shinka_eval( |
| program_path=program_path, |
| results_dir=_results_dir(tmp_path, "bad_workers"), |
| experiment_fn_name="run_experiment", |
| num_runs=1, |
| run_workers=0, |
| ) |
|
|
| with pytest.raises(ValueError, match="max_workers_cap must be >= 1"): |
| run_shinka_eval( |
| program_path=program_path, |
| results_dir=_results_dir(tmp_path, "bad_cap"), |
| experiment_fn_name="run_experiment", |
| num_runs=1, |
| run_workers=1, |
| max_workers_cap=0, |
| ) |
|
|