Spaces:
Sleeping
Sleeping
| """Custom Gradio playground for Interpretability Arena. | |
| OpenEnv's default UI labels fields by snake_case → Title Case, which reads as | |
| vague ("Red Type"). We use Pydantic ``Field(title=...)`` on ``InterpArenaAction`` | |
| and build the form from those titles, with Red / Blue section headers. | |
| One ``env.step()`` always carries **both** agents: every field in | |
| ``InterpArenaAction`` is sent together (Red + Blue in the same payload). | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| from typing import Any, Dict, List, Optional, Type | |
| import gradio as gr | |
| from fastapi import Body, FastAPI, HTTPException, status, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import RedirectResponse | |
| from openenv.core.env_server.gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME | |
| from openenv.core.env_server.gradio_ui import get_gradio_display_title, _format_observation | |
| from openenv.core.env_server.http_server import create_fastapi_app | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import Action, EnvironmentMetadata, Observation | |
| from openenv.core.env_server.web_interface import ( | |
| WebInterfaceManager, | |
| _extract_action_fields, | |
| _is_chat_env, | |
| get_quick_start_markdown, | |
| load_environment_metadata, | |
| ) | |
| def _arena_enrich_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: | |
| """Like OpenEnv's extractor, plus JSON Schema ``title`` / ``description``.""" | |
| fields = _extract_action_fields(action_cls) | |
| props = action_cls.model_json_schema().get("properties", {}) | |
| for f in fields: | |
| name = f["name"] | |
| fi = props.get(name, {}) | |
| f["title"] = fi.get("title") | |
| f["schema_description"] = (fi.get("description") or "").strip() | |
| return fields | |
| def _coerce_web_field(name: str, val: Any, field_type: str) -> Any: | |
| """Map Gradio widget values → JSON types ``InterpArenaAction`` expects.""" | |
| if val is None or val == "": | |
| return None | |
| if field_type == "checkbox": | |
| return bool(val) | |
| int_like = ( | |
| name.endswith("_layer") | |
| or name.endswith("_head") | |
| or name.endswith("_position") | |
| or name.endswith("_token_ids") | |
| or name.endswith("_ids") | |
| ) | |
| if field_type == "number" and int_like and isinstance(val, float): | |
| if val == int(val): | |
| return int(val) | |
| if name in ("red_target_token_ids", "blue_prohibited_token_ids") and isinstance(val, str): | |
| s = val.strip() | |
| if not s: | |
| return None | |
| try: | |
| if s.startswith("["): | |
| return json.loads(s) | |
| return [int(x.strip()) for x in s.split(",") if x.strip()] | |
| except (ValueError, json.JSONDecodeError): | |
| return val | |
| return val | |
| def build_arena_gradio_app( | |
| web_manager: Any, | |
| action_fields: List[Dict[str, Any]], | |
| metadata: Optional[EnvironmentMetadata], | |
| is_chat_env: bool, | |
| title: str = "OpenEnv Environment", | |
| quick_start_md: Optional[str] = None, | |
| ) -> gr.Blocks: | |
| """Gradio Blocks mirroring OpenEnv's ``build_gradio_app`` with clearer labels.""" | |
| from openenv.core.env_server.gradio_ui import _readme_section | |
| readme_content = _readme_section(metadata) | |
| display_title = get_gradio_display_title(metadata, fallback=title) | |
| async def reset_env(): | |
| try: | |
| data = await web_manager.reset_environment() | |
| obs_md = _format_observation(data) | |
| return ( | |
| obs_md, | |
| json.dumps(data, indent=2), | |
| "Environment reset successfully.", | |
| ) | |
| except Exception as e: | |
| return ("", "", f"Error: {e}") | |
| def _step_with_action(action_data: Dict[str, Any]): | |
| async def _run(): | |
| try: | |
| data = await web_manager.step_environment(action_data) | |
| obs_md = _format_observation(data) | |
| return ( | |
| obs_md, | |
| json.dumps(data, indent=2), | |
| "Step complete.", | |
| ) | |
| except Exception as e: | |
| return ("", "", f"Error: {e}") | |
| return _run | |
| async def step_chat(message: str): | |
| if not (message or str(message).strip()): | |
| return ("", "", "Please enter an action message.") | |
| action = {"message": str(message).strip()} | |
| return await _step_with_action(action)() | |
| def get_state_sync(): | |
| try: | |
| data = web_manager.get_state() | |
| return json.dumps(data, indent=2) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| with gr.Blocks(title=display_title) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_classes="col-left"): | |
| if quick_start_md: | |
| with gr.Accordion("Quick Start", open=True): | |
| gr.Markdown(quick_start_md) | |
| with gr.Accordion("README", open=False): | |
| gr.Markdown(readme_content) | |
| with gr.Column(scale=2, elem_classes="col-right"): | |
| obs_display = gr.Markdown( | |
| value=( | |
| "# Playground\n\n" | |
| "Click **Reset** to start a new episode.\n\n" | |
| "**Each Step sends Red and Blue together** — fill both sides " | |
| "(or leave Blue on *noop*). One combined action → one model forward." | |
| ), | |
| ) | |
| with gr.Group(): | |
| if is_chat_env: | |
| action_input = gr.Textbox( | |
| label="Action message", | |
| placeholder="e.g. Enter your message...", | |
| ) | |
| step_inputs = [action_input] | |
| step_fn = step_chat | |
| else: | |
| step_inputs = [] | |
| seen_red = False | |
| seen_blue = False | |
| for field in action_fields: | |
| name = field["name"] | |
| if name.startswith("red_") and not seen_red: | |
| gr.Markdown( | |
| "#### Red (attacker)\n" | |
| "Pick an intervention type, then only the parameters that action uses " | |
| "(others can stay empty)." | |
| ) | |
| seen_red = True | |
| if name.startswith("blue_") and not seen_blue: | |
| gr.Markdown( | |
| "#### Blue (defender)\n" | |
| "Pick how Blue intervenes on the **same** forward pass as Red." | |
| ) | |
| seen_blue = True | |
| field_type = field.get("type", "text") | |
| label = field.get("title") or name.replace("_", " ").title() | |
| placeholder = field.get("placeholder", "") | |
| info = (field.get("schema_description") or "")[:500] or None | |
| common_kw: dict[str, Any] = {"label": label} | |
| if info: | |
| common_kw["info"] = info | |
| if field_type == "checkbox": | |
| inp = gr.Checkbox(**common_kw) | |
| elif field_type == "number": | |
| inp = gr.Number(**common_kw) | |
| elif field_type == "select": | |
| choices = field.get("choices") or [] | |
| inp = gr.Dropdown( | |
| choices=choices, | |
| allow_custom_value=False, | |
| **common_kw, | |
| ) | |
| elif field_type in ("textarea", "tensor"): | |
| inp = gr.Textbox( | |
| placeholder=placeholder, | |
| lines=3, | |
| **common_kw, | |
| ) | |
| else: | |
| inp = gr.Textbox(placeholder=placeholder, **common_kw) | |
| step_inputs.append(inp) | |
| async def step_form(*values): | |
| if not action_fields: | |
| return await _step_with_action({})() | |
| action_data: Dict[str, Any] = {} | |
| for i, field in enumerate(action_fields): | |
| if i >= len(values): | |
| break | |
| fname = field["name"] | |
| raw = values[i] | |
| ft = field.get("type", "text") | |
| coerced = _coerce_web_field(fname, raw, ft) | |
| if coerced is not None and coerced != "": | |
| action_data[fname] = coerced | |
| return await _step_with_action(action_data)() | |
| step_fn = step_form | |
| with gr.Row(): | |
| step_btn = gr.Button("Step", variant="primary") | |
| reset_btn = gr.Button("Reset", variant="secondary") | |
| state_btn = gr.Button("Get state", variant="secondary") | |
| with gr.Row(): | |
| status = gr.Textbox(label="Status", interactive=False) | |
| raw_json = gr.Code( | |
| label="Raw JSON response", | |
| language="json", | |
| interactive=False, | |
| ) | |
| reset_btn.click(fn=reset_env, outputs=[obs_display, raw_json, status]) | |
| step_btn.click(fn=step_fn, inputs=step_inputs, outputs=[obs_display, raw_json, status]) | |
| if is_chat_env: | |
| action_input.submit(fn=step_fn, inputs=step_inputs, outputs=[obs_display, raw_json, status]) | |
| state_btn.click(fn=get_state_sync, outputs=[raw_json]) | |
| return demo | |
| def create_arena_web_interface_app( | |
| env: Environment | Type[Environment], | |
| action_cls: Type[Action], | |
| observation_cls: Type[Observation], | |
| env_name: Optional[str] = None, | |
| max_concurrent_envs: Optional[int] = None, | |
| concurrency_config: Optional[Any] = None, | |
| ) -> FastAPI: | |
| """Same as OpenEnv ``create_web_interface_app`` but with arena-specific Gradio UI.""" | |
| app = create_fastapi_app( | |
| env, action_cls, observation_cls, max_concurrent_envs, concurrency_config | |
| ) | |
| metadata = load_environment_metadata(env, env_name) | |
| web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) | |
| async def web_root(): | |
| return RedirectResponse(url="/web/") | |
| async def web_root_no_slash(): | |
| return RedirectResponse(url="/web/") | |
| async def web_metadata(): | |
| return web_manager.metadata.model_dump() | |
| async def websocket_ui_endpoint(websocket: WebSocket): | |
| await web_manager.connect_websocket(websocket) | |
| try: | |
| while True: | |
| await websocket.receive_text() | |
| except WebSocketDisconnect: | |
| await web_manager.disconnect_websocket(websocket) | |
| async def web_reset(request: Optional[Dict[str, Any]] = Body(default=None)): | |
| return await web_manager.reset_environment(request) | |
| async def web_step(request: Dict[str, Any]): | |
| if "message" in request: | |
| message = request["message"] | |
| if hasattr(web_manager.env, "message_to_action"): | |
| action = web_manager.env.message_to_action(message) | |
| if hasattr(action, "tokens"): | |
| action_data = {"tokens": action.tokens.tolist()} | |
| else: | |
| action_data = action.model_dump(exclude={"metadata"}) | |
| else: | |
| action_data = {"message": message} | |
| else: | |
| action_data = request.get("action", {}) | |
| return await web_manager.step_environment(action_data) | |
| async def web_state(): | |
| try: | |
| return web_manager.get_state() | |
| except RuntimeError as exc: | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail=str(exc), | |
| ) from exc | |
| action_fields = _arena_enrich_action_fields(action_cls) | |
| is_chat_env = _is_chat_env(action_cls) | |
| quick_start_md = get_quick_start_markdown(metadata, action_cls, observation_cls) | |
| gradio_blocks = build_arena_gradio_app( | |
| web_manager, | |
| action_fields, | |
| metadata, | |
| is_chat_env, | |
| title=metadata.name, | |
| quick_start_md=quick_start_md, | |
| ) | |
| return gr.mount_gradio_app( | |
| app, | |
| gradio_blocks, | |
| path="/web", | |
| theme=OPENENV_GRADIO_THEME, | |
| css=OPENENV_GRADIO_CSS, | |
| ) | |