Ajayyy00 Claude Sonnet 4.6 commited on
Commit
5719ec3
·
1 Parent(s): eeadada

Add FSP multi-agent architecture: Red Team LLM action space + alternating turns

Browse files

Step 1 — models.py
• New Red action classes: LateralPivot, DeployPayload, EvadeDetection, PassTurn
• RedActionWrapper (mirrors SOCActionWrapper for WS/HTTP routing)
• RED_ACTION_TYPES frozenset for payload routing
• SOCObservation: active_turn + red_observation fields
• SOCState: active_turn field

Step 2 — play_environment.py
• fsp_mode=False constructor flag; True enables strict Blue/Red alternation
• step() dispatches to _step_blue() or _step_red() based on action type
• _step_blue(): executes Blue action; in fsp_mode auto-flips to red without
incrementing step_count; in legacy mode auto-PassTurns + increments (backward compat)
• _step_red(): executes Red action, increments step_count, flips to blue
• Red handlers: _handle_lateral_pivot, _handle_deploy_payload,
_handle_evade_detection, _handle_pass_turn
• _generate_red_observation(): compromised_hosts + blue_actions_detected
• Removed deterministic _execute_lateral_pivot and _maybe_reinfect;
_adversary_react is now a no-op (Red LLM drives all attack decisions)
• _build_observation: exposes active_turn and red_observation

Step 3 — dashboard_server.py
• Step handler routes RED_ACTION_TYPES to RedActionWrapper

Step 4 — inference.py
• RED_SYSTEM_PROMPT, format_red_observation, get_red_model_action
• run_episode(fsp=True): Blue LLM → env.step(SOCActionWrapper) →
if active_turn==red → Red LLM → env.step(RedActionWrapper) → repeat
• FSP_MODE env var to toggle; RED_MODEL_NAME for separate Red model

Step 5 — test_integration.py
• test_lateral_pivot_red_action replaces removed _execute_lateral_pivot test

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

dashboard_server.py CHANGED
@@ -176,9 +176,16 @@ async def ws_session(websocket: WebSocket, session_id: str):
176
  continue
177
 
178
  try:
179
- from models import SOCActionWrapper # noqa: PLC0415
180
  action_fields = {k: v for k, v in msg.items() if k != "type"}
181
- action = SOCActionWrapper.model_validate(action_fields)
 
 
 
 
 
 
 
182
  obs = await _run(env.step, action)
