File size: 2,090 Bytes
557930c
 
 
4de7d31
557930c
 
 
 
 
 
 
 
 
 
 
 
 
4de7d31
557930c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4de7d31
557930c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Tests for baseline_runner and inference helpers."""

from baseline_runner import run_baseline_episodes, _heuristic_episode
from server.environment import CloudNativeDebugEnvironment
from server.tasks.task_registry import TASK_REGISTRY


def test_heuristic_baseline_scores_above_zero_on_most_scenarios():
    """Heuristic baseline should score > 0 on most scenarios.

    Some scenarios (e.g. reordering steps) can't be solved by simple
    contains-based heuristics, so we allow a few zeros.
    """
    total = 0
    nonzero = 0
    for task_id, task_cls in TASK_REGISTRY.items():
        for scenario in task_cls.SCENARIOS:
            env = CloudNativeDebugEnvironment()
            result = _heuristic_episode(env, task_id, scenario["id"])
            total += 1
            if result.score > 0.0:
                nonzero += 1
    # At least 80% of scenarios should get > 0
    assert nonzero / total >= 0.8, f"Only {nonzero}/{total} scenarios scored > 0"


def test_run_baseline_episodes_single_task():
    results = run_baseline_episodes(task_id="dockerfile_syntax", num_episodes=1)
    assert len(results) == 1
    assert results[0].task_id == "dockerfile_syntax"
    assert results[0].score >= 0.0


def test_run_baseline_episodes_all_tasks():
    results = run_baseline_episodes(task_id=None, num_episodes=1)
    assert len(results) == len(TASK_REGISTRY)
    task_ids_seen = {r.task_id for r in results}
    assert task_ids_seen == set(TASK_REGISTRY.keys())


def test_heuristic_fixes_easy_tasks_well():
    """Easy tasks should score >= 0.5 with heuristic baseline."""
    easy_tasks = [tid for tid, cls in TASK_REGISTRY.items() if cls.DIFFICULTY.value == "easy"]
    for task_id in easy_tasks:
        task_cls = TASK_REGISTRY[task_id]
        scores = []
        for scenario in task_cls.SCENARIOS:
            env = CloudNativeDebugEnvironment()
            result = _heuristic_episode(env, task_id, scenario["id"])
            scores.append(result.score)
        avg = sum(scores) / len(scores)
        assert avg >= 0.3, f"Easy task {task_id} avg score {avg:.2f} too low"