File size: 12,509 Bytes
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f016eb7
 
 
 
8c486a8
f016eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f016eb7
 
8c486a8
 
 
f016eb7
8c486a8
f016eb7
 
 
8c486a8
 
 
 
f016eb7
 
8c486a8
 
f016eb7
 
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f016eb7
 
 
 
 
 
8c486a8
 
 
f016eb7
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f016eb7
8c486a8
 
 
f016eb7
 
 
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
"""Tests for TrajectoryLogger -- JSONL export for SFT warm-start."""

import json
from pathlib import Path

import pytest

from open_range.training.trajectory import (
    Episode,
    TrajectoryLogger,
    Turn,
    RED_SYSTEM_PROMPT,
    BLUE_SYSTEM_PROMPT,
)


# ---------------------------------------------------------------------------
# Turn dataclass
# ---------------------------------------------------------------------------


class TestTurn:
    def test_basic_construction(self):
        t = Turn(role="red", observation="Range ready.", action="nmap -sV web", reward=0.1)
        assert t.role == "red"
        assert t.observation == "Range ready."
        assert t.action == "nmap -sV web"
        assert t.reward == 0.1
        assert t.timestamp > 0

    def test_timestamp_auto_set(self):
        t = Turn(role="blue", observation="ok", action="grep alert", reward=0.0)
        assert t.timestamp > 0


# ---------------------------------------------------------------------------
# Episode dataclass
# ---------------------------------------------------------------------------


class TestEpisode:
    def _make_episode(self) -> Episode:
        ep = Episode(episode_id="ep-1", snapshot_id="snap-1", tier=1)
        ep.briefings = {
            "red": "Red briefing",
            "blue": "Blue briefing",
        }
        ep.turns = [
            Turn(
                role="red",
                observation="[0.2s] 80/tcp open http",
                action="nmap -sV web",
                reward=0.1,
                assistant_content="<think>\nRecon first.\n</think>",
            ),
            Turn(
                role="blue",
                observation="[0.3s] suspicious nmap",
                action="submit_finding nmap scan",
                reward=0.2,
                assistant_content="<think>\nThis is actionable.\n</think>",
            ),
            Turn(
                role="red",
                observation="[0.4s] products",
                action="curl http://web/search?q=test",
                reward=0.15,
                assistant_content="<think>\nInspect the search route.\n</think>",
            ),
            Turn(
                role="blue",
                observation="[0.5s] SQLi in web log",
                action="grep UNION /var/log/siem/web.log",
                reward=0.05,
                assistant_content="<think>\nI need evidence.\n</think>",
            ),
            Turn(
                role="red",
                observation="[0.6s] correct",
                action="submit_flag FLAG{sqli_123}",
                reward=0.5,
                assistant_content="<think>\nThis token is worth validating.\n</think>",
                tool_name="flag_found",
                tool_arguments={"flag": "FLAG{sqli_123}"},
            ),
        ]
        ep.outcome = "flag_captured"
        return ep

    def test_red_turns(self):
        ep = self._make_episode()
        assert len(ep.red_turns) == 3
        assert all(t.role == "red" for t in ep.red_turns)

    def test_blue_turns(self):
        ep = self._make_episode()
        assert len(ep.blue_turns) == 2
        assert all(t.role == "blue" for t in ep.blue_turns)

    def test_total_red_reward(self):
        ep = self._make_episode()
        expected = 0.1 + 0.15 + 0.5
        assert abs(ep.total_red_reward - expected) < 1e-6

    def test_total_blue_reward(self):
        ep = self._make_episode()
        expected = 0.2 + 0.05
        assert abs(ep.total_blue_reward - expected) < 1e-6

    def test_to_chat_messages_red(self):
        ep = self._make_episode()
        msgs = ep.to_chat_messages("red")
        # system + briefing + 3 * (assistant + tool) = 8 messages
        assert len(msgs) == 8
        assert msgs[0]["role"] == "system"
        assert msgs[0]["content"] == RED_SYSTEM_PROMPT
        assert msgs[1]["role"] == "user"
        assert msgs[1]["content"] == "Red briefing"
        assert msgs[2]["role"] == "assistant"
        assert "tool_calls" in msgs[2]
        assert msgs[2]["tool_calls"][0]["function"]["name"] == "shell_command"
        assert msgs[3]["role"] == "tool"

    def test_to_chat_messages_blue(self):
        ep = self._make_episode()
        msgs = ep.to_chat_messages("blue")
        # system + briefing + 2 * (assistant + tool) = 6 messages
        assert len(msgs) == 6
        assert msgs[0]["role"] == "system"
        assert msgs[0]["content"] == BLUE_SYSTEM_PROMPT
        assert msgs[1]["role"] == "user"
        assert msgs[1]["content"] == "Blue briefing"

    def test_to_jsonl_record(self):
        ep = self._make_episode()
        record = ep.to_jsonl_record("red")
        assert record["episode_id"] == "ep-1"
        assert record["snapshot_id"] == "snap-1"
        assert record["tier"] == 1
        assert record["role"] == "red"
        assert record["outcome"] == "flag_captured"
        assert isinstance(record["messages"], list)
        assert isinstance(record["reward"], float)
        # Verify it's JSON-serializable
        json.dumps(record)


