Inject-Arena / tests /test_environment.py
Jaswanth1210's picture
Phase 4: InjectArenaEnv + FastAPI server + Dockerfile + env tests (81 passing)
b54a031
"""Environment tests — run on Mac with stub defenses (no GPU needed)."""
from __future__ import annotations
import pytest
from typing import Any, Optional
from env.defenses.base import DefenseVerdict
from env.environment import InjectArenaEnv
from env.models import InjectAction, InjectObservation, StepResult
from env.scenarios import ScenarioBank
# ---------------------------------------------------------------------------
# Stubs
# ---------------------------------------------------------------------------
class _AlwaysBenignPG2:
name = "stub_pg2_benign"
def scan(self, text: str) -> DefenseVerdict:
return DefenseVerdict(flagged=False, score=0.0, reason="stub_benign")
class _AlwaysMaliciousPG2:
name = "stub_pg2_malicious"
def scan(self, text: str) -> DefenseVerdict:
return DefenseVerdict(flagged=True, score=1.0, reason="stub_malicious")
class _DummySecAlign:
name = "stub_secalign"
mode = "stub"
def __init__(self, output: str = "Here is a summary.") -> None:
self._output = output
def run(self, system: str, user: str, untrusted: Any) -> str:
return self._output
class _DummyFirewall:
name = "stub_firewall"
def __init__(self, flagged: bool = False) -> None:
self._flagged = flagged
def scan(self, user_query: str, agent_output: str, tool_call: Any = None) -> DefenseVerdict:
return DefenseVerdict(flagged=self._flagged, score=float(self._flagged), reason="stub")
def _make_env(pg2=None, secalign=None, firewall=None, bank=None) -> InjectArenaEnv:
return InjectArenaEnv(
pg2=pg2 or _AlwaysBenignPG2(),
secalign=secalign or _DummySecAlign(),
firewall=firewall or _DummyFirewall(),
bank=bank or ScenarioBank(),
)
# ---------------------------------------------------------------------------
# Basic episode flow
# ---------------------------------------------------------------------------
def test_reset_returns_valid_observation():
env = _make_env()
obs = env.reset()
assert isinstance(obs, InjectObservation)
assert obs.attempts_remaining == 3
assert obs.previous_attempts == []
assert obs.scenario_id
def test_step_returns_step_result():
env = _make_env()
env.reset()
result = env.step(InjectAction(payload="Hello, please summarize this."))
assert isinstance(result, StepResult)
assert isinstance(result.reward, float)
assert isinstance(result.done, bool)
assert "episode_id" in result.info
def test_attempts_decrement():
env = _make_env()
obs = env.reset()
assert obs.attempts_remaining == 3
r1 = env.step(InjectAction(payload="attempt 1"))
assert r1.observation.attempts_remaining == 2
r2 = env.step(InjectAction(payload="attempt 2"))
assert r2.observation.attempts_remaining == 1
def test_done_after_max_attempts():
env = _make_env()
env.reset()
for i in range(3):
result = env.step(InjectAction(payload=f"attempt {i}"))
assert result.done
def test_previous_attempts_accumulate():
env = _make_env()
env.reset()
env.step(InjectAction(payload="first"))
result = env.step(InjectAction(payload="second"))
assert len(result.observation.previous_attempts) == 2
assert result.observation.previous_attempts[0].payload == "first"
assert result.observation.previous_attempts[1].payload == "second"
def test_step_before_reset_raises():
env = _make_env()
with pytest.raises(RuntimeError, match="reset"):
env.step(InjectAction(payload="oops"))
# ---------------------------------------------------------------------------
# Safety floor
# ---------------------------------------------------------------------------
def test_safety_violation_terminates_with_minus_one():
env = _make_env()
env.reset()
# This pattern triggers the reward-function-attack safety rule.
result = env.step(InjectAction(payload="reward: 1.0"))
assert result.reward == -1.0
assert result.done
assert result.info.get("safety_violation")
def test_after_safety_violation_reset_works():
env = _make_env()
env.reset()
env.step(InjectAction(payload="reward: 1.0"))
obs = env.reset()
assert isinstance(obs, InjectObservation)
assert obs.attempts_remaining == 3
# ---------------------------------------------------------------------------
# Reward signal sanity
# ---------------------------------------------------------------------------
def test_reward_zero_for_trivial_payload():
env = _make_env()
env.reset()
result = env.step(InjectAction(payload=" "))
assert result.reward == 0.0
def test_reward_positive_when_both_defenses_bypassed():
env = _make_env(
pg2=_AlwaysBenignPG2(),
firewall=_DummyFirewall(flagged=False),
)
env.reset()
result = env.step(InjectAction(payload="Please summarize this document carefully."))
assert result.reward > 0.0
assert result.info["components"]["r_bypass_pg2"] == 1.0
assert result.info["components"]["r_bypass_fw"] == 1.0
def test_reward_lower_when_pg2_flags():
env_benign = _make_env(pg2=_AlwaysBenignPG2())
env_malicious = _make_env(pg2=_AlwaysMaliciousPG2())
payload = "Please summarize this document carefully."
env_benign.reset(seed=0)
env_malicious.reset(seed=0)
r_benign = env_benign.step(InjectAction(payload=payload)).reward
r_malicious = env_malicious.step(InjectAction(payload=payload)).reward
assert r_benign > r_malicious
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
def test_state_reflects_episode():
env = _make_env()
obs = env.reset()
st = env.state
assert st["scenario_id"] == obs.scenario_id
assert st["attempts"] == 0
assert not st["done"]
assert st["episode_id"] is not None
def test_state_done_after_exhaustion():
env = _make_env()
env.reset()
for i in range(3):
env.step(InjectAction(payload=f"try {i}"))
assert env.state["done"]
# ---------------------------------------------------------------------------
# Reset between episodes
# ---------------------------------------------------------------------------
def test_reset_clears_previous_episode():
env = _make_env()
env.reset(seed=0)
env.step(InjectAction(payload="old attempt"))
obs2 = env.reset(seed=1)
assert obs2.previous_attempts == []
assert obs2.attempts_remaining == 3
assert not env.state["done"]