File size: 6,479 Bytes
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769dd2e
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769dd2e
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fedc25
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d1c75
8c486a8
 
 
 
 
49d1c75
8c486a8
 
 
 
 
 
 
 
49d1c75
8c486a8
 
 
 
 
49d1c75
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769dd2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c486a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""Episode orchestration loop.

``run_episode`` connects to an OpenRange environment and alternates
Red / Blue turns until the episode ends or ``max_steps`` is reached.
Agents only see observations -- they cannot control episode lifecycle.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from open_range.agents.protocol import EpisodeMetrics, EpisodeResult

if TYPE_CHECKING:
    from open_range.agents.protocol import RangeAgent
    from open_range.training.curriculum import CurriculumTracker

logger = logging.getLogger(__name__)


def _determine_outcome(
    flags_found: list[str],
    total_flags: int,
    steps: int,
    max_steps: int,
) -> str:
    """Determine episode outcome from final state."""
    if flags_found and len(flags_found) >= total_flags > 0:
        return "red_win"
    if steps >= max_steps:
        return "timeout"
    # If done without all flags and not timeout, Blue held
    return "blue_win"


def _compute_metrics(
    result: EpisodeResult,
    total_flags: int,
) -> EpisodeMetrics:
    """Compute episode metrics from trajectories."""
    metrics = EpisodeMetrics()

    # Flag capture rate
    if total_flags > 0:
        metrics.flag_capture_rate = len(result.flags_found) / total_flags

    # Steps to first flag
    for i, step in enumerate(result.red_trajectory):
        stdout = step.get("stdout", "")
        if "Flag accepted" in stdout or "Correct" in stdout:
            metrics.steps_to_first_flag = i + 1
            break

    # Stealth: 1 - (detected_actions / total_red_actions)
    total_red = len(result.red_trajectory)
    if total_red > 0:
        # Count Blue findings that overlap with Red actions
        blue_findings = sum(
            1 for s in result.blue_trajectory if s.get("command", "").startswith("submit_finding")
        )
        detected = min(blue_findings, total_red)
        metrics.stealth = 1.0 - (detected / total_red)
    else:
        metrics.stealth = 1.0

    # Detection TP: how many Red actions Blue detected
    if total_red > 0:
        blue_findings = sum(
            1 for s in result.blue_trajectory if s.get("command", "").startswith("submit_finding")
        )
        metrics.detection_tp = min(blue_findings / total_red, 1.0)

    return metrics


def run_episode(
    env: object,
    red: RangeAgent,
    blue: RangeAgent,
    max_steps: int = 100,
    red_model: str = "",
    blue_model: str = "",
    curriculum: CurriculumTracker | None = None,
) -> EpisodeResult:
    """Run one tandem Red + Blue episode.

    The orchestration layer calls ``reset()`` and alternates agent turns.
    Agents only see observations -- they cannot control episode lifecycle.

    This function works with the ``RangeEnvironment`` directly (no HTTP).
    For remote environments, use the async variant or call through the client.

    Args:
        env: A ``RangeEnvironment`` instance (or anything with ``reset``/``step``/``state``).
        red: Red team agent (satisfies ``RangeAgent`` protocol).
        blue: Blue team agent (satisfies ``RangeAgent`` protocol).
        max_steps: Maximum total steps (Red + Blue combined).
        red_model: Model identifier for logging.
        blue_model: Model identifier for logging.

    Returns:
        ``EpisodeResult`` with trajectories, metrics, and outcome.
    """
    from open_range.models import RangeAction

    # Reset environment
    obs = env.reset()
    briefing = obs.stdout

    # Initialize agents
    red.reset(briefing=briefing, role="red")
    blue.reset(briefing=briefing, role="blue")

    red_trajectory: list[dict] = []
    blue_trajectory: list[dict] = []
    step = 0

    while not obs.done and step < max_steps:
        # Red's turn
        red_cmd = red.act(obs)
        obs = env.step(RangeAction(command=red_cmd, mode="red"))
        red_trajectory.append({
            "command": red_cmd,
            "stdout": obs.stdout,
            "stderr": getattr(obs, "stderr", ""),
            "alerts": getattr(obs, "alerts", []),
            "reward": obs.reward,
        })
        step += 1

        if obs.done:
            break

        # Blue's turn
        blue_cmd = blue.act(obs)
        obs = env.step(RangeAction(command=blue_cmd, mode="blue"))
        blue_trajectory.append({
            "command": blue_cmd,
            "stdout": obs.stdout,
            "stderr": getattr(obs, "stderr", ""),
            "alerts": getattr(obs, "alerts", []),
            "reward": obs.reward,
        })
        step += 1

    # Gather final state
    env_state = env.state
    flags_found = getattr(env_state, "flags_found", [])
    tier = getattr(env_state, "tier", 1)
    snapshot_id = getattr(env_state, "episode_id", "")

    # Determine total flags available
    snapshot = getattr(env, "snapshot", None) or getattr(env, "_snapshot", None)
    total_flags = len(snapshot.flags) if snapshot and hasattr(snapshot, "flags") else 0

    outcome = _determine_outcome(flags_found, total_flags, step, max_steps)

    result = EpisodeResult(
        red_trajectory=red_trajectory,
        blue_trajectory=blue_trajectory,
        flags_found=list(flags_found),
        steps=step,
        tier=tier,
        snapshot_id=snapshot_id,
        red_model=red_model or getattr(red, "model", ""),
        blue_model=blue_model or getattr(blue, "model", ""),
        outcome=outcome,
    )

    result.metrics = _compute_metrics(result, total_flags)

    logger.info(
        "Episode %s complete: outcome=%s, steps=%d, flags=%d/%d",
        snapshot_id,
        outcome,
        step,
        len(flags_found),
        total_flags,
    )

    # Curriculum feedback wiring (#34)
    if curriculum is not None:
        # Extract vuln classes from snapshot truth graph if available
        vuln_classes: list[str] = []
        if snapshot and hasattr(snapshot, "truth_graph") and snapshot.truth_graph:
            tg = snapshot.truth_graph
            vulns = getattr(tg, "vulns", [])
            vuln_classes = [getattr(v, "type", "") for v in vulns if getattr(v, "type", "")]

        curriculum.update_from_result({
            "snapshot_id": snapshot_id,
            "vuln_classes": vuln_classes,
            "outcome": outcome,
            "flags_found": list(flags_found),
            "steps": step,
            "tier": tier,
            "red_model": red_model or getattr(red, "model", ""),
            "blue_model": blue_model or getattr(blue, "model", ""),
        })

    return result