File size: 15,161 Bytes
ca83593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
#!/usr/bin/env python3
"""
IAMSentinel Baseline Inference Script
======================================
Runs a GPT-4o ReAct agent against all 3 tasks and reports scores.

Usage:
    export OPENAI_API_KEY=sk-...
    python scripts/baseline_agent.py [--task all|task1|task2|task3] [--seed 42] [--model gpt-4o]

Reproducible baseline scores (seed=42, complexity=medium, model=gpt-4o-mini):
    Task 1 (Easy):   ~0.55–0.70
    Task 2 (Medium): ~0.35–0.50
    Task 3 (Hard):   ~0.20–0.35
"""

import argparse
import json
import os
import sys
import time
from typing import Optional

try:
    from openai import OpenAI
except ImportError:
    print("ERROR: openai package not installed. Run: pip install openai")
    sys.exit(1)

# Ensure package is importable
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from iamsentinel import IAMSentinelEnv


# ──────────────────────────────────────────────
# System prompt for the ReAct agent
# ──────────────────────────────────────────────

SYSTEM_PROMPT = """You are an expert cloud security analyst specialising in AWS IAM security.
You are operating inside a simulated IAM environment and must complete security tasks.

You interact with the environment by outputting JSON actions. Each response must contain
EXACTLY ONE action as a JSON block in this format:

```json
{
  "action": "<action_name>",
  ... action parameters ...
}
```

Available actions:
1. list_principals β€” {"action": "list_principals", "kind": "all"|"user"|"role"}
2. list_policies   β€” {"action": "list_policies", "principal_arn": "<arn or null>"}
3. get_policy      β€” {"action": "get_policy", "policy_arn": "<arn>"}
4. get_principal   β€” {"action": "get_principal", "principal_arn": "<arn>"}
5. get_role_trust  β€” {"action": "get_role_trust", "role_arn": "<arn>"}
6. query_audit_log β€” {"action": "query_audit_log", "filter": {"event_name": "...", "severity": "...", "principal_arn": "...", "source_ip": "..."}, "limit": 20}
7. trace_escalation_path β€” {"action": "trace_escalation_path", "from_principal_arn": "<arn>", "to_principal_arn": null}
8. flag_finding    β€” {
     "action": "flag_finding",
     "finding_type": "wildcard_policy"|"mfa_disabled"|"stale_admin_role"|"privilege_escalation_path"|"exposed_trust_policy"|"suspicious_event",
     "affected_principal_arn": "<arn or null>",
     "affected_policy_arn": "<arn or null>",
     "severity": "low"|"medium"|"high"|"critical",
     "description": "<description>",
     "mitre_technique": "<T-code or null>",
     "evidence": ["<arn or event_id>", ...]
   }
9. remediate       β€” {"action": "remediate", "remediation_type": "detach_policy"|"delete_user"|"require_mfa"|"update_trust_policy", "target_arn": "<arn>", "policy_arn": "<arn or null>"}
10. attribute_attack β€” {
      "action": "attribute_attack",
      "compromised_principal_arn": "<arn>",
      "attack_technique": "<description>",
      "mitre_techniques": ["T1078.004", ...],
      "lateral_movement_path": ["<arn1>", "<arn2>"],
      "containment_actions": ["disable_user:<arn>", "delete_function:<name>", ...]
    }

Strategy guidelines:
- For Task 1: List all principals and their policies. Check for wildcards, MFA, stale roles, exposed trust policies.
- For Task 2: Find principals with iam:PassRole. Trace escalation paths. Look for lambda + createUser chains.
- For Task 3: Query audit logs by severity=critical first, then trace suspicious sequences. Look for CreateFunction→CreateUser chains from unusual IPs.

Be systematic. Think step by step before each action. Flag findings as you discover them.
For Task 3, finish with attribute_attack once you've gathered enough evidence.
"""


# ──────────────────────────────────────────────
# JSON action parser
# ──────────────────────────────────────────────

def extract_json_action(text: str) -> Optional[dict]:
    """Extract the first JSON block from model output."""
    import re
    # Try fenced code block first
    pattern = r"```(?:json)?\s*(\{.*?\})\s*```"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        try:
            return json.loads(match.group(1))
        except json.JSONDecodeError:
            pass

    # Try raw JSON
    pattern2 = r"\{[^{}]*\"action\"[^{}]*\}"
    match2 = re.search(pattern2, text, re.DOTALL)
    if match2:
        try:
            return json.loads(match2.group(0))
        except json.JSONDecodeError:
            pass

    # Try to find largest JSON object
    for start in range(len(text)):
        if text[start] == "{":
            for end in range(len(text), start, -1):
                if text[end-1] == "}":
                    try:
                        obj = json.loads(text[start:end])
                        if "action" in obj:
                            return obj
                    except json.JSONDecodeError:
                        continue
    return None


