opensec-env / server /environment.py
Jarrodbarnes's picture
Upload folder using huggingface_hub
3f434eb verified
from __future__ import annotations
import os
import sqlite3
import re
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from sim.attacker_policy import (
AttackerPolicyManager,
ReplayCache,
init_cache_db,
resolve_attacker_policy,
resolve_attacker_policy_config,
resolve_replay_mode,
)
from sim.attacker_state_machine import (
AttackerContext,
ContainmentActions,
ScenarioContext,
advance_state,
apply_attacker_action,
)
from sim.log_compiler import compile_seed, emit_artifact
from oracle.scoring import containment_to_dict, score_report
from oracle.verifier import detect_injection_violations
from .models import ActionResult, AgentAction, ContainmentState, EpisodeState, Observation, StepResult
class OpenSecEnvironment:
def __init__(
self,
scenario_id: str = "seed-001",
seed_path: str = "data/seeds/sample_seed.json",
sqlite_dir: str = "data/sqlite",
max_steps: int = 15,
mask_injections: bool = False,
) -> None:
self.scenario_id = scenario_id
self.max_steps = max_steps
self.seed_path = seed_path
self.sqlite_dir = sqlite_dir
self.mask_injections = mask_injections
self.episode_id = ""
self.step_count = 0
self.attacker_state = "phish_sent"
self.attacker_context = AttackerContext()
self.containment = ContainmentState()
self.scenario: Optional[Dict[str, Any]] = None
self.db_path: Optional[str] = None
cache_path = os.getenv("OPENSEC_REPLAY_CACHE_PATH")
if cache_path and resolve_replay_mode() != "off":
init_cache_db(cache_path)
self.cache = ReplayCache(cache_path)
else:
self.cache = None
self.policy_manager = AttackerPolicyManager(cache=self.cache)
self.policy = resolve_attacker_policy()
self.ground_truth: Optional[Dict[str, Any]] = None
self.seen_evidence_ids: Set[str] = set()
self.content_evidence_ids: Set[str] = set()
self.injection_violations: List[str] = []
def reset(self) -> StepResult:
self.episode_id = str(uuid.uuid4())
self.step_count = 0
self.attacker_state = "phish_sent"
self.attacker_context = AttackerContext()
self.containment = ContainmentState()
self.seen_evidence_ids = set()
self.content_evidence_ids = set()
self.injection_violations = []
self._load_scenario()
if self.scenario and self.scenario.get("attack_graph", {}).get("start_state"):
self.attacker_state = self.scenario["attack_graph"]["start_state"]
if self.mask_injections and self.scenario:
self.scenario["prompt_injection_payloads"] = []
self._init_db()
new_emails = self._emails_for_step(0)
new_alerts = self._alerts_for_step(0)
self._record_evidence(new_emails, new_alerts)
state = self.state()
observation = Observation(
scenario_id=self.scenario_id,
step_index=self.step_count,
attacker_state=self.attacker_state,
containment=self.containment,
new_emails=new_emails,
new_alerts=new_alerts,
evidence_seen_ids=sorted(self.seen_evidence_ids),
evidence_content_ids=sorted(self.content_evidence_ids),
last_action_result=ActionResult(ok=True, message="reset", data={}),
hint="Environment ready",
done=False,
reward=None,
metadata={"info": {}, "state": state.model_dump()},
)
return StepResult(
observation=observation,
reward=0.0,
done=False,
info={},
state=state,
)
def step(self, action: AgentAction) -> StepResult:
if self.scenario is None or self.db_path is None:
self._load_scenario()
self._init_db()
result = self.apply_action(action)
violations = detect_injection_violations(
action.model_dump(),
self.content_evidence_ids,
self.scenario.get("prompt_injection_payloads", []),
)
if violations:
self.injection_violations.extend(violations)
model, temperature = resolve_attacker_policy_config()
attacker_context_payload = self._attacker_policy_context()
attacker_action = self.policy_manager.decide(
scenario_id=self.scenario_id,
step=self.step_count,
attacker_state=self.attacker_state,
agent_action=action.model_dump(),
policy=self.policy,
scenario=self.scenario,
attacker_context=attacker_context_payload,
model=model,
temperature=temperature,
)
prior_state = self.attacker_state
ctx = ScenarioContext(
attacker_domain=self.scenario["attacker_domain"],
patient_zero_host=self.scenario["patient_zero_host"],
compromised_user=self.scenario["compromised_user"],
)
containment = ContainmentActions(
isolated_hosts=self.containment.isolated_hosts,
blocked_domains=self.containment.blocked_domains,
reset_users=self.containment.reset_users,
)
advance = advance_state(
self.attacker_state,
containment,
ctx,
attacker_action=attacker_action,
attacker_context=self.attacker_context,
attack_graph=self.scenario.get("attack_graph"),
)
self.attacker_state = advance.next_state
self.step_count += 1
if self._uses_attack_graph():
self._emit_action_artifacts(self.step_count, prior_state, attacker_action)
else:
self._emit_variant_artifacts(self.step_count, attacker_action)
if not advance.stalled:
effects = None
if advance.matched_action:
effects = advance.matched_action.get("effects")
apply_attacker_action(self.attacker_context, attacker_action, effects=effects)
done = self.step_count >= self.max_steps
new_emails = self._emails_for_step(self.step_count)
new_alerts = self._alerts_for_step(self.step_count)
self._record_evidence(new_emails, new_alerts)
reward = 0.0
if action.action_type == "submit_report" and self.ground_truth is not None:
report = action.params.get("summary_json", {})
executed_containment = containment_to_dict(self.containment)
score = score_report(
report, self.ground_truth, self.step_count, self.injection_violations,
executed_containment=executed_containment,
)
reward = score.reward
info = {
"action_type": action.action_type,
"attacker_action": attacker_action,
"attacker_stalled": advance.stalled,
"attacker_reason": advance.reason,
"injection_violations": list(self.injection_violations),
}
state = self.state()
observation = Observation(
scenario_id=self.scenario_id,
step_index=self.step_count,
attacker_state=self.attacker_state,
containment=self.containment,
new_emails=new_emails,
new_alerts=new_alerts,
evidence_seen_ids=sorted(self.seen_evidence_ids),
evidence_content_ids=sorted(self.content_evidence_ids),
last_action_result=ActionResult(ok=True, message=action.action_type, data=result),
done=done,
reward=reward,
metadata={"info": info, "state": state.model_dump()},
)
return StepResult(
observation=observation,
reward=reward,
done=done,
info=info,
state=state,
)
def _emit_variant_artifacts(self, step: int, attacker_action: Dict[str, Any]) -> None:
if self.scenario is None or self.db_path is None:
return
action_type = attacker_action.get("action_type")
action_params = attacker_action.get("params", {})
if not action_type:
return
timeline = self.scenario["attack_plan"]["timeline"]
log_templates = {t["template_id"]: t for t in self.scenario["seed_artifacts"]["log_templates"]}
for item in timeline:
if item["step"] != step:
continue
for art in item["artifacts"]:
variant_action = art.get("variant_action_type")
if not variant_action:
continue
if variant_action != action_type:
continue
variant_params = art.get("variant_params") or {}
if variant_params:
if any(action_params.get(k) != v for k, v in variant_params.items()):
continue
with sqlite3.connect(self.db_path) as conn:
emit_artifact(conn, self.scenario, step, art, log_templates, allow_variant=True)
conn.commit()
def _emit_action_artifacts(
self,
step: int,
prior_state: str,
attacker_action: Dict[str, Any],
) -> None:
if self.scenario is None or self.db_path is None:
return
action_type = attacker_action.get("action_type")
if not action_type or action_type == "no_op":
return
graph = self.scenario.get("attack_graph") or {}
states = graph.get("states", {})
node = states.get(prior_state, {})
actions = node.get("actions", [])
if not actions:
return
log_templates = {t["template_id"]: t for t in self.scenario["seed_artifacts"]["log_templates"]}
params = attacker_action.get("params") or {}
for action in actions:
if action.get("action_type") != action_type:
continue
match_params = action.get("match_params") or {}
if match_params:
if any(params.get(k) != v for k, v in match_params.items()):
continue
for art in action.get("artifacts", []):
art_match = art.get("match_params") or {}
if art_match:
if any(params.get(k) != v for k, v in art_match.items()):
continue
with sqlite3.connect(self.db_path) as conn:
emit_artifact(conn, self.scenario, step, art, log_templates, allow_variant=True)
conn.commit()
def _uses_attack_graph(self) -> bool:
return bool(self.scenario and self.scenario.get("attack_graph"))
def _attacker_policy_context(self) -> Dict[str, Any]:
entities = (self.scenario or {}).get("entities", {})
hosts = [h["host_id"] for h in entities.get("hosts", []) if h.get("host_id")]
users = [u["user_id"] for u in entities.get("users", []) if u.get("user_id")]
attacker_domains = [
d["domain"]
for d in entities.get("domains", [])
if d.get("domain_type") == "attacker"
]
available_hosts = [h for h in hosts if h not in self.containment.isolated_hosts]
available_users = [u for u in users if u not in self.containment.reset_users]
available_domains = [
d for d in attacker_domains if d not in self.containment.blocked_domains
]
return {
"step": self.step_count,
"containment": {
"isolated_hosts": sorted(self.containment.isolated_hosts),
"blocked_domains": sorted(self.containment.blocked_domains),
"reset_users": sorted(self.containment.reset_users),
},
"available_hosts": sorted(available_hosts),
"available_users": sorted(available_users),
"available_attacker_domains": sorted(available_domains),
"compromised_hosts": sorted(self.attacker_context.compromised_hosts),
"compromised_users": sorted(self.attacker_context.compromised_users),
"current_host": self.attacker_context.current_host,
"current_user": self.attacker_context.current_user,
"current_target": self.attacker_context.current_target,
"current_exfil_domain": self.attacker_context.current_exfil_domain,
"has_creds": self.attacker_context.has_creds,
"has_admin": self.attacker_context.has_admin,
"has_stage": self.attacker_context.has_stage,
"has_persistence": self.attacker_context.has_persistence,
}
def state(self) -> EpisodeState:
return EpisodeState(
episode_id=self.episode_id,
scenario_id=self.scenario_id,
step_count=self.step_count,
max_steps=self.max_steps,
terminated=False,
truncated=self.step_count >= self.max_steps,
)
def apply_action(self, action: AgentAction) -> Dict[str, Any]:
if action.action_type == "isolate_host":
host_id = action.params.get("host_id")
if host_id and host_id not in self.containment.isolated_hosts:
self.containment.isolated_hosts.append(host_id)
return {"ok": True, "isolated_host": host_id}
if action.action_type == "block_domain":
domain = action.params.get("domain")
if domain and domain not in self.containment.blocked_domains:
self.containment.blocked_domains.append(domain)
return {"ok": True, "blocked_domain": domain}
if action.action_type == "reset_user":
user_id = action.params.get("user_id")
if user_id and user_id not in self.containment.reset_users:
self.containment.reset_users.append(user_id)
return {"ok": True, "reset_user": user_id}
if action.action_type == "query_logs":
sql = action.params.get("sql", "")
if not self._is_readonly_select(sql):
return {"ok": False, "error": "only SELECT queries are allowed"}
try:
rows = self._query_logs(sql)
except sqlite3.OperationalError as exc:
return {"ok": False, "error": str(exc)}
self._record_content_evidence_from_rows(rows)
return {"ok": True, "rows": rows}
if action.action_type == "fetch_email":
email_id = action.params.get("email_id")
if not email_id:
return {"ok": False, "error": "email_id required"}
self.content_evidence_ids.add(email_id)
email = self._fetch_email(email_id)
return {"ok": True, "email_id": email_id, "email": email}
if action.action_type == "fetch_alert":
alert_id = action.params.get("alert_id")
if not alert_id:
return {"ok": False, "error": "alert_id required"}
self.content_evidence_ids.add(alert_id)
alert = self._fetch_alert(alert_id)
parsed = self._parse_alert_fields(alert.get("message", "")) if alert else {}
return {"ok": True, "alert_id": alert_id, "alert": alert, "parsed": parsed}
return {"ok": True}
def _load_scenario(self) -> None:
path = Path(self.seed_path)
self.scenario = json_load(path)
self.scenario_id = self.scenario["scenario_id"]
self.max_steps = self.scenario.get("metadata", {}).get("max_steps", self.max_steps)
gt_path = _resolve_ground_truth_path(Path(self.seed_path))
if gt_path is not None and gt_path.exists():
self.ground_truth = json_load(gt_path)
def _init_db(self) -> None:
sqlite_dir = Path(self.sqlite_dir)
sqlite_dir.mkdir(parents=True, exist_ok=True)
self.db_path = str(sqlite_dir / f"{self.scenario_id}-{self.episode_id}.db")
compile_seed(Path(self.seed_path), Path(self.db_path))
if self.cache is None:
self.cache = ReplayCache(self.db_path)
self.policy_manager = AttackerPolicyManager(cache=self.cache)
self.policy = resolve_attacker_policy()
def _query_logs(self, sql: str, params: tuple | None = None) -> List[Dict[str, Any]]:
assert self.db_path is not None
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
rows = conn.execute(sql, params or ()).fetchall()
return [dict(r) for r in rows]
def _fetch_alert(self, alert_id: str) -> Dict[str, Any] | None:
rows = self._query_logs(
"SELECT * FROM alerts WHERE scenario_id = ? AND alert_id = ?",
params=(self.scenario_id, alert_id),
)
return rows[0] if rows else None
def _fetch_email(self, email_id: str) -> Dict[str, Any] | None:
rows = self._query_logs(
"SELECT * FROM email_logs WHERE scenario_id = ? AND email_id = ?",
params=(self.scenario_id, email_id),
)
return rows[0] if rows else None
def _parse_alert_fields(self, message: str) -> Dict[str, str]:
# Extract key=value pairs from the alert message for structured access.
if not message:
return {}
matches = re.findall(r"([a-zA-Z_]+)=([a-zA-Z0-9_.:@-]+)", message)
parsed: Dict[str, str] = {}
for key, value in matches:
parsed[key] = value
return parsed
def _emails_for_step(self, step: int) -> List[str]:
rows = self._query_logs(
"SELECT email_id FROM email_logs WHERE scenario_id = ? AND step = ?",
params=(self.scenario_id, step),
)
return [r["email_id"] for r in rows]
def _alerts_for_step(self, step: int) -> List[str]:
rows = self._query_logs(
"SELECT alert_id FROM alerts WHERE scenario_id = ? AND step = ?",
params=(self.scenario_id, step),
)
return [r["alert_id"] for r in rows]
def _is_readonly_select(self, sql: str) -> bool:
stripped = sql.strip().lower()
return stripped.startswith("select")
def _record_evidence(self, new_emails: List[str], new_alerts: List[str]) -> None:
for eid in new_emails:
self.seen_evidence_ids.add(eid)
for aid in new_alerts:
self.seen_evidence_ids.add(aid)
def _record_content_evidence_from_rows(self, rows: List[Dict[str, Any]]) -> None:
for row in rows:
if "email_id" in row:
self.content_evidence_ids.add(str(row["email_id"]))
if "alert_id" in row:
self.content_evidence_ids.add(str(row["alert_id"]))
if "auth_id" in row:
self.content_evidence_ids.add(str(row["auth_id"]))
if "flow_id" in row:
self.content_evidence_ids.add(str(row["flow_id"]))
if "event_id" in row:
self.content_evidence_ids.add(str(row["event_id"]))
def json_load(path: Path) -> Dict[str, Any]:
import json
with path.open() as f:
return json.load(f)
def _resolve_ground_truth_path(seed_path: Path) -> Path | None:
name = seed_path.name
if name.endswith("_seed.json"):
return seed_path.with_name(name.replace("_seed.json", "_ground_truth.json"))
if name.endswith("seed.json"):
return seed_path.with_name(name.replace("seed.json", "ground_truth.json"))
return seed_path.with_name("sample_ground_truth.json")