AbeBhatti commited on
Commit
5103eea
·
1 Parent(s): f8cde5c

graceful SentenceTransformer fallback on HF Spaces

Browse files
envs/arbitragent_env.py CHANGED
@@ -45,7 +45,10 @@ class ArbitrAgentEnv(Env):
45
 
46
  def __init__(self, data_path: str = "training/data/selfplay_states.json", seed=None):
47
  self.data_path = data_path
48
- self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
49
  if seed is not None:
50
  random.seed(seed)
51
  np.random.seed(seed)
@@ -104,6 +107,8 @@ class ArbitrAgentEnv(Env):
104
 
105
  def _accuracy_reward(self, action: str) -> float:
106
  """Cosine similarity between action embedding and human action embedding."""
 
 
107
  state_text = self.current_state.get("state_text", "")
108
  human_action_text = _extract_human_orders(state_text)
109
  action_emb = self.encoder.encode(action, convert_to_numpy=True)
@@ -204,6 +209,8 @@ Your task: Propose a move. If you detect a bluff, use coalition pressure; otherw
204
 
205
  def _get_observation(self):
206
  text = self._get_state_text()
 
 
207
  emb = self.encoder.encode(text, convert_to_numpy=True)
208
  return emb.astype(np.float32)
209
 
 
45
 
46
  def __init__(self, data_path: str = "training/data/selfplay_states.json", seed=None):
47
  self.data_path = data_path
48
+ try:
49
+ self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
50
+ except Exception:
51
+ self.encoder = None
52
  if seed is not None:
53
  random.seed(seed)
54
  np.random.seed(seed)
 
107
 
108
  def _accuracy_reward(self, action: str) -> float:
109
  """Cosine similarity between action embedding and human action embedding."""
110
+ if self.encoder is None:
111
+ return 0.0
112
  state_text = self.current_state.get("state_text", "")
113
  human_action_text = _extract_human_orders(state_text)
114
  action_emb = self.encoder.encode(action, convert_to_numpy=True)
 
209
 
210
  def _get_observation(self):
211
  text = self._get_state_text()
212
+ if self.encoder is None:
213
+ return np.zeros(384, dtype=np.float32)
214
  emb = self.encoder.encode(text, convert_to_numpy=True)
215
  return emb.astype(np.float32)
216
 
envs/contractor_env.py CHANGED
@@ -24,7 +24,10 @@ class ContractorNegotiationEnv(Env):
24
  def __init__(self, n_contractors=5, budget=10000, seed=None):
25
  self.n_contractors = n_contractors
26
  self.budget = budget
27
- self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
28
  if seed:
29
  random.seed(seed)
30
  np.random.seed(seed)
@@ -145,6 +148,8 @@ class ContractorNegotiationEnv(Env):
145
 
146
  def _get_observation(self):
147
  text = self._get_state_text()
 
 
148
  emb = self.encoder.encode(text, convert_to_numpy=True)
149
  return emb.astype(np.float32)
150
 
 
24
  def __init__(self, n_contractors=5, budget=10000, seed=None):
25
  self.n_contractors = n_contractors
26
  self.budget = budget
27
+ try:
28
+ self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
29
+ except Exception:
30
+ self.encoder = None
31
  if seed:
32
  random.seed(seed)
33
  np.random.seed(seed)
 
148
 
149
  def _get_observation(self):
150
  text = self._get_state_text()
151
+ if self.encoder is None:
152
+ return np.zeros(384, dtype=np.float32)
153
  emb = self.encoder.encode(text, convert_to_numpy=True)
154
  return emb.astype(np.float32)
155
 
envs/diplomacy_env.py CHANGED
@@ -19,7 +19,10 @@ class DiplomacyNegotiationEnv(Env):
19
  def __init__(self, power_name: str = "ENGLAND", seed: int | None = None):
20
  self._reset_random_power = power_name.upper() == "ENGLAND" # default: vary power on reset for non-hardcoded obs
21
  self.power_name = power_name.upper()
22
- self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
23
  self.game: Game | None = None
24
  self.current_phase: int = 0
