Spaces:
Runtime error
Runtime error
Aaron Brown
Add task engine, exposure policy, auth scenario, pivot mechanics, curriculum wiring
769dd2e | """Curriculum tracker for OpenRange training. | |
| Tracks per-vuln-class and per-tier solve/detection rates across episodes. | |
| Feeds runtime context to the Builder/Mutator so it can target agent | |
| weaknesses and adjust difficulty. | |
| Example:: | |
| tracker = CurriculumTracker() | |
| tracker.record_episode( | |
| snapshot_id="snap-001", | |
| vuln_classes=["sqli", "weak_creds"], | |
| red_solved=True, | |
| blue_detected=False, | |
| tier=1, | |
| ) | |
| ctx = tracker.get_build_context() | |
| # ctx = { | |
| # "previous_vuln_classes": ["sqli", "weak_creds"], | |
| # "red_solve_rate": 1.0, | |
| # "blue_detect_rate": 0.0, | |
| # "weak_areas": ["sqli", "weak_creds"], | |
| # "recent_attack_surfaces": [...], | |
| # "episode_count": 1, | |
| # } | |
| """ | |
| from __future__ import annotations | |
| from collections import defaultdict | |
| from typing import Any | |
| class CurriculumTracker: | |
| """Track episode outcomes for curriculum-driven snapshot generation. | |
| Maintains per-vuln-class and per-tier statistics so the Builder | |
| can target agent weaknesses and calibrate difficulty. | |
| """ | |
| def __init__(self, history_window: int = 20) -> None: | |
| self.history_window = history_window | |
| self.vuln_stats: dict[str, dict[str, int]] = defaultdict( | |
| lambda: {"attempts": 0, "red_solves": 0, "blue_detects": 0} | |
| ) | |
| self.tier_stats: dict[int, dict[str, Any]] = defaultdict( | |
| lambda: {"episodes": 0, "red_solves": 0, "blue_detects": 0} | |
| ) | |
| self.episode_history: list[dict[str, Any]] = [] | |
| def record_episode( | |
| self, | |
| snapshot_id: str, | |
| vuln_classes: list[str], | |
| red_solved: bool, | |
| blue_detected: bool, | |
| tier: int = 1, | |
| attack_surfaces: list[str] | None = None, | |
| extra: dict[str, Any] | None = None, | |
| ) -> None: | |
| """Record the outcome of a completed episode. | |
| Args: | |
| snapshot_id: Identifier of the snapshot used. | |
| vuln_classes: Vulnerability classes planted in the episode. | |
| red_solved: Whether Red captured a flag. | |
| blue_detected: Whether Blue detected the attack. | |
| tier: Difficulty tier of the episode. | |
| attack_surfaces: Injection points used (e.g. "/search?q="). | |
| extra: Additional metadata to store. | |
| """ | |
| record = { | |
| "snapshot_id": snapshot_id, | |
| "vuln_classes": vuln_classes, | |
| "red_solved": red_solved, | |
| "blue_detected": blue_detected, | |
| "tier": tier, | |
| "attack_surfaces": attack_surfaces or [], | |
| **(extra or {}), | |
| } | |
| self.episode_history.append(record) | |
| # Trim to window | |
| if len(self.episode_history) > self.history_window * 2: | |
| self.episode_history = self.episode_history[-self.history_window:] | |
| # Update per-vuln stats | |
| for vc in vuln_classes: | |
| self.vuln_stats[vc]["attempts"] += 1 | |
| if red_solved: | |
| self.vuln_stats[vc]["red_solves"] += 1 | |
| if blue_detected: | |
| self.vuln_stats[vc]["blue_detects"] += 1 | |
| # Update per-tier stats | |
| self.tier_stats[tier]["episodes"] += 1 | |
| if red_solved: | |
| self.tier_stats[tier]["red_solves"] += 1 | |
| if blue_detected: | |
| self.tier_stats[tier]["blue_detects"] += 1 | |
| def get_build_context(self) -> dict[str, Any]: | |
| """Generate runtime context for the Builder/Mutator. | |
| Returns a dict suitable for passing as ``BuildContext`` fields: | |
| previous_vuln_classes, red_solve_rate, blue_detect_rate, | |
| weak_areas, recent_attack_surfaces, episode_count. | |
| """ | |
| recent = self.episode_history[-self.history_window:] | |
| # Previous vuln classes (last 5 episodes for diversity enforcement) | |
| last_5 = self.episode_history[-5:] | |
| prev_vulns: list[str] = [] | |
| for ep in last_5: | |
| prev_vulns.extend(ep.get("vuln_classes", [])) | |
| # Overall solve/detect rates over window | |
| if recent: | |
| red_solve_rate = sum(1 for e in recent if e["red_solved"]) / len(recent) | |
| blue_detect_rate = sum(1 for e in recent if e["blue_detected"]) / len(recent) | |
| else: | |
| red_solve_rate = 0.0 | |
| blue_detect_rate = 0.0 | |
| # Weak areas: vuln classes where Red solves >80% or Blue detects <20% | |
| weak_areas: list[str] = [] | |
| for vc, stats in self.vuln_stats.items(): | |
| if stats["attempts"] >= 3: | |
| solve_rate = stats["red_solves"] / stats["attempts"] | |
| detect_rate = stats["blue_detects"] / stats["attempts"] | |
| # Red finds these too easy -- need harder variants | |
| if solve_rate > 0.8: | |
| weak_areas.append(vc) | |
| # Blue can't detect these -- need more practice | |
| if detect_rate < 0.2: | |
| weak_areas.append(vc) | |
| # Deduplicate | |
| weak_areas = list(dict.fromkeys(weak_areas)) | |
| # Recent attack surfaces (last 5 episodes) | |
| recent_surfaces: list[str] = [] | |
| for ep in last_5: | |
| recent_surfaces.extend(ep.get("attack_surfaces", [])) | |
| return { | |
| "previous_vuln_classes": prev_vulns, | |
| "red_solve_rate": red_solve_rate, | |
| "blue_detect_rate": blue_detect_rate, | |
| "weak_areas": weak_areas, | |
| "recent_attack_surfaces": recent_surfaces, | |
| "episode_count": len(self.episode_history), | |
| } | |
| def should_escalate_tier(self, current_tier: int, threshold: float = 0.8) -> bool: | |
| """Check if the agent should move to a harder tier. | |
| Escalation happens when Red solve rate exceeds ``threshold`` | |
| over the history window for the current tier. | |
| """ | |
| stats = self.tier_stats.get(current_tier) | |
| if not stats or stats["episodes"] < 5: | |
| return False | |
| solve_rate = stats["red_solves"] / stats["episodes"] | |
| return solve_rate >= threshold | |
| def get_vuln_solve_rates(self) -> dict[str, float]: | |
| """Return per-vuln-class solve rates for analysis.""" | |
| rates: dict[str, float] = {} | |
| for vc, stats in self.vuln_stats.items(): | |
| if stats["attempts"] > 0: | |
| rates[vc] = stats["red_solves"] / stats["attempts"] | |
| else: | |
| rates[vc] = 0.0 | |
| return rates | |
| def update_from_result(self, result: dict) -> None: | |
| """Update curriculum stats from an episode result. | |
| Accepts a dict with the following optional keys: | |
| - ``snapshot_id`` (str): episode/snapshot identifier | |
| - ``vuln_classes`` (list[str]): vulnerability classes in the episode | |
| - ``red_solved`` (bool): whether Red captured a flag | |
| - ``blue_detected`` (bool): whether Blue detected the attack | |
| - ``tier`` (int): difficulty tier | |
| - ``attack_surfaces`` (list[str]): injection points used | |
| - ``outcome`` (str): episode outcome (``red_win``, ``blue_win``, ``timeout``) | |
| - ``flags_found`` (list[str]): captured flags | |
| - ``steps`` (int): total steps taken | |
| If ``red_solved`` / ``blue_detected`` are not provided they are | |
| inferred from ``outcome`` and ``flags_found``. | |
| """ | |
| snapshot_id = result.get("snapshot_id", "") | |
| vuln_classes = result.get("vuln_classes", []) | |
| tier = result.get("tier", 1) | |
| attack_surfaces = result.get("attack_surfaces", []) | |
| # Infer solve/detect status if not explicitly provided | |
| if "red_solved" in result: | |
| red_solved = bool(result["red_solved"]) | |
| else: | |
| outcome = result.get("outcome", "") | |
| flags = result.get("flags_found", []) | |
| red_solved = outcome == "red_win" or bool(flags) | |
| if "blue_detected" in result: | |
| blue_detected = bool(result["blue_detected"]) | |
| else: | |
| blue_detected = result.get("outcome", "") == "blue_win" | |
| # Collect extra metadata | |
| extra_keys = { | |
| "outcome", "flags_found", "steps", | |
| "red_model", "blue_model", | |
| } | |
| extra = {k: result[k] for k in extra_keys if k in result} | |
| self.record_episode( | |
| snapshot_id=snapshot_id, | |
| vuln_classes=vuln_classes, | |
| red_solved=red_solved, | |
| blue_detected=blue_detected, | |
| tier=tier, | |
| attack_surfaces=attack_surfaces, | |
| extra=extra if extra else None, | |
| ) | |