Spaces:
Running
Running
Commit
·
d7e086f
1
Parent(s):
4727612
allow for multi tokens to generate a single move
Browse files- src/evaluate.py +109 -19
src/evaluate.py
CHANGED
|
@@ -131,6 +131,102 @@ class ChessEvaluator:
|
|
| 131 |
|
| 132 |
return " ".join(moves)
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
def _get_model_move(
|
| 135 |
self,
|
| 136 |
board,
|
|
@@ -140,6 +236,12 @@ class ChessEvaluator:
|
|
| 140 |
"""
|
| 141 |
Get the model's next move prediction.
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
Returns:
|
| 144 |
Tuple of (UCI move string, number of retries used).
|
| 145 |
"""
|
|
@@ -160,26 +262,17 @@ class ChessEvaluator:
|
|
| 160 |
input_text,
|
| 161 |
return_tensors="pt",
|
| 162 |
truncation=True,
|
| 163 |
-
max_length=max_len -
|
| 164 |
).to(self.device)
|
| 165 |
|
| 166 |
# Try to generate a legal move
|
| 167 |
for retry in range(self.max_retries):
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 175 |
-
logits[indices_to_remove] = float("-inf")
|
| 176 |
-
|
| 177 |
-
# Sample
|
| 178 |
-
probs = torch.softmax(logits, dim=-1)
|
| 179 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 180 |
-
|
| 181 |
-
# Decode the move
|
| 182 |
-
move_token = self.tokenizer.decode(next_token[0])
|
| 183 |
|
| 184 |
# Convert to UCI
|
| 185 |
if len(move_token) >= 6:
|
|
@@ -196,9 +289,6 @@ class ChessEvaluator:
|
|
| 196 |
return uci_move, retry
|
| 197 |
except (ValueError, self.chess.InvalidMoveError):
|
| 198 |
pass
|
| 199 |
-
|
| 200 |
-
# Mask out the tried token for next retry
|
| 201 |
-
logits[0, next_token[0]] = float("-inf")
|
| 202 |
|
| 203 |
return None, self.max_retries
|
| 204 |
|
|
|
|
| 131 |
|
| 132 |
return " ".join(moves)
|
| 133 |
|
| 134 |
+
def _is_separator_token(self, token_str: str) -> bool:
|
| 135 |
+
"""
|
| 136 |
+
Check if a token represents a separator (whitespace, EOS, etc.).
|
| 137 |
+
|
| 138 |
+
This allows the evaluator to work with different tokenization strategies:
|
| 139 |
+
- Move-level tokenizers: each move is one token, no separators generated
|
| 140 |
+
- Character-level tokenizers: space character marks end of move
|
| 141 |
+
- BPE/subword tokenizers: may generate partial moves
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
token_str: The decoded token string.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
True if this token indicates end of a move.
|
| 148 |
+
"""
|
| 149 |
+
# Check for EOS token
|
| 150 |
+
if hasattr(self.tokenizer, 'eos_token') and token_str == self.tokenizer.eos_token:
|
| 151 |
+
return True
|
| 152 |
+
|
| 153 |
+
# Check for whitespace (space, newline, etc.)
|
| 154 |
+
if token_str.strip() == "" and len(token_str) > 0:
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
# Check if the token ends with whitespace (some tokenizers include trailing space)
|
| 158 |
+
if token_str != token_str.rstrip():
|
| 159 |
+
return True
|
| 160 |
+
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
def _generate_move_tokens(
|
| 164 |
+
self,
|
| 165 |
+
input_ids: torch.Tensor,
|
| 166 |
+
temperature: float = 0.7,
|
| 167 |
+
top_k: int = 10,
|
| 168 |
+
max_tokens: int = 20,
|
| 169 |
+
) -> str:
|
| 170 |
+
"""
|
| 171 |
+
Generate tokens until a separator (whitespace/EOS) is encountered.
|
| 172 |
+
|
| 173 |
+
This method supports different tokenization strategies:
|
| 174 |
+
- For move-level tokenizers: generates one token (the full move)
|
| 175 |
+
- For character/subword tokenizers: generates until whitespace
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
input_ids: The input token IDs.
|
| 179 |
+
temperature: Sampling temperature.
|
| 180 |
+
top_k: Top-k filtering parameter.
|
| 181 |
+
max_tokens: Maximum tokens to generate for a single move.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
The generated move string (without trailing separator).
|
| 185 |
+
"""
|
| 186 |
+
generated_tokens = []
|
| 187 |
+
current_ids = input_ids.clone()
|
| 188 |
+
|
| 189 |
+
for _ in range(max_tokens):
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
outputs = self.model(input_ids=current_ids)
|
| 192 |
+
logits = outputs.logits[:, -1, :] / temperature
|
| 193 |
+
|
| 194 |
+
# Apply top-k filtering
|
| 195 |
+
if top_k > 0:
|
| 196 |
+
top_k_values = torch.topk(logits, min(top_k, logits.size(-1)))[0]
|
| 197 |
+
indices_to_remove = logits < top_k_values[..., -1, None]
|
| 198 |
+
logits[indices_to_remove] = float("-inf")
|
| 199 |
+
|
| 200 |
+
# Sample
|
| 201 |
+
probs = torch.softmax(logits, dim=-1)
|
| 202 |
+
next_token = torch.multinomial(probs, num_samples=1) # Shape: [1, 1]
|
| 203 |
+
|
| 204 |
+
# Decode the token
|
| 205 |
+
token_str = self.tokenizer.decode(next_token[0])
|
| 206 |
+
|
| 207 |
+
# Check if this is a separator token
|
| 208 |
+
if self._is_separator_token(token_str):
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
generated_tokens.append(next_token[0]) # Store [1] tensor
|
| 212 |
+
|
| 213 |
+
# Append to input for next iteration (next_token is already [1, 1])
|
| 214 |
+
current_ids = torch.cat([current_ids, next_token], dim=-1)
|
| 215 |
+
|
| 216 |
+
# For move-level tokenizers, a single non-separator token is the full move
|
| 217 |
+
# We can detect this by checking if the token looks like a complete move
|
| 218 |
+
# (starts with W or B, has enough characters for a move)
|
| 219 |
+
if len(token_str) >= 6 and token_str[0] in "WB":
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
# Decode all generated tokens together
|
| 223 |
+
if generated_tokens:
|
| 224 |
+
all_tokens = torch.cat(generated_tokens, dim=0)
|
| 225 |
+
move_str = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
|
| 226 |
+
return move_str.strip()
|
| 227 |
+
|
| 228 |
+
return ""
|
| 229 |
+
|
| 230 |
def _get_model_move(
|
| 231 |
self,
|
| 232 |
board,
|
|
|
|
| 236 |
"""
|
| 237 |
Get the model's next move prediction.
|
| 238 |
|
| 239 |
+
This method generates tokens until a separator (whitespace/EOS) is produced,
|
| 240 |
+
allowing it to work with different tokenization strategies:
|
| 241 |
+
- Move-level tokenizers: each move is a single token
|
| 242 |
+
- Character-level tokenizers: moves are generated character by character
|
| 243 |
+
- BPE/subword tokenizers: moves may be split into subwords
|
| 244 |
+
|
| 245 |
Returns:
|
| 246 |
Tuple of (UCI move string, number of retries used).
|
| 247 |
"""
|
|
|
|
| 262 |
input_text,
|
| 263 |
return_tensors="pt",
|
| 264 |
truncation=True,
|
| 265 |
+
max_length=max_len - 10, # Leave room for generated tokens
|
| 266 |
).to(self.device)
|
| 267 |
|
| 268 |
# Try to generate a legal move
|
| 269 |
for retry in range(self.max_retries):
|
| 270 |
+
# Generate tokens until separator
|
| 271 |
+
move_token = self._generate_move_tokens(
|
| 272 |
+
inputs["input_ids"],
|
| 273 |
+
temperature=temperature,
|
| 274 |
+
top_k=top_k,
|
| 275 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
# Convert to UCI
|
| 278 |
if len(move_token) >= 6:
|
|
|
|
| 289 |
return uci_move, retry
|
| 290 |
except (ValueError, self.chess.InvalidMoveError):
|
| 291 |
pass
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
return None, self.max_retries
|
| 294 |
|