Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import sqlite3 | |
| import re | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Set | |
| from sim.attacker_policy import ( | |
| AttackerPolicyManager, | |
| ReplayCache, | |
| init_cache_db, | |
| resolve_attacker_policy, | |
| resolve_attacker_policy_config, | |
| resolve_replay_mode, | |
| ) | |
| from sim.attacker_state_machine import ( | |
| AttackerContext, | |
| ContainmentActions, | |
| ScenarioContext, | |
| advance_state, | |
| apply_attacker_action, | |
| ) | |
| from sim.log_compiler import compile_seed, emit_artifact | |
| from oracle.scoring import containment_to_dict, score_report | |
| from oracle.verifier import detect_injection_violations | |
| from .models import ActionResult, AgentAction, ContainmentState, EpisodeState, Observation, StepResult | |
| class OpenSecEnvironment: | |
| def __init__( | |
| self, | |
| scenario_id: str = "seed-001", | |
| seed_path: str = "data/seeds/sample_seed.json", | |
| sqlite_dir: str = "data/sqlite", | |
| max_steps: int = 15, | |
| mask_injections: bool = False, | |
| ) -> None: | |
| self.scenario_id = scenario_id | |
| self.max_steps = max_steps | |
| self.seed_path = seed_path | |
| self.sqlite_dir = sqlite_dir | |
| self.mask_injections = mask_injections | |
| self.episode_id = "" | |
| self.step_count = 0 | |
| self.attacker_state = "phish_sent" | |
| self.attacker_context = AttackerContext() | |
| self.containment = ContainmentState() | |
| self.scenario: Optional[Dict[str, Any]] = None | |
| self.db_path: Optional[str] = None | |
| cache_path = os.getenv("OPENSEC_REPLAY_CACHE_PATH") | |
| if cache_path and resolve_replay_mode() != "off": | |
| init_cache_db(cache_path) | |
| self.cache = ReplayCache(cache_path) | |
| else: | |
| self.cache = None | |
| self.policy_manager = AttackerPolicyManager(cache=self.cache) | |
| self.policy = resolve_attacker_policy() | |
| self.ground_truth: Optional[Dict[str, Any]] = None | |
| self.seen_evidence_ids: Set[str] = set() | |
| self.content_evidence_ids: Set[str] = set() | |
| self.injection_violations: List[str] = [] | |
| def reset(self) -> StepResult: | |
| self.episode_id = str(uuid.uuid4()) | |
| self.step_count = 0 | |
| self.attacker_state = "phish_sent" | |
| self.attacker_context = AttackerContext() | |
| self.containment = ContainmentState() | |
| self.seen_evidence_ids = set() | |
| self.content_evidence_ids = set() | |
| self.injection_violations = [] | |
| self._load_scenario() | |
| if self.scenario and self.scenario.get("attack_graph", {}).get("start_state"): | |
| self.attacker_state = self.scenario["attack_graph"]["start_state"] | |
| if self.mask_injections and self.scenario: | |
| self.scenario["prompt_injection_payloads"] = [] | |
| self._init_db() | |
| new_emails = self._emails_for_step(0) | |
| new_alerts = self._alerts_for_step(0) | |
| self._record_evidence(new_emails, new_alerts) | |
| state = self.state() | |
| observation = Observation( | |
| scenario_id=self.scenario_id, | |
| step_index=self.step_count, | |
| attacker_state=self.attacker_state, | |
| containment=self.containment, | |
| new_emails=new_emails, | |
| new_alerts=new_alerts, | |
| evidence_seen_ids=sorted(self.seen_evidence_ids), | |
| evidence_content_ids=sorted(self.content_evidence_ids), | |
| last_action_result=ActionResult(ok=True, message="reset", data={}), | |
| hint="Environment ready", | |
| done=False, | |
| reward=None, | |
| metadata={"info": {}, "state": state.model_dump()}, | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=0.0, | |
| done=False, | |
| info={}, | |
| state=state, | |
| ) | |
| def step(self, action: AgentAction) -> StepResult: | |
| if self.scenario is None or self.db_path is None: | |
| self._load_scenario() | |
| self._init_db() | |
| result = self.apply_action(action) | |
| violations = detect_injection_violations( | |
| action.model_dump(), | |
| self.content_evidence_ids, | |
| self.scenario.get("prompt_injection_payloads", []), | |
| ) | |
| if violations: | |
| self.injection_violations.extend(violations) | |
| model, temperature = resolve_attacker_policy_config() | |
| attacker_context_payload = self._attacker_policy_context() | |
| attacker_action = self.policy_manager.decide( | |
| scenario_id=self.scenario_id, | |
| step=self.step_count, | |
| attacker_state=self.attacker_state, | |
| agent_action=action.model_dump(), | |
| policy=self.policy, | |
| scenario=self.scenario, | |
| attacker_context=attacker_context_payload, | |
| model=model, | |
| temperature=temperature, | |
| ) | |
| prior_state = self.attacker_state | |
| ctx = ScenarioContext( | |
| attacker_domain=self.scenario["attacker_domain"], | |
| patient_zero_host=self.scenario["patient_zero_host"], | |
| compromised_user=self.scenario["compromised_user"], | |
| ) | |
| containment = ContainmentActions( | |
| isolated_hosts=self.containment.isolated_hosts, | |
| blocked_domains=self.containment.blocked_domains, | |
| reset_users=self.containment.reset_users, | |
| ) | |
| advance = advance_state( | |
| self.attacker_state, | |
| containment, | |
| ctx, | |
| attacker_action=attacker_action, | |
| attacker_context=self.attacker_context, | |
| attack_graph=self.scenario.get("attack_graph"), | |
| ) | |
| self.attacker_state = advance.next_state | |
| self.step_count += 1 | |
| if self._uses_attack_graph(): | |
| self._emit_action_artifacts(self.step_count, prior_state, attacker_action) | |
| else: | |
| self._emit_variant_artifacts(self.step_count, attacker_action) | |
| if not advance.stalled: | |
| effects = None | |
| if advance.matched_action: | |
| effects = advance.matched_action.get("effects") | |
| apply_attacker_action(self.attacker_context, attacker_action, effects=effects) | |
| done = self.step_count >= self.max_steps | |
| new_emails = self._emails_for_step(self.step_count) | |
| new_alerts = self._alerts_for_step(self.step_count) | |
| self._record_evidence(new_emails, new_alerts) | |
| reward = 0.0 | |
| if action.action_type == "submit_report" and self.ground_truth is not None: | |
| report = action.params.get("summary_json", {}) | |
| executed_containment = containment_to_dict(self.containment) | |
| score = score_report( | |
| report, self.ground_truth, self.step_count, self.injection_violations, | |
| executed_containment=executed_containment, | |
| ) | |
| reward = score.reward | |
| info = { | |
| "action_type": action.action_type, | |
| "attacker_action": attacker_action, | |
| "attacker_stalled": advance.stalled, | |
| "attacker_reason": advance.reason, | |
| "injection_violations": list(self.injection_violations), | |
| } | |
| state = self.state() | |
| observation = Observation( | |
| scenario_id=self.scenario_id, | |
| step_index=self.step_count, | |
| attacker_state=self.attacker_state, | |
| containment=self.containment, | |
| new_emails=new_emails, | |
| new_alerts=new_alerts, | |
| evidence_seen_ids=sorted(self.seen_evidence_ids), | |
| evidence_content_ids=sorted(self.content_evidence_ids), | |
| last_action_result=ActionResult(ok=True, message=action.action_type, data=result), | |
| done=done, | |
| reward=reward, | |
| metadata={"info": info, "state": state.model_dump()}, | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=reward, | |
| done=done, | |
| info=info, | |
| state=state, | |
| ) | |
| def _emit_variant_artifacts(self, step: int, attacker_action: Dict[str, Any]) -> None: | |
| if self.scenario is None or self.db_path is None: | |
| return | |
| action_type = attacker_action.get("action_type") | |
| action_params = attacker_action.get("params", {}) | |
| if not action_type: | |
| return | |
| timeline = self.scenario["attack_plan"]["timeline"] | |
| log_templates = {t["template_id"]: t for t in self.scenario["seed_artifacts"]["log_templates"]} | |
| for item in timeline: | |
| if item["step"] != step: | |
| continue | |
| for art in item["artifacts"]: | |
| variant_action = art.get("variant_action_type") | |
| if not variant_action: | |
| continue | |
| if variant_action != action_type: | |
| continue | |
| variant_params = art.get("variant_params") or {} | |
| if variant_params: | |
| if any(action_params.get(k) != v for k, v in variant_params.items()): | |
| continue | |
| with sqlite3.connect(self.db_path) as conn: | |
| emit_artifact(conn, self.scenario, step, art, log_templates, allow_variant=True) | |
| conn.commit() | |
| def _emit_action_artifacts( | |
| self, | |
| step: int, | |
| prior_state: str, | |
| attacker_action: Dict[str, Any], | |
| ) -> None: | |
| if self.scenario is None or self.db_path is None: | |
| return | |
| action_type = attacker_action.get("action_type") | |
| if not action_type or action_type == "no_op": | |
| return | |
| graph = self.scenario.get("attack_graph") or {} | |
| states = graph.get("states", {}) | |
| node = states.get(prior_state, {}) | |
| actions = node.get("actions", []) | |
| if not actions: | |
| return | |
| log_templates = {t["template_id"]: t for t in self.scenario["seed_artifacts"]["log_templates"]} | |
| params = attacker_action.get("params") or {} | |
| for action in actions: | |
| if action.get("action_type") != action_type: | |
| continue | |
| match_params = action.get("match_params") or {} | |
| if match_params: | |
| if any(params.get(k) != v for k, v in match_params.items()): | |
| continue | |
| for art in action.get("artifacts", []): | |
| art_match = art.get("match_params") or {} | |
| if art_match: | |
| if any(params.get(k) != v for k, v in art_match.items()): | |
| continue | |
| with sqlite3.connect(self.db_path) as conn: | |
| emit_artifact(conn, self.scenario, step, art, log_templates, allow_variant=True) | |
| conn.commit() | |
| def _uses_attack_graph(self) -> bool: | |
| return bool(self.scenario and self.scenario.get("attack_graph")) | |
| def _attacker_policy_context(self) -> Dict[str, Any]: | |
| entities = (self.scenario or {}).get("entities", {}) | |
| hosts = [h["host_id"] for h in entities.get("hosts", []) if h.get("host_id")] | |
| users = [u["user_id"] for u in entities.get("users", []) if u.get("user_id")] | |
| attacker_domains = [ | |
| d["domain"] | |
| for d in entities.get("domains", []) | |
| if d.get("domain_type") == "attacker" | |
| ] | |
| available_hosts = [h for h in hosts if h not in self.containment.isolated_hosts] | |
| available_users = [u for u in users if u not in self.containment.reset_users] | |
| available_domains = [ | |
| d for d in attacker_domains if d not in self.containment.blocked_domains | |
| ] | |
| return { | |
| "step": self.step_count, | |
| "containment": { | |
| "isolated_hosts": sorted(self.containment.isolated_hosts), | |
| "blocked_domains": sorted(self.containment.blocked_domains), | |
| "reset_users": sorted(self.containment.reset_users), | |
| }, | |
| "available_hosts": sorted(available_hosts), | |
| "available_users": sorted(available_users), | |
| "available_attacker_domains": sorted(available_domains), | |
| "compromised_hosts": sorted(self.attacker_context.compromised_hosts), | |
| "compromised_users": sorted(self.attacker_context.compromised_users), | |
| "current_host": self.attacker_context.current_host, | |
| "current_user": self.attacker_context.current_user, | |
| "current_target": self.attacker_context.current_target, | |
| "current_exfil_domain": self.attacker_context.current_exfil_domain, | |
| "has_creds": self.attacker_context.has_creds, | |
| "has_admin": self.attacker_context.has_admin, | |
| "has_stage": self.attacker_context.has_stage, | |
| "has_persistence": self.attacker_context.has_persistence, | |
| } | |
| def state(self) -> EpisodeState: | |
| return EpisodeState( | |
| episode_id=self.episode_id, | |
| scenario_id=self.scenario_id, | |
| step_count=self.step_count, | |
| max_steps=self.max_steps, | |
| terminated=False, | |
| truncated=self.step_count >= self.max_steps, | |
| ) | |
| def apply_action(self, action: AgentAction) -> Dict[str, Any]: | |
| if action.action_type == "isolate_host": | |
| host_id = action.params.get("host_id") | |
| if host_id and host_id not in self.containment.isolated_hosts: | |
| self.containment.isolated_hosts.append(host_id) | |
| return {"ok": True, "isolated_host": host_id} | |
| if action.action_type == "block_domain": | |
| domain = action.params.get("domain") | |
| if domain and domain not in self.containment.blocked_domains: | |
| self.containment.blocked_domains.append(domain) | |
| return {"ok": True, "blocked_domain": domain} | |
| if action.action_type == "reset_user": | |
| user_id = action.params.get("user_id") | |
| if user_id and user_id not in self.containment.reset_users: | |
| self.containment.reset_users.append(user_id) | |
| return {"ok": True, "reset_user": user_id} | |
| if action.action_type == "query_logs": | |
| sql = action.params.get("sql", "") | |
| if not self._is_readonly_select(sql): | |
| return {"ok": False, "error": "only SELECT queries are allowed"} | |
| try: | |
| rows = self._query_logs(sql) | |
| except sqlite3.OperationalError as exc: | |
| return {"ok": False, "error": str(exc)} | |
| self._record_content_evidence_from_rows(rows) | |
| return {"ok": True, "rows": rows} | |
| if action.action_type == "fetch_email": | |
| email_id = action.params.get("email_id") | |
| if not email_id: | |
| return {"ok": False, "error": "email_id required"} | |
| self.content_evidence_ids.add(email_id) | |
| email = self._fetch_email(email_id) | |
| return {"ok": True, "email_id": email_id, "email": email} | |
| if action.action_type == "fetch_alert": | |
| alert_id = action.params.get("alert_id") | |
| if not alert_id: | |
| return {"ok": False, "error": "alert_id required"} | |
| self.content_evidence_ids.add(alert_id) | |
| alert = self._fetch_alert(alert_id) | |
| parsed = self._parse_alert_fields(alert.get("message", "")) if alert else {} | |
| return {"ok": True, "alert_id": alert_id, "alert": alert, "parsed": parsed} | |
| return {"ok": True} | |
| def _load_scenario(self) -> None: | |
| path = Path(self.seed_path) | |
| self.scenario = json_load(path) | |
| self.scenario_id = self.scenario["scenario_id"] | |
| self.max_steps = self.scenario.get("metadata", {}).get("max_steps", self.max_steps) | |
| gt_path = _resolve_ground_truth_path(Path(self.seed_path)) | |
| if gt_path is not None and gt_path.exists(): | |
| self.ground_truth = json_load(gt_path) | |
| def _init_db(self) -> None: | |
| sqlite_dir = Path(self.sqlite_dir) | |
| sqlite_dir.mkdir(parents=True, exist_ok=True) | |
| self.db_path = str(sqlite_dir / f"{self.scenario_id}-{self.episode_id}.db") | |
| compile_seed(Path(self.seed_path), Path(self.db_path)) | |
| if self.cache is None: | |
| self.cache = ReplayCache(self.db_path) | |
| self.policy_manager = AttackerPolicyManager(cache=self.cache) | |
| self.policy = resolve_attacker_policy() | |
| def _query_logs(self, sql: str, params: tuple | None = None) -> List[Dict[str, Any]]: | |
| assert self.db_path is not None | |
| with sqlite3.connect(self.db_path) as conn: | |
| conn.row_factory = sqlite3.Row | |
| rows = conn.execute(sql, params or ()).fetchall() | |
| return [dict(r) for r in rows] | |
| def _fetch_alert(self, alert_id: str) -> Dict[str, Any] | None: | |
| rows = self._query_logs( | |
| "SELECT * FROM alerts WHERE scenario_id = ? AND alert_id = ?", | |
| params=(self.scenario_id, alert_id), | |
| ) | |
| return rows[0] if rows else None | |
| def _fetch_email(self, email_id: str) -> Dict[str, Any] | None: | |
| rows = self._query_logs( | |
| "SELECT * FROM email_logs WHERE scenario_id = ? AND email_id = ?", | |
| params=(self.scenario_id, email_id), | |
| ) | |
| return rows[0] if rows else None | |
| def _parse_alert_fields(self, message: str) -> Dict[str, str]: | |
| # Extract key=value pairs from the alert message for structured access. | |
| if not message: | |
| return {} | |
| matches = re.findall(r"([a-zA-Z_]+)=([a-zA-Z0-9_.:@-]+)", message) | |
| parsed: Dict[str, str] = {} | |
| for key, value in matches: | |
| parsed[key] = value | |
| return parsed | |
| def _emails_for_step(self, step: int) -> List[str]: | |
| rows = self._query_logs( | |
| "SELECT email_id FROM email_logs WHERE scenario_id = ? AND step = ?", | |
| params=(self.scenario_id, step), | |
| ) | |
| return [r["email_id"] for r in rows] | |
| def _alerts_for_step(self, step: int) -> List[str]: | |
| rows = self._query_logs( | |
| "SELECT alert_id FROM alerts WHERE scenario_id = ? AND step = ?", | |
| params=(self.scenario_id, step), | |
| ) | |
| return [r["alert_id"] for r in rows] | |
| def _is_readonly_select(self, sql: str) -> bool: | |
| stripped = sql.strip().lower() | |
| return stripped.startswith("select") | |
| def _record_evidence(self, new_emails: List[str], new_alerts: List[str]) -> None: | |
| for eid in new_emails: | |
| self.seen_evidence_ids.add(eid) | |
| for aid in new_alerts: | |
| self.seen_evidence_ids.add(aid) | |
| def _record_content_evidence_from_rows(self, rows: List[Dict[str, Any]]) -> None: | |
| for row in rows: | |
| if "email_id" in row: | |
| self.content_evidence_ids.add(str(row["email_id"])) | |
| if "alert_id" in row: | |
| self.content_evidence_ids.add(str(row["alert_id"])) | |
| if "auth_id" in row: | |
| self.content_evidence_ids.add(str(row["auth_id"])) | |
| if "flow_id" in row: | |
| self.content_evidence_ids.add(str(row["flow_id"])) | |
| if "event_id" in row: | |
| self.content_evidence_ids.add(str(row["event_id"])) | |
| def json_load(path: Path) -> Dict[str, Any]: | |
| import json | |
| with path.open() as f: | |
| return json.load(f) | |
| def _resolve_ground_truth_path(seed_path: Path) -> Path | None: | |
| name = seed_path.name | |
| if name.endswith("_seed.json"): | |
| return seed_path.with_name(name.replace("_seed.json", "_ground_truth.json")) | |
| if name.endswith("seed.json"): | |
| return seed_path.with_name(name.replace("seed.json", "ground_truth.json")) | |
| return seed_path.with_name("sample_ground_truth.json") | |