Maxlegrec commited on
Commit
5a34b85
·
verified ·
1 Parent(s): 5e5fa59

Update model architecture: d_ff=1024, new weights from merged7.pt

Browse files
Files changed (4) hide show
  1. attn.py +533 -0
  2. model.safetensors +2 -2
  3. modeling_chessbot.py +169 -480
  4. vocab.py +231 -0
attn.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ import numpy as np
7
+ from typing import Optional, Tuple
8
+
9
+
10
+ class ScaledDotProductAttention(nn.Module):
11
+ """
12
+ Scaled Dot-Product Attention proposed in "Attention Is All You Need"
13
+ Compute the dot products of the query with all keys, divide each by sqrt(dim),
14
+ and apply a softmax function to obtain the weights on the values
15
+
16
+ Args: dim, mask
17
+ dim (int): dimention of attention
18
+ mask (torch.Tensor): tensor containing indices to be masked
19
+
20
+ Inputs: query, key, value, mask
21
+ - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
22
+ - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
23
+ - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
24
+ - **mask** (-): tensor containing indices to be masked
25
+
26
+ Returns: context, attn
27
+ - **context**: tensor containing the context vector from attention mechanism.
28
+ - **attn**: tensor containing the attention (alignment) from the encoder outputs.
29
+ """
30
+ def __init__(self, dim: int):
31
+ super(ScaledDotProductAttention, self).__init__()
32
+ self.sqrt_dim = np.sqrt(dim)
33
+
34
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
35
+ score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
36
+
37
+ if mask is not None:
38
+ score.masked_fill_(mask.view(score.size()), -float('Inf'))
39
+
40
+ attn = F.softmax(score, -1)
41
+ context = torch.bmm(attn, value)
42
+ return context, attn
43
+
44
+
45
+ class DotProductAttention(nn.Module):
46
+ """
47
+ Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values
48
+ """
49
+ def __init__(self, hidden_dim):
50
+ super(DotProductAttention, self).__init__()
51
+
52
+ def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
53
+ batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1)
54
+
55
+ score = torch.bmm(query, value.transpose(1, 2))
56
+ attn = F.softmax(score.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
57
+ context = torch.bmm(attn, value)
58
+
59
+ return context, attn
60
+
61
+
62
+ class AdditiveAttention(nn.Module):
63
+ """
64
+ Applies a additive attention (bahdanau) mechanism on the output features from the decoder.
65
+ Additive attention proposed in "Neural Machine Translation by Jointly Learning to Align and Translate" paper.
66
+
67
+ Args:
68
+ hidden_dim (int): dimesion of hidden state vector
69
+
70
+ Inputs: query, value
71
+ - **query** (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
72
+ - **value** (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence.
73
+
74
+ Returns: context, attn
75
+ - **context**: tensor containing the context vector from attention mechanism.
76
+ - **attn**: tensor containing the alignment from the encoder outputs.
77
+
78
+ Reference:
79
+ - **Neural Machine Translation by Jointly Learning to Align and Translate**: https://arxiv.org/abs/1409.0473
80
+ """
81
+ def __init__(self, hidden_dim: int) -> None:
82
+ super(AdditiveAttention, self).__init__()
83
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
84
+ self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
85
+ self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
86
+ self.score_proj = nn.Linear(hidden_dim, 1)
87
+
88
+ def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
89
+ score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1)
90
+ attn = F.softmax(score, dim=-1)
91
+ context = torch.bmm(attn.unsqueeze(1), value)
92
+ return context, attn
93
+
94
+
95
+ class LocationAwareAttention(nn.Module):
96
+ """
97
+ Applies a location-aware attention mechanism on the output features from the decoder.
98
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
99
+ The location-aware attention mechanism is performing well in speech recognition tasks.
100
+ We refer to implementation of ClovaCall Attention style.
101
+
102
+ Args:
103
+ hidden_dim (int): dimesion of hidden state vector
104
+ smoothing (bool): flag indication whether to use smoothing or not.
105
+
106
+ Inputs: query, value, last_attn, smoothing
107
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
108
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
109
+ - **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
110
+
111
+ Returns: output, attn
112
+ - **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
113
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
114
+
115
+ Reference:
116
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
117
+ - **ClovaCall**: https://github.com/clovaai/ClovaCall/blob/master/las.pytorch/models/attention.py
118
+ """
119
+ def __init__(self, hidden_dim: int, smoothing: bool = True) -> None:
120
+ super(LocationAwareAttention, self).__init__()
121
+ self.hidden_dim = hidden_dim
122
+ self.conv1d = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
123
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
124
+ self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
125
+ self.score_proj = nn.Linear(hidden_dim, 1, bias=True)
126
+ self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
127
+ self.smoothing = smoothing
128
+
129
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
130
+ batch_size, hidden_dim, seq_len = query.size(0), query.size(2), value.size(1)
131
+
132
+ # Initialize previous attention (alignment) to zeros
133
+ if last_attn is None:
134
+ last_attn = value.new_zeros(batch_size, seq_len)
135
+
136
+ conv_attn = torch.transpose(self.conv1d(last_attn.unsqueeze(1)), 1, 2)
137
+ score = self.score_proj(torch.tanh(
138
+ self.query_proj(query.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
139
+ + self.value_proj(value.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
140
+ + conv_attn
141
+ + self.bias
142
+ )).squeeze(dim=-1)
143
+
144
+ if self.smoothing:
145
+ score = torch.sigmoid(score)
146
+ attn = torch.div(score, score.sum(dim=-1).unsqueeze(dim=-1))
147
+ else:
148
+ attn = F.softmax(score, dim=-1)
149
+
150
+ context = torch.bmm(attn.unsqueeze(dim=1), value).squeeze(dim=1) # Bx1xT X BxTxD => Bx1xD => BxD
151
+
152
+ return context, attn
153
+
154
+
155
+ class MultiHeadLocationAwareAttention(nn.Module):
156
+ """
157
+ Applies a multi-headed location-aware attention mechanism on the output features from the decoder.
158
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
159
+ The location-aware attention mechanism is performing well in speech recognition tasks.
160
+ In the above paper applied a signle head, but we applied multi head concept.
161
+
162
+ Args:
163
+ hidden_dim (int): The number of expected features in the output
164
+ num_heads (int): The number of heads. (default: )
165
+ conv_out_channel (int): The number of out channel in convolution
166
+
167
+ Inputs: query, value, prev_attn
168
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
169
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
170
+ - **prev_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
171
+
172
+ Returns: output, attn
173
+ - **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
174
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
175
+
176
+ Reference:
177
+ - **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
178
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
179
+ """
180
+ def __init__(self, hidden_dim: int, num_heads: int = 8, conv_out_channel: int = 10) -> None:
181
+ super(MultiHeadLocationAwareAttention, self).__init__()
182
+ self.hidden_dim = hidden_dim
183
+ self.num_heads = num_heads
184
+ self.dim = int(hidden_dim / num_heads)
185
+ self.conv1d = nn.Conv1d(num_heads, conv_out_channel, kernel_size=3, padding=1)
186
+ self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
187
+ self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
188
+ self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
189
+ self.score_proj = nn.Linear(self.dim, 1, bias=True)
190
+ self.bias = nn.Parameter(torch.rand(self.dim).uniform_(-0.1, 0.1))
191
+
192
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
193
+ batch_size, seq_len = value.size(0), value.size(1)
194
+
195
+ if last_attn is None:
196
+ last_attn = value.new_zeros(batch_size, self.num_heads, seq_len)
197
+
198
+ loc_energy = torch.tanh(self.loc_proj(self.conv1d(last_attn).transpose(1, 2)))
199
+ loc_energy = loc_energy.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(-1, seq_len, self.dim)
200
+
201
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
202
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
203
+ query = query.contiguous().view(-1, 1, self.dim)
204
+ value = value.contiguous().view(-1, seq_len, self.dim)
205
+
206
+ score = self.score_proj(torch.tanh(value + query + loc_energy + self.bias)).squeeze(2)
207
+ attn = F.softmax(score, dim=1)
208
+
209
+ value = value.view(batch_size, seq_len, self.num_heads, self.dim).permute(0, 2, 1, 3)
210
+ value = value.contiguous().view(-1, seq_len, self.dim)
211
+
212
+ context = torch.bmm(attn.unsqueeze(1), value).view(batch_size, -1, self.num_heads * self.dim)
213
+ attn = attn.view(batch_size, self.num_heads, -1)
214
+
215
+ return context, attn
216
+
217
+
218
+ class MultiHeadAttention(nn.Module):
219
+ """
220
+ Multi-Head Attention proposed in "Attention Is All You Need"
221
+ Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
222
+ project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
223
+ These are concatenated and once again projected, resulting in the final values.
224
+ Multi-head attention allows the model to jointly attend to information from different representation
225
+ subspaces at different positions.
226
+
227
+ MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
228
+ where head_i = Attention(Q · W_q, K · W_k, V · W_v)
229
+
230
+ Args:
231
+ d_model (int): The dimension of keys / values / quries (default: 512)
232
+ num_heads (int): The number of attention heads. (default: 8)
233
+
234
+ Inputs: query, key, value, mask
235
+ - **query** (batch, q_len, d_model): In transformer, three different ways:
236
+ Case 1: come from previoys decoder layer
237
+ Case 2: come from the input embedding
238
+ Case 3: come from the output embedding (masked)
239
+
240
+ - **key** (batch, k_len, d_model): In transformer, three different ways:
241
+ Case 1: come from the output of the encoder
242
+ Case 2: come from the input embeddings
243
+ Case 3: come from the output embedding (masked)
244
+
245
+ - **value** (batch, v_len, d_model): In transformer, three different ways:
246
+ Case 1: come from the output of the encoder
247
+ Case 2: come from the input embeddings
248
+ Case 3: come from the output embedding (masked)
249
+
250
+ - **mask** (-): tensor containing indices to be masked
251
+
252
+ Returns: output, attn
253
+ - **output** (batch, output_len, dimensions): tensor containing the attended output features.
254
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
255
+ """
256
+ def __init__(self, d_model: int = 512, num_heads: int = 8):
257
+ super(MultiHeadAttention, self).__init__()
258
+
259
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
260
+
261
+ self.d_head = int(d_model / num_heads)
262
+ self.num_heads = num_heads
263
+ self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
264
+ self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
265
+ self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
266
+ self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
267
+
268
+ def forward(
269
+ self,
270
+ query: Tensor,
271
+ key: Tensor,
272
+ value: Tensor,
273
+ mask: Optional[Tensor] = None
274
+ ) -> Tuple[Tensor, Tensor]:
275
+ batch_size = value.size(0)
276
+
277
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
278
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
279
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
280
+
281
+ query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
282
+ key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
283
+ value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
284
+
285
+ if mask is not None:
286
+ mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
287
+
288
+ context, attn = self.scaled_dot_attn(query, key, value, mask)
289
+
290
+ context = context.view(self.num_heads, batch_size, -1, self.d_head)
291
+ context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
292
+
293
+ return context, attn
294
+
295
+
296
+ class RelativeMultiHeadAttention(nn.Module):
297
+ """
298
+ Multi-head attention with relative positional encoding.
299
+ This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
300
+
301
+ Args:
302
+ d_model (int): The dimension of model
303
+ num_heads (int): The number of attention heads.
304
+ dropout_p (float): probability of dropout
305
+
306
+ Inputs: query, key, value, pos_embedding, mask
307
+ - **query** (batch, time, dim): Tensor containing query vector
308
+ - **key** (batch, time, dim): Tensor containing key vector
309
+ - **value** (batch, time, dim): Tensor containing value vector
310
+ - **pos_embedding** (batch, time, dim): Positional embedding tensor
311
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
312
+
313
+ Returns:
314
+ - **outputs**: Tensor produces by relative multi head attention module.
315
+ """
316
+ def __init__(
317
+ self,
318
+ d_model: int = 512,
319
+ num_heads: int = 16,
320
+ dropout_p: float = 0.1,
321
+ ):
322
+ super(RelativeMultiHeadAttention, self).__init__()
323
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
324
+ self.d_model = d_model
325
+ self.d_head = int(d_model / num_heads)
326
+ self.num_heads = num_heads
327
+ self.sqrt_dim = math.sqrt(d_model)
328
+
329
+ self.query_proj = nn.Linear(d_model, d_model)
330
+ self.key_proj = nn.Linear(d_model, d_model)
331
+ self.value_proj = nn.Linear(d_model, d_model)
332
+
333
+ self.dropout = nn.Dropout(p=dropout_p)
334
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
335
+ torch.nn.init.xavier_uniform_(self.u_bias)
336
+
337
+
338
+
339
+ def forward(
340
+ self,
341
+ query: Tensor,
342
+ key: Tensor,
343
+ value: Tensor,
344
+ mask: Optional[Tensor] = None,
345
+ ) -> Tensor:
346
+ batch_size = value.size(0)
347
+
348
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
349
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
350
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
351
+
352
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), (key).transpose(2, 3))
353
+
354
+ score = content_score / self.sqrt_dim
355
+
356
+ if mask is not None:
357
+ mask = mask.unsqueeze(1)
358
+ score.masked_fill_(mask, -1e9)
359
+
360
+ attn = F.softmax(score, -1)
361
+ attn = self.dropout(attn)
362
+
363
+ context = torch.matmul(attn, value).transpose(1, 2)
364
+ context = context.contiguous().view(batch_size, -1, self.d_model)
365
+
366
+ return context
367
+
368
+
369
+
370
+
371
+ class CustomizingAttention(nn.Module):
372
+ r"""
373
+ Customizing Attention
374
+
375
+ Applies a multi-head + location-aware attention mechanism on the output features from the decoder.
376
+ Multi-head attention proposed in "Attention Is All You Need" paper.
377
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
378
+ I combined these two attention mechanisms as custom.
379
+
380
+ Args:
381
+ hidden_dim (int): The number of expected features in the output
382
+ num_heads (int): The number of heads. (default: )
383
+ conv_out_channel (int): The dimension of convolution
384
+
385
+ Inputs: query, value, last_attn
386
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
387
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
388
+ - **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s alignment
389
+
390
+ Returns: output, attn
391
+ - **output** (batch, output_len, dimensions): tensor containing the attended output features from the decoder.
392
+ - **attn** (batch * num_heads, v_len): tensor containing the alignment from the encoder outputs.
393
+
394
+ Reference:
395
+ - **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
396
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
397
+ """
398
+
399
+ def __init__(self, hidden_dim: int, num_heads: int = 4, conv_out_channel: int = 10) -> None:
400
+ super(CustomizingAttention, self).__init__()
401
+ self.hidden_dim = hidden_dim
402
+ self.num_heads = num_heads
403
+ self.dim = int(hidden_dim / num_heads)
404
+ self.scaled_dot_attn = ScaledDotProductAttention(self.dim)
405
+ self.conv1d = nn.Conv1d(1, conv_out_channel, kernel_size=3, padding=1)
406
+ self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=True)
407
+ self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
408
+ self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
409
+ self.bias = nn.Parameter(torch.rand(self.dim * num_heads).uniform_(-0.1, 0.1))
410
+
411
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
412
+ batch_size, q_len, v_len = value.size(0), query.size(1), value.size(1)
413
+
414
+ if last_attn is None:
415
+ last_attn = value.new_zeros(batch_size * self.num_heads, v_len)
416
+
417
+ loc_energy = self.get_loc_energy(last_attn, batch_size, v_len) # get location energy
418
+
419
+ query = self.query_proj(query).view(batch_size, q_len, self.num_heads * self.dim)
420
+ value = self.value_proj(value).view(batch_size, v_len, self.num_heads * self.dim) + loc_energy + self.bias
421
+
422
+ query = query.view(batch_size, q_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
423
+ value = value.view(batch_size, v_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
424
+ query = query.contiguous().view(-1, q_len, self.dim)
425
+ value = value.contiguous().view(-1, v_len, self.dim)
426
+
427
+ context, attn = self.scaled_dot_attn(query, value)
428
+ attn = attn.squeeze()
429
+
430
+ context = context.view(self.num_heads, batch_size, q_len, self.dim).permute(1, 2, 0, 3)
431
+ context = context.contiguous().view(batch_size, q_len, -1)
432
+
433
+ return context, attn
434
+
435
+ def get_loc_energy(self, last_attn: Tensor, batch_size: int, v_len: int) -> Tensor:
436
+ conv_feat = self.conv1d(last_attn.unsqueeze(1))
437
+ conv_feat = conv_feat.view(batch_size, self.num_heads, -1, v_len).permute(0, 1, 3, 2)
438
+
439
+ loc_energy = self.loc_proj(conv_feat).view(batch_size, self.num_heads, v_len, self.dim)
440
+ loc_energy = loc_energy.permute(0, 2, 1, 3).reshape(batch_size, v_len, self.num_heads * self.dim)
441
+
442
+ return loc_energy
443
+
444
+
445
+ class RelativeMultiHeadAttention2(nn.Module):
446
+ """
447
+ Multi-head attention with relative positional encoding.
448
+ This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
449
+
450
+ Args:
451
+ d_model (int): The dimension of model
452
+ num_heads (int): The number of attention heads.
453
+ dropout_p (float): probability of dropout
454
+
455
+ Inputs: query, key, value, pos_embedding, mask
456
+ - **query** (batch, time, dim): Tensor containing query vector
457
+ - **key** (batch, time, dim): Tensor containing key vector
458
+ - **value** (batch, time, dim): Tensor containing value vector
459
+ - **pos_embedding** (batch, time, dim): Positional embedding tensor
460
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
461
+
462
+ Returns:
463
+ - **outputs**: Tensor produces by relative multi head attention module.
464
+ """
465
+ def __init__(
466
+ self,
467
+ d_model: int = 512,
468
+ num_heads: int = 16,
469
+ dropout_p: float = 0.1,
470
+ ):
471
+ super(RelativeMultiHeadAttention2, self).__init__()
472
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
473
+ self.d_model = d_model
474
+ self.d_head = int(d_model / num_heads)
475
+ self.num_heads = num_heads
476
+ self.sqrt_dim = math.sqrt(d_model)
477
+
478
+ self.query_proj = nn.Linear(d_model, d_model)
479
+ self.key_proj = nn.Linear(d_model, d_model)
480
+ self.value_proj = nn.Linear(d_model, d_model)
481
+ self.pos_proj = nn.Linear(d_model, d_model, bias=False)
482
+
483
+ self.dropout = nn.Dropout(p=dropout_p)
484
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
485
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
486
+ torch.nn.init.xavier_uniform_(self.u_bias)
487
+ torch.nn.init.xavier_uniform_(self.v_bias)
488
+
489
+ self.out_proj = nn.Linear(d_model, d_model)
490
+
491
+ def forward(
492
+ self,
493
+ query: Tensor,
494
+ key: Tensor,
495
+ value: Tensor,
496
+ pos_embedding: Tensor,
497
+ mask: Optional[Tensor] = None,
498
+ ) -> Tensor:
499
+ batch_size = value.size(0)
500
+
501
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
502
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
503
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
504
+
505
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
506
+
507
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
508
+ pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
509
+ pos_score = self._compute_relative_positional_encoding(pos_score)
510
+
511
+ score = (content_score + pos_score) / self.sqrt_dim
512
+
513
+ if mask is not None:
514
+ mask = mask.unsqueeze(1)
515
+ score.masked_fill_(mask, -1e9)
516
+
517
+ attn = F.softmax(score, -1)
518
+ attn = self.dropout(attn)
519
+
520
+ context = torch.matmul(attn, value).transpose(1, 2)
521
+ context = context.contiguous().view(batch_size, -1, self.d_model)
522
+
523
+ return self.out_proj(context)
524
+
525
+ def _compute_relative_positional_encoding(self, pos_score: Tensor) -> Tensor:
526
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
527
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
528
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
529
+
530
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
531
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
532
+
533
+ return pos_score
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:824ed0a0d945ebf519eee41755d7a7d29487bbdc92e49b94aaf97de6105f7b17
3
- size 138793144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b33c63dd6fd9936f63e976118e2c4301328d8017edaaa90892335e61ffed6929
3
+ size 138640496
modeling_chessbot.py CHANGED
@@ -1,376 +1,81 @@
1
- """
2
- Updated HuggingFace Compatible ChessBot Chess Model
3
-
4
- This file contains the updated architecture with d_ff=1024 and new weights
5
- """
6
-
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
- import numpy as np
11
- import chess
12
- from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
13
- from transformers.modeling_outputs import BaseModelOutput
14
- from typing import Optional, Tuple
15
- import math
16
  import sys
17
  import os
 
 
 
18
 
19
- class RelativeMultiHeadAttention2(nn.Module):
20
- def __init__(self, d_model: int = 512, num_heads: int = 8, dropout_p: float = 0.1):
21
- super().__init__()
22
- assert d_model % num_heads == 0
23
- self.d_model = d_model
24
- self.num_heads = num_heads
25
- self.d_head = d_model // num_heads
26
- self.sqrt_dim = math.sqrt(d_model)
27
-
28
- self.query_proj = nn.Linear(d_model, d_model)
29
- self.key_proj = nn.Linear(d_model, d_model)
30
- self.value_proj = nn.Linear(d_model, d_model)
31
- self.pos_proj = nn.Linear(d_model, d_model)
32
- self.out_proj = nn.Linear(d_model, d_model)
33
-
34
- self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
35
- self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
36
- torch.nn.init.xavier_uniform_(self.u_bias)
37
- torch.nn.init.xavier_uniform_(self.v_bias)
38
- self.dropout = nn.Dropout(dropout_p)
39
-
40
- def forward(self, query, key, value, pos_embedding, mask=None):
41
- batch_size = value.size(0)
42
-
43
- query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
44
- key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
45
- value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
46
-
47
- pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
48
-
49
- content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
50
- pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
51
- pos_score = self._compute_relative_positional_encoding(pos_score)
52
-
53
- score = (content_score + pos_score) / self.sqrt_dim
54
-
55
- if mask is not None:
56
- mask = mask.unsqueeze(1)
57
- score.masked_fill_(mask, -1e9)
58
-
59
- attn = F.softmax(score, -1)
60
- attn = self.dropout(attn)
61
-
62
- context = torch.matmul(attn, value).transpose(1, 2)
63
- context = context.contiguous().view(batch_size, -1, self.d_model)
64
-
65
- return self.out_proj(context)
66
-
67
- def _compute_relative_positional_encoding(self, pos_score):
68
- batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
69
- zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
70
- padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
71
-
72
- padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
73
- pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
74
-
75
- return pos_score
76
 
 
 
 
 
 
77
 
78
  def fen_to_tensor(fen: str):
79
- """Convert FEN string to tensor representation for the model."""
80
  board = chess.Board(fen)
81
- tensor = np.zeros((8, 8, 19), dtype=np.float32)
 
82
 
83
- # Piece mapping
84
  piece_map = {
85
  'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
86
  'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
87
  }
88
 
89
- # Fill piece positions
90
- for square in chess.SQUARES:
91
- piece = board.piece_at(square)
92
- if piece:
93
- row = 7 - (square // 8) # Flip vertically for proper orientation
94
- col = square % 8
95
- tensor[row, col, piece_map[piece.symbol()]] = 1.0
96
-
97
- # Add metadata channels
98
- # Channel 12: White to move
99
- if board.turn == chess.WHITE:
100
- tensor[:, :, 12] = 1.0
101
 
102
- # Channel 13: Black to move
103
- if board.turn == chess.BLACK:
104
- tensor[:, :, 13] = 1.0
 
 
 
 
105
 
106
- # Castling rights
107
- if board.has_kingside_castling_rights(chess.WHITE):
108
- tensor[:, :, 14] = 1.0
109
- if board.has_queenside_castling_rights(chess.WHITE):
110
- tensor[:, :, 15] = 1.0
111
- if board.has_kingside_castling_rights(chess.BLACK):
112
- tensor[:, :, 16] = 1.0
113
- if board.has_queenside_castling_rights(chess.BLACK):
114
- tensor[:, :, 17] = 1.0
115
 
116
- # En passant
117
- if board.ep_square is not None:
118
- ep_row = 7 - (board.ep_square // 8)
119
- ep_col = board.ep_square % 8
120
- tensor[ep_row, ep_col, 18] = 1.0
121
 
122
  return tensor
123
 
124
- # Complete policy index with all 1929 moves
125
- policy_index = [
126
- "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
127
- "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
128
- "a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1",
129
- "b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3",
130
- "b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7",
131
- "b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1",
132
- "c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3",
133
- "c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8",
134
- "d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2",
135
- "d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4",
136
- "d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1",
137
- "e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2",
138
- "e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4",
139
- "e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1",
140
- "f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3",
141
- "f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6",
142
- "f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1",
143
- "g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3",
144
- "g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8",
145
- "h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2",
146
- "h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6",
147
- "h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2",
148
- "a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3",
149
- "a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7",
150
- "a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2",
151
- "b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4",
152
- "b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7",
153
- "b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2",
154
- "c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3",
155
- "c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6",
156
- "c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1",
157
- "d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3",
158
- "d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5",
159
- "d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1",
160
- "e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2",
161
- "e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4",
162
- "e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1",
163
- "f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2",
164
- "f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4",
165
- "f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7",
166
- "f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2",
167
- "g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4",
168
- "g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8",
169
- "g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2",
170
- "h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5",
171
- "h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1",
172
- "a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3",
173
- "a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6",
174
- "a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1",
175
- "b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3",
176
- "b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5",
177
- "b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1",
178
- "c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3",
179
- "c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4",
180
- "c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6",
181
- "c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1",
182
- "d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3",
183
- "d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5",
184
- "d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7",
185
- "d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2",
186
- "e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3",
187
- "e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5",
188
- "e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1",
189
- "f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3",
190
- "f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4",
191
- "f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6",
192
- "f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2",
193
- "g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3",
194
- "g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5",
195
- "g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1",
196
- "h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3",
197
- "h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6",
198
- "h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2",
199
- "a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4",
200
- "a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7",
201
- "a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3",
202
- "b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4",
203
- "b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6",
204
- "b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2",
205
- "c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4",
206
- "c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5",
207
- "c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8",
208
- "c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2",
209
- "d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4",
210
- "d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6",
211
- "d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8",
212
- "e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3",
213
- "e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4",
214
- "e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6",
215
- "e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1",
216
- "f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3",
217
- "f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4",
218
- "f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6",
219
- "f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2",
220
- "g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4",
221
- "g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6",
222
- "g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1",
223
- "h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4",
224
- "h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6",
225
- "h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2",
226
- "a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5",
227
- "a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7",
228
- "a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3",
229
- "b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5",
230
- "b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7",
231
- "b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2",
232
- "c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4",
233
- "c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6",
234
- "c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7",
235
- "c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3",
236
- "d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5",
237
- "d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6",
238
- "d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8",
239
- "d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3",
240
- "e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5",
241
- "e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6",
242
- "e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8",
243
- "f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3",
244
- "f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5",
245
- "f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7",
246
- "f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2",
247
- "g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4",
248
- "g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6",
249
- "g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1",
250
- "h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4",
251
- "h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6",
252
- "h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2",
253
- "a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5",
254
- "a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7",
255
- "a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3",
256
- "b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5",
257
- "b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7",
258
- "b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2",
259
- "c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5",
260
- "c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6",
261
- "c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8",
262
- "c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3",
263
- "d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5",
264
- "d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7",
265
- "d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8",
266
- "e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4",
267
- "e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6",
268
- "e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7",
269
- "e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2",
270
- "f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5",
271
- "f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6",
272
- "f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8",
273
- "f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3",
274
- "g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6",
275
- "g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7",
276
- "g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2",
277
- "h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6",
278
- "h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7",
279
- "h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3",
280
- "a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7",
281
- "a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8",
282
- "b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5",
283
- "b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7",
284
- "b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8",
285
- "c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5",
286
- "c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7",
287
- "c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8",
288
- "c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4",
289
- "d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6",
290
- "d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8",
291
- "d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4",
292
- "e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6",
293
- "e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7",
294
- "e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2",
295
- "f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5",
296
- "f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7",
297
- "f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1",
298
- "g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5",
299
- "g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7",
300
- "g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1",
301
- "h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5",
302
- "h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7",
303
- "h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2",
304
- "a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6",
305
- "a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8",
306
- "a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5",
307
- "b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7",
308
- "b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2",
309
- "c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6",
310
- "c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8",
311
- "c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4",
312
- "d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6",
313
- "d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8",
314
- "d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5",
315
- "e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7",
316
- "e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8",
317
- "e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5",
318
- "f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7",
319
- "f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1",
320
- "g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6",
321
- "g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8",
322
- "g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2",
323
- "h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6",
324
- "h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8",
325
- "h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q",
326
- "b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b",
327
- "c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r",
328
- "c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q",
329
- "d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b",
330
- "e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r",
331
- "f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q",
332
- "g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b",
333
- "h7h8q", "h7h8r", "h7h8b", #add the promotions for black
334
- "a2a1q","a2a1r","a2a1b","a2b1q","a2b1r","a2b1b",
335
- "b2a1q","b2a1r","b2a1b","b2b1q","b2b1r","b2b1b","b2c1q","b2c1r","b2c1b",
336
- "c2b1q","c2b1r","c2b1b","c2c1q","c2c1r","c2c1b","c2d1q","c2d1r","c2d1b",
337
- "d2c1q","d2c1r","d2c1b","d2d1q","d2d1r","d2d1b","d2e1q","d2e1r","d2e1b",
338
- "e2d1q","e2d1r","e2d1b","e2e1q","e2e1r","e2e1b","e2f1q","e2f1r","e2f1b",
339
- "f2e1q","f2e1r","f2e1b","f2f1q","f2f1r","f2f1b","f2g1q","f2g1r","f2g1b",
340
- "g2f1q","g2f1r","g2f1b","g2g1q","g2g1r","g2g1b","g2h1q","g2h1r","g2h1b",
341
- "h2g1q","h2g1r","h2g1b","h2h1q","h2h1r","h2h1b",#add special tokens
342
- "<thinking>","</thinking>","end_variation","end","padding_token"
343
- ]
344
-
345
-
346
- # Configuration class
347
  class ChessBotConfig(PretrainedConfig):
348
- """
349
- Configuration class for ChessBot model.
350
- """
351
-
352
  model_type = "chessbot"
353
 
354
  def __init__(
355
  self,
356
- num_layers: int = 10,
357
- d_model: int = 512,
358
- d_ff: int = 1024, # Updated to match new architecture
359
- num_heads: int = 8,
360
- vocab_size: int = 1929,
361
- max_position_embeddings: int = 64,
362
- **kwargs,
 
363
  ):
364
- super().__init__(**kwargs)
365
  self.num_layers = num_layers
366
  self.d_model = d_model
367
  self.d_ff = d_ff
368
  self.num_heads = num_heads
369
- self.vocab_size = vocab_size
370
  self.max_position_embeddings = max_position_embeddings
 
 
 
371
 
372
 
373
- # Model components
374
  class MaGating(nn.Module):
375
  def __init__(self, d_model):
376
  super().__init__()
@@ -395,38 +100,54 @@ class EncoderLayer(nn.Module):
395
  attn_out = self.attention(x, x, x, pos_enc)
396
  x = attn_out + x
397
  x = self.norm1(x)
398
-
399
  y = self.ff1(x)
400
  y = self.gelu(y)
401
  y = self.ff2(y)
402
  y = y + x
403
  y = self.norm2(y)
404
-
405
  return y
406
 
407
 
408
  class AbsolutePositionalEncoder(nn.Module):
409
  def __init__(self, d_model):
410
- super().__init__()
411
- self.position = torch.arange(64).unsqueeze(1)
412
- self.positional_encoding = torch.zeros(1, 64, d_model)
 
 
413
  _2i = torch.arange(0, d_model, step=2).float()
414
- self.positional_encoding[:, :, 0::2] = torch.sin(self.position / (10000 ** (_2i / d_model)))
415
- self.positional_encoding[:, :, 1::2] = torch.cos(self.position / (10000 ** (_2i / d_model)))
416
 
417
- # Register as buffer so it moves with the model
418
- self.register_buffer('pos_encoding', self.positional_encoding)
419
 
420
  def forward(self, x):
421
  batch_size, _, _ = x.size()
422
- return self.pos_encoding.expand(batch_size, -1, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
 
425
  class ValueHead(nn.Module):
426
  def __init__(self, d_model):
427
  super().__init__()
428
  self.dense1 = nn.Linear(d_model, 128)
429
- self.dense2 = nn.Linear(128*64, 128)
430
  self.dense3 = nn.Linear(128, 3)
431
 
432
  def forward(self, x):
@@ -438,13 +159,13 @@ class ValueHead(nn.Module):
438
  x = F.gelu(x)
439
  x = self.dense3(x)
440
  return x
441
-
442
 
443
  class ValueHeadQ(nn.Module):
444
  def __init__(self, d_model):
445
  super().__init__()
446
  self.dense1 = nn.Linear(d_model, 128)
447
- self.dense2 = nn.Linear(128*64, 128)
448
  self.dense3 = nn.Linear(128, 3)
449
 
450
  def forward(self, x):
@@ -458,49 +179,21 @@ class ValueHeadQ(nn.Module):
458
  return x
459
 
460
 
461
- # Main HuggingFace compatible model class
462
- class ChessBotPreTrainedModel(PreTrainedModel):
463
- """
464
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
465
- """
466
-
467
  config_class = ChessBotConfig
468
- base_model_prefix = "chessbot"
469
- supports_gradient_checkpointing = True
470
-
471
- def _init_weights(self, module):
472
- """Initialize the weights"""
473
- if isinstance(module, nn.Linear):
474
- module.weight.data.normal_(mean=0.0, std=0.02)
475
- if module.bias is not None:
476
- module.bias.data.zero_()
477
- elif isinstance(module, nn.Embedding):
478
- module.weight.data.normal_(mean=0.0, std=0.02)
479
- elif isinstance(module, nn.LayerNorm):
480
- module.bias.data.zero_()
481
- module.weight.data.fill_(1.0)
482
-
483
-
484
- class ChessBotModel(ChessBotPreTrainedModel):
485
- """
486
- Updated HuggingFace compatible ChessBot Chess model with d_ff=1024
487
- """
488
 
489
  def __init__(self, config):
490
  super().__init__(config)
491
  self.config = config
492
-
493
- # Initialize exactly like the updated BT4 model
494
  self.is_thinking_model = False
495
  self.d_model = config.d_model
496
  self.num_layers = config.num_layers
497
-
498
- # Model layers - same as updated model
499
  self.layers = nn.ModuleList([
500
  EncoderLayer(config.d_model, config.d_ff, config.num_heads)
501
  for _ in range(config.num_layers)
502
  ])
503
-
504
  self.linear1 = nn.Linear(19, config.d_model)
505
  self.layernorm1 = nn.LayerNorm(config.d_model)
506
  self.policy_tokens_lin = nn.Linear(config.d_model, config.d_model)
@@ -508,29 +201,33 @@ class ChessBotModel(ChessBotPreTrainedModel):
508
  self.keys_pol = nn.Linear(config.d_model, config.d_model)
509
  self.positional = AbsolutePositionalEncoder(config.d_model)
510
  self.ma_gating = MaGating(config.d_model)
511
- self.policy_head = nn.Linear(64*64, config.vocab_size, bias=False)
512
  self.value_head = ValueHead(config.d_model)
513
  self.value_head_q = ValueHeadQ(config.d_model)
514
-
515
- # Initialize weights
516
- self.post_init()
517
-
518
- def forward(self, input_ids, attention_mask=None, compute_loss=False):
519
- """
520
- Forward pass compatible with both HuggingFace interface and original interface
521
- """
522
- # Handle both HF interface (input_ids) and original interface (tuple)
523
- if isinstance(input_ids, tuple):
 
524
  inp = input_ids
 
 
 
 
 
 
525
  x = inp[0]
526
- compute_loss = compute_loss or len(inp) > 1
527
  else:
528
- x = input_ids
529
- inp = (x,)
530
-
531
  b, seq_len, _, _, emb = x.size()
532
  x = x.view(b * seq_len, 64, emb)
533
-
534
  x = self.linear1(x)
535
  x = F.gelu(x)
536
  x = self.layernorm1(x)
@@ -539,7 +236,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
539
  pos_enc = self.positional(x)
540
  for i in range(self.num_layers):
541
  x = self.layers[i](x, pos_enc)
542
-
543
  value_h = self.value_head(x)
544
  value_h = value_h.view(b, seq_len, 3)
545
  value_h_q = self.value_head_q(x)
@@ -548,64 +245,64 @@ class ChessBotModel(ChessBotPreTrainedModel):
548
  policy_tokens = self.policy_tokens_lin(x)
549
  policy_tokens = F.gelu(policy_tokens)
550
  policy_tokens = policy_tokens + pos_enc
551
-
552
  queries = self.queries_pol(policy_tokens)
553
  keys = self.keys_pol(policy_tokens)
554
-
555
  matmul_qk = torch.matmul(queries, torch.transpose(keys, -2, -1))
556
- dk = torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
557
-
558
  policy_attn_logits = matmul_qk / dk
559
- policy_attn_logits = policy_attn_logits.view(b, seq_len, 64*64)
560
-
561
  policy = self.policy_head(policy_attn_logits)
562
-
563
- if compute_loss:
564
  targets = inp[1]
565
- true_values = inp[3]
566
- q_values = inp[4]
567
- true_values = q_values
568
 
569
- z = torch.argmax(true_values, dim=-1)
570
- q = torch.argmax(q_values, dim=-1)
571
- value_h_q_softmax = torch.softmax(value_h_q, dim=-1)
572
-
573
- # Always compute policy loss
574
- loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
575
-
576
- # Create mask for samples where true_values/q_values is not [0,0,0]
577
- valid_mask = (true_values.sum(dim=-1) != 0) & (q_values.sum(dim=-1) != 0)
578
-
579
- # Only compute value losses if we have valid samples
580
- if valid_mask.any():
581
- # Filter to only valid samples
582
- valid_value_h = value_h[valid_mask]
583
- valid_value_h_q = value_h_q_softmax[valid_mask]
584
- valid_z = z[valid_mask]
585
- valid_q_values = q_values[valid_mask]
586
 
587
- loss_value = F.cross_entropy(valid_value_h.view(-1, valid_value_h.size(-1)), valid_z.view(-1))
588
- loss_q = F.mse_loss(valid_value_h_q.view(-1, valid_value_h_q.size(-1)), valid_q_values.view(-1, 3))
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  else:
590
- # No valid samples, set losses to zero
591
- loss_value = torch.tensor(0.0, device=value_h.device, requires_grad=True)
592
- loss_q = torch.tensor(0.0, device=value_h_q.device, requires_grad=True)
593
-
594
- return policy, value_h, value_h_q, loss_policy, loss_value, loss_q, targets, z, q
595
-
596
  return BaseModelOutput(
597
- last_hidden_state=x,
598
- hidden_states=None,
599
- attentions=None,
600
- ), policy, value_h, value_h_q
601
 
602
  def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
603
- """Get a move from FEN string without thinking"""
604
- board = chess.Board(fen)
605
  x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
606
  x = x.view(1, 1, 8, 8, 19)
607
 
608
- _, logits, _, _ = self.forward(x)
 
 
 
 
 
609
  logits = logits.view(-1, 1929)
610
  legal_move_mask = torch.zeros((1, 1929), device=device)
611
  for legal_move in board.legal_moves:
@@ -613,21 +310,25 @@ class ChessBotModel(ChessBotPreTrainedModel):
613
  legal_move_uci = legal_move.uci()[:-1]
614
  else:
615
  legal_move_uci = legal_move.uci()
616
- if legal_move_uci in policy_index:
617
- legal_move_mask[0][policy_index.index(legal_move_uci)] = 1
618
-
619
- # Set all illegal moves to -inf
620
  if force_legal:
621
- logits = logits + (1-legal_move_mask) * -999
622
 
623
  if T == 0:
624
  sampled = torch.argmax(logits, dim=-1, keepdim=True)
625
  else:
626
- probs = F.softmax(logits/T, dim=-1)
627
  sampled = torch.multinomial(probs, num_samples=1)
628
  if return_probs:
629
- return probs
630
-
 
 
 
 
 
 
631
  move = policy_index[sampled.item()]
632
  return move
633
 
@@ -651,6 +352,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
651
  value_logits = self.value_head_q(x_processed)
652
  value_logits = value_logits.view(b, seq_len, 3)
653
  value = torch.softmax(value_logits, dim=-1)
 
654
  return value.squeeze()
655
 
656
  def get_batch_position_values(self, fens, device="cuda"):
@@ -681,6 +383,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
681
  value_logits = self.value_head_q(x_processed)
682
  value_logits = value_logits.view(b, seq_len, 3)
683
  value_logits = torch.softmax(value_logits, dim=-1)
 
684
  return value_logits.squeeze(1)
685
 
686
  def calculate_move_values(self, fen, device="cuda"):
@@ -689,7 +392,6 @@ class ChessBotModel(ChessBotPreTrainedModel):
689
  board.set_fen(fen)
690
 
691
  is_white_turn = board.turn == chess.WHITE
692
-
693
  legal_moves = list(board.legal_moves)
694
  if len(legal_moves) == 0:
695
  return [], torch.empty(0, device=device)
@@ -701,9 +403,8 @@ class ChessBotModel(ChessBotPreTrainedModel):
701
  board.pop()
702
 
703
  batch_value_q = self.get_batch_position_values(resulting_fens, device)
 
704
 
705
- # Calculate values from the current player's perspective
706
- batch_value_q = batch_value_q[:,2]-batch_value_q[:,0]
707
  if is_white_turn:
708
  player_values = batch_value_q
709
  else:
@@ -713,20 +414,20 @@ class ChessBotModel(ChessBotPreTrainedModel):
713
 
714
  def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False, to_fall_back_to_policy=False):
715
  """Determine the best move based on the value of resulting positions using efficient batching."""
716
- # Check if we should fall back to policy
717
- if to_fall_back_to_policy:
718
- value = self.get_position_value(fen, device)
719
- board = chess.Board()
720
- board.set_fen(fen)
721
-
722
- is_white_turn = board.turn == chess.WHITE
723
- if is_white_turn:
724
- value = value[2]-value[0]
725
- else:
726
- value = value[0]-value[2]
727
 
728
- if value > 0.9:
729
- return self.get_move_from_fen_no_thinking(fen, T, device, force_legal=True, return_probs=return_probs)
 
 
730
 
731
  legal_moves, move_values = self.calculate_move_values(fen, device)
732
 
@@ -749,19 +450,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
749
  probs[best_idx] = 1.0
750
  else:
751
  probs = F.softmax(move_values / T, dim=0)
752
-
753
- move_dict = {}
754
- for i, move in enumerate(legal_moves):
755
- move_dict[move.uci()] = probs[i].item()
756
- return move_uci, move_dict
757
 
758
  return move_uci
759
-
760
-
761
- # Register the configuration and model with transformers
762
- AutoConfig.register("chessbot", ChessBotConfig)
763
- AutoModel.register(ChessBotConfig, ChessBotModel)
764
-
765
- # For backward compatibility
766
- ChessBot = ChessBotModel
767
- BT4Model = ChessBotModel
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
 
 
 
 
 
4
  import sys
5
  import os
6
+ import chess
7
+ import numpy as np
8
+ from typing import Optional, Tuple, Union
9
 
10
+ from transformers import PreTrainedModel, PretrainedConfig
11
+ from transformers.modeling_outputs import BaseModelOutput
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ from attn import RelativeMultiHeadAttention2
14
+ from vocab import policy_index
15
+
16
+ # Make policy_index available for imports from this module
17
+ __all__ = ['ChessBotConfig', 'ChessBotModel', 'policy_index', 'fen_to_tensor']
18
 
19
  def fen_to_tensor(fen: str):
 
20
  board = chess.Board(fen)
21
+ P = 19 # 12 planes for pieces + 1 for side to play + 1 for en passant + 4 for castling + 1 for 50-move rule
22
+ tensor = np.zeros((8, 8, P), dtype=np.float32)
23
 
 
24
  piece_map = {
25
  'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
26
  'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
27
  }
28
 
29
+ # Populate piece planes
30
+ for square, piece in board.piece_map().items():
31
+ rank, file = divmod(square, 8)
32
+ plane = piece_map[piece.symbol()]
33
+ tensor[7 - rank, file, plane] = 1.0 # Flip rank to align with standard board representation
 
 
 
 
 
 
 
34
 
35
+ # Side to play plane
36
+ tensor[:, :, 12] = 1.0 if board.turn == chess.WHITE else 0.0
37
+
38
+ # En passant plane
39
+ if board.ep_square is not None:
40
+ rank, file = divmod(board.ep_square, 8)
41
+ tensor[7 - rank, file, 13] = 1.0
42
 
43
+ # Castling rights planes (4 total: white kingside, white queenside, black kingside, black queenside)
44
+ tensor[:, :, 14] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0
45
+ tensor[:, :, 15] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0
46
+ tensor[:, :, 16] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0
47
+ tensor[:, :, 17] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0
 
 
 
 
48
 
49
+ # 50-move rule plane (normalized to [0,1])
50
+ tensor[:, :, 18] = min(board.halfmove_clock / 100.0, 1.0)
 
 
 
51
 
52
  return tensor
53
 
54
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  class ChessBotConfig(PretrainedConfig):
 
 
 
 
56
  model_type = "chessbot"
57
 
58
  def __init__(
59
  self,
60
+ num_layers=10,
61
+ d_model=512,
62
+ d_ff=1024,
63
+ num_heads=8,
64
+ max_position_embeddings=64,
65
+ vocab_size=1929,
66
+ torch_dtype="float32",
67
+ **kwargs
68
  ):
 
69
  self.num_layers = num_layers
70
  self.d_model = d_model
71
  self.d_ff = d_ff
72
  self.num_heads = num_heads
 
73
  self.max_position_embeddings = max_position_embeddings
74
+ self.vocab_size = vocab_size
75
+ self.torch_dtype = torch_dtype
76
+ super().__init__(**kwargs)
77
 
78
 
 
79
  class MaGating(nn.Module):
80
  def __init__(self, d_model):
81
  super().__init__()
 
100
  attn_out = self.attention(x, x, x, pos_enc)
101
  x = attn_out + x
102
  x = self.norm1(x)
103
+
104
  y = self.ff1(x)
105
  y = self.gelu(y)
106
  y = self.ff2(y)
107
  y = y + x
108
  y = self.norm2(y)
109
+
110
  return y
111
 
112
 
113
  class AbsolutePositionalEncoder(nn.Module):
114
  def __init__(self, d_model):
115
+ super(AbsolutePositionalEncoder, self).__init__()
116
+ self.d_model = d_model
117
+ self.register_buffer('position', torch.arange(64).unsqueeze(1))
118
+
119
+ positional_encoding = torch.zeros(1, 64, d_model)
120
  _2i = torch.arange(0, d_model, step=2).float()
121
+ positional_encoding[:, :, 0::2] = torch.sin(self.position / (10000 ** (_2i / d_model)))
122
+ positional_encoding[:, :, 1::2] = torch.cos(self.position / (10000 ** (_2i / d_model)))
123
 
124
+ self.register_buffer('positional_encoding', positional_encoding)
 
125
 
126
  def forward(self, x):
127
  batch_size, _, _ = x.size()
128
+ return self.positional_encoding.expand(batch_size, -1, -1)
129
+
130
+
131
+ class LearnedPositionalEncoder(nn.Module):
132
+ def __init__(self, d_model=1929, max_len=64):
133
+ super(LearnedPositionalEncoder, self).__init__()
134
+ self.d_model = d_model
135
+ self.max_len = max_len
136
+ self.positional_embedding = nn.Embedding(max_len, d_model)
137
+
138
+ def forward(self, x):
139
+ batch_size, seq_len, _ = x.size()
140
+ positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
141
+ pos_embed = self.positional_embedding(positions)
142
+ pos_embed = pos_embed.expand(batch_size, -1, -1)
143
+ return pos_embed
144
 
145
 
146
  class ValueHead(nn.Module):
147
  def __init__(self, d_model):
148
  super().__init__()
149
  self.dense1 = nn.Linear(d_model, 128)
150
+ self.dense2 = nn.Linear(128 * 64, 128)
151
  self.dense3 = nn.Linear(128, 3)
152
 
153
  def forward(self, x):
 
159
  x = F.gelu(x)
160
  x = self.dense3(x)
161
  return x
162
+
163
 
164
  class ValueHeadQ(nn.Module):
165
  def __init__(self, d_model):
166
  super().__init__()
167
  self.dense1 = nn.Linear(d_model, 128)
168
+ self.dense2 = nn.Linear(128 * 64, 128)
169
  self.dense3 = nn.Linear(128, 3)
170
 
171
  def forward(self, x):
 
179
  return x
180
 
181
 
182
+ class ChessBotModel(PreTrainedModel):
 
 
 
 
 
183
  config_class = ChessBotConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  def __init__(self, config):
186
  super().__init__(config)
187
  self.config = config
 
 
188
  self.is_thinking_model = False
189
  self.d_model = config.d_model
190
  self.num_layers = config.num_layers
191
+
 
192
  self.layers = nn.ModuleList([
193
  EncoderLayer(config.d_model, config.d_ff, config.num_heads)
194
  for _ in range(config.num_layers)
195
  ])
196
+
197
  self.linear1 = nn.Linear(19, config.d_model)
198
  self.layernorm1 = nn.LayerNorm(config.d_model)
199
  self.policy_tokens_lin = nn.Linear(config.d_model, config.d_model)
 
201
  self.keys_pol = nn.Linear(config.d_model, config.d_model)
202
  self.positional = AbsolutePositionalEncoder(config.d_model)
203
  self.ma_gating = MaGating(config.d_model)
204
+ self.policy_head = nn.Linear(64 * 64, 1929, bias=False)
205
  self.value_head = ValueHead(config.d_model)
206
  self.value_head_q = ValueHeadQ(config.d_model)
207
+
208
+ def forward(
209
+ self,
210
+ input_ids=None,
211
+ inputs_embeds=None,
212
+ compute_loss=False,
213
+ step=None,
214
+ **kwargs
215
+ ):
216
+ # Handle both old-style input format and new HF format
217
+ if input_ids is not None:
218
  inp = input_ids
219
+ elif inputs_embeds is not None:
220
+ inp = inputs_embeds
221
+ else:
222
+ raise ValueError("Either input_ids or inputs_embeds must be provided")
223
+
224
+ if isinstance(inp, (list, tuple)):
225
  x = inp[0]
 
226
  else:
227
+ x = inp
228
+
 
229
  b, seq_len, _, _, emb = x.size()
230
  x = x.view(b * seq_len, 64, emb)
 
231
  x = self.linear1(x)
232
  x = F.gelu(x)
233
  x = self.layernorm1(x)
 
236
  pos_enc = self.positional(x)
237
  for i in range(self.num_layers):
238
  x = self.layers[i](x, pos_enc)
239
+
240
  value_h = self.value_head(x)
241
  value_h = value_h.view(b, seq_len, 3)
242
  value_h_q = self.value_head_q(x)
 
245
  policy_tokens = self.policy_tokens_lin(x)
246
  policy_tokens = F.gelu(policy_tokens)
247
  policy_tokens = policy_tokens + pos_enc
 
248
  queries = self.queries_pol(policy_tokens)
249
  keys = self.keys_pol(policy_tokens)
250
+
251
  matmul_qk = torch.matmul(queries, torch.transpose(keys, -2, -1))
252
+ dk = torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32, device=x.device))
 
253
  policy_attn_logits = matmul_qk / dk
254
+ policy_attn_logits = policy_attn_logits.view(b, seq_len, 64 * 64)
 
255
  policy = self.policy_head(policy_attn_logits)
256
+
257
+ if compute_loss and isinstance(inp, (list, tuple)) and len(inp) > 1:
258
  targets = inp[1]
259
+ true_values = inp[3] if len(inp) > 3 else None
260
+ q_values = inp[4] if len(inp) > 4 else None
 
261
 
262
+ if true_values is not None and q_values is not None:
263
+ true_values = q_values
264
+ z = torch.argmax(true_values, dim=-1)
265
+ q = torch.argmax(q_values, dim=-1)
266
+ value_h_q_softmax = torch.softmax(value_h_q, dim=-1)
267
+
268
+ loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
 
 
 
 
 
 
 
 
 
 
269
 
270
+ valid_mask = (true_values.sum(dim=-1) != 0) & (q_values.sum(dim=-1) != 0)
271
+
272
+ if valid_mask.any():
273
+ valid_value_h = value_h[valid_mask]
274
+ valid_value_h_q = value_h_q_softmax[valid_mask]
275
+ valid_z = z[valid_mask]
276
+ valid_q_values = q_values[valid_mask]
277
+
278
+ loss_value = F.cross_entropy(valid_value_h.view(-1, valid_value_h.size(-1)), valid_z.view(-1))
279
+ loss_q = F.mse_loss(valid_value_h_q.view(-1, valid_value_h_q.size(-1)), valid_q_values.view(-1, 3))
280
+ else:
281
+ loss_value = torch.tensor(0.0, device=value_h.device, requires_grad=True)
282
+ loss_q = torch.tensor(0.0, device=value_h_q.device, requires_grad=True)
283
+
284
+ return policy, value_h, value_h_q, loss_policy, loss_value, loss_q, targets, z, q
285
  else:
286
+ loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
287
+ return policy, value_h, value_h_q, loss_policy
288
+
 
 
 
289
  return BaseModelOutput(
290
+ last_hidden_state=policy,
291
+ hidden_states=(value_h, value_h_q)
292
+ )
 
293
 
294
  def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
295
+ board = chess.Board()
296
+ board.set_fen(fen)
297
  x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
298
  x = x.view(1, 1, 8, 8, 19)
299
 
300
+ output = self(x)
301
+ if hasattr(output, 'last_hidden_state'):
302
+ logits = output.last_hidden_state
303
+ else:
304
+ logits = output
305
+
306
  logits = logits.view(-1, 1929)
307
  legal_move_mask = torch.zeros((1, 1929), device=device)
308
  for legal_move in board.legal_moves:
 
310
  legal_move_uci = legal_move.uci()[:-1]
311
  else:
312
  legal_move_uci = legal_move.uci()
313
+ legal_move_mask[0][policy_index.index(legal_move_uci)] = 1
314
+
 
 
315
  if force_legal:
316
+ logits = logits + (1 - legal_move_mask) * -999
317
 
318
  if T == 0:
319
  sampled = torch.argmax(logits, dim=-1, keepdim=True)
320
  else:
321
+ probs = F.softmax(logits / T, dim=-1)
322
  sampled = torch.multinomial(probs, num_samples=1)
323
  if return_probs:
324
+ # Map to legal moves
325
+ legal_move_probs = {}
326
+ for move in board.legal_moves:
327
+ idx = policy_index.index(move.uci())
328
+ legal_move_probs[move.uci()] = probs[0,idx].item()
329
+
330
+ return legal_move_probs
331
+
332
  move = policy_index[sampled.item()]
333
  return move
334
 
 
352
  value_logits = self.value_head_q(x_processed)
353
  value_logits = value_logits.view(b, seq_len, 3)
354
  value = torch.softmax(value_logits, dim=-1)
355
+
356
  return value.squeeze()
357
 
358
  def get_batch_position_values(self, fens, device="cuda"):
 
383
  value_logits = self.value_head_q(x_processed)
384
  value_logits = value_logits.view(b, seq_len, 3)
385
  value_logits = torch.softmax(value_logits, dim=-1)
386
+
387
  return value_logits.squeeze(1)
388
 
389
  def calculate_move_values(self, fen, device="cuda"):
 
392
  board.set_fen(fen)
393
 
394
  is_white_turn = board.turn == chess.WHITE
 
395
  legal_moves = list(board.legal_moves)
396
  if len(legal_moves) == 0:
397
  return [], torch.empty(0, device=device)
 
403
  board.pop()
404
 
405
  batch_value_q = self.get_batch_position_values(resulting_fens, device)
406
+ batch_value_q = batch_value_q[:, 2] - batch_value_q[:, 0]
407
 
 
 
408
  if is_white_turn:
409
  player_values = batch_value_q
410
  else:
 
414
 
415
  def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False, to_fall_back_to_policy=False):
416
  """Determine the best move based on the value of resulting positions using efficient batching."""
417
+ value = self.get_position_value(fen, device)
418
+ board = chess.Board()
419
+ board.set_fen(fen)
420
+
421
+ is_white_turn = board.turn == chess.WHITE
422
+ if is_white_turn:
423
+ value = value[2] - value[0]
424
+ else:
425
+ value = value[0] - value[2]
 
 
426
 
427
+ if value > 0.9 and to_fall_back_to_policy:
428
+ self.fall_back_to_policy = True
429
+ if to_fall_back_to_policy and hasattr(self, 'fall_back_to_policy') and self.fall_back_to_policy:
430
+ return self.get_move_from_fen_no_thinking(fen, T, device, force_legal=True, return_probs=return_probs)
431
 
432
  legal_moves, move_values = self.calculate_move_values(fen, device)
433
 
 
450
  probs[best_idx] = 1.0
451
  else:
452
  probs = F.softmax(move_values / T, dim=0)
453
+
454
+ return probs.cpu().numpy()
 
 
 
455
 
456
  return move_uci
 
 
 
 
 
 
 
 
 
vocab.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #vocab consists of
2
+
3
+ #all the possible moves
4
+ #special tokens start_think and end_think
5
+ #special token end
6
+ #special token end_variation
7
+
8
+
9
+ policy_index = [
10
+ "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
11
+ "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
12
+ "a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1",
13
+ "b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3",
14
+ "b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7",
15
+ "b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1",
16
+ "c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3",
17
+ "c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8",
18
+ "d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2",
19
+ "d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4",
20
+ "d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1",
21
+ "e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2",
22
+ "e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4",
23
+ "e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1",
24
+ "f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3",
25
+ "f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6",
26
+ "f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1",
27
+ "g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3",
28
+ "g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8",
29
+ "h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2",
30
+ "h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6",
31
+ "h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2",
32
+ "a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3",
33
+ "a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7",
34
+ "a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2",
35
+ "b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4",
36
+ "b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7",
37
+ "b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2",
38
+ "c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3",
39
+ "c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6",
40
+ "c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1",
41
+ "d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3",
42
+ "d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5",
43
+ "d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1",
44
+ "e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2",
45
+ "e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4",
46
+ "e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1",
47
+ "f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2",
48
+ "f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4",
49
+ "f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7",
50
+ "f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2",
51
+ "g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4",
52
+ "g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8",
53
+ "g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2",
54
+ "h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5",
55
+ "h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1",
56
+ "a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3",
57
+ "a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6",
58
+ "a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1",
59
+ "b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3",
60
+ "b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5",
61
+ "b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1",
62
+ "c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3",
63
+ "c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4",
64
+ "c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6",
65
+ "c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1",
66
+ "d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3",
67
+ "d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5",
68
+ "d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7",
69
+ "d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2",
70
+ "e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3",
71
+ "e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5",
72
+ "e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1",
73
+ "f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3",
74
+ "f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4",
75
+ "f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6",
76
+ "f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2",
77
+ "g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3",
78
+ "g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5",
79
+ "g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1",
80
+ "h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3",
81
+ "h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6",
82
+ "h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2",
83
+ "a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4",
84
+ "a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7",
85
+ "a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3",
86
+ "b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4",
87
+ "b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6",
88
+ "b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2",
89
+ "c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4",
90
+ "c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5",
91
+ "c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8",
92
+ "c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2",
93
+ "d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4",
94
+ "d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6",
95
+ "d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8",
96
+ "e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3",
97
+ "e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4",
98
+ "e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6",
99
+ "e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1",
100
+ "f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3",
101
+ "f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4",
102
+ "f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6",
103
+ "f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2",
104
+ "g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4",
105
+ "g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6",
106
+ "g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1",
107
+ "h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4",
108
+ "h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6",
109
+ "h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2",
110
+ "a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5",
111
+ "a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7",
112
+ "a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3",
113
+ "b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5",
114
+ "b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7",
115
+ "b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2",
116
+ "c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4",
117
+ "c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6",
118
+ "c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7",
119
+ "c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3",
120
+ "d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5",
121
+ "d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6",
122
+ "d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8",
123
+ "d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3",
124
+ "e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5",
125
+ "e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6",
126
+ "e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8",
127
+ "f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3",
128
+ "f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5",
129
+ "f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7",
130
+ "f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2",
131
+ "g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4",
132
+ "g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6",
133
+ "g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1",
134
+ "h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4",
135
+ "h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6",
136
+ "h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2",
137
+ "a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5",
138
+ "a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7",
139
+ "a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3",
140
+ "b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5",
141
+ "b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7",
142
+ "b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2",
143
+ "c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5",
144
+ "c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6",
145
+ "c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8",
146
+ "c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3",
147
+ "d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5",
148
+ "d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7",
149
+ "d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8",
150
+ "e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4",
151
+ "e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6",
152
+ "e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7",
153
+ "e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2",
154
+ "f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5",
155
+ "f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6",
156
+ "f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8",
157
+ "f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3",
158
+ "g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6",
159
+ "g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7",
160
+ "g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2",
161
+ "h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6",
162
+ "h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7",
163
+ "h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3",
164
+ "a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7",
165
+ "a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8",
166
+ "b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5",
167
+ "b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7",
168
+ "b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8",
169
+ "c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5",
170
+ "c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7",
171
+ "c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8",
172
+ "c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4",
173
+ "d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6",
174
+ "d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8",
175
+ "d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4",
176
+ "e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6",
177
+ "e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7",
178
+ "e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2",
179
+ "f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5",
180
+ "f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7",
181
+ "f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1",
182
+ "g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5",
183
+ "g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7",
184
+ "g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1",
185
+ "h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5",
186
+ "h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7",
187
+ "h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2",
188
+ "a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6",
189
+ "a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8",
190
+ "a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5",
191
+ "b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7",
192
+ "b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2",
193
+ "c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6",
194
+ "c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8",
195
+ "c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4",
196
+ "d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6",
197
+ "d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8",
198
+ "d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5",
199
+ "e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7",
200
+ "e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8",
201
+ "e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5",
202
+ "f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7",
203
+ "f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1",
204
+ "g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6",
205
+ "g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8",
206
+ "g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2",
207
+ "h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6",
208
+ "h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8",
209
+ "h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q",
210
+ "b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b",
211
+ "c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r",
212
+ "c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q",
213
+ "d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b",
214
+ "e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r",
215
+ "f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q",
216
+ "g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b",
217
+ "h7h8q", "h7h8r", "h7h8b", #add the promotions for black
218
+ "a2a1q","a2a1r","a2a1b","a2b1q","a2b1r","a2b1b",
219
+ "b2a1q","b2a1r","b2a1b","b2b1q","b2b1r","b2b1b","b2c1q","b2c1r","b2c1b",
220
+ "c2b1q","c2b1r","c2b1b","c2c1q","c2c1r","c2c1b","c2d1q","c2d1r","c2d1b",
221
+ "d2c1q","d2c1r","d2c1b","d2d1q","d2d1r","d2d1b","d2e1q","d2e1r","d2e1b",
222
+ "e2d1q","e2d1r","e2d1b","e2e1q","e2e1r","e2e1b","e2f1q","e2f1r","e2f1b",
223
+ "f2e1q","f2e1r","f2e1b","f2f1q","f2f1r","f2f1b","f2g1q","f2g1r","f2g1b",
224
+ "g2f1q","g2f1r","g2f1b","g2g1q","g2g1r","g2g1b","g2h1q","g2h1r","g2h1b",
225
+ "h2g1q","h2g1r","h2g1b","h2h1q","h2h1r","h2h1b",#add special tokens
226
+ "<thinking>","</thinking>","end_variation","end","padding_token"
227
+ ]
228
+
229
+
230
+
231
+ print("Number of unique tokens: ", len(policy_index))