burtenshaw HF Staff commited on
Commit
7207595
·
verified ·
1 Parent(s): 73071e0

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. models.py +0 -1
  2. rewards.py +9 -3
  3. server/environment.py +12 -2
models.py CHANGED
@@ -16,7 +16,6 @@ from pydantic import BaseModel, Field
16
  from typing import Any, Dict, List, Optional
17
 
18
 
19
-
20
  from openenv.core.env_server.types import Action, Observation, State
21
 
22
 
 
16
  from typing import Any, Dict, List, Optional
17
 
18
 
 
19
  from openenv.core.env_server.types import Action, Observation, State
20
 
21
 
rewards.py CHANGED
@@ -17,7 +17,9 @@ class RewardProvider(Protocol):
17
  def reset(self) -> None:
18
  """Clear any internal state before a new episode."""
19
 
20
- def compute(self, *, action: TextArenaAction, observation: TextArenaObservation) -> Dict[str, float]:
 
 
21
  """Return a mapping of reward names to float values for the step."""
22
 
23
 
@@ -92,12 +94,16 @@ class _WordleRewardProvider:
92
  def reset(self) -> None:
93
  self._guess_history.clear()
94
 
95
- def compute(self, *, action: TextArenaAction, observation: TextArenaObservation) -> Dict[str, float]:
 
 
96
  guess = extract_guess(action.message)
97
  feedback = extract_wordle_feedback(observation)
98
 
99
  normalized_guess = guess if guess and guess != "[dunno]" else ""
100
- previous_occurrences = self._guess_history.get(normalized_guess, 0) if normalized_guess else 0
 
 
101
 
102
  green_score = 0.0
103
  yellow_score = 0.0
 
17
  def reset(self) -> None:
18
  """Clear any internal state before a new episode."""
19
 
20
+ def compute(
21
+ self, *, action: TextArenaAction, observation: TextArenaObservation
22
+ ) -> Dict[str, float]:
23
  """Return a mapping of reward names to float values for the step."""
24
 
25
 
 
94
  def reset(self) -> None:
95
  self._guess_history.clear()
96
 
97
+ def compute(
98
+ self, *, action: TextArenaAction, observation: TextArenaObservation
99
+ ) -> Dict[str, float]:
100
  guess = extract_guess(action.message)
101
  feedback = extract_wordle_feedback(observation)
102
 
103
  normalized_guess = guess if guess and guess != "[dunno]" else ""
104
+ previous_occurrences = (
105
+ self._guess_history.get(normalized_guess, 0) if normalized_guess else 0
106
+ )
107
 
108
  green_score = 0.0
109
  yellow_score = 0.0
server/environment.py CHANGED
@@ -38,6 +38,17 @@ except ImportError:
38
 
39
  _TEXTARENA_MODULE: Any | None = None
40
  _TEXTARENA_IMPORT_ERROR: Exception | None = None
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  def _import_textarena() -> Any:
@@ -85,8 +96,7 @@ class TextArenaEnvironment(Environment):
85
  ta = _import_textarena()
86
 
87
  if download_nltk:
88
- nltk.download("words", quiet=True)
89
- nltk.download("averaged_perceptron_tagger_eng", quiet=True)
90
 
91
  self.env_id = env_id
92
  self.num_players = num_players
 
38
 
39
  _TEXTARENA_MODULE: Any | None = None
40
  _TEXTARENA_IMPORT_ERROR: Exception | None = None
41
+ _NLTK_DOWNLOADED: bool = False
42
+
43
+
44
+ def _ensure_nltk_data() -> None:
45
+ """Download NLTK data once per process."""
46
+ global _NLTK_DOWNLOADED
47
+ if _NLTK_DOWNLOADED:
48
+ return
49
+ nltk.download("words", quiet=True)
50
+ nltk.download("averaged_perceptron_tagger_eng", quiet=True)
51
+ _NLTK_DOWNLOADED = True
52
 
53
 
54
  def _import_textarena() -> Any:
 
96
  ta = _import_textarena()
97
 
98
  if download_nltk:
99
+ _ensure_nltk_data()
 
100
 
101
  self.env_id = env_id
102
  self.num_players = num_players