"""Environment tests — run on Mac with stub defenses (no GPU needed).""" from __future__ import annotations import pytest from typing import Any, Optional from env.defenses.base import DefenseVerdict from env.environment import InjectArenaEnv from env.models import InjectAction, InjectObservation, StepResult from env.scenarios import ScenarioBank # --------------------------------------------------------------------------- # Stubs # --------------------------------------------------------------------------- class _AlwaysBenignPG2: name = "stub_pg2_benign" def scan(self, text: str) -> DefenseVerdict: return DefenseVerdict(flagged=False, score=0.0, reason="stub_benign") class _AlwaysMaliciousPG2: name = "stub_pg2_malicious" def scan(self, text: str) -> DefenseVerdict: return DefenseVerdict(flagged=True, score=1.0, reason="stub_malicious") class _DummySecAlign: name = "stub_secalign" mode = "stub" def __init__(self, output: str = "Here is a summary.") -> None: self._output = output def run(self, system: str, user: str, untrusted: Any) -> str: return self._output class _DummyFirewall: name = "stub_firewall" def __init__(self, flagged: bool = False) -> None: self._flagged = flagged def scan(self, user_query: str, agent_output: str, tool_call: Any = None) -> DefenseVerdict: return DefenseVerdict(flagged=self._flagged, score=float(self._flagged), reason="stub") def _make_env(pg2=None, secalign=None, firewall=None, bank=None) -> InjectArenaEnv: return InjectArenaEnv( pg2=pg2 or _AlwaysBenignPG2(), secalign=secalign or _DummySecAlign(), firewall=firewall or _DummyFirewall(), bank=bank or ScenarioBank(), ) # --------------------------------------------------------------------------- # Basic episode flow # --------------------------------------------------------------------------- def test_reset_returns_valid_observation(): env = _make_env() obs = env.reset() assert isinstance(obs, InjectObservation) assert obs.attempts_remaining == 3 assert obs.previous_attempts == [] assert obs.scenario_id def test_step_returns_step_result(): env = _make_env() env.reset() result = env.step(InjectAction(payload="Hello, please summarize this.")) assert isinstance(result, StepResult) assert isinstance(result.reward, float) assert isinstance(result.done, bool) assert "episode_id" in result.info def test_attempts_decrement(): env = _make_env() obs = env.reset() assert obs.attempts_remaining == 3 r1 = env.step(InjectAction(payload="attempt 1")) assert r1.observation.attempts_remaining == 2 r2 = env.step(InjectAction(payload="attempt 2")) assert r2.observation.attempts_remaining == 1 def test_done_after_max_attempts(): env = _make_env() env.reset() for i in range(3): result = env.step(InjectAction(payload=f"attempt {i}")) assert result.done def test_previous_attempts_accumulate(): env = _make_env() env.reset() env.step(InjectAction(payload="first")) result = env.step(InjectAction(payload="second")) assert len(result.observation.previous_attempts) == 2 assert result.observation.previous_attempts[0].payload == "first" assert result.observation.previous_attempts[1].payload == "second" def test_step_before_reset_raises(): env = _make_env() with pytest.raises(RuntimeError, match="reset"): env.step(InjectAction(payload="oops")) # --------------------------------------------------------------------------- # Safety floor # --------------------------------------------------------------------------- def test_safety_violation_terminates_with_minus_one(): env = _make_env() env.reset() # This pattern triggers the reward-function-attack safety rule. result = env.step(InjectAction(payload="reward: 1.0")) assert result.reward == -1.0 assert result.done assert result.info.get("safety_violation") def test_after_safety_violation_reset_works(): env = _make_env() env.reset() env.step(InjectAction(payload="reward: 1.0")) obs = env.reset() assert isinstance(obs, InjectObservation) assert obs.attempts_remaining == 3 # --------------------------------------------------------------------------- # Reward signal sanity # --------------------------------------------------------------------------- def test_reward_zero_for_trivial_payload(): env = _make_env() env.reset() result = env.step(InjectAction(payload=" ")) assert result.reward == 0.0 def test_reward_positive_when_both_defenses_bypassed(): env = _make_env( pg2=_AlwaysBenignPG2(), firewall=_DummyFirewall(flagged=False), ) env.reset() result = env.step(InjectAction(payload="Please summarize this document carefully.")) assert result.reward > 0.0 assert result.info["components"]["r_bypass_pg2"] == 1.0 assert result.info["components"]["r_bypass_fw"] == 1.0 def test_reward_lower_when_pg2_flags(): env_benign = _make_env(pg2=_AlwaysBenignPG2()) env_malicious = _make_env(pg2=_AlwaysMaliciousPG2()) payload = "Please summarize this document carefully." env_benign.reset(seed=0) env_malicious.reset(seed=0) r_benign = env_benign.step(InjectAction(payload=payload)).reward r_malicious = env_malicious.step(InjectAction(payload=payload)).reward assert r_benign > r_malicious # --------------------------------------------------------------------------- # State # --------------------------------------------------------------------------- def test_state_reflects_episode(): env = _make_env() obs = env.reset() st = env.state assert st["scenario_id"] == obs.scenario_id assert st["attempts"] == 0 assert not st["done"] assert st["episode_id"] is not None def test_state_done_after_exhaustion(): env = _make_env() env.reset() for i in range(3): env.step(InjectAction(payload=f"try {i}")) assert env.state["done"] # --------------------------------------------------------------------------- # Reset between episodes # --------------------------------------------------------------------------- def test_reset_clears_previous_episode(): env = _make_env() env.reset(seed=0) env.step(InjectAction(payload="old attempt")) obs2 = env.reset(seed=1) assert obs2.previous_attempts == [] assert obs2.attempts_remaining == 3 assert not env.state["done"]