soc-openenv / tests /test_environment.py
Battlecon's picture
Flatten diretory structure
03b6159
Raw
History Blame Contribute Delete
2.91 kB
"""Core environment contract tests."""
import pytest
from soc_env import SOCEnv, Action
from soc_env.models import ActionType, Observation, Reward, EnvState
@pytest.mark.parametrize("task_id", SOCEnv.TASK_IDS)
def test_reset_returns_observation(task_id):
env = SOCEnv(task_id=task_id, seed=42)
obs = env.reset()
assert isinstance(obs, Observation)
assert obs.step == 0
assert len(obs.active_alerts) > 0
assert len(obs.hosts) > 0
@pytest.mark.parametrize("task_id", SOCEnv.TASK_IDS)
def test_reset_is_reproducible(task_id):
e1 = SOCEnv(task_id=task_id, seed=42); o1 = e1.reset()
e2 = SOCEnv(task_id=task_id, seed=42); o2 = e2.reset()
assert o1.model_dump() == o2.model_dump()
@pytest.mark.parametrize("task_id", SOCEnv.TASK_IDS)
def test_step_returns_correct_types(task_id):
env = SOCEnv(task_id=task_id, seed=42)
obs = env.reset()
action = Action(action_type=ActionType.ENRICH_ALERT,
alert_id=obs.active_alerts[0].alert_id, source="threat_intel")
obs2, reward, done, info = env.step(action)
assert isinstance(obs2, Observation)
assert isinstance(reward, Reward)
assert isinstance(done, bool)
assert isinstance(info, dict)
assert -1.0 <= reward.total <= 1.0
@pytest.mark.parametrize("task_id", SOCEnv.TASK_IDS)
def test_state_returns_envstate(task_id):
env = SOCEnv(task_id=task_id, seed=42); env.reset()
s = env.state()
assert isinstance(s, EnvState)
assert s.task_id == task_id
def test_step_raises_before_reset():
env = SOCEnv()
with pytest.raises(RuntimeError, match="reset"):
env.step(Action(action_type=ActionType.CREATE_TICKET, priority="P1"))
def test_episode_terminates():
env = SOCEnv(task_id="alert_triage", seed=42); env.reset()
for _ in range(SOCEnv.MAX_STEPS["alert_triage"] + 2):
s = env.state()
if s.done: break
alerts = s.observation.active_alerts
action = (Action(action_type=ActionType.ENRICH_ALERT,
alert_id=alerts[0].alert_id, source="threat_intel")
if alerts else
Action(action_type=ActionType.CREATE_TICKET, priority="P3", summary="done"))
_, _, done, _ = env.step(action)
if done: break
assert env.state().done
def test_hard_block_prevents_isolation():
env = SOCEnv(task_id="constrained_incident_response", seed=42); env.reset()
obs, _, _, _ = env.step(Action(action_type=ActionType.ISOLATE_ENDPOINT, host_id="HOST-CEO"))
assert "BLOCKED" in (obs.last_action_result or "")
assert obs.last_action_success is False
def test_ground_truth_hidden_in_observation():
env = SOCEnv(task_id="alert_triage", seed=42)
obs = env.reset()
obs_dict = obs.model_dump_safe()
for alert in obs_dict.get("active_alerts", []):
assert "ground_truth" not in alert, "ground_truth must never reach the agent"