Spaces:
Sleeping
Sleeping
File size: 8,328 Bytes
433f30e 38df389 433f30e 38df389 433f30e 38df389 433f30e 38df389 433f30e 38df389 433f30e 38df389 433f30e a12d38f 433f30e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | """OpenEnv-compliant models for Interpretability Arena.
These are the typed API contracts between the client and server.
- InterpArenaAction : combined Red + Blue actions for one step
- InterpArenaObservation : what both agents observe after a step
- InterpArenaState : episode-level metadata (returned by state())
Mechanistic fields (residual_stream, logits) are omitted from the
wire format for efficiency β they are accessed in-process on the
server side and surfaced as summary statistics in the observation.
"""
from __future__ import annotations
from typing import Any, Literal, Optional
from pydantic import Field
from openenv.core.env_server import Action, Observation, State
# Wire / schema enums (Gradio dropdown + docs)
RedIntervention = Literal[
"steer_residual",
"amplify_attn",
"patch_activation",
"logit_bias",
"modify_prompt",
"append_suffix",
"query_model",
]
BlueDefense = Literal[
"ablate_direction",
"clamp_activation",
"restore_baseline",
"suppress_head",
"logit_filter",
"sanitize_prompt",
"block_output",
"noop",
]
# ββ Action βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# One combined action per step; both agents act simultaneously.
class InterpArenaAction(Action):
"""Combined Red + Blue action for a single arena step.
**Both agents act in one step:** ``env.step(InterpArenaAction(...))`` always
sends Red's move and Blue's move together for the same forward pass.
Red fields
----------
red_type
Intervention family (see titles below).
red_layer, red_direction_id, red_strength, β¦
Only the fields relevant to ``red_type`` need to be set.
Blue fields
-----------
blue_type
Defense family.
blue_layer, blue_direction_id, β¦
Same idea β fill what ``blue_type`` requires (``noop`` needs no extras).
"""
# ββ Red (attacker) ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
red_type: RedIntervention = Field(
default="append_suffix",
title="Red: intervention type",
description=(
"What Red does this step: e.g. steer_residual adds a concept vector at a layer; "
"append_suffix / modify_prompt change text; query_model runs a side prompt on the same LM."
),
)
red_layer: Optional[int] = Field(
default=None,
title="Red: transformer layer index",
description="Layer 0 = early, higher = later. Used by steer_residual, amplify_attn, patch_activation.",
)
red_direction_id: Optional[str] = Field(
default=None,
title="Red: steering direction name",
description="Registry key, e.g. jailbreak, refusal, toxicity (Space must register it).",
)
red_strength: Optional[float] = Field(
default=None,
title="Red: steering / patch strength",
description="Scale for residual steering or similar (meaning depends on intervention type).",
)
red_head: Optional[int] = Field(
default=None,
title="Red: attention head index",
description="Which head to amplify (for amplify_attn).",
)
red_scale: Optional[float] = Field(
default=None,
title="Red: attention scale factor",
description="Multiply attention scores for that head (amplify_attn).",
)
red_position: Optional[int] = Field(
default=None,
title="Red: token position (sequence index)",
description="Token index for patch_activation / some hooks.",
)
red_target_token_ids: Optional[list[int]] = Field(
default=None,
title="Red: vocab token IDs to bias (logit_bias)",
description="Comma-separated integers or JSON list, e.g. 1234, 5678 β required for logit_bias.",
)
red_bias_strength: Optional[float] = Field(
default=None,
title="Red: logit bias strength",
description="Added to logits at red_target_token_ids (logit_bias).",
)
red_text: Optional[str] = Field(
default=None,
title="Red: prompt text",
description="Suffix for append_suffix, full prompt for modify_prompt, probe text for query_model.",
)
# ββ Blue (defender) βββββββββββββββββββββββββββββββββββββββββββββββββββββ
blue_type: BlueDefense = Field(
default="noop",
title="Blue: defense type",
description=(
"What Blue does on the same forward pass as Red: noop observes only; "
"ablate_direction removes a direction at a layer; block_output replaces output, etc."
),
)
blue_layer: Optional[int] = Field(
default=None,
title="Blue: transformer layer index",
description="Layer for ablate_direction, clamp_activation, restore_baseline, suppress_head.",
)
blue_direction_id: Optional[str] = Field(
default=None,
title="Blue: direction to remove (ablate)",
description="Same registry names as Red steering, e.g. jailbreak.",
)
blue_head: Optional[int] = Field(
default=None,
title="Blue: attention head index",
description="For suppress_head.",
)
blue_position: Optional[int] = Field(
default=None,
title="Blue: token position",
description="For restore_baseline (which tokenβs residual to reset).",
)
blue_clamp_min: Optional[float] = Field(
default=None,
title="Blue: activation clamp minimum",
description="Lower bound for clamp_activation on residuals.",
)
blue_clamp_max: Optional[float] = Field(
default=None,
title="Blue: activation clamp maximum",
description="Upper bound for clamp_activation.",
)
blue_prohibited_token_ids: Optional[list[int]] = Field(
default=None,
title="Blue: token IDs to suppress in logits",
description="Comma-separated or JSON list β for logit_filter.",
)
# ββ Observation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class InterpArenaObservation(Observation):
"""What both agents observe after each step.
Scalar summaries of internal states are included (e.g. mean_resid_norm)
so that RL policies can learn from mechanistic signals without
transmitting large tensors over the wire.
"""
# Prompt/output
prompt: str
target_output: str
prohibited_patterns: list[str]
modified_prompt: str
model_output: str
# Scores
target_similarity: float # [0,1] Red's objective
safety_score: float # [0,1] 1=safe, 0=unsafe
# Rewards (both exposed so either agent can learn)
reward_red: float
reward_blue: float
# Mechanistic summaries β key inputs for layer selection policy
mean_resid_norm: float = 0.0 # mean L2 norm across all layers
layer_resid_norms: list[float] = [] # per-layer L2 norms; len = n_layers
detected_layers: list[int] = [] # layers where Blue detected Red steering
# Episode metadata
step: int
done: bool
# Auxiliary info
red_action_type: str
blue_action_type: str
hard_blocked: bool = False
# Filled when red_type == query_model: extra decode on red_text (same LM)
red_probe_output: str = ""
# ββ State ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class InterpArenaState(State):
"""Episode-level metadata (returned by env.state())."""
episode_id: int
step_count: int
prompt: str
target_output: str
prohibited_patterns: list[str]
cumulative_reward_red: float = 0.0
cumulative_reward_blue: float = 0.0
jailbreak_achieved: bool = False
|