"""Tests for TrajectoryLogger -- JSONL export for SFT warm-start."""
import json
from pathlib import Path
import pytest
from open_range.training.trajectory import (
Episode,
TrajectoryLogger,
Turn,
RED_SYSTEM_PROMPT,
BLUE_SYSTEM_PROMPT,
)
# ---------------------------------------------------------------------------
# Turn dataclass
# ---------------------------------------------------------------------------
class TestTurn:
def test_basic_construction(self):
t = Turn(role="red", observation="Range ready.", action="nmap -sV web", reward=0.1)
assert t.role == "red"
assert t.observation == "Range ready."
assert t.action == "nmap -sV web"
assert t.reward == 0.1
assert t.timestamp > 0
def test_timestamp_auto_set(self):
t = Turn(role="blue", observation="ok", action="grep alert", reward=0.0)
assert t.timestamp > 0
# ---------------------------------------------------------------------------
# Episode dataclass
# ---------------------------------------------------------------------------
class TestEpisode:
def _make_episode(self) -> Episode:
ep = Episode(episode_id="ep-1", snapshot_id="snap-1", tier=1)
ep.briefings = {
"red": "Red briefing",
"blue": "Blue briefing",
}
ep.turns = [
Turn(
role="red",
observation="[0.2s] 80/tcp open http",
action="nmap -sV web",
reward=0.1,
assistant_content="\nRecon first.\n",
),
Turn(
role="blue",
observation="[0.3s] suspicious nmap",
action="submit_finding nmap scan",
reward=0.2,
assistant_content="\nThis is actionable.\n",
),
Turn(
role="red",
observation="[0.4s] products",
action="curl http://web/search?q=test",
reward=0.15,
assistant_content="\nInspect the search route.\n",
),
Turn(
role="blue",
observation="[0.5s] SQLi in web log",
action="grep UNION /var/log/siem/web.log",
reward=0.05,
assistant_content="\nI need evidence.\n",
),
Turn(
role="red",
observation="[0.6s] correct",
action="submit_flag FLAG{sqli_123}",
reward=0.5,
assistant_content="\nThis token is worth validating.\n",
tool_name="flag_found",
tool_arguments={"flag": "FLAG{sqli_123}"},
),
]
ep.outcome = "flag_captured"
return ep
def test_red_turns(self):
ep = self._make_episode()
assert len(ep.red_turns) == 3
assert all(t.role == "red" for t in ep.red_turns)
def test_blue_turns(self):
ep = self._make_episode()
assert len(ep.blue_turns) == 2
assert all(t.role == "blue" for t in ep.blue_turns)
def test_total_red_reward(self):
ep = self._make_episode()
expected = 0.1 + 0.15 + 0.5
assert abs(ep.total_red_reward - expected) < 1e-6
def test_total_blue_reward(self):
ep = self._make_episode()
expected = 0.2 + 0.05
assert abs(ep.total_blue_reward - expected) < 1e-6
def test_to_chat_messages_red(self):
ep = self._make_episode()
msgs = ep.to_chat_messages("red")
# system + briefing + 3 * (assistant + tool) = 8 messages
assert len(msgs) == 8
assert msgs[0]["role"] == "system"
assert msgs[0]["content"] == RED_SYSTEM_PROMPT
assert msgs[1]["role"] == "user"
assert msgs[1]["content"] == "Red briefing"
assert msgs[2]["role"] == "assistant"
assert "tool_calls" in msgs[2]
assert msgs[2]["tool_calls"][0]["function"]["name"] == "shell_command"
assert msgs[3]["role"] == "tool"
def test_to_chat_messages_blue(self):
ep = self._make_episode()
msgs = ep.to_chat_messages("blue")
# system + briefing + 2 * (assistant + tool) = 6 messages
assert len(msgs) == 6
assert msgs[0]["role"] == "system"
assert msgs[0]["content"] == BLUE_SYSTEM_PROMPT
assert msgs[1]["role"] == "user"
assert msgs[1]["content"] == "Blue briefing"
def test_to_jsonl_record(self):
ep = self._make_episode()
record = ep.to_jsonl_record("red")
assert record["episode_id"] == "ep-1"
assert record["snapshot_id"] == "snap-1"
assert record["tier"] == 1
assert record["role"] == "red"
assert record["outcome"] == "flag_captured"
assert isinstance(record["messages"], list)
assert isinstance(record["reward"], float)
# Verify it's JSON-serializable
json.dumps(record)
# ---------------------------------------------------------------------------
# TrajectoryLogger
# ---------------------------------------------------------------------------
class TestTrajectoryLogger:
def test_start_episode(self):
logger = TrajectoryLogger()
ep = logger.start_episode(
"ep-1",
snapshot_id="snap-1",
tier=2,
briefings={"red": "brief"},
)
assert ep.episode_id == "ep-1"
assert ep.snapshot_id == "snap-1"
assert ep.tier == 2
assert ep.briefings["red"] == "brief"
assert logger.current_episode is ep
def test_log_turn(self):
logger = TrajectoryLogger()
logger.start_episode("ep-1")
turn = logger.log_turn(role="red", observation="Ready.", action="nmap web", reward=0.1)
assert turn.role == "red"
assert len(logger.current_episode.turns) == 1
def test_log_turn_without_episode_raises(self):
logger = TrajectoryLogger()
with pytest.raises(RuntimeError, match="No active episode"):
logger.log_turn(role="red", observation="x", action="y")
def test_end_episode(self):
logger = TrajectoryLogger()
logger.start_episode("ep-1")
logger.log_turn(role="red", observation="Ready.", action="nmap web", reward=0.1)
ep = logger.end_episode(outcome="timeout", metrics={"steps": 1})
assert ep.outcome == "timeout"
assert ep.metrics == {"steps": 1}
assert logger.current_episode is None
assert len(logger.episodes) == 1
def test_end_episode_without_active_raises(self):
logger = TrajectoryLogger()
with pytest.raises(RuntimeError, match="No active episode"):
logger.end_episode()
def test_start_new_episode_abandons_current(self):
logger = TrajectoryLogger()
logger.start_episode("ep-1")
logger.log_turn(role="red", observation="x", action="y")
logger.start_episode("ep-2")
assert len(logger.episodes) == 1
assert logger.episodes[0].outcome == "abandoned"
assert logger.current_episode.episode_id == "ep-2"
def test_clear(self):
logger = TrajectoryLogger()
logger.start_episode("ep-1")
logger.log_turn(role="red", observation="x", action="y")
logger.end_episode()
logger.clear()
assert len(logger.episodes) == 0
assert logger.current_episode is None
# ---------------------------------------------------------------------------
# JSONL export
# ---------------------------------------------------------------------------
class TestExportJsonl:
def _build_logger_with_episodes(self) -> TrajectoryLogger:
logger = TrajectoryLogger()
# Episode 1: Red succeeds (high reward)
logger.start_episode("ep-1", snapshot_id="snap-1", tier=1)
logger.log_turn(role="red", observation="Range ready.", action="nmap -sV web", reward=0.1)
logger.log_turn(role="blue", observation="Alert: nmap", action="submit_finding nmap", reward=0.2)
logger.log_turn(role="red", observation="80/tcp open", action="curl http://web/search?q=1' OR 1=1--", reward=0.3)
logger.log_turn(role="red", observation="FLAG{sqli}", action="submit_flag FLAG{sqli}", reward=0.5)
logger.end_episode(outcome="flag_captured")
# Episode 2: Both low reward (below typical threshold)
logger.start_episode("ep-2", snapshot_id="snap-2", tier=1)
logger.log_turn(role="red", observation="Range ready.", action="nmap -sV web", reward=0.01)
logger.log_turn(role="blue", observation="No alerts", action="tail /var/log/siem/web.log", reward=0.01)
logger.end_episode(outcome="timeout")
return logger
def test_export_creates_file(self, tmp_path: Path):
logger = self._build_logger_with_episodes()
out = tmp_path / "trajectories.jsonl"
count = logger.export_jsonl(out)
assert out.exists()
assert count > 0
def test_export_all_no_filter(self, tmp_path: Path):
logger = self._build_logger_with_episodes()
out = tmp_path / "all.jsonl"
count = logger.export_jsonl(out, reward_threshold=0.0)
# 2 episodes * 2 roles = 4 lines
assert count == 4
lines = out.read_text().strip().split("\n")
assert len(lines) == 4
def test_export_with_reward_filter(self, tmp_path: Path):
logger = self._build_logger_with_episodes()
out = tmp_path / "filtered.jsonl"
count = logger.export_jsonl(out, reward_threshold=0.1)
# ep-1 red reward = 0.9, ep-1 blue reward = 0.2 -> both pass
# ep-2 red reward = 0.01, ep-2 blue reward = 0.01 -> both filtered
assert count == 2
def test_export_single_role(self, tmp_path: Path):
logger = self._build_logger_with_episodes()
out = tmp_path / "red_only.jsonl"
count = logger.export_jsonl(out, roles=("red",))
# 2 episodes, red only = 2 lines
assert count == 2
def test_export_jsonl_format(self, tmp_path: Path):
logger = self._build_logger_with_episodes()
out = tmp_path / "format.jsonl"
logger.export_jsonl(out)
lines = out.read_text().strip().split("\n")
for line in lines:
record = json.loads(line)
assert "episode_id" in record
assert "snapshot_id" in record
assert "tier" in record
assert "role" in record
assert "messages" in record
assert "reward" in record
assert "outcome" in record
# Messages must follow chat format
msgs = record["messages"]
assert msgs[0]["role"] == "system"
assert msgs[1]["role"] == "user"
for msg in msgs:
assert "role" in msg
assert "content" in msg
for msg in msgs:
if msg["role"] == "assistant":
assert msg["tool_calls"]
def test_export_creates_parent_dirs(self, tmp_path: Path):
logger = self._build_logger_with_episodes()
out = tmp_path / "nested" / "deep" / "trajectories.jsonl"
count = logger.export_jsonl(out)
assert out.exists()
assert count > 0
def test_export_empty_logger(self, tmp_path: Path):
logger = TrajectoryLogger()
out = tmp_path / "empty.jsonl"
count = logger.export_jsonl(out)
assert count == 0
assert out.exists()
assert out.read_text() == ""
def test_red_and_blue_are_independent_examples(self, tmp_path: Path):
"""Red and Blue trajectories produce separate JSONL lines."""
logger = self._build_logger_with_episodes()
out = tmp_path / "independent.jsonl"
logger.export_jsonl(out)
lines = out.read_text().strip().split("\n")
records = [json.loads(line) for line in lines]
# Find ep-1 records
ep1_records = [r for r in records if r["episode_id"] == "ep-1"]
roles = {r["role"] for r in ep1_records}
assert roles == {"red", "blue"}
# Red and Blue have different system prompts
for rec in ep1_records:
system_msg = rec["messages"][0]
if rec["role"] == "red":
assert "penetration tester" in system_msg["content"]
else:
assert "SOC analyst" in system_msg["content"]