|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
TextArena Environment HTTP Client. |
|
|
|
|
|
This module provides the client for connecting to a TextArena Environment server |
|
|
over HTTP. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict |
|
|
|
|
|
from openenv.core.client_types import StepResult |
|
|
from openenv.core.env_client import EnvClient |
|
|
|
|
|
from .models import ( |
|
|
TextArenaAction, |
|
|
TextArenaMessage, |
|
|
TextArenaObservation, |
|
|
TextArenaState, |
|
|
) |
|
|
|
|
|
|
|
|
class TextArenaEnv(EnvClient[TextArenaAction, TextArenaObservation, TextArenaState]): |
|
|
""" |
|
|
HTTP client for the TextArena Environment. |
|
|
|
|
|
This client connects to a TextArenaEnvironment HTTP server and provides |
|
|
methods to interact with it: reset(), step(), and state access. |
|
|
|
|
|
Example: |
|
|
>>> # Connect to a running server |
|
|
>>> client = TextArenaEnv(base_url="http://localhost:8000") |
|
|
>>> result = client.reset() |
|
|
>>> print(result.observation.echoed_message) |
|
|
>>> |
|
|
>>> # Send a message |
|
|
>>> result = client.step(TextArenaAction(message="Hello!")) |
|
|
>>> print(result.observation.echoed_message) |
|
|
>>> print(result.reward) |
|
|
|
|
|
Example with Docker: |
|
|
>>> # Automatically start container and connect |
|
|
>>> client = TextArenaEnv.from_docker_image("textarena-env:latest") |
|
|
>>> result = client.reset() |
|
|
>>> result = client.step(TextArenaAction(message="Test")) |
|
|
""" |
|
|
|
|
|
def _step_payload(self, action: TextArenaAction) -> Dict: |
|
|
""" |
|
|
Convert TextArenaAction to JSON payload for step request. |
|
|
|
|
|
Args: |
|
|
action: TextArenaAction instance |
|
|
|
|
|
Returns: |
|
|
Dictionary representation suitable for JSON encoding |
|
|
""" |
|
|
return { |
|
|
"message": action.message, |
|
|
} |
|
|
|
|
|
def _parse_result(self, payload: Dict) -> StepResult[TextArenaObservation]: |
|
|
""" |
|
|
Parse server response into StepResult[TextArenaObservation]. |
|
|
|
|
|
Args: |
|
|
payload: JSON response from server |
|
|
|
|
|
Returns: |
|
|
StepResult with TextArenaObservation |
|
|
""" |
|
|
obs_data = payload.get("observation", {}) |
|
|
messages_payload = obs_data.get("messages", []) |
|
|
messages = [ |
|
|
TextArenaMessage( |
|
|
sender_id=item.get("sender_id", -1), |
|
|
content=item.get("content", ""), |
|
|
category=item.get("category", "MESSAGE"), |
|
|
) |
|
|
for item in messages_payload |
|
|
if isinstance(item, dict) |
|
|
] |
|
|
|
|
|
observation = TextArenaObservation( |
|
|
prompt=obs_data.get("prompt", ""), |
|
|
messages=messages, |
|
|
current_player_id=obs_data.get("current_player_id", 0), |
|
|
legal_players=obs_data.get("legal_players", []), |
|
|
info=obs_data.get("info", {}), |
|
|
reward=payload.get("reward"), |
|
|
done=payload.get("done", False), |
|
|
metadata=obs_data.get("metadata", {}), |
|
|
) |
|
|
return StepResult( |
|
|
observation=observation, |
|
|
reward=payload.get("reward"), |
|
|
done=payload.get("done", False), |
|
|
) |
|
|
|
|
|
def _parse_state(self, payload: Dict[str, Any]) -> TextArenaState: |
|
|
return TextArenaState( |
|
|
episode_id=payload.get("episode_id"), |
|
|
step_count=payload.get("step_count", 0), |
|
|
env_id=payload.get("env_id", "unknown"), |
|
|
num_players=payload.get("num_players", 1), |
|
|
max_turns=payload.get("max_turns"), |
|
|
turn=payload.get("turn", 0), |
|
|
last_reward=payload.get("last_reward", 0.0), |
|
|
last_info=payload.get("last_info", {}), |
|
|
raw_state=payload.get("raw_state", {}), |
|
|
) |
|
|
|