opensec-env / scripts /green_agent.py
Jarrodbarnes's picture
Upload folder using huggingface_hub
b77ed28 verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
from typing import Dict, Optional
from server.models import AgentAction
try:
from client.openenv_client import OpenSecEnv
except Exception as exc: # pragma: no cover - optional dependency
raise SystemExit(
"OpenEnv client not available. Install with: pip install '.[openenv]'"
) from exc
def _best_effort_value(parsed: Dict[str, str], *keys: str) -> Optional[str]:
for key in keys:
value = parsed.get(key)
if value:
return value
return None
def run_green_agent(base_url: str, max_steps: int = 15) -> None:
with OpenSecEnv(base_url=base_url) as env:
result = env.reset()
parsed: Dict[str, str] = {}
for _ in range(max_steps):
obs = result.observation
if obs.new_alerts:
alert_id = obs.new_alerts[0]
result = env.step(
AgentAction(action_type="fetch_alert", params={"alert_id": alert_id})
)
parsed = result.observation.last_action_result.data.get("parsed", {})
break
# Fallback: query for any alert id, then fetch
sql = (
"SELECT alert_id FROM alerts "
f"WHERE scenario_id = '{obs.scenario_id}' "
"ORDER BY step DESC LIMIT 1"
)
result = env.step(AgentAction(action_type="query_logs", params={"sql": sql}))
rows = result.observation.last_action_result.data.get("rows", [])
if rows:
alert_id = rows[0].get("alert_id")
if alert_id:
result = env.step(
AgentAction(
action_type="fetch_alert", params={"alert_id": alert_id}
)
)
parsed = result.observation.last_action_result.data.get("parsed", {})
break
attacker_domain = _best_effort_value(parsed, "dst_domain", "domain")
patient_zero_host = _best_effort_value(parsed, "src_host", "host")
compromised_user = _best_effort_value(parsed, "compromised_user", "user")
data_target = _best_effort_value(parsed, "data_target", "target")
if patient_zero_host:
env.step(
AgentAction(action_type="isolate_host", params={"host_id": patient_zero_host})
)
if attacker_domain:
env.step(
AgentAction(action_type="block_domain", params={"domain": attacker_domain})
)
if compromised_user:
env.step(
AgentAction(action_type="reset_user", params={"user_id": compromised_user})
)
report = {
"patient_zero_host": patient_zero_host or "unknown",
"compromised_user": compromised_user or "unknown",
"attacker_domain": attacker_domain or "unknown",
"data_target": data_target or "unknown",
"initial_vector": "unknown",
"containment_actions": {
"isolated_hosts": [patient_zero_host] if patient_zero_host else [],
"blocked_domains": [attacker_domain] if attacker_domain else [],
"reset_users": [compromised_user] if compromised_user else [],
},
}
result = env.step(
AgentAction(action_type="submit_report", params={"summary_json": report})
)
print(f"done={result.done} reward={result.reward}")
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", default="http://localhost:8000")
parser.add_argument("--max-steps", type=int, default=15)
args = parser.parse_args()
run_green_agent(args.base_url, args.max_steps)
return 0
if __name__ == "__main__":
raise SystemExit(main())