Spaces:
Runtime error
Runtime error
| """Curriculum runner for OpenRange training evaluation. | |
| Orchestrates a matrix of (manifest x seed x episode) runs, collects | |
| results, and saves them as JSONL for downstream analysis. | |
| Usage:: | |
| python -m open_range.training.runner \\ | |
| --manifest manifests/tier1_basic.yaml \\ | |
| --seeds 1-5 \\ | |
| --episodes 3 | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Protocol, runtime_checkable | |
| import yaml | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Protocols for pluggable environment and agents | |
| # --------------------------------------------------------------------------- | |
| class RunnerEnvironment(Protocol): | |
| """Minimal environment interface for the curriculum runner.""" | |
| def reset(self, seed: int | None = None, **kwargs: Any) -> Any: | |
| """Reset the environment, optionally with a seed.""" | |
| ... | |
| def step(self, action: Any) -> Any: | |
| """Execute one step.""" | |
| ... | |
| def state(self) -> Any: | |
| """Current environment state.""" | |
| ... | |
| class RunnerAgent(Protocol): | |
| """Minimal agent interface for the curriculum runner.""" | |
| def reset(self, briefing: str, role: str) -> None: | |
| ... | |
| def act(self, observation: Any) -> str: | |
| ... | |
| # --------------------------------------------------------------------------- | |
| # Run config and result types | |
| # --------------------------------------------------------------------------- | |
| class RunConfig: | |
| """Configuration for a curriculum run.""" | |
| manifests: list[str] | |
| seeds: list[int] | |
| episodes_per_seed: int = 1 | |
| max_steps: int = 100 | |
| def from_cli( | |
| cls, | |
| manifests: list[str], | |
| seeds_str: str, | |
| episodes: int = 1, | |
| max_steps: int = 100, | |
| ) -> RunConfig: | |
| """Parse CLI seed range (e.g. '1-5') into a RunConfig.""" | |
| seeds = _parse_seed_range(seeds_str) | |
| return cls( | |
| manifests=manifests, | |
| seeds=seeds, | |
| episodes_per_seed=episodes, | |
| max_steps=max_steps, | |
| ) | |
| class EpisodeRecord: | |
| """Result of a single episode run.""" | |
| manifest: str | |
| seed: int | |
| episode: int | |
| outcome: str = "timeout" | |
| steps: int = 0 | |
| flags_found: list[str] = field(default_factory=list) | |
| reward: float = 0.0 | |
| duration_s: float = 0.0 | |
| metadata: dict[str, Any] = field(default_factory=dict) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "manifest": self.manifest, | |
| "seed": self.seed, | |
| "episode": self.episode, | |
| "outcome": self.outcome, | |
| "steps": self.steps, | |
| "flags_found": self.flags_found, | |
| "reward": round(self.reward, 4), | |
| "duration_s": round(self.duration_s, 2), | |
| **self.metadata, | |
| } | |
| def _parse_seed_range(seeds_str: str) -> list[int]: | |
| """Parse a seed range like '1-5' or '1,3,5' into a list of ints.""" | |
| seeds: list[int] = [] | |
| for part in seeds_str.split(","): | |
| part = part.strip() | |
| if "-" in part: | |
| start, end = part.split("-", 1) | |
| seeds.extend(range(int(start), int(end) + 1)) | |
| else: | |
| seeds.append(int(part)) | |
| return seeds | |
| # --------------------------------------------------------------------------- | |
| # Curriculum Runner | |
| # --------------------------------------------------------------------------- | |
| class CurriculumRunner: | |
| """Orchestrate a matrix of (manifest x seed x episode) runs. | |
| Args: | |
| env: Environment instance (or factory callable). | |
| red: Red team agent. | |
| blue: Blue team agent. | |
| config: Run configuration specifying the matrix. | |
| """ | |
| def __init__( | |
| self, | |
| env: Any, | |
| red: RunnerAgent, | |
| blue: RunnerAgent, | |
| config: RunConfig, | |
| ) -> None: | |
| self.env = env | |
| self.red = red | |
| self.blue = blue | |
| self.config = config | |
| self._results: list[EpisodeRecord] = [] | |
| self._manifest_cache: dict[str, dict[str, Any]] = {} | |
| self._snapshot_builder: Any = None | |
| def results(self) -> list[EpisodeRecord]: | |
| """All episode results collected so far.""" | |
| return list(self._results) | |
| def run(self) -> list[EpisodeRecord]: | |
| """Run the full (manifest x seed x episode) matrix. | |
| Returns: | |
| List of EpisodeRecord for each completed episode. | |
| """ | |
| self._results = [] | |
| total = ( | |
| len(self.config.manifests) | |
| * len(self.config.seeds) | |
| * self.config.episodes_per_seed | |
| ) | |
| run_idx = 0 | |
| for manifest_path in self.config.manifests: | |
| for seed in self.config.seeds: | |
| for ep in range(self.config.episodes_per_seed): | |
| run_idx += 1 | |
| logger.info( | |
| "Run %d/%d: manifest=%s seed=%d episode=%d", | |
| run_idx, | |
| total, | |
| manifest_path, | |
| seed, | |
| ep + 1, | |
| ) | |
| record = self._run_episode(manifest_path, seed, ep + 1) | |
| self._results.append(record) | |
| return list(self._results) | |
| def _run_episode( | |
| self, manifest_path: str, seed: int, episode_num: int | |
| ) -> EpisodeRecord: | |
| """Run a single episode and return an EpisodeRecord.""" | |
| from open_range.models import RangeAction | |
| start = time.time() | |
| reset_kwargs: dict[str, Any] = {"seed": seed, "manifest_path": manifest_path} | |
| snapshot = self._build_snapshot_for_manifest(manifest_path, seed) | |
| if snapshot is not None: | |
| reset_kwargs["snapshot"] = snapshot | |
| try: | |
| obs = self.env.reset(**reset_kwargs) | |
| except Exception as exc: | |
| logger.error("Reset failed: %s", exc) | |
| return EpisodeRecord( | |
| manifest=manifest_path, | |
| seed=seed, | |
| episode=episode_num, | |
| outcome="error", | |
| metadata={"error": str(exc), "reset_kwargs": sorted(reset_kwargs.keys())}, | |
| ) | |
| briefing = getattr(obs, "stdout", str(obs)) | |
| self.red.reset(briefing=briefing, role="red") | |
| self.blue.reset(briefing=briefing, role="blue") | |
| step = 0 | |
| total_reward = 0.0 | |
| done = getattr(obs, "done", False) | |
| while not done and step < self.config.max_steps: | |
| # Red turn | |
| try: | |
| red_cmd = self.red.act(obs) | |
| obs = self.env.step(RangeAction(command=red_cmd, mode="red")) | |
| total_reward += getattr(obs, "reward", 0.0) or 0.0 | |
| step += 1 | |
| done = getattr(obs, "done", False) | |
| except Exception as exc: | |
| logger.warning("Red step failed: %s", exc) | |
| break | |
| if done: | |
| break | |
| # Blue turn | |
| try: | |
| blue_cmd = self.blue.act(obs) | |
| obs = self.env.step(RangeAction(command=blue_cmd, mode="blue")) | |
| total_reward += getattr(obs, "reward", 0.0) or 0.0 | |
| step += 1 | |
| done = getattr(obs, "done", False) | |
| except Exception as exc: | |
| logger.warning("Blue step failed: %s", exc) | |
| break | |
| duration = time.time() - start | |
| # Gather state | |
| env_state = self.env.state | |
| flags_found = list(getattr(env_state, "flags_found", [])) | |
| # Determine outcome | |
| if flags_found: | |
| outcome = "flag_captured" | |
| elif step >= self.config.max_steps: | |
| outcome = "timeout" | |
| elif done: | |
| outcome = "done" | |
| else: | |
| outcome = "error" | |
| return EpisodeRecord( | |
| manifest=manifest_path, | |
| seed=seed, | |
| episode=episode_num, | |
| outcome=outcome, | |
| steps=step, | |
| flags_found=flags_found, | |
| reward=total_reward, | |
| duration_s=duration, | |
| ) | |
| def _build_snapshot_for_manifest(self, manifest_path: str, seed: int) -> Any | None: | |
| """Build a deterministic snapshot from the manifest for this episode. | |
| This grounds the manifest axis in real environment input rather than | |
| using manifest names only for reporting metadata. | |
| """ | |
| from open_range.builder.builder import TemplateOnlyBuilder | |
| from open_range.protocols import BuildContext | |
| manifest = self._manifest_cache.get(manifest_path) | |
| if manifest is None: | |
| manifest_file = Path(manifest_path) | |
| if not manifest_file.exists(): | |
| logger.debug( | |
| "Manifest path %s not found; skipping snapshot build and relying on env.reset kwargs only", | |
| manifest_path, | |
| ) | |
| return None | |
| with open(manifest_file) as f: | |
| manifest = yaml.safe_load(f) | |
| if not isinstance(manifest, dict): | |
| raise ValueError(f"Manifest {manifest_path!r} did not parse to a mapping") | |
| self._manifest_cache[manifest_path] = manifest | |
| if self._snapshot_builder is None: | |
| self._snapshot_builder = TemplateOnlyBuilder() | |
| tier = int(manifest.get("tier", 1) or 1) | |
| context = BuildContext(seed=seed, tier=tier) | |
| return self._run_coro_sync(self._snapshot_builder.build(manifest, context)) | |
| def _run_coro_sync(coro: Any) -> Any: | |
| """Run an async coroutine in sync code, including notebook event loops.""" | |
| try: | |
| asyncio.get_running_loop() | |
| except RuntimeError: | |
| return asyncio.run(coro) | |
| from concurrent.futures import ThreadPoolExecutor | |
| with ThreadPoolExecutor(max_workers=1) as pool: | |
| return pool.submit(lambda: asyncio.run(coro)).result() | |
| def save_results(self, path: str | Path) -> int: | |
| """Save results to a JSONL file. | |
| Args: | |
| path: Output file path. | |
| Returns: | |
| Number of records written. | |
| """ | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "w") as f: | |
| for record in self._results: | |
| f.write(json.dumps(record.to_dict()) + "\n") | |
| return len(self._results) | |
| def results_summary(self) -> dict[str, Any]: | |
| """Aggregate pass/fail and reward statistics. | |
| Returns: | |
| Dict with: | |
| - total_runs: total episodes run | |
| - outcomes: outcome -> count | |
| - avg_reward: mean reward | |
| - avg_steps: mean steps | |
| - avg_duration_s: mean duration | |
| - by_manifest: per-manifest stats | |
| - by_seed: per-seed stats | |
| """ | |
| if not self._results: | |
| return { | |
| "total_runs": 0, | |
| "outcomes": {}, | |
| "avg_reward": 0.0, | |
| "avg_steps": 0.0, | |
| "avg_duration_s": 0.0, | |
| "by_manifest": {}, | |
| "by_seed": {}, | |
| } | |
| outcomes: dict[str, int] = {} | |
| rewards: list[float] = [] | |
| steps: list[int] = [] | |
| durations: list[float] = [] | |
| by_manifest: dict[str, list[EpisodeRecord]] = {} | |
| by_seed: dict[int, list[EpisodeRecord]] = {} | |
| for r in self._results: | |
| outcomes[r.outcome] = outcomes.get(r.outcome, 0) + 1 | |
| rewards.append(r.reward) | |
| steps.append(r.steps) | |
| durations.append(r.duration_s) | |
| by_manifest.setdefault(r.manifest, []).append(r) | |
| by_seed.setdefault(r.seed, []).append(r) | |
| def _agg(records: list[EpisodeRecord]) -> dict[str, Any]: | |
| oc: dict[str, int] = {} | |
| for rec in records: | |
| oc[rec.outcome] = oc.get(rec.outcome, 0) + 1 | |
| return { | |
| "count": len(records), | |
| "outcomes": oc, | |
| "avg_reward": round( | |
| sum(rec.reward for rec in records) / len(records), 4 | |
| ), | |
| "avg_steps": round( | |
| sum(rec.steps for rec in records) / len(records), 2 | |
| ), | |
| } | |
| return { | |
| "total_runs": len(self._results), | |
| "outcomes": outcomes, | |
| "avg_reward": round(sum(rewards) / len(rewards), 4), | |
| "avg_steps": round(sum(steps) / len(steps), 2), | |
| "avg_duration_s": round(sum(durations) / len(durations), 2), | |
| "by_manifest": {m: _agg(recs) for m, recs in sorted(by_manifest.items())}, | |
| "by_seed": {s: _agg(recs) for s, recs in sorted(by_seed.items())}, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Run OpenRange curriculum evaluation matrix", | |
| ) | |
| parser.add_argument( | |
| "--manifest", | |
| nargs="+", | |
| required=True, | |
| help="One or more manifest YAML paths", | |
| ) | |
| parser.add_argument( | |
| "--seeds", | |
| default="1-3", | |
| help="Seed range, e.g. '1-5' or '1,3,7' (default: 1-3)", | |
| ) | |
| parser.add_argument( | |
| "--episodes", | |
| type=int, | |
| default=1, | |
| help="Episodes per (manifest, seed) pair (default: 1)", | |
| ) | |
| parser.add_argument( | |
| "--max-steps", | |
| type=int, | |
| default=100, | |
| help="Maximum steps per episode (default: 100)", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| default="results.jsonl", | |
| help="Output JSONL path (default: results.jsonl)", | |
| ) | |
| args = parser.parse_args() | |
| config = RunConfig.from_cli( | |
| manifests=args.manifest, | |
| seeds_str=args.seeds, | |
| episodes=args.episodes, | |
| max_steps=args.max_steps, | |
| ) | |
| print( | |
| f"Curriculum run: {len(config.manifests)} manifests x " | |
| f"{len(config.seeds)} seeds x {config.episodes_per_seed} episodes " | |
| f"= {len(config.manifests) * len(config.seeds) * config.episodes_per_seed} total runs" | |
| ) | |
| print(f"Max steps per episode: {config.max_steps}") | |
| print(f"Output: {args.output}") | |
| print() | |
| # In CLI mode, we need a real environment and agents. | |
| # For now, report the config and exit -- actual runs require | |
| # a running environment instance. | |
| print( | |
| "Note: CLI mode requires a running environment. " | |
| "Use CurriculumRunner programmatically with your environment and agents." | |
| ) | |
| print(f"Config: manifests={config.manifests}, seeds={config.seeds}, " | |
| f"episodes={config.episodes_per_seed}") | |
| sys.exit(0) | |
| if __name__ == "__main__": | |
| main() | |