File size: 2,668 Bytes
433f30e
 
 
a12d38f
 
38df389
 
 
 
a12d38f
38df389
a12d38f
38df389
a12d38f
 
38df389
 
 
 
a12d38f
 
433f30e
38df389
433f30e
 
 
 
38df389
 
 
 
 
 
 
 
 
 
 
 
 
 
433f30e
 
 
38df389
 
 
433f30e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""FastAPI server entry point for Interpretability Arena."""

import os
import sys

# OpenEnv runs sync reset() in a thread pool; the first `import transformer_lens` happens there.
# Loading TL here in the main import path ensures: (1) the same venv is used as training, and
# (2) the server process exits immediately if this stack is broken, instead of failing on first
# WebSocket reset with a cryptic remote error.
try:
    import torch  # noqa: F401 — transformer-lens / HF expect torch imported first
    from transformers import BertForPreTraining  # noqa: F401
    import transformer_lens  # noqa: F401
except Exception as e:
    sys.exit(
        "Arena server: could not import torch + transformer-lens (same stack as the training client). "
        "The process running uvicorn must use the project venv: `uv run uvicorn server.app:app --host 0.0.0.0 --port 8000`. "
        "If the error mentions Bert, reinstall pins: `uv sync` or `pip install -r server/requirements.txt --force-reinstall`, "
        "then stop any other process on that port and start again. "
        f"Original: {e}"
    )

from openenv.core.env_server import create_fastapi_app

from models import InterpArenaAction, InterpArenaObservation
from server.interp_arena_environment import InterpArenaEnvironment

# Prefer arena-specific Gradio UI when present; otherwise OpenEnv default (so clones work
# without server/web_playground.py, e.g. before that file was added).
try:
    from server.web_playground import create_arena_web_interface_app
except ModuleNotFoundError:
    from openenv.core.env_server import create_web_interface_app as _openenv_web_app

    if _openenv_web_app is None:
        sys.exit(
            "ENABLE_WEB_INTERFACE=true requires Gradio. Install `gradio` (or use the default "
            "deps from `uv sync`) or add `server/web_playground.py` from the repo."
        )
    create_arena_web_interface_app = _openenv_web_app

# OpenEnv expects the environment *class* (or factory), not an instance—
# the HTTP server instantiates it per its own lifecycle.
if os.environ.get("ENABLE_WEB_INTERFACE", "false").lower() == "true":
    app = create_arena_web_interface_app(
        InterpArenaEnvironment, InterpArenaAction, InterpArenaObservation
    )
else:
    app = create_fastapi_app(InterpArenaEnvironment, InterpArenaAction, InterpArenaObservation)


def main() -> None:
    """Console entry for OpenEnv multi-mode / `openenv validate` and `uv run server`."""
    import uvicorn

    port = int(os.environ.get("PORT", "8000"))
    host = os.environ.get("HOST", "0.0.0.0")
    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()