Spaces:
Sleeping
Sleeping
| """ | |
| Evaluation & Comparison Script — AI-SOAR Network Defense | |
| ========================================================= | |
| Runs three agents across all difficulty levels and produces comparison graphs: | |
| 1. Random agent (untrained baseline) | |
| 2. Greedy agent (hardcoded optimal policy) | |
| 3. LoRA agent (your GRPO-trained model — requires GPU + unsloth) | |
| Usage: | |
| # Without trained model (random vs greedy only): | |
| python eval.py --episodes 20 | |
| # With trained LoRA model: | |
| python eval.py --episodes 20 --lora-dir grpo_ai_soar_qwen_demo40 | |
| # Save graphs and log to wandb: | |
| python eval.py --episodes 20 --lora-dir grpo_ai_soar_qwen_demo40 --wandb | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import re | |
| import textwrap | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", message=".*max_new_tokens.*max_length.*") | |
| warnings.filterwarnings("ignore", category=UserWarning, message=".*urllib3.*") | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| from typing import Any, Callable, Dict, List, Optional | |
| import matplotlib | |
| matplotlib.use("Agg") # headless — works without a display | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), ".env"), override=False) | |
| except ImportError: | |
| pass | |
| from models import NetworkAction | |
| from server.vir_env_environment import NetworkEnvironment | |
| TASKS = ["easy", "medium", "hard"] | |
| GRAPH_DIR = "eval_graphs" | |
| # --------------------------------------------------------------------------- | |
| # Agent policies | |
| # --------------------------------------------------------------------------- | |
| def random_agent(obs: Any) -> NetworkAction: | |
| from server.vir_env_environment import NETWORK_TOPOLOGY | |
| action_type = random.choice(["scan_network", "isolate_node", "deploy_patch"]) | |
| target = random.choice(list(NETWORK_TOPOLOGY.keys())) if action_type != "scan_network" else None | |
| return NetworkAction(action_type=action_type, target=target, reasoning="random") | |
| def greedy_agent(obs: Any) -> NetworkAction: | |
| infected = [n for n, info in obs.network_state.items() | |
| if isinstance(info, dict) and info.get("status") == "infected"] | |
| if "DB_Primary" in infected: | |
| return NetworkAction(action_type="deploy_patch", target="DB_Primary", | |
| reasoning="patch DB_Primary immediately") | |
| if "Auth" in infected: | |
| return NetworkAction(action_type="isolate_node", target="Auth", | |
| reasoning="isolate Auth — doubles spread") | |
| non_db = [n for n in infected if n not in {"DB_Primary", "DB_Replica"}] | |
| if non_db: | |
| return NetworkAction(action_type="isolate_node", target=non_db[0], | |
| reasoning=f"isolate {non_db[0]}") | |
| if "DB_Replica" in infected: | |
| return NetworkAction(action_type="isolate_node", target="DB_Replica", | |
| reasoning="isolate DB_Replica") | |
| isolated = [n for n, info in obs.network_state.items() | |
| if isinstance(info, dict) and info.get("status") == "isolated"] | |
| if isolated: | |
| return NetworkAction(action_type="deploy_patch", target=isolated[0], | |
| reasoning="restore isolated node") | |
| return NetworkAction(action_type="scan_network", reasoning="all clear") | |
| def build_lora_agent(lora_dir: str) -> Callable: | |
| """Load LoRA weights and return an agent function. Requires GPU + unsloth.""" | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=lora_dir, | |
| max_seq_length=1024, | |
| load_in_4bit=True, | |
| dtype=None, | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| # Clear conflicting generation config defaults | |
| model.generation_config.max_length = None | |
| system_prompt = textwrap.dedent(""" | |
| You are an AI SOC analyst defending a 20-node enterprise network. | |
| Return only JSON: {"action_type": "scan_network|isolate_node|deploy_patch", | |
| "target": "NodeName or null", "reasoning": "one sentence"} | |
| Policy: only act on INFECTED nodes. DB_Primary → deploy_patch. Auth → isolate immediately. | |
| """).strip() | |
| def _obs_to_text(obs: Any) -> str: | |
| lines = [f"step={obs.step}/{obs.max_steps}", | |
| f"infected={obs.infected_count} isolated={obs.isolated_count}", | |
| f"db_infected_steps={getattr(obs,'db_infected_steps',0)}", | |
| f"auth_compromised={getattr(obs,'auth_compromised',False)}", ""] | |
| for node, info in obs.network_state.items(): | |
| st = info.get("status", "?") | |
| cn = ", ".join(info.get("connections", [])) or "none" | |
| lines.append(f"{node}: {st} links=[{cn}]") | |
| return "\n".join(lines) | |
| def lora_agent(obs: Any) -> NetworkAction: | |
| prompt = (f"<|im_start|>system\n{system_prompt}<|im_end|>\n" | |
| f"<|im_start|>user\n{_obs_to_text(obs)}\n<|im_end|>\n" | |
| f"<|im_start|>assistant\n") | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with __import__("torch").no_grad(): | |
| out = model.generate(**inputs, max_new_tokens=128, temperature=0.3, | |
| do_sample=True, pad_token_id=tokenizer.eos_token_id) | |
| raw = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
| m = re.search(r"\{.*\}", raw, re.DOTALL) | |
| if not m: | |
| return NetworkAction(action_type="scan_network", reasoning="parse error") | |
| try: | |
| p = json.loads(m.group()) | |
| return NetworkAction( | |
| action_type=p.get("action_type", "scan_network"), | |
| target=p.get("target") or None, | |
| reasoning=p.get("reasoning", ""), | |
| ) | |
| except Exception: | |
| return NetworkAction(action_type="scan_network", reasoning="json error") | |
| return lora_agent | |
| # --------------------------------------------------------------------------- | |
| # Episode runner | |
| # --------------------------------------------------------------------------- | |
| def run_episode(agent_fn: Callable, task: str) -> Dict: | |
| env = NetworkEnvironment() | |
| obs = env.reset(task=task) | |
| total_reward = 0.0 | |
| rewards_per_step = [] | |
| infected_per_step = [obs.infected_count] | |
| while not obs.done: | |
| action = agent_fn(obs) | |
| obs = env.step(action) | |
| total_reward += obs.reward | |
| rewards_per_step.append(obs.reward) | |
| infected_per_step.append(obs.infected_count) | |
| return { | |
| "won": obs.infected_count == 0 and not getattr(obs, "data_breach", False), | |
| "steps": obs.step, | |
| "total_reward": total_reward, | |
| "final_score": obs.cumulative_score, | |
| "rewards_per_step": rewards_per_step, | |
| "infected_curve": infected_per_step, | |
| } | |
| def run_agent(agent_fn: Callable, agent_name: str, n_episodes: int) -> Dict[str, List]: | |
| results: Dict[str, Dict[str, List]] = {t: {"won": [], "steps": [], "score": [], | |
| "infected_curve": []} for t in TASKS} | |
| for task in TASKS: | |
| print(f" [{agent_name}] {task} — ", end="", flush=True) | |
| for ep in range(n_episodes): | |
| r = run_episode(agent_fn, task) | |
| results[task]["won"].append(float(r["won"])) | |
| results[task]["steps"].append(r["steps"]) | |
| results[task]["score"].append(r["final_score"]) | |
| results[task]["infected_curve"].append(r["infected_curve"]) | |
| print("W" if r["won"] else ".", end="", flush=True) | |
| print() | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Graph generation | |
| # --------------------------------------------------------------------------- | |
| COLORS = {"Random (untrained)": "#e74c3c", "Greedy": "#2ecc71", "LoRA (trained)": "#3498db"} | |
| def _bar_chart(ax, metric_by_agent: Dict[str, Dict[str, float]], ylabel: str, title: str): | |
| agents = list(metric_by_agent.keys()) | |
| x = np.arange(len(TASKS)) | |
| width = 0.8 / len(agents) | |
| for i, agent in enumerate(agents): | |
| vals = [metric_by_agent[agent].get(t, 0) for t in TASKS] | |
| offset = (i - len(agents) / 2 + 0.5) * width | |
| bars = ax.bar(x + offset, vals, width, label=agent, | |
| color=COLORS.get(agent, f"C{i}"), alpha=0.85, edgecolor="white") | |
| for bar, v in zip(bars, vals): | |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, | |
| f"{v:.2f}", ha="center", va="bottom", fontsize=8) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels([t.upper() for t in TASKS]) | |
| ax.set_ylabel(ylabel) | |
| ax.set_title(title) | |
| ax.legend(fontsize=8) | |
| ax.grid(axis="y", alpha=0.3) | |
| def _infected_curve(ax, results_by_agent: Dict[str, Dict], task: str): | |
| for agent, results in results_by_agent.items(): | |
| curves = results[task]["infected_curve"] | |
| max_len = max(len(c) for c in curves) | |
| padded = [c + [c[-1]] * (max_len - len(c)) for c in curves] | |
| arr = np.array(padded, dtype=float) | |
| mean = arr.mean(axis=0) | |
| std = arr.std(axis=0) | |
| xs = range(len(mean)) | |
| ax.plot(xs, mean, label=agent, color=COLORS.get(agent, None), linewidth=2) | |
| ax.fill_between(xs, mean - std, mean + std, | |
| color=COLORS.get(agent, None), alpha=0.15) | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Infected nodes") | |
| ax.set_title(f"Infected over time — {task.upper()}") | |
| ax.legend(fontsize=8) | |
| ax.grid(alpha=0.3) | |
| def generate_graphs(results_by_agent: Dict[str, Dict], out_dir: str): | |
| os.makedirs(out_dir, exist_ok=True) | |
| plt.style.use("dark_background") | |
| # --- Figure 1: Win Rate, Score, Steps (3 bar charts) --- | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| fig.suptitle("Agent Comparison — AI-SOAR Network Defense", fontsize=14, fontweight="bold") | |
| win_rate = {a: {t: np.mean(r[t]["won"]) for t in TASKS} for a, r in results_by_agent.items()} | |
| avg_score = {a: {t: np.mean(r[t]["score"]) for t in TASKS} for a, r in results_by_agent.items()} | |
| avg_steps = {a: {t: np.mean(r[t]["steps"]) for t in TASKS} for a, r in results_by_agent.items()} | |
| _bar_chart(axes[0], win_rate, "Win Rate", "Win Rate by Difficulty") | |
| _bar_chart(axes[1], avg_score, "Avg Final Score", "Score by Difficulty") | |
| _bar_chart(axes[2], avg_steps, "Avg Steps", "Steps to Done by Difficulty") | |
| plt.tight_layout() | |
| path1 = os.path.join(out_dir, "comparison_bars.png") | |
| fig.savefig(path1, dpi=150, bbox_inches="tight") | |
| print(f" Saved: {path1}") | |
| plt.close(fig) | |
| # --- Figure 2: Infected-over-time curves per difficulty --- | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 4)) | |
| fig.suptitle("Infected Nodes Over Time (mean ± std)", fontsize=13, fontweight="bold") | |
| for ax, task in zip(axes, TASKS): | |
| _infected_curve(ax, results_by_agent, task) | |
| plt.tight_layout() | |
| path2 = os.path.join(out_dir, "infected_curves.png") | |
| fig.savefig(path2, dpi=150, bbox_inches="tight") | |
| print(f" Saved: {path2}") | |
| plt.close(fig) | |
| # --- Figure 3: Score distribution violin plot --- | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| fig.suptitle("Score Distribution per Difficulty", fontsize=13, fontweight="bold") | |
| for ax, task in zip(axes, TASKS): | |
| data = [results_by_agent[a][task]["score"] for a in results_by_agent] | |
| labels = list(results_by_agent.keys()) | |
| parts = ax.violinplot(data, showmedians=True, showextrema=True) | |
| for pc, agent in zip(parts["bodies"], labels): | |
| pc.set_facecolor(COLORS.get(agent, "grey")) | |
| pc.set_alpha(0.7) | |
| ax.set_xticks(range(1, len(labels) + 1)) | |
| ax.set_xticklabels([a.split()[0] for a in labels], fontsize=8) | |
| ax.set_ylabel("Final Score") | |
| ax.set_title(f"{task.upper()}") | |
| ax.grid(axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| path3 = os.path.join(out_dir, "score_distribution.png") | |
| fig.savefig(path3, dpi=150, bbox_inches="tight") | |
| print(f" Saved: {path3}") | |
| plt.close(fig) | |
| return [path1, path2, path3] | |
| # --------------------------------------------------------------------------- | |
| # WandB logging | |
| # --------------------------------------------------------------------------- | |
| def log_to_wandb(results_by_agent: Dict[str, Dict], graph_paths: List[str]): | |
| try: | |
| import wandb | |
| if not callable(getattr(wandb, "init", None)): | |
| return | |
| except ImportError: | |
| return | |
| wandb.init(project="vir_env", name="eval-comparison", reinit=True) | |
| for agent, results in results_by_agent.items(): | |
| for task in TASKS: | |
| wandb.log({ | |
| f"{agent}/{task}/win_rate": np.mean(results[task]["won"]), | |
| f"{agent}/{task}/avg_score": np.mean(results[task]["score"]), | |
| f"{agent}/{task}/avg_steps": np.mean(results[task]["steps"]), | |
| }) | |
| wandb.log({"graphs": [wandb.Image(p) for p in graph_paths]}) | |
| wandb.finish() | |
| print(" Logged to wandb.") | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--episodes", type=int, default=20, help="Episodes per agent per task") | |
| p.add_argument("--lora-dir", type=str, default=None, help="Path to trained LoRA dir") | |
| p.add_argument("--out-dir", type=str, default=GRAPH_DIR) | |
| p.add_argument("--wandb", action="store_true", help="Log results to wandb") | |
| p.add_argument("--seed", type=int, default=42) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| agents: Dict[str, Callable] = { | |
| "Random (untrained)": random_agent, | |
| "Greedy": greedy_agent, | |
| } | |
| if args.lora_dir: | |
| print(f"Loading LoRA model from {args.lora_dir} ...") | |
| try: | |
| agents["LoRA (trained)"] = build_lora_agent(args.lora_dir) | |
| print(" LoRA model loaded.") | |
| except Exception as e: | |
| print(f" Could not load LoRA model: {e}\n Skipping trained agent.") | |
| print(f"\nRunning {args.episodes} episodes per agent per task ...\n") | |
| results_by_agent: Dict[str, Dict] = {} | |
| for name, fn in agents.items(): | |
| print(f"Agent: {name}") | |
| results_by_agent[name] = run_agent(fn, name, args.episodes) | |
| print("\nGenerating graphs ...") | |
| graph_paths = generate_graphs(results_by_agent, args.out_dir) | |
| if args.wandb: | |
| print("Logging to wandb ...") | |
| log_to_wandb(results_by_agent, graph_paths) | |
| # Print summary table | |
| print("\n" + "=" * 60) | |
| print(f"{'Agent':<22} {'Task':<8} {'Win%':>6} {'Score':>8} {'Steps':>7}") | |
| print("-" * 60) | |
| for agent, results in results_by_agent.items(): | |
| for task in TASKS: | |
| wr = np.mean(results[task]["won"]) * 100 | |
| sc = np.mean(results[task]["score"]) | |
| st = np.mean(results[task]["steps"]) | |
| print(f"{agent:<22} {task:<8} {wr:>5.1f}% {sc:>8.2f} {st:>7.1f}") | |
| print("=" * 60) | |
| print(f"\nGraphs saved to: {os.path.abspath(args.out_dir)}/") | |
| if __name__ == "__main__": | |
| main() | |