sergiopaniego HF Staff commited on
Commit
4e38d2f
·
verified ·
1 Parent(s): 2928bc6

Upload folder using huggingface_hub

Browse files
Files changed (13) hide show
  1. Dockerfile +84 -0
  2. README.md +198 -5
  3. __init__.py +26 -0
  4. client.py +117 -0
  5. models.py +57 -0
  6. openenv.yaml +7 -0
  7. pyproject.toml +52 -0
  8. rewards.py +129 -0
  9. server/__init__.py +11 -0
  10. server/app.py +90 -0
  11. server/environment.py +320 -0
  12. server/run_local.sh +7 -0
  13. uv.lock +0 -0
Dockerfile ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local src/core)
10
+ # - Standalone environments (with openenv-core from pip)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Build argument to control whether we're building standalone or in-repo
19
+ ARG BUILD_MODE=in-repo
20
+ ARG ENV_NAME=textarena
21
+
22
+ # Copy environment code (always at root of build context)
23
+ COPY . /app/env
24
+
25
+ # For in-repo builds, openenv-core is already in the pyproject.toml dependencies
26
+ # For standalone builds, openenv-core will be installed from pip via pyproject.toml
27
+ WORKDIR /app/env
28
+
29
+ # Ensure uv is available (for local builds where base image lacks it)
30
+ RUN if ! command -v uv >/dev/null 2>&1; then \
31
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
32
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
33
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
34
+ fi
35
+
36
+ # Install system libraries required by TextArena (cv2 needs libGL, glib)
37
+ # Also install git for building from git repos
38
+ RUN apt-get update && apt-get install -y --no-install-recommends \
39
+ libgl1 \
40
+ libglib2.0-0 \
41
+ git \
42
+ && rm -rf /var/lib/apt/lists/*
43
+
44
+ # Install dependencies using uv sync
45
+ # If uv.lock exists, use it; otherwise resolve on the fly
46
+ RUN --mount=type=cache,target=/root/.cache/uv \
47
+ if [ -f uv.lock ]; then \
48
+ uv sync --frozen --no-install-project --no-editable; \
49
+ else \
50
+ uv sync --no-install-project --no-editable; \
51
+ fi
52
+
53
+ RUN --mount=type=cache,target=/root/.cache/uv \
54
+ if [ -f uv.lock ]; then \
55
+ uv sync --frozen --no-editable; \
56
+ else \
57
+ uv sync --no-editable; \
58
+ fi
59
+
60
+ # Final runtime stage
61
+ FROM ${BASE_IMAGE}
62
+
63
+ WORKDIR /app
64
+
65
+ # Copy the virtual environment from builder
66
+ COPY --from=builder /app/env/.venv /app/.venv
67
+
68
+ # Copy the environment code
69
+ COPY --from=builder /app/env /app/env
70
+
71
+ # Set PATH to use the virtual environment
72
+ ENV PATH="/app/.venv/bin:$PATH"
73
+
74
+ # Set PYTHONPATH so imports work correctly
75
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
76
+
77
+ # Health check
78
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
79
+ CMD curl -f http://localhost:8000/health || exit 1
80
+
81
+ # Run the FastAPI server
82
+ # The module path is constructed to work with the /app/env structure
83
+ ENV ENABLE_WEB_INTERFACE=true
84
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,203 @@
1
  ---
2
- title: Textarena2
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: TextArena Environment Server
3
+ emoji: 🎮
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # TextArena Environment
15
+
16
+ A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
+
18
+ ## Quick Start
19
+
20
+ The simplest way to use the TextArena environment is through the `TextArenaEnv` class:
21
+
22
+ ```python
23
+ from textarena import TextArenaAction, TextArenaEnv
24
+
25
+ try:
26
+ # Create environment from Docker image
27
+ textarenaenv = TextArenaEnv.from_docker_image("textarena-env:latest")
28
+
29
+ # Reset
30
+ result = textArenaEnv.reset()
31
+ print(f"Reset: {result.observation.echoed_message}")
32
+
33
+ # Send multiple messages
34
+ messages = ["Hello, World!", "Testing echo", "Final message"]
35
+
36
+ for msg in messages:
37
+ result = textArenaEnv.step(TextArenaAction(message=msg))
38
+ print(f"Sent: '{msg}'")
39
+ print(f" → Echoed: '{result.observation.echoed_message}'")
40
+ print(f" → Length: {result.observation.message_length}")
41
+ print(f" → Reward: {result.reward}")
42
+
43
+ finally:
44
+ # Always clean up
45
+ textArenaEnv.close()
46
+ ```
47
+
48
+ That's it! The `TextArenaEnv.from_docker_image()` method handles:
49
+ - Starting the Docker container
50
+ - Waiting for the server to be ready
51
+ - Connecting to the environment
52
+ - Container cleanup when you call `close()`
53
+
54
+ ## Building the Docker Image
55
+
56
+ Before using the environment, you need to build the Docker image:
57
+
58
+ ```bash
59
+ # From project root
60
+ docker build -t textarena-env:latest -f server/Dockerfile .
61
+ ```
62
+
63
+ ## Deploying to Hugging Face Spaces
64
+
65
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
66
+
67
+ ```bash
68
+ # From the environment directory (where openenv.yaml is located)
69
+ openenv push
70
+
71
+ # Or specify options
72
+ openenv push --namespace my-org --private
73
+ ```
74
+
75
+ The `openenv push` command will:
76
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
+ 3. Upload to Hugging Face (ensuring you're logged in)
79
+
80
+ ### Prerequisites
81
+
82
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
+
84
+ ### Options
85
+
86
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
+ - `--private`: Deploy the space as private (default: public)
90
+
91
+ ### Examples
92
+
93
+ ```bash
94
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
+ openenv push
96
+
97
+ # Push to a specific repository
98
+ openenv push --repo-id my-org/my-env
99
+
100
+ # Push with a custom base image
101
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
102
+
103
+ # Push as a private space
104
+ openenv push --private
105
+
106
+ # Combine options
107
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
+ ```
109
+
110
+ After deployment, your space will be available at:
111
+ `https://huggingface.co/spaces/<repo-id>`
112
+
113
+ The deployed space includes:
114
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
+ - **Health Check** at `/health` - Container health monitoring
117
+
118
+ ## Environment Details
119
+
120
+ ### Action
121
+ **TextArenaAction**: Contains a single field
122
+ - `message` (str) - The message to echo back
123
+
124
+ ### Observation
125
+ **TextArenaObservation**: Contains the echo response and metadata
126
+ - `echoed_message` (str) - The message echoed back
127
+ - `message_length` (int) - Length of the message
128
+ - `reward` (float) - Reward based on message length (length × 0.1)
129
+ - `done` (bool) - Always False for echo environment
130
+ - `metadata` (dict) - Additional info like step count
131
+
132
+ ### Reward
133
+ The reward is calculated as: `message_length × 0.1`
134
+ - "Hi" → reward: 0.2
135
+ - "Hello, World!" → reward: 1.3
136
+ - Empty message → reward: 0.0
137
+
138
+ ## Advanced Usage
139
+
140
+ ### Connecting to an Existing Server
141
+
142
+ If you already have a TextArena environment server running, you can connect directly:
143
+
144
+ ```python
145
+ from textarena import TextArenaEnv
146
+
147
+ # Connect to existing server
148
+ textarenaenv = TextArenaEnv(base_url="<ENV_HTTP_URL_HERE>")
149
+
150
+ # Use as normal
151
+ result = textarenaenv.reset()
152
+ result = textarenaenv.step(TextArenaAction(message="Hello!"))
153
+ ```
154
+
155
+ Note: When connecting to an existing server, `textarenaenv.close()` will NOT stop the server.
156
+
157
+ ## Development & Testing
158
+
159
+ ### Direct Environment Testing
160
+
161
+ Test the environment logic directly without starting the HTTP server:
162
+
163
+ ```bash
164
+ # From the server directory
165
+ python3 server/textarena_environment.py
166
+ ```
167
+
168
+ This verifies that:
169
+ - Environment resets correctly
170
+ - Step executes actions properly
171
+ - State tracking works
172
+ - Rewards are calculated correctly
173
+
174
+ ### Running Locally
175
+
176
+ Run the server locally for development:
177
+
178
+ ```bash
179
+ # Install dependencies
180
+ uv venv && source .venv/bin/activate
181
+ uv pip install -e .
182
+
183
+ # Start the server (use python -m to ensure venv Python is used)
184
+ python -m uvicorn server.app:app --reload
185
+ ```
186
+
187
+ ## Project Structure
188
+
189
+ ```
190
+ textarena/
191
+ ├── __init__.py # Module exports
192
+ ├── README.md # This file
193
+ ├── openenv.yaml # OpenEnv manifest
194
+ ├── pyproject.toml # Project metadata and dependencies
195
+ ├── uv.lock # Locked dependencies (generated)
196
+ ├── client.py # TextArenaEnv client implementation
197
+ ├── models.py # Action and Observation models
198
+ └── server/
199
+ ├── __init__.py # Server module exports
200
+ ├── textarena_environment.py # Core environment logic
201
+ ├── app.py # FastAPI application
202
+ └── Dockerfile # Container image definition
203
+ ```
__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """TextArena environment integration for OpenEnv."""
8
+
9
+ from .client import TextArenaEnv
10
+ from .models import (
11
+ TextArenaAction,
12
+ TextArenaMessage,
13
+ TextArenaObservation,
14
+ TextArenaState,
15
+ )
16
+ from .rewards import RewardProvider, build_reward_providers
17
+
18
+ __all__ = [
19
+ "TextArenaEnv",
20
+ "TextArenaAction",
21
+ "TextArenaObservation",
22
+ "TextArenaState",
23
+ "TextArenaMessage",
24
+ "RewardProvider",
25
+ "build_reward_providers",
26
+ ]
client.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ TextArena Environment HTTP Client.
9
+
10
+ This module provides the client for connecting to a TextArena Environment server
11
+ over HTTP.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any, Dict
17
+
18
+ from openenv.core.client_types import StepResult
19
+ from openenv.core.env_client import EnvClient
20
+
21
+ from .models import (
22
+ TextArenaAction,
23
+ TextArenaMessage,
24
+ TextArenaObservation,
25
+ TextArenaState,
26
+ )
27
+
28
+
29
+ class TextArenaEnv(EnvClient[TextArenaAction, TextArenaObservation, TextArenaState]):
30
+ """
31
+ HTTP client for the TextArena Environment.
32
+
33
+ This client connects to a TextArenaEnvironment HTTP server and provides
34
+ methods to interact with it: reset(), step(), and state access.
35
+
36
+ Example:
37
+ >>> # Connect to a running server
38
+ >>> client = TextArenaEnv(base_url="http://localhost:8000")
39
+ >>> result = client.reset()
40
+ >>> print(result.observation.echoed_message)
41
+ >>>
42
+ >>> # Send a message
43
+ >>> result = client.step(TextArenaAction(message="Hello!"))
44
+ >>> print(result.observation.echoed_message)
45
+ >>> print(result.reward)
46
+
47
+ Example with Docker:
48
+ >>> # Automatically start container and connect
49
+ >>> client = TextArenaEnv.from_docker_image("textarena-env:latest")
50
+ >>> result = client.reset()
51
+ >>> result = client.step(TextArenaAction(message="Test"))
52
+ """
53
+
54
+ def _step_payload(self, action: TextArenaAction) -> Dict:
55
+ """
56
+ Convert TextArenaAction to JSON payload for step request.
57
+
58
+ Args:
59
+ action: TextArenaAction instance
60
+
61
+ Returns:
62
+ Dictionary representation suitable for JSON encoding
63
+ """
64
+ return {
65
+ "message": action.message,
66
+ }
67
+
68
+ def _parse_result(self, payload: Dict) -> StepResult[TextArenaObservation]:
69
+ """
70
+ Parse server response into StepResult[TextArenaObservation].
71
+
72
+ Args:
73
+ payload: JSON response from server
74
+
75
+ Returns:
76
+ StepResult with TextArenaObservation
77
+ """
78
+ obs_data = payload.get("observation", {})
79
+ messages_payload = obs_data.get("messages", [])
80
+ messages = [
81
+ TextArenaMessage(
82
+ sender_id=item.get("sender_id", -1),
83
+ content=item.get("content", ""),
84
+ category=item.get("category", "MESSAGE"),
85
+ )
86
+ for item in messages_payload
87
+ if isinstance(item, dict)
88
+ ]
89
+
90
+ observation = TextArenaObservation(
91
+ prompt=obs_data.get("prompt", ""),
92
+ messages=messages,
93
+ current_player_id=obs_data.get("current_player_id", 0),
94
+ legal_players=obs_data.get("legal_players", []),
95
+ info=obs_data.get("info", {}),
96
+ reward=payload.get("reward"),
97
+ done=payload.get("done", False),
98
+ metadata=obs_data.get("metadata", {}),
99
+ )
100
+ return StepResult(
101
+ observation=observation,
102
+ reward=payload.get("reward"),
103
+ done=payload.get("done", False),
104
+ )
105
+
106
+ def _parse_state(self, payload: Dict[str, Any]) -> TextArenaState:
107
+ return TextArenaState(
108
+ episode_id=payload.get("episode_id"),
109
+ step_count=payload.get("step_count", 0),
110
+ env_id=payload.get("env_id", "unknown"),
111
+ num_players=payload.get("num_players", 1),
112
+ max_turns=payload.get("max_turns"),
113
+ turn=payload.get("turn", 0),
114
+ last_reward=payload.get("last_reward", 0.0),
115
+ last_info=payload.get("last_info", {}),
116
+ raw_state=payload.get("raw_state", {}),
117
+ )
models.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the TextArena Environment.
9
+
10
+ The textarena environment is a simple test environment that echoes back messages.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from pydantic import Field
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ from pydantic import BaseModel, Field
19
+
20
+ from openenv.core.env_server.types import Action, Observation, State
21
+
22
+
23
+ class TextArenaMessage(BaseModel):
24
+ """Single message observed by a player."""
25
+
26
+ sender_id: int
27
+ content: str
28
+ category: str
29
+
30
+
31
+ class TextArenaAction(Action):
32
+ """Action issued by the agent for TextArena games."""
33
+
34
+ message: str
35
+
36
+
37
+ class TextArenaObservation(Observation):
38
+ """Observation returned from any TextArena game."""
39
+
40
+ prompt: str
41
+ messages: List[TextArenaMessage] = Field(default_factory=list)
42
+ current_player_id: int = 0
43
+ legal_players: List[int] = Field(default_factory=list)
44
+ info: Dict[str, Any] = Field(default_factory=dict)
45
+
46
+
47
+ class TextArenaState(State):
48
+ """Structured state snapshot for the server."""
49
+
50
+ env_id: str
51
+ num_players: int
52
+ max_turns: Optional[int] = None
53
+ turn: int = 0
54
+ last_reward: float = 0.0
55
+ last_info: Dict[str, Any] = Field(default_factory=dict)
56
+ raw_state: Dict[str, Any] = Field(default_factory=dict)
57
+
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: textarena
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
pyproject.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-textarena"
13
+ version = "0.1.0"
14
+ description = "TextArena environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv dependencies (required for server functionality)
18
+ "openenv-core @ git+https://github.com/meta-pytorch/OpenEnv.git@main",
19
+ "fastapi>=0.115.0",
20
+ "pydantic>=2.0.0",
21
+ "uvicorn>=0.24.0",
22
+ "requests>=2.31.0",
23
+ # Environment-specific dependencies
24
+ # Add all dependencies needed for your environment here
25
+ # Examples:
26
+ # "numpy>=1.19.0",
27
+ # "torch>=2.0.0",
28
+ # "gymnasium>=0.29.0",
29
+ # "openspiel>=1.0.0",
30
+ # "smolagents>=1.22.0,<2",
31
+ "textarena>=0.6.1",
32
+ "nltk>=3.9.2",
33
+ ]
34
+
35
+ [project.optional-dependencies]
36
+ dev = [
37
+ "pytest>=8.0.0",
38
+ "pytest-cov>=4.0.0",
39
+ ]
40
+
41
+ [project.scripts]
42
+ # Server entry point - enables running via: uv run --project . server
43
+ # or: python -m textarena.server.app
44
+ server = "textarena.server.app:main"
45
+
46
+ [tool.setuptools]
47
+ # Explicitly list packages - "textarena_env" maps to current dir
48
+ packages = ["textarena_env", "textarena_env.server"]
49
+ package-dir = {"textarena_env" = ".", "textarena_env.server" = "server"}
50
+
51
+
52
+
rewards.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward provider utilities for TextArena environments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import Dict, List, Protocol, Tuple
7
+
8
+ try:
9
+ from textarena_env.models import TextArenaAction, TextArenaObservation
10
+ except ImportError:
11
+ from models import TextArenaAction, TextArenaObservation
12
+
13
+
14
+ class RewardProvider(Protocol):
15
+ """Interface for computing auxiliary reward signals."""
16
+
17
+ def reset(self) -> None:
18
+ """Clear any internal state before a new episode."""
19
+
20
+ def compute(self, *, action: TextArenaAction, observation: TextArenaObservation) -> Dict[str, float]:
21
+ """Return a mapping of reward names to float values for the step."""
22
+
23
+
24
+ def build_reward_providers(env_id: str) -> List[RewardProvider]:
25
+ """Instantiate reward providers appropriate for the given environment."""
26
+
27
+ providers: List[RewardProvider] = []
28
+ if env_id == "Wordle-v0":
29
+ providers.append(_WordleRewardProvider())
30
+ return providers
31
+
32
+
33
+ _WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]")
34
+
35
+
36
+ def extract_guess(text: str) -> str:
37
+ """Normalize a Wordle guess string from arbitrary text."""
38
+
39
+ match = _WORDLE_GUESS_PATTERN.search(text)
40
+ if match:
41
+ return match.group(0).lower()
42
+
43
+ cleaned = re.sub(r"[^a-z]", "", text.lower())
44
+ if len(cleaned) >= 5:
45
+ return f"[{cleaned[:5]}]"
46
+ return "[dunno]"
47
+
48
+
49
+ def extract_wordle_feedback(observation: TextArenaObservation) -> str:
50
+ """Pull the latest feedback text from a Wordle observation."""
51
+
52
+ for message in reversed(observation.messages):
53
+ content = message.content.strip()
54
+ if "Feedback:" in content:
55
+ return content.split("Feedback:", 1)[-1].strip()
56
+ return ""
57
+
58
+
59
+ def extract_feedback_counts(feedback: str) -> Tuple[int, int]:
60
+ """Return counts of green (G) and yellow (Y) markers from feedback."""
61
+
62
+ if not feedback:
63
+ return (0, 0)
64
+
65
+ lines = [line.strip() for line in feedback.split("\n") if line.strip()]
66
+ if len(lines) < 2:
67
+ return (0, 0)
68
+
69
+ for line in reversed(lines):
70
+ normalized = line.replace(" ", "")
71
+ if normalized and all(c in "GYX" for c in normalized):
72
+ green = normalized.count("G")
73
+ yellow = normalized.count("Y")
74
+ return (green, yellow)
75
+
76
+ return (0, 0)
77
+
78
+
79
+ class _WordleRewardProvider:
80
+ """Reward provider that mirrors the GRPO Wordle heuristics."""
81
+
82
+ SIGNAL_MAP = {
83
+ "greens": "wordle.greens",
84
+ "yellows": "wordle.yellows",
85
+ "repetitions": "wordle.repetitions",
86
+ "correct": "wordle.correct",
87
+ }
88
+
89
+ def __init__(self) -> None:
90
+ self._guess_history: Dict[str, int] = {}
91
+
92
+ def reset(self) -> None:
93
+ self._guess_history.clear()
94
+
95
+ def compute(self, *, action: TextArenaAction, observation: TextArenaObservation) -> Dict[str, float]:
96
+ guess = extract_guess(action.message)
97
+ feedback = extract_wordle_feedback(observation)
98
+
99
+ normalized_guess = guess if guess and guess != "[dunno]" else ""
100
+ previous_occurrences = self._guess_history.get(normalized_guess, 0) if normalized_guess else 0
101
+
102
+ green_score = 0.0
103
+ yellow_score = 0.0
104
+ if feedback:
105
+ green_count, yellow_count = extract_feedback_counts(feedback)
106
+ green_score = green_count / 5.0
107
+ yellow_score = yellow_count / 5.0
108
+
109
+ repetition_score = 1.0 - previous_occurrences
110
+ correct_score = float(observation.reward or 0.0)
111
+
112
+ if normalized_guess:
113
+ self._guess_history[normalized_guess] = previous_occurrences + 1
114
+
115
+ return {
116
+ self.SIGNAL_MAP["greens"]: float(green_score),
117
+ self.SIGNAL_MAP["yellows"]: float(yellow_score),
118
+ self.SIGNAL_MAP["repetitions"]: float(repetition_score),
119
+ self.SIGNAL_MAP["correct"]: float(correct_score),
120
+ }
121
+
122
+
123
+ __all__ = [
124
+ "RewardProvider",
125
+ "build_reward_providers",
126
+ "extract_feedback_counts",
127
+ "extract_guess",
128
+ "extract_wordle_feedback",
129
+ ]
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """TextArena environment server components."""
8
+
9
+ from .environment import TextArenaEnvironment
10
+
11
+ __all__ = ["TextArenaEnvironment"]
server/app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """FastAPI application entrypoint for the TextArena environment."""
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+
13
+ from openenv.core.env_server.http_server import create_app
14
+
15
+ try:
16
+ # When running as installed package
17
+ from textarena_env.models import TextArenaAction, TextArenaObservation
18
+ from textarena_env.server.environment import TextArenaEnvironment
19
+ except ImportError:
20
+ # When running uvicorn directly from textarena_env/
21
+ from models import TextArenaAction, TextArenaObservation
22
+ from .environment import TextArenaEnvironment
23
+
24
+
25
+ def _parse_env_kwargs(prefix: str = "TEXTARENA_KW_") -> dict[str, str]:
26
+ """Collect arbitrary environment kwargs from the process environment."""
27
+
28
+ env_kwargs: dict[str, str] = {}
29
+ for key, value in os.environ.items():
30
+ if key.startswith(prefix):
31
+ env_key = key[len(prefix) :].lower()
32
+ env_kwargs[env_key] = value
33
+ return env_kwargs
34
+
35
+
36
+ env_id = os.getenv("TEXTARENA_ENV_ID", "Wordle-v0")
37
+ num_players = int(os.getenv("TEXTARENA_NUM_PLAYERS", "1"))
38
+ max_turns_env = os.getenv("TEXTARENA_MAX_TURNS")
39
+ max_turns = int(max_turns_env) if max_turns_env is not None else None
40
+ download_nltk = os.getenv("TEXTARENA_DOWNLOAD_NLTK", "1") in {"1", "true", "True"}
41
+
42
+ extra_kwargs = _parse_env_kwargs()
43
+
44
+
45
+ # Factory function to create TextArenaEnvironment instances
46
+ def create_textarena_environment():
47
+ """Factory function that creates TextArenaEnvironment with config."""
48
+ return TextArenaEnvironment(
49
+ env_id=env_id,
50
+ num_players=num_players,
51
+ max_turns=max_turns,
52
+ download_nltk=download_nltk,
53
+ env_kwargs=extra_kwargs,
54
+ )
55
+
56
+
57
+ # Create the FastAPI app
58
+ # Pass the factory function instead of an instance for WebSocket session support
59
+ app = create_app(
60
+ create_textarena_environment,
61
+ TextArenaAction,
62
+ TextArenaObservation,
63
+ env_name="textarena_env",
64
+ )
65
+
66
+
67
+ def main(host: str = "0.0.0.0", port: int = 8000):
68
+ """
69
+ Entry point for direct execution via uv run or python -m.
70
+
71
+ This function enables running the server without Docker:
72
+ uv run --project . server
73
+ uv run --project . server --port 8001
74
+ python -m textarena.server.app
75
+
76
+ Args:
77
+ host: Host address to bind to (default: "0.0.0.0")
78
+ port: Port number to listen on (default: 8000)
79
+
80
+ For production deployments, consider using uvicorn directly with
81
+ multiple workers:
82
+ uvicorn textarena.server.app:app --workers 4
83
+ """
84
+ import uvicorn
85
+
86
+ uvicorn.run(app, host=host, port=port)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()
server/environment.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Server implementation for the generic TextArena environment."""
8
+
9
+ from __future__ import annotations
10
+
11
+ import sys
12
+ from typing import Any, Dict, Iterable, List, Optional
13
+ from uuid import uuid4
14
+
15
+ import nltk
16
+
17
+ from openenv.core.env_server.interfaces import Environment
18
+
19
+ try:
20
+ # When running as installed package
21
+ from textarena_env.models import (
22
+ TextArenaAction,
23
+ TextArenaMessage,
24
+ TextArenaObservation,
25
+ TextArenaState,
26
+ )
27
+ from textarena_env.rewards import RewardProvider, build_reward_providers
28
+ except ImportError:
29
+ # When running uvicorn directly from textarena_env/
30
+ from models import (
31
+ TextArenaAction,
32
+ TextArenaMessage,
33
+ TextArenaObservation,
34
+ TextArenaState,
35
+ )
36
+ from rewards import RewardProvider, build_reward_providers
37
+
38
+
39
+ _TEXTARENA_MODULE: Any | None = None
40
+ _TEXTARENA_IMPORT_ERROR: Exception | None = None
41
+
42
+
43
+ def _import_textarena() -> Any:
44
+ """Import ``textarena`` lazily and cache the module reference."""
45
+
46
+ global _TEXTARENA_MODULE, _TEXTARENA_IMPORT_ERROR
47
+
48
+ if _TEXTARENA_MODULE is not None:
49
+ return _TEXTARENA_MODULE
50
+
51
+ if _TEXTARENA_IMPORT_ERROR is not None:
52
+ raise _TEXTARENA_IMPORT_ERROR
53
+
54
+ if sys.version_info < (3, 10):
55
+ _TEXTARENA_IMPORT_ERROR = RuntimeError(
56
+ "TextArena environments require Python 3.10 or newer; "
57
+ f"current interpreter is {sys.version_info.major}.{sys.version_info.minor}"
58
+ )
59
+ raise _TEXTARENA_IMPORT_ERROR
60
+
61
+ try:
62
+ import textarena as ta # type: ignore[import]
63
+ except Exception as exc: # pragma: no cover - surfaced to caller
64
+ _TEXTARENA_IMPORT_ERROR = exc
65
+ raise
66
+
67
+ _TEXTARENA_MODULE = ta
68
+ return ta
69
+
70
+
71
+ class TextArenaEnvironment(Environment):
72
+ """Wrap any TextArena game behind the OpenEnv ``Environment`` API."""
73
+
74
+ def __init__(
75
+ self,
76
+ env_id: str = "Wordle-v0",
77
+ *,
78
+ num_players: int = 1,
79
+ max_turns: Optional[int] = None,
80
+ download_nltk: bool = True,
81
+ env_kwargs: Optional[Dict[str, Any]] = None,
82
+ ) -> None:
83
+ super().__init__()
84
+
85
+ ta = _import_textarena()
86
+
87
+ if download_nltk:
88
+ nltk.download("words", quiet=True)
89
+ nltk.download("averaged_perceptron_tagger_eng", quiet=True)
90
+
91
+ self.env_id = env_id
92
+ self.num_players = num_players
93
+ self.max_turns = max_turns
94
+ self._env_kwargs = env_kwargs or {}
95
+
96
+ self._ta_env = ta.make(env_id=env_id, **self._env_kwargs)
97
+
98
+ self._state = TextArenaState(
99
+ env_id=env_id,
100
+ num_players=num_players,
101
+ max_turns=max_turns,
102
+ )
103
+
104
+ self._reward_providers: List[RewardProvider] = build_reward_providers(env_id)
105
+ self._last_reward_signals: Dict[str, float] = {}
106
+
107
+ # ------------------------------------------------------------------
108
+ # Environment interface
109
+ # ------------------------------------------------------------------
110
+ def reset(self) -> TextArenaObservation:
111
+ # TextArena observation wrappers (LLMObservationWrapper, etc.) accumulate
112
+ # observations in self.full_observations across resets. Since we can't modify TextArena,
113
+ # we need to manually clear this state to prevent history accumulation.
114
+ env = self._ta_env
115
+ while hasattr(env, "env"):
116
+ if hasattr(env, "full_observations"):
117
+ env.full_observations = {}
118
+ env = env.env
119
+ # Also check the final unwrapped env
120
+ if hasattr(env, "full_observations"):
121
+ env.full_observations = {}
122
+
123
+ self._ta_env.reset(num_players=self.num_players)
124
+
125
+ for provider in self._reward_providers:
126
+ provider.reset()
127
+
128
+ self._state.episode_id = str(uuid4())
129
+ self._state.step_count = 0
130
+ self._state.turn = 0
131
+ self._state.last_reward = 0.0
132
+ self._state.last_info = {}
133
+ self._state.raw_state = self._snapshot_state()
134
+ self._last_reward_signals = {}
135
+
136
+ observation = self._build_observation()
137
+ observation.reward = 0.0
138
+ observation.done = False
139
+
140
+ return observation
141
+
142
+ def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore[override]
143
+ if not isinstance(action, TextArenaAction):
144
+ raise TypeError(f"Expected TextArenaAction, received {type(action)!r}")
145
+
146
+ done, info = self._ta_env.step(action.message)
147
+
148
+ self._state.step_count += 1
149
+ self._state.turn = getattr(self._ta_env.state, "turn", self._state.turn + 1)
150
+ self._state.last_info = info or {}
151
+
152
+ observation = self._build_observation()
153
+ observation.done = done
154
+
155
+ reward = self._extract_reward()
156
+ observation.reward = reward
157
+ self._state.last_reward = reward
158
+
159
+ reward_signals = self._compute_reward_signals(action=action, observation=observation)
160
+ if reward_signals:
161
+ observation.info.setdefault("reward_signals", {}).update(reward_signals)
162
+ observation.metadata.setdefault("reward_signals", {}).update(reward_signals)
163
+ self._last_reward_signals = reward_signals
164
+ if reward_signals:
165
+ self._state.last_info = {
166
+ **(self._state.last_info or {}),
167
+ "reward_signals": reward_signals,
168
+ }
169
+ self._state.raw_state = self._snapshot_state()
170
+
171
+ return observation
172
+
173
+ @property
174
+ def state(self) -> TextArenaState:
175
+ return self._state
176
+
177
+ # ------------------------------------------------------------------
178
+ # Helpers
179
+ # ------------------------------------------------------------------
180
+ def _build_observation(self) -> TextArenaObservation:
181
+ player_id, messages = self._ta_env.get_observation()
182
+
183
+ ta_messages = self._convert_messages(messages)
184
+
185
+ # Extract prompt from the appropriate messages.
186
+ # TextArena PROMPT type messages contain the game instructions added during reset.
187
+ # As a fallback for environments that don't use typed messages, use only the first
188
+ # message if we're at turn 0 (fresh reset).
189
+ prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"]
190
+
191
+ if not prompt_lines:
192
+ # Fallback: use the first message only if at turn 0 (just after reset)
193
+ # DO NOT use all messages as this causes history accumulation
194
+ current_turn = getattr(self._ta_env.state, "turn", 0)
195
+ if current_turn == 0 and ta_messages:
196
+ prompt_lines = [ta_messages[0].content]
197
+ else:
198
+ # Use env_id as final fallback to avoid including game history
199
+ prompt_lines = [self.env_id]
200
+
201
+ prompt = "\n".join(prompt_lines).strip()
202
+
203
+ info: Dict[str, Any] = {}
204
+ info.update(getattr(self._ta_env.state, "step_info", {}))
205
+
206
+ observation = TextArenaObservation(
207
+ prompt=prompt,
208
+ messages=ta_messages,
209
+ current_player_id=player_id,
210
+ legal_players=self._legal_players(),
211
+ info=info,
212
+ metadata={
213
+ "env_id": self.env_id,
214
+ "turn": getattr(self._ta_env.state, "turn", 0),
215
+ "raw_messages": [
216
+ {
217
+ "sender_id": msg.sender_id,
218
+ "content": msg.content,
219
+ "category": msg.category,
220
+ }
221
+ for msg in ta_messages
222
+ ],
223
+ },
224
+ )
225
+
226
+ return observation
227
+
228
+ def _legal_players(self) -> List[int]:
229
+ role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {}
230
+ players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0]
231
+ return sorted(players)
232
+
233
+ def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]:
234
+ converted: List[TextArenaMessage] = []
235
+ buffered_sender: int | None = None
236
+ buffered_category: str | None = None
237
+ buffered_content: List[str] = []
238
+
239
+ def flush_buffer() -> None:
240
+ nonlocal buffered_content, buffered_sender, buffered_category
241
+ if not buffered_content:
242
+ return
243
+ converted.append(
244
+ TextArenaMessage(
245
+ sender_id=buffered_sender if buffered_sender is not None else -1,
246
+ content="".join(buffered_content),
247
+ category=buffered_category or "MESSAGE",
248
+ )
249
+ )
250
+ buffered_content = []
251
+ buffered_category = None
252
+ buffered_sender = None
253
+
254
+ for entry in messages:
255
+ if isinstance(entry, tuple) and len(entry) == 3:
256
+ sender, content, category = entry
257
+ elif isinstance(entry, tuple) and len(entry) == 2:
258
+ sender, content = entry
259
+ category = "MESSAGE"
260
+ else:
261
+ sender, content, category = -1, str(entry), "MESSAGE"
262
+
263
+ category_name = getattr(category, "name", str(category))
264
+ sender_id = int(sender) if isinstance(sender, (int, float)) else -1
265
+ text = str(content)
266
+
267
+ if buffered_content and buffered_category == category_name and buffered_sender == sender_id:
268
+ buffered_content.append(text)
269
+ else:
270
+ flush_buffer()
271
+ buffered_sender = sender_id
272
+ buffered_category = category_name
273
+ buffered_content = [text]
274
+
275
+ flush_buffer()
276
+
277
+ return converted
278
+
279
+ def _extract_reward(self) -> float:
280
+ rewards = getattr(self._ta_env.state, "rewards", None)
281
+ if isinstance(rewards, dict):
282
+ # Use current player reward if available, otherwise default to player 0.
283
+ player_id = getattr(self._ta_env.state, "current_player_id", 0)
284
+ if player_id in rewards:
285
+ return float(rewards[player_id])
286
+ if 0 in rewards:
287
+ return float(rewards[0])
288
+ return 0.0
289
+
290
+ def _snapshot_state(self) -> Dict[str, Any]:
291
+ state = self._ta_env.state
292
+ snapshot: Dict[str, Any] = {
293
+ "turn": getattr(state, "turn", 0),
294
+ "game_state": getattr(state, "game_state", {}),
295
+ "logs": list(getattr(state, "logs", [])),
296
+ "rewards": getattr(state, "rewards", None),
297
+ "done": getattr(state, "done", False),
298
+ "role_mapping": getattr(state, "role_mapping", {}),
299
+ "game_info": getattr(state, "game_info", {}),
300
+ "step_info": getattr(state, "step_info", {}),
301
+ }
302
+ if self._last_reward_signals:
303
+ snapshot["reward_signals"] = dict(self._last_reward_signals)
304
+ return snapshot
305
+
306
+ def _compute_reward_signals(
307
+ self, *, action: TextArenaAction, observation: TextArenaObservation
308
+ ) -> Dict[str, float]:
309
+ if not self._reward_providers:
310
+ return {}
311
+
312
+ aggregated: Dict[str, float] = {}
313
+ for provider in self._reward_providers:
314
+ try:
315
+ result = provider.compute(action=action, observation=observation)
316
+ except Exception: # pragma: no cover - defensive
317
+ continue
318
+ for key, value in result.items():
319
+ aggregated[key] = float(value)
320
+ return aggregated
server/run_local.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ export TEXTARENA_ENV_ID="Wordle-v0"
2
+ export TEXTARENA_NUM_PLAYERS=1
3
+
4
+ # Run the server
5
+ exec uvicorn envs.textarena_env.server.app:app --host 0.0.0.0 --port 8001
6
+
7
+
uv.lock ADDED
The diff for this file is too large to render. See raw diff