Maxlegrec commited on
Commit
d7ecccf
·
verified ·
1 Parent(s): ae675a7

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +756 -0
model.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .attn_map import apm_map, apm_out
5
+ import math
6
+ from .encoding_simple import encode_fen_to_tensor, encode_moves_to_tensor
7
+ from .vocab import policy_index
8
+ from typing import Union, List, Optional
9
+ import bulletchess
10
+ import numpy as np
11
+
12
+ class Gating(nn.Module):
13
+ def __init__(self, features_shape, additive=True, init_value=None):
14
+ super(Gating, self).__init__()
15
+ self.additive = additive
16
+ if init_value is None:
17
+ init_value = 0 if self.additive else 1
18
+
19
+ self.gate = nn.Parameter(torch.full(features_shape, float(init_value)))
20
+ if not self.additive:
21
+ self.gate.register_hook(lambda grad: torch.clamp(grad, min=0))
22
+
23
+ def forward(self, x):
24
+ if self.additive:
25
+ return x + self.gate
26
+ else:
27
+ return x * self.gate
28
+
29
+ def ma_gating(x, in_features):
30
+ x = Gating(in_features, additive=False)(x)
31
+ x = Gating(in_features, additive=True)(x)
32
+ return x
33
+
34
+ class RMSNorm(nn.Module):
35
+ def __init__(self, in_features, scale=True):
36
+ super(RMSNorm, self).__init__()
37
+ self.scale = scale
38
+ if self.scale:
39
+ self.gamma = nn.Parameter(torch.ones(in_features))
40
+
41
+ def forward(self, x):
42
+ rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-5)
43
+ x_normalized = x / rms
44
+ if self.scale:
45
+ return x_normalized * self.gamma
46
+ return x_normalized
47
+
48
+ class ApplyAttentionPolicyMap(nn.Module):
49
+ def __init__(self):
50
+ super(ApplyAttentionPolicyMap, self).__init__()
51
+ # Register as buffers so they move with the model when .to(device) is called
52
+ # Use same names as before for backward compatibility with saved models
53
+ self.register_buffer('fc1', torch.from_numpy(apm_map).float())
54
+ self.register_buffer('idx', torch.from_numpy(apm_out).long())
55
+
56
+ def forward(self, logits, pp_logits):
57
+ logits = torch.cat([logits.reshape(-1, 64 * 64),
58
+ pp_logits.reshape(-1, 8 * 24)],
59
+ dim=1)
60
+
61
+ batch_size = logits.size(0)
62
+ idx = self.idx.unsqueeze(0).expand(batch_size, -1)
63
+
64
+ return torch.gather(logits, 1, idx)
65
+
66
+ class Mish(nn.Module):
67
+ def __init__(self):
68
+ super(Mish, self).__init__()
69
+
70
+ def forward(self, x):
71
+ return x * torch.tanh(F.softplus(x))
72
+
73
+ class CustomMHA(nn.Module):
74
+ def __init__(self, emb_size, d_model, num_heads, dropout=0.0, use_bias_qkv=True, use_bias_out=True,
75
+ use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'):
76
+ super(CustomMHA, self).__init__()
77
+ assert d_model % num_heads == 0
78
+ self.emb_size = emb_size
79
+ self.d_model = d_model
80
+ self.num_heads = num_heads
81
+ self.head_dim = d_model // num_heads
82
+ self.wq = nn.Linear(emb_size, d_model, bias=use_bias_qkv)
83
+ self.wk = nn.Linear(emb_size, d_model, bias=use_bias_qkv)
84
+ self.wv = nn.Linear(emb_size, d_model, bias=use_bias_qkv)
85
+ self.out_proj = nn.Linear(d_model, emb_size, bias=use_bias_out)
86
+ self.attn_dropout = nn.Dropout(dropout)
87
+ # Optional Smolgen components
88
+ self.smol_compress = None
89
+ self.smol_hidden1 = None
90
+ self.smol_hidden1_ln = None
91
+ self.smol_gen_from = None
92
+ self.smol_gen_from_ln = None
93
+ self.smol_weight_gen = None
94
+ if use_smolgen:
95
+ self.smol_compress = nn.Linear(emb_size, smol_hidden_channels, bias=False)
96
+ self.smol_hidden1 = nn.Linear(64 * smol_hidden_channels, smol_hidden_sz, bias=True)
97
+ self.smol_hidden1_ln = nn.LayerNorm(smol_hidden_sz, eps=1e-3)
98
+ self.smol_gen_from = nn.Linear(smol_hidden_sz, num_heads * smol_gen_sz, bias=True)
99
+ self.smol_gen_from_ln = nn.LayerNorm(num_heads * smol_gen_sz, eps=1e-3)
100
+ self.smol_weight_gen = nn.Linear(smol_gen_sz, 64 * 64, bias=False)
101
+ self.smol_activation = smol_activation
102
+
103
+ def _shape(self, x):
104
+ b, l, _ = x.shape
105
+ return x.view(b, l, self.num_heads, self.head_dim).transpose(1, 2)
106
+
107
+ def forward(self, x, return_attn=False):
108
+ # x: (B, L, emb_size)
109
+ q = self.wq(x)
110
+ k = self.wk(x)
111
+ v = self.wv(x)
112
+ q = self._shape(q) # (B, H, L, D)
113
+ k = self._shape(k)
114
+ v = self._shape(v)
115
+ scale = torch.sqrt(torch.tensor(self.head_dim, dtype=x.dtype, device=x.device))
116
+ attn_logits = torch.matmul(q, k.transpose(-2, -1)) / scale
117
+ # Add Smolgen weights if present
118
+ smol_w = None
119
+ if self.smol_compress is not None:
120
+ b, l, _ = x.shape
121
+ compressed = self.smol_compress(x) # (B, L, hidden_channels)
122
+ compressed = compressed.reshape(b, l * compressed.shape[-1]) # (B, 64*hidden_channels)
123
+ hidden_pre = self.smol_hidden1(compressed)
124
+ hidden = F.silu(hidden_pre) if self.smol_activation == 'swish' else F.silu(hidden_pre)
125
+ hidden_ln = self.smol_hidden1_ln(hidden)
126
+ gen_from_pre = self.smol_gen_from(hidden_ln)
127
+ gen_from_act = F.silu(gen_from_pre) if self.smol_activation == 'swish' else F.silu(gen_from_pre)
128
+ gen_from = self.smol_gen_from_ln(gen_from_act)
129
+ gen_from = gen_from.view(b, self.num_heads, -1) # (B, H, gen_sz)
130
+ smol_w = self.smol_weight_gen(gen_from) # (B, H, 64*64)
131
+ smol_w = smol_w.view(b, self.num_heads, l, l)
132
+ attn_logits = attn_logits + smol_w
133
+ # Numerically stable softmax matching TF (float32, subtract max)
134
+ attn_logits = attn_logits - attn_logits.max(dim=-1, keepdim=True)[0]
135
+ attn_weights = torch.exp(attn_logits)
136
+ attn_weights = attn_weights / attn_weights.sum(dim=-1, keepdim=True)
137
+ attn_weights = self.attn_dropout(attn_weights)
138
+ attn_output = torch.matmul(attn_weights, v) # (B, H, L, D)
139
+ attn_output = attn_output.transpose(1, 2).contiguous().view(x.size(0), x.size(1), self.d_model)
140
+ out = self.out_proj(attn_output)
141
+ if return_attn:
142
+ return out, attn_weights, smol_w, attn_logits
143
+ return out
144
+
145
+ class FFN(nn.Module):
146
+ def __init__(self, emb_size, dff, activation=Mish(), omit_other_biases=False):
147
+ super(FFN, self).__init__()
148
+ self.dense1 = nn.Linear(emb_size, dff, bias=not omit_other_biases)
149
+ self.activation = activation
150
+ self.dense2 = nn.Linear(dff, emb_size, bias=not omit_other_biases)
151
+
152
+ def forward(self, x):
153
+ x = self.dense1(x)
154
+ x = self.activation(x)
155
+ x = self.dense2(x)
156
+ return x
157
+
158
+ class EncoderLayer(nn.Module):
159
+ def __init__(self, emb_size, d_model, num_heads, dff, dropout_rate, encoder_layers, skip_first_ln=False, encoder_rms_norm=False, omit_qkv_biases=False, omit_other_biases=False,
160
+ use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'):
161
+ super(EncoderLayer, self).__init__()
162
+ self.mha = CustomMHA(emb_size, d_model, num_heads, dropout=dropout_rate, use_bias_qkv=not omit_qkv_biases, use_bias_out=not omit_other_biases,
163
+ use_smolgen=use_smolgen, smol_hidden_channels=smol_hidden_channels, smol_hidden_sz=smol_hidden_sz, smol_gen_sz=smol_gen_sz, smol_activation=smol_activation)
164
+ self.ffn = FFN(emb_size, dff, omit_other_biases=omit_other_biases)
165
+
166
+ self.norm1 = RMSNorm(emb_size) if encoder_rms_norm else nn.LayerNorm(emb_size, eps=0.001)
167
+ self.norm2 = RMSNorm(emb_size) if encoder_rms_norm else nn.LayerNorm(emb_size, eps=0.001)
168
+
169
+ self.dropout1 = nn.Dropout(dropout_rate)
170
+ self.dropout2 = nn.Dropout(dropout_rate)
171
+
172
+ self.alpha = (2. * encoder_layers)**-0.25
173
+ self.skip_first_ln = skip_first_ln
174
+
175
+ def forward(self, x):
176
+ attn_output = self.mha(x)
177
+ attn_output = self.dropout1(attn_output)
178
+
179
+ out1 = x + attn_output * self.alpha
180
+ if not self.skip_first_ln:
181
+ out1 = self.norm1(out1)
182
+ ffn_output = self.ffn(out1)
183
+ ffn_output = self.dropout2(ffn_output)
184
+
185
+ out2 = self.norm2(out1 + ffn_output * self.alpha)
186
+ return out2
187
+
188
+ class PolicyHead(nn.Module):
189
+ def __init__(self, pol_embedding_size, policy_d_model, opponent=False):
190
+ super(PolicyHead, self).__init__()
191
+ self.opponent = opponent
192
+ self.wq = nn.Linear(pol_embedding_size, policy_d_model)
193
+ self.wk = nn.Linear(pol_embedding_size, policy_d_model)
194
+ self.ppo = nn.Linear(policy_d_model, 4, bias=False)
195
+ self.apply_map = ApplyAttentionPolicyMap()
196
+
197
+ def forward(self, x):
198
+ if self.opponent:
199
+ x = torch.flip(x, [1])
200
+
201
+ queries = self.wq(x)
202
+ keys = self.wk(x)
203
+
204
+ matmul_qk = torch.matmul(queries, keys.transpose(-2, -1))
205
+
206
+ dk = torch.sqrt(torch.tensor(keys.shape[-1], dtype=keys.dtype, device=keys.device))
207
+ promotion_keys = keys[:, -8:, :]
208
+ promotion_offsets = self.ppo(promotion_keys).transpose(-2,-1) * dk
209
+ promotion_offsets = promotion_offsets[:, :3, :] + promotion_offsets[:, 3:4, :]
210
+
211
+ n_promo_logits = matmul_qk[:, -16:-8, -8:]
212
+ q_promo_logits = (n_promo_logits + promotion_offsets[:, 0:1, :]).unsqueeze(3)
213
+ r_promo_logits = (n_promo_logits + promotion_offsets[:, 1:2, :]).unsqueeze(3)
214
+ b_promo_logits = (n_promo_logits + promotion_offsets[:, 2:3, :]).unsqueeze(3)
215
+ promotion_logits = torch.cat([q_promo_logits, r_promo_logits, b_promo_logits], axis=3).reshape(-1, 8, 24)
216
+
217
+ promotion_logits = promotion_logits / dk
218
+ policy_attn_logits = matmul_qk / dk
219
+
220
+ return self.apply_map(policy_attn_logits, promotion_logits)
221
+
222
+ class ValueHead(nn.Module):
223
+ def __init__(self, embedding_size, val_embedding_size, default_activation=Mish()):
224
+ super(ValueHead, self).__init__()
225
+ self.embedding = nn.Linear(embedding_size, val_embedding_size)
226
+ self.activation = default_activation
227
+ self.flatten = nn.Flatten()
228
+ self.dense1 = nn.Linear(val_embedding_size * 64, 128)
229
+ self.dense2 = nn.Linear(128, 3)
230
+
231
+ def forward(self, x):
232
+ x = self.embedding(x)
233
+ x = self.activation(x)
234
+ x = self.flatten(x)
235
+ x = self.dense1(x)
236
+ x = self.activation(x)
237
+ x = self.dense2(x)
238
+ return x
239
+
240
+ class BT4(nn.Module):
241
+ def __init__(self, embedding_size=1024, embedding_dense_sz=512, encoder_layers=15, encoder_d_model=1024, encoder_heads=32, encoder_dff=1536, dropout_rate=0.0, pol_embedding_size=1024, policy_d_model=1024, val_embedding_size=128, default_activation=Mish(),
242
+ use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'):
243
+ super(BT4, self).__init__()
244
+ self.embedding_dense_sz = embedding_dense_sz
245
+ # DeepNorm alpha used in embedding residual; default uses provided encoder_layers
246
+ self.deepnorm_alpha = (2. * encoder_layers) ** -0.25
247
+
248
+ self.embedding_preprocess = nn.Linear(64*12, 64*self.embedding_dense_sz)
249
+ self.embedding = nn.Linear(112 + self.embedding_dense_sz, embedding_size)
250
+ nn.init.xavier_uniform_(self.embedding.weight) # Explicitly set initializer
251
+ nn.init.zeros_(self.embedding.bias)
252
+
253
+ self.embedding_ln = nn.LayerNorm(embedding_size, eps=0.001)
254
+
255
+ self.gating_mult = Gating((64, embedding_size), additive=False)
256
+ self.gating_add = Gating((64, embedding_size), additive=True)
257
+
258
+ self.embedding_ffn = FFN(embedding_size, encoder_dff)
259
+ self.embedding_ffn_ln = nn.LayerNorm(embedding_size, eps=0.001)
260
+
261
+ self.encoder_layers_list = nn.ModuleList([
262
+ EncoderLayer(embedding_size, encoder_d_model, encoder_heads, encoder_dff, dropout_rate, encoder_layers,
263
+ use_smolgen=use_smolgen, smol_hidden_channels=smol_hidden_channels, smol_hidden_sz=smol_hidden_sz, smol_gen_sz=smol_gen_sz, smol_activation=smol_activation)
264
+ for _ in range(encoder_layers)
265
+ ])
266
+
267
+ self.policy_embedding = nn.Linear(embedding_size, pol_embedding_size)
268
+ self.policy_head = PolicyHead(pol_embedding_size, policy_d_model)
269
+ self.value_head_winner = ValueHead(embedding_size, val_embedding_size)
270
+ self.value_head_q = ValueHead(embedding_size, val_embedding_size)
271
+ self.activation = default_activation
272
+
273
+ self.apply(self._init_weights)
274
+
275
+ def _init_weights(self, module):
276
+ if isinstance(module, nn.Linear):
277
+ # Keras' glorot_normal is equivalent to PyTorch's xavier_normal_
278
+ nn.init.xavier_normal_(module.weight)
279
+ if module.bias is not None:
280
+ nn.init.zeros_(module.bias)
281
+
282
+ def forward(self, x):
283
+ # x shape: (batch, 112, 8, 8)
284
+ flow = x.permute(0, 2, 3, 1).reshape(-1, 64, 112)
285
+
286
+ pos_info = flow[..., :12]
287
+ pos_info_flat = pos_info.reshape(-1, 64 * 12)
288
+
289
+ pos_info_processed = self.embedding_preprocess(pos_info_flat)
290
+ pos_info = pos_info_processed.reshape(-1, 64, self.embedding_dense_sz)
291
+
292
+ flow = torch.cat([flow, pos_info], dim=-1)
293
+
294
+ flow = self.embedding(flow)
295
+
296
+ flow = self.activation(flow)
297
+
298
+ flow = self.embedding_ln(flow)
299
+
300
+ flow = self.gating_mult(flow)
301
+ flow = self.gating_add(flow)
302
+
303
+ ffn_dense1_pre = self.embedding_ffn.dense1(flow)
304
+ ffn_dense1 = self.embedding_ffn.activation(ffn_dense1_pre)
305
+ ffn_output = self.embedding_ffn.dense2(ffn_dense1)
306
+
307
+ residual = flow + ffn_output * self.deepnorm_alpha
308
+ flow = self.embedding_ffn_ln(residual)
309
+
310
+ for i, layer in enumerate(self.encoder_layers_list):
311
+ flow = layer(flow)
312
+
313
+ policy_tokens = self.policy_embedding(flow)
314
+ policy_tokens = self.activation(policy_tokens)
315
+
316
+ policy_logits = self.policy_head(policy_tokens)
317
+
318
+ value_winner = self.value_head_winner(flow)
319
+ value_q = self.value_head_q(flow)
320
+
321
+ return policy_logits, value_winner, value_q
322
+
323
+ def get_move_from_fen_no_thinking(self, fen_or_moves: Union[str, List[str]], T: float, device: str = None, **kwargs) -> str:
324
+ """
325
+ Predict a move from a FEN position or move history without thinking/search.
326
+
327
+ Args:
328
+ fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves
329
+ T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
330
+ device: Device to run the model on (if None, uses model's device)
331
+
332
+ Returns:
333
+ UCI move string (e.g., 'e2e4')
334
+ """
335
+ # Detect device from model if not provided
336
+ if device is None:
337
+ device = next(self.parameters()).device
338
+ else:
339
+ device = torch.device(device)
340
+
341
+ # Determine if input is FEN string or list of moves
342
+ if isinstance(fen_or_moves, str):
343
+ # FEN string input
344
+ fen = fen_or_moves
345
+ is_black_to_move = fen.split()[1] == 'b'
346
+ input_tensor_112, legal_moves_mask = encode_fen_to_tensor(fen)
347
+ castling_rights = fen.split()[2] if len(fen.split()) > 2 else ""
348
+ elif isinstance(fen_or_moves, list):
349
+ # List of UCI moves input
350
+ move_history = fen_or_moves
351
+ input_tensor_112, legal_moves_mask = encode_moves_to_tensor(move_history)
352
+ # Create board to check if black is to move and for castling rights
353
+ board = bulletchess.Board()
354
+ for mv in move_history:
355
+ move = bulletchess.Move.from_uci(mv)
356
+ board.apply(move)
357
+ is_black_to_move = (board.turn == bulletchess.BLACK)
358
+ fen_parts = board.fen().split()
359
+ castling_rights = fen_parts[2] if len(fen_parts) > 2 else ""
360
+ else:
361
+ raise ValueError("Input must be a FEN string or a list of UCI moves")
362
+
363
+ input_tensor_112 = input_tensor_112.to(device, non_blocking=True)
364
+
365
+ self.eval()
366
+ with torch.inference_mode():
367
+ policy_logits,_,_ = self.forward(input_tensor_112)
368
+
369
+ # Apply legal moves mask without in-place ops (inference tensor)
370
+ logits0 = policy_logits[0] + torch.from_numpy(legal_moves_mask).to(policy_logits.device)
371
+
372
+ # Check if return_probs is requested
373
+ return_probs = kwargs.get('return_probs', False)
374
+
375
+ if return_probs:
376
+ # Return probabilities dictionary
377
+ scaled_logits = logits0 / T if T > 0 else logits0
378
+ probs = F.softmax(scaled_logits, dim=0)
379
+ probs_dict = {}
380
+ for idx, move in enumerate(policy_index):
381
+ prob_val = probs[idx].item()
382
+ if prob_val > 1e-6: # Only include moves with non-negligible probability
383
+ probs_dict[move] = prob_val
384
+ return probs_dict
385
+
386
+ if T == 0.0:
387
+ # Deterministic: return best move
388
+ best_move_idx = torch.argmax(logits0).item()
389
+ uci_move = policy_index[best_move_idx]
390
+ else:
391
+ # Stochastic sampling with temperature
392
+ # Apply temperature scaling
393
+ scaled_logits = logits0 / T
394
+ # Apply softmax to get probabilities
395
+ probs = F.softmax(scaled_logits, dim=0)
396
+ # Sample from the distribution
397
+ move_idx = torch.multinomial(probs, 1).item()
398
+ uci_move = policy_index[move_idx]
399
+
400
+ # If black is to move, the board was mirrored during encoding, so we need to mirror the move back
401
+ # Mirror ranks: 1↔8, 2↔7, 3↔6, 4↔5 (keep file letters the same)
402
+ if is_black_to_move:
403
+ def mirror_rank(rank_char):
404
+ rank = int(rank_char)
405
+ return str(9 - rank)
406
+
407
+ # UCI format: e2e4, e7e8q, etc.
408
+ if len(uci_move) >= 4:
409
+ from_file = uci_move[0]
410
+ from_rank = uci_move[1]
411
+ to_file = uci_move[2]
412
+ to_rank = uci_move[3]
413
+ promo = uci_move[4:] if len(uci_move) > 4 else ""
414
+
415
+ uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo
416
+
417
+ # Convert castling moves from king-to-rook-square format to standard castling format
418
+ # Only if castling rights are available (check FEN castling rights)
419
+ # Check and convert white castling moves
420
+ if uci_move == "e1h1" and "K" in castling_rights:
421
+ uci_move = "e1g1"
422
+ elif uci_move == "e1a1" and "Q" in castling_rights:
423
+ uci_move = "e1c1"
424
+ # Check and convert black castling moves
425
+ elif uci_move == "e8h8" and "k" in castling_rights:
426
+ uci_move = "e8g8"
427
+ elif uci_move == "e8a8" and "q" in castling_rights:
428
+ uci_move = "e8c8"
429
+
430
+ return uci_move
431
+
432
+ def get_best_move_value(self, fen_or_moves: Union[str, List[str]], T: float = 0.0, device: str = None) -> tuple:
433
+ """
434
+ Get the best move and its value using value analysis.
435
+
436
+ Args:
437
+ fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves
438
+ T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
439
+ device: Device to run the model on (if None, uses model's device)
440
+
441
+ Returns:
442
+ Tuple of (best_move, value) where value is the position evaluation
443
+ """
444
+ # Detect device from model if not provided
445
+ if device is None:
446
+ device = next(self.parameters()).device
447
+ else:
448
+ device = torch.device(device)
449
+
450
+ # Determine if input is FEN string or list of moves
451
+ if isinstance(fen_or_moves, str):
452
+ fen = fen_or_moves
453
+ is_black_to_move = fen.split()[1] == 'b'
454
+ input_tensor_112, legal_moves_mask = encode_fen_to_tensor(fen)
455
+ castling_rights = fen.split()[2] if len(fen.split()) > 2 else ""
456
+ elif isinstance(fen_or_moves, list):
457
+ move_history = fen_or_moves
458
+ input_tensor_112, legal_moves_mask = encode_moves_to_tensor(move_history)
459
+ board = bulletchess.Board()
460
+ for mv in move_history:
461
+ move = bulletchess.Move.from_uci(mv)
462
+ board.apply(move)
463
+ is_black_to_move = (board.turn == bulletchess.BLACK)
464
+ fen_parts = board.fen().split()
465
+ castling_rights = fen_parts[2] if len(fen_parts) > 2 else ""
466
+ else:
467
+ raise ValueError("Input must be a FEN string or a list of UCI moves")
468
+
469
+ input_tensor_112 = input_tensor_112.to(device, non_blocking=True)
470
+
471
+ self.eval()
472
+ with torch.inference_mode():
473
+ policy_logits, _, value_q = self.forward(input_tensor_112)
474
+
475
+ # Apply legal moves mask
476
+ logits0 = policy_logits[0] + torch.from_numpy(legal_moves_mask).to(policy_logits.device)
477
+
478
+ # Get best move
479
+ if T == 0.0:
480
+ best_move_idx = torch.argmax(logits0).item()
481
+ else:
482
+ scaled_logits = logits0 / T
483
+ probs = F.softmax(scaled_logits, dim=0)
484
+ move_idx = torch.multinomial(probs, 1).item()
485
+ best_move_idx = move_idx
486
+
487
+ uci_move = policy_index[best_move_idx]
488
+
489
+ # Mirror move if black is to move
490
+ if is_black_to_move:
491
+ def mirror_rank(rank_char):
492
+ rank = int(rank_char)
493
+ return str(9 - rank)
494
+
495
+ if len(uci_move) >= 4:
496
+ from_file = uci_move[0]
497
+ from_rank = uci_move[1]
498
+ to_file = uci_move[2]
499
+ to_rank = uci_move[3]
500
+ promo = uci_move[4:] if len(uci_move) > 4 else ""
501
+ uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo
502
+
503
+ # Convert castling moves
504
+ if uci_move == "e1h1" and "K" in castling_rights:
505
+ uci_move = "e1g1"
506
+ elif uci_move == "e1a1" and "Q" in castling_rights:
507
+ uci_move = "e1c1"
508
+ elif uci_move == "e8h8" and "k" in castling_rights:
509
+ uci_move = "e8g8"
510
+ elif uci_move == "e8a8" and "q" in castling_rights:
511
+ uci_move = "e8c8"
512
+
513
+ # Get value (softmax over value_q)
514
+ value_probs = F.softmax(value_q[0], dim=0)
515
+ value = value_probs.cpu().numpy()
516
+
517
+ return uci_move, value
518
+
519
+ def get_position_value(self, fen_or_moves: Union[str, List[str]], device: str = None) -> np.ndarray:
520
+ """
521
+ Get position evaluation using value_q.
522
+
523
+ Args:
524
+ fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves
525
+ device: Device to run the model on (if None, uses model's device)
526
+
527
+ Returns:
528
+ Array of [black_win, draw, white_win] probabilities
529
+ """
530
+ # Detect device from model if not provided
531
+ if device is None:
532
+ device = next(self.parameters()).device
533
+ else:
534
+ device = torch.device(device)
535
+
536
+ # Determine if input is FEN string or list of moves
537
+ if isinstance(fen_or_moves, str):
538
+ input_tensor_112, _ = encode_fen_to_tensor(fen_or_moves)
539
+ elif isinstance(fen_or_moves, list):
540
+ input_tensor_112, _ = encode_moves_to_tensor(fen_or_moves)
541
+ else:
542
+ raise ValueError("Input must be a FEN string or a list of UCI moves")
543
+
544
+ input_tensor_112 = input_tensor_112.to(device, non_blocking=True)
545
+
546
+ self.eval()
547
+ with torch.inference_mode():
548
+ _, _, value_q = self.forward(input_tensor_112)
549
+
550
+ # Apply softmax to get probabilities [black_win, draw, white_win]
551
+ value_probs = F.softmax(value_q[0], dim=0)
552
+ return value_probs.cpu().numpy()
553
+
554
+ def batch_get_moves_from_fens(self, fens: List[str], T: float, device: str = None, use_fp16: bool = False) -> List[str]:
555
+ """
556
+ Get moves for multiple FEN positions using batched inference.
557
+
558
+ Args:
559
+ fens: List of FEN strings representing chess positions
560
+ T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
561
+ device: Device to run the model on (if None, uses model's device)
562
+
563
+ Returns:
564
+ List of UCI move strings
565
+ """
566
+ if not fens:
567
+ return []
568
+
569
+ # Detect device from model if not provided
570
+ if device is None:
571
+ device = next(self.parameters()).device
572
+ else:
573
+ device = torch.device(device)
574
+
575
+ batch_size = len(fens)
576
+
577
+ # Batch encode all FENs
578
+ input_tensors = []
579
+ legal_moves_masks = []
580
+ is_black_to_move_list = []
581
+ castling_rights_list = []
582
+
583
+ for fen in fens:
584
+ input_tensor, legal_mask = encode_fen_to_tensor(fen)
585
+ input_tensors.append(input_tensor.squeeze(0)) # Remove batch dim
586
+ legal_moves_masks.append(legal_mask)
587
+ is_black_to_move_list.append(fen.split()[1] == 'b')
588
+ castling_rights_list.append(fen.split()[2] if len(fen.split()) > 2 else "")
589
+
590
+ # Stack into batch tensor: (batch_size, 112, 8, 8)
591
+ batch_tensor = torch.stack(input_tensors).to(device, non_blocking=True)
592
+ if use_fp16 and device.type == 'cuda':
593
+ batch_tensor = batch_tensor.half()
594
+
595
+ # Run batched inference
596
+ self.eval()
597
+ with torch.inference_mode():
598
+ if use_fp16 and device.type == 'cuda':
599
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
600
+ policy_logits,_,_ = self.forward(batch_tensor)
601
+ else:
602
+ policy_logits,_,_ = self.forward(batch_tensor)
603
+
604
+ # Process each position in the batch
605
+ moves = []
606
+ for i in range(batch_size):
607
+ # Apply legal moves mask
608
+ logits = policy_logits[i] + torch.from_numpy(legal_moves_masks[i]).to(policy_logits.device, dtype=policy_logits.dtype)
609
+
610
+ # Sample move
611
+ if T == 0.0:
612
+ best_move_idx = torch.argmax(logits).item()
613
+ uci_move = policy_index[best_move_idx]
614
+ else:
615
+ scaled_logits = logits / T
616
+ probs = F.softmax(scaled_logits, dim=0)
617
+ move_idx = torch.multinomial(probs, 1).item()
618
+ uci_move = policy_index[move_idx]
619
+
620
+ # Mirror move if black is to move
621
+ if is_black_to_move_list[i]:
622
+ def mirror_rank(rank_char):
623
+ rank = int(rank_char)
624
+ return str(9 - rank)
625
+
626
+ if len(uci_move) >= 4:
627
+ from_file = uci_move[0]
628
+ from_rank = uci_move[1]
629
+ to_file = uci_move[2]
630
+ to_rank = uci_move[3]
631
+ promo = uci_move[4:] if len(uci_move) > 4 else ""
632
+
633
+ uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo
634
+
635
+ # Convert castling moves
636
+ castling_rights = castling_rights_list[i]
637
+ if uci_move == "e1h1" and "K" in castling_rights:
638
+ uci_move = "e1g1"
639
+ elif uci_move == "e1a1" and "Q" in castling_rights:
640
+ uci_move = "e1c1"
641
+ elif uci_move == "e8h8" and "k" in castling_rights:
642
+ uci_move = "e8g8"
643
+ elif uci_move == "e8a8" and "q" in castling_rights:
644
+ uci_move = "e8c8"
645
+
646
+ moves.append(uci_move)
647
+
648
+ return moves
649
+
650
+ def batch_get_moves_from_move_lists(self, move_lists: List[List[str]], T: float, device: str = None, use_fp16: bool = False, fens: Optional[List[str]] = None):
651
+ """
652
+ Get moves for multiple move histories using batched inference.
653
+
654
+ Args:
655
+ move_lists: List of move sequences, where each sequence is a list of UCI moves
656
+ T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
657
+ device: Device to run the model on (if None, uses model's device)
658
+ fens: Optional list of FEN strings that represent the board state prior to
659
+ applying the corresponding move list. When provided, each move history
660
+ is applied starting from the supplied FEN instead of the standard initial position.
661
+
662
+ Returns:
663
+ List of UCI move strings
664
+ """
665
+ if not move_lists:
666
+ return []
667
+
668
+ # Detect device from model if not provided
669
+ if device is None:
670
+ device = next(self.parameters()).device
671
+ else:
672
+ device = torch.device(device)
673
+
674
+ batch_size = len(move_lists)
675
+
676
+ if fens is not None and len(fens) != len(move_lists):
677
+ raise ValueError("Length of fens must match length of move_lists when provided.")
678
+
679
+ # Batch encode all move histories
680
+ input_tensors = []
681
+ legal_moves_masks = []
682
+ is_black_to_move_list = []
683
+ castling_rights_list = []
684
+
685
+ for idx, move_history in enumerate(move_lists):
686
+ starting_fen = fens[idx] if fens is not None else None
687
+ input_tensor, legal_mask = encode_moves_to_tensor(move_history, starting_fen=starting_fen)
688
+ input_tensors.append(input_tensor.squeeze(0)) # Remove batch dim
689
+ legal_moves_masks.append(legal_mask)
690
+
691
+ board = bulletchess.Board.from_fen(starting_fen) if starting_fen is not None else bulletchess.Board()
692
+ for mv in move_history:
693
+ move = bulletchess.Move.from_uci(mv)
694
+ board.apply(move)
695
+ is_black_to_move_list.append(board.turn == bulletchess.BLACK)
696
+ fen_parts = board.fen().split()
697
+ castling_rights_list.append(fen_parts[2] if len(fen_parts) > 2 else "")
698
+
699
+ # Stack into batch tensor: (batch_size, 112, 8, 8)
700
+ batch_tensor = torch.stack(input_tensors).to(device, non_blocking=True)
701
+ if use_fp16 and device.type == 'cuda':
702
+ batch_tensor = batch_tensor.half()
703
+
704
+ # Run batched inference
705
+ self.eval()
706
+ with torch.inference_mode():
707
+ if use_fp16 and device.type == 'cuda':
708
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
709
+ policy_logits,_,_ = self.forward(batch_tensor)
710
+ else:
711
+ policy_logits,_,_ = self.forward(batch_tensor)
712
+
713
+ # Process each position in the batch
714
+ moves = []
715
+ for i in range(batch_size):
716
+ # Apply legal moves mask
717
+ logits = policy_logits[i] + torch.from_numpy(legal_moves_masks[i]).to(policy_logits.device, dtype=policy_logits.dtype)
718
+
719
+ # Sample move
720
+ if T == 0.0:
721
+ best_move_idx = torch.argmax(logits).item()
722
+ uci_move = policy_index[best_move_idx]
723
+ else:
724
+ scaled_logits = logits / T
725
+ probs = F.softmax(scaled_logits, dim=0)
726
+ move_idx = torch.multinomial(probs, 1).item()
727
+ uci_move = policy_index[move_idx]
728
+
729
+ # Mirror move if black is to move
730
+ if is_black_to_move_list[i]:
731
+ def mirror_rank(rank_char):
732
+ rank = int(rank_char)
733
+ return str(9 - rank)
734
+
735
+ if len(uci_move) >= 4:
736
+ from_file = uci_move[0]
737
+ from_rank = uci_move[1]
738
+ to_file = uci_move[2]
739
+ to_rank = uci_move[3]
740
+ promo = uci_move[4:] if len(uci_move) > 4 else ""
741
+
742
+ uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo
743
+
744
+ # Convert castling moves
745
+ castling_rights = castling_rights_list[i]
746
+ if uci_move == "e1h1" and "K" in castling_rights:
747
+ uci_move = "e1g1"
748
+ elif uci_move == "e1a1" and "Q" in castling_rights:
749
+ uci_move = "e1c1"
750
+ elif uci_move == "e8h8" and "k" in castling_rights:
751
+ uci_move = "e8g8"
752
+ elif uci_move == "e8a8" and "q" in castling_rights:
753
+ uci_move = "e8c8"
754
+
755
+ moves.append(uci_move)
756
+ return moves