183
  await websocket.send_json({
184
  "type": "step_ok",
 
176
  continue
177
 
178
  try:
179
+ from models import SOCActionWrapper, RedActionWrapper, RED_ACTION_TYPES # noqa: PLC0415
180
  action_fields = {k: v for k, v in msg.items() if k != "type"}
181
+ action_type_str = action_fields.get("type", "")
182
+
183
+ # Route to Red or Blue wrapper based on action type
184
+ if action_type_str in RED_ACTION_TYPES:
185
+ action = RedActionWrapper.model_validate(action_fields)
186
+ else:
187
+ action = SOCActionWrapper.model_validate(action_fields)
188
+
189
  obs = await _run(env.step, action)
190
  await websocket.send_json({
191
  "type": "step_ok",
inference.py CHANGED
@@ -16,9 +16,11 @@ HACKATHON RULES:
16
  - Must work on vcpu=2, memory=8gb
17
 
18
  Environment Variables:
19
- API_BASE_URL - The API endpoint for the LLM
20
- MODEL_NAME - The model identifier to use for inference
21
- HF_TOKEN - Your Hugging Face / API key
 
 
22
  """
23
 
24
  import asyncio
@@ -29,29 +31,31 @@ from typing import Any, Dict, List, Optional
29
 
30
  from openai import OpenAI
31
 
32
- from models import SOCActionWrapper, SOCObservation
33
  from server.play_environment import CyberSOCEnvironment
34
 
35
  # =============================================================================
36
  # Configuration (from environment variables)
37
  # =============================================================================
38
 
39
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
40
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
41
- HF_TOKEN = os.getenv("HF_TOKEN")
 
 
42
 
43
  BENCHMARK = "cybersocenv"
44
- TASKS = ["easy", "medium", "hard"]
45
  MAX_STEPS = {"easy": 15, "medium": 25, "hard": 30}
 
46
  TEMPERATURE = 0.1
47
- MAX_TOKENS = 1024
48
 
49
- # Scoring: normalize rewards to [0, 1]
50
- MAX_POSSIBLE_REWARD = 2.0 # Approximate max reward per episode
51
  SUCCESS_SCORE_THRESHOLD = 0.3
52
 
53
  # =============================================================================
54
- # System Prompt
55
  # =============================================================================
56
 
57
  SYSTEM_PROMPT = textwrap.dedent("""
@@ -79,9 +83,43 @@ SYSTEM_PROMPT = textwrap.dedent("""
79
  - You have a limited number of steps. Be efficient.
80
  """).strip()
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # =============================================================================
84
- # Logging Helpers (EXACT hackathon format — lowercase booleans, null errors)
85
  # =============================================================================
86
 
87
  def log_start(task: str, env: str, model: str) -> None:
@@ -90,7 +128,7 @@ def log_start(task: str, env: str, model: str) -> None:
90
 
91
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
92
  error_val = error if error else "null"
93
- done_val = str(done).lower()
94
  print(
95
  f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
96
  flush=True,
@@ -104,16 +142,14 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
104
  flush=True,
105
  )
106
 
107
-
108
  # =============================================================================
109
- # Observation Formatting for LLM
110
  # =============================================================================
111
 
112
  def format_observation(obs: SOCObservation) -> str:
113
- """Format observation into readable text for the LLM."""
114
  parts = []
115
 
116
- # Alert queue
117
  if obs.alert_queue:
118
  parts.append(f"## Active Alerts ({len(obs.alert_queue)}):")
119
  for a in obs.alert_queue:
@@ -124,14 +160,14 @@ def format_observation(obs: SOCObservation) -> str:
124
  if a.ioc_indicators:
125
  parts.append(f" IOCs: {', '.join(a.ioc_indicators)}")
126
 
127
- # Network topology
128
  topo = obs.network_topology
129
  parts.append(f"\n## Network Status:")
130
- parts.append(f" Compromised: {topo.compromised_count} | "
131
- f"Isolated: {topo.isolated_count} | "
132
- f"Online: {topo.online_count}")
 
 
133
 
134
- # Forensics
135
  if obs.host_forensics:
136
  f = obs.host_forensics
137
  parts.append(f"\n## Forensics Result ({f.hostname}):")
@@ -141,26 +177,53 @@ def format_observation(obs: SOCObservation) -> str:
141
  parts.append(f" Network connections: {f.network_connections}")
142
  parts.append(f" Memory artifacts: {f.memory_artifacts}")
143
 
144
- # Active threats
145
  parts.append(f"\n## Active Threats: {obs.active_threats if obs.active_threats else 'None (all contained!)'}")
146
  parts.append(f"## Business Impact: {obs.business_impact_score:.2f}")
147
  parts.append(f"## Step: {obs.step_count} / {obs.max_steps}")
148
 
149
- # Timeline (last 5)
150
  if obs.timeline:
151
  parts.append(f"\n## Recent Actions:")
152
  for t in obs.timeline[-5:]:
153
- parts.append(f" Step {t.step}: {t.action_type} -> {t.target} (reward={t.reward:.2f})")
 
154
 
155
  return "\n".join(parts)
156
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def parse_llm_action(content: str) -> Dict[str, Any]:
159
  """Parse the LLM's response into a valid action dict."""
160
  content = content.strip()
161
  if content.startswith("```"):
162
- lines = content.split("\n")
163
- lines = [l for l in lines if not l.strip().startswith("```")]
164
  content = "\n".join(lines).strip()
165
 
166
  try:
@@ -170,7 +233,6 @@ def parse_llm_action(content: str) -> Dict[str, Any]:
170
  except json.JSONDecodeError:
171
  pass
172
 
173
- # Try to find JSON in the response
174
  for start in range(len(content)):
175
  if content[start] == "{":
176
  for end in range(len(content), start, -1):
@@ -185,6 +247,10 @@ def parse_llm_action(content: str) -> Dict[str, Any]:
185
  raise ValueError(f"Could not parse action from LLM response: {content[:200]}")
186
 
187
 
 
 
 
 
188
  def get_model_action(
189
  client: OpenAI,
190
  step: int,
@@ -192,7 +258,7 @@ def get_model_action(
192
  task_id: str,
193
  history: List[str],
194
  ) -> str:
195
- """Get the next action from the LLM."""
196
  obs_text = format_observation(obs)
197
 
198
  if step == 1:
@@ -213,7 +279,7 @@ def get_model_action(
213
  model=MODEL_NAME,
214
  messages=[
215
  {"role": "system", "content": SYSTEM_PROMPT},
216
- {"role": "user", "content": user_content},
217
  ],
218
  temperature=TEMPERATURE,
219
  max_tokens=MAX_TOKENS,
@@ -223,73 +289,155 @@ def get_model_action(
223
  return text if text else '{"type": "query_host", "hostname": "WS-001"}'
224
  except Exception as exc:
225
  if "429" in str(exc) or "RateLimit" in str(exc):
226
- raise # Let the batch runner handle rate limits
227
- print(f"[DEBUG] Model request failed: {exc}", flush=True)
228
  return '{"type": "query_host", "hostname": "WS-001"}'
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # =============================================================================
232
  # Episode Runner
233
  # =============================================================================
234
 
235
- async def run_episode(client: OpenAI, task_id: str) -> tuple:
236
- """Run a single episode. Returns (success, steps, score, rewards)."""
237
- env = CyberSOCEnvironment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  history: List[str] = []
239
  rewards: List[float] = []
240
  steps_taken = 0
241
- score = 0.0
242
  success = False
243
 
244
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
245
 
246
  try:
247
- # Reset environment
248
  obs = env.reset(task_id=task_id)
249
-
250
  max_steps = MAX_STEPS.get(task_id, 30)
251
 
252
  for step in range(1, max_steps + 1):
253
  if obs.done:
254
  break
255
 
256
- # Get action from LLM
257
- llm_response = get_model_action(client, step, obs, task_id, history)
258
 
259
- # Parse and execute
260
- error = None
261
  action_str = "unknown"
262
- reward = 0.0
 
263
 
264
  try:
265
- action_dict = parse_llm_action(llm_response)
266
- action_str = action_dict.get("type", "unknown")
267
- action = SOCActionWrapper(**action_dict)
268
- obs = env.step(action)
269
- reward = obs.reward or 0.0
270
- done = obs.done
271
  except Exception as exc:
272
  error = str(exc)[:200]
273
- done = False
274
- reward = 0.0
275
 
276
  rewards.append(reward)
277
  steps_taken = step
278
-
279
  log_step(step=step, action=action_str, reward=reward, done=done, error=error)
280
-
281
- history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}")
282
 
283
  if done:
284
  break
285
 
286
- # Calculate score from final_score if available, else normalize rewards
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  if obs.final_score is not None:
288
  score = obs.final_score
289
  else:
290
  score = sum(rewards) / MAX_POSSIBLE_REWARD if MAX_POSSIBLE_REWARD > 0 else 0.0
291
-
292
- score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
293
  success = score >= SUCCESS_SCORE_THRESHOLD
294
 
295
  finally:
@@ -304,14 +452,19 @@ async def run_episode(client: OpenAI, task_id: str) -> tuple:
304
 
305
  async def main() -> None:
306
  """Run baseline inference across all tasks."""
307
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
 
308
 
309
- total_scores = {}
310
  for task_id in TASKS:
311
- success, steps, score, rewards = await run_episode(client, task_id)
 
 
 
 
 
312
  total_scores[task_id] = score
313
 
314
- # Print summary
315
  avg = sum(total_scores.values()) / len(total_scores) if total_scores else 0.0
316
  print(f"\n# Summary: avg_score={avg:.3f}", flush=True)
317
  for tid, s in total_scores.items():
 
16
  - Must work on vcpu=2, memory=8gb
17
 
18
  Environment Variables:
19
+ API_BASE_URL - The API endpoint for the LLM
20
+ MODEL_NAME - Blue Team model identifier
21
+ RED_MODEL_NAME - Red Team model identifier (defaults to MODEL_NAME)
22
+ HF_TOKEN - Your Hugging Face / API key
23
+ FSP_MODE - Set to "true" to enable Fictitious Self-Play (Blue+Red alternate)
24
  """
25
 
26
  import asyncio
 
31
 
32
  from openai import OpenAI
33
 
34
+ from models import SOCActionWrapper, RedActionWrapper, SOCObservation
35
  from server.play_environment import CyberSOCEnvironment
36
 
37
  # =============================================================================
38
  # Configuration (from environment variables)
39
  # =============================================================================
40
 
41
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
42
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
43
+ RED_MODEL_NAME = os.getenv("RED_MODEL_NAME", MODEL_NAME) # defaults to same model
44
+ HF_TOKEN = os.getenv("HF_TOKEN")
45
+ FSP_MODE = os.getenv("FSP_MODE", "false").lower() == "true"
46
 
47
  BENCHMARK = "cybersocenv"
48
+ TASKS = ["easy", "medium", "hard"]
49
  MAX_STEPS = {"easy": 15, "medium": 25, "hard": 30}
50
+
51
  TEMPERATURE = 0.1
52
+ MAX_TOKENS = 1024
53
 
54
+ MAX_POSSIBLE_REWARD = 2.0
 
55
  SUCCESS_SCORE_THRESHOLD = 0.3
56
 
57
  # =============================================================================
58
+ # Blue Team System Prompt
59
  # =============================================================================
60
 
61
  SYSTEM_PROMPT = textwrap.dedent("""
 
83
  - You have a limited number of steps. Be efficient.
84
  """).strip()
85
 
86
+ # =============================================================================
87
+ # Red Team System Prompt (FSP mode)
88
+ # =============================================================================
89
+
90
+ RED_SYSTEM_PROMPT = textwrap.dedent("""
91
+ You are an adversarial Red Team AI in a cybersecurity simulation.
92
+ You have already gained an initial foothold on the network and must expand your attack
93
+ while evading the Blue Team SOC analyst who is trying to contain you.
94
+
95
+ Your goal: spread to new hosts, deploy payloads to maximize business impact, and evade
96
+ detection — all before the Blue Team shuts you down.
97
+
98
+ Available Actions (respond with exactly ONE JSON object per turn):
99
+
100
+ 1. Lateral pivot to a new host:
101
+ {"type": "lateral_pivot", "source_host": "<COMPROMISED_HOST>", "target_host": "<TARGET>"}
102
+
103
+ 2. Deploy a payload on a host you control:
104
+ {"type": "deploy_payload", "hostname": "<HOST>", "payload_type": "ransomware|exfiltration|c2"}
105
+
106
+ 3. Evade detection on a host you control:
107
+ {"type": "evade_detection", "hostname": "<HOST>", "technique": "migrate_pid|clear_logs"}
108
+
109
+ 4. Stay stealthy (do nothing this turn):
110
+ {"type": "pass_turn"}
111
+
112
+ Rules:
113
+ - Respond with ONLY a valid JSON object. No markdown, no explanation.
114
+ - You can only pivot FROM a host listed in compromised_hosts.
115
+ - You cannot pivot TO an isolated host — Blue has cut that path.
116
+ - Use evade_detection when Blue runs forensics on your hosts.
117
+ - Use pass_turn when staying hidden is more valuable than acting.
118
+ - Ransomware causes the most business damage; use it on high-value hosts.
119
+ """).strip()
120
 
121
  # =============================================================================
122
+ # Logging Helpers (EXACT hackathon format)
123
  # =============================================================================
124
 
125
  def log_start(task: str, env: str, model: str) -> None:
 
128
 
129
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
130
  error_val = error if error else "null"
131
+ done_val = str(done).lower()
132
  print(
133
  f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
134
  flush=True,
 
142
  flush=True,
143
  )
144
 
 
145
  # =============================================================================
146
+ # Observation Formatting
147
  # =============================================================================
148
 
149
  def format_observation(obs: SOCObservation) -> str:
150
+ """Format Blue Team observation into readable text for the LLM."""
151
  parts = []
152
 
 
153
  if obs.alert_queue:
154
  parts.append(f"## Active Alerts ({len(obs.alert_queue)}):")
155
  for a in obs.alert_queue:
 
160
  if a.ioc_indicators:
161
  parts.append(f" IOCs: {', '.join(a.ioc_indicators)}")
162
 
 
163
  topo = obs.network_topology
164
  parts.append(f"\n## Network Status:")
165
+ parts.append(
166
+ f" Compromised: {topo.compromised_count} | "
167
+ f"Isolated: {topo.isolated_count} | "
168
+ f"Online: {topo.online_count}"
169
+ )
170
 
 
171
  if obs.host_forensics:
172
  f = obs.host_forensics
173
  parts.append(f"\n## Forensics Result ({f.hostname}):")
 
177
  parts.append(f" Network connections: {f.network_connections}")
178
  parts.append(f" Memory artifacts: {f.memory_artifacts}")
179
 
 
180
  parts.append(f"\n## Active Threats: {obs.active_threats if obs.active_threats else 'None (all contained!)'}")
181
  parts.append(f"## Business Impact: {obs.business_impact_score:.2f}")
182
  parts.append(f"## Step: {obs.step_count} / {obs.max_steps}")
183
 
 
184
  if obs.timeline:
185
  parts.append(f"\n## Recent Actions:")
186
  for t in obs.timeline[-5:]:
187
+ if not t.action_type.startswith("red:"):
188
+ parts.append(f" Step {t.step}: {t.action_type} -> {t.target} (reward={t.reward:.2f})")
189
 
190
  return "\n".join(parts)
191
 
192
 
193
+ def format_red_observation(red_obs: Dict[str, Any]) -> str:
194
+ """Format Red Team observation into readable text for the Red LLM."""
195
+ parts = []
196
+
197
+ parts.append(f"## Round: {red_obs.get('round', '?')}")
198
+
199
+ compromised = red_obs.get("compromised_hosts", [])
200
+ parts.append(f"\n## Your Compromised Hosts ({len(compromised)}):")
201
+ for h in compromised:
202
+ parts.append(f" - {h}")
203
+
204
+ blue_actions = red_obs.get("blue_actions_detected", [])
205
+ if blue_actions:
206
+ parts.append("\n## Blue Team's Last Action (detected):")
207
+ for ba in blue_actions:
208
+ parts.append(f" Step {ba['step']}: {ba['action']} -> {ba['target']}")
209
+ else:
210
+ parts.append("\n## Blue Team's Last Action: (none detected yet)")
211
+
212
+ parts.append(f"\n## Active Threats Still Live: {red_obs.get('active_threats', [])}")
213
+ parts.append(f"## Business Impact So Far: {red_obs.get('business_impact', 0.0):.2f}")
214
+
215
+ return "\n".join(parts)
216
+
217
+
218
+ # =============================================================================
219
+ # LLM Action Parsing
220
+ # =============================================================================
221
+
222
  def parse_llm_action(content: str) -> Dict[str, Any]:
223
  """Parse the LLM's response into a valid action dict."""
224
  content = content.strip()
225
  if content.startswith("```"):
226
+ lines = [l for l in content.split("\n") if not l.strip().startswith("```")]
 
227
  content = "\n".join(lines).strip()
228
 
229
  try:
 
233
  except json.JSONDecodeError:
234
  pass
235
 
 
236
  for start in range(len(content)):
237
  if content[start] == "{":
238
  for end in range(len(content), start, -1):
 
247
  raise ValueError(f"Could not parse action from LLM response: {content[:200]}")
248
 
249
 
250
+ # =============================================================================
251
+ # LLM Callers
252
+ # =============================================================================
253
+
254
  def get_model_action(
255
  client: OpenAI,
256
  step: int,
 
258
  task_id: str,
259
  history: List[str],
260
  ) -> str:
261
+ """Get the next Blue Team action from the LLM."""
262
  obs_text = format_observation(obs)
263
 
264
  if step == 1:
 
279
  model=MODEL_NAME,
280
  messages=[
281
  {"role": "system", "content": SYSTEM_PROMPT},
282
+ {"role": "user", "content": user_content},
283
  ],
284
  temperature=TEMPERATURE,
285
  max_tokens=MAX_TOKENS,
 
289
  return text if text else '{"type": "query_host", "hostname": "WS-001"}'
290
  except Exception as exc:
291
  if "429" in str(exc) or "RateLimit" in str(exc):
292
+ raise
293
+ print(f"[DEBUG] Blue model request failed: {exc}", flush=True)
294
  return '{"type": "query_host", "hostname": "WS-001"}'
295
 
296
 
297
+ def get_red_model_action(
298
+ client: OpenAI,
299
+ step: int,
300
+ red_obs: Dict[str, Any],
301
+ task_id: str,
302
+ ) -> str:
303
+ """Get the next Red Team action from the Red LLM."""
304
+ obs_text = format_red_observation(red_obs)
305
+
306
+ compromised = red_obs.get("compromised_hosts", [])
307
+ if not compromised:
308
+ return '{"type": "pass_turn"}'
309
+
310
+ if step == 1:
311
+ user_content = (
312
+ f"## Mission Briefing (Task: {task_id.upper()})\n\n"
313
+ f"{obs_text}\n\n"
314
+ f"You have initial footholds. Plan your next move. Respond with a single JSON action."
315
+ )
316
+ else:
317
+ user_content = (
318
+ f"## Situation Update:\n\n"
319
+ f"{obs_text}\n\n"
320
+ f"Choose your next Red Team action. Respond with a single JSON action."
321
+ )
322
+
323
+ try:
324
+ completion = client.chat.completions.create(
325
+ model=RED_MODEL_NAME,
326
+ messages=[
327
+ {"role": "system", "content": RED_SYSTEM_PROMPT},
328
+ {"role": "user", "content": user_content},
329
+ ],
330
+ temperature=TEMPERATURE,
331
+ max_tokens=512,
332
+ stream=False,
333
+ )
334
+ text = (completion.choices[0].message.content or "").strip()
335
+ return text if text else '{"type": "pass_turn"}'
336
+ except Exception as exc:
337
+ if "429" in str(exc) or "RateLimit" in str(exc):
338
+ raise
339
+ print(f"[DEBUG] Red model request failed: {exc}", flush=True)
340
+ return '{"type": "pass_turn"}'
341
+
342
+
343
  # =============================================================================
344
  # Episode Runner
345
  # =============================================================================
346
 
347
+ async def run_episode(
348
+ blue_client: OpenAI,
349
+ task_id: str,
350
+ red_client: Optional[OpenAI] = None,
351
+ fsp: bool = False,
352
+ ) -> tuple:
353
+ """Run a single episode. Returns (success, steps, score, rewards).
354
+
355
+ Args:
356
+ blue_client: OpenAI client for the Blue Team LLM.
357
+ task_id: Task difficulty ('easy', 'medium', 'hard').
358
+ red_client: OpenAI client for the Red Team LLM (FSP mode only).
359
+ Falls back to blue_client when None.
360
+ fsp: When True, enables Fictitious Self-Play (Blue + Red alternate).
361
+ """
362
+ if red_client is None:
363
+ red_client = blue_client
364
+
365
+ env = CyberSOCEnvironment(fsp_mode=fsp)
366
  history: List[str] = []
367
  rewards: List[float] = []
368
  steps_taken = 0
369
+ score = 0.0
370
  success = False
371
 
372
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
373
 
374
  try:
 
375
  obs = env.reset(task_id=task_id)
 
376
  max_steps = MAX_STEPS.get(task_id, 30)
377
 
378
  for step in range(1, max_steps + 1):
379
  if obs.done:
380
  break
381
 
382
+ # ── Blue Turn ────────────────────────────────────────────────────
383
+ blue_response = get_model_action(blue_client, step, obs, task_id, history)
384
 
385
+ error = None
 
386
  action_str = "unknown"
387
+ reward = 0.0
388
+ done = False
389
 
390
  try:
391
+ action_dict = parse_llm_action(blue_response)
392
+ action_str = action_dict.get("type", "unknown")
393
+ blue_action = SOCActionWrapper(**action_dict)
394
+ obs = env.step(blue_action)
395
+ reward = obs.reward or 0.0
396
+ done = obs.done
397
  except Exception as exc:
398
  error = str(exc)[:200]
399
+ done = False
 
400
 
401
  rewards.append(reward)
402
  steps_taken = step
 
403
  log_step(step=step, action=action_str, reward=reward, done=done, error=error)
404
+ history.append(f"Step {step} [Blue]: {action_str} -> reward {reward:+.2f}")
 
405
 
406
  if done:
407
  break
408
 
409
+ # ── Red Turn (FSP mode only) ──────────────────────────────────────
410
+ if fsp and getattr(obs, "active_turn", "blue") == "red":
411
+ red_obs_data = obs.red_observation or {}
412
+ red_response = get_red_model_action(red_client, step, red_obs_data, task_id)
413
+
414
+ try:
415
+ red_dict = parse_llm_action(red_response)
416
+ red_action = RedActionWrapper(**red_dict)
417
+ obs = env.step(red_action)
418
+ done = obs.done
419
+ except Exception as exc:
420
+ print(f"[DEBUG] Red action failed: {exc}", flush=True)
421
+ # Fall back to PassTurn to close the round
422
+ try:
423
+ obs = env.step(RedActionWrapper(type="pass_turn"))
424
+ done = obs.done
425
+ except Exception:
426
+ pass
427
+
428
+ history.append(
429
+ f"Step {step} [Red]: {red_dict.get('type', 'pass_turn')}"
430
+ )
431
+
432
+ if done:
433
+ break
434
+
435
+ # Final score
436
  if obs.final_score is not None:
437
  score = obs.final_score
438
  else:
439
  score = sum(rewards) / MAX_POSSIBLE_REWARD if MAX_POSSIBLE_REWARD > 0 else 0.0
440
+ score = min(max(score, 0.0), 1.0)
 
441
  success = score >= SUCCESS_SCORE_THRESHOLD
442
 
443
  finally:
 
452
 
453
  async def main() -> None:
454
  """Run baseline inference across all tasks."""
455
+ blue_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
456
+ red_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) if FSP_MODE else None
457
 
458
+ total_scores: Dict[str, float] = {}
459
  for task_id in TASKS:
460
+ success, steps, score, rewards = await run_episode(
461
+ blue_client=blue_client,
462
+ task_id=task_id,
463
+ red_client=red_client,
464
+ fsp=FSP_MODE,
465
+ )
466
  total_scores[task_id] = score
467
 
 
468
  avg = sum(total_scores.values()) / len(total_scores) if total_scores else 0.0
469
  print(f"\n# Summary: avg_score={avg:.3f}", flush=True)
470
  for tid, s in total_scores.items():
models.py CHANGED
@@ -234,6 +234,17 @@ class SOCObservation(Observation):
234
  "Keys match grade_breakdown (threat_containment, ioc_blocking, etc.)."
235
  ),
236
  )
 
 
 
 
 
 
 
 
 
 
 
237
 
238
 
239
  # =============================================================================
@@ -345,6 +356,80 @@ class QuarantineFile(Action):
345
  file_path: str = Field(..., description="File path to quarantine")
346
 
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  # Discriminated union of all SOC actions
349
  SOCAction = Annotated[
350
  Union[
@@ -438,3 +523,7 @@ class SOCState(State):
438
  default=None,
439
  description="Mutable copy of containment_requirements (for adaptive grading).",
440
  )
 
 
 
 
 
234
  "Keys match grade_breakdown (threat_containment, ioc_blocking, etc.)."
235
  ),
236
  )
237
+ active_turn: str = Field(
238
+ default="blue",
239
+ description="Whose turn it is next: 'blue' or 'red'. Used by FSP inference loops.",
240
+ )
241
+ red_observation: Optional[Dict[str, Any]] = Field(
242
+ default=None,
243
+ description=(
244
+ "Red Team's current view of the world (populated when active_turn='red'). "
245
+ "Contains compromised_hosts and blue_actions_detected."
246
+ ),
247
+ )
248
 
249
 
250
  # =============================================================================
 
356
  file_path: str = Field(..., description="File path to quarantine")
357
 
358
 
359
+ # =============================================================================
360
+ # Red Team Actions (FSP — Fictitious Self-Play)
361
+ # =============================================================================
362
+
363
+ class LateralPivot(Action):
364
+ """Red Team: move laterally from a compromised host to a new target."""
365
+ type: Literal["lateral_pivot"] = Field(default="lateral_pivot")
366
+ source_host: str = Field(..., description="Already-compromised host used as the pivot point")
367
+ target_host: str = Field(..., description="Destination host to compromise")
368
+
369
+
370
+ class DeployPayload(Action):
371
+ """Red Team: deploy a malicious payload on a host Red already controls."""
372
+ type: Literal["deploy_payload"] = Field(default="deploy_payload")
373
+ hostname: str = Field(..., description="Compromised host to deploy payload on")
374
+ payload_type: Literal["ransomware", "exfiltration", "c2"] = Field(
375
+ ..., description="Class of payload to deploy"
376
+ )
377
+
378
+
379
+ class EvadeDetection(Action):
380
+ """Red Team: apply an evasion technique on a compromised host."""
381
+ type: Literal["evade_detection"] = Field(default="evade_detection")
382
+ hostname: str = Field(..., description="Compromised host to apply evasion on")
383
+ technique: Literal["migrate_pid", "clear_logs"] = Field(
384
+ ...,
385
+ description=(
386
+ "migrate_pid: rename running malicious processes to blend with system names; "
387
+ "clear_logs: remove SIEM alerts originating from this host"
388
+ ),
389
+ )
390
+
391
+
392
+ class PassTurn(Action):
393
+ """Red Team: remain stealthy and take no action this turn."""
394
+ type: Literal["pass_turn"] = Field(default="pass_turn")
395
+
396
+
397
+ # Constant used by dashboard_server and inference to route payloads
398
+ RED_ACTION_TYPES: frozenset = frozenset(
399
+ {"lateral_pivot", "deploy_payload", "evade_detection", "pass_turn"}
400
+ )
401
+
402
+ # Discriminated union of all Red actions
403
+ RedAction = Annotated[
404
+ Union[LateralPivot, DeployPayload, EvadeDetection, PassTurn],
405
+ Field(discriminator="type"),
406
+ ]
407
+
408
+
409
+ class RedActionWrapper(Action):
410
+ """Wrapper for Red Team actions — mirrors SOCActionWrapper for the WS/HTTP layer."""
411
+
412
+ type: str = Field(..., description="Red action type discriminator")
413
+ model_config = ConfigDict(extra="allow")
414
+
415
+ def to_typed_action(self):
416
+ """Deserialize to the correctly-typed Red action."""
417
+ data = self.model_dump(exclude={"metadata"})
418
+ action_map = {
419
+ "lateral_pivot": LateralPivot,
420
+ "deploy_payload": DeployPayload,
421
+ "evade_detection": EvadeDetection,
422
+ "pass_turn": PassTurn,
423
+ }
424
+ cls = action_map.get(data["type"])
425
+ if cls is None:
426
+ raise ValueError(
427
+ f"Unknown red action type: {data['type']}. "
428
+ f"Valid types: {list(action_map)}"
429
+ )
430
+ return cls(**data)
431
+
432
+
433
  # Discriminated union of all SOC actions
434
  SOCAction = Annotated[
435
  Union[
 
523
  default=None,
524
  description="Mutable copy of containment_requirements (for adaptive grading).",
525
  )
526
+ active_turn: str = Field(
527
+ default="blue",
528
+ description="Current active turn in the FSP engine: 'blue' or 'red'.",
529
+ )
server/play_environment.py CHANGED
@@ -47,6 +47,12 @@ try:
47
  TerminatePID,
48
  CreateFirewallRule,
49
  QuarantineFile,
 
 
 
 
 
 
50
  )
51
  except ImportError:
52
  from models import (
@@ -69,6 +75,12 @@ except ImportError:
69
  TerminatePID,
70
  CreateFirewallRule,
71
  QuarantineFile,
 
 
 
 
 
 
72
  )
73
 
74
  from .tasks import get_task, build_network
@@ -149,12 +161,25 @@ class CyberSOCEnvironment(Environment):
149
  adaptive: bool = False,
150
  neural_red_policy: Optional[Any] = None,
151
  red_team_logger: Optional[Callable[[Dict[str, Any]], None]] = None,
 
152
  ):
153
- """Initialize the environment (actual state set in reset)."""
 
 
 
 
 
 
 
 
 
 
 
154
  super().__init__()
155
  self._adaptive = adaptive
156
  self._neural_red_policy = neural_red_policy
157
  self._red_team_logger = red_team_logger
 
158
  self._red_team_decisions: List[Dict[str, Any]] = []
159
  self._live_requirements: Dict[str, Any] = {}
160
  self._threat_graph = None # will be initialized on reset()
@@ -252,6 +277,7 @@ class CyberSOCEnvironment(Environment):
252
  timeline=[],
253
  is_done=False,
254
  submitted_plan=False,
 
255
  )
