"""Synthetic trajectory generation for OpenRange.
This module provides a fast, snapshot-backed simulator for collecting
teacher-model trajectories without booting Docker containers. It is meant
for SFT warm-start data generation, not reward-faithful evaluation.
"""
from __future__ import annotations
import asyncio
import logging
import random
import re
import shlex
from pathlib import Path
from typing import Any
from open_range.agents.llm_agent import LLMRangeAgent
from open_range.agents.parsing import strip_command_from_response
from open_range.agents.protocol import RangeAgent
from open_range.agents.replay_agent import ScriptedBlueAgent, ScriptedRedAgent
from open_range.builder.builder import LLMSnapshotBuilder, TemplateOnlyBuilder
from open_range.protocols import BuildContext, SnapshotBuilder, SnapshotSpec, Vulnerability
from open_range.server.environment import RangeEnvironment
from open_range.models import RangeAction, RangeObservation
from open_range.training.trajectory import TrajectoryLogger
logger = logging.getLogger(__name__)
_TOKEN_RE = re.compile(r"[a-z0-9_./:-]+")
_SYNTHETIC_REASONING_GUIDE = (
"When you act, think briefly inside ... about what you learned, "
"what hypothesis you are testing, and why the next step is justified. "
"After the reasoning, output exactly one command prefixed with 'Command:'. "
"Prefer high-signal interaction with the listed services and artifacts over "
"repeating local workstation discovery commands. Do not claim success until "
"the tool output confirms it."
)
def _run_async(coro: Any) -> Any:
"""Run an async coroutine from synchronous code."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, coro).result()
return asyncio.run(coro)
def _iter_hosts(snapshot: SnapshotSpec) -> list[str]:
raw_hosts = snapshot.topology.get("hosts", [])
hosts: list[str] = []
for host in raw_hosts:
if isinstance(host, dict):
name = str(host.get("name", "")).strip()
else:
name = str(host).strip()
if name:
hosts.append(name)
return hosts
def _deep_replace(value: Any, replacements: dict[str, str]) -> Any:
if isinstance(value, str):
result = value
for old, new in replacements.items():
result = result.replace(old, new)
return result
if isinstance(value, list):
return [_deep_replace(item, replacements) for item in value]
if isinstance(value, dict):
return {key: _deep_replace(item, replacements) for key, item in value.items()}
return value
def randomize_snapshot_flags(snapshot: SnapshotSpec, seed: int | None = None) -> SnapshotSpec:
"""Clone *snapshot* with unique flag values substituted throughout."""
if not snapshot.flags:
return snapshot.model_copy(deep=True)
rng = random.Random(seed)
replacements: dict[str, str] = {}
for flag in snapshot.flags:
inner = "".join(rng.choice("abcdef0123456789") for _ in range(16))
replacements[flag.value] = f"FLAG{{{inner}}}"
payload = snapshot.model_dump(mode="python")
payload = _deep_replace(payload, replacements)
return SnapshotSpec.model_validate(payload)
def _observation_text(observation: str | RangeObservation) -> str:
"""Convert an observation into training text without reward leakage."""
if isinstance(observation, str):
return observation
parts: list[str] = []
if observation.stdout:
parts.append(observation.stdout)
if observation.stderr:
parts.append(f"STDERR:\n{observation.stderr}")
if observation.alerts:
parts.append("ALERTS:\n" + "\n".join(f"- {alert}" for alert in observation.alerts))
if observation.flags_captured:
parts.append(
"FLAGS CAPTURED:\n"
+ "\n".join(f"- {flag}" for flag in observation.flags_captured)
)
return "\n\n".join(parts)
def _prefixed_output(text: str, *, step: int) -> str:
"""Add deterministic pseudo-timing prefixes to tool output lines."""
cleaned = text.strip("\n")
if not cleaned:
return f"[{0.2 + (step % 5) * 0.1:.1f}s]"
prefix = f"[{0.2 + (step % 5) * 0.1:.1f}s] "
return "\n".join(f"{prefix}{line}" for line in cleaned.splitlines())
def _host_inventory(snapshot: SnapshotSpec) -> str:
host_lines: list[str] = []
zones = snapshot.topology.get("zones", {}) if isinstance(snapshot.topology, dict) else {}
zone_map: dict[str, str] = {}
if isinstance(zones, dict):
for zone, hosts in zones.items():
if isinstance(hosts, list):
for host in hosts:
zone_map[str(host)] = str(zone)
for host in _iter_hosts(snapshot):
details = [zone_map.get(host, "").strip()]
if host == "web":
details.append("primary web application")
elif host == "mail":
details.append("mail gateway")
elif host == "db":
details.append("database service")
elif host == "files":
details.append("file share service")
elif host == "ldap":
details.append("directory and identity service")
elif host == "siem":
details.append("log aggregation and monitoring")
elif host == "firewall":
details.append("perimeter routing and filtering")
elif host == "attacker":
details.append("your operator workstation")
rendered = ", ".join(part for part in details if part)
if rendered:
host_lines.append(f"- {host}: {rendered}")
else:
host_lines.append(f"- {host}")
return "\n".join(host_lines)
def _entry_points(snapshot: SnapshotSpec) -> str:
entries: list[str] = []
seen: set[str] = set()
for step in snapshot.golden_path:
for match in re.findall(r"https?://[^\s'\"`]+", step.command):
if match not in seen:
entries.append(f"- {match}")
seen.add(match)
for host in _iter_hosts(snapshot):
if host == "web":
for candidate in ("http://web/", "http://web/login", "http://web/search?q=test"):
if candidate not in seen:
entries.append(f"- {candidate}")
seen.add(candidate)
elif host == "mail":
target = "mail:25 (SMTP)"
if target not in seen:
entries.append(f"- {target}")
seen.add(target)
elif host == "db":
target = "db:3306 (MySQL)"
if target not in seen:
entries.append(f"- {target}")
seen.add(target)
elif host == "files":
target = "files:445 (SMB)"
if target not in seen:
entries.append(f"- {target}")
seen.add(target)
elif host == "ldap":
target = "ldap:389 (LDAP)"
if target not in seen:
entries.append(f"- {target}")
seen.add(target)
return "\n".join(entries)
def _build_training_briefing(snapshot: SnapshotSpec, *, role: str) -> str:
task = snapshot.task
org_name = str(snapshot.topology.get("org_name", "OpenRange target")).strip() or "OpenRange target"
tier = int(snapshot.topology.get("tier", 1) or 1)
host_summary = _host_inventory(snapshot)
success_conditions = task.success_conditions if hasattr(task, "success_conditions") else []
objective_lines = []
for item in success_conditions:
if isinstance(item, dict):
description = str(item.get("description") or item.get("type") or "").strip()
if description:
objective_lines.append(f"- {description}")
if role == "red":
scenario = task.red_briefing or "Investigate the environment and recover the flag."
mission = [
"# ROLE",
"You are an autonomous red-team operator authorized to assess this environment.",
"",
"# TARGET INFORMATION",
f"**Organization**: {org_name}",
f"**Tier**: {tier}",
"**Goal**: Identify a viable attack path, validate it with evidence, and recover the exercise flag.",
"",
"## Scenario Brief",
scenario,
"",
"## Reachable Hosts",
host_summary or "- attacker",
"",
"## Known Access Points",
_entry_points(snapshot),
"",
"## Operating Notes",
"- Service hostnames are already known from the range context.",
"- Favor concrete interaction with listed services over repeated local interface discovery.",
]
if objective_lines:
mission.extend(["", "## Success Conditions", *objective_lines])
return "\n".join(mission).strip()
scenario = task.blue_briefing or "Monitor the SIEM, investigate suspicious activity, and respond."
mission = [
"# ROLE",
"You are an autonomous SOC analyst responsible for this environment.",
"",
"# DEFENSIVE CONTEXT",
f"**Organization**: {org_name}",
f"**Tier**: {tier}",
"**Goal**: Detect attacker activity, confirm it with log evidence, and respond without breaking healthy services.",
"",
"## Scenario Brief",
scenario,
"",
"## Monitored Hosts",
host_summary or "- siem",
]
return "\n".join(mission).strip()
def _infer_tool_call(command: str) -> tuple[str, dict[str, Any]]:
stripped = command.strip()
lowered = stripped.lower()
try:
parts = shlex.split(stripped)
except ValueError:
parts = stripped.split()
if lowered.startswith("submit_flag "):
flag = stripped.split(maxsplit=1)[1] if len(stripped.split(maxsplit=1)) == 2 else ""
return "flag_found", {"flag": flag}
if lowered.startswith("cat ") and len(parts) >= 2:
return "read_file", {"file_path": parts[1], "line_numbers": True}
if lowered.startswith("grep "):
return "grep", {"command": stripped}
if lowered.startswith("find "):
return "file_search", {"command": stripped}
return "shell_command", {"command": stripped, "timeout": 120}
def _render_trace_reasoning(agent: RangeAgent, *, role: str, command: str) -> str:
raw_response = str(getattr(agent, "last_response_text", "") or "").strip()
if raw_response:
reasoning = strip_command_from_response(raw_response, command)
if reasoning:
return reasoning
lowered = command.lower()
if "nmap" in lowered:
thought = "I need a service inventory first so I can narrow the exposed attack surface."
elif "curl" in lowered and ("union" in lowered or "select" in lowered):
thought = "The web workflow looks injectable, so I will test a UNION-style payload that can expose sensitive data."
elif "curl" in lowered:
thought = "I should interrogate the web surface directly to learn the available routes and behaviors."
elif lowered.startswith("cat "):
thought = "I need to inspect the referenced file directly for source-level clues, credentials, or the flag."
elif lowered.startswith("grep "):
thought = (
"I need to filter the SIEM signal down to evidence that confirms the current hypothesis."
if role == "blue"
else "I should search the available artifacts for indicators that support the next exploit step."
)
elif lowered.startswith("submit_flag "):
thought = "The recovered token is strong enough to validate immediately."
elif lowered.startswith("submit_finding "):
thought = "The observed activity is concrete enough to escalate as a finding."
else:
thought = "I will take the next low-risk step that reduces uncertainty and advances the objective."
return f"\n{thought}\n"
def _blue_stimulus(env: SyntheticRangeEnvironment) -> RangeObservation:
alerts = env._get_pending_alerts()
status = "Suspicious activity has been observed in the monitored environment."
if not alerts:
status = "No high-confidence alerts yet. Continue monitoring for attacker activity."
return RangeObservation(stdout=status, alerts=alerts)
class SyntheticRangeEnvironment(RangeEnvironment):
"""Fast, deterministic simulator built from a ``SnapshotSpec``."""
def __init__(
self,
*,
randomize_flags: bool = True,
max_steps: int = 30,
) -> None:
super().__init__(docker_available=False, max_steps=max_steps)
self._randomize_flags = randomize_flags
self._synthetic_seed: int | None = None
self._ephemeral_files: dict[str, str] = {}
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: Any,
) -> RangeObservation:
self._synthetic_seed = seed
self._ephemeral_files = {}
return super().reset(seed=seed, episode_id=episode_id, **kwargs)
def _select_snapshot(self, **kwargs: Any) -> SnapshotSpec:
snapshot = super()._select_snapshot(**kwargs)
if not self._randomize_flags:
return snapshot.model_copy(deep=True)
return randomize_snapshot_flags(snapshot, seed=self._synthetic_seed)
def _exec_in_container(
self,
container_name: str,
command: str,
timeout_s: float | None = None,
) -> tuple[str, str]:
del container_name, timeout_s # unused in the synthetic executor
if self._snapshot is None:
return "", "No snapshot loaded"
if self._state.mode == "blue":
return self._simulate_blue_command(command)
return self._simulate_red_command(command)
def _simulate_red_command(self, command: str) -> tuple[str, str]:
normalized = command.strip().lower()
if not normalized:
return "", "Empty command"
exact_step = self._match_golden_step(command)
if exact_step is not None:
return self._render_golden_output(command, exact_step), ""
if normalized == "whoami":
return "kali\n", ""
if normalized == "pwd":
return "/root\n", ""
if normalized.startswith("ip ") or normalized in {"ip", "hostname -i", "hostname -i && ip route && ip -br addr", "hostname -i && ip route"}:
return self._render_network_identity(command), ""
if normalized.startswith("arp"):
return self._render_arp_cache(), ""
if normalized.startswith("getent hosts"):
return self._render_hosts_lookup(command), ""
if normalized.startswith("ls"):
return self._render_ls(command), ""
if normalized.startswith("cat "):
return self._render_cat(command)
if "nmap" in normalized:
return self._render_nmap(command), ""
if "curl" in normalized:
return self._render_curl(command), ""
if "mysql" in normalized:
return self._render_mysql(command), ""
if "smbclient" in normalized:
return self._render_smb(command), ""
if "ldapsearch" in normalized:
return self._render_ldap(command), ""
if re.search(r"\bssh\b|\bsshpass\b", normalized):
return self._render_ssh(command), ""
if "grep" in normalized and "flag" in normalized:
return self._render_flag_search(), ""
return "Command completed successfully.\n", ""
def _simulate_blue_command(self, command: str) -> tuple[str, str]:
normalized = command.strip().lower()
if not normalized:
return "", "Empty command"
if any(token in normalized for token in ("grep", "tail", "cat", "awk", "sed")):
return self._render_siem_query(command), ""
if "check_services" in normalized:
return self._render_service_status(), ""
if "iptables" in normalized or "ufw" in normalized or "firewall" in normalized:
return "Firewall policy updated.\n", ""
if normalized.startswith("patch "):
return "Patch applied in synthetic environment.\n", ""
if "restart" in normalized:
return "Service restarted.\n", ""
return "Investigation command completed.\n", ""
def _match_golden_step(self, command: str):
if self._snapshot is None:
return None
normalized = self._normalize_command(command)
best_step = None
best_score = 0.0
cmd_name = self._command_name(command)
for step in self._snapshot.golden_path:
step_normalized = self._normalize_command(step.command)
if normalized == step_normalized:
return step
if cmd_name != self._command_name(step.command):
continue
score = self._token_overlap(normalized, step_normalized)
if score > best_score:
best_score = score
best_step = step
if best_score >= 0.66:
return best_step
return None
@staticmethod
def _command_name(command: str) -> str:
stripped = command.strip()
if not stripped:
return ""
return stripped.split()[0].rsplit("/", 1)[-1].lower()
@staticmethod
def _normalize_command(command: str) -> str:
lowered = command.lower()
return " ".join(_TOKEN_RE.findall(lowered))
@staticmethod
def _token_overlap(left: str, right: str) -> float:
left_tokens = set(left.split())
right_tokens = set(right.split())
if not left_tokens or not right_tokens:
return 0.0
intersection = left_tokens & right_tokens
union = left_tokens | right_tokens
return len(intersection) / len(union)
def _render_golden_output(self, command: str, step: Any) -> str:
expected = step.expect_in_stdout or "Command completed."
lowered = command.lower()
if "nmap" in lowered:
return f"Starting Nmap 7.94\n{expected}\nNmap done.\n"
if "curl" in lowered and "search" in lowered and ("union" in lowered or "flag" in lowered):
return f"Search results:\n{expected}\n"
if "curl" in lowered:
return f"{expected}\n"
if "mysql" in lowered:
return f"{expected}\n"
return f"{expected}\n"
def _render_nmap(self, command: str) -> str:
lines = ["Starting Nmap 7.94"]
lowered = command.lower()
if "10.0.1" in lowered or "web" in lowered:
lines.extend(
[
"80/tcp open http nginx 1.24",
"25/tcp open smtp postfix",
]
)
if "10.0.2" in lowered or "db" in lowered:
lines.extend(
[
"3306/tcp open mysql MySQL 8.0",
"445/tcp open smb samba 4.17",
]
)
if "10.0.3" in lowered or "ldap" in lowered or "siem" in lowered:
lines.extend(
[
"389/tcp open ldap OpenLDAP 2.6",
"514/tcp open syslog rsyslogd",
]
)
if len(lines) == 1:
lines.append("Host seems down.")
lines.append("Nmap done.")
return "\n".join(lines) + "\n"
def _render_curl(self, command: str) -> str:
lowered = command.lower()
flag = self._flag_value()
if "http://web/" in lowered and "login" not in lowered and "search" not in lowered:
hints = ["Welcome to the customer portal."]
if self._has_vuln_type("sqli"):
hints.append("Try /search?q=test")
if self._has_vuln_type("idor"):
hints.append("API available at /api/users/1/profile")
if self._has_vuln_type("path_traversal"):
hints.append("Downloads available at /download?file=report.pdf")
return "\n".join(hints) + "\n"
if "/login" in lowered:
return "Login\n