siege / server /web_playground.py
BART-ender's picture
Upload folder using huggingface_hub
38df389 verified
"""Custom Gradio playground for Interpretability Arena.
OpenEnv's default UI labels fields by snake_case → Title Case, which reads as
vague ("Red Type"). We use Pydantic ``Field(title=...)`` on ``InterpArenaAction``
and build the form from those titles, with Red / Blue section headers.
One ``env.step()`` always carries **both** agents: every field in
``InterpArenaAction`` is sent together (Red + Blue in the same payload).
"""
from __future__ import annotations
import asyncio
import json
from typing import Any, Dict, List, Optional, Type
import gradio as gr
from fastapi import Body, FastAPI, HTTPException, status, WebSocket, WebSocketDisconnect
from fastapi.responses import RedirectResponse
from openenv.core.env_server.gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
from openenv.core.env_server.gradio_ui import get_gradio_display_title, _format_observation
from openenv.core.env_server.http_server import create_fastapi_app
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import Action, EnvironmentMetadata, Observation
from openenv.core.env_server.web_interface import (
WebInterfaceManager,
_extract_action_fields,
_is_chat_env,
get_quick_start_markdown,
load_environment_metadata,
)
def _arena_enrich_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
"""Like OpenEnv's extractor, plus JSON Schema ``title`` / ``description``."""
fields = _extract_action_fields(action_cls)
props = action_cls.model_json_schema().get("properties", {})
for f in fields:
name = f["name"]
fi = props.get(name, {})
f["title"] = fi.get("title")
f["schema_description"] = (fi.get("description") or "").strip()
return fields
def _coerce_web_field(name: str, val: Any, field_type: str) -> Any:
"""Map Gradio widget values → JSON types ``InterpArenaAction`` expects."""
if val is None or val == "":
return None
if field_type == "checkbox":
return bool(val)
int_like = (
name.endswith("_layer")
or name.endswith("_head")
or name.endswith("_position")
or name.endswith("_token_ids")
or name.endswith("_ids")
)
if field_type == "number" and int_like and isinstance(val, float):
if val == int(val):
return int(val)
if name in ("red_target_token_ids", "blue_prohibited_token_ids") and isinstance(val, str):
s = val.strip()
if not s:
return None
try:
if s.startswith("["):
return json.loads(s)
return [int(x.strip()) for x in s.split(",") if x.strip()]
except (ValueError, json.JSONDecodeError):
return val
return val
def build_arena_gradio_app(
web_manager: Any,
action_fields: List[Dict[str, Any]],
metadata: Optional[EnvironmentMetadata],
is_chat_env: bool,
title: str = "OpenEnv Environment",
quick_start_md: Optional[str] = None,
) -> gr.Blocks:
"""Gradio Blocks mirroring OpenEnv's ``build_gradio_app`` with clearer labels."""
from openenv.core.env_server.gradio_ui import _readme_section
readme_content = _readme_section(metadata)
display_title = get_gradio_display_title(metadata, fallback=title)
async def reset_env():
try:
data = await web_manager.reset_environment()
obs_md = _format_observation(data)
return (
obs_md,
json.dumps(data, indent=2),
"Environment reset successfully.",
)
except Exception as e:
return ("", "", f"Error: {e}")
def _step_with_action(action_data: Dict[str, Any]):
async def _run():
try:
data = await web_manager.step_environment(action_data)
obs_md = _format_observation(data)
return (
obs_md,
json.dumps(data, indent=2),
"Step complete.",
)
except Exception as e:
return ("", "", f"Error: {e}")
return _run
async def step_chat(message: str):
if not (message or str(message).strip()):
return ("", "", "Please enter an action message.")
action = {"message": str(message).strip()}
return await _step_with_action(action)()
def get_state_sync():
try:
data = web_manager.get_state()
return json.dumps(data, indent=2)
except Exception as e:
return f"Error: {e}"
with gr.Blocks(title=display_title) as demo:
with gr.Row():
with gr.Column(scale=1, elem_classes="col-left"):
if quick_start_md:
with gr.Accordion("Quick Start", open=True):
gr.Markdown(quick_start_md)
with gr.Accordion("README", open=False):
gr.Markdown(readme_content)
with gr.Column(scale=2, elem_classes="col-right"):
obs_display = gr.Markdown(
value=(
"# Playground\n\n"
"Click **Reset** to start a new episode.\n\n"
"**Each Step sends Red and Blue together** — fill both sides "
"(or leave Blue on *noop*). One combined action → one model forward."
),
)
with gr.Group():
if is_chat_env:
action_input = gr.Textbox(
label="Action message",
placeholder="e.g. Enter your message...",
)
step_inputs = [action_input]
step_fn = step_chat
else:
step_inputs = []
seen_red = False
seen_blue = False
for field in action_fields:
name = field["name"]
if name.startswith("red_") and not seen_red:
gr.Markdown(
"#### Red (attacker)\n"
"Pick an intervention type, then only the parameters that action uses "
"(others can stay empty)."
)
seen_red = True
if name.startswith("blue_") and not seen_blue:
gr.Markdown(
"#### Blue (defender)\n"
"Pick how Blue intervenes on the **same** forward pass as Red."
)
seen_blue = True
field_type = field.get("type", "text")
label = field.get("title") or name.replace("_", " ").title()
placeholder = field.get("placeholder", "")
info = (field.get("schema_description") or "")[:500] or None
common_kw: dict[str, Any] = {"label": label}
if info:
common_kw["info"] = info
if field_type == "checkbox":
inp = gr.Checkbox(**common_kw)
elif field_type == "number":
inp = gr.Number(**common_kw)
elif field_type == "select":
choices = field.get("choices") or []
inp = gr.Dropdown(
choices=choices,
allow_custom_value=False,
**common_kw,
)
elif field_type in ("textarea", "tensor"):
inp = gr.Textbox(
placeholder=placeholder,
lines=3,
**common_kw,
)
else:
inp = gr.Textbox(placeholder=placeholder, **common_kw)
step_inputs.append(inp)
async def step_form(*values):
if not action_fields:
return await _step_with_action({})()
action_data: Dict[str, Any] = {}
for i, field in enumerate(action_fields):
if i >= len(values):
break
fname = field["name"]
raw = values[i]
ft = field.get("type", "text")
coerced = _coerce_web_field(fname, raw, ft)
if coerced is not None and coerced != "":
action_data[fname] = coerced
return await _step_with_action(action_data)()
step_fn = step_form
with gr.Row():
step_btn = gr.Button("Step", variant="primary")
reset_btn = gr.Button("Reset", variant="secondary")
state_btn = gr.Button("Get state", variant="secondary")
with gr.Row():
status = gr.Textbox(label="Status", interactive=False)
raw_json = gr.Code(
label="Raw JSON response",
language="json",
interactive=False,
)
reset_btn.click(fn=reset_env, outputs=[obs_display, raw_json, status])
step_btn.click(fn=step_fn, inputs=step_inputs, outputs=[obs_display, raw_json, status])
if is_chat_env:
action_input.submit(fn=step_fn, inputs=step_inputs, outputs=[obs_display, raw_json, status])
state_btn.click(fn=get_state_sync, outputs=[raw_json])
return demo
def create_arena_web_interface_app(
env: Environment | Type[Environment],
action_cls: Type[Action],
observation_cls: Type[Observation],
env_name: Optional[str] = None,
max_concurrent_envs: Optional[int] = None,
concurrency_config: Optional[Any] = None,
) -> FastAPI:
"""Same as OpenEnv ``create_web_interface_app`` but with arena-specific Gradio UI."""
app = create_fastapi_app(
env, action_cls, observation_cls, max_concurrent_envs, concurrency_config
)
metadata = load_environment_metadata(env, env_name)
web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
@app.get("/", include_in_schema=False)
async def web_root():
return RedirectResponse(url="/web/")
@app.get("/web", include_in_schema=False)
async def web_root_no_slash():
return RedirectResponse(url="/web/")
@app.get("/web/metadata")
async def web_metadata():
return web_manager.metadata.model_dump()
@app.websocket("/ws/ui")
async def websocket_ui_endpoint(websocket: WebSocket):
await web_manager.connect_websocket(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
await web_manager.disconnect_websocket(websocket)
@app.post("/web/reset")
async def web_reset(request: Optional[Dict[str, Any]] = Body(default=None)):
return await web_manager.reset_environment(request)
@app.post("/web/step")
async def web_step(request: Dict[str, Any]):
if "message" in request:
message = request["message"]
if hasattr(web_manager.env, "message_to_action"):
action = web_manager.env.message_to_action(message)
if hasattr(action, "tokens"):
action_data = {"tokens": action.tokens.tolist()}
else:
action_data = action.model_dump(exclude={"metadata"})
else:
action_data = {"message": message}
else:
action_data = request.get("action", {})
return await web_manager.step_environment(action_data)
@app.get("/web/state")
async def web_state():
try:
return web_manager.get_state()
except RuntimeError as exc:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(exc),
) from exc
action_fields = _arena_enrich_action_fields(action_cls)
is_chat_env = _is_chat_env(action_cls)
quick_start_md = get_quick_start_markdown(metadata, action_cls, observation_cls)
gradio_blocks = build_arena_gradio_app(
web_manager,
action_fields,
metadata,
is_chat_env,
title=metadata.name,
quick_start_md=quick_start_md,
)
return gr.mount_gradio_app(
app,
gradio_blocks,
path="/web",
theme=OPENENV_GRADIO_THEME,
css=OPENENV_GRADIO_CSS,
)