File size: 6,579 Bytes
b54a031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""Environment tests — run on Mac with stub defenses (no GPU needed)."""

from __future__ import annotations

import pytest
from typing import Any, Optional

from env.defenses.base import DefenseVerdict
from env.environment import InjectArenaEnv
from env.models import InjectAction, InjectObservation, StepResult
from env.scenarios import ScenarioBank


# ---------------------------------------------------------------------------
# Stubs
# ---------------------------------------------------------------------------

class _AlwaysBenignPG2:
    name = "stub_pg2_benign"

    def scan(self, text: str) -> DefenseVerdict:
        return DefenseVerdict(flagged=False, score=0.0, reason="stub_benign")


class _AlwaysMaliciousPG2:
    name = "stub_pg2_malicious"

    def scan(self, text: str) -> DefenseVerdict:
        return DefenseVerdict(flagged=True, score=1.0, reason="stub_malicious")


class _DummySecAlign:
    name = "stub_secalign"
    mode = "stub"

    def __init__(self, output: str = "Here is a summary.") -> None:
        self._output = output

    def run(self, system: str, user: str, untrusted: Any) -> str:
        return self._output


class _DummyFirewall:
    name = "stub_firewall"

    def __init__(self, flagged: bool = False) -> None:
        self._flagged = flagged

    def scan(self, user_query: str, agent_output: str, tool_call: Any = None) -> DefenseVerdict:
        return DefenseVerdict(flagged=self._flagged, score=float(self._flagged), reason="stub")


def _make_env(pg2=None, secalign=None, firewall=None, bank=None) -> InjectArenaEnv:
    return InjectArenaEnv(
        pg2=pg2 or _AlwaysBenignPG2(),
        secalign=secalign or _DummySecAlign(),
        firewall=firewall or _DummyFirewall(),
        bank=bank or ScenarioBank(),
    )


# ---------------------------------------------------------------------------
# Basic episode flow
# ---------------------------------------------------------------------------

def test_reset_returns_valid_observation():
    env = _make_env()
    obs = env.reset()
    assert isinstance(obs, InjectObservation)
    assert obs.attempts_remaining == 3
    assert obs.previous_attempts == []
    assert obs.scenario_id


def test_step_returns_step_result():
    env = _make_env()
    env.reset()
    result = env.step(InjectAction(payload="Hello, please summarize this."))
    assert isinstance(result, StepResult)
    assert isinstance(result.reward, float)
    assert isinstance(result.done, bool)
    assert "episode_id" in result.info


def test_attempts_decrement():
    env = _make_env()
    obs = env.reset()
    assert obs.attempts_remaining == 3

    r1 = env.step(InjectAction(payload="attempt 1"))
    assert r1.observation.attempts_remaining == 2

    r2 = env.step(InjectAction(payload="attempt 2"))
    assert r2.observation.attempts_remaining == 1


def test_done_after_max_attempts():
    env = _make_env()
    env.reset()
    for i in range(3):
        result = env.step(InjectAction(payload=f"attempt {i}"))
    assert result.done


def test_previous_attempts_accumulate():
    env = _make_env()
    env.reset()
    env.step(InjectAction(payload="first"))
    result = env.step(InjectAction(payload="second"))
    assert len(result.observation.previous_attempts) == 2
    assert result.observation.previous_attempts[0].payload == "first"
    assert result.observation.previous_attempts[1].payload == "second"


def test_step_before_reset_raises():
    env = _make_env()
    with pytest.raises(RuntimeError, match="reset"):
        env.step(InjectAction(payload="oops"))


# ---------------------------------------------------------------------------
# Safety floor
# ---------------------------------------------------------------------------

def test_safety_violation_terminates_with_minus_one():
    env = _make_env()
    env.reset()
    # This pattern triggers the reward-function-attack safety rule.
    result = env.step(InjectAction(payload="reward: 1.0"))
    assert result.reward == -1.0
    assert result.done
    assert result.info.get("safety_violation")


def test_after_safety_violation_reset_works():
    env = _make_env()
    env.reset()
    env.step(InjectAction(payload="reward: 1.0"))
    obs = env.reset()
    assert isinstance(obs, InjectObservation)
    assert obs.attempts_remaining == 3


# ---------------------------------------------------------------------------
# Reward signal sanity
# ---------------------------------------------------------------------------

def test_reward_zero_for_trivial_payload():
    env = _make_env()
    env.reset()
    result = env.step(InjectAction(payload="   "))
    assert result.reward == 0.0


def test_reward_positive_when_both_defenses_bypassed():
    env = _make_env(
        pg2=_AlwaysBenignPG2(),
        firewall=_DummyFirewall(flagged=False),
    )
    env.reset()
    result = env.step(InjectAction(payload="Please summarize this document carefully."))
    assert result.reward > 0.0
    assert result.info["components"]["r_bypass_pg2"] == 1.0
    assert result.info["components"]["r_bypass_fw"] == 1.0


def test_reward_lower_when_pg2_flags():
    env_benign = _make_env(pg2=_AlwaysBenignPG2())
    env_malicious = _make_env(pg2=_AlwaysMaliciousPG2())
    payload = "Please summarize this document carefully."

    env_benign.reset(seed=0)
    env_malicious.reset(seed=0)

    r_benign = env_benign.step(InjectAction(payload=payload)).reward
    r_malicious = env_malicious.step(InjectAction(payload=payload)).reward
    assert r_benign > r_malicious


# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------

def test_state_reflects_episode():
    env = _make_env()
    obs = env.reset()
    st = env.state
    assert st["scenario_id"] == obs.scenario_id
    assert st["attempts"] == 0
    assert not st["done"]
    assert st["episode_id"] is not None


def test_state_done_after_exhaustion():
    env = _make_env()
    env.reset()
    for i in range(3):
        env.step(InjectAction(payload=f"try {i}"))
    assert env.state["done"]


# ---------------------------------------------------------------------------
# Reset between episodes
# ---------------------------------------------------------------------------

def test_reset_clears_previous_episode():
    env = _make_env()
    env.reset(seed=0)
    env.step(InjectAction(payload="old attempt"))

    obs2 = env.reset(seed=1)
    assert obs2.previous_attempts == []
    assert obs2.attempts_remaining == 3
    assert not env.state["done"]