siege / models.py
BART-ender's picture
Upload folder using huggingface_hub
38df389 verified
"""OpenEnv-compliant models for Interpretability Arena.
These are the typed API contracts between the client and server.
- InterpArenaAction : combined Red + Blue actions for one step
- InterpArenaObservation : what both agents observe after a step
- InterpArenaState : episode-level metadata (returned by state())
Mechanistic fields (residual_stream, logits) are omitted from the
wire format for efficiency β€” they are accessed in-process on the
server side and surfaced as summary statistics in the observation.
"""
from __future__ import annotations
from typing import Any, Literal, Optional
from pydantic import Field
from openenv.core.env_server import Action, Observation, State
# Wire / schema enums (Gradio dropdown + docs)
RedIntervention = Literal[
"steer_residual",
"amplify_attn",
"patch_activation",
"logit_bias",
"modify_prompt",
"append_suffix",
"query_model",
]
BlueDefense = Literal[
"ablate_direction",
"clamp_activation",
"restore_baseline",
"suppress_head",
"logit_filter",
"sanitize_prompt",
"block_output",
"noop",
]
# ── Action ─────────────────────────────────────────────────────────────────────
# One combined action per step; both agents act simultaneously.
class InterpArenaAction(Action):
"""Combined Red + Blue action for a single arena step.
**Both agents act in one step:** ``env.step(InterpArenaAction(...))`` always
sends Red's move and Blue's move together for the same forward pass.
Red fields
----------
red_type
Intervention family (see titles below).
red_layer, red_direction_id, red_strength, …
Only the fields relevant to ``red_type`` need to be set.
Blue fields
-----------
blue_type
Defense family.
blue_layer, blue_direction_id, …
Same idea β€” fill what ``blue_type`` requires (``noop`` needs no extras).
"""
# ── Red (attacker) ──────────────────────────────────────────────────────
red_type: RedIntervention = Field(
default="append_suffix",
title="Red: intervention type",
description=(
"What Red does this step: e.g. steer_residual adds a concept vector at a layer; "
"append_suffix / modify_prompt change text; query_model runs a side prompt on the same LM."
),
)
red_layer: Optional[int] = Field(
default=None,
title="Red: transformer layer index",
description="Layer 0 = early, higher = later. Used by steer_residual, amplify_attn, patch_activation.",
)
red_direction_id: Optional[str] = Field(
default=None,
title="Red: steering direction name",
description="Registry key, e.g. jailbreak, refusal, toxicity (Space must register it).",
)
red_strength: Optional[float] = Field(
default=None,
title="Red: steering / patch strength",
description="Scale for residual steering or similar (meaning depends on intervention type).",
)
red_head: Optional[int] = Field(
default=None,
title="Red: attention head index",
description="Which head to amplify (for amplify_attn).",
)
red_scale: Optional[float] = Field(
default=None,
title="Red: attention scale factor",
description="Multiply attention scores for that head (amplify_attn).",
)
red_position: Optional[int] = Field(
default=None,
title="Red: token position (sequence index)",
description="Token index for patch_activation / some hooks.",
)
red_target_token_ids: Optional[list[int]] = Field(
default=None,
title="Red: vocab token IDs to bias (logit_bias)",
description="Comma-separated integers or JSON list, e.g. 1234, 5678 β€” required for logit_bias.",
)
red_bias_strength: Optional[float] = Field(
default=None,
title="Red: logit bias strength",
description="Added to logits at red_target_token_ids (logit_bias).",
)
red_text: Optional[str] = Field(
default=None,
title="Red: prompt text",
description="Suffix for append_suffix, full prompt for modify_prompt, probe text for query_model.",
)
# ── Blue (defender) ─────────────────────────────────────────────────────
blue_type: BlueDefense = Field(
default="noop",
title="Blue: defense type",
description=(
"What Blue does on the same forward pass as Red: noop observes only; "
"ablate_direction removes a direction at a layer; block_output replaces output, etc."
),
)
blue_layer: Optional[int] = Field(
default=None,
title="Blue: transformer layer index",
description="Layer for ablate_direction, clamp_activation, restore_baseline, suppress_head.",
)
blue_direction_id: Optional[str] = Field(
default=None,
title="Blue: direction to remove (ablate)",
description="Same registry names as Red steering, e.g. jailbreak.",
)
blue_head: Optional[int] = Field(
default=None,
title="Blue: attention head index",
description="For suppress_head.",
)
blue_position: Optional[int] = Field(
default=None,
title="Blue: token position",
description="For restore_baseline (which token’s residual to reset).",
)
blue_clamp_min: Optional[float] = Field(
default=None,
title="Blue: activation clamp minimum",
description="Lower bound for clamp_activation on residuals.",
)
blue_clamp_max: Optional[float] = Field(
default=None,
title="Blue: activation clamp maximum",
description="Upper bound for clamp_activation.",
)
blue_prohibited_token_ids: Optional[list[int]] = Field(
default=None,
title="Blue: token IDs to suppress in logits",
description="Comma-separated or JSON list β€” for logit_filter.",
)
# ── Observation ────────────────────────────────────────────────────────────────
class InterpArenaObservation(Observation):
"""What both agents observe after each step.
Scalar summaries of internal states are included (e.g. mean_resid_norm)
so that RL policies can learn from mechanistic signals without
transmitting large tensors over the wire.
"""
# Prompt/output
prompt: str
target_output: str
prohibited_patterns: list[str]
modified_prompt: str
model_output: str
# Scores
target_similarity: float # [0,1] Red's objective
safety_score: float # [0,1] 1=safe, 0=unsafe
# Rewards (both exposed so either agent can learn)
reward_red: float
reward_blue: float
# Mechanistic summaries β€” key inputs for layer selection policy
mean_resid_norm: float = 0.0 # mean L2 norm across all layers
layer_resid_norms: list[float] = [] # per-layer L2 norms; len = n_layers
detected_layers: list[int] = [] # layers where Blue detected Red steering
# Episode metadata
step: int
done: bool
# Auxiliary info
red_action_type: str
blue_action_type: str
hard_blocked: bool = False
# Filled when red_type == query_model: extra decode on red_text (same LM)
red_probe_output: str = ""
# ── State ──────────────────────────────────────────────────────────────────────
class InterpArenaState(State):
"""Episode-level metadata (returned by env.state())."""
episode_id: int
step_count: int
prompt: str
target_output: str
prohibited_patterns: list[str]
cumulative_reward_red: float = 0.0
cumulative_reward_blue: float = 0.0
jailbreak_achieved: bool = False