open-range / src /open_range /agents /episode.py
Aaron Brown
Remove hardcoded fallbacks, add snapshot-driven service lifecycle
7fedc25
"""Episode orchestration loop.
``run_episode`` connects to an OpenRange environment and alternates
Red / Blue turns until the episode ends or ``max_steps`` is reached.
Agents only see observations -- they cannot control episode lifecycle.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from open_range.agents.protocol import EpisodeMetrics, EpisodeResult
if TYPE_CHECKING:
from open_range.agents.protocol import RangeAgent
from open_range.training.curriculum import CurriculumTracker
logger = logging.getLogger(__name__)
def _determine_outcome(
flags_found: list[str],
total_flags: int,
steps: int,
max_steps: int,
) -> str:
"""Determine episode outcome from final state."""
if flags_found and len(flags_found) >= total_flags > 0:
return "red_win"
if steps >= max_steps:
return "timeout"
# If done without all flags and not timeout, Blue held
return "blue_win"
def _compute_metrics(
result: EpisodeResult,
total_flags: int,
) -> EpisodeMetrics:
"""Compute episode metrics from trajectories."""
metrics = EpisodeMetrics()
# Flag capture rate
if total_flags > 0:
metrics.flag_capture_rate = len(result.flags_found) / total_flags
# Steps to first flag
for i, step in enumerate(result.red_trajectory):
stdout = step.get("stdout", "")
if "Flag accepted" in stdout or "Correct" in stdout:
metrics.steps_to_first_flag = i + 1
break
# Stealth: 1 - (detected_actions / total_red_actions)
total_red = len(result.red_trajectory)
if total_red > 0:
# Count Blue findings that overlap with Red actions
blue_findings = sum(
1 for s in result.blue_trajectory if s.get("command", "").startswith("submit_finding")
)
detected = min(blue_findings, total_red)
metrics.stealth = 1.0 - (detected / total_red)
else:
metrics.stealth = 1.0
# Detection TP: how many Red actions Blue detected
if total_red > 0:
blue_findings = sum(
1 for s in result.blue_trajectory if s.get("command", "").startswith("submit_finding")
)
metrics.detection_tp = min(blue_findings / total_red, 1.0)
return metrics
def run_episode(
env: object,
red: RangeAgent,
blue: RangeAgent,
max_steps: int = 100,
red_model: str = "",
blue_model: str = "",
curriculum: CurriculumTracker | None = None,
) -> EpisodeResult:
"""Run one tandem Red + Blue episode.
The orchestration layer calls ``reset()`` and alternates agent turns.
Agents only see observations -- they cannot control episode lifecycle.
This function works with the ``RangeEnvironment`` directly (no HTTP).
For remote environments, use the async variant or call through the client.
Args:
env: A ``RangeEnvironment`` instance (or anything with ``reset``/``step``/``state``).
red: Red team agent (satisfies ``RangeAgent`` protocol).
blue: Blue team agent (satisfies ``RangeAgent`` protocol).
max_steps: Maximum total steps (Red + Blue combined).
red_model: Model identifier for logging.
blue_model: Model identifier for logging.
Returns:
``EpisodeResult`` with trajectories, metrics, and outcome.
"""
from open_range.models import RangeAction
# Reset environment
obs = env.reset()
briefing = obs.stdout
# Initialize agents
red.reset(briefing=briefing, role="red")
blue.reset(briefing=briefing, role="blue")
red_trajectory: list[dict] = []
blue_trajectory: list[dict] = []
step = 0
while not obs.done and step < max_steps:
# Red's turn
red_cmd = red.act(obs)
obs = env.step(RangeAction(command=red_cmd, mode="red"))
red_trajectory.append({
"command": red_cmd,
"stdout": obs.stdout,
"stderr": getattr(obs, "stderr", ""),
"alerts": getattr(obs, "alerts", []),
"reward": obs.reward,
})
step += 1
if obs.done:
break
# Blue's turn
blue_cmd = blue.act(obs)
obs = env.step(RangeAction(command=blue_cmd, mode="blue"))
blue_trajectory.append({
"command": blue_cmd,
"stdout": obs.stdout,
"stderr": getattr(obs, "stderr", ""),
"alerts": getattr(obs, "alerts", []),
"reward": obs.reward,
})
step += 1
# Gather final state
env_state = env.state
flags_found = getattr(env_state, "flags_found", [])
tier = getattr(env_state, "tier", 1)
snapshot_id = getattr(env_state, "episode_id", "")
# Determine total flags available
snapshot = getattr(env, "snapshot", None) or getattr(env, "_snapshot", None)
total_flags = len(snapshot.flags) if snapshot and hasattr(snapshot, "flags") else 0
outcome = _determine_outcome(flags_found, total_flags, step, max_steps)
result = EpisodeResult(
red_trajectory=red_trajectory,
blue_trajectory=blue_trajectory,
flags_found=list(flags_found),
steps=step,
tier=tier,
snapshot_id=snapshot_id,
red_model=red_model or getattr(red, "model", ""),
blue_model=blue_model or getattr(blue, "model", ""),
outcome=outcome,
)
result.metrics = _compute_metrics(result, total_flags)
logger.info(
"Episode %s complete: outcome=%s, steps=%d, flags=%d/%d",
snapshot_id,
outcome,
step,
len(flags_found),
total_flags,
)
# Curriculum feedback wiring (#34)
if curriculum is not None:
# Extract vuln classes from snapshot truth graph if available
vuln_classes: list[str] = []
if snapshot and hasattr(snapshot, "truth_graph") and snapshot.truth_graph:
tg = snapshot.truth_graph
vulns = getattr(tg, "vulns", [])
vuln_classes = [getattr(v, "type", "") for v in vulns if getattr(v, "type", "")]
curriculum.update_from_result({
"snapshot_id": snapshot_id,
"vuln_classes": vuln_classes,
"outcome": outcome,
"flags_found": list(flags_found),
"steps": step,
"tier": tier,
"red_model": red_model or getattr(red, "model", ""),
"blue_model": blue_model or getattr(blue, "model", ""),
})
return result