opensec-env / scripts /eval_checkpoint.py
Jarrodbarnes's picture
Upload folder using huggingface_hub
3f434eb verified
#!/usr/bin/env python3
"""Evaluate a trained checkpoint on the OpenSec-Env eval split.
Uses the same evaluation protocol as run_llm_baseline.py but with a local model.
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, List
import requests
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from eval_utils import extract_json, injection_evidence_ids, load_env, load_json
from oracle.scoring import containment_to_dict, score_report
from server.environment import OpenSecEnvironment
from server.models import AgentAction
from training.prompts import SYSTEM_PROMPT, build_observation_prompt
ALLOWED_ACTIONS = [
"query_logs",
"fetch_email",
"fetch_alert",
"isolate_host",
"block_domain",
"reset_user",
"submit_report",
]
def _normalize_action(data: Dict[str, Any]) -> AgentAction:
action_type = data.get("action_type") if isinstance(data, dict) else None
if action_type not in ALLOWED_ACTIONS:
action_type = "query_logs"
params = data.get("params") if isinstance(data, dict) and isinstance(data.get("params"), dict) else {}
if action_type == "query_logs" and "sql" not in params:
params["sql"] = "SELECT 1"
if action_type == "fetch_email" and "email_id" not in params:
params["email_id"] = ""
if action_type == "fetch_alert" and "alert_id" not in params:
params["alert_id"] = ""
return AgentAction(action_type=action_type, params=params)
def _call_sglang(
url: str,
model_id: str,
messages: List[Dict[str, str]],
max_tokens: int,
temperature: float,
) -> str:
"""Call SGLang server via OpenAI-compatible API."""
response = requests.post(
f"{url}/v1/chat/completions",
json={
"model": model_id,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
},
timeout=120,
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
def _call_local_model(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
messages: List[Dict[str, str]],
max_tokens: int,
temperature: float,
) -> str:
"""Call local model directly."""
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False if "qwen3" in tokenizer.name_or_path.lower() else None
)
else:
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else 1.0,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
def _invoke_model(
model: AutoModelForCausalLM | None,
tokenizer: AutoTokenizer | None,
sglang_url: str | None,
sglang_model: str | None,
messages: List[Dict[str, str]],
temperature: float = 0.2,
max_tokens: int = 600,
) -> tuple[str, Dict[str, Any]]:
"""Invoke model for a single step using full conversation history."""
if sglang_url:
if not sglang_model:
raise ValueError("sglang_model must be provided when using sglang_url")
text = _call_sglang(sglang_url, sglang_model, messages, max_tokens, temperature)
elif model is not None and tokenizer is not None:
text = _call_local_model(model, tokenizer, messages, max_tokens, temperature)
else:
raise ValueError("Either sglang_url or model/tokenizer must be provided")
try:
return text, extract_json(text)
except Exception as exc:
if os.getenv("OPENSEC_DEFENDER_STRICT", "0") == "1":
raise RuntimeError("defender_invalid_json") from exc
return text, {"action_type": "query_logs", "params": {"sql": "SELECT 1"}}
def _default_report() -> Dict[str, Any]:
return {
"patient_zero_host": "unknown",
"compromised_user": "unknown",
"attacker_domain": "unknown",
"data_target": "unknown",
"initial_vector": "phish",
"containment_actions": {"isolated_hosts": [], "blocked_domains": [], "reset_users": []},
}
def run_episode(
seed_path: Path,
model: AutoModelForCausalLM | None,
tokenizer: AutoTokenizer | None,
sglang_url: str | None,
sglang_model: str | None,
max_steps: int,
temperature: float = 0.2,
) -> Dict[str, Any]:
"""Run a single evaluation episode."""
env = OpenSecEnvironment(seed_path=str(seed_path))
reset_result = env.reset()
observation = reset_result.observation.model_dump()
messages: List[Dict[str, str]] = [
{"role": "system", "content": SYSTEM_PROMPT.format(max_steps=max_steps)}
]
messages.append({"role": "user", "content": build_observation_prompt(observation)})
steps: List[Dict[str, Any]] = []
report: Dict[str, Any] | None = None
submitted_report = False
for step_idx in range(max_steps):
model_text, model_action = _invoke_model(
model,
tokenizer,
sglang_url,
sglang_model,
messages,
temperature,
)
messages.append({"role": "assistant", "content": model_text})
action = _normalize_action(model_action)
result = env.step(action)
observation = result.observation.model_dump()
messages.append({"role": "user", "content": build_observation_prompt(observation)})
steps.append({
"step": step_idx,
"action": action.model_dump(),
"attacker_action": result.info.get("attacker_action"),
"injection_violations": result.info.get("injection_violations", []),
})
if action.action_type == "submit_report":
report = action.params.get("summary_json")
submitted_report = True
break
if report is None:
report = _default_report()
result = env.step(AgentAction(action_type="submit_report", params={"summary_json": report}))
injection_evidence = injection_evidence_ids(env.scenario)
injection_encountered = injection_evidence.intersection(env.content_evidence_ids)
containment_attempted = bool(
env.containment.isolated_hosts
or env.containment.blocked_domains
or env.containment.reset_users
)
executed_containment = containment_to_dict(env.containment)
return {
"scenario_id": env.scenario_id,
"seed_path": str(seed_path),
"step_count": env.step_count,
"report": report,
"executed_containment": executed_containment,
"submitted_report": submitted_report,
"steps": steps,
"injection_violations": result.info.get("injection_violations", []),
"diagnostics": {
"evidence_seen_count": len(env.seen_evidence_ids),
"evidence_content_count": len(env.content_evidence_ids),
"injection_evidence_total": len(injection_evidence),
"injection_evidence_seen": len(injection_encountered),
"containment_attempted": containment_attempted,
"report_submitted": submitted_report,
},
}
def main() -> int:
parser = argparse.ArgumentParser(description="Evaluate trained checkpoint on OpenSec-Env")
parser.add_argument("--checkpoint", required=True, help="Path to trained checkpoint")
parser.add_argument("--sglang-url", default=None, help="SGLang server URL (faster than local)")
parser.add_argument("--sglang-model", default=os.getenv("OPENSEC_SGLANG_MODEL"),
help="SGLang model ID (e.g., HF repo or checkpoint path)")
parser.add_argument("--manifest", default="data/seeds/manifest.json")
parser.add_argument("--split", default="eval", choices=["train", "eval"])
parser.add_argument("--tier", default="standard", choices=["trivial", "easy", "standard", "all"])
parser.add_argument("--limit", type=int, default=0, help="Limit seeds (0 = all)")
parser.add_argument("--max-steps", type=int, default=15)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--output", default="outputs/checkpoint_eval.jsonl")
parser.add_argument("--summary", default="outputs/checkpoint_eval_summary.json")
args = parser.parse_args()
load_env()
# Load model if not using SGLang
model = None
tokenizer = None
if not args.sglang_url:
print(f"Loading checkpoint: {args.checkpoint}", flush=True)
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
print("Model loaded", flush=True)
else:
if not args.sglang_model:
args.sglang_model = args.checkpoint
print(f"Using SGLang server: {args.sglang_url} (model={args.sglang_model})", flush=True)
# Load seeds
manifest = load_json(Path(args.manifest))
seeds = manifest[args.split]
if args.tier != "all":
seeds = [entry for entry in seeds if entry.get("tier") == args.tier]
seeds = [Path(entry["seed_path"]) for entry in seeds]
if args.limit:
seeds = seeds[:args.limit]
print(f"Evaluating on {len(seeds)} seeds ({args.split}/{args.tier})", flush=True)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
rewards: List[float] = []
containment_executed_count = 0
report_submitted_count = 0
correct_containment_count = 0
false_positive_count = 0
injection_exposure_count = 0
injection_violation_count = 0
with output_path.open("w") as f:
for i, seed_path in enumerate(seeds):
print(f"[{i+1}/{len(seeds)}] {seed_path.name}...", end=" ", flush=True)
result = run_episode(
seed_path, model, tokenizer,
args.sglang_url, args.sglang_model,
args.max_steps, args.temperature
)
gt_path = seed_path.with_name(seed_path.name.replace("_seed.json", "_ground_truth.json"))
ground_truth = load_json(gt_path)
score = score_report(
result["report"],
ground_truth,
result["step_count"],
result["injection_violations"],
executed_containment=result.get("executed_containment"),
)
containment_details = score.details.get("containment", {})
correct_total = 0
false_total = 0
for field in ("isolated_hosts", "blocked_domains", "reset_users"):
field_data = containment_details.get(field, {})
correct_total += len(field_data.get("correct", []))
false_total += len(field_data.get("false_positive", []))
if not result["submitted_report"]:
score.reward = 0.0
score.details["failure"] = "no_report_submitted"
row = {
"scenario_id": result["scenario_id"],
"seed_path": result["seed_path"],
"step_count": result["step_count"],
"reward": score.reward,
"details": score.details,
"submitted_report": result["submitted_report"],
"executed_containment": result["executed_containment"],
"containment_correct_total": correct_total,
"containment_false_positive_total": false_total,
"diagnostics": result["diagnostics"],
}
f.write(json.dumps(row) + "\n")
rewards.append(score.reward)
if result["diagnostics"]["containment_attempted"]:
containment_executed_count += 1
if result["submitted_report"]:
report_submitted_count += 1
if correct_total > 0:
correct_containment_count += 1
if false_total > 0:
false_positive_count += 1
if result["diagnostics"]["injection_evidence_seen"] > 0:
injection_exposure_count += 1
if score.details.get("injection", {}).get("violations"):
injection_violation_count += 1
print(f"reward={score.reward:.2f}, containment={result['diagnostics']['containment_attempted']}", flush=True)
# Summary statistics
import statistics
summary = {
"checkpoint": args.checkpoint,
"split": args.split,
"tier": args.tier,
"runs": len(rewards),
"mean_reward": statistics.mean(rewards) if rewards else 0,
"std_reward": statistics.stdev(rewards) if len(rewards) > 1 else 0,
"min_reward": min(rewards) if rewards else 0,
"max_reward": max(rewards) if rewards else 0,
"containment_executed_rate": containment_executed_count / len(rewards) if rewards else 0,
"report_submitted_rate": report_submitted_count / len(rewards) if rewards else 0,
"correct_containment_rate": correct_containment_count / len(rewards) if rewards else 0,
"false_positive_rate": false_positive_count / len(rewards) if rewards else 0,
"injection_exposure_rate": injection_exposure_count / len(rewards) if rewards else 0,
"injection_violation_rate": injection_violation_count / len(rewards) if rewards else 0,
}
Path(args.summary).write_text(json.dumps(summary, indent=2))
print("\n" + "="*60)
print("EVALUATION SUMMARY")
print("="*60)
print(f"Checkpoint: {args.checkpoint}")
print(f"Seeds: {len(rewards)} ({args.split}/{args.tier})")
print(f"Reward: {summary['mean_reward']:.2f} ± {summary['std_reward']:.2f}")
print(f"Containment Executed: {summary['containment_executed_rate']*100:.1f}%")
print(f"Report Submitted: {summary['report_submitted_rate']*100:.1f}%")
print("="*60)
print(f"\nWrote: {output_path} and {args.summary}")
return 0
if __name__ == "__main__":
raise SystemExit(main())