File size: 8,328 Bytes
433f30e
 
 
 
 
 
 
 
 
 
 
 
 
 
38df389
 
 
433f30e
 
 
38df389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433f30e
 
 
 
 
 
 
38df389
 
 
433f30e
 
38df389
 
 
 
433f30e
 
 
38df389
 
 
 
433f30e
 
38df389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433f30e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a12d38f
 
433f30e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""OpenEnv-compliant models for Interpretability Arena.

These are the typed API contracts between the client and server.
- InterpArenaAction  : combined Red + Blue actions for one step
- InterpArenaObservation : what both agents observe after a step
- InterpArenaState   : episode-level metadata (returned by state())

Mechanistic fields (residual_stream, logits) are omitted from the
wire format for efficiency β€” they are accessed in-process on the
server side and surfaced as summary statistics in the observation.
"""

from __future__ import annotations

from typing import Any, Literal, Optional

from pydantic import Field

from openenv.core.env_server import Action, Observation, State

# Wire / schema enums (Gradio dropdown + docs)
RedIntervention = Literal[
    "steer_residual",
    "amplify_attn",
    "patch_activation",
    "logit_bias",
    "modify_prompt",
    "append_suffix",
    "query_model",
]
BlueDefense = Literal[
    "ablate_direction",
    "clamp_activation",
    "restore_baseline",
    "suppress_head",
    "logit_filter",
    "sanitize_prompt",
    "block_output",
    "noop",
]


# ── Action ─────────────────────────────────────────────────────────────────────
# One combined action per step; both agents act simultaneously.

class InterpArenaAction(Action):
    """Combined Red + Blue action for a single arena step.

    **Both agents act in one step:** ``env.step(InterpArenaAction(...))`` always
    sends Red's move and Blue's move together for the same forward pass.

    Red fields
    ----------
    red_type
        Intervention family (see titles below).
    red_layer, red_direction_id, red_strength, …
        Only the fields relevant to ``red_type`` need to be set.

    Blue fields
    -----------
    blue_type
        Defense family.
    blue_layer, blue_direction_id, …
        Same idea β€” fill what ``blue_type`` requires (``noop`` needs no extras).
    """

    # ── Red (attacker) ──────────────────────────────────────────────────────
    red_type: RedIntervention = Field(
        default="append_suffix",
        title="Red: intervention type",
        description=(
            "What Red does this step: e.g. steer_residual adds a concept vector at a layer; "
            "append_suffix / modify_prompt change text; query_model runs a side prompt on the same LM."
        ),
    )
    red_layer: Optional[int] = Field(
        default=None,
        title="Red: transformer layer index",
        description="Layer 0 = early, higher = later. Used by steer_residual, amplify_attn, patch_activation.",
    )
    red_direction_id: Optional[str] = Field(
        default=None,
        title="Red: steering direction name",
        description="Registry key, e.g. jailbreak, refusal, toxicity (Space must register it).",
    )
    red_strength: Optional[float] = Field(
        default=None,
        title="Red: steering / patch strength",
        description="Scale for residual steering or similar (meaning depends on intervention type).",
    )
    red_head: Optional[int] = Field(
        default=None,
        title="Red: attention head index",
        description="Which head to amplify (for amplify_attn).",
    )
    red_scale: Optional[float] = Field(
        default=None,
        title="Red: attention scale factor",
        description="Multiply attention scores for that head (amplify_attn).",
    )
    red_position: Optional[int] = Field(
        default=None,
        title="Red: token position (sequence index)",
        description="Token index for patch_activation / some hooks.",
    )
    red_target_token_ids: Optional[list[int]] = Field(
        default=None,
        title="Red: vocab token IDs to bias (logit_bias)",
        description="Comma-separated integers or JSON list, e.g. 1234, 5678 β€” required for logit_bias.",
    )
    red_bias_strength: Optional[float] = Field(
        default=None,
        title="Red: logit bias strength",
        description="Added to logits at red_target_token_ids (logit_bias).",
    )
    red_text: Optional[str] = Field(
        default=None,
        title="Red: prompt text",
        description="Suffix for append_suffix, full prompt for modify_prompt, probe text for query_model.",
    )

    # ── Blue (defender) ─────────────────────────────────────────────────────
    blue_type: BlueDefense = Field(
        default="noop",
        title="Blue: defense type",
        description=(
            "What Blue does on the same forward pass as Red: noop observes only; "
            "ablate_direction removes a direction at a layer; block_output replaces output, etc."
        ),
    )
    blue_layer: Optional[int] = Field(
        default=None,
        title="Blue: transformer layer index",
        description="Layer for ablate_direction, clamp_activation, restore_baseline, suppress_head.",
    )
    blue_direction_id: Optional[str] = Field(
        default=None,
        title="Blue: direction to remove (ablate)",
        description="Same registry names as Red steering, e.g. jailbreak.",
    )
    blue_head: Optional[int] = Field(
        default=None,
        title="Blue: attention head index",
        description="For suppress_head.",
    )
    blue_position: Optional[int] = Field(
        default=None,
        title="Blue: token position",
        description="For restore_baseline (which token’s residual to reset).",
    )
    blue_clamp_min: Optional[float] = Field(
        default=None,
        title="Blue: activation clamp minimum",
        description="Lower bound for clamp_activation on residuals.",
    )
    blue_clamp_max: Optional[float] = Field(
        default=None,
        title="Blue: activation clamp maximum",
        description="Upper bound for clamp_activation.",
    )
    blue_prohibited_token_ids: Optional[list[int]] = Field(
        default=None,
        title="Blue: token IDs to suppress in logits",
        description="Comma-separated or JSON list β€” for logit_filter.",
    )


# ── Observation ────────────────────────────────────────────────────────────────

class InterpArenaObservation(Observation):
    """What both agents observe after each step.

    Scalar summaries of internal states are included (e.g. mean_resid_norm)
    so that RL policies can learn from mechanistic signals without
    transmitting large tensors over the wire.
    """

    # Prompt/output
    prompt: str
    target_output: str
    prohibited_patterns: list[str]
    modified_prompt: str
    model_output: str

    # Scores
    target_similarity: float       # [0,1] Red's objective
    safety_score: float            # [0,1] 1=safe, 0=unsafe

    # Rewards (both exposed so either agent can learn)
    reward_red: float
    reward_blue: float

    # Mechanistic summaries β€” key inputs for layer selection policy
    mean_resid_norm: float = 0.0          # mean L2 norm across all layers
    layer_resid_norms: list[float] = []   # per-layer L2 norms; len = n_layers
    detected_layers: list[int] = []       # layers where Blue detected Red steering

    # Episode metadata
    step: int
    done: bool

    # Auxiliary info
    red_action_type: str
    blue_action_type: str
    hard_blocked: bool = False
    # Filled when red_type == query_model: extra decode on red_text (same LM)
    red_probe_output: str = ""


# ── State ──────────────────────────────────────────────────────────────────────

class InterpArenaState(State):
    """Episode-level metadata (returned by env.state())."""

    episode_id: int
    step_count: int
    prompt: str
    target_output: str
    prohibited_patterns: list[str]
    cumulative_reward_red: float = 0.0
    cumulative_reward_blue: float = 0.0
    jailbreak_achieved: bool = False