25
  self.prev_sc_count: int = 0
@@ -149,6 +152,8 @@ class DiplomacyNegotiationEnv(Env):
149
  def _get_observation(self) -> np.ndarray:
150
  """Return a 384-dim MiniLM embedding of the current game state text."""
151
  text = self._get_state_text()
 
 
152
  embedding = self.encoder.encode(text, convert_to_numpy=True)
153
  # Ensure consistent dtype for downstream RL code.
154
  return embedding.astype(np.float32)
 
19
  def __init__(self, power_name: str = "ENGLAND", seed: int | None = None):
20
  self._reset_random_power = power_name.upper() == "ENGLAND" # default: vary power on reset for non-hardcoded obs
21
  self.power_name = power_name.upper()
22
+ try:
23
+ self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24
+ except Exception:
25
+ self.encoder = None
26
  self.game: Game | None = None
27
  self.current_phase: int = 0
28
  self.prev_sc_count: int = 0
 
152
  def _get_observation(self) -> np.ndarray:
153
  """Return a 384-dim MiniLM embedding of the current game state text."""
154
  text = self._get_state_text()
155
+ if self.encoder is None:
156
+ return np.zeros(384, dtype=np.float32)
157
  embedding = self.encoder.encode(text, convert_to_numpy=True)
158
  # Ensure consistent dtype for downstream RL code.
159
  return embedding.astype(np.float32)
envs/human_imitation_env.py CHANGED
@@ -17,7 +17,10 @@ from sentence_transformers import SentenceTransformer
17
  class HumanImitationEnv(Env):
18
  def __init__(self, data_path="training/data/selfplay_states.json", seed=None):
19
  self.data_path = data_path
20
- self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
21
  if seed is not None:
22
  random.seed(seed)
23
  np.random.seed(seed)
@@ -115,6 +118,8 @@ Explain your reasoning and state your intended orders."""
115
 
116
  def _get_observation(self):
117
  text = self._get_state_text()
 
 
118
  emb = self.encoder.encode(text, convert_to_numpy=True)
119
  return emb.astype(np.float32)
120
 
@@ -156,7 +161,10 @@ from sentence_transformers import SentenceTransformer
156
  class HumanImitationEnv(Env):
157
  def __init__(self, data_path="training/data/selfplay_states.json", seed=None):
158
  self.data_path = data_path
159
- self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
160
  if seed is not None:
161
  random.seed(seed)
162
  np.random.seed(seed)
@@ -254,6 +262,8 @@ Explain your reasoning and state your intended orders."""
254
 
255
  def _get_observation(self):
256
  text = self._get_state_text()
 
 
257
  emb = self.encoder.encode(text, convert_to_numpy=True)
258
  return emb.astype(np.float32)
259
 
 
17
  class HumanImitationEnv(Env):
18
  def __init__(self, data_path="training/data/selfplay_states.json", seed=None):
19
  self.data_path = data_path
20
+ try:
21
+ self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
22
+ except Exception:
23
+ self.encoder = None
24
  if seed is not None:
25
  random.seed(seed)
26
  np.random.seed(seed)
 
118
 
119
  def _get_observation(self):
120
  text = self._get_state_text()
121
+ if self.encoder is None:
122
+ return np.zeros(384, dtype=np.float32)
123
  emb = self.encoder.encode(text, convert_to_numpy=True)
124
  return emb.astype(np.float32)
125
 
 
161
  class HumanImitationEnv(Env):
162
  def __init__(self, data_path="training/data/selfplay_states.json", seed=None):
163
  self.data_path = data_path
164
+ try:
165
+ self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
166
+ except Exception:
167
+ self.encoder = None
168
  if seed is not None:
169
  random.seed(seed)
170
  np.random.seed(seed)
 
262
 
263
  def _get_observation(self):
264
  text = self._get_state_text()
265
+ if self.encoder is None:
266
+ return np.zeros(384, dtype=np.float32)
267
  emb = self.encoder.encode(text, convert_to_numpy=True)
268
  return emb.astype(np.float32)
269