Spaces:
Sleeping
Sleeping
| """InterpArenaEnv — OpenEnv EnvClient for Interpretability Arena. | |
| Follows the OpenEnv packaging guide: WebSocket `EnvClient` with explicit | |
| ``_step_payload`` / ``_parse_result`` / ``_parse_state`` implementations. | |
| See: https://meta-pytorch.org/OpenEnv/auto_getting_started/environment-builder.html | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| from models import InterpArenaAction, InterpArenaObservation, InterpArenaState | |
| class InterpArenaEnv(EnvClient[InterpArenaAction, InterpArenaObservation, InterpArenaState]): | |
| """Client for the Interpretability Arena environment. | |
| Async usage:: | |
| async with InterpArenaEnv(base_url="http://localhost:8000") as env: | |
| result = await env.reset() | |
| result = await env.step(InterpArenaAction( | |
| red_type="steer_residual", | |
| red_layer=8, | |
| red_direction_id="jailbreak", | |
| red_strength=3.0, | |
| blue_type="ablate_direction", | |
| blue_layer=8, | |
| blue_direction_id="jailbreak", | |
| )) | |
| Sync usage:: | |
| with InterpArenaEnv(base_url="http://localhost:8000").sync() as env: | |
| result = env.reset() | |
| result = env.step(InterpArenaAction( | |
| red_type="append_suffix", | |
| blue_type="noop", | |
| )) | |
| """ | |
| action_type = InterpArenaAction | |
| observation_type = InterpArenaObservation | |
| state_type = InterpArenaState | |
| def _step_payload(self, action: InterpArenaAction) -> dict[str, Any]: | |
| if hasattr(action, "model_dump"): | |
| return action.model_dump() | |
| return dict(action) | |
| def _parse_result( | |
| self, payload: dict[str, Any] | |
| ) -> StepResult[InterpArenaObservation]: | |
| # Matches openenv…serialization.serialize_observation wire shape: | |
| # { "observation": {...}, "reward": optional, "done": bool } | |
| obs_inner: Dict[str, Any] = dict(payload.get("observation") or {}) | |
| if "done" in payload and "done" not in obs_inner: | |
| obs_inner["done"] = payload["done"] | |
| if "reward" in payload and "reward" not in obs_inner: | |
| obs_inner["reward"] = payload["reward"] | |
| obs = InterpArenaObservation.model_validate(obs_inner) | |
| return StepResult( | |
| observation=obs, | |
| reward=payload.get("reward"), | |
| done=bool(payload.get("done", False)), | |
| ) | |
| def _parse_state(self, payload: dict[str, Any]) -> InterpArenaState: | |
| data = payload.get("state", payload) | |
| if isinstance(data, InterpArenaState): | |
| return data | |
| if isinstance(data, dict): | |
| return InterpArenaState.model_validate(data) | |
| raise TypeError(f"Cannot parse state from {type(data)!r}") | |