File size: 3,920 Bytes
b77ed28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
from __future__ import annotations

import argparse
from typing import Dict, Optional

from server.models import AgentAction

try:
    from client.openenv_client import OpenSecEnv
except Exception as exc:  # pragma: no cover - optional dependency
    raise SystemExit(
        "OpenEnv client not available. Install with: pip install '.[openenv]'"
    ) from exc


def _best_effort_value(parsed: Dict[str, str], *keys: str) -> Optional[str]:
    for key in keys:
        value = parsed.get(key)
        if value:
            return value
    return None


def run_green_agent(base_url: str, max_steps: int = 15) -> None:
    with OpenSecEnv(base_url=base_url) as env:
        result = env.reset()
        parsed: Dict[str, str] = {}

        for _ in range(max_steps):
            obs = result.observation

            if obs.new_alerts:
                alert_id = obs.new_alerts[0]
                result = env.step(
                    AgentAction(action_type="fetch_alert", params={"alert_id": alert_id})
                )
                parsed = result.observation.last_action_result.data.get("parsed", {})
                break

            # Fallback: query for any alert id, then fetch
            sql = (
                "SELECT alert_id FROM alerts "
                f"WHERE scenario_id = '{obs.scenario_id}' "
                "ORDER BY step DESC LIMIT 1"
            )
            result = env.step(AgentAction(action_type="query_logs", params={"sql": sql}))
            rows = result.observation.last_action_result.data.get("rows", [])
            if rows:
                alert_id = rows[0].get("alert_id")
                if alert_id:
                    result = env.step(
                        AgentAction(
                            action_type="fetch_alert", params={"alert_id": alert_id}
                        )
                    )
                    parsed = result.observation.last_action_result.data.get("parsed", {})
                    break

        attacker_domain = _best_effort_value(parsed, "dst_domain", "domain")
        patient_zero_host = _best_effort_value(parsed, "src_host", "host")
        compromised_user = _best_effort_value(parsed, "compromised_user", "user")
        data_target = _best_effort_value(parsed, "data_target", "target")

        if patient_zero_host:
            env.step(
                AgentAction(action_type="isolate_host", params={"host_id": patient_zero_host})
            )
        if attacker_domain:
            env.step(
                AgentAction(action_type="block_domain", params={"domain": attacker_domain})
            )
        if compromised_user:
            env.step(
                AgentAction(action_type="reset_user", params={"user_id": compromised_user})
            )

        report = {
            "patient_zero_host": patient_zero_host or "unknown",
            "compromised_user": compromised_user or "unknown",
            "attacker_domain": attacker_domain or "unknown",
            "data_target": data_target or "unknown",
            "initial_vector": "unknown",
            "containment_actions": {
                "isolated_hosts": [patient_zero_host] if patient_zero_host else [],
                "blocked_domains": [attacker_domain] if attacker_domain else [],
                "reset_users": [compromised_user] if compromised_user else [],
            },
        }

        result = env.step(
            AgentAction(action_type="submit_report", params={"summary_json": report})
        )
        print(f"done={result.done} reward={result.reward}")


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--base-url", default="http://localhost:8000")
    parser.add_argument("--max-steps", type=int, default=15)
    args = parser.parse_args()

    run_green_agent(args.base_url, args.max_steps)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())