open-range / tests /test_demo.py
Aaron Brown
Add episode CLI, synthetic data pipeline, NPC generalization, service manifest
f016eb7
"""Tests for the end-to-end replay demo."""
import json
from pathlib import Path
import pytest
from examples.demo import (
ScriptedAgent,
make_demo_snapshot,
run_demo,
)
from open_range.server.environment import RangeEnvironment
from open_range.server.models import RangeAction
# ---------------------------------------------------------------------------
# ScriptedAgent
# ---------------------------------------------------------------------------
class TestScriptedAgent:
def test_follows_script(self):
agent = ScriptedAgent(["cmd1", "cmd2", "cmd3"])
assert agent.act("obs") == "cmd1"
assert agent.act("obs") == "cmd2"
assert agent.act("obs") == "cmd3"
def test_exhausted(self):
agent = ScriptedAgent(["cmd1"])
assert not agent.exhausted
agent.act("obs")
assert agent.exhausted
def test_fallback_after_exhausted(self):
agent = ScriptedAgent(["cmd1"])
agent.act("obs")
assert agent.act("obs") == "echo done"
def test_reset(self):
agent = ScriptedAgent(["cmd1", "cmd2"])
agent.act("obs")
agent.reset(briefing="new episode", role="red")
assert agent.act("obs") == "cmd1"
assert not agent.exhausted
# ---------------------------------------------------------------------------
# Demo snapshot
# ---------------------------------------------------------------------------
class TestDemoSnapshot:
def test_snapshot_has_flags(self):
snap = make_demo_snapshot()
assert len(snap.flags) == 1
assert snap.flags[0].value == "FLAG{test_sqli_123}"
def test_snapshot_has_golden_path(self):
snap = make_demo_snapshot()
assert len(snap.golden_path) >= 3
def test_snapshot_has_vulnerability(self):
snap = make_demo_snapshot()
assert len(snap.truth_graph.vulns) == 1
assert snap.truth_graph.vulns[0].type == "sqli"
def test_snapshot_has_task(self):
snap = make_demo_snapshot()
assert snap.task.red_briefing
assert snap.task.blue_briefing
# ---------------------------------------------------------------------------
# Full demo run (mocked, no Docker)
# ---------------------------------------------------------------------------
class TestRunDemo:
def test_demo_completes(self):
result = run_demo(quiet=True)
assert result["outcome"] in ("flag_captured", "timeout", "completed")
assert result["steps"] > 0
def test_demo_captures_flag(self):
result = run_demo(quiet=True)
assert "FLAG{test_sqli_123}" in result["flags_found"]
assert result["outcome"] == "flag_captured"
def test_demo_produces_trajectories(self):
result = run_demo(quiet=True)
episode = result["episode"]
assert len(episode.red_turns) > 0
assert len(episode.blue_turns) > 0
def test_demo_trajectory_export(self, tmp_path: Path):
result = run_demo(quiet=True)
traj = result["trajectory_logger"]
out = tmp_path / "demo.jsonl"
count = traj.export_jsonl(out)
assert count >= 2 # at least red + blue
lines = out.read_text().strip().split("\n")
for line in lines:
record = json.loads(line)
assert record["episode_id"] == "demo-001"
assert record["messages"][0]["role"] == "system"
def test_demo_with_custom_env(self):
env = RangeEnvironment(docker_available=False, max_steps=50)
result = run_demo(env=env, quiet=True)
assert result["outcome"] == "flag_captured"
def test_demo_rewards_are_numeric(self):
result = run_demo(quiet=True)
episode = result["episode"]
for turn in episode.turns:
assert isinstance(turn.reward, (int, float))
def test_demo_red_and_blue_alternate(self):
"""Verify Red and Blue take alternating turns."""
result = run_demo(quiet=True)
episode = result["episode"]
# The first turn should be Red
assert episode.turns[0].role == "red"
# After the first Red turn, turns should alternate (with some
# flexibility since Red might finish before Blue exhausts)
for i in range(1, min(len(episode.turns), 8)):
# Alternation pattern: even index = red, odd = blue
# (within the interleaved portion)
expected = "blue" if i % 2 == 1 else "red"
assert episode.turns[i].role == expected, (
f"Turn {i}: expected {expected}, got {episode.turns[i].role}"
)