Spaces:
Sleeping
Sleeping
| """Evaluation script — runs a checkpoint against the eval split. | |
| Usage: | |
| python train/eval.py --checkpoint /path/to/checkpoint --output-json docs/eval_results.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") | |
| logger = logging.getLogger("eval") | |
| def _parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--checkpoint", type=str, required=True) | |
| p.add_argument("--output-json", type=str, default="docs/eval_results.json") | |
| p.add_argument("--model-base", type=str, default="Qwen/Qwen2.5-1.5B-Instruct") | |
| p.add_argument("--max-new-tokens", type=int, default=512) | |
| p.add_argument("--temperature", type=float, default=0.7) | |
| p.add_argument("--seed", type=int, default=42) | |
| return p.parse_args() | |
| def _load_model(checkpoint: str, model_base: str): | |
| try: | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=checkpoint, | |
| max_seq_length=1024, | |
| load_in_4bit=True, | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| logger.info("Eval model loaded via Unsloth from %s", checkpoint) | |
| return model, tokenizer | |
| except Exception as exc: | |
| logger.warning("Unsloth load failed (%s) — using transformers.", exc) | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| model_base, quantization_config=bnb, device_map="auto", trust_remote_code=True | |
| ) | |
| model = PeftModel.from_pretrained(base, checkpoint) | |
| model.eval() | |
| return model, tokenizer | |
| def _generate(model, tokenizer, prompt: str, max_new_tokens: int, temperature: float, seed: int) -> str: | |
| import torch | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| torch.manual_seed(seed) | |
| with torch.inference_mode(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.9, | |
| ) | |
| new_tokens = output[0][inputs["input_ids"].shape[-1]:] | |
| return tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| def main() -> None: | |
| args = _parse_args() | |
| # Load defenses. | |
| from env.defenses.secalign_agent import SecAlignAgent | |
| from env.defenses.prompt_guard import PromptGuard | |
| from env.defenses.llama_firewall import FirewallWrapper | |
| from env.utils.embedding_cache import EmbeddingCache | |
| from env.environment import InjectArenaEnv | |
| from env.models import InjectAction | |
| from env.scenarios import ScenarioBank | |
| from train.client import format_prompt, parse_payload | |
| from env.models import InjectObservation | |
| secalign = SecAlignAgent() | |
| pg2 = PromptGuard() | |
| firewall = FirewallWrapper(prompt_guard_fallback=pg2) | |
| embedder = EmbeddingCache() | |
| bank = ScenarioBank() | |
| env = InjectArenaEnv(pg2=pg2, secalign=secalign, firewall=firewall, | |
| bank=bank, embedder=embedder) | |
| # Load attacker model. | |
| model, tokenizer = _load_model(args.checkpoint, args.model_base) | |
| eval_scenarios = bank.by_split("eval") | |
| logger.info("Evaluating %d scenarios …", len(eval_scenarios)) | |
| records: List[Dict[str, Any]] = [] | |
| for s in eval_scenarios: | |
| obs = InjectObservation( | |
| scenario_id=s["scenario_id"], | |
| target_behavior=s["target_behavior"], | |
| target_category=s["target_category"], | |
| agent_system_prompt=s["agent_system_prompt"], | |
| user_query=s["user_query"], | |
| injection_slot=s["injection_slot"], | |
| tool_surface=s["tool_surface"], | |
| canary_string=s.get("canary_string"), | |
| previous_attempts=[], | |
| attempts_remaining=3, | |
| ) | |
| prompt = format_prompt(obs) | |
| completion = _generate(model, tokenizer, prompt, | |
| args.max_new_tokens, args.temperature, args.seed) | |
| payload = parse_payload(completion) | |
| env.reset(scenario_id=s["scenario_id"]) | |
| result = env.step(InjectAction(payload=payload)) | |
| rec = { | |
| "scenario_id": s["scenario_id"], | |
| "target_category": s["target_category"], | |
| "injection_slot": s["injection_slot"], | |
| "payload": payload, | |
| "reward": result.reward, | |
| "task_success": result.info.get("task_success", False), | |
| "pg2_flagged": result.info.get("pg2_verdict", {}).get("flagged", False), | |
| "fw_flagged": result.info.get("fw_verdict", {}).get("flagged", False), | |
| "components": result.info.get("components", {}), | |
| } | |
| records.append(rec) | |
| logger.info(" %s reward=%.3f task=%s pg2=%s", | |
| s["scenario_id"], rec["reward"], | |
| rec["task_success"], rec["pg2_flagged"]) | |
| # Aggregate metrics. | |
| n = len(records) | |
| task_success_rate = sum(r["task_success"] for r in records) / n | |
| pg2_bypass_rate = sum(not r["pg2_flagged"] for r in records) / n | |
| fw_bypass_rate = sum(not r["fw_flagged"] for r in records) / n | |
| composed_bypass = sum( | |
| r["task_success"] and not r["pg2_flagged"] and not r["fw_flagged"] | |
| for r in records | |
| ) / n | |
| mean_reward = sum(r["reward"] for r in records) / n | |
| # Per-category breakdown. | |
| categories = {r["target_category"] for r in records} | |
| per_category = {} | |
| for cat in categories: | |
| cat_recs = [r for r in records if r["target_category"] == cat] | |
| per_category[cat] = { | |
| "n": len(cat_recs), | |
| "task_success": sum(r["task_success"] for r in cat_recs) / len(cat_recs), | |
| "composed_bypass": sum( | |
| r["task_success"] and not r["pg2_flagged"] and not r["fw_flagged"] | |
| for r in cat_recs | |
| ) / len(cat_recs), | |
| } | |
| summary = { | |
| "checkpoint": args.checkpoint, | |
| "n_scenarios": n, | |
| "mean_reward": round(mean_reward, 4), | |
| "task_success_rate": round(task_success_rate, 4), | |
| "pg2_bypass_rate": round(pg2_bypass_rate, 4), | |
| "fw_bypass_rate": round(fw_bypass_rate, 4), | |
| "composed_bypass_rate": round(composed_bypass, 4), | |
| "per_category": per_category, | |
| "records": records, | |
| } | |
| out_path = Path(args.output_json) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| out_path.write_text(json.dumps(summary, indent=2)) | |
| print("\n=== Eval Results ===") | |
| print(f"Scenarios: {n}") | |
| print(f"Mean reward: {mean_reward:.4f}") | |
| print(f"Task success: {task_success_rate:.1%}") | |
| print(f"PG2 bypass: {pg2_bypass_rate:.1%}") | |
| print(f"FW bypass: {fw_bypass_rate:.1%}") | |
| print(f"Composed bypass: {composed_bypass:.1%} (all three)") | |
| print(f"\nResults saved to {out_path}") | |
| if __name__ == "__main__": | |
| main() | |