File size: 3,942 Bytes
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Tests for OpenEnv-compatible model serialization."""

import json

from open_range.server.models import RangeAction, RangeObservation, RangeState


class TestRangeAction:
    """RangeAction serialization."""

    def test_create(self):
        a = RangeAction(command="nmap -sV web", mode="red")
        assert a.command == "nmap -sV web"
        assert a.mode == "red"

    def test_serialize_deserialize(self):
        a = RangeAction(command="curl http://web/", mode="red")
        data = a.model_dump()
        assert data["command"] == "curl http://web/"
        assert data["mode"] == "red"
        a2 = RangeAction(**data)
        assert a2.command == a.command
        assert a2.mode == a.mode

    def test_mode_literal(self):
        """mode must be 'red' or 'blue'."""
        a = RangeAction(command="x", mode="blue")
        assert a.mode == "blue"

    def test_json_roundtrip(self):
        a = RangeAction(command="hydra -l admin web ssh", mode="red")
        js = a.model_dump_json()
        a2 = RangeAction.model_validate_json(js)
        assert a2.command == a.command
        assert a2.mode == a.mode


class TestRangeObservation:
    """RangeObservation inherits done/reward from Observation base."""

    def test_defaults(self):
        obs = RangeObservation()
        assert obs.done is False
        assert obs.reward is None
        assert obs.stdout == ""
        assert obs.stderr == ""
        assert obs.flags_captured == []
        assert obs.alerts == []

    def test_done_inherited(self):
        obs = RangeObservation(done=True, reward=1.0)
        assert obs.done is True
        assert obs.reward == 1.0

    def test_custom_fields(self):
        obs = RangeObservation(
            stdout="flag found",
            flags_captured=["FLAG{test}"],
            alerts=["IDS alert"],
        )
        assert obs.flags_captured == ["FLAG{test}"]
        assert obs.alerts == ["IDS alert"]

    def test_json_roundtrip(self):
        obs = RangeObservation(
            stdout="output",
            stderr="err",
            done=True,
            reward=0.5,
            flags_captured=["FLAG{a}"],
            alerts=["alert1"],
        )
        js = obs.model_dump_json()
        obs2 = RangeObservation.model_validate_json(js)
        assert obs2.stdout == obs.stdout
        assert obs2.done is True
        assert obs2.reward == 0.5
        assert obs2.flags_captured == ["FLAG{a}"]


class TestRangeState:
    """RangeState inherits episode_id/step_count from State base."""

    def test_defaults(self):
        s = RangeState()
        assert s.episode_id is None
        assert s.step_count == 0
        assert s.mode == ""
        assert s.tier == 1
        assert s.flags_found == []

    def test_inherited_fields(self):
        s = RangeState(episode_id="ep_42", step_count=5)
        assert s.episode_id == "ep_42"
        assert s.step_count == 5

    def test_custom_fields(self):
        s = RangeState(
            episode_id="ep1",
            mode="red",
            tier=3,
            flags_found=["FLAG{a}"],
            services_status={"web": "running"},
        )
        assert s.tier == 3
        assert s.flags_found == ["FLAG{a}"]
        assert s.services_status["web"] == "running"

    def test_json_roundtrip(self):
        s = RangeState(
            episode_id="ep99",
            step_count=10,
            mode="blue",
            flags_found=["FLAG{x}"],
            tier=2,
        )
        js = s.model_dump_json()
        s2 = RangeState.model_validate_json(js)
        assert s2.episode_id == "ep99"
        assert s2.step_count == 10
        assert s2.mode == "blue"
        assert s2.tier == 2

    def test_model_dump_and_back(self):
        s = RangeState(episode_id="e1", step_count=3, mode="red", tier=1)
        d = s.model_dump()
        s2 = RangeState(**d)
        assert s2.episode_id == s.episode_id
        assert s2.step_count == s.step_count