""" Evaluation & Comparison Script — AI-SOAR Network Defense ========================================================= Runs three agents across all difficulty levels and produces comparison graphs: 1. Random agent (untrained baseline) 2. Greedy agent (hardcoded optimal policy) 3. LoRA agent (your GRPO-trained model — requires GPU + unsloth) Usage: # Without trained model (random vs greedy only): python eval.py --episodes 20 # With trained LoRA model: python eval.py --episodes 20 --lora-dir grpo_ai_soar_qwen_demo40 # Save graphs and log to wandb: python eval.py --episodes 20 --lora-dir grpo_ai_soar_qwen_demo40 --wandb """ from __future__ import annotations import argparse import json import os import random import re import textwrap import warnings warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", message=".*max_new_tokens.*max_length.*") warnings.filterwarnings("ignore", category=UserWarning, message=".*urllib3.*") os.environ["TOKENIZERS_PARALLELISM"] = "false" from typing import Any, Callable, Dict, List, Optional import matplotlib matplotlib.use("Agg") # headless — works without a display import matplotlib.pyplot as plt import numpy as np try: from dotenv import load_dotenv load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), ".env"), override=False) except ImportError: pass from models import NetworkAction from server.vir_env_environment import NetworkEnvironment TASKS = ["easy", "medium", "hard"] GRAPH_DIR = "eval_graphs" # --------------------------------------------------------------------------- # Agent policies # --------------------------------------------------------------------------- def random_agent(obs: Any) -> NetworkAction: from server.vir_env_environment import NETWORK_TOPOLOGY action_type = random.choice(["scan_network", "isolate_node", "deploy_patch"]) target = random.choice(list(NETWORK_TOPOLOGY.keys())) if action_type != "scan_network" else None return NetworkAction(action_type=action_type, target=target, reasoning="random") def greedy_agent(obs: Any) -> NetworkAction: infected = [n for n, info in obs.network_state.items() if isinstance(info, dict) and info.get("status") == "infected"] if "DB_Primary" in infected: return NetworkAction(action_type="deploy_patch", target="DB_Primary", reasoning="patch DB_Primary immediately") if "Auth" in infected: return NetworkAction(action_type="isolate_node", target="Auth", reasoning="isolate Auth — doubles spread") non_db = [n for n in infected if n not in {"DB_Primary", "DB_Replica"}] if non_db: return NetworkAction(action_type="isolate_node", target=non_db[0], reasoning=f"isolate {non_db[0]}") if "DB_Replica" in infected: return NetworkAction(action_type="isolate_node", target="DB_Replica", reasoning="isolate DB_Replica") isolated = [n for n, info in obs.network_state.items() if isinstance(info, dict) and info.get("status") == "isolated"] if isolated: return NetworkAction(action_type="deploy_patch", target=isolated[0], reasoning="restore isolated node") return NetworkAction(action_type="scan_network", reasoning="all clear") def build_lora_agent(lora_dir: str) -> Callable: """Load LoRA weights and return an agent function. Requires GPU + unsloth.""" from unsloth import FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name=lora_dir, max_seq_length=1024, load_in_4bit=True, dtype=None, ) FastLanguageModel.for_inference(model) # Clear conflicting generation config defaults model.generation_config.max_length = None system_prompt = textwrap.dedent(""" You are an AI SOC analyst defending a 20-node enterprise network. Return only JSON: {"action_type": "scan_network|isolate_node|deploy_patch", "target": "NodeName or null", "reasoning": "one sentence"} Policy: only act on INFECTED nodes. DB_Primary → deploy_patch. Auth → isolate immediately. """).strip() def _obs_to_text(obs: Any) -> str: lines = [f"step={obs.step}/{obs.max_steps}", f"infected={obs.infected_count} isolated={obs.isolated_count}", f"db_infected_steps={getattr(obs,'db_infected_steps',0)}", f"auth_compromised={getattr(obs,'auth_compromised',False)}", ""] for node, info in obs.network_state.items(): st = info.get("status", "?") cn = ", ".join(info.get("connections", [])) or "none" lines.append(f"{node}: {st} links=[{cn}]") return "\n".join(lines) def lora_agent(obs: Any) -> NetworkAction: prompt = (f"<|im_start|>system\n{system_prompt}<|im_end|>\n" f"<|im_start|>user\n{_obs_to_text(obs)}\n<|im_end|>\n" f"<|im_start|>assistant\n") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with __import__("torch").no_grad(): out = model.generate(**inputs, max_new_tokens=128, temperature=0.3, do_sample=True, pad_token_id=tokenizer.eos_token_id) raw = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) m = re.search(r"\{.*\}", raw, re.DOTALL) if not m: return NetworkAction(action_type="scan_network", reasoning="parse error") try: p = json.loads(m.group()) return NetworkAction( action_type=p.get("action_type", "scan_network"), target=p.get("target") or None, reasoning=p.get("reasoning", ""), ) except Exception: return NetworkAction(action_type="scan_network", reasoning="json error") return lora_agent # --------------------------------------------------------------------------- # Episode runner # --------------------------------------------------------------------------- def run_episode(agent_fn: Callable, task: str) -> Dict: env = NetworkEnvironment() obs = env.reset(task=task) total_reward = 0.0 rewards_per_step = [] infected_per_step = [obs.infected_count] while not obs.done: action = agent_fn(obs) obs = env.step(action) total_reward += obs.reward rewards_per_step.append(obs.reward) infected_per_step.append(obs.infected_count) return { "won": obs.infected_count == 0 and not getattr(obs, "data_breach", False), "steps": obs.step, "total_reward": total_reward, "final_score": obs.cumulative_score, "rewards_per_step": rewards_per_step, "infected_curve": infected_per_step, } def run_agent(agent_fn: Callable, agent_name: str, n_episodes: int) -> Dict[str, List]: results: Dict[str, Dict[str, List]] = {t: {"won": [], "steps": [], "score": [], "infected_curve": []} for t in TASKS} for task in TASKS: print(f" [{agent_name}] {task} — ", end="", flush=True) for ep in range(n_episodes): r = run_episode(agent_fn, task) results[task]["won"].append(float(r["won"])) results[task]["steps"].append(r["steps"]) results[task]["score"].append(r["final_score"]) results[task]["infected_curve"].append(r["infected_curve"]) print("W" if r["won"] else ".", end="", flush=True) print() return results # --------------------------------------------------------------------------- # Graph generation # --------------------------------------------------------------------------- COLORS = {"Random (untrained)": "#e74c3c", "Greedy": "#2ecc71", "LoRA (trained)": "#3498db"} def _bar_chart(ax, metric_by_agent: Dict[str, Dict[str, float]], ylabel: str, title: str): agents = list(metric_by_agent.keys()) x = np.arange(len(TASKS)) width = 0.8 / len(agents) for i, agent in enumerate(agents): vals = [metric_by_agent[agent].get(t, 0) for t in TASKS] offset = (i - len(agents) / 2 + 0.5) * width bars = ax.bar(x + offset, vals, width, label=agent, color=COLORS.get(agent, f"C{i}"), alpha=0.85, edgecolor="white") for bar, v in zip(bars, vals): ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8) ax.set_xticks(x) ax.set_xticklabels([t.upper() for t in TASKS]) ax.set_ylabel(ylabel) ax.set_title(title) ax.legend(fontsize=8) ax.grid(axis="y", alpha=0.3) def _infected_curve(ax, results_by_agent: Dict[str, Dict], task: str): for agent, results in results_by_agent.items(): curves = results[task]["infected_curve"] max_len = max(len(c) for c in curves) padded = [c + [c[-1]] * (max_len - len(c)) for c in curves] arr = np.array(padded, dtype=float) mean = arr.mean(axis=0) std = arr.std(axis=0) xs = range(len(mean)) ax.plot(xs, mean, label=agent, color=COLORS.get(agent, None), linewidth=2) ax.fill_between(xs, mean - std, mean + std, color=COLORS.get(agent, None), alpha=0.15) ax.set_xlabel("Step") ax.set_ylabel("Infected nodes") ax.set_title(f"Infected over time — {task.upper()}") ax.legend(fontsize=8) ax.grid(alpha=0.3) def generate_graphs(results_by_agent: Dict[str, Dict], out_dir: str): os.makedirs(out_dir, exist_ok=True) plt.style.use("dark_background") # --- Figure 1: Win Rate, Score, Steps (3 bar charts) --- fig, axes = plt.subplots(1, 3, figsize=(15, 5)) fig.suptitle("Agent Comparison — AI-SOAR Network Defense", fontsize=14, fontweight="bold") win_rate = {a: {t: np.mean(r[t]["won"]) for t in TASKS} for a, r in results_by_agent.items()} avg_score = {a: {t: np.mean(r[t]["score"]) for t in TASKS} for a, r in results_by_agent.items()} avg_steps = {a: {t: np.mean(r[t]["steps"]) for t in TASKS} for a, r in results_by_agent.items()} _bar_chart(axes[0], win_rate, "Win Rate", "Win Rate by Difficulty") _bar_chart(axes[1], avg_score, "Avg Final Score", "Score by Difficulty") _bar_chart(axes[2], avg_steps, "Avg Steps", "Steps to Done by Difficulty") plt.tight_layout() path1 = os.path.join(out_dir, "comparison_bars.png") fig.savefig(path1, dpi=150, bbox_inches="tight") print(f" Saved: {path1}") plt.close(fig) # --- Figure 2: Infected-over-time curves per difficulty --- fig, axes = plt.subplots(1, 3, figsize=(15, 4)) fig.suptitle("Infected Nodes Over Time (mean ± std)", fontsize=13, fontweight="bold") for ax, task in zip(axes, TASKS): _infected_curve(ax, results_by_agent, task) plt.tight_layout() path2 = os.path.join(out_dir, "infected_curves.png") fig.savefig(path2, dpi=150, bbox_inches="tight") print(f" Saved: {path2}") plt.close(fig) # --- Figure 3: Score distribution violin plot --- fig, axes = plt.subplots(1, 3, figsize=(15, 5)) fig.suptitle("Score Distribution per Difficulty", fontsize=13, fontweight="bold") for ax, task in zip(axes, TASKS): data = [results_by_agent[a][task]["score"] for a in results_by_agent] labels = list(results_by_agent.keys()) parts = ax.violinplot(data, showmedians=True, showextrema=True) for pc, agent in zip(parts["bodies"], labels): pc.set_facecolor(COLORS.get(agent, "grey")) pc.set_alpha(0.7) ax.set_xticks(range(1, len(labels) + 1)) ax.set_xticklabels([a.split()[0] for a in labels], fontsize=8) ax.set_ylabel("Final Score") ax.set_title(f"{task.upper()}") ax.grid(axis="y", alpha=0.3) plt.tight_layout() path3 = os.path.join(out_dir, "score_distribution.png") fig.savefig(path3, dpi=150, bbox_inches="tight") print(f" Saved: {path3}") plt.close(fig) return [path1, path2, path3] # --------------------------------------------------------------------------- # WandB logging # --------------------------------------------------------------------------- def log_to_wandb(results_by_agent: Dict[str, Dict], graph_paths: List[str]): try: import wandb if not callable(getattr(wandb, "init", None)): return except ImportError: return wandb.init(project="vir_env", name="eval-comparison", reinit=True) for agent, results in results_by_agent.items(): for task in TASKS: wandb.log({ f"{agent}/{task}/win_rate": np.mean(results[task]["won"]), f"{agent}/{task}/avg_score": np.mean(results[task]["score"]), f"{agent}/{task}/avg_steps": np.mean(results[task]["steps"]), }) wandb.log({"graphs": [wandb.Image(p) for p in graph_paths]}) wandb.finish() print(" Logged to wandb.") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--episodes", type=int, default=20, help="Episodes per agent per task") p.add_argument("--lora-dir", type=str, default=None, help="Path to trained LoRA dir") p.add_argument("--out-dir", type=str, default=GRAPH_DIR) p.add_argument("--wandb", action="store_true", help="Log results to wandb") p.add_argument("--seed", type=int, default=42) return p.parse_args() def main(): args = parse_args() random.seed(args.seed) np.random.seed(args.seed) agents: Dict[str, Callable] = { "Random (untrained)": random_agent, "Greedy": greedy_agent, } if args.lora_dir: print(f"Loading LoRA model from {args.lora_dir} ...") try: agents["LoRA (trained)"] = build_lora_agent(args.lora_dir) print(" LoRA model loaded.") except Exception as e: print(f" Could not load LoRA model: {e}\n Skipping trained agent.") print(f"\nRunning {args.episodes} episodes per agent per task ...\n") results_by_agent: Dict[str, Dict] = {} for name, fn in agents.items(): print(f"Agent: {name}") results_by_agent[name] = run_agent(fn, name, args.episodes) print("\nGenerating graphs ...") graph_paths = generate_graphs(results_by_agent, args.out_dir) if args.wandb: print("Logging to wandb ...") log_to_wandb(results_by_agent, graph_paths) # Print summary table print("\n" + "=" * 60) print(f"{'Agent':<22} {'Task':<8} {'Win%':>6} {'Score':>8} {'Steps':>7}") print("-" * 60) for agent, results in results_by_agent.items(): for task in TASKS: wr = np.mean(results[task]["won"]) * 100 sc = np.mean(results[task]["score"]) st = np.mean(results[task]["steps"]) print(f"{agent:<22} {task:<8} {wr:>5.1f}% {sc:>8.2f} {st:>7.1f}") print("=" * 60) print(f"\nGraphs saved to: {os.path.abspath(args.out_dir)}/") if __name__ == "__main__": main()