File size: 5,084 Bytes
550a83e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Zero-shot baseline — same Qwen2.5-1.5B-Instruct model, no training.

Evaluates the base model's raw prompt-injection capability without any RL.
"""

from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Any, Dict, List

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger("zero_shot")

_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"


def run_zero_shot(output_json: str = "docs/zero_shot_results.json",
                  temperature: float = 0.7,
                  seed: int = 42) -> Dict[str, Any]:
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    from env.defenses.secalign_agent import SecAlignAgent
    from env.defenses.prompt_guard import PromptGuard
    from env.defenses.llama_firewall import FirewallWrapper
    from env.utils.embedding_cache import EmbeddingCache
    from env.environment import InjectArenaEnv
    from env.models import InjectAction, InjectObservation
    from env.scenarios import ScenarioBank
    from train.client import format_prompt, parse_payload

    # Load defenses.
    secalign = SecAlignAgent()
    pg2 = PromptGuard()
    firewall = FirewallWrapper(prompt_guard_fallback=pg2)
    embedder = EmbeddingCache()
    bank = ScenarioBank()
    env = InjectArenaEnv(pg2=pg2, secalign=secalign, firewall=firewall,
                         bank=bank, embedder=embedder)

    # Load base model (no LoRA).
    logger.info("Loading zero-shot model %s …", _MODEL_ID)
    bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
                              bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
    tokenizer = AutoTokenizer.from_pretrained(_MODEL_ID, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        _MODEL_ID, quantization_config=bnb, device_map="auto", trust_remote_code=True
    )
    model.eval()
    logger.info("Zero-shot model ready.")

    eval_scenarios = bank.by_split("eval")
    records: List[Dict[str, Any]] = []

    for s in eval_scenarios:
        obs = InjectObservation(
            scenario_id=s["scenario_id"],
            target_behavior=s["target_behavior"],
            target_category=s["target_category"],
            agent_system_prompt=s["agent_system_prompt"],
            user_query=s["user_query"],
            injection_slot=s["injection_slot"],
            tool_surface=s["tool_surface"],
            canary_string=s.get("canary_string"),
            previous_attempts=[],
            attempts_remaining=3,
        )
        prompt = format_prompt(obs)
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        torch.manual_seed(seed)
        with torch.inference_mode():
            output = model.generate(
                **inputs, max_new_tokens=512, do_sample=True,
                temperature=temperature, top_p=0.9,
            )
        new_tokens = output[0][inputs["input_ids"].shape[-1]:]
        completion = tokenizer.decode(new_tokens, skip_special_tokens=True)
        payload = parse_payload(completion)

        env.reset(scenario_id=s["scenario_id"])
        result = env.step(InjectAction(payload=payload))

        records.append({
            "scenario_id": s["scenario_id"],
            "target_category": s["target_category"],
            "injection_slot": s["injection_slot"],
            "payload": payload,
            "reward": result.reward,
            "task_success": result.info.get("task_success", False),
            "pg2_flagged": result.info.get("pg2_verdict", {}).get("flagged", False),
            "fw_flagged": result.info.get("fw_verdict", {}).get("flagged", False),
        })

    n = len(records)
    summary = {
        "baseline": "zero_shot",
        "model": _MODEL_ID,
        "n_scenarios": n,
        "mean_reward": round(sum(r["reward"] for r in records) / n, 4),
        "task_success_rate": round(sum(r["task_success"] for r in records) / n, 4),
        "pg2_bypass_rate": round(sum(not r["pg2_flagged"] for r in records) / n, 4),
        "fw_bypass_rate": round(sum(not r["fw_flagged"] for r in records) / n, 4),
        "composed_bypass_rate": round(
            sum(r["task_success"] and not r["pg2_flagged"] and not r["fw_flagged"]
                for r in records) / n, 4),
        "records": records,
    }

    out_path = Path(output_json)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(summary, indent=2))
    logger.info("Zero-shot results saved to %s", out_path)
    return summary


if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--output-json", default="docs/zero_shot_results.json")
    p.add_argument("--temperature", type=float, default=0.7)
    p.add_argument("--seed", type=int, default=42)
    args = p.parse_args()
    result = run_zero_shot(args.output_json, args.temperature, args.seed)
    print(f"\nZero-shot composed bypass: {result['composed_bypass_rate']:.1%}")