iamsentinel / scripts /rl_training_example.py
Nampally Tejasri
Initial OpenEnv submission deploy
ca83593
#!/usr/bin/env python3
"""
IAMSentinel RL Training Example
=================================
Demonstrates how to connect a local RL training loop to the
remote IAMSentinel OpenEnv server (Hugging Face Spaces or local Docker).
This implements a simple LLM-guided policy (REINFORCE-style) using the
OpenAI API as the policy network, with episode-level reward signals.
The same pattern works with any RL framework (Stable-Baselines3, RLlib,
CleanRL) β€” just replace the policy network.
Setup:
# Option A β€” local docker
docker build -t iamsentinel . && docker run -p 7860:7860 iamsentinel
# Option B β€” HF Space (set HF_SPACE_URL env var)
export HF_SPACE_URL=https://<username>-iamsentinel.hf.space
# Run training
export OPENAI_API_KEY=sk-...
python scripts/rl_training_example.py --episodes 20 --task task1
Architecture:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Local Machine (trainer) β”‚
β”‚ β”‚
β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚ β”‚ Policy β”‚ β”‚ Replay β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ β”‚ (GPT-4o) β”‚ β”‚ Buffer β”‚ │◄───────►│ IAMSentinel Server β”‚
β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ HTTP β”‚ (HF Space / Docker) β”‚
β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ β”‚ Episode Logger / Scorer β”‚ β”‚
β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
import argparse
import json
import os
import sys
import time
import statistics
from collections import defaultdict
from typing import Optional
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from iamsentinel.client import IAMSentinelClient, IAMSentinelClientError
try:
from openai import OpenAI
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
# ──────────────────────────────────────────────
# Replay buffer (stores episodes for training)
# ──────────────────────────────────────────────
class Episode:
def __init__(self, task_id: str, seed: int):
self.task_id = task_id
self.seed = seed
self.transitions: list[dict] = [] # (obs, action, reward, next_obs, done)
self.total_reward = 0.0
self.final_score = 0.0
self.steps = 0
def add(self, obs: dict, action: dict, reward: dict,
next_obs: dict, done: bool):
self.transitions.append({
"obs": obs,
"action": action,
"reward": reward["total"],
"step_reward": reward.get("step_reward", 0.0),
"next_obs": next_obs,
"done": done,
})
self.total_reward += reward["total"]
self.steps += 1
if done and reward.get("total") is not None:
self.final_score = reward["total"]
class ReplayBuffer:
def __init__(self, max_episodes: int = 100):
self.episodes: list[Episode] = []
self.max_episodes = max_episodes
def add(self, episode: Episode):
self.episodes.append(episode)
if len(self.episodes) > self.max_episodes:
self.episodes.pop(0)
def mean_score(self, last_n: int = 10) -> float:
recent = [e.final_score for e in self.episodes[-last_n:]]
return statistics.mean(recent) if recent else 0.0
def task_scores(self) -> dict[str, list[float]]:
by_task: dict[str, list[float]] = defaultdict(list)
for ep in self.episodes:
by_task[ep.task_id].append(ep.final_score)
return dict(by_task)
# ──────────────────────────────────────────────
# LLM Policy
# ──────────────────────────────────────────────
POLICY_SYSTEM = """You are an IAM security AI agent. You interact with a cloud IAM
environment by outputting ONE JSON action per turn.
Your goal: identify security vulnerabilities and complete the assigned task.
Output ONLY a valid JSON action block like:
{"action": "list_principals", "kind": "all"}
Available actions:
- {"action": "list_principals", "kind": "all"|"user"|"role"}
- {"action": "list_policies", "principal_arn": null}
- {"action": "get_policy", "policy_arn": "<arn>"}
- {"action": "get_principal", "principal_arn": "<arn>"}
- {"action": "get_role_trust", "role_arn": "<arn>"}
- {"action": "query_audit_log", "filter": {"severity":"critical","event_name":"..."}, "limit": 20}
- {"action": "trace_escalation_path", "from_principal_arn": "<arn>", "to_principal_arn": null}
- {"action": "flag_finding", "finding_type": "wildcard_policy"|"mfa_disabled"|"stale_admin_role"|"privilege_escalation_path"|"exposed_trust_policy", "severity": "critical", "description": "...", "affected_principal_arn": null, "evidence": []}
- {"action": "attribute_attack", "compromised_principal_arn":"<arn>","attack_technique":"...","mitre_techniques":["T1078.004"],"lateral_movement_path":["<arn1>","<arn2>"],"containment_actions":["disable_user:<arn>"]}
Be systematic. For Task 1: scan all principals and policies for misconfigs.
For Task 2: find iam:PassRole chains. For Task 3: query critical/high severity logs first."""
def _format_obs_for_policy(obs: dict, step: int, prev_reward: float = 0.0) -> str:
"""Format observation into LLM-friendly text."""
lines = [
f"Step {step}/{obs.get('max_steps', '?')} | Budget: {obs.get('budget_remaining', '?')}",
f"Task: {obs.get('task_description', '')[:120]}",
]
if prev_reward != 0:
lines.append(f"Last reward signal: {prev_reward:+.3f}")
findings = obs.get("findings", [])
if findings:
lines.append(f"Findings logged ({len(findings)}):")
for f in findings[-3:]:
lines.append(f" [{f['severity']}] {f['finding_type']}: {f['description'][:60]}")
if obs.get("hints"):
lines.append("Hints: " + " | ".join(obs["hints"]))
if obs.get("principals"):
lines.append(f"Principals ({len(obs['principals'])}):")
for p in obs["principals"][:6]:
mfa = "MFAβœ“" if p.get("mfa_enabled") else "MFAβœ—"
lines.append(
f" {p['kind']}: {p['name']} | {mfa} | "
f"inactive={p['last_active_days']}d | "
f"policies={len(p.get('policies',[]))}"
)
if obs.get("policies"):
lines.append(f"Policies ({len(obs['policies'])}):")
for p in obs["policies"][:6]:
wc = "⚠WILDCARD" if p.get("is_wildcard") else ""
acts = []
for stmt in p.get("statements", []):
acts.extend(stmt.get("actions", []))
lines.append(f" {p['name']} {wc} | arn={p['arn']} | actions={acts[:4]}")
if obs.get("audit_events"):
lines.append(f"Audit events ({len(obs['audit_events'])}):")
for e in obs["audit_events"][:8]:
lines.append(
f" [{e.get('severity','?')}] {e['event_time'][-8:]} | "
f"{e['event_name']} | {e['principal_name']} | ip={e['source_ip']}"
)
if obs.get("escalation_paths"):
lines.append(f"Escalation paths found: {len(obs['escalation_paths'])}")
for ep in obs["escalation_paths"][:2]:
path_str = " β†’ ".join(a.split("/")[-1] for a in ep.get("path", []))
lines.append(f" {path_str} (risk={ep.get('risk_score', '?')})")
lines.append("\nOutput ONE JSON action:")
return "\n".join(lines)
def _extract_action(text: str) -> Optional[dict]:
"""Extract JSON action from LLM output."""
import re
for pattern in [
r"```(?:json)?\s*(\{.*?\})\s*```",
r"(\{[^{}]*\"action\"[^{}]*\})",
]:
m = re.search(pattern, text, re.DOTALL)
if m:
try:
return json.loads(m.group(1))
except Exception:
pass
# Greedy fallback
for s in range(len(text)):
if text[s] == "{":
for e in range(len(text), s, -1):
try:
obj = json.loads(text[s:e])
if "action" in obj:
return obj
except Exception:
continue
return None
# ──────────────────────────────────────────────
# Episode runner
# ──────────────────────────────────────────────
def run_episode(
client: IAMSentinelClient,
task_id: str,
seed: int,
model: str,
openai_client,
verbose: bool = False,
) -> Episode:
"""Run one complete episode and return the filled Episode object."""
episode = Episode(task_id=task_id, seed=seed)
obs = client.reset(task_id=task_id, seed=seed, complexity="medium")
messages = [{"role": "system", "content": POLICY_SYSTEM}]
prev_reward = 0.0
done = False
step = 0
max_steps = obs.get("max_steps", 40)
while not done and step < max_steps:
step += 1
user_msg = _format_obs_for_policy(obs, step, prev_reward)
messages.append({"role": "user", "content": user_msg})
# Get action from policy
try:
resp = openai_client.chat.completions.create(
model=model,
messages=messages[-20:], # sliding window context
temperature=0.3,
max_tokens=400,
)
text = resp.choices[0].message.content
messages.append({"role": "assistant", "content": text})
except Exception as ex:
if verbose:
print(f" LLM error: {ex}")
time.sleep(2)
continue
action = _extract_action(text)
if action is None:
if verbose:
print(f" [Step {step}] Failed to parse action")
messages.append({
"role": "user",
"content": "Could not parse JSON. Output ONLY a valid JSON action."
})
continue
# Execute action
try:
next_obs, reward, done, info = client.step(action)
except IAMSentinelClientError as ex:
if verbose:
print(f" [Step {step}] Server error: {ex}")
break
prev_reward = reward.get("step_reward", 0.0)
episode.add(obs, action, reward, next_obs, done)
if verbose:
final = f" | FINAL={reward['total']:.3f}" if done else ""
print(
f" [Step {step:02d}] {action.get('action','?'):<28} "
f"r={prev_reward:+.3f}{final}"
)
obs = next_obs
time.sleep(0.2) # rate limit
return episode
# ──────────────────────────────────────────────
# Training loop
# ──────────────────────────────────────────────
def train(
server_url: str,
tasks: list[str],
n_episodes: int,
seeds: list[int],
model: str,
verbose: bool,
output_path: Optional[str],
):
if not HAS_OPENAI:
print("ERROR: pip install openai")
sys.exit(1)
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
print("ERROR: Set OPENAI_API_KEY environment variable")
sys.exit(1)
openai_client = OpenAI(api_key=api_key)
client = IAMSentinelClient(base_url=server_url)
# Verify server is up
try:
health = client.health()
print(f"βœ… Connected to IAMSentinel server: {server_url}")
print(f" Status: {health['status']} | Active sessions: {health['sessions']}")
except IAMSentinelClientError as e:
print(f"❌ Cannot reach server at {server_url}")
print(f" Error: {e}")
print("\nTo start a local server:")
print(" docker build -t iamsentinel . && docker run -p 7860:7860 iamsentinel")
sys.exit(1)
buffer = ReplayBuffer(max_episodes=200)
episode_num = 0
all_results = []
print(f"\n{'='*65}")
print(f"IAMSentinel RL Training")
print(f"Tasks: {tasks} | Episodes: {n_episodes} | Model: {model}")
print(f"{'='*65}\n")
for ep_idx in range(n_episodes):
task_id = tasks[ep_idx % len(tasks)]
seed = seeds[ep_idx % len(seeds)]
episode_num += 1
print(f"Episode {episode_num:03d}/{n_episodes} | task={task_id} | seed={seed}")
episode = run_episode(
client, task_id, seed, model, openai_client, verbose
)
buffer.add(episode)
# Log results
result = {
"episode": episode_num,
"task_id": task_id,
"seed": seed,
"steps": episode.steps,
"total_reward": round(episode.total_reward, 4),
"final_score": round(episode.final_score, 4),
}
all_results.append(result)
mean_10 = buffer.mean_score(last_n=10)
print(
f" Score={episode.final_score:.3f} | "
f"Steps={episode.steps} | "
f"Moving avg(10)={mean_10:.3f}"
)
# Print per-task breakdown every 5 episodes
if episode_num % 5 == 0:
print("\n πŸ“Š Per-task mean scores:")
for tid, scores in buffer.task_scores().items():
print(f" {tid}: mean={statistics.mean(scores):.3f} "
f"over {len(scores)} episodes")
print()
# ── Final summary ──────────────────────────
print(f"\n{'='*65}")
print("TRAINING COMPLETE β€” Final Summary")
print(f"{'='*65}")
task_scores = buffer.task_scores()
for tid in tasks:
scores = task_scores.get(tid, [])
if scores:
print(
f" {tid}: mean={statistics.mean(scores):.3f} "
f"| best={max(scores):.3f} "
f"| worst={min(scores):.3f} "
f"| n={len(scores)}"
)
if output_path:
with open(output_path, "w") as f:
json.dump({
"config": {
"server_url": server_url,
"tasks": tasks,
"model": model,
"n_episodes": n_episodes,
},
"episodes": all_results,
"final_task_scores": {
tid: {
"mean": round(statistics.mean(s), 4),
"best": round(max(s), 4),
"n": len(s),
}
for tid, s in task_scores.items()
},
}, f, indent=2)
print(f"\nResults saved β†’ {output_path}")
return all_results
# ──────────────────────────────────────────────
# Entry point
# ──────────────────────────────────────────────
def main():
hf_url = os.environ.get("HF_SPACE_URL", "")
default_url = hf_url if hf_url else "http://localhost:7860"
parser = argparse.ArgumentParser(description="IAMSentinel RL Training")
parser.add_argument("--server", default=default_url,
help="Server URL (default: $HF_SPACE_URL or http://localhost:7860)")
parser.add_argument("--task", default="all",
help="task1|task2|task3|all")
parser.add_argument("--episodes", type=int, default=15,
help="Total training episodes")
parser.add_argument("--seeds", default="42,123,456,789,1337",
help="Comma-separated seeds to cycle through")
parser.add_argument("--model", default="gpt-4o-mini",
help="OpenAI model to use as policy")
parser.add_argument("--output", default="training_results.json",
help="Output file for results")
parser.add_argument("--verbose", action="store_true",
help="Print step-level details")
args = parser.parse_args()
tasks = ["task1", "task2", "task3"] if args.task == "all" else [args.task]
seeds = [int(s) for s in args.seeds.split(",")]
train(
server_url=args.server,
tasks=tasks,
n_episodes=args.episodes,
seeds=seeds,
model=args.model,
verbose=args.verbose,
output_path=args.output,
)
if __name__ == "__main__":
main()