256
 
257
  self._plan_entries = []
@@ -347,27 +373,44 @@ class CyberSOCEnvironment(Environment):
347
 
348
  def step(
349
  self,
350
- action: SOCActionWrapper, # type: ignore[override]
351
  timeout_s: Optional[float] = None,
352
  **kwargs: Any,
353
  ) -> SOCObservation:
354
- """Process one agent action.
355
 
356
- Args:
357
- action: SOCActionWrapper containing the typed action.
358
- timeout_s: Ignored.
 
 
 
 
359
 
360
  Returns:
361
- SOCObservation with updated state, reward, and done flag.
362
  """
363
  if self._state.is_done:
364
  return self._build_observation(reward=0.0, done=True)
365
 
366
- # Convert wrapper to typed action (before consuming a step)
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  typed_action = action.to_typed_action()
368
  args = typed_action.model_dump(exclude={"metadata", "type"})
369
 
370
- # Pre-flight validation — invalid actions are penalised without consuming a step
371
  current_phase = self._get_current_phase()
372
  validation_error = self._middleware.validate(
373
  current_phase, typed_action.type, args, self._threat_graph
@@ -378,16 +421,13 @@ class CyberSOCEnvironment(Environment):
378
  self._state.total_reward += penalty
379
  return self._build_observation(reward=penalty, done=False)
380
 
381
- # Action is valid — now consume the step
382
- self._state.step_count += 1
383
 
384
- # Dispatch to handler
385
  reward = 0.0
386
  result_description = "unknown action"
387
 
388
- # Reset per-step observation extras at the start of every step
389
- self._last_obs_extras = {}
390
-
391
  if isinstance(typed_action, QueryHost):
392
  reward, result_description = self._handle_query_host(typed_action)
393
  elif isinstance(typed_action, IsolateSegment):
@@ -422,9 +462,11 @@ class CyberSOCEnvironment(Environment):
422
  elif isinstance(typed_action, QuarantineFile):
423
  reward, result_description = self._handle_quarantine_file(typed_action)
424
 
425
- # Step reward (idempotent per triple)
426
  target = self._get_action_target(typed_action)
427
- step_r = self._get_step_reward(phase="investigation", action_type=typed_action.type, target=target)
 
 
428
  reward += step_r
429
  self._step_reward_total += step_r
430
 
@@ -436,26 +478,26 @@ class CyberSOCEnvironment(Environment):
436
  if len(self._recent_actions) >= 3:
437
  last_three = self._recent_actions[-3:]
438
  if last_three[0] == last_three[1] == last_three[2]:
439
- reward -= 0.05 # stall penalty
440
-
441
- # Adaptive adversary reaction (deterministic by default, optional neural override)
442
- self._apply_red_team_dynamics(action_type=typed_action.type, target=target)
443
 
444
  # Business impact grows each step (attacker progresses)
445
  if not self._state.is_done:
446
  impact_rate = self._task_def.get("impact_per_step", 0.02)
447
- # Reduce impact growth if threats are being contained
448
- active_ratio = len(self._state.active_threats) / max(1, len(self._task_def["attack_chain"]))
 
449
  self._state.business_impact = min(
450
- 1.0,
451
- self._state.business_impact + impact_rate * active_ratio,
452
  )
453
 
 
 
 
454
  # Record timeline
455
  self._state.timeline.append({
456
- "step": self._state.step_count,
457
  "action_type": typed_action.type,
458
- "target": self._get_action_target(typed_action),
459
  "result": result_description,
460
  "reward": reward,
461
  })
@@ -463,16 +505,77 @@ class CyberSOCEnvironment(Environment):
463
  # Accumulate reward
464
  self._state.total_reward += reward
465
 
466
- # Check termination
467
  done = False
468
  if self._state.submitted_plan:
469
  done = True
470
  self._state.is_done = True
471
- elif self._state.step_count >= self._state.max_steps:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  done = True
473
  self._state.is_done = True
474
- reward -= 0.20 # Penalty for running out of time
475
- self._state.total_reward += (-0.20)
476
 
477
  return self._build_observation(reward=reward, done=done)
478
 
@@ -822,9 +925,6 @@ class CyberSOCEnvironment(Environment):
822
  reward = -0.08 # Penalty: killing legitimate process = downtime
823
  self._state.business_impact = min(1.0, self._state.business_impact + 0.03)
824
 
825
- if was_malicious:
826
- self._maybe_reinfect(hostname, process)
827
-
828
  return reward, f"Killed '{process}' on {hostname}. Malicious: {was_malicious}"
829
 
830
  def _handle_terminate_pid(self, action: TerminatePID) -> tuple[float, str]:
@@ -878,7 +978,6 @@ class CyberSOCEnvironment(Environment):
878
  self._state.business_impact = min(1.0, self._state.business_impact + 0.04)
879
  return reward, f"Terminated benign PID '{pid}' on {hostname} - business disruption"
880
 
881
- self._maybe_reinfect(hostname, process_name)
882
  return reward, f"Terminated PID '{pid}' on {hostname}. Malicious: True"
883
 
884
  def _handle_create_firewall_rule(self, action: CreateFirewallRule) -> tuple[float, str]:
@@ -1097,6 +1196,203 @@ class CyberSOCEnvironment(Environment):
1097
  "description": f"Scanned {hostname}: found {len(vuln_results)} CVEs",
1098
  }
