"""Curriculum runner for OpenRange training evaluation. Orchestrates a matrix of (manifest x seed x episode) runs, collects results, and saves them as JSONL for downstream analysis. Usage:: python -m open_range.training.runner \\ --manifest manifests/tier1_basic.yaml \\ --seeds 1-5 \\ --episodes 3 """ from __future__ import annotations import asyncio import argparse import json import logging import sys import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Protocol, runtime_checkable import yaml logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Protocols for pluggable environment and agents # --------------------------------------------------------------------------- @runtime_checkable class RunnerEnvironment(Protocol): """Minimal environment interface for the curriculum runner.""" def reset(self, seed: int | None = None, **kwargs: Any) -> Any: """Reset the environment, optionally with a seed.""" ... def step(self, action: Any) -> Any: """Execute one step.""" ... @property def state(self) -> Any: """Current environment state.""" ... @runtime_checkable class RunnerAgent(Protocol): """Minimal agent interface for the curriculum runner.""" def reset(self, briefing: str, role: str) -> None: ... def act(self, observation: Any) -> str: ... # --------------------------------------------------------------------------- # Run config and result types # --------------------------------------------------------------------------- @dataclass class RunConfig: """Configuration for a curriculum run.""" manifests: list[str] seeds: list[int] episodes_per_seed: int = 1 max_steps: int = 100 @classmethod def from_cli( cls, manifests: list[str], seeds_str: str, episodes: int = 1, max_steps: int = 100, ) -> RunConfig: """Parse CLI seed range (e.g. '1-5') into a RunConfig.""" seeds = _parse_seed_range(seeds_str) return cls( manifests=manifests, seeds=seeds, episodes_per_seed=episodes, max_steps=max_steps, ) @dataclass class EpisodeRecord: """Result of a single episode run.""" manifest: str seed: int episode: int outcome: str = "timeout" steps: int = 0 flags_found: list[str] = field(default_factory=list) reward: float = 0.0 duration_s: float = 0.0 metadata: dict[str, Any] = field(default_factory=dict) def to_dict(self) -> dict[str, Any]: return { "manifest": self.manifest, "seed": self.seed, "episode": self.episode, "outcome": self.outcome, "steps": self.steps, "flags_found": self.flags_found, "reward": round(self.reward, 4), "duration_s": round(self.duration_s, 2), **self.metadata, } def _parse_seed_range(seeds_str: str) -> list[int]: """Parse a seed range like '1-5' or '1,3,5' into a list of ints.""" seeds: list[int] = [] for part in seeds_str.split(","): part = part.strip() if "-" in part: start, end = part.split("-", 1) seeds.extend(range(int(start), int(end) + 1)) else: seeds.append(int(part)) return seeds # --------------------------------------------------------------------------- # Curriculum Runner # --------------------------------------------------------------------------- class CurriculumRunner: """Orchestrate a matrix of (manifest x seed x episode) runs. Args: env: Environment instance (or factory callable). red: Red team agent. blue: Blue team agent. config: Run configuration specifying the matrix. """ def __init__( self, env: Any, red: RunnerAgent, blue: RunnerAgent, config: RunConfig, ) -> None: self.env = env self.red = red self.blue = blue self.config = config self._results: list[EpisodeRecord] = [] self._manifest_cache: dict[str, dict[str, Any]] = {} self._snapshot_builder: Any = None @property def results(self) -> list[EpisodeRecord]: """All episode results collected so far.""" return list(self._results) def run(self) -> list[EpisodeRecord]: """Run the full (manifest x seed x episode) matrix. Returns: List of EpisodeRecord for each completed episode. """ self._results = [] total = ( len(self.config.manifests) * len(self.config.seeds) * self.config.episodes_per_seed ) run_idx = 0 for manifest_path in self.config.manifests: for seed in self.config.seeds: for ep in range(self.config.episodes_per_seed): run_idx += 1 logger.info( "Run %d/%d: manifest=%s seed=%d episode=%d", run_idx, total, manifest_path, seed, ep + 1, ) record = self._run_episode(manifest_path, seed, ep + 1) self._results.append(record) return list(self._results) def _run_episode( self, manifest_path: str, seed: int, episode_num: int ) -> EpisodeRecord: """Run a single episode and return an EpisodeRecord.""" from open_range.models import RangeAction start = time.time() reset_kwargs: dict[str, Any] = {"seed": seed, "manifest_path": manifest_path} snapshot = self._build_snapshot_for_manifest(manifest_path, seed) if snapshot is not None: reset_kwargs["snapshot"] = snapshot try: obs = self.env.reset(**reset_kwargs) except Exception as exc: logger.error("Reset failed: %s", exc) return EpisodeRecord( manifest=manifest_path, seed=seed, episode=episode_num, outcome="error", metadata={"error": str(exc), "reset_kwargs": sorted(reset_kwargs.keys())}, ) briefing = getattr(obs, "stdout", str(obs)) self.red.reset(briefing=briefing, role="red") self.blue.reset(briefing=briefing, role="blue") step = 0 total_reward = 0.0 done = getattr(obs, "done", False) while not done and step < self.config.max_steps: # Red turn try: red_cmd = self.red.act(obs) obs = self.env.step(RangeAction(command=red_cmd, mode="red")) total_reward += getattr(obs, "reward", 0.0) or 0.0 step += 1 done = getattr(obs, "done", False) except Exception as exc: logger.warning("Red step failed: %s", exc) break if done: break # Blue turn try: blue_cmd = self.blue.act(obs) obs = self.env.step(RangeAction(command=blue_cmd, mode="blue")) total_reward += getattr(obs, "reward", 0.0) or 0.0 step += 1 done = getattr(obs, "done", False) except Exception as exc: logger.warning("Blue step failed: %s", exc) break duration = time.time() - start # Gather state env_state = self.env.state flags_found = list(getattr(env_state, "flags_found", [])) # Determine outcome if flags_found: outcome = "flag_captured" elif step >= self.config.max_steps: outcome = "timeout" elif done: outcome = "done" else: outcome = "error" return EpisodeRecord( manifest=manifest_path, seed=seed, episode=episode_num, outcome=outcome, steps=step, flags_found=flags_found, reward=total_reward, duration_s=duration, ) def _build_snapshot_for_manifest(self, manifest_path: str, seed: int) -> Any | None: """Build a deterministic snapshot from the manifest for this episode. This grounds the manifest axis in real environment input rather than using manifest names only for reporting metadata. """ from open_range.builder.builder import TemplateOnlyBuilder from open_range.protocols import BuildContext manifest = self._manifest_cache.get(manifest_path) if manifest is None: manifest_file = Path(manifest_path) if not manifest_file.exists(): logger.debug( "Manifest path %s not found; skipping snapshot build and relying on env.reset kwargs only", manifest_path, ) return None with open(manifest_file) as f: manifest = yaml.safe_load(f) if not isinstance(manifest, dict): raise ValueError(f"Manifest {manifest_path!r} did not parse to a mapping") self._manifest_cache[manifest_path] = manifest if self._snapshot_builder is None: self._snapshot_builder = TemplateOnlyBuilder() tier = int(manifest.get("tier", 1) or 1) context = BuildContext(seed=seed, tier=tier) return self._run_coro_sync(self._snapshot_builder.build(manifest, context)) @staticmethod def _run_coro_sync(coro: Any) -> Any: """Run an async coroutine in sync code, including notebook event loops.""" try: asyncio.get_running_loop() except RuntimeError: return asyncio.run(coro) from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(max_workers=1) as pool: return pool.submit(lambda: asyncio.run(coro)).result() def save_results(self, path: str | Path) -> int: """Save results to a JSONL file. Args: path: Output file path. Returns: Number of records written. """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: for record in self._results: f.write(json.dumps(record.to_dict()) + "\n") return len(self._results) def results_summary(self) -> dict[str, Any]: """Aggregate pass/fail and reward statistics. Returns: Dict with: - total_runs: total episodes run - outcomes: outcome -> count - avg_reward: mean reward - avg_steps: mean steps - avg_duration_s: mean duration - by_manifest: per-manifest stats - by_seed: per-seed stats """ if not self._results: return { "total_runs": 0, "outcomes": {}, "avg_reward": 0.0, "avg_steps": 0.0, "avg_duration_s": 0.0, "by_manifest": {}, "by_seed": {}, } outcomes: dict[str, int] = {} rewards: list[float] = [] steps: list[int] = [] durations: list[float] = [] by_manifest: dict[str, list[EpisodeRecord]] = {} by_seed: dict[int, list[EpisodeRecord]] = {} for r in self._results: outcomes[r.outcome] = outcomes.get(r.outcome, 0) + 1 rewards.append(r.reward) steps.append(r.steps) durations.append(r.duration_s) by_manifest.setdefault(r.manifest, []).append(r) by_seed.setdefault(r.seed, []).append(r) def _agg(records: list[EpisodeRecord]) -> dict[str, Any]: oc: dict[str, int] = {} for rec in records: oc[rec.outcome] = oc.get(rec.outcome, 0) + 1 return { "count": len(records), "outcomes": oc, "avg_reward": round( sum(rec.reward for rec in records) / len(records), 4 ), "avg_steps": round( sum(rec.steps for rec in records) / len(records), 2 ), } return { "total_runs": len(self._results), "outcomes": outcomes, "avg_reward": round(sum(rewards) / len(rewards), 4), "avg_steps": round(sum(steps) / len(steps), 2), "avg_duration_s": round(sum(durations) / len(durations), 2), "by_manifest": {m: _agg(recs) for m, recs in sorted(by_manifest.items())}, "by_seed": {s: _agg(recs) for s, recs in sorted(by_seed.items())}, } # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser( description="Run OpenRange curriculum evaluation matrix", ) parser.add_argument( "--manifest", nargs="+", required=True, help="One or more manifest YAML paths", ) parser.add_argument( "--seeds", default="1-3", help="Seed range, e.g. '1-5' or '1,3,7' (default: 1-3)", ) parser.add_argument( "--episodes", type=int, default=1, help="Episodes per (manifest, seed) pair (default: 1)", ) parser.add_argument( "--max-steps", type=int, default=100, help="Maximum steps per episode (default: 100)", ) parser.add_argument( "--output", default="results.jsonl", help="Output JSONL path (default: results.jsonl)", ) args = parser.parse_args() config = RunConfig.from_cli( manifests=args.manifest, seeds_str=args.seeds, episodes=args.episodes, max_steps=args.max_steps, ) print( f"Curriculum run: {len(config.manifests)} manifests x " f"{len(config.seeds)} seeds x {config.episodes_per_seed} episodes " f"= {len(config.manifests) * len(config.seeds) * config.episodes_per_seed} total runs" ) print(f"Max steps per episode: {config.max_steps}") print(f"Output: {args.output}") print() # In CLI mode, we need a real environment and agents. # For now, report the config and exit -- actual runs require # a running environment instance. print( "Note: CLI mode requires a running environment. " "Use CurriculumRunner programmatically with your environment and agents." ) print(f"Config: manifests={config.manifests}, seeds={config.seeds}, " f"episodes={config.episodes_per_seed}") sys.exit(0) if __name__ == "__main__": main()