def obs_to_text(obs_dict: dict, step: int) -> str:
    """Convert observation dict to a concise text summary for the LLM."""
    parts = [f"[Step {step}] Budget remaining: {obs_dict.get('budget_remaining', '?')}"]

    if obs_dict.get("hints"):
        parts.append("Hints: " + " | ".join(obs_dict["hints"]))

    if obs_dict.get("findings"):
        parts.append(f"Findings so far ({len(obs_dict['findings'])}):")
        for f in obs_dict["findings"][-3:]:  # last 3
            parts.append(f"  - [{f['severity']}] {f['finding_type']}: {f['description'][:80]}")

    if obs_dict.get("principals"):
        parts.append(f"Principals returned: {len(obs_dict['principals'])}")
        for p in obs_dict["principals"][:5]:
            mfa = "βœ“MFA" if p.get("mfa_enabled") else "βœ—MFA"
            parts.append(
                f"  {p['kind']}: {p['name']} | {mfa} | "
                f"last_active={p['last_active_days']}d | "
                f"policies={len(p.get('policies', []))}"
            )
        if len(obs_dict["principals"]) > 5:
            parts.append(f"  ... and {len(obs_dict['principals'])-5} more")

    if obs_dict.get("policies"):
        parts.append(f"Policies returned: {len(obs_dict['policies'])}")
        for p in obs_dict["policies"][:5]:
            wildcard = "⚠WILDCARD" if p.get("is_wildcard") else ""
            parts.append(f"  {p['name']} {wildcard} | arn={p['arn']}")
            if p.get("statements"):
                actions = p["statements"][0].get("actions", [])
                parts.append(f"    actions: {actions[:5]}")
        if len(obs_dict["policies"]) > 5:
            parts.append(f"  ... and {len(obs_dict['policies'])-5} more")

    if obs_dict.get("audit_events"):
        parts.append(f"Audit events returned: {len(obs_dict['audit_events'])}")
        for e in obs_dict["audit_events"][:8]:
            parts.append(
                f"  [{e.get('severity','?')}] {e['event_time']} | "
                f"{e['event_name']} | {e['principal_name']} | ip={e['source_ip']}"
            )
        if len(obs_dict["audit_events"]) > 8:
            parts.append(f"  ... and {len(obs_dict['audit_events'])-8} more")

    if obs_dict.get("escalation_paths"):
        parts.append(f"Escalation paths found: {len(obs_dict['escalation_paths'])}")
        for ep in obs_dict["escalation_paths"][:3]:
            parts.append(f"  Path (risk={ep.get('risk_score','?')}): {' β†’ '.join(ep['path'])}")

    if obs_dict.get("role_trust_policy"):
        parts.append(f"Trust policy: {json.dumps(obs_dict['role_trust_policy'], indent=2)[:300]}")

    if obs_dict.get("done"):
        parts.append("EPISODE DONE.")

    return "\n".join(parts)


# ──────────────────────────────────────────────
# Agent runner
# ──────────────────────────────────────────────