# ---------------------------------------------------------------------------
# TrajectoryLogger
# ---------------------------------------------------------------------------


class TestTrajectoryLogger:
    def test_start_episode(self):
        logger = TrajectoryLogger()
        ep = logger.start_episode(
            "ep-1",
            snapshot_id="snap-1",
            tier=2,
            briefings={"red": "brief"},
        )
        assert ep.episode_id == "ep-1"
        assert ep.snapshot_id == "snap-1"
        assert ep.tier == 2
        assert ep.briefings["red"] == "brief"
        assert logger.current_episode is ep

    def test_log_turn(self):
        logger = TrajectoryLogger()
        logger.start_episode("ep-1")
        turn = logger.log_turn(role="red", observation="Ready.", action="nmap web", reward=0.1)
        assert turn.role == "red"
        assert len(logger.current_episode.turns) == 1

    def test_log_turn_without_episode_raises(self):
        logger = TrajectoryLogger()
        with pytest.raises(RuntimeError, match="No active episode"):
            logger.log_turn(role="red", observation="x", action="y")

    def test_end_episode(self):
        logger = TrajectoryLogger()
        logger.start_episode("ep-1")
        logger.log_turn(role="red", observation="Ready.", action="nmap web", reward=0.1)
        ep = logger.end_episode(outcome="timeout", metrics={"steps": 1})
        assert ep.outcome == "timeout"
        assert ep.metrics == {"steps": 1}
        assert logger.current_episode is None
        assert len(logger.episodes) == 1

    def test_end_episode_without_active_raises(self):
        logger = TrajectoryLogger()
        with pytest.raises(RuntimeError, match="No active episode"):
            logger.end_episode()

    def test_start_new_episode_abandons_current(self):
        logger = TrajectoryLogger()
        logger.start_episode("ep-1")
        logger.log_turn(role="red", observation="x", action="y")
        logger.start_episode("ep-2")
        assert len(logger.episodes) == 1
        assert logger.episodes[0].outcome == "abandoned"
        assert logger.current_episode.episode_id == "ep-2"

    def test_clear(self):
        logger = TrajectoryLogger()
        logger.start_episode("ep-1")
        logger.log_turn(role="red", observation="x", action="y")
        logger.end_episode()
        logger.clear()
        assert len(logger.episodes) == 0
        assert logger.current_episode is None


# ---------------------------------------------------------------------------
# JSONL export
# ---------------------------------------------------------------------------


