File size: 1,791 Bytes
1595dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ec70de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from models import PythonCodeReviewAction
from server.env import PythonCodeReviewEnvironment


FIXED_SYNTAX_CODE = """def normalize_username(raw_name: str) -> str:
    cleaned = raw_name.strip().lower()
    if not cleaned:
        return "anonymous"
    return cleaned.replace(" ", "_")
"""


def test_reward_changes_across_five_steps():
    env = PythonCodeReviewEnvironment(verbose=False)
    env.reset(task_id="syntax-fix-easy")

    actions = [
        PythonCodeReviewAction(action_type="analyze_code"),
        PythonCodeReviewAction(action_type="analyze_code"),
        PythonCodeReviewAction(action_type="run_tests"),
        PythonCodeReviewAction(action_type="edit_code", code=FIXED_SYNTAX_CODE),
        PythonCodeReviewAction(action_type="submit_solution"),
    ]

    rewards = []
    for action in actions:
        observation = env.step(action)
        rewards.append(float(observation.reward or 0.0))

    assert all(-1.0 <= reward <= 1.0 for reward in rewards)
    assert len(set(rewards)) > 1
    assert any(reward > 0 for reward in rewards)
    assert any(reward < 0 for reward in rewards)
    assert not any(
        rewards[index] == rewards[index + 1] == rewards[index + 2]
        for index in range(len(rewards) - 2)
    )


def test_repeated_no_progress_actions_do_not_flatline_three_steps():
    env = PythonCodeReviewEnvironment(verbose=False)
    env.reset(task_id="bug-fix-medium")

    rewards = []
    for _ in range(5):
        observation = env.step(PythonCodeReviewAction(action_type="analyze_code"))
        rewards.append(float(observation.reward or 0.0))

    assert all(-1.0 <= reward <= 1.0 for reward in rewards)
    assert not any(
        rewards[index] == rewards[index + 1] == rewards[index + 2]
        for index in range(len(rewards) - 2)
    )