burtenshaw HF Staff commited on
Commit
5536759
·
verified ·
1 Parent(s): b64ae1e

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. models.py +4 -2
  2. rewards.py +4 -1
  3. server/app.py +6 -1
  4. server/environment.py +5 -0
models.py CHANGED
@@ -12,6 +12,7 @@ The textarena environment is a simple test environment that echoes back messages
12
 
13
  from __future__ import annotations
14
 
 
15
  from typing import Any, Dict, List, Optional
16
 
17
  from pydantic import BaseModel, Field
@@ -19,7 +20,7 @@ from pydantic import BaseModel, Field
19
  from openenv.core.env_server.types import Action, Observation, State
20
 
21
 
22
- class TextArenaMessage(BaseModel):
23
  """Single message observed by a player."""
24
 
25
  sender_id: int
@@ -42,7 +43,7 @@ class TextArenaObservation(Observation):
42
  legal_players: List[int] = Field(default_factory=list)
43
  info: Dict[str, Any] = Field(default_factory=dict)
44
 
45
-
46
  class TextArenaState(State):
47
  """Structured state snapshot for the server."""
48
 
@@ -53,3 +54,4 @@ class TextArenaState(State):
53
  last_reward: float = 0.0
54
  last_info: Dict[str, Any] = Field(default_factory=dict)
55
  raw_state: Dict[str, Any] = Field(default_factory=dict)
 
 
12
 
13
  from __future__ import annotations
14
 
15
+ from pydantic import Field
16
  from typing import Any, Dict, List, Optional
17
 
18
  from pydantic import BaseModel, Field
 
20
  from openenv.core.env_server.types import Action, Observation, State
21
 
22
 
23
+ class TextArenaMessage:
24
  """Single message observed by a player."""
25
 
26
  sender_id: int
 
43
  legal_players: List[int] = Field(default_factory=list)
44
  info: Dict[str, Any] = Field(default_factory=dict)
45
 
46
+
47
  class TextArenaState(State):
48
  """Structured state snapshot for the server."""
49
 
 
54
  last_reward: float = 0.0
55
  last_info: Dict[str, Any] = Field(default_factory=dict)
56
  raw_state: Dict[str, Any] = Field(default_factory=dict)
57
+
rewards.py CHANGED
@@ -5,7 +5,10 @@ from __future__ import annotations
5
  import re
6
  from typing import Dict, List, Protocol, Tuple
7
 
8
- from .models import TextArenaAction, TextArenaObservation
 
 
 
9
 
10
 
11
  class RewardProvider(Protocol):
 
5
  import re
6
  from typing import Dict, List, Protocol, Tuple
7
 
8
+ try:
9
+ from textarena_env.models import TextArenaAction, TextArenaObservation
10
+ except ImportError:
11
+ from models import TextArenaAction, TextArenaObservation
12
 
13
 
14
  class RewardProvider(Protocol):
server/app.py CHANGED
@@ -56,7 +56,12 @@ def create_textarena_environment():
56
 
57
  # Create the FastAPI app
58
  # Pass the factory function instead of an instance for WebSocket session support
59
- app = create_app(create_textarena_environment, TextArenaAction, TextArenaObservation, env_name="textarena_env")
 
 
 
 
 
60
 
61
 
62
  def main(host: str = "0.0.0.0", port: int = 8000):
 
56
 
57
  # Create the FastAPI app
58
  # Pass the factory function instead of an instance for WebSocket session support
59
+ app = create_app(
60
+ create_textarena_environment,
61
+ TextArenaAction,
62
+ TextArenaObservation,
63
+ env_name="textarena_env",
64
+ )
65
 
66
 
67
  def main(host: str = "0.0.0.0", port: int = 8000):
server/environment.py CHANGED
@@ -104,6 +104,11 @@ class TextArenaEnvironment(Environment):
104
  self._reward_providers: List[RewardProvider] = build_reward_providers(env_id)
105
  self._last_reward_signals: Dict[str, float] = {}
106
 
 
 
 
 
 
107
  # ------------------------------------------------------------------
108
  # Environment interface
109
  # ------------------------------------------------------------------
 
104
  self._reward_providers: List[RewardProvider] = build_reward_providers(env_id)
105
  self._last_reward_signals: Dict[str, float] = {}
106
 
107
+ # Initialize environment state - TextArena envs require reset() to be called
108
+ # before step() can be used, as the internal state object isn't created until reset.
109
+ # This ensures the environment is always in a valid state after construction.
110
+ self._ta_env.reset(num_players=self.num_players)
111
+
112
  # ------------------------------------------------------------------
113
  # Environment interface
114
  # ------------------------------------------------------------------