Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from urllib import error as urlerror | |
| from urllib import request as urlrequest | |
| from contextlib import contextmanager | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List | |
| from eval_utils import load_env, load_json | |
| from oracle.scoring import containment_to_dict | |
| from server.environment import OpenSecEnvironment | |
| from server.models import AgentAction | |
| from sim.attacker_state_machine import STATE_INDEX | |
| def _default_report() -> Dict[str, Any]: | |
| return { | |
| "patient_zero_host": "unknown", | |
| "compromised_user": "unknown", | |
| "attacker_domain": "unknown", | |
| "data_target": "unknown", | |
| "initial_vector": "phish", | |
| "containment_actions": {"isolated_hosts": [], "blocked_domains": [], "reset_users": []}, | |
| } | |
| def _oracle_report(seed: Dict[str, Any]) -> Dict[str, Any]: | |
| return { | |
| "patient_zero_host": seed["patient_zero_host"], | |
| "compromised_user": seed["compromised_user"], | |
| "attacker_domain": seed["attacker_domain"], | |
| "data_target": seed["data_target"], | |
| "initial_vector": seed["attack_plan"]["initial_vector"], | |
| "containment_actions": { | |
| "isolated_hosts": [seed["patient_zero_host"]], | |
| "blocked_domains": [seed["attacker_domain"]], | |
| "reset_users": [seed["compromised_user"]], | |
| }, | |
| } | |
| def _attack_graph_progress(seed: Dict[str, Any], attacker_state: str) -> int | None: | |
| graph = seed.get("attack_graph") or {} | |
| objectives = graph.get("objectives") or [] | |
| if objectives: | |
| index_map = {state: idx for idx, state in enumerate(objectives)} | |
| if attacker_state in index_map: | |
| return index_map[attacker_state] | |
| return STATE_INDEX.get(attacker_state) | |
| def _temp_env(overrides: Dict[str, str | None]): | |
| prior: Dict[str, str | None] = {} | |
| for key, value in overrides.items(): | |
| prior[key] = os.environ.get(key) | |
| if value is None: | |
| os.environ.pop(key, None) | |
| else: | |
| os.environ[key] = value | |
| try: | |
| yield | |
| finally: | |
| for key, value in prior.items(): | |
| if value is None: | |
| os.environ.pop(key, None) | |
| else: | |
| os.environ[key] = value | |
| def _attacker_episode_metrics( | |
| trace: List[Dict[str, Any]], | |
| attacker_state: str, | |
| attacker_context: Dict[str, Any], | |
| ) -> Dict[str, Any]: | |
| exfil_attempts = 0 | |
| time_to_exfil = None | |
| lateral_spread_hosts = set() | |
| persistence = False | |
| stall_reason = None | |
| for entry in trace: | |
| action = entry.get("attacker_action") or {} | |
| action_type = action.get("action_type") | |
| params = action.get("params") or {} | |
| if action_type: | |
| if action_type.startswith("exfiltrate"): | |
| exfil_attempts += 1 | |
| if time_to_exfil is None: | |
| time_to_exfil = entry.get("step_index") | |
| if action_type in ("lateral_move", "lateral_move_alt", "lateral_spread", "pivot"): | |
| dst = params.get("dst") or params.get("host") or params.get("target_host") | |
| if dst: | |
| lateral_spread_hosts.add(dst) | |
| if action_type in ("establish_persistence", "persistence"): | |
| persistence = True | |
| if entry.get("attacker_stalled"): | |
| stall_reason = entry.get("attacker_reason") | |
| persistence_established = persistence or attacker_context.get("has_persistence", False) | |
| reached_exfil = attacker_state == "exfil_attempt" or exfil_attempts > 0 | |
| return { | |
| "attacker_reached_exfil": reached_exfil, | |
| "exfil_attempt_count": exfil_attempts, | |
| "persistence_established": persistence_established, | |
| "lateral_spread_count": len(lateral_spread_hosts), | |
| "time_to_exfil": time_to_exfil, | |
| "stall_reason": stall_reason, | |
| } | |
| def _run_episode(seed_path: Path, defender: str, max_steps: int) -> Dict[str, Any]: | |
| seed = load_json(seed_path) | |
| env = OpenSecEnvironment(seed_path=str(seed_path)) | |
| reset = env.reset() | |
| episode_max_steps = min(max_steps, env.max_steps) | |
| if defender == "oracle": | |
| report = _oracle_report(seed) | |
| actions = [ | |
| AgentAction(action_type="query_logs", params={"sql": "SELECT 1"}), | |
| AgentAction(action_type="isolate_host", params={"host_id": seed["patient_zero_host"]}), | |
| AgentAction(action_type="block_domain", params={"domain": seed["attacker_domain"]}), | |
| AgentAction(action_type="reset_user", params={"user_id": seed["compromised_user"]}), | |
| ] | |
| else: | |
| report = _default_report() | |
| actions = [AgentAction(action_type="query_logs", params={"sql": "SELECT 1"})] | |
| step_results: List[Dict[str, Any]] = [] | |
| remaining = max(0, episode_max_steps - 1) | |
| for step_idx in range(remaining): | |
| action = actions[step_idx] if step_idx < len(actions) else actions[0] | |
| result = env.step(action) | |
| step_results.append( | |
| { | |
| "step_index": step_idx + 1, | |
| "action": action.model_dump(), | |
| "attacker_action": result.info.get("attacker_action"), | |
| "attacker_state": result.observation.attacker_state, | |
| "attacker_stalled": result.info.get("attacker_stalled"), | |
| "attacker_reason": result.info.get("attacker_reason"), | |
| "injection_violations": result.info.get("injection_violations", []), | |
| } | |
| ) | |
| final_action = AgentAction(action_type="submit_report", params={"summary_json": report}) | |
| final = env.step(final_action) | |
| step_results.append( | |
| { | |
| "step_index": env.step_count, | |
| "action": final_action.model_dump(), | |
| "attacker_action": final.info.get("attacker_action"), | |
| "attacker_state": final.observation.attacker_state, | |
| "attacker_stalled": final.info.get("attacker_stalled"), | |
| "attacker_reason": final.info.get("attacker_reason"), | |
| "injection_violations": final.info.get("injection_violations", []), | |
| } | |
| ) | |
| executed_containment = containment_to_dict(env.containment) | |
| containment_attempted = bool( | |
| env.containment.isolated_hosts | |
| or env.containment.blocked_domains | |
| or env.containment.reset_users | |
| ) | |
| attacker_state = final.observation.attacker_state | |
| attacker_progress = _attack_graph_progress(seed, attacker_state) | |
| attacker_context = { | |
| "has_persistence": env.attacker_context.has_persistence, | |
| "has_stage": env.attacker_context.has_stage, | |
| "has_admin": env.attacker_context.has_admin, | |
| "has_creds": env.attacker_context.has_creds, | |
| } | |
| attacker_metrics = _attacker_episode_metrics(step_results, attacker_state, attacker_context) | |
| return { | |
| "scenario_id": env.scenario_id, | |
| "seed_path": str(seed_path), | |
| "step_count": env.step_count, | |
| "reward": final.reward, | |
| "executed_containment": executed_containment, | |
| "containment_attempted": containment_attempted, | |
| "injection_violations": list(env.injection_violations), | |
| "attacker_state": attacker_state, | |
| "attacker_progress": attacker_progress, | |
| **attacker_metrics, | |
| "trace": step_results, | |
| } | |
| def _summarize(rows: Iterable[Dict[str, Any]]) -> Dict[str, Any]: | |
| rows = list(rows) | |
| total = len(rows) | |
| if total == 0: | |
| return { | |
| "episodes": 0, | |
| "mean_reward": 0.0, | |
| "containment_rate": 0.0, | |
| "injection_rate": 0.0, | |
| "attacker_reached_exfil_rate": 0.0, | |
| "exfil_attempt_mean": 0.0, | |
| "persistence_established_rate": 0.0, | |
| "lateral_spread_mean": 0.0, | |
| "time_to_exfil_mean": None, | |
| "mean_steps": 0.0, | |
| } | |
| mean_reward = sum(r["reward"] for r in rows) / total | |
| containment_rate = sum(1 for r in rows if r["containment_attempted"]) / total | |
| injection_rate = sum(1 for r in rows if r["injection_violations"]) / total | |
| exfil_rate = sum(1 for r in rows if r["attacker_reached_exfil"]) / total | |
| exfil_attempt_mean = sum(r["exfil_attempt_count"] for r in rows) / total | |
| persistence_rate = sum(1 for r in rows if r["persistence_established"]) / total | |
| lateral_spread_mean = sum(r["lateral_spread_count"] for r in rows) / total | |
| exfil_times = [r["time_to_exfil"] for r in rows if r["time_to_exfil"] is not None] | |
| time_to_exfil_mean = sum(exfil_times) / len(exfil_times) if exfil_times else None | |
| mean_steps = sum(r["step_count"] for r in rows) / total | |
| return { | |
| "episodes": total, | |
| "mean_reward": round(mean_reward, 4), | |
| "containment_rate": round(containment_rate, 4), | |
| "injection_rate": round(injection_rate, 4), | |
| "attacker_reached_exfil_rate": round(exfil_rate, 4), | |
| "exfil_attempt_mean": round(exfil_attempt_mean, 4), | |
| "persistence_established_rate": round(persistence_rate, 4), | |
| "lateral_spread_mean": round(lateral_spread_mean, 4), | |
| "time_to_exfil_mean": round(time_to_exfil_mean, 4) if time_to_exfil_mean is not None else None, | |
| "mean_steps": round(mean_steps, 4), | |
| } | |
| def _preflight_sglang(base_url: str) -> None: | |
| url = base_url.rstrip("/") + "/models" | |
| try: | |
| with urlrequest.urlopen(url, timeout=5) as response: | |
| if response.status >= 400: | |
| raise RuntimeError(f"SGLang returned HTTP {response.status}") | |
| payload = json.loads(response.read().decode("utf-8")) | |
| if not isinstance(payload, dict) or "data" not in payload: | |
| raise RuntimeError("SGLang response missing models payload") | |
| except Exception as exc: | |
| raise SystemExit( | |
| "Strict attacker is enabled but the SGLang backend is not reachable. " | |
| f"Check SGLANG_BASE_URL ({base_url}) and ensure the server is running. " | |
| f"Details: {exc}" | |
| ) from exc | |
| def _preflight_openai(api_key: str) -> None: | |
| url = "https://api.openai.com/v1/models" | |
| request = urlrequest.Request(url) | |
| request.add_header("Authorization", f"Bearer {api_key}") | |
| try: | |
| with urlrequest.urlopen(request, timeout=5) as response: | |
| if response.status >= 400: | |
| raise RuntimeError(f"OpenAI returned HTTP {response.status}") | |
| payload = json.loads(response.read().decode("utf-8")) | |
| if not isinstance(payload, dict) or "data" not in payload: | |
| raise RuntimeError("OpenAI response missing models payload") | |
| except urlerror.URLError as exc: | |
| raise SystemExit( | |
| "Strict attacker is enabled but the OpenAI backend is not reachable. " | |
| f"Details: {exc}" | |
| ) from exc | |
| except Exception as exc: | |
| raise SystemExit( | |
| "Strict attacker is enabled but the OpenAI backend check failed. " | |
| f"Details: {exc}" | |
| ) from exc | |
| def _preflight_live_backend() -> None: | |
| if os.getenv("OPENSEC_ATTACKER_SGLANG") == "1": | |
| base_url = os.getenv("SGLANG_BASE_URL", "http://localhost:30000/v1") | |
| _preflight_sglang(base_url) | |
| elif os.getenv("OPENAI_API_KEY"): | |
| _preflight_openai(os.getenv("OPENAI_API_KEY", "")) | |
| def main() -> int: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--manifest", default="data/seeds/manifest.json") | |
| parser.add_argument("--split", default="eval", choices=["train", "eval"]) | |
| parser.add_argument("--tier", default=None, choices=["trivial", "easy", "standard"]) | |
| parser.add_argument("--limit", type=int, default=10) | |
| parser.add_argument("--max-steps", type=int, default=15) | |
| parser.add_argument("--defender", default="noop", choices=["noop", "oracle"]) | |
| parser.add_argument("--output-dir", default="outputs/tier_eval") | |
| parser.add_argument("--replay-mode", default="record", choices=["off", "record", "replay"]) | |
| parser.add_argument("--replay-cache", default="") | |
| parser.add_argument("--tiers", default="T0,T1,T2", help="Comma-separated tiers to run") | |
| parser.add_argument("--strict-attacker", default="1", choices=["0", "1"]) | |
| args = parser.parse_args() | |
| load_env() | |
| if args.strict_attacker == "1": | |
| has_sglang = os.getenv("OPENSEC_ATTACKER_SGLANG") == "1" | |
| has_openai = bool(os.getenv("OPENAI_API_KEY")) | |
| if not (has_sglang or has_openai): | |
| raise SystemExit( | |
| "Strict attacker is enabled but no live LLM backend is configured. " | |
| "Set OPENSEC_ATTACKER_SGLANG=1 or OPENAI_API_KEY." | |
| ) | |
| _preflight_live_backend() | |
| manifest = load_json(Path(args.manifest)) | |
| seeds = manifest[args.split] | |
| if args.tier: | |
| seeds = [entry for entry in seeds if entry.get("tier") == args.tier] | |
| if args.limit: | |
| seeds = seeds[: args.limit] | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| tiers = [ | |
| { | |
| "name": "T0", | |
| "env": { | |
| "OPENSEC_ATTACKER_SANDBOX": "0", | |
| "OPENSEC_ATTACKER_SGLANG": None, | |
| "OPENAI_API_KEY": None, | |
| }, | |
| }, | |
| { | |
| "name": "T1", | |
| "env": { | |
| "OPENSEC_ATTACKER_SANDBOX": "0", | |
| }, | |
| }, | |
| { | |
| "name": "T2", | |
| "env": { | |
| "OPENSEC_ATTACKER_SANDBOX": "1", | |
| }, | |
| }, | |
| ] | |
| summaries: Dict[str, Any] = {} | |
| gate_failures: List[str] = [] | |
| wanted = {t.strip().upper() for t in args.tiers.split(",") if t.strip()} | |
| for tier in tiers: | |
| if tier["name"] not in wanted: | |
| continue | |
| tier_env = dict(tier["env"]) | |
| tier_env["OPENSEC_REPLAY_MODE"] = args.replay_mode | |
| if args.replay_cache: | |
| tier_env["OPENSEC_REPLAY_CACHE_PATH"] = args.replay_cache | |
| if tier["name"] in {"T1", "T2"}: | |
| tier_env["OPENSEC_ATTACKER_STRICT"] = args.strict_attacker | |
| else: | |
| tier_env["OPENSEC_ATTACKER_STRICT"] = "0" | |
| rows: List[Dict[str, Any]] = [] | |
| with _temp_env(tier_env): | |
| for entry in seeds: | |
| seed_path = Path(entry["seed_path"]) | |
| rows.append(_run_episode(seed_path, args.defender, args.max_steps)) | |
| out_path = output_dir / f"tier_{tier['name'].lower()}.jsonl" | |
| with out_path.open("w") as f: | |
| for row in rows: | |
| f.write(json.dumps(row) + "\n") | |
| summaries[tier["name"]] = _summarize(rows) | |
| if tier["name"] in {"T1", "T2"} and summaries[tier["name"]]["episodes"] > 0: | |
| if summaries[tier["name"]]["attacker_reached_exfil_rate"] == 0.0: | |
| gate_failures.append( | |
| f"{tier['name']}: attacker_reached_exfil_rate == 0" | |
| ) | |
| summary_path = output_dir / "summary.json" | |
| summary_path.write_text(json.dumps(summaries, indent=2)) | |
| print(json.dumps(summaries, indent=2)) | |
| if gate_failures: | |
| print("Tier eval gate failed: " + "; ".join(gate_failures)) | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |