# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Server implementation for the generic TextArena environment.""" from __future__ import annotations import sys from typing import Any, Dict, Iterable, List, Optional from uuid import uuid4 import nltk from openenv.core.env_server.interfaces import Environment try: # When running as installed package from textarena_env.models import ( TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState, ) from textarena_env.rewards import RewardProvider, build_reward_providers except ImportError: # When running uvicorn directly from textarena_env/ from models import ( TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState, ) from rewards import RewardProvider, build_reward_providers _TEXTARENA_MODULE: Any | None = None _TEXTARENA_IMPORT_ERROR: Exception | None = None def _import_textarena() -> Any: """Import ``textarena`` lazily and cache the module reference.""" global _TEXTARENA_MODULE, _TEXTARENA_IMPORT_ERROR if _TEXTARENA_MODULE is not None: return _TEXTARENA_MODULE if _TEXTARENA_IMPORT_ERROR is not None: raise _TEXTARENA_IMPORT_ERROR if sys.version_info < (3, 10): _TEXTARENA_IMPORT_ERROR = RuntimeError( "TextArena environments require Python 3.10 or newer; " f"current interpreter is {sys.version_info.major}.{sys.version_info.minor}" ) raise _TEXTARENA_IMPORT_ERROR try: import textarena as ta # type: ignore[import] except Exception as exc: # pragma: no cover - surfaced to caller _TEXTARENA_IMPORT_ERROR = exc raise _TEXTARENA_MODULE = ta return ta class TextArenaEnvironment(Environment): """Wrap any TextArena game behind the OpenEnv ``Environment`` API.""" def __init__( self, env_id: str = "Wordle-v0", *, num_players: int = 1, max_turns: Optional[int] = None, download_nltk: bool = True, env_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() ta = _import_textarena() if download_nltk: nltk.download("words", quiet=True) nltk.download("averaged_perceptron_tagger_eng", quiet=True) self.env_id = env_id self.num_players = num_players self.max_turns = max_turns self._env_kwargs = env_kwargs or {} self._ta_env = ta.make(env_id=env_id, **self._env_kwargs) self._state = TextArenaState( env_id=env_id, num_players=num_players, max_turns=max_turns, ) self._reward_providers: List[RewardProvider] = build_reward_providers(env_id) self._last_reward_signals: Dict[str, float] = {} # ------------------------------------------------------------------ # Environment interface # ------------------------------------------------------------------ def reset(self) -> TextArenaObservation: # TextArena observation wrappers (LLMObservationWrapper, etc.) accumulate # observations in self.full_observations across resets. Since we can't modify TextArena, # we need to manually clear this state to prevent history accumulation. env = self._ta_env while hasattr(env, "env"): if hasattr(env, "full_observations"): env.full_observations = {} env = env.env # Also check the final unwrapped env if hasattr(env, "full_observations"): env.full_observations = {} self._ta_env.reset(num_players=self.num_players) for provider in self._reward_providers: provider.reset() self._state.episode_id = str(uuid4()) self._state.step_count = 0 self._state.turn = 0 self._state.last_reward = 0.0 self._state.last_info = {} self._state.raw_state = self._snapshot_state() self._last_reward_signals = {} observation = self._build_observation() observation.reward = 0.0 observation.done = False return observation def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore[override] if not isinstance(action, TextArenaAction): raise TypeError(f"Expected TextArenaAction, received {type(action)!r}") done, info = self._ta_env.step(action.message) self._state.step_count += 1 self._state.turn = getattr(self._ta_env.state, "turn", self._state.turn + 1) self._state.last_info = info or {} observation = self._build_observation() observation.done = done reward = self._extract_reward() observation.reward = reward self._state.last_reward = reward reward_signals = self._compute_reward_signals(action=action, observation=observation) if reward_signals: observation.info.setdefault("reward_signals", {}).update(reward_signals) observation.metadata.setdefault("reward_signals", {}).update(reward_signals) self._last_reward_signals = reward_signals if reward_signals: self._state.last_info = { **(self._state.last_info or {}), "reward_signals": reward_signals, } self._state.raw_state = self._snapshot_state() return observation @property def state(self) -> TextArenaState: return self._state # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _build_observation(self) -> TextArenaObservation: player_id, messages = self._ta_env.get_observation() ta_messages = self._convert_messages(messages) # Extract prompt from the appropriate messages. # TextArena PROMPT type messages contain the game instructions added during reset. # As a fallback for environments that don't use typed messages, use only the first # message if we're at turn 0 (fresh reset). prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"] if not prompt_lines: # Fallback: use the first message only if at turn 0 (just after reset) # DO NOT use all messages as this causes history accumulation current_turn = getattr(self._ta_env.state, "turn", 0) if current_turn == 0 and ta_messages: prompt_lines = [ta_messages[0].content] else: # Use env_id as final fallback to avoid including game history prompt_lines = [self.env_id] prompt = "\n".join(prompt_lines).strip() info: Dict[str, Any] = {} info.update(getattr(self._ta_env.state, "step_info", {})) observation = TextArenaObservation( prompt=prompt, messages=ta_messages, current_player_id=player_id, legal_players=self._legal_players(), info=info, metadata={ "env_id": self.env_id, "turn": getattr(self._ta_env.state, "turn", 0), "raw_messages": [ { "sender_id": msg.sender_id, "content": msg.content, "category": msg.category, } for msg in ta_messages ], }, ) return observation def _legal_players(self) -> List[int]: role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {} players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0] return sorted(players) def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]: converted: List[TextArenaMessage] = [] buffered_sender: int | None = None buffered_category: str | None = None buffered_content: List[str] = [] def flush_buffer() -> None: nonlocal buffered_content, buffered_sender, buffered_category if not buffered_content: return converted.append( TextArenaMessage( sender_id=buffered_sender if buffered_sender is not None else -1, content="".join(buffered_content), category=buffered_category or "MESSAGE", ) ) buffered_content = [] buffered_category = None buffered_sender = None for entry in messages: if isinstance(entry, tuple) and len(entry) == 3: sender, content, category = entry elif isinstance(entry, tuple) and len(entry) == 2: sender, content = entry category = "MESSAGE" else: sender, content, category = -1, str(entry), "MESSAGE" category_name = getattr(category, "name", str(category)) sender_id = int(sender) if isinstance(sender, (int, float)) else -1 text = str(content) if buffered_content and buffered_category == category_name and buffered_sender == sender_id: buffered_content.append(text) else: flush_buffer() buffered_sender = sender_id buffered_category = category_name buffered_content = [text] flush_buffer() return converted def _extract_reward(self) -> float: rewards = getattr(self._ta_env.state, "rewards", None) if isinstance(rewards, dict): # Use current player reward if available, otherwise default to player 0. player_id = getattr(self._ta_env.state, "current_player_id", 0) if player_id in rewards: return float(rewards[player_id]) if 0 in rewards: return float(rewards[0]) return 0.0 def _snapshot_state(self) -> Dict[str, Any]: state = self._ta_env.state snapshot: Dict[str, Any] = { "turn": getattr(state, "turn", 0), "game_state": getattr(state, "game_state", {}), "logs": list(getattr(state, "logs", [])), "rewards": getattr(state, "rewards", None), "done": getattr(state, "done", False), "role_mapping": getattr(state, "role_mapping", {}), "game_info": getattr(state, "game_info", {}), "step_info": getattr(state, "step_info", {}), } if self._last_reward_signals: snapshot["reward_signals"] = dict(self._last_reward_signals) return snapshot def _compute_reward_signals( self, *, action: TextArenaAction, observation: TextArenaObservation ) -> Dict[str, float]: if not self._reward_providers: return {} aggregated: Dict[str, float] = {} for provider in self._reward_providers: try: result = provider.compute(action=action, observation=observation) except Exception: # pragma: no cover - defensive continue for key, value in result.items(): aggregated[key] = float(value) return aggregated