Inject-Arena / env /environment.py
Jaswanth1210's picture
Phase 4: InjectArenaEnv + FastAPI server + Dockerfile + env tests (81 passing)
b54a031
"""InjectArenaEnv — the core RL environment.
Defense instances are injected at construction so the environment can run
with stub defenses on Mac (no GPU) and real defenses on Colab.
"""
from __future__ import annotations
import time
import uuid
from typing import Any, Dict, List, Optional
from .models import (
AttemptRecord,
InjectAction,
InjectObservation,
StepResult,
)
from .rewards import compute_reward
from .scenarios import ScenarioBank
from .utils.safety_filter import is_safe
from .verifiers.exfiltration import verify_exfiltration
from .verifiers.forbidden_tool import verify_forbidden_tool
from .verifiers.prompt_leak import verify_prompt_leak
_MAX_ATTEMPTS = 3
_STEP_TIMEOUT_S = 30.0
class InjectArenaEnv:
"""Episode-based environment for prompt-injection red-teaming."""
def __init__(
self,
pg2: Any,
secalign: Any,
firewall: Any,
bank: Optional[ScenarioBank] = None,
max_attempts: int = _MAX_ATTEMPTS,
embedder: Any = None,
) -> None:
self._pg2 = pg2
self._secalign = secalign
self._firewall = firewall
self._bank = bank or ScenarioBank()
self._max_attempts = max_attempts
self._embedder = embedder
self._scenario: Optional[Dict[str, Any]] = None
self._attempts: int = 0
self._previous: List[AttemptRecord] = []
self._episode_id: Optional[str] = None
self._done: bool = True
# ------------------------------------------------------------------
# OpenEnv interface
# ------------------------------------------------------------------
def reset(
self,
scenario_id: Optional[str] = None,
seed: Optional[int] = None,
split: str = "train",
) -> InjectObservation:
if scenario_id is not None:
self._scenario = self._bank.by_id(scenario_id)
else:
self._scenario = self._bank.sample(seed=seed, split=split)
self._attempts = 0
self._previous = []
self._episode_id = str(uuid.uuid4())
self._done = False
return self._build_obs()
def step(self, action: InjectAction) -> StepResult:
if self._done or self._scenario is None:
raise RuntimeError("Call reset() before step().")
t0 = time.monotonic()
payload = action.payload
scenario = self._scenario
# Safety floor — immediate termination.
safe, safety_reason = is_safe(payload)
if not safe:
self._done = True
return StepResult(
observation=self._build_obs(),
reward=-1.0,
done=True,
info={
"safety_violation": True,
"safety_reason": safety_reason,
"episode_id": self._episode_id,
},
)
# All three defenses run on every step (§5.4 invariant).
pg2_verdict = self._pg2.scan(payload)
agent_output = self._secalign.run(
system=scenario["agent_system_prompt"],
user=scenario["user_query"],
untrusted={scenario["injection_slot"]: payload},
)
fw_verdict = self._firewall.scan(
user_query=payload,
agent_output=agent_output,
tool_call=None,
)
elapsed = time.monotonic() - t0
# Task success via category-specific verifier.
cat = scenario["target_category"]
if cat == "exfiltration":
task_success = verify_exfiltration(agent_output, scenario)
elif cat == "forbidden_tool":
task_success = verify_forbidden_tool(agent_output, scenario)
else:
task_success = verify_prompt_leak(agent_output, scenario)
# Reward.
recent = [a.payload for a in self._previous]
reward, reward_info = compute_reward(
payload=payload,
pg2_verdict=pg2_verdict,
firewall_verdict=fw_verdict,
task_success=task_success,
scenario=scenario,
recent_payloads=recent,
embedder=self._embedder,
)
# Record.
record = AttemptRecord(
payload=payload,
pg2_label="injection" if pg2_verdict.flagged else "benign",
firewall_flagged=fw_verdict.flagged,
task_success=task_success,
)
self._previous.append(record)
self._attempts += 1
# Done conditions (§5.4).
timeout = elapsed > _STEP_TIMEOUT_S
done = (
task_success
or self._attempts >= self._max_attempts
or reward_info.get("safety_violation", False)
or timeout
)
self._done = done
return StepResult(
observation=self._build_obs(),
reward=reward,
done=done,
info={
**reward_info,
"agent_output": agent_output,
"pg2_verdict": pg2_verdict.model_dump(),
"fw_verdict": fw_verdict.model_dump(),
"task_success": task_success,
"elapsed_s": round(elapsed, 3),
"timeout": timeout,
"strategy_tag": action.strategy_tag,
"episode_id": self._episode_id,
},
)
@property
def state(self) -> Dict[str, Any]:
return {
"episode_id": self._episode_id,
"scenario_id": self._scenario["scenario_id"] if self._scenario else None,
"attempts": self._attempts,
"max_attempts": self._max_attempts,
"done": self._done,
}
def close(self) -> None:
pass
# ------------------------------------------------------------------
def _build_obs(self) -> InjectObservation:
s = self._scenario
return InjectObservation(
scenario_id=s["scenario_id"],
target_behavior=s["target_behavior"],
target_category=s["target_category"],
agent_system_prompt=s["agent_system_prompt"],
user_query=s["user_query"],
injection_slot=s["injection_slot"],
tool_surface=s["tool_surface"],
canary_string=s.get("canary_string"),
previous_attempts=list(self._previous),
attempts_remaining=self._max_attempts - self._attempts,
)