scaler-openenv / train_rl.py
Hacktrix-121's picture
removed start from prev best weights
dd6f45c
"""
train_rl.py β€” Training + Evaluation Script for Adaptive Alert Triage
=====================================================================
Runs:
1. Rule-based baselines (RuleBasedAgent, ImprovedRuleBasedAgent)
2. PPO RL agent training across all 3 tasks
3. Saves all results to results.json for the comparison plot
Changes vs previous version:
- Per-task episode budgets: hard gets 3Γ— more episodes than easy
because it has 40% chain probability and needs far more samples
to observe enough chain outcomes for the policy to converge.
- --episodes arg now sets the EASY budget; medium and hard scale up.
- Grader is now passed into trainer.train() so terminal scores are
injected into trajectories (was only logged after, not used).
Usage:
python train_rl.py [--episodes 300] [--eval-episodes 20] [--seed 42]
"""
from __future__ import annotations
import json
import sys
import os
import argparse
import time
from typing import Any, Dict, List
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from adaptive_alert_triage.env import AdaptiveAlertTriageEnv
from adaptive_alert_triage.models import Action
from tasks.easy import EasyTaskGrader, SUCCESS_THRESHOLD as EASY_THRESH
from tasks.medium import MediumTaskGrader, SUCCESS_THRESHOLD as MED_THRESH
from tasks.hard import HardTaskGrader, SUCCESS_THRESHOLD as HARD_THRESH
from rl_agent import PPOTrainer, encode_state, _ACTION_NAMES
_GRADER_REGISTRY = {
"easy": (EasyTaskGrader, {}, EASY_THRESH),
"medium": (MediumTaskGrader, {"max_investigations_per_step": 3}, MED_THRESH),
"hard": (HardTaskGrader, {}, HARD_THRESH),
}
# Per-task episode budgets.
# Hard needs ~3Γ— easy: 40% chain probability over 50 steps means the agent
# sees ~20 chain alerts per episode but needs hundreds of episodes to learn
# which observable features (age, type, severity pattern) correlate with
# chain membership. Medium gets 1.5Γ— for the resource-constraint curriculum.
_EPISODE_SCALE = {
"easy": 1.0,
"medium": 1.5,
"hard": 3.0,
}
# ── Rule-based agents ─────────────────────────────────────────────────────
class SimpleRuleAgent:
name = "RuleBased"
def act(self, obs):
if not obs.alerts:
raise ValueError("No alerts")
alert = max(obs.alerts, key=lambda a: a.visible_severity)
sev, conf = alert.visible_severity, alert.confidence
budget = obs.resource_budget
if sev > 0.75 and conf > 0.70:
if budget is not None and budget <= 0:
return Action(alert_id=alert.id, action_type="ESCALATE")
return Action(alert_id=alert.id, action_type="INVESTIGATE")
if conf < 0.30:
return Action(alert_id=alert.id, action_type="IGNORE")
if sev > 0.55:
return Action(alert_id=alert.id, action_type="ESCALATE")
return Action(alert_id=alert.id, action_type="DELAY")
def reset(self):
pass
class ImprovedRuleAgent:
name = "ImprovedRule"
def act(self, obs):
if not obs.alerts:
raise ValueError("No alerts")
def score(a):
s = a.visible_severity * 2.0 + a.age * 0.08
if a.alert_type == "SECURITY":
s += 0.1
return s
alert = max(obs.alerts, key=score)
sev, conf, age = alert.visible_severity, alert.confidence, alert.age
budget, sys_load = obs.resource_budget, obs.system_load
if age >= 3 and sev > 0.70:
if budget is not None and budget <= 0:
return Action(alert_id=alert.id, action_type="ESCALATE")
return Action(alert_id=alert.id, action_type="INVESTIGATE")
if sys_load > 0.85:
if sev > 0.85 and conf > 0.80:
return Action(alert_id=alert.id, action_type="INVESTIGATE")
return Action(alert_id=alert.id, action_type="DELAY")
if sev > 0.75 and conf > 0.70:
if budget is not None and budget <= 0:
return Action(alert_id=alert.id, action_type="ESCALATE")
return Action(alert_id=alert.id, action_type="INVESTIGATE")
if conf < 0.30:
return Action(alert_id=alert.id, action_type="IGNORE")
if sev > 0.55:
return Action(alert_id=alert.id, action_type="ESCALATE")
return Action(alert_id=alert.id, action_type="DELAY")
def reset(self):
pass
# ── Evaluation ────────────────────────────────────────────────────────────
def evaluate_agent(agent, task_id: str, n_episodes: int, seed_offset: int = 0) -> Dict:
grader_cls, grader_kwargs, threshold = _GRADER_REGISTRY[task_id]
env = AdaptiveAlertTriageEnv(task_id=task_id)
scores = []
is_hard = (task_id == "hard")
for ep in range(n_episodes):
grader = grader_cls(**grader_kwargs)
if hasattr(agent, 'reset'):
agent.reset()
obs = env.reset(seed=seed_offset + ep)
done = False
while not done:
if not obs.alerts:
break
action = agent.act(obs)
obs, _r, done, info = env.step(action)
if is_hard:
grader.update_correlation_state(info.get("correlation_groups", []))
for ad in info.get("processed_alerts", []):
grader.process_step(ad, info)
if is_hard:
grader.record_failures(info.get("failures_this_step", 0))
scores.append(grader.get_episode_score())
arr = np.array(scores)
return {
"mean": float(arr.mean()),
"std": float(arr.std()),
"min": float(arr.min()),
"max": float(arr.max()),
"success_rate": float((arr >= threshold).mean()),
"scores": scores,
}
# ── PPO wrapper ───────────────────────────────────────────────────────────
class PPOAgentWrapper:
def __init__(self, trainer: PPOTrainer):
self._trainer = trainer
self.name = "PPO_LSTM"
def act(self, obs):
return self._trainer.act(obs)
def reset(self):
self._trainer.reset()
# ── Main ──────────────────────────────────────────────────────────────────
def run(args):
results = {}
for task_id in ["easy", "medium", "hard"]:
grader_cls, grader_kwargs, threshold = _GRADER_REGISTRY[task_id]
# Per-task episode budget
n_episodes = int(args.episodes * _EPISODE_SCALE[task_id])
print(f"\n{'='*60}")
print(f"TASK: {task_id.upper()} "
f"(threshold β‰₯ {threshold}, episodes = {n_episodes})")
print(f"{'='*60}")
# 1. Rule-based baselines
print(f"\n[1/3] Evaluating rule-based baselines…")
rb_basic_res = evaluate_agent(SimpleRuleAgent(), task_id, args.eval_episodes, seed_offset=100)
rb_improved_res = evaluate_agent(ImprovedRuleAgent(), task_id, args.eval_episodes, seed_offset=100)
print(f" RuleBased : mean={rb_basic_res['mean']:.3f} "
f"success={rb_basic_res['success_rate']:.0%}")
print(f" ImprovedRule : mean={rb_improved_res['mean']:.3f} "
f"success={rb_improved_res['success_rate']:.0%}")
# 2. PPO training (grader passed in so terminal scores hit trajectories)
print(f"\n[2/3] Training PPO agent ({n_episodes} episodes)…")
env = AdaptiveAlertTriageEnv(task_id=task_id)
trainer = PPOTrainer(task_id=task_id, seed=args.seed, lr=3e-4)
t0 = time.time()
history = trainer.train(
env,
n_episodes = n_episodes,
grader_cls = grader_cls,
grader_kwargs = grader_kwargs,
log_interval = max(1, n_episodes // 10),
verbose = True,
)
elapsed = time.time() - t0
print(f" Training done in {elapsed:.1f}s")
os.makedirs("weights", exist_ok=True)
trainer.save(f"weights/ppo_{task_id}.json")
# 3. PPO evaluation
print(f"\n[3/3] Evaluating PPO agent ({args.eval_episodes} episodes)…")
ppo_agent = PPOAgentWrapper(trainer)
ppo_res = evaluate_agent(ppo_agent, task_id, args.eval_episodes, seed_offset=200)
print(f" PPO : mean={ppo_res['mean']:.3f} "
f"success={ppo_res['success_rate']:.0%}")
results[task_id] = {
"threshold": threshold,
"n_episodes": n_episodes,
"rule_basic": rb_basic_res,
"rule_improved": rb_improved_res,
"ppo": ppo_res,
"training": {
"episode_rewards": history["episode_rewards"],
"episode_scores": history["episode_scores"],
"policy_losses": history["policy_losses"],
"entropies": history["entropies"],
},
}
# Save results
class _NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (np.floating,)): return float(obj)
if isinstance(obj, (np.integer,)): return int(obj)
if isinstance(obj, np.ndarray): return obj.tolist()
return super().default(obj)
os.makedirs("results", exist_ok=True)
out_path = "results/comparison_results.json"
with open(out_path, "w") as f:
json.dump(results, f, indent=2, cls=_NumpyEncoder)
print(f"\nβœ“ Results saved to {out_path}")
# Summary table
print(f"\n{'='*60}")
print("FINAL COMPARISON SUMMARY")
print(f"{'='*60}")
print(f"{'Task':<10} {'Agent':<16} {'Mean':>8} {'Std':>7} {'Pass%':>8}")
print("─" * 55)
for task_id, res in results.items():
for name, key in [("RuleBased", "rule_basic"),
("ImprovedRule", "rule_improved"),
("PPO+LSTM", "ppo")]:
r = res[key]
print(f"{task_id:<10} {name:<16} "
f"{r['mean']:>8.3f} "
f"{r['std']:>7.3f} "
f"{r['success_rate']*100:>7.1f}%")
print("─" * 55)
return results
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--episodes", type=int, default=300,
help="Episode budget for easy task; medium=1.5Γ—, hard=3Γ—")
p.add_argument("--eval-episodes", type=int, default=20)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
run(args)