iamsentinel / scripts /baseline_agent.py
Nampally Tejasri
Initial OpenEnv submission deploy
ca83593
#!/usr/bin/env python3
"""
IAMSentinel Baseline Inference Script
======================================
Runs a GPT-4o ReAct agent against all 3 tasks and reports scores.
Usage:
export OPENAI_API_KEY=sk-...
python scripts/baseline_agent.py [--task all|task1|task2|task3] [--seed 42] [--model gpt-4o]
Reproducible baseline scores (seed=42, complexity=medium, model=gpt-4o-mini):
Task 1 (Easy): ~0.55–0.70
Task 2 (Medium): ~0.35–0.50
Task 3 (Hard): ~0.20–0.35
"""
import argparse
import json
import os
import sys
import time
from typing import Optional
try:
from openai import OpenAI
except ImportError:
print("ERROR: openai package not installed. Run: pip install openai")
sys.exit(1)
# Ensure package is importable
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from iamsentinel import IAMSentinelEnv
# ──────────────────────────────────────────────
# System prompt for the ReAct agent
# ──────────────────────────────────────────────
SYSTEM_PROMPT = """You are an expert cloud security analyst specialising in AWS IAM security.
You are operating inside a simulated IAM environment and must complete security tasks.
You interact with the environment by outputting JSON actions. Each response must contain
EXACTLY ONE action as a JSON block in this format:
```json
{
"action": "<action_name>",
... action parameters ...
}
```
Available actions:
1. list_principals β€” {"action": "list_principals", "kind": "all"|"user"|"role"}
2. list_policies β€” {"action": "list_policies", "principal_arn": "<arn or null>"}
3. get_policy β€” {"action": "get_policy", "policy_arn": "<arn>"}
4. get_principal β€” {"action": "get_principal", "principal_arn": "<arn>"}
5. get_role_trust β€” {"action": "get_role_trust", "role_arn": "<arn>"}
6. query_audit_log β€” {"action": "query_audit_log", "filter": {"event_name": "...", "severity": "...", "principal_arn": "...", "source_ip": "..."}, "limit": 20}
7. trace_escalation_path β€” {"action": "trace_escalation_path", "from_principal_arn": "<arn>", "to_principal_arn": null}
8. flag_finding β€” {
"action": "flag_finding",
"finding_type": "wildcard_policy"|"mfa_disabled"|"stale_admin_role"|"privilege_escalation_path"|"exposed_trust_policy"|"suspicious_event",
"affected_principal_arn": "<arn or null>",
"affected_policy_arn": "<arn or null>",
"severity": "low"|"medium"|"high"|"critical",
"description": "<description>",
"mitre_technique": "<T-code or null>",
"evidence": ["<arn or event_id>", ...]
}
9. remediate β€” {"action": "remediate", "remediation_type": "detach_policy"|"delete_user"|"require_mfa"|"update_trust_policy", "target_arn": "<arn>", "policy_arn": "<arn or null>"}
10. attribute_attack β€” {
"action": "attribute_attack",
"compromised_principal_arn": "<arn>",
"attack_technique": "<description>",
"mitre_techniques": ["T1078.004", ...],
"lateral_movement_path": ["<arn1>", "<arn2>"],
"containment_actions": ["disable_user:<arn>", "delete_function:<name>", ...]
}
Strategy guidelines:
- For Task 1: List all principals and their policies. Check for wildcards, MFA, stale roles, exposed trust policies.
- For Task 2: Find principals with iam:PassRole. Trace escalation paths. Look for lambda + createUser chains.
- For Task 3: Query audit logs by severity=critical first, then trace suspicious sequences. Look for CreateFunction→CreateUser chains from unusual IPs.
Be systematic. Think step by step before each action. Flag findings as you discover them.
For Task 3, finish with attribute_attack once you've gathered enough evidence.
"""
# ──────────────────────────────────────────────
# JSON action parser
# ──────────────────────────────────────────────
def extract_json_action(text: str) -> Optional[dict]:
"""Extract the first JSON block from model output."""
import re
# Try fenced code block first
pattern = r"```(?:json)?\s*(\{.*?\})\s*```"
match = re.search(pattern, text, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
# Try raw JSON
pattern2 = r"\{[^{}]*\"action\"[^{}]*\}"
match2 = re.search(pattern2, text, re.DOTALL)
if match2:
try:
return json.loads(match2.group(0))
except json.JSONDecodeError:
pass
# Try to find largest JSON object
for start in range(len(text)):
if text[start] == "{":
for end in range(len(text), start, -1):
if text[end-1] == "}":
try:
obj = json.loads(text[start:end])
if "action" in obj:
return obj
except json.JSONDecodeError:
continue
return None
def obs_to_text(obs_dict: dict, step: int) -> str:
"""Convert observation dict to a concise text summary for the LLM."""
parts = [f"[Step {step}] Budget remaining: {obs_dict.get('budget_remaining', '?')}"]
if obs_dict.get("hints"):
parts.append("Hints: " + " | ".join(obs_dict["hints"]))
if obs_dict.get("findings"):
parts.append(f"Findings so far ({len(obs_dict['findings'])}):")
for f in obs_dict["findings"][-3:]: # last 3
parts.append(f" - [{f['severity']}] {f['finding_type']}: {f['description'][:80]}")
if obs_dict.get("principals"):
parts.append(f"Principals returned: {len(obs_dict['principals'])}")
for p in obs_dict["principals"][:5]:
mfa = "βœ“MFA" if p.get("mfa_enabled") else "βœ—MFA"
parts.append(
f" {p['kind']}: {p['name']} | {mfa} | "
f"last_active={p['last_active_days']}d | "
f"policies={len(p.get('policies', []))}"
)
if len(obs_dict["principals"]) > 5:
parts.append(f" ... and {len(obs_dict['principals'])-5} more")
if obs_dict.get("policies"):
parts.append(f"Policies returned: {len(obs_dict['policies'])}")
for p in obs_dict["policies"][:5]:
wildcard = "⚠WILDCARD" if p.get("is_wildcard") else ""
parts.append(f" {p['name']} {wildcard} | arn={p['arn']}")
if p.get("statements"):
actions = p["statements"][0].get("actions", [])
parts.append(f" actions: {actions[:5]}")
if len(obs_dict["policies"]) > 5:
parts.append(f" ... and {len(obs_dict['policies'])-5} more")
if obs_dict.get("audit_events"):
parts.append(f"Audit events returned: {len(obs_dict['audit_events'])}")
for e in obs_dict["audit_events"][:8]:
parts.append(
f" [{e.get('severity','?')}] {e['event_time']} | "
f"{e['event_name']} | {e['principal_name']} | ip={e['source_ip']}"
)
if len(obs_dict["audit_events"]) > 8:
parts.append(f" ... and {len(obs_dict['audit_events'])-8} more")
if obs_dict.get("escalation_paths"):
parts.append(f"Escalation paths found: {len(obs_dict['escalation_paths'])}")
for ep in obs_dict["escalation_paths"][:3]:
parts.append(f" Path (risk={ep.get('risk_score','?')}): {' β†’ '.join(ep['path'])}")
if obs_dict.get("role_trust_policy"):
parts.append(f"Trust policy: {json.dumps(obs_dict['role_trust_policy'], indent=2)[:300]}")
if obs_dict.get("done"):
parts.append("EPISODE DONE.")
return "\n".join(parts)
# ──────────────────────────────────────────────
# Agent runner
# ──────────────────────────────────────────────
def run_agent(
task_id: str,
seed: int = 42,
model: str = "gpt-4o-mini",
complexity: str = "medium",
verbose: bool = True,
) -> dict:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set")
client = OpenAI(api_key=api_key)
env = IAMSentinelEnv(task_id=task_id, seed=seed, complexity=complexity)
obs = env.reset()
task_cfg = {
"task1": {"name": "Misconfiguration Scanner", "difficulty": "Easy"},
"task2": {"name": "Privilege Escalation Path Detection","difficulty": "Medium"},
"task3": {"name": "Live Attack Attribution", "difficulty": "Hard"},
}[task_id]
if verbose:
print(f"\n{'='*60}")
print(f"Task: {task_cfg['name']} ({task_cfg['difficulty']})")
print(f"Seed: {seed} | Model: {model} | Complexity: {complexity}")
print(f"{'='*60}")
# Build conversation history
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Initial user message with task description
initial_msg = (
f"Task: {obs.task_description}\n\n"
f"Account ID: {obs.account_id}\n"
f"Max steps: {obs.max_steps}\n"
)
if obs.hints:
initial_msg += "\nHints:\n" + "\n".join(f"- {h}" for h in obs.hints)
initial_msg += "\n\nBegin your investigation. Output one JSON action."
messages.append({"role": "user", "content": initial_msg})
episode_done = False
step = 0
final_score = 0.0
total_reward = 0.0
action_history = []
while not episode_done and step < env._max_steps():
step += 1
# ── Call LLM ──────────────────────────
try:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0.2,
max_tokens=800,
)
assistant_text = response.choices[0].message.content
except Exception as e:
print(f" [Step {step}] LLM error: {e}")
time.sleep(2)
continue
messages.append({"role": "assistant", "content": assistant_text})
# ── Parse action ───────────────────────
action_dict = extract_json_action(assistant_text)
if action_dict is None:
if verbose:
print(f" [Step {step}] Could not parse action from: {assistant_text[:100]}")
feedback = "ERROR: Could not parse a valid JSON action. Output ONLY a JSON block."
messages.append({"role": "user", "content": feedback})
continue
action_name = action_dict.get("action", "unknown")
action_history.append(action_name)
if verbose:
print(f" [Step {step}] Action: {action_name}", end="")
key_params = {k: v for k, v in action_dict.items()
if k != "action" and v is not None}
if key_params:
print(f" | params: {json.dumps(key_params)[:100]}", end="")
print()
# ── Step environment ───────────────────
try:
next_obs, reward, done, info = env.step(action_dict)
except Exception as e:
feedback = f"ERROR executing action: {e}. Try a different action."
messages.append({"role": "user", "content": feedback})
continue
total_reward += reward.total
episode_done = done
if done and info.get("final_score") is not None:
final_score = info["final_score"]
if verbose:
print(f" [Step {step}] Episode done. Final score: {final_score:.3f}")
# ── Build feedback message ─────────────
obs_dict = next_obs.model_dump()
feedback_text = obs_to_text(obs_dict, step)
if reward.step_reward != 0:
feedback_text += f"\n[Reward signal: {reward.step_reward:+.3f}]"
if obs_dict.get("findings"):
feedback_text += f"\n[Total findings logged: {len(obs_dict['findings'])}]"
if not done:
feedback_text += "\n\nContinue your investigation. Output one JSON action."
messages.append({"role": "user", "content": feedback_text})
# Small delay to respect rate limits
time.sleep(0.3)
return {
"task_id": task_id,
"task_name": task_cfg["name"],
"difficulty": task_cfg["difficulty"],
"seed": seed,
"model": model,
"final_score": final_score,
"total_reward": total_reward,
"steps_taken": step,
"action_history":action_history,
"state": env.state(),
}
# ──────────────────────────────────────────────
# Main entry point
# ──────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="IAMSentinel Baseline Agent")
parser.add_argument("--task", default="all", help="task1|task2|task3|all")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--model", default="gpt-4o-mini")
parser.add_argument("--complexity", default="medium", help="easy|medium|hard")
parser.add_argument("--output", default=None, help="Save results to JSON file")
parser.add_argument("--quiet", action="store_true")
args = parser.parse_args()
tasks = ["task1", "task2", "task3"] if args.task == "all" else [args.task]
results = []
for task_id in tasks:
result = run_agent(
task_id=task_id,
seed=args.seed,
model=args.model,
complexity=args.complexity,
verbose=not args.quiet,
)
results.append(result)
# ── Print summary ──────────────────────────
print("\n" + "="*60)
print("BASELINE SCORES SUMMARY")
print("="*60)
print(f"{'Task':<35} {'Score':>6} {'Steps':>5} {'Difficulty'}")
print("-"*60)
for r in results:
print(
f"{r['task_name']:<35} {r['final_score']:>6.3f} "
f"{r['steps_taken']:>5} {r['difficulty']}"
)
print("-"*60)
avg = sum(r["final_score"] for r in results) / len(results)
print(f"{'Average':<35} {avg:>6.3f}")
print("="*60)
if args.output:
with open(args.output, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {args.output}")
return results
if __name__ == "__main__":
main()