Spaces:
Running
Running
File size: 5,193 Bytes
72bc633 012ffc6 72bc633 012ffc6 72bc633 58f6308 012ffc6 72bc633 012ffc6 72bc633 012ffc6 72bc633 012ffc6 72bc633 58f6308 72bc633 012ffc6 72bc633 012ffc6 72bc633 012ffc6 72bc633 012ffc6 72bc633 012ffc6 72bc633 012ffc6 72bc633 012ffc6 72bc633 012ffc6 72bc633 012ffc6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """Tests for PatchHawkEnv (OpenEnv compliance + reward logic)."""
import pytest
from patchhawk.agent.environment import PatchHawkEnv
from patchhawk.env_models import PatchHawkAction, PatchHawkObservation, PatchHawkState
@pytest.fixture
def env():
"""Create a PatchHawkEnv with the default scenarios file."""
e = PatchHawkEnv(use_docker=False)
yield e
e.close()
# ββ Basic API βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_reset_returns_observation(env):
"""reset() returns a PatchHawkObservation (OpenEnv API)."""
obs = env.reset()
assert isinstance(obs, PatchHawkObservation)
assert hasattr(obs, "code_snippet")
assert hasattr(obs, "static_flags")
assert hasattr(obs, "risk_score")
assert hasattr(obs, "done")
assert hasattr(obs, "reward")
assert hasattr(obs, "metadata")
def test_observation_fields(env):
"""Verify observation field types."""
obs = env.reset()
assert isinstance(obs.code_snippet, str)
assert isinstance(obs.static_flags, list)
assert isinstance(obs.risk_score, float)
assert isinstance(obs.done, bool)
assert isinstance(obs.metadata, dict)
def test_step_returns_observation(env):
"""step() returns a PatchHawkObservation (OpenEnv API)."""
env.reset()
action = PatchHawkAction(action_type=env.ACTION_ANALYZE)
obs = env.step(action)
assert isinstance(obs, PatchHawkObservation)
assert isinstance(obs.reward, (int, float))
assert isinstance(obs.done, bool)
assert isinstance(obs.metadata, dict)
def test_state_property(env):
"""state property returns a PatchHawkState."""
env.reset()
state = env.state
assert isinstance(state, PatchHawkState)
assert hasattr(state, "episode_id")
assert hasattr(state, "step_count")
assert hasattr(state, "scenario_id")
def test_all_action_types_accepted(env):
"""All five action types (0-4) are accepted."""
for action_type in range(5):
obs = env.reset()
action = PatchHawkAction(action_type=action_type)
result = env.step(action)
assert isinstance(result, PatchHawkObservation)
# ββ Reward logic ββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_block_malicious_positive_reward(env):
malicious = [s for s in env.scenarios if s.get("label") == "malicious"]
if not malicious:
pytest.skip("No malicious scenarios available")
env.reset(scenario=malicious[0])
action = PatchHawkAction(action_type=env.ACTION_BLOCK_PR)
obs = env.step(action)
assert obs.reward == 2.0
assert obs.done is True
def test_block_benign_negative_reward(env):
benign = [s for s in env.scenarios if s.get("label") == "benign"]
if not benign:
pytest.skip("No benign scenarios available")
env.reset(scenario=benign[0])
action = PatchHawkAction(action_type=env.ACTION_BLOCK_PR)
obs = env.step(action)
assert obs.reward == -1.0
assert obs.done is True
def test_execute_sandbox_reward(env):
env.reset()
action = PatchHawkAction(action_type=env.ACTION_EXECUTE_SANDBOX)
obs = env.step(action)
assert obs.reward == 0.1
assert obs.done is False
def test_analyze_no_reward(env):
env.reset()
action = PatchHawkAction(action_type=env.ACTION_ANALYZE)
obs = env.step(action)
assert obs.reward == 0.0
assert obs.done is False
def test_request_review_terminates(env):
env.reset()
action = PatchHawkAction(action_type=env.ACTION_REQUEST_REVIEW)
obs = env.step(action)
assert obs.reward == 0.0
assert obs.done is True
def test_max_steps_penalty(env):
malicious = [s for s in env.scenarios if s.get("label") == "malicious"]
if not malicious:
pytest.skip("No malicious scenarios available")
env.reset(scenario=malicious[0])
action = PatchHawkAction(action_type=env.ACTION_ANALYZE)
obs = None
for _ in range(env.max_steps):
obs = env.step(action)
if obs.done:
break
# Last step on malicious without block/patch β -5.0
assert obs.reward == -5.0
assert obs.done is True
def test_episode_with_scenario_kwarg(env):
"""Verify that passing a scenario via kwargs works."""
scenario = {
"id": "test_custom",
"type": "functional",
"label": "benign",
"code_snippet": "x = 42",
"patch": None,
"unit_test_code": None,
"attack_type": None,
}
obs = env.reset(scenario=scenario)
assert obs.code_snippet == "x = 42"
assert obs.metadata["scenario_id"] == "test_custom"
def test_step_counter_increments(env):
"""Verify step counter tracks correctly."""
env.reset()
for i in range(3):
action = PatchHawkAction(action_type=env.ACTION_ANALYZE)
env.step(action)
assert env.state.step_count == 3
def test_close_resets_scenario(env):
"""close() clears episode state."""
env.reset()
env.close()
assert env.current_scenario is None
|