Spaces:
Runtime error
Runtime error
File size: 8,523 Bytes
8c486a8 769dd2e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | """Curriculum tracker for OpenRange training.
Tracks per-vuln-class and per-tier solve/detection rates across episodes.
Feeds runtime context to the Builder/Mutator so it can target agent
weaknesses and adjust difficulty.
Example::
tracker = CurriculumTracker()
tracker.record_episode(
snapshot_id="snap-001",
vuln_classes=["sqli", "weak_creds"],
red_solved=True,
blue_detected=False,
tier=1,
)
ctx = tracker.get_build_context()
# ctx = {
# "previous_vuln_classes": ["sqli", "weak_creds"],
# "red_solve_rate": 1.0,
# "blue_detect_rate": 0.0,
# "weak_areas": ["sqli", "weak_creds"],
# "recent_attack_surfaces": [...],
# "episode_count": 1,
# }
"""
from __future__ import annotations
from collections import defaultdict
from typing import Any
class CurriculumTracker:
"""Track episode outcomes for curriculum-driven snapshot generation.
Maintains per-vuln-class and per-tier statistics so the Builder
can target agent weaknesses and calibrate difficulty.
"""
def __init__(self, history_window: int = 20) -> None:
self.history_window = history_window
self.vuln_stats: dict[str, dict[str, int]] = defaultdict(
lambda: {"attempts": 0, "red_solves": 0, "blue_detects": 0}
)
self.tier_stats: dict[int, dict[str, Any]] = defaultdict(
lambda: {"episodes": 0, "red_solves": 0, "blue_detects": 0}
)
self.episode_history: list[dict[str, Any]] = []
def record_episode(
self,
snapshot_id: str,
vuln_classes: list[str],
red_solved: bool,
blue_detected: bool,
tier: int = 1,
attack_surfaces: list[str] | None = None,
extra: dict[str, Any] | None = None,
) -> None:
"""Record the outcome of a completed episode.
Args:
snapshot_id: Identifier of the snapshot used.
vuln_classes: Vulnerability classes planted in the episode.
red_solved: Whether Red captured a flag.
blue_detected: Whether Blue detected the attack.
tier: Difficulty tier of the episode.
attack_surfaces: Injection points used (e.g. "/search?q=").
extra: Additional metadata to store.
"""
record = {
"snapshot_id": snapshot_id,
"vuln_classes": vuln_classes,
"red_solved": red_solved,
"blue_detected": blue_detected,
"tier": tier,
"attack_surfaces": attack_surfaces or [],
**(extra or {}),
}
self.episode_history.append(record)
# Trim to window
if len(self.episode_history) > self.history_window * 2:
self.episode_history = self.episode_history[-self.history_window:]
# Update per-vuln stats
for vc in vuln_classes:
self.vuln_stats[vc]["attempts"] += 1
if red_solved:
self.vuln_stats[vc]["red_solves"] += 1
if blue_detected:
self.vuln_stats[vc]["blue_detects"] += 1
# Update per-tier stats
self.tier_stats[tier]["episodes"] += 1
if red_solved:
self.tier_stats[tier]["red_solves"] += 1
if blue_detected:
self.tier_stats[tier]["blue_detects"] += 1
def get_build_context(self) -> dict[str, Any]:
"""Generate runtime context for the Builder/Mutator.
Returns a dict suitable for passing as ``BuildContext`` fields:
previous_vuln_classes, red_solve_rate, blue_detect_rate,
weak_areas, recent_attack_surfaces, episode_count.
"""
recent = self.episode_history[-self.history_window:]
# Previous vuln classes (last 5 episodes for diversity enforcement)
last_5 = self.episode_history[-5:]
prev_vulns: list[str] = []
for ep in last_5:
prev_vulns.extend(ep.get("vuln_classes", []))
# Overall solve/detect rates over window
if recent:
red_solve_rate = sum(1 for e in recent if e["red_solved"]) / len(recent)
blue_detect_rate = sum(1 for e in recent if e["blue_detected"]) / len(recent)
else:
red_solve_rate = 0.0
blue_detect_rate = 0.0
# Weak areas: vuln classes where Red solves >80% or Blue detects <20%
weak_areas: list[str] = []
for vc, stats in self.vuln_stats.items():
if stats["attempts"] >= 3:
solve_rate = stats["red_solves"] / stats["attempts"]
detect_rate = stats["blue_detects"] / stats["attempts"]
# Red finds these too easy -- need harder variants
if solve_rate > 0.8:
weak_areas.append(vc)
# Blue can't detect these -- need more practice
if detect_rate < 0.2:
weak_areas.append(vc)
# Deduplicate
weak_areas = list(dict.fromkeys(weak_areas))
# Recent attack surfaces (last 5 episodes)
recent_surfaces: list[str] = []
for ep in last_5:
recent_surfaces.extend(ep.get("attack_surfaces", []))
return {
"previous_vuln_classes": prev_vulns,
"red_solve_rate": red_solve_rate,
"blue_detect_rate": blue_detect_rate,
"weak_areas": weak_areas,
"recent_attack_surfaces": recent_surfaces,
"episode_count": len(self.episode_history),
}
def should_escalate_tier(self, current_tier: int, threshold: float = 0.8) -> bool:
"""Check if the agent should move to a harder tier.
Escalation happens when Red solve rate exceeds ``threshold``
over the history window for the current tier.
"""
stats = self.tier_stats.get(current_tier)
if not stats or stats["episodes"] < 5:
return False
solve_rate = stats["red_solves"] / stats["episodes"]
return solve_rate >= threshold
def get_vuln_solve_rates(self) -> dict[str, float]:
"""Return per-vuln-class solve rates for analysis."""
rates: dict[str, float] = {}
for vc, stats in self.vuln_stats.items():
if stats["attempts"] > 0:
rates[vc] = stats["red_solves"] / stats["attempts"]
else:
rates[vc] = 0.0
return rates
def update_from_result(self, result: dict) -> None:
"""Update curriculum stats from an episode result.
Accepts a dict with the following optional keys:
- ``snapshot_id`` (str): episode/snapshot identifier
- ``vuln_classes`` (list[str]): vulnerability classes in the episode
- ``red_solved`` (bool): whether Red captured a flag
- ``blue_detected`` (bool): whether Blue detected the attack
- ``tier`` (int): difficulty tier
- ``attack_surfaces`` (list[str]): injection points used
- ``outcome`` (str): episode outcome (``red_win``, ``blue_win``, ``timeout``)
- ``flags_found`` (list[str]): captured flags
- ``steps`` (int): total steps taken
If ``red_solved`` / ``blue_detected`` are not provided they are
inferred from ``outcome`` and ``flags_found``.
"""
snapshot_id = result.get("snapshot_id", "")
vuln_classes = result.get("vuln_classes", [])
tier = result.get("tier", 1)
attack_surfaces = result.get("attack_surfaces", [])
# Infer solve/detect status if not explicitly provided
if "red_solved" in result:
red_solved = bool(result["red_solved"])
else:
outcome = result.get("outcome", "")
flags = result.get("flags_found", [])
red_solved = outcome == "red_win" or bool(flags)
if "blue_detected" in result:
blue_detected = bool(result["blue_detected"])
else:
blue_detected = result.get("outcome", "") == "blue_win"
# Collect extra metadata
extra_keys = {
"outcome", "flags_found", "steps",
"red_model", "blue_model",
}
extra = {k: result[k] for k in extra_keys if k in result}
self.record_episode(
snapshot_id=snapshot_id,
vuln_classes=vuln_classes,
red_solved=red_solved,
blue_detected=blue_detected,
tier=tier,
attack_surfaces=attack_surfaces,
extra=extra if extra else None,
)
|