1099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  # ===========================================================================
1101
  # Helpers
1102
  # ===========================================================================
@@ -1326,6 +1622,13 @@ class CyberSOCEnvironment(Environment):
1326
  # Per-step partial reward dimensions for GRPO credit assignment
1327
  reward_dimensions = self._compute_reward_dimensions()
1328
 
 
 
 
 
 
 
 
1329
  return SOCObservation(
1330
  episode_id=self._state.episode_id or "",
1331
  alert_queue=alerts,
@@ -1349,6 +1652,8 @@ class CyberSOCEnvironment(Environment):
1349
  threat_graph_summary=threat_graph_summary,
1350
  available_playbooks=[],
1351
  reward_dimensions=reward_dimensions,
 
 
1352
  )
1353
 
1354
  def _get_action_target(self, action: Any) -> str:
@@ -1383,18 +1688,38 @@ class CyberSOCEnvironment(Environment):
1383
  # Adaptive Red Team + Step Rewards (Task 10)
1384
  # ===========================================================================
1385
 
1386
- def _build_red_observation(self, action_type: str, target: str) -> Dict[str, Any]:
1387
- """Compact red-side view used for imitation logs and neural policies."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1388
  return {
1389
  "episode_id": self._state.episode_id,
1390
- "task_id": self._state.task_id,
1391
- "step_count": self._state.step_count,
1392
- "blue_action_type": action_type,
1393
- "blue_action_target": target,
1394
  "active_threats": list(self._state.active_threats),
1395
- "contained_threats": list(self._state.contained_threats),
1396
- "business_impact": self._state.business_impact,
1397
- "adaptive_enabled": self._adaptive,
1398
  }
1399
 
1400
  def _log_red_decision(self, observation: Dict[str, Any], action: Dict[str, Any]) -> None:
@@ -1409,77 +1734,16 @@ class CyberSOCEnvironment(Environment):
1409
  pass
1410
 
1411
  def _apply_red_team_dynamics(self, action_type: str, target: str) -> None:
1412
- """
1413
- Route red-team behavior through deterministic logic (default) or neural policy.
1414
 
