Inject-Arena / scripts /generate_traces.py
Jaswanth1210's picture
feat: HF Space replay backend — trace store, /api endpoints, Docker, Cell 9
8c536e6
"""Generate replay traces for the public demo (run on Colab).
For each (attack_type, steps_label) requested, this script:
1. Loads the attacker checkpoint (or zero-shot Qwen for the lowest step count).
2. Loads the real defense stack (PG2 + SecAlign + LlamaFirewall).
3. Runs the attacker once against the scenario for that attack type.
4. Captures stage-by-stage timings + verdicts into a JSON trace.
5. Saves the trace to ``data/traces/{attack_type}_{steps}.json``.
Also writes ``data/highlights/highlight.json`` — the most successful trace
across the run, used by the homepage hero animation.
Usage (typical Colab cell)
--------------------------
python scripts/generate_traces.py \\
--checkpoint /content/drive/MyDrive/injectarena/run_v1/final \\
--steps-labels 50 100 300 500 1000 1500 \\
--output-dir data/traces \\
--highlight-dir data/highlights
For ``--steps-labels`` values <= ``--baseline-cutoff`` (default 100), the
zero-shot baseline (untrained Qwen) is used to simulate an early/under-trained
attacker. Above the cutoff, the trained checkpoint is used. This is
documented in the trace itself via the ``model_source`` field, so the demo
stays honest.
"""
from __future__ import annotations
import argparse
import json
import logging
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger("generate_traces")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
# Attack-type → scenario_id mapping (kept in sync with env/replay.py)
ATTACK_TYPE_TO_SCENARIO = {
"email_exfiltration": "email_exfil_001",
"forbidden_tool": "email_forbidden_001",
"prompt_leak": "email_leak_001",
"rag_injection": "rag_exfil_001",
}
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", type=str, required=True,
help="Path to trained attacker checkpoint dir (the final/ folder).")
p.add_argument("--steps-labels", type=int, nargs="+",
default=[50, 100, 300, 500, 1000, 1500],
help="Step-count labels to generate traces for.")
p.add_argument("--baseline-cutoff", type=int, default=100,
help="Steps <= cutoff use the zero-shot baseline; above use the checkpoint.")
p.add_argument("--output-dir", type=str, default="data/traces")
p.add_argument("--highlight-dir", type=str, default="data/highlights")
p.add_argument("--max-new-tokens", type=int, default=128)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
# ---------------------------------------------------------------------------
# Attacker loading (uses the same Unsloth path as train/eval.py)
# ---------------------------------------------------------------------------
def _load_attacker(checkpoint: str, max_new_tokens: int):
from unsloth import FastLanguageModel
logger.info("Loading attacker from %s", checkpoint)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=checkpoint,
max_seq_length=2048,
load_in_4bit=False,
dtype="bfloat16",
)
FastLanguageModel.for_inference(model)
return model, tokenizer
def _load_zero_shot(max_new_tokens: int):
from unsloth import FastLanguageModel
logger.info("Loading zero-shot baseline (Qwen2.5-1.5B-Instruct)")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="Qwen/Qwen2.5-1.5B-Instruct",
max_seq_length=2048,
load_in_4bit=False,
dtype="bfloat16",
)
FastLanguageModel.for_inference(model)
return model, tokenizer
def _generate_payload(model, tokenizer, observation, max_new_tokens: int, seed: int) -> str:
import torch
from train.client import format_prompt, parse_payload
prompt = format_prompt(observation)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
torch.manual_seed(seed)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
raw = tokenizer.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
return parse_payload(raw)
# ---------------------------------------------------------------------------
# Defense stack — wrapped to record timings + verdicts
# ---------------------------------------------------------------------------
def _load_defenses():
from env.defenses.prompt_guard import PromptGuard
from env.defenses.secalign_agent import SecAlignAgent
from env.defenses.llama_firewall import FirewallWrapper
logger.info("Loading defense stack (SecAlign first for vLLM CUDA order) …")
secalign = SecAlignAgent()
pg2 = PromptGuard()
fw = FirewallWrapper(prompt_guard_fallback=pg2)
logger.info("Defense stack ready (SecAlign mode=%s).", secalign.mode)
return pg2, secalign, fw
def _run_pipeline(
payload: str,
scenario: Dict[str, Any],
pg2,
secalign,
firewall,
) -> Dict[str, Any]:
"""Run the full attack pipeline and return a trace dict."""
timeline: List[Dict[str, Any]] = []
t0 = time.perf_counter()
# Stage 1: generation already happened upstream — we record it as t=0
timeline.append({"stage": "generation", "t": 0.0, "payload": payload})
# Stage 2: PG2 input scan
s = time.perf_counter()
pg2_v = pg2.scan(payload)
pg2_t = time.perf_counter() - t0
timeline.append({
"stage": "pg2_scan",
"t": round(pg2_t, 3),
"duration": round(time.perf_counter() - s, 3),
"flagged": bool(pg2_v.flagged),
"score": float(getattr(pg2_v, "score", 0.0)),
"reason": getattr(pg2_v, "reason", None),
})
# Stage 3: SecAlign agent run
s = time.perf_counter()
untrusted = {scenario["injection_slot"]: payload}
agent_output = secalign.run(
system=scenario["agent_system_prompt"],
user=scenario["user_query"],
untrusted=untrusted,
)
agent_t = time.perf_counter() - t0
timeline.append({
"stage": "agent_run",
"t": round(agent_t, 3),
"duration": round(time.perf_counter() - s, 3),
"agent_output": agent_output,
})
# Stage 4: LlamaFirewall output scan
s = time.perf_counter()
fw_v = firewall.scan(
user_query=scenario["user_query"],
agent_output=agent_output,
tool_call=None,
)
fw_t = time.perf_counter() - t0
timeline.append({
"stage": "firewall_scan",
"t": round(fw_t, 3),
"duration": round(time.perf_counter() - s, 3),
"flagged": bool(fw_v.flagged),
"score": float(getattr(fw_v, "score", 0.0)),
"reason": getattr(fw_v, "reason", None),
})
# Stage 5: task verifier
task_succeeded = _check_task_success(scenario, agent_output)
timeline.append({
"stage": "verdict",
"t": round(time.perf_counter() - t0, 3),
"task_succeeded": task_succeeded,
})
broke_pg2 = not pg2_v.flagged
broke_fw = not fw_v.flagged
composed = broke_pg2 and broke_fw and task_succeeded
return {
"timeline": timeline,
"outcome": {
"broke_pg2": broke_pg2,
"broke_fw": broke_fw,
"task_succeeded": bool(task_succeeded),
"composed_bypass": bool(composed),
"blocked_at": (
"Llama Prompt Guard 2" if not broke_pg2
else ("LlamaFirewall" if not broke_fw
else (None if task_succeeded else "SecAlign agent (refused)"))
),
},
}
def _check_task_success(scenario: Dict[str, Any], agent_output: str) -> bool:
"""Lightweight verifier so traces can include a task-success flag."""
from env.verifiers.exfiltration import verify_exfiltration
from env.verifiers.forbidden_tool import verify_forbidden_tool
from env.verifiers.prompt_leak import verify_prompt_leak
cat = scenario["target_category"]
try:
if cat == "exfiltration":
return verify_exfiltration(agent_output, scenario)
if cat == "forbidden_tool":
return verify_forbidden_tool(agent_output, scenario)
if cat == "prompt_leak":
return verify_prompt_leak(agent_output, scenario)
except Exception as exc: # noqa: BLE001
logger.warning("Verifier error for %s: %s", scenario.get("scenario_id"), exc)
return False
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
args = _parse_args()
out_dir = Path(args.output_dir)
highlight_dir = Path(args.highlight_dir)
out_dir.mkdir(parents=True, exist_ok=True)
highlight_dir.mkdir(parents=True, exist_ok=True)
from env.environment import InjectArenaEnv
from env.scenarios import ScenarioBank
bank = ScenarioBank()
# Load defenses ONCE — they dominate cost.
pg2, secalign, firewall = _load_defenses()
# Build a minimal env so we can reset() into a scenario and get an InjectObservation.
env = InjectArenaEnv(pg2=pg2, secalign=secalign, firewall=firewall, bank=bank)
# Cache attackers by source so we don't reload between calls.
attackers: Dict[str, Any] = {}
def get_attacker(source: str):
if source in attackers:
return attackers[source]
if source == "checkpoint":
attackers[source] = _load_attacker(args.checkpoint, args.max_new_tokens)
else:
attackers[source] = _load_zero_shot(args.max_new_tokens)
return attackers[source]
best: Optional[Dict[str, Any]] = None
best_score = -1
for attack_type, scenario_id in ATTACK_TYPE_TO_SCENARIO.items():
try:
scenario = bank.by_id(scenario_id)
except KeyError:
logger.warning("Scenario %s not in bank — skipping %s", scenario_id, attack_type)
continue
for steps in args.steps_labels:
source = "checkpoint" if steps > args.baseline_cutoff else "zero_shot"
model, tokenizer = get_attacker(source)
obs = env.reset(scenario_id=scenario_id)
payload = _generate_payload(model, tokenizer, obs, args.max_new_tokens, args.seed + steps)
pipe = _run_pipeline(payload, scenario, pg2, secalign, firewall)
trace = {
"attack_type": attack_type,
"steps": steps,
"scenario_id": scenario_id,
"scenario_label": scenario.get("target_behavior", ""),
"model_source": source,
"payload": payload,
**pipe,
}
out_path = out_dir / f"{attack_type}_{steps}.json"
with out_path.open("w") as f:
json.dump(trace, f, indent=2)
logger.info("Wrote %s (broke_pg2=%s, broke_fw=%s, task=%s)",
out_path.name,
pipe["outcome"]["broke_pg2"],
pipe["outcome"]["broke_fw"],
pipe["outcome"]["task_succeeded"])
# Score for highlight selection: composed > fw_bypass > pg2_bypass.
o = pipe["outcome"]
score = (4 if o["composed_bypass"] else 0) \
+ (2 if o["broke_fw"] else 0) \
+ (1 if o["broke_pg2"] else 0) \
+ (steps / 10000.0) # tiebreaker: prefer higher-step traces
if score > best_score:
best_score = score
best = trace
if best is not None:
with (highlight_dir / "highlight.json").open("w") as f:
json.dump(best, f, indent=2)
logger.info("Highlight: %s_%s (score=%.2f)", best["attack_type"], best["steps"], best_score)
logger.info("Done. Wrote %d traces to %s.", len(list(out_dir.glob("*.json"))), out_dir)
if __name__ == "__main__":
main()