Spaces:
Runtime error
Runtime error
| """Tests for OpenEnv-compatible model serialization.""" | |
| import json | |
| from open_range.server.models import RangeAction, RangeObservation, RangeState | |
| class TestRangeAction: | |
| """RangeAction serialization.""" | |
| def test_create(self): | |
| a = RangeAction(command="nmap -sV web", mode="red") | |
| assert a.command == "nmap -sV web" | |
| assert a.mode == "red" | |
| def test_serialize_deserialize(self): | |
| a = RangeAction(command="curl http://web/", mode="red") | |
| data = a.model_dump() | |
| assert data["command"] == "curl http://web/" | |
| assert data["mode"] == "red" | |
| a2 = RangeAction(**data) | |
| assert a2.command == a.command | |
| assert a2.mode == a.mode | |
| def test_mode_literal(self): | |
| """mode must be 'red' or 'blue'.""" | |
| a = RangeAction(command="x", mode="blue") | |
| assert a.mode == "blue" | |
| def test_json_roundtrip(self): | |
| a = RangeAction(command="hydra -l admin web ssh", mode="red") | |
| js = a.model_dump_json() | |
| a2 = RangeAction.model_validate_json(js) | |
| assert a2.command == a.command | |
| assert a2.mode == a.mode | |
| class TestRangeObservation: | |
| """RangeObservation inherits done/reward from Observation base.""" | |
| def test_defaults(self): | |
| obs = RangeObservation() | |
| assert obs.done is False | |
| assert obs.reward is None | |
| assert obs.stdout == "" | |
| assert obs.stderr == "" | |
| assert obs.flags_captured == [] | |
| assert obs.alerts == [] | |
| def test_done_inherited(self): | |
| obs = RangeObservation(done=True, reward=1.0) | |
| assert obs.done is True | |
| assert obs.reward == 1.0 | |
| def test_custom_fields(self): | |
| obs = RangeObservation( | |
| stdout="flag found", | |
| flags_captured=["FLAG{test}"], | |
| alerts=["IDS alert"], | |
| ) | |
| assert obs.flags_captured == ["FLAG{test}"] | |
| assert obs.alerts == ["IDS alert"] | |
| def test_json_roundtrip(self): | |
| obs = RangeObservation( | |
| stdout="output", | |
| stderr="err", | |
| done=True, | |
| reward=0.5, | |
| flags_captured=["FLAG{a}"], | |
| alerts=["alert1"], | |
| ) | |
| js = obs.model_dump_json() | |
| obs2 = RangeObservation.model_validate_json(js) | |
| assert obs2.stdout == obs.stdout | |
| assert obs2.done is True | |
| assert obs2.reward == 0.5 | |
| assert obs2.flags_captured == ["FLAG{a}"] | |
| class TestRangeState: | |
| """RangeState inherits episode_id/step_count from State base.""" | |
| def test_defaults(self): | |
| s = RangeState() | |
| assert s.episode_id is None | |
| assert s.step_count == 0 | |
| assert s.mode == "" | |
| assert s.tier == 1 | |
| assert s.flags_found == [] | |
| def test_inherited_fields(self): | |
| s = RangeState(episode_id="ep_42", step_count=5) | |
| assert s.episode_id == "ep_42" | |
| assert s.step_count == 5 | |
| def test_custom_fields(self): | |
| s = RangeState( | |
| episode_id="ep1", | |
| mode="red", | |
| tier=3, | |
| flags_found=["FLAG{a}"], | |
| services_status={"web": "running"}, | |
| ) | |
| assert s.tier == 3 | |
| assert s.flags_found == ["FLAG{a}"] | |
| assert s.services_status["web"] == "running" | |
| def test_json_roundtrip(self): | |
| s = RangeState( | |
| episode_id="ep99", | |
| step_count=10, | |
| mode="blue", | |
| flags_found=["FLAG{x}"], | |
| tier=2, | |
| ) | |
| js = s.model_dump_json() | |
| s2 = RangeState.model_validate_json(js) | |
| assert s2.episode_id == "ep99" | |
| assert s2.step_count == 10 | |
| assert s2.mode == "blue" | |
| assert s2.tier == 2 | |
| def test_model_dump_and_back(self): | |
| s = RangeState(episode_id="e1", step_count=3, mode="red", tier=1) | |
| d = s.model_dump() | |
| s2 = RangeState(**d) | |
| assert s2.episode_id == s.episode_id | |
| assert s2.step_count == s.step_count | |