File size: 2,090 Bytes
557930c 4de7d31 557930c 4de7d31 557930c 4de7d31 557930c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | """Tests for baseline_runner and inference helpers."""
from baseline_runner import run_baseline_episodes, _heuristic_episode
from server.environment import CloudNativeDebugEnvironment
from server.tasks.task_registry import TASK_REGISTRY
def test_heuristic_baseline_scores_above_zero_on_most_scenarios():
"""Heuristic baseline should score > 0 on most scenarios.
Some scenarios (e.g. reordering steps) can't be solved by simple
contains-based heuristics, so we allow a few zeros.
"""
total = 0
nonzero = 0
for task_id, task_cls in TASK_REGISTRY.items():
for scenario in task_cls.SCENARIOS:
env = CloudNativeDebugEnvironment()
result = _heuristic_episode(env, task_id, scenario["id"])
total += 1
if result.score > 0.0:
nonzero += 1
# At least 80% of scenarios should get > 0
assert nonzero / total >= 0.8, f"Only {nonzero}/{total} scenarios scored > 0"
def test_run_baseline_episodes_single_task():
results = run_baseline_episodes(task_id="dockerfile_syntax", num_episodes=1)
assert len(results) == 1
assert results[0].task_id == "dockerfile_syntax"
assert results[0].score >= 0.0
def test_run_baseline_episodes_all_tasks():
results = run_baseline_episodes(task_id=None, num_episodes=1)
assert len(results) == len(TASK_REGISTRY)
task_ids_seen = {r.task_id for r in results}
assert task_ids_seen == set(TASK_REGISTRY.keys())
def test_heuristic_fixes_easy_tasks_well():
"""Easy tasks should score >= 0.5 with heuristic baseline."""
easy_tasks = [tid for tid, cls in TASK_REGISTRY.items() if cls.DIFFICULTY.value == "easy"]
for task_id in easy_tasks:
task_cls = TASK_REGISTRY[task_id]
scores = []
for scenario in task_cls.SCENARIOS:
env = CloudNativeDebugEnvironment()
result = _heuristic_episode(env, task_id, scenario["id"])
scores.append(result.score)
avg = sum(scores) / len(scores)
assert avg >= 0.3, f"Easy task {task_id} avg score {avg:.2f} too low"
|