Spaces:
Runtime error
Runtime error
File size: 3,942 Bytes
8c486a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """Tests for OpenEnv-compatible model serialization."""
import json
from open_range.server.models import RangeAction, RangeObservation, RangeState
class TestRangeAction:
"""RangeAction serialization."""
def test_create(self):
a = RangeAction(command="nmap -sV web", mode="red")
assert a.command == "nmap -sV web"
assert a.mode == "red"
def test_serialize_deserialize(self):
a = RangeAction(command="curl http://web/", mode="red")
data = a.model_dump()
assert data["command"] == "curl http://web/"
assert data["mode"] == "red"
a2 = RangeAction(**data)
assert a2.command == a.command
assert a2.mode == a.mode
def test_mode_literal(self):
"""mode must be 'red' or 'blue'."""
a = RangeAction(command="x", mode="blue")
assert a.mode == "blue"
def test_json_roundtrip(self):
a = RangeAction(command="hydra -l admin web ssh", mode="red")
js = a.model_dump_json()
a2 = RangeAction.model_validate_json(js)
assert a2.command == a.command
assert a2.mode == a.mode
class TestRangeObservation:
"""RangeObservation inherits done/reward from Observation base."""
def test_defaults(self):
obs = RangeObservation()
assert obs.done is False
assert obs.reward is None
assert obs.stdout == ""
assert obs.stderr == ""
assert obs.flags_captured == []
assert obs.alerts == []
def test_done_inherited(self):
obs = RangeObservation(done=True, reward=1.0)
assert obs.done is True
assert obs.reward == 1.0
def test_custom_fields(self):
obs = RangeObservation(
stdout="flag found",
flags_captured=["FLAG{test}"],
alerts=["IDS alert"],
)
assert obs.flags_captured == ["FLAG{test}"]
assert obs.alerts == ["IDS alert"]
def test_json_roundtrip(self):
obs = RangeObservation(
stdout="output",
stderr="err",
done=True,
reward=0.5,
flags_captured=["FLAG{a}"],
alerts=["alert1"],
)
js = obs.model_dump_json()
obs2 = RangeObservation.model_validate_json(js)
assert obs2.stdout == obs.stdout
assert obs2.done is True
assert obs2.reward == 0.5
assert obs2.flags_captured == ["FLAG{a}"]
class TestRangeState:
"""RangeState inherits episode_id/step_count from State base."""
def test_defaults(self):
s = RangeState()
assert s.episode_id is None
assert s.step_count == 0
assert s.mode == ""
assert s.tier == 1
assert s.flags_found == []
def test_inherited_fields(self):
s = RangeState(episode_id="ep_42", step_count=5)
assert s.episode_id == "ep_42"
assert s.step_count == 5
def test_custom_fields(self):
s = RangeState(
episode_id="ep1",
mode="red",
tier=3,
flags_found=["FLAG{a}"],
services_status={"web": "running"},
)
assert s.tier == 3
assert s.flags_found == ["FLAG{a}"]
assert s.services_status["web"] == "running"
def test_json_roundtrip(self):
s = RangeState(
episode_id="ep99",
step_count=10,
mode="blue",
flags_found=["FLAG{x}"],
tier=2,
)
js = s.model_dump_json()
s2 = RangeState.model_validate_json(js)
assert s2.episode_id == "ep99"
assert s2.step_count == 10
assert s2.mode == "blue"
assert s2.tier == 2
def test_model_dump_and_back(self):
s = RangeState(episode_id="e1", step_count=3, mode="red", tier=1)
d = s.model_dump()
s2 = RangeState(**d)
assert s2.episode_id == s.episode_id
assert s2.step_count == s.step_count
|