1415
- When no neural policy is provided, behavior is unchanged from the legacy
1416
- deterministic `_adversary_react` implementation.
1417
  """
1418
- red_obs = self._build_red_observation(action_type=action_type, target=target)
1419
-
1420
- if self._neural_red_policy is None:
1421
- result = self._adversary_react(action_type=action_type, target=target)
1422
- self._log_red_decision(
1423
- red_obs,
1424
- result or {"policy": "deterministic", "action_type": "noop"},
1425
- )
1426
- return
1427
-
1428
- policy_fn = None
1429
- if callable(self._neural_red_policy):
1430
- policy_fn = self._neural_red_policy
1431
- elif hasattr(self._neural_red_policy, "act"):
1432
- policy_fn = self._neural_red_policy.act
1433
-
1434
- if policy_fn is None:
1435
- result = self._adversary_react(action_type=action_type, target=target)
1436
- self._log_red_decision(
1437
- red_obs,
1438
- result or {"policy": "deterministic_fallback", "action_type": "noop"},
1439
- )
1440
- return
1441
-
1442
- try:
1443
- proposed = policy_fn(red_obs)
1444
- except Exception as exc:
1445
- result = self._adversary_react(action_type=action_type, target=target)
1446
- self._log_red_decision(
1447
- red_obs,
1448
- {
1449
- "policy": "neural_fallback",
1450
- "action_type": "noop",
1451
- "error": f"{type(exc).__name__}: {exc}",
1452
- },
1453
- )
1454
- if result is not None:
1455
- self._log_red_decision(red_obs, result)
1456
- return
1457
-
1458
- if not isinstance(proposed, dict):
1459
- self._log_red_decision(
1460
- red_obs,
1461
- {"policy": "neural_invalid", "action_type": "noop"},
1462
- )
1463
- return
1464
-
1465
- red_action_type = str(proposed.get("action_type", "noop"))
1466
- if red_action_type == "lateral_pivot":
1467
- source_host = str(proposed.get("source_host") or target)
1468
- outcome = self._execute_lateral_pivot(source_host=source_host)
1469
- self._log_red_decision(
1470
- red_obs,
1471
- {
1472
- "policy": "neural",
1473
- "action_type": "lateral_pivot",
1474
- "source_host": source_host,
1475
- "executed": bool(outcome and outcome.get("executed")),
1476
- },
1477
- )
1478
- return
1479
-
1480
  self._log_red_decision(
1481
  red_obs,
1482
- {"policy": "neural", "action_type": red_action_type},
 
1483
  )
1484
 
1485
  def export_red_team_decisions(self) -> List[Dict[str, Any]]:
@@ -1578,123 +1842,8 @@ class CyberSOCEnvironment(Environment):
1578
  })
1579
 
1580
  def _adversary_react(self, action_type: str, target: str) -> Optional[Dict[str, Any]]:
1581
- """Adaptive red team response fires after each step when adaptive=True."""
1582
- if not self._adaptive:
1583
- return None
1584
-
1585
- difficulty = self._task_def.get("difficulty") or getattr(self._state, "task_id", "easy")
1586
- # Reduced medium base probability for better GRPO credit assignment
1587
- pivot_probability = {"easy": 0.0, "medium": 0.3, "hard": 1.0}.get(difficulty, 0.0)
1588
-
1589
- # Time-pressure escalation: attacker moves faster when uncontained and late in episode
1590
- if self._state.step_count > 10 and len(self._state.contained_threats) == 0:
1591
- pivot_probability += 0.2
1592
-
1593
- # Trigger on isolate_segment OR kill_process (extended pivot trigger)
1594
- if action_type in ("isolate_segment", "kill_process") and pivot_probability > 0:
1595
- if self._rng.random() < pivot_probability:
1596
- source_host = target.split("/")[0] if "/" in target else target
1597
- return self._execute_lateral_pivot(source_host=source_host)
1598
- return {"policy": "deterministic", "action_type": "noop", "executed": False}
1599
-
1600
- def _execute_lateral_pivot(self, source_host: str) -> Optional[Dict[str, Any]]:
1601
- """Copy-not-move lateral pivot: spread to an adjacent healthy host.
1602
-
1603
- Rubric is capped at MAX_RUBRIC_ITEMS to prevent competent agents from
1604
- being penalised by an impossible-to-complete rubric.
1605
- """
1606
- MAX_RUBRIC_ITEMS = 12
1607
- graph = self._threat_graph
1608
- if graph is None:
1609
- return None
1610
-
1611
- # Rubric cap: stop pivoting once live_requirements is full
1612
- if self._live_requirements:
1613
- current_items = (
1614
- len(self._live_requirements.get("must_kill", []))
1615
- + len(self._live_requirements.get("must_isolate", []))
1616
- )
1617
- if current_items >= MAX_RUBRIC_ITEMS:
1618
- return {"policy": "deterministic", "action_type": "lateral_pivot", "executed": False}
1619
-
1620
- adjacent_hosts = [
1621
- e.target_id for e in graph.edges
1622
- if e.source_id == source_host and e.target_id in graph.hosts
1623
- and graph.hosts[e.target_id].status == "healthy"
1624
- ]
1625
- if not adjacent_hosts:
1626
- # Try graph hosts first, then fall back to full host_index
1627
- healthy_hosts = [
1628
- h for h, node in graph.hosts.items()
1629
- if node.status == "healthy" and h != source_host
1630
- ]
1631
- if not healthy_hosts:
1632
- # Expand search to the full network
1633
- healthy_hosts = [
1634
- h for h, hd in self._host_index.items()
1635
- if hd.get("status", "online") not in ("compromised", "isolated")
1636
- and h != source_host
1637
- and h not in graph.hosts
1638
- ]
1639
- if not healthy_hosts:
1640
- return {"policy": "deterministic", "action_type": "lateral_pivot", "executed": False}
1641
- adjacent_hosts = healthy_hosts
1642
-
1643
- dest_host = self._rng.choice(adjacent_hosts)
1644
-
1645
- # Ensure destination host is in graph
1646
- if dest_host not in graph.hosts:
1647
- hd = self._host_index.get(dest_host, {})
1648
- graph.add_host(HostNode(
1649
- hostname=dest_host,
1650
- subnet=hd.get("subnet", "corporate"),
1651
- business_criticality="medium",
1652
- status="healthy",
1653
- ))
1654
-
1655
- source_processes = [p for p in graph.processes.values() if p.hostname == source_host]
1656
- if not source_processes:
1657
- return {"policy": "deterministic", "action_type": "lateral_pivot", "executed": False}
1658
- original = source_processes[0]
1659
-
1660
- new_pid = str(uuid.uuid4())[:8] # uuid imported at module level
1661
- new_process = ProcessNode(
1662
- process_id=f"{dest_host}:{new_pid}",
1663
- hostname=dest_host,
1664
- process_name=original.process_name,
1665
- killed=False,
1666
- )
1667
- graph.add_process(new_process)
1668
-
1669
- graph.add_edge(Edge(
1670
- edge_type="pivoted_from",
1671
- source_id=dest_host,
1672
- target_id=source_host,
1673
- evidence={"trigger_action": "isolate_segment", "step": self._state.step_count},
1674
- ))
1675
-
1676
- if self._live_requirements is None:
1677
- self._live_requirements = {}
1678
- self._live_requirements.setdefault("must_kill", []).append(
1679
- f"{dest_host}:{original.process_name}"
1680
- )
1681
- self._live_requirements.setdefault("must_isolate", []).append(dest_host)
1682
-
1683
- new_alert = AlertNode(
1684
- alert_id=f"PIVOT-{new_pid}",
1685
- severity="critical",
1686
- priority_score=15.0,
1687
- source_host=dest_host,
1688
- )
1689
- graph.add_alert(new_alert)
1690
- return {
1691
- "policy": "deterministic",
1692
- "action_type": "lateral_pivot",
1693
- "executed": True,
1694
- "source_host": source_host,
1695
- "dest_host": dest_host,
1696
- "alert_id": new_alert.alert_id,
1697
- }
1698
 
1699
  @property
1700
  def state(self) -> SOCState:
 
47
  TerminatePID,
48
  CreateFirewallRule,
49
  QuarantineFile,
50
+ RedActionWrapper,
51
+ LateralPivot,
52
+ DeployPayload,
53
+ EvadeDetection,
54
+ PassTurn,
55
+ RED_ACTION_TYPES,
56
  )
57
  except ImportError:
58
  from models import (
 
75
  TerminatePID,
76
  CreateFirewallRule,
77
  QuarantineFile,
78
+ RedActionWrapper,
79
+ LateralPivot,
80
+ DeployPayload,
81
+ EvadeDetection,
82
+ PassTurn,
83
+ RED_ACTION_TYPES,
84
  )
85
 
86
  from .tasks import get_task, build_network
 
161
  adaptive: bool = False,
162
  neural_red_policy: Optional[Any] = None,
163
  red_team_logger: Optional[Callable[[Dict[str, Any]], None]] = None,
164
+ fsp_mode: bool = False,
165
  ):
166
+ """Initialize the environment (actual state set in reset).
167
+
168
+ Args:
169
+ adaptive: Legacy adaptive-adversary flag (kept for backward compat).
170
+ neural_red_policy: Optional callable for neural Red policy (legacy hook).
171
+ red_team_logger: Optional callback for recording Red decisions.
172
+ fsp_mode: When True, step() uses strict alternating turns and
173
+ step_count only increments after BOTH Blue and Red have acted.
174
+ When False (default), step(SOCActionWrapper) behaves exactly as
175
+ before — Red's PassTurn is applied automatically so existing code
176
+ and tests remain unaffected.
177
+ """
178
  super().__init__()
179
  self._adaptive = adaptive
180
  self._neural_red_policy = neural_red_policy
181
  self._red_team_logger = red_team_logger
182
+ self._fsp_mode = fsp_mode
183
  self._red_team_decisions: List[Dict[str, Any]] = []
184
  self._live_requirements: Dict[str, Any] = {}
185
  self._threat_graph = None # will be initialized on reset()
 
277
  timeline=[],
278
  is_done=False,
279
  submitted_plan=False,
280
+ active_turn="blue",
281
  )
282
 
283
  self._plan_entries = []
 
373
 
374
  def step(
375
  self,
376
+ action, # SOCActionWrapper | RedActionWrapper
377
  timeout_s: Optional[float] = None,
378
  **kwargs: Any,
379
  ) -> SOCObservation:
380
+ """Process one agent action — Blue (SOCActionWrapper) or Red (RedActionWrapper).
381
 