class TestExportJsonl:
    def _build_logger_with_episodes(self) -> TrajectoryLogger:
        logger = TrajectoryLogger()

        # Episode 1: Red succeeds (high reward)
        logger.start_episode("ep-1", snapshot_id="snap-1", tier=1)
        logger.log_turn(role="red", observation="Range ready.", action="nmap -sV web", reward=0.1)
        logger.log_turn(role="blue", observation="Alert: nmap", action="submit_finding nmap", reward=0.2)
        logger.log_turn(role="red", observation="80/tcp open", action="curl http://web/search?q=1' OR 1=1--", reward=0.3)
        logger.log_turn(role="red", observation="FLAG{sqli}", action="submit_flag FLAG{sqli}", reward=0.5)
        logger.end_episode(outcome="flag_captured")

        # Episode 2: Both low reward (below typical threshold)
        logger.start_episode("ep-2", snapshot_id="snap-2", tier=1)
        logger.log_turn(role="red", observation="Range ready.", action="nmap -sV web", reward=0.01)
        logger.log_turn(role="blue", observation="No alerts", action="tail /var/log/siem/web.log", reward=0.01)
        logger.end_episode(outcome="timeout")

        return logger

    def test_export_creates_file(self, tmp_path: Path):
        logger = self._build_logger_with_episodes()
        out = tmp_path / "trajectories.jsonl"
        count = logger.export_jsonl(out)
        assert out.exists()
        assert count > 0

    def test_export_all_no_filter(self, tmp_path: Path):
        logger = self._build_logger_with_episodes()
        out = tmp_path / "all.jsonl"
        count = logger.export_jsonl(out, reward_threshold=0.0)
        # 2 episodes * 2 roles = 4 lines
        assert count == 4
        lines = out.read_text().strip().split("\n")
        assert len(lines) == 4

    def test_export_with_reward_filter(self, tmp_path: Path):
        logger = self._build_logger_with_episodes()
        out = tmp_path / "filtered.jsonl"
        count = logger.export_jsonl(out, reward_threshold=0.1)
        # ep-1 red reward = 0.9, ep-1 blue reward = 0.2 -> both pass
        # ep-2 red reward = 0.01, ep-2 blue reward = 0.01 -> both filtered
        assert count == 2

    def test_export_single_role(self, tmp_path: Path):
        logger = self._build_logger_with_episodes()
        out = tmp_path / "red_only.jsonl"
        count = logger.export_jsonl(out, roles=("red",))
        # 2 episodes, red only = 2 lines
        assert count == 2

    def test_export_jsonl_format(self, tmp_path: Path):
        logger = self._build_logger_with_episodes()
        out = tmp_path / "format.jsonl"
        logger.export_jsonl(out)
        lines = out.read_text().strip().split("\n")
        for line in lines:
            record = json.loads(line)
            assert "episode_id" in record
            assert "snapshot_id" in record
            assert "tier" in record
            assert "role" in record
            assert "messages" in record
            assert "reward" in record
            assert "outcome" in record
            # Messages must follow chat format
            msgs = record["messages"]
            assert msgs[0]["role"] == "system"
            assert msgs[1]["role"] == "user"
            for msg in msgs:
                assert "role" in msg
                assert "content" in msg
            for msg in msgs:
                if msg["role"] == "assistant":
                    assert msg["tool_calls"]

    def test_export_creates_parent_dirs(self, tmp_path: Path):
        logger = self._build_logger_with_episodes()
        out = tmp_path / "nested" / "deep" / "trajectories.jsonl"
        count = logger.export_jsonl(out)
        assert out.exists()
        assert count > 0

    def test_export_empty_logger(self, tmp_path: Path):
        logger = TrajectoryLogger()
        out = tmp_path / "empty.jsonl"
        count = logger.export_jsonl(out)
        assert count == 0
        assert out.exists()
        assert out.read_text() == ""

    def test_red_and_blue_are_independent_examples(self, tmp_path: Path):
        """Red and Blue trajectories produce separate JSONL lines."""
        logger = self._build_logger_with_episodes()
        out = tmp_path / "independent.jsonl"
        logger.export_jsonl(out)
        lines = out.read_text().strip().split("\n")
        records = [json.loads(line) for line in lines]

        # Find ep-1 records
        ep1_records = [r for r in records if r["episode_id"] == "ep-1"]
        roles = {r["role"] for r in ep1_records}
        assert roles == {"red", "blue"}

        # Red and Blue have different system prompts
        for rec in ep1_records:
            system_msg = rec["messages"][0]
            if rec["role"] == "red":
                assert "penetration tester" in system_msg["content"]
            else:
                assert "SOC analyst" in system_msg["content"]