vir_env / eval.py
arun-misra's picture
Upload folder using huggingface_hub
94496c6 verified
"""
Evaluation & Comparison Script — AI-SOAR Network Defense
=========================================================
Runs three agents across all difficulty levels and produces comparison graphs:
1. Random agent (untrained baseline)
2. Greedy agent (hardcoded optimal policy)
3. LoRA agent (your GRPO-trained model — requires GPU + unsloth)
Usage:
# Without trained model (random vs greedy only):
python eval.py --episodes 20
# With trained LoRA model:
python eval.py --episodes 20 --lora-dir grpo_ai_soar_qwen_demo40
# Save graphs and log to wandb:
python eval.py --episodes 20 --lora-dir grpo_ai_soar_qwen_demo40 --wandb
"""
from __future__ import annotations
import argparse
import json
import os
import random
import re
import textwrap
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*max_new_tokens.*max_length.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*urllib3.*")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import Any, Callable, Dict, List, Optional
import matplotlib
matplotlib.use("Agg") # headless — works without a display
import matplotlib.pyplot as plt
import numpy as np
try:
from dotenv import load_dotenv
load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), ".env"), override=False)
except ImportError:
pass
from models import NetworkAction
from server.vir_env_environment import NetworkEnvironment
TASKS = ["easy", "medium", "hard"]
GRAPH_DIR = "eval_graphs"
# ---------------------------------------------------------------------------
# Agent policies
# ---------------------------------------------------------------------------
def random_agent(obs: Any) -> NetworkAction:
from server.vir_env_environment import NETWORK_TOPOLOGY
action_type = random.choice(["scan_network", "isolate_node", "deploy_patch"])
target = random.choice(list(NETWORK_TOPOLOGY.keys())) if action_type != "scan_network" else None
return NetworkAction(action_type=action_type, target=target, reasoning="random")
def greedy_agent(obs: Any) -> NetworkAction:
infected = [n for n, info in obs.network_state.items()
if isinstance(info, dict) and info.get("status") == "infected"]
if "DB_Primary" in infected:
return NetworkAction(action_type="deploy_patch", target="DB_Primary",
reasoning="patch DB_Primary immediately")
if "Auth" in infected:
return NetworkAction(action_type="isolate_node", target="Auth",
reasoning="isolate Auth — doubles spread")
non_db = [n for n in infected if n not in {"DB_Primary", "DB_Replica"}]
if non_db:
return NetworkAction(action_type="isolate_node", target=non_db[0],
reasoning=f"isolate {non_db[0]}")
if "DB_Replica" in infected:
return NetworkAction(action_type="isolate_node", target="DB_Replica",
reasoning="isolate DB_Replica")
isolated = [n for n, info in obs.network_state.items()
if isinstance(info, dict) and info.get("status") == "isolated"]
if isolated:
return NetworkAction(action_type="deploy_patch", target=isolated[0],
reasoning="restore isolated node")
return NetworkAction(action_type="scan_network", reasoning="all clear")
def build_lora_agent(lora_dir: str) -> Callable:
"""Load LoRA weights and return an agent function. Requires GPU + unsloth."""
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=lora_dir,
max_seq_length=1024,
load_in_4bit=True,
dtype=None,
)
FastLanguageModel.for_inference(model)
# Clear conflicting generation config defaults
model.generation_config.max_length = None
system_prompt = textwrap.dedent("""
You are an AI SOC analyst defending a 20-node enterprise network.
Return only JSON: {"action_type": "scan_network|isolate_node|deploy_patch",
"target": "NodeName or null", "reasoning": "one sentence"}
Policy: only act on INFECTED nodes. DB_Primary → deploy_patch. Auth → isolate immediately.
""").strip()
def _obs_to_text(obs: Any) -> str:
lines = [f"step={obs.step}/{obs.max_steps}",
f"infected={obs.infected_count} isolated={obs.isolated_count}",
f"db_infected_steps={getattr(obs,'db_infected_steps',0)}",
f"auth_compromised={getattr(obs,'auth_compromised',False)}", ""]
for node, info in obs.network_state.items():
st = info.get("status", "?")
cn = ", ".join(info.get("connections", [])) or "none"
lines.append(f"{node}: {st} links=[{cn}]")
return "\n".join(lines)
def lora_agent(obs: Any) -> NetworkAction:
prompt = (f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
f"<|im_start|>user\n{_obs_to_text(obs)}\n<|im_end|>\n"
f"<|im_start|>assistant\n")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with __import__("torch").no_grad():
out = model.generate(**inputs, max_new_tokens=128, temperature=0.3,
do_sample=True, pad_token_id=tokenizer.eos_token_id)
raw = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
m = re.search(r"\{.*\}", raw, re.DOTALL)
if not m:
return NetworkAction(action_type="scan_network", reasoning="parse error")
try:
p = json.loads(m.group())
return NetworkAction(
action_type=p.get("action_type", "scan_network"),
target=p.get("target") or None,
reasoning=p.get("reasoning", ""),
)
except Exception:
return NetworkAction(action_type="scan_network", reasoning="json error")
return lora_agent
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
def run_episode(agent_fn: Callable, task: str) -> Dict:
env = NetworkEnvironment()
obs = env.reset(task=task)
total_reward = 0.0
rewards_per_step = []
infected_per_step = [obs.infected_count]
while not obs.done:
action = agent_fn(obs)
obs = env.step(action)
total_reward += obs.reward
rewards_per_step.append(obs.reward)
infected_per_step.append(obs.infected_count)
return {
"won": obs.infected_count == 0 and not getattr(obs, "data_breach", False),
"steps": obs.step,
"total_reward": total_reward,
"final_score": obs.cumulative_score,
"rewards_per_step": rewards_per_step,
"infected_curve": infected_per_step,
}
def run_agent(agent_fn: Callable, agent_name: str, n_episodes: int) -> Dict[str, List]:
results: Dict[str, Dict[str, List]] = {t: {"won": [], "steps": [], "score": [],
"infected_curve": []} for t in TASKS}
for task in TASKS:
print(f" [{agent_name}] {task} — ", end="", flush=True)
for ep in range(n_episodes):
r = run_episode(agent_fn, task)
results[task]["won"].append(float(r["won"]))
results[task]["steps"].append(r["steps"])
results[task]["score"].append(r["final_score"])
results[task]["infected_curve"].append(r["infected_curve"])
print("W" if r["won"] else ".", end="", flush=True)
print()
return results
# ---------------------------------------------------------------------------
# Graph generation
# ---------------------------------------------------------------------------
COLORS = {"Random (untrained)": "#e74c3c", "Greedy": "#2ecc71", "LoRA (trained)": "#3498db"}
def _bar_chart(ax, metric_by_agent: Dict[str, Dict[str, float]], ylabel: str, title: str):
agents = list(metric_by_agent.keys())
x = np.arange(len(TASKS))
width = 0.8 / len(agents)
for i, agent in enumerate(agents):
vals = [metric_by_agent[agent].get(t, 0) for t in TASKS]
offset = (i - len(agents) / 2 + 0.5) * width
bars = ax.bar(x + offset, vals, width, label=agent,
color=COLORS.get(agent, f"C{i}"), alpha=0.85, edgecolor="white")
for bar, v in zip(bars, vals):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
f"{v:.2f}", ha="center", va="bottom", fontsize=8)
ax.set_xticks(x)
ax.set_xticklabels([t.upper() for t in TASKS])
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.legend(fontsize=8)
ax.grid(axis="y", alpha=0.3)
def _infected_curve(ax, results_by_agent: Dict[str, Dict], task: str):
for agent, results in results_by_agent.items():
curves = results[task]["infected_curve"]
max_len = max(len(c) for c in curves)
padded = [c + [c[-1]] * (max_len - len(c)) for c in curves]
arr = np.array(padded, dtype=float)
mean = arr.mean(axis=0)
std = arr.std(axis=0)
xs = range(len(mean))
ax.plot(xs, mean, label=agent, color=COLORS.get(agent, None), linewidth=2)
ax.fill_between(xs, mean - std, mean + std,
color=COLORS.get(agent, None), alpha=0.15)
ax.set_xlabel("Step")
ax.set_ylabel("Infected nodes")
ax.set_title(f"Infected over time — {task.upper()}")
ax.legend(fontsize=8)
ax.grid(alpha=0.3)
def generate_graphs(results_by_agent: Dict[str, Dict], out_dir: str):
os.makedirs(out_dir, exist_ok=True)
plt.style.use("dark_background")
# --- Figure 1: Win Rate, Score, Steps (3 bar charts) ---
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("Agent Comparison — AI-SOAR Network Defense", fontsize=14, fontweight="bold")
win_rate = {a: {t: np.mean(r[t]["won"]) for t in TASKS} for a, r in results_by_agent.items()}
avg_score = {a: {t: np.mean(r[t]["score"]) for t in TASKS} for a, r in results_by_agent.items()}
avg_steps = {a: {t: np.mean(r[t]["steps"]) for t in TASKS} for a, r in results_by_agent.items()}
_bar_chart(axes[0], win_rate, "Win Rate", "Win Rate by Difficulty")
_bar_chart(axes[1], avg_score, "Avg Final Score", "Score by Difficulty")
_bar_chart(axes[2], avg_steps, "Avg Steps", "Steps to Done by Difficulty")
plt.tight_layout()
path1 = os.path.join(out_dir, "comparison_bars.png")
fig.savefig(path1, dpi=150, bbox_inches="tight")
print(f" Saved: {path1}")
plt.close(fig)
# --- Figure 2: Infected-over-time curves per difficulty ---
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
fig.suptitle("Infected Nodes Over Time (mean ± std)", fontsize=13, fontweight="bold")
for ax, task in zip(axes, TASKS):
_infected_curve(ax, results_by_agent, task)
plt.tight_layout()
path2 = os.path.join(out_dir, "infected_curves.png")
fig.savefig(path2, dpi=150, bbox_inches="tight")
print(f" Saved: {path2}")
plt.close(fig)
# --- Figure 3: Score distribution violin plot ---
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("Score Distribution per Difficulty", fontsize=13, fontweight="bold")
for ax, task in zip(axes, TASKS):
data = [results_by_agent[a][task]["score"] for a in results_by_agent]
labels = list(results_by_agent.keys())
parts = ax.violinplot(data, showmedians=True, showextrema=True)
for pc, agent in zip(parts["bodies"], labels):
pc.set_facecolor(COLORS.get(agent, "grey"))
pc.set_alpha(0.7)
ax.set_xticks(range(1, len(labels) + 1))
ax.set_xticklabels([a.split()[0] for a in labels], fontsize=8)
ax.set_ylabel("Final Score")
ax.set_title(f"{task.upper()}")
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
path3 = os.path.join(out_dir, "score_distribution.png")
fig.savefig(path3, dpi=150, bbox_inches="tight")
print(f" Saved: {path3}")
plt.close(fig)
return [path1, path2, path3]
# ---------------------------------------------------------------------------
# WandB logging
# ---------------------------------------------------------------------------
def log_to_wandb(results_by_agent: Dict[str, Dict], graph_paths: List[str]):
try:
import wandb
if not callable(getattr(wandb, "init", None)):
return
except ImportError:
return
wandb.init(project="vir_env", name="eval-comparison", reinit=True)
for agent, results in results_by_agent.items():
for task in TASKS:
wandb.log({
f"{agent}/{task}/win_rate": np.mean(results[task]["won"]),
f"{agent}/{task}/avg_score": np.mean(results[task]["score"]),
f"{agent}/{task}/avg_steps": np.mean(results[task]["steps"]),
})
wandb.log({"graphs": [wandb.Image(p) for p in graph_paths]})
wandb.finish()
print(" Logged to wandb.")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--episodes", type=int, default=20, help="Episodes per agent per task")
p.add_argument("--lora-dir", type=str, default=None, help="Path to trained LoRA dir")
p.add_argument("--out-dir", type=str, default=GRAPH_DIR)
p.add_argument("--wandb", action="store_true", help="Log results to wandb")
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main():
args = parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
agents: Dict[str, Callable] = {
"Random (untrained)": random_agent,
"Greedy": greedy_agent,
}
if args.lora_dir:
print(f"Loading LoRA model from {args.lora_dir} ...")
try:
agents["LoRA (trained)"] = build_lora_agent(args.lora_dir)
print(" LoRA model loaded.")
except Exception as e:
print(f" Could not load LoRA model: {e}\n Skipping trained agent.")
print(f"\nRunning {args.episodes} episodes per agent per task ...\n")
results_by_agent: Dict[str, Dict] = {}
for name, fn in agents.items():
print(f"Agent: {name}")
results_by_agent[name] = run_agent(fn, name, args.episodes)
print("\nGenerating graphs ...")
graph_paths = generate_graphs(results_by_agent, args.out_dir)
if args.wandb:
print("Logging to wandb ...")
log_to_wandb(results_by_agent, graph_paths)
# Print summary table
print("\n" + "=" * 60)
print(f"{'Agent':<22} {'Task':<8} {'Win%':>6} {'Score':>8} {'Steps':>7}")
print("-" * 60)
for agent, results in results_by_agent.items():
for task in TASKS:
wr = np.mean(results[task]["won"]) * 100
sc = np.mean(results[task]["score"])
st = np.mean(results[task]["steps"])
print(f"{agent:<22} {task:<8} {wr:>5.1f}% {sc:>8.2f} {st:>7.1f}")
print("=" * 60)
print(f"\nGraphs saved to: {os.path.abspath(args.out_dir)}/")
if __name__ == "__main__":
main()