File size: 11,575 Bytes
6e194fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Server implementation for the generic TextArena environment."""

from __future__ import annotations

import sys
from typing import Any, Dict, Iterable, List, Optional
from uuid import uuid4

import nltk

from openenv.core.env_server.interfaces import Environment

try:
    # When running as installed package
    from textarena_env.models import (
        TextArenaAction,
        TextArenaMessage,
        TextArenaObservation,
        TextArenaState,
    )
    from textarena_env.rewards import RewardProvider, build_reward_providers
except ImportError:
    # When running uvicorn directly from textarena_env/
    from models import (
        TextArenaAction,
        TextArenaMessage,
        TextArenaObservation,
        TextArenaState,
    )
    from rewards import RewardProvider, build_reward_providers


_TEXTARENA_MODULE: Any | None = None
_TEXTARENA_IMPORT_ERROR: Exception | None = None


def _import_textarena() -> Any:
    """Import ``textarena`` lazily and cache the module reference."""

    global _TEXTARENA_MODULE, _TEXTARENA_IMPORT_ERROR

    if _TEXTARENA_MODULE is not None:
        return _TEXTARENA_MODULE

    if _TEXTARENA_IMPORT_ERROR is not None:
        raise _TEXTARENA_IMPORT_ERROR

    if sys.version_info < (3, 10):
        _TEXTARENA_IMPORT_ERROR = RuntimeError(
            "TextArena environments require Python 3.10 or newer; "
            f"current interpreter is {sys.version_info.major}.{sys.version_info.minor}"
        )
        raise _TEXTARENA_IMPORT_ERROR

    try:
        import textarena as ta  # type: ignore[import]
    except Exception as exc:  # pragma: no cover - surfaced to caller
        _TEXTARENA_IMPORT_ERROR = exc
        raise

    _TEXTARENA_MODULE = ta
    return ta


class TextArenaEnvironment(Environment):
    """Wrap any TextArena game behind the OpenEnv ``Environment`` API."""

    def __init__(
        self,
        env_id: str = "Wordle-v0",
        *,
        num_players: int = 1,
        max_turns: Optional[int] = None,
        download_nltk: bool = True,
        env_kwargs: Optional[Dict[str, Any]] = None,
    ) -> None:
        super().__init__()

        ta = _import_textarena()

        if download_nltk:
            nltk.download("words", quiet=True)
            nltk.download("averaged_perceptron_tagger_eng", quiet=True)

        self.env_id = env_id
        self.num_players = num_players
        self.max_turns = max_turns
        self._env_kwargs = env_kwargs or {}

        self._ta_env = ta.make(env_id=env_id, **self._env_kwargs)

        self._state = TextArenaState(
            env_id=env_id,
            num_players=num_players,
            max_turns=max_turns,
        )

        self._reward_providers: List[RewardProvider] = build_reward_providers(env_id)
        self._last_reward_signals: Dict[str, float] = {}

    # ------------------------------------------------------------------
    # Environment interface
    # ------------------------------------------------------------------
    def reset(self) -> TextArenaObservation:
        # TextArena observation wrappers (LLMObservationWrapper, etc.) accumulate
        # observations in self.full_observations across resets. Since we can't modify TextArena,
        # we need to manually clear this state to prevent history accumulation.
        env = self._ta_env
        while hasattr(env, "env"):
            if hasattr(env, "full_observations"):
                env.full_observations = {}
            env = env.env
        # Also check the final unwrapped env
        if hasattr(env, "full_observations"):
            env.full_observations = {}

        self._ta_env.reset(num_players=self.num_players)

        for provider in self._reward_providers:
            provider.reset()

        self._state.episode_id = str(uuid4())
        self._state.step_count = 0
        self._state.turn = 0
        self._state.last_reward = 0.0
        self._state.last_info = {}
        self._state.raw_state = self._snapshot_state()
        self._last_reward_signals = {}

        observation = self._build_observation()
        observation.reward = 0.0
        observation.done = False

        return observation

    def step(self, action: TextArenaAction) -> TextArenaObservation:  # type: ignore[override]
        if not isinstance(action, TextArenaAction):
            raise TypeError(f"Expected TextArenaAction, received {type(action)!r}")

        done, info = self._ta_env.step(action.message)

        self._state.step_count += 1
        self._state.turn = getattr(self._ta_env.state, "turn", self._state.turn + 1)
        self._state.last_info = info or {}

        observation = self._build_observation()
        observation.done = done

        reward = self._extract_reward()
        observation.reward = reward
        self._state.last_reward = reward

        reward_signals = self._compute_reward_signals(action=action, observation=observation)
        if reward_signals:
            observation.info.setdefault("reward_signals", {}).update(reward_signals)
            observation.metadata.setdefault("reward_signals", {}).update(reward_signals)
        self._last_reward_signals = reward_signals
        if reward_signals:
            self._state.last_info = {
                **(self._state.last_info or {}),
                "reward_signals": reward_signals,
            }
        self._state.raw_state = self._snapshot_state()

        return observation

    @property
    def state(self) -> TextArenaState:
        return self._state

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------
    def _build_observation(self) -> TextArenaObservation:
        player_id, messages = self._ta_env.get_observation()

        ta_messages = self._convert_messages(messages)

        # Extract prompt from the appropriate messages.
        # TextArena PROMPT type messages contain the game instructions added during reset.
        # As a fallback for environments that don't use typed messages, use only the first
        # message if we're at turn 0 (fresh reset).
        prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"]

        if not prompt_lines:
            # Fallback: use the first message only if at turn 0 (just after reset)
            # DO NOT use all messages as this causes history accumulation
            current_turn = getattr(self._ta_env.state, "turn", 0)
            if current_turn == 0 and ta_messages:
                prompt_lines = [ta_messages[0].content]
            else:
                # Use env_id as final fallback to avoid including game history
                prompt_lines = [self.env_id]

        prompt = "\n".join(prompt_lines).strip()

        info: Dict[str, Any] = {}
        info.update(getattr(self._ta_env.state, "step_info", {}))

        observation = TextArenaObservation(
            prompt=prompt,
            messages=ta_messages,
            current_player_id=player_id,
            legal_players=self._legal_players(),
            info=info,
            metadata={
                "env_id": self.env_id,
                "turn": getattr(self._ta_env.state, "turn", 0),
                "raw_messages": [
                    {
                        "sender_id": msg.sender_id,
                        "content": msg.content,
                        "category": msg.category,
                    }
                    for msg in ta_messages
                ],
            },
        )

        return observation

    def _legal_players(self) -> List[int]:
        role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {}
        players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0]
        return sorted(players)

    def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]:
        converted: List[TextArenaMessage] = []
        buffered_sender: int | None = None
        buffered_category: str | None = None
        buffered_content: List[str] = []

        def flush_buffer() -> None:
            nonlocal buffered_content, buffered_sender, buffered_category
            if not buffered_content:
                return
            converted.append(
                TextArenaMessage(
                    sender_id=buffered_sender if buffered_sender is not None else -1,
                    content="".join(buffered_content),
                    category=buffered_category or "MESSAGE",
                )
            )
            buffered_content = []
            buffered_category = None
            buffered_sender = None

        for entry in messages:
            if isinstance(entry, tuple) and len(entry) == 3:
                sender, content, category = entry
            elif isinstance(entry, tuple) and len(entry) == 2:
                sender, content = entry
                category = "MESSAGE"
            else:
                sender, content, category = -1, str(entry), "MESSAGE"

            category_name = getattr(category, "name", str(category))
            sender_id = int(sender) if isinstance(sender, (int, float)) else -1
            text = str(content)

            if buffered_content and buffered_category == category_name and buffered_sender == sender_id:
                buffered_content.append(text)
            else:
                flush_buffer()
                buffered_sender = sender_id
                buffered_category = category_name
                buffered_content = [text]

        flush_buffer()

        return converted

    def _extract_reward(self) -> float:
        rewards = getattr(self._ta_env.state, "rewards", None)
        if isinstance(rewards, dict):
            # Use current player reward if available, otherwise default to player 0.
            player_id = getattr(self._ta_env.state, "current_player_id", 0)
            if player_id in rewards:
                return float(rewards[player_id])
            if 0 in rewards:
                return float(rewards[0])
        return 0.0

    def _snapshot_state(self) -> Dict[str, Any]:
        state = self._ta_env.state
        snapshot: Dict[str, Any] = {
            "turn": getattr(state, "turn", 0),
            "game_state": getattr(state, "game_state", {}),
            "logs": list(getattr(state, "logs", [])),
            "rewards": getattr(state, "rewards", None),
            "done": getattr(state, "done", False),
            "role_mapping": getattr(state, "role_mapping", {}),
            "game_info": getattr(state, "game_info", {}),
            "step_info": getattr(state, "step_info", {}),
        }
        if self._last_reward_signals:
            snapshot["reward_signals"] = dict(self._last_reward_signals)
        return snapshot

    def _compute_reward_signals(
        self, *, action: TextArenaAction, observation: TextArenaObservation
    ) -> Dict[str, float]:
        if not self._reward_providers:
            return {}

        aggregated: Dict[str, float] = {}
        for provider in self._reward_providers:
            try:
                result = provider.compute(action=action, observation=observation)
            except Exception:  # pragma: no cover - defensive
                continue
            for key, value in result.items():
                aggregated[key] = float(value)
        return aggregated