Spaces:
Sleeping
Sleeping
File size: 2,877 Bytes
433f30e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | """InterpArenaEnv — OpenEnv EnvClient for Interpretability Arena.
Follows the OpenEnv packaging guide: WebSocket `EnvClient` with explicit
``_step_payload`` / ``_parse_result`` / ``_parse_state`` implementations.
See: https://meta-pytorch.org/OpenEnv/auto_getting_started/environment-builder.html
"""
from __future__ import annotations
from typing import Any, Dict
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from models import InterpArenaAction, InterpArenaObservation, InterpArenaState
class InterpArenaEnv(EnvClient[InterpArenaAction, InterpArenaObservation, InterpArenaState]):
"""Client for the Interpretability Arena environment.
Async usage::
async with InterpArenaEnv(base_url="http://localhost:8000") as env:
result = await env.reset()
result = await env.step(InterpArenaAction(
red_type="steer_residual",
red_layer=8,
red_direction_id="jailbreak",
red_strength=3.0,
blue_type="ablate_direction",
blue_layer=8,
blue_direction_id="jailbreak",
))
Sync usage::
with InterpArenaEnv(base_url="http://localhost:8000").sync() as env:
result = env.reset()
result = env.step(InterpArenaAction(
red_type="append_suffix",
blue_type="noop",
))
"""
action_type = InterpArenaAction
observation_type = InterpArenaObservation
state_type = InterpArenaState
def _step_payload(self, action: InterpArenaAction) -> dict[str, Any]:
if hasattr(action, "model_dump"):
return action.model_dump()
return dict(action)
def _parse_result(
self, payload: dict[str, Any]
) -> StepResult[InterpArenaObservation]:
# Matches openenv…serialization.serialize_observation wire shape:
# { "observation": {...}, "reward": optional, "done": bool }
obs_inner: Dict[str, Any] = dict(payload.get("observation") or {})
if "done" in payload and "done" not in obs_inner:
obs_inner["done"] = payload["done"]
if "reward" in payload and "reward" not in obs_inner:
obs_inner["reward"] = payload["reward"]
obs = InterpArenaObservation.model_validate(obs_inner)
return StepResult(
observation=obs,
reward=payload.get("reward"),
done=bool(payload.get("done", False)),
)
def _parse_state(self, payload: dict[str, Any]) -> InterpArenaState:
data = payload.get("state", payload)
if isinstance(data, InterpArenaState):
return data
if isinstance(data, dict):
return InterpArenaState.model_validate(data)
raise TypeError(f"Cannot parse state from {type(data)!r}")
|