nathanael-fijalkow commited on
Commit
d7e086f
·
1 Parent(s): 4727612

allow for multi tokens to generate a single move

Browse files
Files changed (1) hide show
  1. src/evaluate.py +109 -19
src/evaluate.py CHANGED
@@ -131,6 +131,102 @@ class ChessEvaluator:
131
 
132
  return " ".join(moves)
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def _get_model_move(
135
  self,
136
  board,
@@ -140,6 +236,12 @@ class ChessEvaluator:
140
  """
141
  Get the model's next move prediction.
142
 
 
 
 
 
 
 
143
  Returns:
144
  Tuple of (UCI move string, number of retries used).
145
  """
@@ -160,26 +262,17 @@ class ChessEvaluator:
160
  input_text,
161
  return_tensors="pt",
162
  truncation=True,
163
- max_length=max_len - 1,
164
  ).to(self.device)
165
 
166
  # Try to generate a legal move
167
  for retry in range(self.max_retries):
168
- with torch.no_grad():
169
- outputs = self.model(**inputs)
170
- logits = outputs.logits[:, -1, :] / temperature
171
-
172
- # Apply top-k filtering
173
- if top_k > 0:
174
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
175
- logits[indices_to_remove] = float("-inf")
176
-
177
- # Sample
178
- probs = torch.softmax(logits, dim=-1)
179
- next_token = torch.multinomial(probs, num_samples=1)
180
-
181
- # Decode the move
182
- move_token = self.tokenizer.decode(next_token[0])
183
 
184
  # Convert to UCI
185
  if len(move_token) >= 6:
@@ -196,9 +289,6 @@ class ChessEvaluator:
196
  return uci_move, retry
197
  except (ValueError, self.chess.InvalidMoveError):
198
  pass
199
-
200
- # Mask out the tried token for next retry
201
- logits[0, next_token[0]] = float("-inf")
202
 
203
  return None, self.max_retries
204
 
 
131
 
132
  return " ".join(moves)
133
 
134
+ def _is_separator_token(self, token_str: str) -> bool:
135
+ """
136
+ Check if a token represents a separator (whitespace, EOS, etc.).
137
+
138
+ This allows the evaluator to work with different tokenization strategies:
139
+ - Move-level tokenizers: each move is one token, no separators generated
140
+ - Character-level tokenizers: space character marks end of move
141
+ - BPE/subword tokenizers: may generate partial moves
142
+
143
+ Args:
144
+ token_str: The decoded token string.
145
+
146
+ Returns:
147
+ True if this token indicates end of a move.
148
+ """
149
+ # Check for EOS token
150
+ if hasattr(self.tokenizer, 'eos_token') and token_str == self.tokenizer.eos_token:
151
+ return True
152
+
153
+ # Check for whitespace (space, newline, etc.)
154
+ if token_str.strip() == "" and len(token_str) > 0:
155
+ return True
156
+
157
+ # Check if the token ends with whitespace (some tokenizers include trailing space)
158
+ if token_str != token_str.rstrip():
159
+ return True
160
+
161
+ return False
162
+
163
+ def _generate_move_tokens(
164
+ self,
165
+ input_ids: torch.Tensor,
166
+ temperature: float = 0.7,
167
+ top_k: int = 10,
168
+ max_tokens: int = 20,
169
+ ) -> str:
170
+ """
171
+ Generate tokens until a separator (whitespace/EOS) is encountered.
172
+
173
+ This method supports different tokenization strategies:
174
+ - For move-level tokenizers: generates one token (the full move)
175
+ - For character/subword tokenizers: generates until whitespace
176
+
177
+ Args:
178
+ input_ids: The input token IDs.
179
+ temperature: Sampling temperature.
180
+ top_k: Top-k filtering parameter.
181
+ max_tokens: Maximum tokens to generate for a single move.
182
+
183
+ Returns:
184
+ The generated move string (without trailing separator).
185
+ """
186
+ generated_tokens = []
187
+ current_ids = input_ids.clone()
188
+
189
+ for _ in range(max_tokens):
190
+ with torch.no_grad():
191
+ outputs = self.model(input_ids=current_ids)
192
+ logits = outputs.logits[:, -1, :] / temperature
193
+
194
+ # Apply top-k filtering
195
+ if top_k > 0:
196
+ top_k_values = torch.topk(logits, min(top_k, logits.size(-1)))[0]
197
+ indices_to_remove = logits < top_k_values[..., -1, None]
198
+ logits[indices_to_remove] = float("-inf")
199
+
200
+ # Sample
201
+ probs = torch.softmax(logits, dim=-1)
202
+ next_token = torch.multinomial(probs, num_samples=1) # Shape: [1, 1]
203
+
204
+ # Decode the token
205
+ token_str = self.tokenizer.decode(next_token[0])
206
+
207
+ # Check if this is a separator token
208
+ if self._is_separator_token(token_str):
209
+ break
210
+
211
+ generated_tokens.append(next_token[0]) # Store [1] tensor
212
+
213
+ # Append to input for next iteration (next_token is already [1, 1])
214
+ current_ids = torch.cat([current_ids, next_token], dim=-1)
215
+
216
+ # For move-level tokenizers, a single non-separator token is the full move
217
+ # We can detect this by checking if the token looks like a complete move
218
+ # (starts with W or B, has enough characters for a move)
219
+ if len(token_str) >= 6 and token_str[0] in "WB":
220
+ break
221
+
222
+ # Decode all generated tokens together
223
+ if generated_tokens:
224
+ all_tokens = torch.cat(generated_tokens, dim=0)
225
+ move_str = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
226
+ return move_str.strip()
227
+
228
+ return ""
229
+
230
  def _get_model_move(
231
  self,
232
  board,
 
236
  """
237
  Get the model's next move prediction.
238
 
239
+ This method generates tokens until a separator (whitespace/EOS) is produced,
240
+ allowing it to work with different tokenization strategies:
241
+ - Move-level tokenizers: each move is a single token
242
+ - Character-level tokenizers: moves are generated character by character
243
+ - BPE/subword tokenizers: moves may be split into subwords
244
+
245
  Returns:
246
  Tuple of (UCI move string, number of retries used).
247
  """
 
262
  input_text,
263
  return_tensors="pt",
264
  truncation=True,
265
+ max_length=max_len - 10, # Leave room for generated tokens
266
  ).to(self.device)
267
 
268
  # Try to generate a legal move
269
  for retry in range(self.max_retries):
270
+ # Generate tokens until separator
271
+ move_token = self._generate_move_tokens(
272
+ inputs["input_ids"],
273
+ temperature=temperature,
274
+ top_k=top_k,
275
+ )
 
 
 
 
 
 
 
 
 
276
 
277
  # Convert to UCI
278
  if len(move_token) >= 6:
 
289
  return uci_move, retry
290
  except (ValueError, self.chess.InvalidMoveError):
291
  pass
 
 
 
292
 
293
  return None, self.max_retries
294