open-range / tests /test_runner.py
Lars Talian
fix(runner): ground manifest axis in reset inputs
7106e5f
"""Tests for the curriculum runner (issue #23)."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import pytest
from open_range.training.runner import (
CurriculumRunner,
EpisodeRecord,
RunConfig,
_parse_seed_range,
)
# ---------------------------------------------------------------------------
# Mock environment and agents
# ---------------------------------------------------------------------------
@dataclass
class _MockObs:
stdout: str = "Range ready."
stderr: str = ""
done: bool = False
reward: float = 0.0
@dataclass
class _MockState:
flags_found: list[str] = field(default_factory=list)
episode_id: str = "mock-ep"
tier: int = 1
step_count: int = 0
class MockEnvironment:
"""Minimal environment that ends after a configurable number of steps."""
def __init__(self, max_env_steps: int = 4, flags: list[str] | None = None):
self._max_env_steps = max_env_steps
self._flags = flags or []
self._step_count = 0
self._state = _MockState()
self.reset_calls: list[dict[str, Any]] = []
def reset(self, seed: int | None = None, **kwargs: Any) -> _MockObs:
self._step_count = 0
self._state = _MockState()
self.reset_calls.append({"seed": seed, **kwargs})
return _MockObs(stdout=f"Range ready. Seed={seed}")
def step(self, action: Any) -> _MockObs:
self._step_count += 1
done = self._step_count >= self._max_env_steps
if done and self._flags:
self._state.flags_found = list(self._flags)
return _MockObs(
stdout=f"Step {self._step_count} result",
done=done,
reward=0.1,
)
@property
def state(self) -> _MockState:
return self._state
class MockAgent:
"""Minimal agent that returns fixed commands."""
def __init__(self, commands: list[str] | None = None):
self._commands = commands or ["nmap -sV web"]
self._idx = 0
def reset(self, briefing: str, role: str) -> None:
self._idx = 0
def act(self, observation: str) -> str:
cmd = self._commands[self._idx % len(self._commands)]
self._idx += 1
return cmd
# ---------------------------------------------------------------------------
# Tests: seed range parsing
# ---------------------------------------------------------------------------
class TestSeedParsing:
def test_range(self):
assert _parse_seed_range("1-5") == [1, 2, 3, 4, 5]
def test_single(self):
assert _parse_seed_range("3") == [3]
def test_comma_separated(self):
assert _parse_seed_range("1,3,5") == [1, 3, 5]
def test_mixed(self):
assert _parse_seed_range("1-3,7,10-12") == [1, 2, 3, 7, 10, 11, 12]
# ---------------------------------------------------------------------------
# Tests: RunConfig
# ---------------------------------------------------------------------------
class TestRunConfig:
def test_from_cli(self):
config = RunConfig.from_cli(
manifests=["tier1.yaml"],
seeds_str="1-3",
episodes=2,
max_steps=50,
)
assert config.manifests == ["tier1.yaml"]
assert config.seeds == [1, 2, 3]
assert config.episodes_per_seed == 2
assert config.max_steps == 50
# ---------------------------------------------------------------------------
# Tests: CurriculumRunner
# ---------------------------------------------------------------------------
class TestCurriculumRunner:
def test_run_returns_results(self):
env = MockEnvironment(max_env_steps=4)
red = MockAgent()
blue = MockAgent(commands=["tail /var/log/syslog"])
config = RunConfig(
manifests=["tier1.yaml"],
seeds=[1, 2],
episodes_per_seed=1,
max_steps=100,
)
runner = CurriculumRunner(env, red, blue, config)
results = runner.run()
# 1 manifest x 2 seeds x 1 episode = 2 runs
assert len(results) == 2
assert all(isinstance(r, EpisodeRecord) for r in results)
def test_results_have_correct_seeds(self):
env = MockEnvironment(max_env_steps=2)
red = MockAgent()
blue = MockAgent()
config = RunConfig(
manifests=["m.yaml"],
seeds=[10, 20],
episodes_per_seed=1,
)
runner = CurriculumRunner(env, red, blue, config)
results = runner.run()
assert {r.seed for r in results} == {10, 20}
def test_results_have_correct_manifests(self):
env = MockEnvironment(max_env_steps=2)
red = MockAgent()
blue = MockAgent()
config = RunConfig(
manifests=["a.yaml", "b.yaml"],
seeds=[1],
episodes_per_seed=1,
)
runner = CurriculumRunner(env, red, blue, config)
results = runner.run()
assert {r.manifest for r in results} == {"a.yaml", "b.yaml"}
def test_flag_capture_outcome(self):
env = MockEnvironment(max_env_steps=2, flags=["FLAG{test}"])
red = MockAgent()
blue = MockAgent()
config = RunConfig(
manifests=["m.yaml"],
seeds=[1],
episodes_per_seed=1,
)
runner = CurriculumRunner(env, red, blue, config)
results = runner.run()
assert results[0].outcome == "flag_captured"
assert results[0].flags_found == ["FLAG{test}"]
def test_timeout_outcome(self):
env = MockEnvironment(max_env_steps=9999) # never ends
red = MockAgent()
blue = MockAgent()
config = RunConfig(
manifests=["m.yaml"],
seeds=[1],
episodes_per_seed=1,
max_steps=4,
)
runner = CurriculumRunner(env, red, blue, config)
results = runner.run()
assert results[0].outcome == "timeout"
assert results[0].steps == 4
def test_multiple_episodes_per_seed(self):
env = MockEnvironment(max_env_steps=2)
red = MockAgent()
blue = MockAgent()
config = RunConfig(
manifests=["m.yaml"],
seeds=[1],
episodes_per_seed=3,
)
runner = CurriculumRunner(env, red, blue, config)
results = runner.run()
assert len(results) == 3
assert all(r.episode in (1, 2, 3) for r in results)
def test_manifest_axis_is_passed_to_reset_with_snapshot(self):
root = Path(__file__).resolve().parent.parent
manifests = [
str(root / "manifests" / "tier1_basic.yaml"),
str(root / "manifests" / "tier2_corporate.yaml"),
]
env = MockEnvironment(max_env_steps=1)
red = MockAgent()
blue = MockAgent()
config = RunConfig(
manifests=manifests,
seeds=[7],
episodes_per_seed=1,
max_steps=1,
)
runner = CurriculumRunner(env, red, blue, config)
runner.run()
assert len(env.reset_calls) == 2
assert env.reset_calls[0]["manifest_path"].endswith("tier1_basic.yaml")
assert env.reset_calls[1]["manifest_path"].endswith("tier2_corporate.yaml")
assert env.reset_calls[0]["snapshot"].topology["tier"] == 1
assert env.reset_calls[1]["snapshot"].topology["tier"] == 2
# ---------------------------------------------------------------------------
# Tests: results collection
# ---------------------------------------------------------------------------
class TestResultsCollection:
def test_save_and_load(self, tmp_path):
env = MockEnvironment(max_env_steps=2)
red = MockAgent()
blue = MockAgent()
config = RunConfig(
manifests=["m.yaml"],
seeds=[1, 2],
episodes_per_seed=1,
)
runner = CurriculumRunner(env, red, blue, config)
runner.run()
output_path = tmp_path / "results.jsonl"
count = runner.save_results(output_path)
assert count == 2
# Verify JSONL is loadable
with open(output_path) as f:
lines = [json.loads(line) for line in f if line.strip()]
assert len(lines) == 2
assert lines[0]["seed"] == 1
assert lines[1]["seed"] == 2
def test_results_summary(self):
env = MockEnvironment(max_env_steps=2, flags=["FLAG{x}"])
red = MockAgent()
blue = MockAgent()
config = RunConfig(
manifests=["a.yaml", "b.yaml"],
seeds=[1, 2],
episodes_per_seed=1,
)
runner = CurriculumRunner(env, red, blue, config)
runner.run()
summary = runner.results_summary()
assert summary["total_runs"] == 4
assert "flag_captured" in summary["outcomes"]
assert summary["avg_reward"] > 0
assert "a.yaml" in summary["by_manifest"]
assert "b.yaml" in summary["by_manifest"]
assert 1 in summary["by_seed"]
assert 2 in summary["by_seed"]
def test_empty_results_summary(self):
env = MockEnvironment()
red = MockAgent()
blue = MockAgent()
config = RunConfig(manifests=[], seeds=[], episodes_per_seed=1)
runner = CurriculumRunner(env, red, blue, config)
summary = runner.results_summary()
assert summary["total_runs"] == 0
assert summary["avg_reward"] == 0.0
def test_episode_record_to_dict(self):
record = EpisodeRecord(
manifest="m.yaml",
seed=42,
episode=1,
outcome="flag_captured",
steps=5,
flags_found=["FLAG{x}"],
reward=1.5,
duration_s=3.14,
metadata={"extra": "info"},
)
d = record.to_dict()
assert d["manifest"] == "m.yaml"
assert d["seed"] == 42
assert d["outcome"] == "flag_captured"
assert d["extra"] == "info"