File size: 8,523 Bytes
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769dd2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""Curriculum tracker for OpenRange training.

Tracks per-vuln-class and per-tier solve/detection rates across episodes.
Feeds runtime context to the Builder/Mutator so it can target agent
weaknesses and adjust difficulty.

Example::

    tracker = CurriculumTracker()
    tracker.record_episode(
        snapshot_id="snap-001",
        vuln_classes=["sqli", "weak_creds"],
        red_solved=True,
        blue_detected=False,
        tier=1,
    )
    ctx = tracker.get_build_context()
    # ctx = {
    #     "previous_vuln_classes": ["sqli", "weak_creds"],
    #     "red_solve_rate": 1.0,
    #     "blue_detect_rate": 0.0,
    #     "weak_areas": ["sqli", "weak_creds"],
    #     "recent_attack_surfaces": [...],
    #     "episode_count": 1,
    # }
"""

from __future__ import annotations

from collections import defaultdict
from typing import Any


class CurriculumTracker:
    """Track episode outcomes for curriculum-driven snapshot generation.

    Maintains per-vuln-class and per-tier statistics so the Builder
    can target agent weaknesses and calibrate difficulty.
    """

    def __init__(self, history_window: int = 20) -> None:
        self.history_window = history_window
        self.vuln_stats: dict[str, dict[str, int]] = defaultdict(
            lambda: {"attempts": 0, "red_solves": 0, "blue_detects": 0}
        )
        self.tier_stats: dict[int, dict[str, Any]] = defaultdict(
            lambda: {"episodes": 0, "red_solves": 0, "blue_detects": 0}
        )
        self.episode_history: list[dict[str, Any]] = []

    def record_episode(
        self,
        snapshot_id: str,
        vuln_classes: list[str],
        red_solved: bool,
        blue_detected: bool,
        tier: int = 1,
        attack_surfaces: list[str] | None = None,
        extra: dict[str, Any] | None = None,
    ) -> None:
        """Record the outcome of a completed episode.

        Args:
            snapshot_id: Identifier of the snapshot used.
            vuln_classes: Vulnerability classes planted in the episode.
            red_solved: Whether Red captured a flag.
            blue_detected: Whether Blue detected the attack.
            tier: Difficulty tier of the episode.
            attack_surfaces: Injection points used (e.g. "/search?q=").
            extra: Additional metadata to store.
        """
        record = {
            "snapshot_id": snapshot_id,
            "vuln_classes": vuln_classes,
            "red_solved": red_solved,
            "blue_detected": blue_detected,
            "tier": tier,
            "attack_surfaces": attack_surfaces or [],
            **(extra or {}),
        }
        self.episode_history.append(record)

        # Trim to window
        if len(self.episode_history) > self.history_window * 2:
            self.episode_history = self.episode_history[-self.history_window:]

        # Update per-vuln stats
        for vc in vuln_classes:
            self.vuln_stats[vc]["attempts"] += 1
            if red_solved:
                self.vuln_stats[vc]["red_solves"] += 1
            if blue_detected:
                self.vuln_stats[vc]["blue_detects"] += 1

        # Update per-tier stats
        self.tier_stats[tier]["episodes"] += 1
        if red_solved:
            self.tier_stats[tier]["red_solves"] += 1
        if blue_detected:
            self.tier_stats[tier]["blue_detects"] += 1

    def get_build_context(self) -> dict[str, Any]:
        """Generate runtime context for the Builder/Mutator.

        Returns a dict suitable for passing as ``BuildContext`` fields:
        previous_vuln_classes, red_solve_rate, blue_detect_rate,
        weak_areas, recent_attack_surfaces, episode_count.
        """
        recent = self.episode_history[-self.history_window:]

        # Previous vuln classes (last 5 episodes for diversity enforcement)
        last_5 = self.episode_history[-5:]
        prev_vulns: list[str] = []
        for ep in last_5:
            prev_vulns.extend(ep.get("vuln_classes", []))

        # Overall solve/detect rates over window
        if recent:
            red_solve_rate = sum(1 for e in recent if e["red_solved"]) / len(recent)
            blue_detect_rate = sum(1 for e in recent if e["blue_detected"]) / len(recent)
        else:
            red_solve_rate = 0.0
            blue_detect_rate = 0.0

        # Weak areas: vuln classes where Red solves >80% or Blue detects <20%
        weak_areas: list[str] = []
        for vc, stats in self.vuln_stats.items():
            if stats["attempts"] >= 3:
                solve_rate = stats["red_solves"] / stats["attempts"]
                detect_rate = stats["blue_detects"] / stats["attempts"]
                # Red finds these too easy -- need harder variants
                if solve_rate > 0.8:
                    weak_areas.append(vc)
                # Blue can't detect these -- need more practice
                if detect_rate < 0.2:
                    weak_areas.append(vc)

        # Deduplicate
        weak_areas = list(dict.fromkeys(weak_areas))

        # Recent attack surfaces (last 5 episodes)
        recent_surfaces: list[str] = []
        for ep in last_5:
            recent_surfaces.extend(ep.get("attack_surfaces", []))

        return {
            "previous_vuln_classes": prev_vulns,
            "red_solve_rate": red_solve_rate,
            "blue_detect_rate": blue_detect_rate,
            "weak_areas": weak_areas,
            "recent_attack_surfaces": recent_surfaces,
            "episode_count": len(self.episode_history),
        }

    def should_escalate_tier(self, current_tier: int, threshold: float = 0.8) -> bool:
        """Check if the agent should move to a harder tier.

        Escalation happens when Red solve rate exceeds ``threshold``
        over the history window for the current tier.
        """
        stats = self.tier_stats.get(current_tier)
        if not stats or stats["episodes"] < 5:
            return False
        solve_rate = stats["red_solves"] / stats["episodes"]
        return solve_rate >= threshold

    def get_vuln_solve_rates(self) -> dict[str, float]:
        """Return per-vuln-class solve rates for analysis."""
        rates: dict[str, float] = {}
        for vc, stats in self.vuln_stats.items():
            if stats["attempts"] > 0:
                rates[vc] = stats["red_solves"] / stats["attempts"]
            else:
                rates[vc] = 0.0
        return rates

    def update_from_result(self, result: dict) -> None:
        """Update curriculum stats from an episode result.

        Accepts a dict with the following optional keys:

        - ``snapshot_id`` (str): episode/snapshot identifier
        - ``vuln_classes`` (list[str]): vulnerability classes in the episode
        - ``red_solved`` (bool): whether Red captured a flag
        - ``blue_detected`` (bool): whether Blue detected the attack
        - ``tier`` (int): difficulty tier
        - ``attack_surfaces`` (list[str]): injection points used
        - ``outcome`` (str): episode outcome (``red_win``, ``blue_win``, ``timeout``)
        - ``flags_found`` (list[str]): captured flags
        - ``steps`` (int): total steps taken

        If ``red_solved`` / ``blue_detected`` are not provided they are
        inferred from ``outcome`` and ``flags_found``.
        """
        snapshot_id = result.get("snapshot_id", "")
        vuln_classes = result.get("vuln_classes", [])
        tier = result.get("tier", 1)
        attack_surfaces = result.get("attack_surfaces", [])

        # Infer solve/detect status if not explicitly provided
        if "red_solved" in result:
            red_solved = bool(result["red_solved"])
        else:
            outcome = result.get("outcome", "")
            flags = result.get("flags_found", [])
            red_solved = outcome == "red_win" or bool(flags)

        if "blue_detected" in result:
            blue_detected = bool(result["blue_detected"])
        else:
            blue_detected = result.get("outcome", "") == "blue_win"

        # Collect extra metadata
        extra_keys = {
            "outcome", "flags_found", "steps",
            "red_model", "blue_model",
        }
        extra = {k: result[k] for k in extra_keys if k in result}

        self.record_episode(
            snapshot_id=snapshot_id,
            vuln_classes=vuln_classes,
            red_solved=red_solved,
            blue_detected=blue_detected,
            tier=tier,
            attack_surfaces=attack_surfaces,
            extra=extra if extra else None,
        )