382
+ Turn semantics (fsp_mode=True):
383
+ • Blue step: execute, flip active_turn 'red', do NOT increment step_count.
384
+ • Red step: execute, flip active_turn → 'blue', increment step_count.
385
+
386
+ When fsp_mode=False (default / backward-compat):
387
+ • Blue step auto-applies a Red PassTurn so step_count always increments,
388
+ preserving all existing test and dashboard behaviour.
389
 
390
  Returns:
391
+ SOCObservation; includes active_turn and red_observation fields.
392
  """
393
  if self._state.is_done:
394
  return self._build_observation(reward=0.0, done=True)
395
 
396
+ if isinstance(action, RedActionWrapper):
397
+ return self._step_red(action)
398
+ return self._step_blue(action)
399
+
400
+ # ------------------------------------------------------------------
401
+ # _step_blue — execute a Blue (SOC analyst) action
402
+ # ------------------------------------------------------------------
403
+
404
+ def _step_blue(
405
+ self,
406
+ action: SOCActionWrapper,
407
+ ) -> SOCObservation:
408
+ """Execute one Blue turn."""
409
+ # Convert wrapper to typed action
410
  typed_action = action.to_typed_action()
411
  args = typed_action.model_dump(exclude={"metadata", "type"})
412
 
413
+ # Pre-flight validation — penalise without consuming a step
414
  current_phase = self._get_current_phase()
415
  validation_error = self._middleware.validate(
416
  current_phase, typed_action.type, args, self._threat_graph
 
421
  self._state.total_reward += penalty
422
  return self._build_observation(reward=penalty, done=False)
423
 
424
+ # Reset per-step extras
425
+ self._last_obs_extras = {}
426
 
427
+ # Dispatch to Blue handler
428
  reward = 0.0
429
  result_description = "unknown action"
430
 
 
 
 
431
  if isinstance(typed_action, QueryHost):
432
  reward, result_description = self._handle_query_host(typed_action)
433
  elif isinstance(typed_action, IsolateSegment):
 
462
  elif isinstance(typed_action, QuarantineFile):
463
  reward, result_description = self._handle_quarantine_file(typed_action)
464
 
465
+ # Idempotent step reward
466
  target = self._get_action_target(typed_action)
467
+ step_r = self._get_step_reward(
468
+ phase="investigation", action_type=typed_action.type, target=target
469
+ )
470
  reward += step_r
471
  self._step_reward_total += step_r
472
 
 
478
  if len(self._recent_actions) >= 3:
479
  last_three = self._recent_actions[-3:]
480
  if last_three[0] == last_three[1] == last_three[2]:
481
+ reward -= 0.05
 
 
 
482
 
483
  # Business impact grows each step (attacker progresses)
484
  if not self._state.is_done:
485
  impact_rate = self._task_def.get("impact_per_step", 0.02)
486
+ active_ratio = len(self._state.active_threats) / max(
487
+ 1, len(self._task_def["attack_chain"])
488
+ )
489
  self._state.business_impact = min(
490
+ 1.0, self._state.business_impact + impact_rate * active_ratio
 
491
  )
492
 
493
+ # Round label: step_count+1 = current round being played (not yet closed)
494
+ round_label = self._state.step_count + 1
495
+
496
  # Record timeline
497
  self._state.timeline.append({
498
+ "step": round_label,
499
  "action_type": typed_action.type,
500
+ "target": target,
501
  "result": result_description,
502
  "reward": reward,
503
  })
 
505
  # Accumulate reward
506
  self._state.total_reward += reward
507
 
508
+ # Check if episode ends due to Blue action (plan submission)
509
  done = False
510
  if self._state.submitted_plan:
511
  done = True
512
  self._state.is_done = True
513
+ self._state.active_turn = "blue" # episode over — keep at blue
514
+ # In non-FSP mode, still increment step_count for consistency
515
+ if not self._fsp_mode:
516
+ self._state.step_count += 1
517
+ return self._build_observation(reward=reward, done=done)
518
+
519
+ # Flip turn to Red
520
+ self._state.active_turn = "red"
521
+
522
+ # fsp_mode=False (backward compat): auto-apply Red PassTurn so
523
+ # callers that only drive Blue see step_count increment as before.
524
+ if not self._fsp_mode:
525
+ self._state.step_count += 1
526
+ self._state.active_turn = "blue"
527
+ # Timeout check (done after Red's "auto turn")
528
+ if self._state.step_count >= self._state.max_steps:
529
+ reward -= 0.20
530
+ self._state.total_reward -= 0.20
531
+ self._state.is_done = True
532
+ done = True
533
+
534
+ return self._build_observation(reward=reward, done=done)
535
+
536
+ # ------------------------------------------------------------------
537
+ # _step_red — execute a Red Team action
538
+ # ------------------------------------------------------------------
539
+
540
+ def _step_red(self, action: RedActionWrapper) -> SOCObservation:
541
+ """Execute one Red turn. Only valid when active_turn == 'red'."""
542
+ if self._state.active_turn != "red":
543
+ # Wrong turn — return current obs with 0 reward (no state change)
544
+ return self._build_observation(reward=0.0, done=False)
545
+
546
+ typed_action = action.to_typed_action()
547
+ self._last_obs_extras = {}
548
+
549
+ reward = 0.0
550
+ result_description = "red: noop"
551
+
552
+ if isinstance(typed_action, LateralPivot):
553
+ reward, result_description = self._handle_lateral_pivot(typed_action)
554
+ elif isinstance(typed_action, DeployPayload):
555
+ reward, result_description = self._handle_deploy_payload(typed_action)
556
+ elif isinstance(typed_action, EvadeDetection):
557
+ reward, result_description = self._handle_evade_detection(typed_action)
558
+ elif isinstance(typed_action, PassTurn):
559
+ reward, result_description = self._handle_pass_turn(typed_action)
560
+
561
+ # Close the round: increment step_count, flip turn back to Blue
562
+ self._state.step_count += 1
563
+ self._state.active_turn = "blue"
564
+
565
+ # Record Red's action in timeline (prefixed with "red:" to distinguish)
566
+ self._state.timeline.append({
567
+ "step": self._state.step_count,
568
+ "action_type": f"red:{typed_action.type}",
569
+ "target": self._get_red_action_target(typed_action),
570
+ "result": result_description,
571
+ "reward": 0.0, # Red actions don't add to Blue's reward total
572
+ })
573
+
574
+ # Timeout check after the full round
575
+ done = False
576
+ if self._state.step_count >= self._state.max_steps:
577
  done = True
578
  self._state.is_done = True
 
 
579
 
580
  return self._build_observation(reward=reward, done=done)
581
 
 
925
  reward = -0.08 # Penalty: killing legitimate process = downtime
926
  self._state.business_impact = min(1.0, self._state.business_impact + 0.03)
927
 
 
 
 
928
  return reward, f"Killed '{process}' on {hostname}. Malicious: {was_malicious}"
929
 
930
  def _handle_terminate_pid(self, action: TerminatePID) -> tuple[float, str]:
 
978
  self._state.business_impact = min(1.0, self._state.business_impact + 0.04)
979
  return reward, f"Terminated benign PID '{pid}' on {hostname} - business disruption"
980
 
 
981
  return reward, f"Terminated PID '{pid}' on {hostname}. Malicious: True"
982
 
983
  def _handle_create_firewall_rule(self, action: CreateFirewallRule) -> tuple[float, str]:
 
1196
  "description": f"Scanned {hostname}: found {len(vuln_results)} CVEs",
1197
  }
1198
 
1199
+ # ===========================================================================
1200
+ # Red Team Action Handlers
1201
+ # ===========================================================================
1202
+
1203
+ def _handle_lateral_pivot(self, action: LateralPivot) -> tuple[float, str]:
1204
+ """Red: spread from a compromised host to a new target."""
1205
+ src = action.source_host
1206
+ dst = action.target_host
1207
+
1208
+ if src not in self._host_index:
1209
+ return 0.0, f"red: lateral_pivot — source '{src}' not in network"
1210
+ if self._host_index[src].get("status") != "compromised":
1211
+ return 0.0, f"red: lateral_pivot — '{src}' not under Red control"
1212
+ if dst not in self._host_index:
1213
+ return 0.0, f"red: lateral_pivot — target '{dst}' not in network"
1214
+
1215
+ dst_status = self._host_index[dst].get("status", "online")
1216
+ if dst_status == "isolated":
1217
+ return 0.0, f"red: lateral_pivot — '{dst}' is isolated, pivot blocked by Blue"
1218
+ if dst_status == "compromised":
1219
+ return 0.0, f"red: lateral_pivot — '{dst}' already compromised"
1220
+
1221
+ # Compromise target and copy a process from source
1222
+ self._host_index[dst]["status"] = "compromised"
1223
+ src_procs = (
1224
+ [p for p in self._threat_graph.processes.values() if p.hostname == src]
1225
+ if self._threat_graph else []
1226
+ )
1227
+ proc_name = src_procs[0].process_name if src_procs else "cmd.exe"
1228
+ self._host_index[dst].setdefault("running_processes", [])
1229
+ if proc_name not in self._host_index[dst]["running_processes"]:
1230
+ self._host_index[dst]["running_processes"].append(proc_name)
1231
+
1232
+ # Update threat graph
1233
+ if self._threat_graph is not None:
1234
+ if dst not in self._threat_graph.hosts:
1235
+ hd = self._host_index[dst]
1236
+ self._threat_graph.add_host(HostNode(
1237
+ hostname=dst,
1238
+ subnet=hd.get("subnet", "corporate"),
1239
+ business_criticality="medium",
1240
+ status="compromised",
1241
+ ))
1242
+ else:
1243
+ self._threat_graph.hosts[dst].status = "compromised"
1244
+
1245
+ pid = f"{dst}:{proc_name}"
1246
+ if pid not in self._threat_graph.processes:
1247
+ self._threat_graph.add_process(ProcessNode(
1248
+ process_id=pid, hostname=dst, process_name=proc_name
1249
+ ))
1250
+ self._threat_graph.add_edge(Edge(
1251
+ edge_type="pivoted_from", source_id=dst, target_id=src
1252
+ ))
1253
+
1254
+ # Generate SIEM alert for Blue
1255
+ alert_id = f"PIVOT-{uuid.uuid4().hex[:6].upper()}"
1256
+ subnet = self._host_index.get(dst, {}).get("subnet", "unknown")
1257
+ self._alert_queue.append({
1258
+ "alert_id": alert_id,
1259
+ "timestamp": "2024-01-01T00:00:00Z",
1260
+ "source_host": dst,
1261
+ "severity": "critical",
1262
+ "threat_type": "lateral_movement",
1263
+ "description": (
1264
+ f"Lateral movement detected: {proc_name} spawned on {dst} "
1265
+ f"(pivot from {src})"
1266
+ ),
1267
+ "ioc_indicators": [],
1268
+ "subnet": subnet,
1269
+ "is_acknowledged": False,
1270
+ })
1271
+ if self._threat_graph is not None:
1272
+ self._threat_graph.add_alert(AlertNode(
1273
+ alert_id=alert_id, severity="critical",
1274
+ priority_score=15.0, source_host=dst,
1275
+ ))
1276
+
1277
+ # Update live rubric
1278
+ if self._live_requirements is not None:
1279
+ self._live_requirements.setdefault("must_kill", []).append({
1280
+ "hostname": dst, "process": proc_name, "threat_id": "FSP_PIVOT",
1281
+ })
1282
+
1283
+ return 0.0, f"red: lateral_pivot {src} → {dst} (proc={proc_name})"
1284
+
1285
+ def _handle_deploy_payload(self, action: DeployPayload) -> tuple[float, str]:
1286
+ """Red: deploy a malicious payload on a host Red controls."""
1287
+ hostname = action.hostname
1288
+ payload_type = action.payload_type
1289
+
1290
+ if hostname not in self._host_index:
1291
+ return 0.0, f"red: deploy_payload — '{hostname}' not in network"
1292
+ if self._host_index[hostname].get("status") != "compromised":
1293
+ return 0.0, f"red: deploy_payload — no shell on '{hostname}'"
1294
+
1295
+ proc_name = {
1296
+ "ransomware": "ransomware.exe",
1297
+ "exfiltration": "exfil_agent.exe",
1298
+ "c2": "c2_beacon.exe",
1299
+ }[payload_type]
1300
+
1301
+ host = self._host_index[hostname]
1302
+ if proc_name not in host.get("running_processes", []):
1303
+ host.setdefault("running_processes", []).append(proc_name)
1304
+
1305
+ if self._threat_graph is not None:
1306
+ pid = f"{hostname}:{proc_name}"
1307
+ if pid not in self._threat_graph.processes:
1308
+ self._threat_graph.add_process(ProcessNode(
1309
+ process_id=pid, hostname=hostname, process_name=proc_name
1310
+ ))
1311
+
1312
+ impact_delta = {"ransomware": 0.15, "exfiltration": 0.08, "c2": 0.05}[payload_type]
1313
+ self._state.business_impact = min(1.0, self._state.business_impact + impact_delta)
1314
+
1315
+ severity = {"ransomware": "critical", "exfiltration": "high", "c2": "high"}[payload_type]
1316
+ alert_id = f"PAYLOAD-{uuid.uuid4().hex[:6].upper()}"
1317
+ self._alert_queue.append({
1318
+ "alert_id": alert_id,
1319
+ "timestamp": "2024-01-01T00:00:00Z",
1320
+ "source_host": hostname,
1321
+ "severity": severity,
1322
+ "threat_type": payload_type,
1323
+ "description": (
1324
+ f"{payload_type.capitalize()} payload deployed on {hostname}: {proc_name}"
1325
+ ),
1326
+ "ioc_indicators": [],
1327
+ "subnet": host.get("subnet", "unknown"),
1328
+ "is_acknowledged": False,
1329
+ })
1330
+ if self._threat_graph is not None:
1331
+ self._threat_graph.add_alert(AlertNode(
1332
+ alert_id=alert_id, severity=severity,
1333
+ priority_score=18.0, source_host=hostname,
1334
+ ))
1335
+
1336
+ return 0.0, f"red: deployed {payload_type} payload on {hostname}"
1337
+
1338
+ def _handle_evade_detection(self, action: EvadeDetection) -> tuple[float, str]:
1339
+ """Red: apply a detection-evasion technique on a controlled host."""
1340
+ hostname = action.hostname
1341
+ technique = action.technique
1342
+
1343
+ if hostname not in self._host_index:
1344
+ return 0.0, f"red: evade_detection — '{hostname}' not in network"
1345
+ if self._host_index[hostname].get("status") != "compromised":
1346
+ return 0.0, f"red: evade_detection — no shell on '{hostname}'"
1347
+
1348
+ if technique == "migrate_pid":
1349
+ host = self._host_index[hostname]
1350
+ malicious_procs = {
1351
+ proc
1352
+ for threat in self._task_def.get("attack_chain", [])
1353
+ if hostname in threat.get("compromised_hosts", [])
1354
+ for proc in threat.get("malicious_processes", [])
1355
+ }
1356
+ for i, proc in enumerate(list(host.get("running_processes", []))):
1357
+ if proc in malicious_procs:
1358
+ new_name = f"svchost_{i}.exe"
1359
+ host["running_processes"][i] = new_name
1360
+ if self._threat_graph:
1361
+ old_pid = f"{hostname}:{proc}"
1362
+ if old_pid in self._threat_graph.processes:
1363
+ self._threat_graph.processes.pop(old_pid)
1364
+ new_pid = f"{hostname}:{new_name}"
1365
+ self._threat_graph.add_process(ProcessNode(
1366
+ process_id=new_pid, hostname=hostname,
1367
+ process_name=new_name,
1368
+ ))
1369
+ return 0.0, f"red: migrated PIDs on {hostname} to blend with system processes"
1370
+
1371
+ if technique == "clear_logs":
1372
+ before = len(self._alert_queue)
1373
+ self._alert_queue = [
1374
+ a for a in self._alert_queue
1375
+ if a.get("source_host") != hostname
1376
+ ]
1377
+ removed = before - len(self._alert_queue)
1378
+ return 0.0, f"red: cleared {removed} SIEM alert(s) from {hostname}"
1379
+
1380
+ return 0.0, f"red: evasion '{technique}' applied on {hostname}"
1381
+
1382
+ def _handle_pass_turn(self, action: PassTurn) -> tuple[float, str]: # noqa: ARG002
1383
+ """Red: remain stealthy, take no action."""
1384
+ return 0.0, "red: pass_turn (stealth)"
1385
+
1386
+ def _get_red_action_target(self, action: Any) -> str:
1387
+ """Extract a compact target string from a Red action for timeline logging."""
1388
+ if isinstance(action, LateralPivot):
1389
+ return f"{action.source_host}→{action.target_host}"
1390
+ if isinstance(action, DeployPayload):
1391
+ return f"{action.hostname}/{action.payload_type}"
1392
+ if isinstance(action, EvadeDetection):
1393
+ return f"{action.hostname}/{action.technique}"
1394
+ return "—"
1395
+
1396
  # ===========================================================================
1397
  # Helpers
1398
  # ===========================================================================
 
1622
  # Per-step partial reward dimensions for GRPO credit assignment
1623
  reward_dimensions = self._compute_reward_dimensions()
1624
 
1625
+ # Red observation — only populated when it is Red's turn next
1626
+ red_obs = (
1627
+ self._generate_red_observation()
1628
+ if self._state.active_turn == "red"
1629
+ else None
1630
+ )
1631
+
1632
  return SOCObservation(
1633
  episode_id=self._state.episode_id or "",
1634
  alert_queue=alerts,
 
1652
  threat_graph_summary=threat_graph_summary,
1653
  available_playbooks=[],
1654
  reward_dimensions=reward_dimensions,
1655
+ active_turn=self._state.active_turn,
1656
+ red_observation=red_obs,
1657
  )
1658
 
1659
  def _get_action_target(self, action: Any) -> str:
 
1688
  # Adaptive Red Team + Step Rewards (Task 10)
1689
  # ===========================================================================
1690
 
1691
+ def _generate_red_observation(self) -> Dict[str, Any]:
1692
+ """What the Red Team LLM sees: footholds it controls + Blue's last action.
1693
+
1694
+ Returned as the ``red_observation`` field in SOCObservation whenever
1695
+ ``active_turn == 'red'``, so inference.py can feed it straight to the
1696
+ Red LLM without a separate API call.
1697
+ """
1698
+ compromised_hosts = [
1699
+ h for h, hd in self._host_index.items()
1700
+ if hd.get("status") == "compromised"
1701
+ ]
1702
+
1703
+ # Most recent Blue action from the timeline (exclude Red's own entries)
1704
+ blue_actions_detected: List[Dict[str, Any]] = []
1705
+ for entry in reversed(self._state.timeline):
1706
+ action_type = entry.get("action_type", "")
1707
+ if not action_type.startswith("red:"):
1708
+ blue_actions_detected.append({
1709
+ "step": entry["step"],
1710
+ "action": action_type,
1711
+ "target": entry["target"],
1712
+ "result": entry["result"],
1713
+ })
1714
+ break # Only the single most recent Blue action
1715
+
1716
  return {
1717
  "episode_id": self._state.episode_id,
1718
+ "round": self._state.step_count + 1,
1719
+ "compromised_hosts": compromised_hosts,
1720
+ "blue_actions_detected": blue_actions_detected,
 
1721
  "active_threats": list(self._state.active_threats),
1722
+ "business_impact": round(self._state.business_impact, 4),
 
 
1723
  }
1724
 
1725
  def _log_red_decision(self, observation: Dict[str, Any], action: Dict[str, Any]) -> None:
 
1734
  pass
1735
 
1736
  def _apply_red_team_dynamics(self, action_type: str, target: str) -> None:
1737
+ """Log a Red-side observation record (imitation data for offline SFT).
 
