Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| IAMSentinel Baseline Inference Script | |
| ====================================== | |
| Runs a GPT-4o ReAct agent against all 3 tasks and reports scores. | |
| Usage: | |
| export OPENAI_API_KEY=sk-... | |
| python scripts/baseline_agent.py [--task all|task1|task2|task3] [--seed 42] [--model gpt-4o] | |
| Reproducible baseline scores (seed=42, complexity=medium, model=gpt-4o-mini): | |
| Task 1 (Easy): ~0.55β0.70 | |
| Task 2 (Medium): ~0.35β0.50 | |
| Task 3 (Hard): ~0.20β0.35 | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from typing import Optional | |
| try: | |
| from openai import OpenAI | |
| except ImportError: | |
| print("ERROR: openai package not installed. Run: pip install openai") | |
| sys.exit(1) | |
| # Ensure package is importable | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from iamsentinel import IAMSentinelEnv | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # System prompt for the ReAct agent | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """You are an expert cloud security analyst specialising in AWS IAM security. | |
| You are operating inside a simulated IAM environment and must complete security tasks. | |
| You interact with the environment by outputting JSON actions. Each response must contain | |
| EXACTLY ONE action as a JSON block in this format: | |
| ```json | |
| { | |
| "action": "<action_name>", | |
| ... action parameters ... | |
| } | |
| ``` | |
| Available actions: | |
| 1. list_principals β {"action": "list_principals", "kind": "all"|"user"|"role"} | |
| 2. list_policies β {"action": "list_policies", "principal_arn": "<arn or null>"} | |
| 3. get_policy β {"action": "get_policy", "policy_arn": "<arn>"} | |
| 4. get_principal β {"action": "get_principal", "principal_arn": "<arn>"} | |
| 5. get_role_trust β {"action": "get_role_trust", "role_arn": "<arn>"} | |
| 6. query_audit_log β {"action": "query_audit_log", "filter": {"event_name": "...", "severity": "...", "principal_arn": "...", "source_ip": "..."}, "limit": 20} | |
| 7. trace_escalation_path β {"action": "trace_escalation_path", "from_principal_arn": "<arn>", "to_principal_arn": null} | |
| 8. flag_finding β { | |
| "action": "flag_finding", | |
| "finding_type": "wildcard_policy"|"mfa_disabled"|"stale_admin_role"|"privilege_escalation_path"|"exposed_trust_policy"|"suspicious_event", | |
| "affected_principal_arn": "<arn or null>", | |
| "affected_policy_arn": "<arn or null>", | |
| "severity": "low"|"medium"|"high"|"critical", | |
| "description": "<description>", | |
| "mitre_technique": "<T-code or null>", | |
| "evidence": ["<arn or event_id>", ...] | |
| } | |
| 9. remediate β {"action": "remediate", "remediation_type": "detach_policy"|"delete_user"|"require_mfa"|"update_trust_policy", "target_arn": "<arn>", "policy_arn": "<arn or null>"} | |
| 10. attribute_attack β { | |
| "action": "attribute_attack", | |
| "compromised_principal_arn": "<arn>", | |
| "attack_technique": "<description>", | |
| "mitre_techniques": ["T1078.004", ...], | |
| "lateral_movement_path": ["<arn1>", "<arn2>"], | |
| "containment_actions": ["disable_user:<arn>", "delete_function:<name>", ...] | |
| } | |
| Strategy guidelines: | |
| - For Task 1: List all principals and their policies. Check for wildcards, MFA, stale roles, exposed trust policies. | |
| - For Task 2: Find principals with iam:PassRole. Trace escalation paths. Look for lambda + createUser chains. | |
| - For Task 3: Query audit logs by severity=critical first, then trace suspicious sequences. Look for CreateFunctionβCreateUser chains from unusual IPs. | |
| Be systematic. Think step by step before each action. Flag findings as you discover them. | |
| For Task 3, finish with attribute_attack once you've gathered enough evidence. | |
| """ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # JSON action parser | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_json_action(text: str) -> Optional[dict]: | |
| """Extract the first JSON block from model output.""" | |
| import re | |
| # Try fenced code block first | |
| pattern = r"```(?:json)?\s*(\{.*?\})\s*```" | |
| match = re.search(pattern, text, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try raw JSON | |
| pattern2 = r"\{[^{}]*\"action\"[^{}]*\}" | |
| match2 = re.search(pattern2, text, re.DOTALL) | |
| if match2: | |
| try: | |
| return json.loads(match2.group(0)) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to find largest JSON object | |
| for start in range(len(text)): | |
| if text[start] == "{": | |
| for end in range(len(text), start, -1): | |
| if text[end-1] == "}": | |
| try: | |
| obj = json.loads(text[start:end]) | |
| if "action" in obj: | |
| return obj | |
| except json.JSONDecodeError: | |
| continue | |
| return None | |
| def obs_to_text(obs_dict: dict, step: int) -> str: | |
| """Convert observation dict to a concise text summary for the LLM.""" | |
| parts = [f"[Step {step}] Budget remaining: {obs_dict.get('budget_remaining', '?')}"] | |
| if obs_dict.get("hints"): | |
| parts.append("Hints: " + " | ".join(obs_dict["hints"])) | |
| if obs_dict.get("findings"): | |
| parts.append(f"Findings so far ({len(obs_dict['findings'])}):") | |
| for f in obs_dict["findings"][-3:]: # last 3 | |
| parts.append(f" - [{f['severity']}] {f['finding_type']}: {f['description'][:80]}") | |
| if obs_dict.get("principals"): | |
| parts.append(f"Principals returned: {len(obs_dict['principals'])}") | |
| for p in obs_dict["principals"][:5]: | |
| mfa = "βMFA" if p.get("mfa_enabled") else "βMFA" | |
| parts.append( | |
| f" {p['kind']}: {p['name']} | {mfa} | " | |
| f"last_active={p['last_active_days']}d | " | |
| f"policies={len(p.get('policies', []))}" | |
| ) | |
| if len(obs_dict["principals"]) > 5: | |
| parts.append(f" ... and {len(obs_dict['principals'])-5} more") | |
| if obs_dict.get("policies"): | |
| parts.append(f"Policies returned: {len(obs_dict['policies'])}") | |
| for p in obs_dict["policies"][:5]: | |
| wildcard = "β WILDCARD" if p.get("is_wildcard") else "" | |
| parts.append(f" {p['name']} {wildcard} | arn={p['arn']}") | |
| if p.get("statements"): | |
| actions = p["statements"][0].get("actions", []) | |
| parts.append(f" actions: {actions[:5]}") | |
| if len(obs_dict["policies"]) > 5: | |
| parts.append(f" ... and {len(obs_dict['policies'])-5} more") | |
| if obs_dict.get("audit_events"): | |
| parts.append(f"Audit events returned: {len(obs_dict['audit_events'])}") | |
| for e in obs_dict["audit_events"][:8]: | |
| parts.append( | |
| f" [{e.get('severity','?')}] {e['event_time']} | " | |
| f"{e['event_name']} | {e['principal_name']} | ip={e['source_ip']}" | |
| ) | |
| if len(obs_dict["audit_events"]) > 8: | |
| parts.append(f" ... and {len(obs_dict['audit_events'])-8} more") | |
| if obs_dict.get("escalation_paths"): | |
| parts.append(f"Escalation paths found: {len(obs_dict['escalation_paths'])}") | |
| for ep in obs_dict["escalation_paths"][:3]: | |
| parts.append(f" Path (risk={ep.get('risk_score','?')}): {' β '.join(ep['path'])}") | |
| if obs_dict.get("role_trust_policy"): | |
| parts.append(f"Trust policy: {json.dumps(obs_dict['role_trust_policy'], indent=2)[:300]}") | |
| if obs_dict.get("done"): | |
| parts.append("EPISODE DONE.") | |
| return "\n".join(parts) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Agent runner | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_agent( | |
| task_id: str, | |
| seed: int = 42, | |
| model: str = "gpt-4o-mini", | |
| complexity: str = "medium", | |
| verbose: bool = True, | |
| ) -> dict: | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY environment variable not set") | |
| client = OpenAI(api_key=api_key) | |
| env = IAMSentinelEnv(task_id=task_id, seed=seed, complexity=complexity) | |
| obs = env.reset() | |
| task_cfg = { | |
| "task1": {"name": "Misconfiguration Scanner", "difficulty": "Easy"}, | |
| "task2": {"name": "Privilege Escalation Path Detection","difficulty": "Medium"}, | |
| "task3": {"name": "Live Attack Attribution", "difficulty": "Hard"}, | |
| }[task_id] | |
| if verbose: | |
| print(f"\n{'='*60}") | |
| print(f"Task: {task_cfg['name']} ({task_cfg['difficulty']})") | |
| print(f"Seed: {seed} | Model: {model} | Complexity: {complexity}") | |
| print(f"{'='*60}") | |
| # Build conversation history | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| # Initial user message with task description | |
| initial_msg = ( | |
| f"Task: {obs.task_description}\n\n" | |
| f"Account ID: {obs.account_id}\n" | |
| f"Max steps: {obs.max_steps}\n" | |
| ) | |
| if obs.hints: | |
| initial_msg += "\nHints:\n" + "\n".join(f"- {h}" for h in obs.hints) | |
| initial_msg += "\n\nBegin your investigation. Output one JSON action." | |
| messages.append({"role": "user", "content": initial_msg}) | |
| episode_done = False | |
| step = 0 | |
| final_score = 0.0 | |
| total_reward = 0.0 | |
| action_history = [] | |
| while not episode_done and step < env._max_steps(): | |
| step += 1 | |
| # ββ Call LLM ββββββββββββββββββββββββββ | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=0.2, | |
| max_tokens=800, | |
| ) | |
| assistant_text = response.choices[0].message.content | |
| except Exception as e: | |
| print(f" [Step {step}] LLM error: {e}") | |
| time.sleep(2) | |
| continue | |
| messages.append({"role": "assistant", "content": assistant_text}) | |
| # ββ Parse action βββββββββββββββββββββββ | |
| action_dict = extract_json_action(assistant_text) | |
| if action_dict is None: | |
| if verbose: | |
| print(f" [Step {step}] Could not parse action from: {assistant_text[:100]}") | |
| feedback = "ERROR: Could not parse a valid JSON action. Output ONLY a JSON block." | |
| messages.append({"role": "user", "content": feedback}) | |
| continue | |
| action_name = action_dict.get("action", "unknown") | |
| action_history.append(action_name) | |
| if verbose: | |
| print(f" [Step {step}] Action: {action_name}", end="") | |
| key_params = {k: v for k, v in action_dict.items() | |
| if k != "action" and v is not None} | |
| if key_params: | |
| print(f" | params: {json.dumps(key_params)[:100]}", end="") | |
| print() | |
| # ββ Step environment βββββββββββββββββββ | |
| try: | |
| next_obs, reward, done, info = env.step(action_dict) | |
| except Exception as e: | |
| feedback = f"ERROR executing action: {e}. Try a different action." | |
| messages.append({"role": "user", "content": feedback}) | |
| continue | |
| total_reward += reward.total | |
| episode_done = done | |
| if done and info.get("final_score") is not None: | |
| final_score = info["final_score"] | |
| if verbose: | |
| print(f" [Step {step}] Episode done. Final score: {final_score:.3f}") | |
| # ββ Build feedback message βββββββββββββ | |
| obs_dict = next_obs.model_dump() | |
| feedback_text = obs_to_text(obs_dict, step) | |
| if reward.step_reward != 0: | |
| feedback_text += f"\n[Reward signal: {reward.step_reward:+.3f}]" | |
| if obs_dict.get("findings"): | |
| feedback_text += f"\n[Total findings logged: {len(obs_dict['findings'])}]" | |
| if not done: | |
| feedback_text += "\n\nContinue your investigation. Output one JSON action." | |
| messages.append({"role": "user", "content": feedback_text}) | |
| # Small delay to respect rate limits | |
| time.sleep(0.3) | |
| return { | |
| "task_id": task_id, | |
| "task_name": task_cfg["name"], | |
| "difficulty": task_cfg["difficulty"], | |
| "seed": seed, | |
| "model": model, | |
| "final_score": final_score, | |
| "total_reward": total_reward, | |
| "steps_taken": step, | |
| "action_history":action_history, | |
| "state": env.state(), | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main entry point | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="IAMSentinel Baseline Agent") | |
| parser.add_argument("--task", default="all", help="task1|task2|task3|all") | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--model", default="gpt-4o-mini") | |
| parser.add_argument("--complexity", default="medium", help="easy|medium|hard") | |
| parser.add_argument("--output", default=None, help="Save results to JSON file") | |
| parser.add_argument("--quiet", action="store_true") | |
| args = parser.parse_args() | |
| tasks = ["task1", "task2", "task3"] if args.task == "all" else [args.task] | |
| results = [] | |
| for task_id in tasks: | |
| result = run_agent( | |
| task_id=task_id, | |
| seed=args.seed, | |
| model=args.model, | |
| complexity=args.complexity, | |
| verbose=not args.quiet, | |
| ) | |
| results.append(result) | |
| # ββ Print summary ββββββββββββββββββββββββββ | |
| print("\n" + "="*60) | |
| print("BASELINE SCORES SUMMARY") | |
| print("="*60) | |
| print(f"{'Task':<35} {'Score':>6} {'Steps':>5} {'Difficulty'}") | |
| print("-"*60) | |
| for r in results: | |
| print( | |
| f"{r['task_name']:<35} {r['final_score']:>6.3f} " | |
| f"{r['steps_taken']:>5} {r['difficulty']}" | |
| ) | |
| print("-"*60) | |
| avg = sum(r["final_score"] for r in results) / len(results) | |
| print(f"{'Average':<35} {avg:>6.3f}") | |
| print("="*60) | |
| if args.output: | |
| with open(args.output, "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nResults saved to {args.output}") | |
| return results | |
| if __name__ == "__main__": | |
| main() | |