def run_agent(
    task_id: str,
    seed: int = 42,
    model: str = "gpt-4o-mini",
    complexity: str = "medium",
    verbose: bool = True,
) -> dict:
    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("OPENAI_API_KEY environment variable not set")

    client = OpenAI(api_key=api_key)
    env = IAMSentinelEnv(task_id=task_id, seed=seed, complexity=complexity)
    obs = env.reset()

    task_cfg = {
        "task1": {"name": "Misconfiguration Scanner",          "difficulty": "Easy"},
        "task2": {"name": "Privilege Escalation Path Detection","difficulty": "Medium"},
        "task3": {"name": "Live Attack Attribution",           "difficulty": "Hard"},
    }[task_id]

    if verbose:
        print(f"\n{'='*60}")
        print(f"Task: {task_cfg['name']} ({task_cfg['difficulty']})")
        print(f"Seed: {seed} | Model: {model} | Complexity: {complexity}")
        print(f"{'='*60}")

    # Build conversation history
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    # Initial user message with task description
    initial_msg = (
        f"Task: {obs.task_description}\n\n"
        f"Account ID: {obs.account_id}\n"
        f"Max steps: {obs.max_steps}\n"
    )
    if obs.hints:
        initial_msg += "\nHints:\n" + "\n".join(f"- {h}" for h in obs.hints)
    initial_msg += "\n\nBegin your investigation. Output one JSON action."

    messages.append({"role": "user", "content": initial_msg})

    episode_done = False
    step = 0
    final_score = 0.0
    total_reward = 0.0
    action_history = []

    while not episode_done and step < env._max_steps():
        step += 1

        # ── Call LLM ──────────────────────────
        try:
            response = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=0.2,
                max_tokens=800,
            )
            assistant_text = response.choices[0].message.content
        except Exception as e:
            print(f"  [Step {step}] LLM error: {e}")
            time.sleep(2)
            continue

        messages.append({"role": "assistant", "content": assistant_text})

        # ── Parse action ───────────────────────
        action_dict = extract_json_action(assistant_text)
        if action_dict is None:
            if verbose:
                print(f"  [Step {step}] Could not parse action from: {assistant_text[:100]}")
            feedback = "ERROR: Could not parse a valid JSON action. Output ONLY a JSON block."
            messages.append({"role": "user", "content": feedback})
            continue

        action_name = action_dict.get("action", "unknown")
        action_history.append(action_name)

        if verbose:
            print(f"  [Step {step}] Action: {action_name}", end="")
            key_params = {k: v for k, v in action_dict.items()
                         if k != "action" and v is not None}
            if key_params:
                print(f" | params: {json.dumps(key_params)[:100]}", end="")
            print()

        # ── Step environment ───────────────────
        try:
            next_obs, reward, done, info = env.step(action_dict)
        except Exception as e:
            feedback = f"ERROR executing action: {e}. Try a different action."
            messages.append({"role": "user", "content": feedback})
            continue

        total_reward += reward.total
        episode_done = done

        if done and info.get("final_score") is not None:
            final_score = info["final_score"]
            if verbose:
                print(f"  [Step {step}] Episode done. Final score: {final_score:.3f}")

        # ── Build feedback message ─────────────
        obs_dict = next_obs.model_dump()
        feedback_text = obs_to_text(obs_dict, step)
        if reward.step_reward != 0:
            feedback_text += f"\n[Reward signal: {reward.step_reward:+.3f}]"
        if obs_dict.get("findings"):
            feedback_text += f"\n[Total findings logged: {len(obs_dict['findings'])}]"

        if not done:
            feedback_text += "\n\nContinue your investigation. Output one JSON action."

        messages.append({"role": "user", "content": feedback_text})

        # Small delay to respect rate limits
        time.sleep(0.3)

    return {
        "task_id":       task_id,
        "task_name":     task_cfg["name"],
        "difficulty":    task_cfg["difficulty"],
        "seed":          seed,
        "model":         model,
        "final_score":   final_score,
        "total_reward":  total_reward,
        "steps_taken":   step,
        "action_history":action_history,
        "state":         env.state(),
    }


# ──────────────────────────────────────────────
# Main entry point
# ──────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="IAMSentinel Baseline Agent")
    parser.add_argument("--task",       default="all",      help="task1|task2|task3|all")
    parser.add_argument("--seed",       type=int, default=42)
    parser.add_argument("--model",      default="gpt-4o-mini")
    parser.add_argument("--complexity", default="medium",   help="easy|medium|hard")
    parser.add_argument("--output",     default=None,       help="Save results to JSON file")
    parser.add_argument("--quiet",      action="store_true")
    args = parser.parse_args()

    tasks = ["task1", "task2", "task3"] if args.task == "all" else [args.task]
    results = []

    for task_id in tasks:
        result = run_agent(
            task_id=task_id,
            seed=args.seed,
            model=args.model,
            complexity=args.complexity,
            verbose=not args.quiet,
        )
        results.append(result)

    # ── Print summary ──────────────────────────
    print("\n" + "="*60)
    print("BASELINE SCORES SUMMARY")
    print("="*60)
    print(f"{'Task':<35} {'Score':>6}  {'Steps':>5}  {'Difficulty'}")
    print("-"*60)
    for r in results:
        print(
            f"{r['task_name']:<35} {r['final_score']:>6.3f}  "
            f"{r['steps_taken']:>5}  {r['difficulty']}"
        )
    print("-"*60)
    avg = sum(r["final_score"] for r in results) / len(results)
    print(f"{'Average':<35} {avg:>6.3f}")
    print("="*60)

    if args.output:
        with open(args.output, "w") as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to {args.output}")

    return results


if __name__ == "__main__":
    main()