1738
 
1739
+ In FSP mode the Red LLM acts via explicit RedActionWrapper steps, so
1740
+ this method only records observations rather than executing any attack.
1741
  """
1742
+ red_obs = self._generate_red_observation()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1743
  self._log_red_decision(
1744
  red_obs,
1745
+ {"policy": "fsp_turn_engine", "action_type": "noop",
1746
+ "blue_action": action_type, "blue_target": target},
1747
  )
1748
 
1749
  def export_red_team_decisions(self) -> List[Dict[str, Any]]:
 
1842
  })
1843
 
1844
  def _adversary_react(self, action_type: str, target: str) -> Optional[Dict[str, Any]]:
1845
+ """Legacy hook disabled; Red Team now acts via explicit RedActionWrapper steps."""
1846
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1847
 
1848
  @property
1849
  def state(self) -> SOCState:
tests/test_integration.py CHANGED
@@ -13,7 +13,7 @@ if _PROJECT_ROOT not in sys.path:
13
  from server.play_environment import CyberSOCEnvironment
14
  from server.episode_sandbox import EpisodeTimeout
15
  from server.graders import grade_episode
16
- from models import SOCActionWrapper
17
 
18
 
19
  # ---------------------------------------------------------------------------
@@ -95,18 +95,35 @@ def test_phase_violation_returns_error():
95
  assert obs is not None
96
 
97
 
98
- def test_adaptive_pivot_fires_on_hard():
99
- env = CyberSOCEnvironment(adaptive=True)
 
100
  env.reset(task_id="hard")
101
 
102
- # Force pivot probability to 1.0 (hard task)
103
- # We need to isolate_segment where the host is the source_host for an edge
104
- # OR just call _execute_lateral_pivot directly for test certainty
105
- hostname = _first_host(env)
106
- env._execute_lateral_pivot(source_host=hostname)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  pivot_edges = [e for e in env._threat_graph.edges if e.edge_type == "pivoted_from"]
109
  assert len(pivot_edges) >= 1
 
 
110
 
111
 
112
  def test_step_reward_accumulates():
 
13
  from server.play_environment import CyberSOCEnvironment
14
  from server.episode_sandbox import EpisodeTimeout
15
  from server.graders import grade_episode
16
+ from models import SOCActionWrapper, RedActionWrapper
17
 
18
 
19
  # ---------------------------------------------------------------------------
 
95
  assert obs is not None
96
 
97
 
98
+ def test_lateral_pivot_red_action():
99
+ """LateralPivot RedActionWrapper creates a pivoted_from edge and a SIEM alert."""
100
+ env = CyberSOCEnvironment(fsp_mode=True)
101
  env.reset(task_id="hard")
102
 
103
+ # Find a compromised host to pivot from and a healthy one to pivot to
104
+ src = next(
105
+ (h for h, hd in env._host_index.items() if hd.get("status") == "compromised"),
106
+ None,
107
+ )
108
+ dst = next(
109
+ (h for h, hd in env._host_index.items()
110
+ if hd.get("status") not in ("compromised", "isolated") and h != src),
111
+ None,
112
+ )
113
+ if src is None or dst is None:
114
+ pytest.skip("No suitable host pair for lateral pivot test")
115
+
116
+ # Blue takes a PassTurn-equivalent (query) so active_turn flips to red
117
+ env.step(_valid_action("query_host", hostname=src))
118
+ assert env._state.active_turn == "red"
119
+
120
+ alerts_before = len(env._alert_queue)
121
+ env.step(RedActionWrapper(type="lateral_pivot", source_host=src, target_host=dst))
122
 
123
  pivot_edges = [e for e in env._threat_graph.edges if e.edge_type == "pivoted_from"]
124
  assert len(pivot_edges) >= 1
125
+ assert env._host_index[dst]["status"] == "compromised"
126
+ assert len(env._alert_queue) > alerts_before # SIEM alert generated
127
 
128
 
129
  def test_step_reward_accumulates():