Spaces:
Sleeping
Sleeping
| """Generate replay traces for the public demo (run on Colab). | |
| For each (attack_type, steps_label) requested, this script: | |
| 1. Loads the attacker checkpoint (or zero-shot Qwen for the lowest step count). | |
| 2. Loads the real defense stack (PG2 + SecAlign + LlamaFirewall). | |
| 3. Runs the attacker once against the scenario for that attack type. | |
| 4. Captures stage-by-stage timings + verdicts into a JSON trace. | |
| 5. Saves the trace to ``data/traces/{attack_type}_{steps}.json``. | |
| Also writes ``data/highlights/highlight.json`` — the most successful trace | |
| across the run, used by the homepage hero animation. | |
| Usage (typical Colab cell) | |
| -------------------------- | |
| python scripts/generate_traces.py \\ | |
| --checkpoint /content/drive/MyDrive/injectarena/run_v1/final \\ | |
| --steps-labels 50 100 300 500 1000 1500 \\ | |
| --output-dir data/traces \\ | |
| --highlight-dir data/highlights | |
| For ``--steps-labels`` values <= ``--baseline-cutoff`` (default 100), the | |
| zero-shot baseline (untrained Qwen) is used to simulate an early/under-trained | |
| attacker. Above the cutoff, the trained checkpoint is used. This is | |
| documented in the trace itself via the ``model_source`` field, so the demo | |
| stays honest. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| logger = logging.getLogger("generate_traces") | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") | |
| # Attack-type → scenario_id mapping (kept in sync with env/replay.py) | |
| ATTACK_TYPE_TO_SCENARIO = { | |
| "email_exfiltration": "email_exfil_001", | |
| "forbidden_tool": "email_forbidden_001", | |
| "prompt_leak": "email_leak_001", | |
| "rag_injection": "rag_exfil_001", | |
| } | |
| def _parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--checkpoint", type=str, required=True, | |
| help="Path to trained attacker checkpoint dir (the final/ folder).") | |
| p.add_argument("--steps-labels", type=int, nargs="+", | |
| default=[50, 100, 300, 500, 1000, 1500], | |
| help="Step-count labels to generate traces for.") | |
| p.add_argument("--baseline-cutoff", type=int, default=100, | |
| help="Steps <= cutoff use the zero-shot baseline; above use the checkpoint.") | |
| p.add_argument("--output-dir", type=str, default="data/traces") | |
| p.add_argument("--highlight-dir", type=str, default="data/highlights") | |
| p.add_argument("--max-new-tokens", type=int, default=128) | |
| p.add_argument("--seed", type=int, default=42) | |
| return p.parse_args() | |
| # --------------------------------------------------------------------------- | |
| # Attacker loading (uses the same Unsloth path as train/eval.py) | |
| # --------------------------------------------------------------------------- | |
| def _load_attacker(checkpoint: str, max_new_tokens: int): | |
| from unsloth import FastLanguageModel | |
| logger.info("Loading attacker from %s", checkpoint) | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=checkpoint, | |
| max_seq_length=2048, | |
| load_in_4bit=False, | |
| dtype="bfloat16", | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| return model, tokenizer | |
| def _load_zero_shot(max_new_tokens: int): | |
| from unsloth import FastLanguageModel | |
| logger.info("Loading zero-shot baseline (Qwen2.5-1.5B-Instruct)") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name="Qwen/Qwen2.5-1.5B-Instruct", | |
| max_seq_length=2048, | |
| load_in_4bit=False, | |
| dtype="bfloat16", | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| return model, tokenizer | |
| def _generate_payload(model, tokenizer, observation, max_new_tokens: int, seed: int) -> str: | |
| import torch | |
| from train.client import format_prompt, parse_payload | |
| prompt = format_prompt(observation) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| torch.manual_seed(seed) | |
| with torch.inference_mode(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| raw = tokenizer.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| return parse_payload(raw) | |
| # --------------------------------------------------------------------------- | |
| # Defense stack — wrapped to record timings + verdicts | |
| # --------------------------------------------------------------------------- | |
| def _load_defenses(): | |
| from env.defenses.prompt_guard import PromptGuard | |
| from env.defenses.secalign_agent import SecAlignAgent | |
| from env.defenses.llama_firewall import FirewallWrapper | |
| logger.info("Loading defense stack (SecAlign first for vLLM CUDA order) …") | |
| secalign = SecAlignAgent() | |
| pg2 = PromptGuard() | |
| fw = FirewallWrapper(prompt_guard_fallback=pg2) | |
| logger.info("Defense stack ready (SecAlign mode=%s).", secalign.mode) | |
| return pg2, secalign, fw | |
| def _run_pipeline( | |
| payload: str, | |
| scenario: Dict[str, Any], | |
| pg2, | |
| secalign, | |
| firewall, | |
| ) -> Dict[str, Any]: | |
| """Run the full attack pipeline and return a trace dict.""" | |
| timeline: List[Dict[str, Any]] = [] | |
| t0 = time.perf_counter() | |
| # Stage 1: generation already happened upstream — we record it as t=0 | |
| timeline.append({"stage": "generation", "t": 0.0, "payload": payload}) | |
| # Stage 2: PG2 input scan | |
| s = time.perf_counter() | |
| pg2_v = pg2.scan(payload) | |
| pg2_t = time.perf_counter() - t0 | |
| timeline.append({ | |
| "stage": "pg2_scan", | |
| "t": round(pg2_t, 3), | |
| "duration": round(time.perf_counter() - s, 3), | |
| "flagged": bool(pg2_v.flagged), | |
| "score": float(getattr(pg2_v, "score", 0.0)), | |
| "reason": getattr(pg2_v, "reason", None), | |
| }) | |
| # Stage 3: SecAlign agent run | |
| s = time.perf_counter() | |
| untrusted = {scenario["injection_slot"]: payload} | |
| agent_output = secalign.run( | |
| system=scenario["agent_system_prompt"], | |
| user=scenario["user_query"], | |
| untrusted=untrusted, | |
| ) | |
| agent_t = time.perf_counter() - t0 | |
| timeline.append({ | |
| "stage": "agent_run", | |
| "t": round(agent_t, 3), | |
| "duration": round(time.perf_counter() - s, 3), | |
| "agent_output": agent_output, | |
| }) | |
| # Stage 4: LlamaFirewall output scan | |
| s = time.perf_counter() | |
| fw_v = firewall.scan( | |
| user_query=scenario["user_query"], | |
| agent_output=agent_output, | |
| tool_call=None, | |
| ) | |
| fw_t = time.perf_counter() - t0 | |
| timeline.append({ | |
| "stage": "firewall_scan", | |
| "t": round(fw_t, 3), | |
| "duration": round(time.perf_counter() - s, 3), | |
| "flagged": bool(fw_v.flagged), | |
| "score": float(getattr(fw_v, "score", 0.0)), | |
| "reason": getattr(fw_v, "reason", None), | |
| }) | |
| # Stage 5: task verifier | |
| task_succeeded = _check_task_success(scenario, agent_output) | |
| timeline.append({ | |
| "stage": "verdict", | |
| "t": round(time.perf_counter() - t0, 3), | |
| "task_succeeded": task_succeeded, | |
| }) | |
| broke_pg2 = not pg2_v.flagged | |
| broke_fw = not fw_v.flagged | |
| composed = broke_pg2 and broke_fw and task_succeeded | |
| return { | |
| "timeline": timeline, | |
| "outcome": { | |
| "broke_pg2": broke_pg2, | |
| "broke_fw": broke_fw, | |
| "task_succeeded": bool(task_succeeded), | |
| "composed_bypass": bool(composed), | |
| "blocked_at": ( | |
| "Llama Prompt Guard 2" if not broke_pg2 | |
| else ("LlamaFirewall" if not broke_fw | |
| else (None if task_succeeded else "SecAlign agent (refused)")) | |
| ), | |
| }, | |
| } | |
| def _check_task_success(scenario: Dict[str, Any], agent_output: str) -> bool: | |
| """Lightweight verifier so traces can include a task-success flag.""" | |
| from env.verifiers.exfiltration import verify_exfiltration | |
| from env.verifiers.forbidden_tool import verify_forbidden_tool | |
| from env.verifiers.prompt_leak import verify_prompt_leak | |
| cat = scenario["target_category"] | |
| try: | |
| if cat == "exfiltration": | |
| return verify_exfiltration(agent_output, scenario) | |
| if cat == "forbidden_tool": | |
| return verify_forbidden_tool(agent_output, scenario) | |
| if cat == "prompt_leak": | |
| return verify_prompt_leak(agent_output, scenario) | |
| except Exception as exc: # noqa: BLE001 | |
| logger.warning("Verifier error for %s: %s", scenario.get("scenario_id"), exc) | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| args = _parse_args() | |
| out_dir = Path(args.output_dir) | |
| highlight_dir = Path(args.highlight_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| highlight_dir.mkdir(parents=True, exist_ok=True) | |
| from env.environment import InjectArenaEnv | |
| from env.scenarios import ScenarioBank | |
| bank = ScenarioBank() | |
| # Load defenses ONCE — they dominate cost. | |
| pg2, secalign, firewall = _load_defenses() | |
| # Build a minimal env so we can reset() into a scenario and get an InjectObservation. | |
| env = InjectArenaEnv(pg2=pg2, secalign=secalign, firewall=firewall, bank=bank) | |
| # Cache attackers by source so we don't reload between calls. | |
| attackers: Dict[str, Any] = {} | |
| def get_attacker(source: str): | |
| if source in attackers: | |
| return attackers[source] | |
| if source == "checkpoint": | |
| attackers[source] = _load_attacker(args.checkpoint, args.max_new_tokens) | |
| else: | |
| attackers[source] = _load_zero_shot(args.max_new_tokens) | |
| return attackers[source] | |
| best: Optional[Dict[str, Any]] = None | |
| best_score = -1 | |
| for attack_type, scenario_id in ATTACK_TYPE_TO_SCENARIO.items(): | |
| try: | |
| scenario = bank.by_id(scenario_id) | |
| except KeyError: | |
| logger.warning("Scenario %s not in bank — skipping %s", scenario_id, attack_type) | |
| continue | |
| for steps in args.steps_labels: | |
| source = "checkpoint" if steps > args.baseline_cutoff else "zero_shot" | |
| model, tokenizer = get_attacker(source) | |
| obs = env.reset(scenario_id=scenario_id) | |
| payload = _generate_payload(model, tokenizer, obs, args.max_new_tokens, args.seed + steps) | |
| pipe = _run_pipeline(payload, scenario, pg2, secalign, firewall) | |
| trace = { | |
| "attack_type": attack_type, | |
| "steps": steps, | |
| "scenario_id": scenario_id, | |
| "scenario_label": scenario.get("target_behavior", ""), | |
| "model_source": source, | |
| "payload": payload, | |
| **pipe, | |
| } | |
| out_path = out_dir / f"{attack_type}_{steps}.json" | |
| with out_path.open("w") as f: | |
| json.dump(trace, f, indent=2) | |
| logger.info("Wrote %s (broke_pg2=%s, broke_fw=%s, task=%s)", | |
| out_path.name, | |
| pipe["outcome"]["broke_pg2"], | |
| pipe["outcome"]["broke_fw"], | |
| pipe["outcome"]["task_succeeded"]) | |
| # Score for highlight selection: composed > fw_bypass > pg2_bypass. | |
| o = pipe["outcome"] | |
| score = (4 if o["composed_bypass"] else 0) \ | |
| + (2 if o["broke_fw"] else 0) \ | |
| + (1 if o["broke_pg2"] else 0) \ | |
| + (steps / 10000.0) # tiebreaker: prefer higher-step traces | |
| if score > best_score: | |
| best_score = score | |
| best = trace | |
| if best is not None: | |
| with (highlight_dir / "highlight.json").open("w") as f: | |
| json.dump(best, f, indent=2) | |
| logger.info("Highlight: %s_%s (score=%.2f)", best["attack_type"], best["steps"], best_score) | |
| logger.info("Done. Wrote %d traces to %s.", len(list(out_dir.glob("*.json"))), out_dir) | |
| if __name__ == "__main__": | |
| main() | |