Update model architecture: d_ff=1024, new weights from merged7.pt
Browse files- attn.py +533 -0
- model.safetensors +2 -2
- modeling_chessbot.py +169 -480
- vocab.py +231 -0
attn.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ScaledDotProductAttention(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
|
| 13 |
+
Compute the dot products of the query with all keys, divide each by sqrt(dim),
|
| 14 |
+
and apply a softmax function to obtain the weights on the values
|
| 15 |
+
|
| 16 |
+
Args: dim, mask
|
| 17 |
+
dim (int): dimention of attention
|
| 18 |
+
mask (torch.Tensor): tensor containing indices to be masked
|
| 19 |
+
|
| 20 |
+
Inputs: query, key, value, mask
|
| 21 |
+
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
|
| 22 |
+
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
|
| 23 |
+
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
|
| 24 |
+
- **mask** (-): tensor containing indices to be masked
|
| 25 |
+
|
| 26 |
+
Returns: context, attn
|
| 27 |
+
- **context**: tensor containing the context vector from attention mechanism.
|
| 28 |
+
- **attn**: tensor containing the attention (alignment) from the encoder outputs.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, dim: int):
|
| 31 |
+
super(ScaledDotProductAttention, self).__init__()
|
| 32 |
+
self.sqrt_dim = np.sqrt(dim)
|
| 33 |
+
|
| 34 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
|
| 35 |
+
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
|
| 36 |
+
|
| 37 |
+
if mask is not None:
|
| 38 |
+
score.masked_fill_(mask.view(score.size()), -float('Inf'))
|
| 39 |
+
|
| 40 |
+
attn = F.softmax(score, -1)
|
| 41 |
+
context = torch.bmm(attn, value)
|
| 42 |
+
return context, attn
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DotProductAttention(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, hidden_dim):
|
| 50 |
+
super(DotProductAttention, self).__init__()
|
| 51 |
+
|
| 52 |
+
def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
|
| 53 |
+
batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1)
|
| 54 |
+
|
| 55 |
+
score = torch.bmm(query, value.transpose(1, 2))
|
| 56 |
+
attn = F.softmax(score.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
|
| 57 |
+
context = torch.bmm(attn, value)
|
| 58 |
+
|
| 59 |
+
return context, attn
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class AdditiveAttention(nn.Module):
|
| 63 |
+
"""
|
| 64 |
+
Applies a additive attention (bahdanau) mechanism on the output features from the decoder.
|
| 65 |
+
Additive attention proposed in "Neural Machine Translation by Jointly Learning to Align and Translate" paper.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
hidden_dim (int): dimesion of hidden state vector
|
| 69 |
+
|
| 70 |
+
Inputs: query, value
|
| 71 |
+
- **query** (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
|
| 72 |
+
- **value** (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence.
|
| 73 |
+
|
| 74 |
+
Returns: context, attn
|
| 75 |
+
- **context**: tensor containing the context vector from attention mechanism.
|
| 76 |
+
- **attn**: tensor containing the alignment from the encoder outputs.
|
| 77 |
+
|
| 78 |
+
Reference:
|
| 79 |
+
- **Neural Machine Translation by Jointly Learning to Align and Translate**: https://arxiv.org/abs/1409.0473
|
| 80 |
+
"""
|
| 81 |
+
def __init__(self, hidden_dim: int) -> None:
|
| 82 |
+
super(AdditiveAttention, self).__init__()
|
| 83 |
+
self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 84 |
+
self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 85 |
+
self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
|
| 86 |
+
self.score_proj = nn.Linear(hidden_dim, 1)
|
| 87 |
+
|
| 88 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
|
| 89 |
+
score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1)
|
| 90 |
+
attn = F.softmax(score, dim=-1)
|
| 91 |
+
context = torch.bmm(attn.unsqueeze(1), value)
|
| 92 |
+
return context, attn
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class LocationAwareAttention(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
Applies a location-aware attention mechanism on the output features from the decoder.
|
| 98 |
+
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
|
| 99 |
+
The location-aware attention mechanism is performing well in speech recognition tasks.
|
| 100 |
+
We refer to implementation of ClovaCall Attention style.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
hidden_dim (int): dimesion of hidden state vector
|
| 104 |
+
smoothing (bool): flag indication whether to use smoothing or not.
|
| 105 |
+
|
| 106 |
+
Inputs: query, value, last_attn, smoothing
|
| 107 |
+
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
|
| 108 |
+
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
|
| 109 |
+
- **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
|
| 110 |
+
|
| 111 |
+
Returns: output, attn
|
| 112 |
+
- **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
|
| 113 |
+
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
|
| 114 |
+
|
| 115 |
+
Reference:
|
| 116 |
+
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
|
| 117 |
+
- **ClovaCall**: https://github.com/clovaai/ClovaCall/blob/master/las.pytorch/models/attention.py
|
| 118 |
+
"""
|
| 119 |
+
def __init__(self, hidden_dim: int, smoothing: bool = True) -> None:
|
| 120 |
+
super(LocationAwareAttention, self).__init__()
|
| 121 |
+
self.hidden_dim = hidden_dim
|
| 122 |
+
self.conv1d = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
|
| 123 |
+
self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 124 |
+
self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 125 |
+
self.score_proj = nn.Linear(hidden_dim, 1, bias=True)
|
| 126 |
+
self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
|
| 127 |
+
self.smoothing = smoothing
|
| 128 |
+
|
| 129 |
+
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
|
| 130 |
+
batch_size, hidden_dim, seq_len = query.size(0), query.size(2), value.size(1)
|
| 131 |
+
|
| 132 |
+
# Initialize previous attention (alignment) to zeros
|
| 133 |
+
if last_attn is None:
|
| 134 |
+
last_attn = value.new_zeros(batch_size, seq_len)
|
| 135 |
+
|
| 136 |
+
conv_attn = torch.transpose(self.conv1d(last_attn.unsqueeze(1)), 1, 2)
|
| 137 |
+
score = self.score_proj(torch.tanh(
|
| 138 |
+
self.query_proj(query.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
|
| 139 |
+
+ self.value_proj(value.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
|
| 140 |
+
+ conv_attn
|
| 141 |
+
+ self.bias
|
| 142 |
+
)).squeeze(dim=-1)
|
| 143 |
+
|
| 144 |
+
if self.smoothing:
|
| 145 |
+
score = torch.sigmoid(score)
|
| 146 |
+
attn = torch.div(score, score.sum(dim=-1).unsqueeze(dim=-1))
|
| 147 |
+
else:
|
| 148 |
+
attn = F.softmax(score, dim=-1)
|
| 149 |
+
|
| 150 |
+
context = torch.bmm(attn.unsqueeze(dim=1), value).squeeze(dim=1) # Bx1xT X BxTxD => Bx1xD => BxD
|
| 151 |
+
|
| 152 |
+
return context, attn
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class MultiHeadLocationAwareAttention(nn.Module):
|
| 156 |
+
"""
|
| 157 |
+
Applies a multi-headed location-aware attention mechanism on the output features from the decoder.
|
| 158 |
+
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
|
| 159 |
+
The location-aware attention mechanism is performing well in speech recognition tasks.
|
| 160 |
+
In the above paper applied a signle head, but we applied multi head concept.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
hidden_dim (int): The number of expected features in the output
|
| 164 |
+
num_heads (int): The number of heads. (default: )
|
| 165 |
+
conv_out_channel (int): The number of out channel in convolution
|
| 166 |
+
|
| 167 |
+
Inputs: query, value, prev_attn
|
| 168 |
+
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
|
| 169 |
+
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
|
| 170 |
+
- **prev_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
|
| 171 |
+
|
| 172 |
+
Returns: output, attn
|
| 173 |
+
- **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
|
| 174 |
+
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
|
| 175 |
+
|
| 176 |
+
Reference:
|
| 177 |
+
- **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
|
| 178 |
+
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
|
| 179 |
+
"""
|
| 180 |
+
def __init__(self, hidden_dim: int, num_heads: int = 8, conv_out_channel: int = 10) -> None:
|
| 181 |
+
super(MultiHeadLocationAwareAttention, self).__init__()
|
| 182 |
+
self.hidden_dim = hidden_dim
|
| 183 |
+
self.num_heads = num_heads
|
| 184 |
+
self.dim = int(hidden_dim / num_heads)
|
| 185 |
+
self.conv1d = nn.Conv1d(num_heads, conv_out_channel, kernel_size=3, padding=1)
|
| 186 |
+
self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
|
| 187 |
+
self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
|
| 188 |
+
self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
|
| 189 |
+
self.score_proj = nn.Linear(self.dim, 1, bias=True)
|
| 190 |
+
self.bias = nn.Parameter(torch.rand(self.dim).uniform_(-0.1, 0.1))
|
| 191 |
+
|
| 192 |
+
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
|
| 193 |
+
batch_size, seq_len = value.size(0), value.size(1)
|
| 194 |
+
|
| 195 |
+
if last_attn is None:
|
| 196 |
+
last_attn = value.new_zeros(batch_size, self.num_heads, seq_len)
|
| 197 |
+
|
| 198 |
+
loc_energy = torch.tanh(self.loc_proj(self.conv1d(last_attn).transpose(1, 2)))
|
| 199 |
+
loc_energy = loc_energy.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(-1, seq_len, self.dim)
|
| 200 |
+
|
| 201 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
|
| 202 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
|
| 203 |
+
query = query.contiguous().view(-1, 1, self.dim)
|
| 204 |
+
value = value.contiguous().view(-1, seq_len, self.dim)
|
| 205 |
+
|
| 206 |
+
score = self.score_proj(torch.tanh(value + query + loc_energy + self.bias)).squeeze(2)
|
| 207 |
+
attn = F.softmax(score, dim=1)
|
| 208 |
+
|
| 209 |
+
value = value.view(batch_size, seq_len, self.num_heads, self.dim).permute(0, 2, 1, 3)
|
| 210 |
+
value = value.contiguous().view(-1, seq_len, self.dim)
|
| 211 |
+
|
| 212 |
+
context = torch.bmm(attn.unsqueeze(1), value).view(batch_size, -1, self.num_heads * self.dim)
|
| 213 |
+
attn = attn.view(batch_size, self.num_heads, -1)
|
| 214 |
+
|
| 215 |
+
return context, attn
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class MultiHeadAttention(nn.Module):
|
| 219 |
+
"""
|
| 220 |
+
Multi-Head Attention proposed in "Attention Is All You Need"
|
| 221 |
+
Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
|
| 222 |
+
project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
|
| 223 |
+
These are concatenated and once again projected, resulting in the final values.
|
| 224 |
+
Multi-head attention allows the model to jointly attend to information from different representation
|
| 225 |
+
subspaces at different positions.
|
| 226 |
+
|
| 227 |
+
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
|
| 228 |
+
where head_i = Attention(Q · W_q, K · W_k, V · W_v)
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
d_model (int): The dimension of keys / values / quries (default: 512)
|
| 232 |
+
num_heads (int): The number of attention heads. (default: 8)
|
| 233 |
+
|
| 234 |
+
Inputs: query, key, value, mask
|
| 235 |
+
- **query** (batch, q_len, d_model): In transformer, three different ways:
|
| 236 |
+
Case 1: come from previoys decoder layer
|
| 237 |
+
Case 2: come from the input embedding
|
| 238 |
+
Case 3: come from the output embedding (masked)
|
| 239 |
+
|
| 240 |
+
- **key** (batch, k_len, d_model): In transformer, three different ways:
|
| 241 |
+
Case 1: come from the output of the encoder
|
| 242 |
+
Case 2: come from the input embeddings
|
| 243 |
+
Case 3: come from the output embedding (masked)
|
| 244 |
+
|
| 245 |
+
- **value** (batch, v_len, d_model): In transformer, three different ways:
|
| 246 |
+
Case 1: come from the output of the encoder
|
| 247 |
+
Case 2: come from the input embeddings
|
| 248 |
+
Case 3: come from the output embedding (masked)
|
| 249 |
+
|
| 250 |
+
- **mask** (-): tensor containing indices to be masked
|
| 251 |
+
|
| 252 |
+
Returns: output, attn
|
| 253 |
+
- **output** (batch, output_len, dimensions): tensor containing the attended output features.
|
| 254 |
+
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
|
| 255 |
+
"""
|
| 256 |
+
def __init__(self, d_model: int = 512, num_heads: int = 8):
|
| 257 |
+
super(MultiHeadAttention, self).__init__()
|
| 258 |
+
|
| 259 |
+
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
|
| 260 |
+
|
| 261 |
+
self.d_head = int(d_model / num_heads)
|
| 262 |
+
self.num_heads = num_heads
|
| 263 |
+
self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
|
| 264 |
+
self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
|
| 265 |
+
self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
|
| 266 |
+
self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
|
| 267 |
+
|
| 268 |
+
def forward(
|
| 269 |
+
self,
|
| 270 |
+
query: Tensor,
|
| 271 |
+
key: Tensor,
|
| 272 |
+
value: Tensor,
|
| 273 |
+
mask: Optional[Tensor] = None
|
| 274 |
+
) -> Tuple[Tensor, Tensor]:
|
| 275 |
+
batch_size = value.size(0)
|
| 276 |
+
|
| 277 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
|
| 278 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
|
| 279 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
|
| 280 |
+
|
| 281 |
+
query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
|
| 282 |
+
key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
|
| 283 |
+
value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
|
| 284 |
+
|
| 285 |
+
if mask is not None:
|
| 286 |
+
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
|
| 287 |
+
|
| 288 |
+
context, attn = self.scaled_dot_attn(query, key, value, mask)
|
| 289 |
+
|
| 290 |
+
context = context.view(self.num_heads, batch_size, -1, self.d_head)
|
| 291 |
+
context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
|
| 292 |
+
|
| 293 |
+
return context, attn
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class RelativeMultiHeadAttention(nn.Module):
|
| 297 |
+
"""
|
| 298 |
+
Multi-head attention with relative positional encoding.
|
| 299 |
+
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
d_model (int): The dimension of model
|
| 303 |
+
num_heads (int): The number of attention heads.
|
| 304 |
+
dropout_p (float): probability of dropout
|
| 305 |
+
|
| 306 |
+
Inputs: query, key, value, pos_embedding, mask
|
| 307 |
+
- **query** (batch, time, dim): Tensor containing query vector
|
| 308 |
+
- **key** (batch, time, dim): Tensor containing key vector
|
| 309 |
+
- **value** (batch, time, dim): Tensor containing value vector
|
| 310 |
+
- **pos_embedding** (batch, time, dim): Positional embedding tensor
|
| 311 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
- **outputs**: Tensor produces by relative multi head attention module.
|
| 315 |
+
"""
|
| 316 |
+
def __init__(
|
| 317 |
+
self,
|
| 318 |
+
d_model: int = 512,
|
| 319 |
+
num_heads: int = 16,
|
| 320 |
+
dropout_p: float = 0.1,
|
| 321 |
+
):
|
| 322 |
+
super(RelativeMultiHeadAttention, self).__init__()
|
| 323 |
+
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
|
| 324 |
+
self.d_model = d_model
|
| 325 |
+
self.d_head = int(d_model / num_heads)
|
| 326 |
+
self.num_heads = num_heads
|
| 327 |
+
self.sqrt_dim = math.sqrt(d_model)
|
| 328 |
+
|
| 329 |
+
self.query_proj = nn.Linear(d_model, d_model)
|
| 330 |
+
self.key_proj = nn.Linear(d_model, d_model)
|
| 331 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
| 332 |
+
|
| 333 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
| 334 |
+
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 335 |
+
torch.nn.init.xavier_uniform_(self.u_bias)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def forward(
|
| 340 |
+
self,
|
| 341 |
+
query: Tensor,
|
| 342 |
+
key: Tensor,
|
| 343 |
+
value: Tensor,
|
| 344 |
+
mask: Optional[Tensor] = None,
|
| 345 |
+
) -> Tensor:
|
| 346 |
+
batch_size = value.size(0)
|
| 347 |
+
|
| 348 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
| 349 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 350 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 351 |
+
|
| 352 |
+
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), (key).transpose(2, 3))
|
| 353 |
+
|
| 354 |
+
score = content_score / self.sqrt_dim
|
| 355 |
+
|
| 356 |
+
if mask is not None:
|
| 357 |
+
mask = mask.unsqueeze(1)
|
| 358 |
+
score.masked_fill_(mask, -1e9)
|
| 359 |
+
|
| 360 |
+
attn = F.softmax(score, -1)
|
| 361 |
+
attn = self.dropout(attn)
|
| 362 |
+
|
| 363 |
+
context = torch.matmul(attn, value).transpose(1, 2)
|
| 364 |
+
context = context.contiguous().view(batch_size, -1, self.d_model)
|
| 365 |
+
|
| 366 |
+
return context
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class CustomizingAttention(nn.Module):
|
| 372 |
+
r"""
|
| 373 |
+
Customizing Attention
|
| 374 |
+
|
| 375 |
+
Applies a multi-head + location-aware attention mechanism on the output features from the decoder.
|
| 376 |
+
Multi-head attention proposed in "Attention Is All You Need" paper.
|
| 377 |
+
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
|
| 378 |
+
I combined these two attention mechanisms as custom.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
hidden_dim (int): The number of expected features in the output
|
| 382 |
+
num_heads (int): The number of heads. (default: )
|
| 383 |
+
conv_out_channel (int): The dimension of convolution
|
| 384 |
+
|
| 385 |
+
Inputs: query, value, last_attn
|
| 386 |
+
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
|
| 387 |
+
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
|
| 388 |
+
- **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s alignment
|
| 389 |
+
|
| 390 |
+
Returns: output, attn
|
| 391 |
+
- **output** (batch, output_len, dimensions): tensor containing the attended output features from the decoder.
|
| 392 |
+
- **attn** (batch * num_heads, v_len): tensor containing the alignment from the encoder outputs.
|
| 393 |
+
|
| 394 |
+
Reference:
|
| 395 |
+
- **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
|
| 396 |
+
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
def __init__(self, hidden_dim: int, num_heads: int = 4, conv_out_channel: int = 10) -> None:
|
| 400 |
+
super(CustomizingAttention, self).__init__()
|
| 401 |
+
self.hidden_dim = hidden_dim
|
| 402 |
+
self.num_heads = num_heads
|
| 403 |
+
self.dim = int(hidden_dim / num_heads)
|
| 404 |
+
self.scaled_dot_attn = ScaledDotProductAttention(self.dim)
|
| 405 |
+
self.conv1d = nn.Conv1d(1, conv_out_channel, kernel_size=3, padding=1)
|
| 406 |
+
self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=True)
|
| 407 |
+
self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
|
| 408 |
+
self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
|
| 409 |
+
self.bias = nn.Parameter(torch.rand(self.dim * num_heads).uniform_(-0.1, 0.1))
|
| 410 |
+
|
| 411 |
+
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
|
| 412 |
+
batch_size, q_len, v_len = value.size(0), query.size(1), value.size(1)
|
| 413 |
+
|
| 414 |
+
if last_attn is None:
|
| 415 |
+
last_attn = value.new_zeros(batch_size * self.num_heads, v_len)
|
| 416 |
+
|
| 417 |
+
loc_energy = self.get_loc_energy(last_attn, batch_size, v_len) # get location energy
|
| 418 |
+
|
| 419 |
+
query = self.query_proj(query).view(batch_size, q_len, self.num_heads * self.dim)
|
| 420 |
+
value = self.value_proj(value).view(batch_size, v_len, self.num_heads * self.dim) + loc_energy + self.bias
|
| 421 |
+
|
| 422 |
+
query = query.view(batch_size, q_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
|
| 423 |
+
value = value.view(batch_size, v_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
|
| 424 |
+
query = query.contiguous().view(-1, q_len, self.dim)
|
| 425 |
+
value = value.contiguous().view(-1, v_len, self.dim)
|
| 426 |
+
|
| 427 |
+
context, attn = self.scaled_dot_attn(query, value)
|
| 428 |
+
attn = attn.squeeze()
|
| 429 |
+
|
| 430 |
+
context = context.view(self.num_heads, batch_size, q_len, self.dim).permute(1, 2, 0, 3)
|
| 431 |
+
context = context.contiguous().view(batch_size, q_len, -1)
|
| 432 |
+
|
| 433 |
+
return context, attn
|
| 434 |
+
|
| 435 |
+
def get_loc_energy(self, last_attn: Tensor, batch_size: int, v_len: int) -> Tensor:
|
| 436 |
+
conv_feat = self.conv1d(last_attn.unsqueeze(1))
|
| 437 |
+
conv_feat = conv_feat.view(batch_size, self.num_heads, -1, v_len).permute(0, 1, 3, 2)
|
| 438 |
+
|
| 439 |
+
loc_energy = self.loc_proj(conv_feat).view(batch_size, self.num_heads, v_len, self.dim)
|
| 440 |
+
loc_energy = loc_energy.permute(0, 2, 1, 3).reshape(batch_size, v_len, self.num_heads * self.dim)
|
| 441 |
+
|
| 442 |
+
return loc_energy
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class RelativeMultiHeadAttention2(nn.Module):
|
| 446 |
+
"""
|
| 447 |
+
Multi-head attention with relative positional encoding.
|
| 448 |
+
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
d_model (int): The dimension of model
|
| 452 |
+
num_heads (int): The number of attention heads.
|
| 453 |
+
dropout_p (float): probability of dropout
|
| 454 |
+
|
| 455 |
+
Inputs: query, key, value, pos_embedding, mask
|
| 456 |
+
- **query** (batch, time, dim): Tensor containing query vector
|
| 457 |
+
- **key** (batch, time, dim): Tensor containing key vector
|
| 458 |
+
- **value** (batch, time, dim): Tensor containing value vector
|
| 459 |
+
- **pos_embedding** (batch, time, dim): Positional embedding tensor
|
| 460 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
- **outputs**: Tensor produces by relative multi head attention module.
|
| 464 |
+
"""
|
| 465 |
+
def __init__(
|
| 466 |
+
self,
|
| 467 |
+
d_model: int = 512,
|
| 468 |
+
num_heads: int = 16,
|
| 469 |
+
dropout_p: float = 0.1,
|
| 470 |
+
):
|
| 471 |
+
super(RelativeMultiHeadAttention2, self).__init__()
|
| 472 |
+
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
|
| 473 |
+
self.d_model = d_model
|
| 474 |
+
self.d_head = int(d_model / num_heads)
|
| 475 |
+
self.num_heads = num_heads
|
| 476 |
+
self.sqrt_dim = math.sqrt(d_model)
|
| 477 |
+
|
| 478 |
+
self.query_proj = nn.Linear(d_model, d_model)
|
| 479 |
+
self.key_proj = nn.Linear(d_model, d_model)
|
| 480 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
| 481 |
+
self.pos_proj = nn.Linear(d_model, d_model, bias=False)
|
| 482 |
+
|
| 483 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
| 484 |
+
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 485 |
+
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 486 |
+
torch.nn.init.xavier_uniform_(self.u_bias)
|
| 487 |
+
torch.nn.init.xavier_uniform_(self.v_bias)
|
| 488 |
+
|
| 489 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 490 |
+
|
| 491 |
+
def forward(
|
| 492 |
+
self,
|
| 493 |
+
query: Tensor,
|
| 494 |
+
key: Tensor,
|
| 495 |
+
value: Tensor,
|
| 496 |
+
pos_embedding: Tensor,
|
| 497 |
+
mask: Optional[Tensor] = None,
|
| 498 |
+
) -> Tensor:
|
| 499 |
+
batch_size = value.size(0)
|
| 500 |
+
|
| 501 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
| 502 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 503 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 504 |
+
|
| 505 |
+
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
| 506 |
+
|
| 507 |
+
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
|
| 508 |
+
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
| 509 |
+
pos_score = self._compute_relative_positional_encoding(pos_score)
|
| 510 |
+
|
| 511 |
+
score = (content_score + pos_score) / self.sqrt_dim
|
| 512 |
+
|
| 513 |
+
if mask is not None:
|
| 514 |
+
mask = mask.unsqueeze(1)
|
| 515 |
+
score.masked_fill_(mask, -1e9)
|
| 516 |
+
|
| 517 |
+
attn = F.softmax(score, -1)
|
| 518 |
+
attn = self.dropout(attn)
|
| 519 |
+
|
| 520 |
+
context = torch.matmul(attn, value).transpose(1, 2)
|
| 521 |
+
context = context.contiguous().view(batch_size, -1, self.d_model)
|
| 522 |
+
|
| 523 |
+
return self.out_proj(context)
|
| 524 |
+
|
| 525 |
+
def _compute_relative_positional_encoding(self, pos_score: Tensor) -> Tensor:
|
| 526 |
+
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
| 527 |
+
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
| 528 |
+
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
| 529 |
+
|
| 530 |
+
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
| 531 |
+
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
| 532 |
+
|
| 533 |
+
return pos_score
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b33c63dd6fd9936f63e976118e2c4301328d8017edaaa90892335e61ffed6929
|
| 3 |
+
size 138640496
|
modeling_chessbot.py
CHANGED
|
@@ -1,376 +1,81 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Updated HuggingFace Compatible ChessBot Chess Model
|
| 3 |
-
|
| 4 |
-
This file contains the updated architecture with d_ff=1024 and new weights
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
import torch.nn.functional as F
|
| 10 |
-
import numpy as np
|
| 11 |
-
import chess
|
| 12 |
-
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
|
| 13 |
-
from transformers.modeling_outputs import BaseModelOutput
|
| 14 |
-
from typing import Optional, Tuple
|
| 15 |
-
import math
|
| 16 |
import sys
|
| 17 |
import os
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
super().__init__()
|
| 22 |
-
assert d_model % num_heads == 0
|
| 23 |
-
self.d_model = d_model
|
| 24 |
-
self.num_heads = num_heads
|
| 25 |
-
self.d_head = d_model // num_heads
|
| 26 |
-
self.sqrt_dim = math.sqrt(d_model)
|
| 27 |
-
|
| 28 |
-
self.query_proj = nn.Linear(d_model, d_model)
|
| 29 |
-
self.key_proj = nn.Linear(d_model, d_model)
|
| 30 |
-
self.value_proj = nn.Linear(d_model, d_model)
|
| 31 |
-
self.pos_proj = nn.Linear(d_model, d_model)
|
| 32 |
-
self.out_proj = nn.Linear(d_model, d_model)
|
| 33 |
-
|
| 34 |
-
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 35 |
-
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 36 |
-
torch.nn.init.xavier_uniform_(self.u_bias)
|
| 37 |
-
torch.nn.init.xavier_uniform_(self.v_bias)
|
| 38 |
-
self.dropout = nn.Dropout(dropout_p)
|
| 39 |
-
|
| 40 |
-
def forward(self, query, key, value, pos_embedding, mask=None):
|
| 41 |
-
batch_size = value.size(0)
|
| 42 |
-
|
| 43 |
-
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
| 44 |
-
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 45 |
-
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 46 |
-
|
| 47 |
-
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
| 48 |
-
|
| 49 |
-
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
|
| 50 |
-
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
| 51 |
-
pos_score = self._compute_relative_positional_encoding(pos_score)
|
| 52 |
-
|
| 53 |
-
score = (content_score + pos_score) / self.sqrt_dim
|
| 54 |
-
|
| 55 |
-
if mask is not None:
|
| 56 |
-
mask = mask.unsqueeze(1)
|
| 57 |
-
score.masked_fill_(mask, -1e9)
|
| 58 |
-
|
| 59 |
-
attn = F.softmax(score, -1)
|
| 60 |
-
attn = self.dropout(attn)
|
| 61 |
-
|
| 62 |
-
context = torch.matmul(attn, value).transpose(1, 2)
|
| 63 |
-
context = context.contiguous().view(batch_size, -1, self.d_model)
|
| 64 |
-
|
| 65 |
-
return self.out_proj(context)
|
| 66 |
-
|
| 67 |
-
def _compute_relative_positional_encoding(self, pos_score):
|
| 68 |
-
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
| 69 |
-
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
| 70 |
-
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
| 71 |
-
|
| 72 |
-
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
| 73 |
-
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
| 74 |
-
|
| 75 |
-
return pos_score
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
def fen_to_tensor(fen: str):
|
| 79 |
-
"""Convert FEN string to tensor representation for the model."""
|
| 80 |
board = chess.Board(fen)
|
| 81 |
-
|
|
|
|
| 82 |
|
| 83 |
-
# Piece mapping
|
| 84 |
piece_map = {
|
| 85 |
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
|
| 86 |
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
|
| 87 |
}
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
for square in
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
col = square % 8
|
| 95 |
-
tensor[row, col, piece_map[piece.symbol()]] = 1.0
|
| 96 |
-
|
| 97 |
-
# Add metadata channels
|
| 98 |
-
# Channel 12: White to move
|
| 99 |
-
if board.turn == chess.WHITE:
|
| 100 |
-
tensor[:, :, 12] = 1.0
|
| 101 |
|
| 102 |
-
#
|
| 103 |
-
if board.turn == chess.
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
# Castling rights
|
| 107 |
-
if board.has_kingside_castling_rights(chess.WHITE)
|
| 108 |
-
|
| 109 |
-
if board.
|
| 110 |
-
|
| 111 |
-
if board.has_kingside_castling_rights(chess.BLACK):
|
| 112 |
-
tensor[:, :, 16] = 1.0
|
| 113 |
-
if board.has_queenside_castling_rights(chess.BLACK):
|
| 114 |
-
tensor[:, :, 17] = 1.0
|
| 115 |
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
ep_row = 7 - (board.ep_square // 8)
|
| 119 |
-
ep_col = board.ep_square % 8
|
| 120 |
-
tensor[ep_row, ep_col, 18] = 1.0
|
| 121 |
|
| 122 |
return tensor
|
| 123 |
|
| 124 |
-
|
| 125 |
-
policy_index = [
|
| 126 |
-
"a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
|
| 127 |
-
"a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
|
| 128 |
-
"a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1",
|
| 129 |
-
"b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3",
|
| 130 |
-
"b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7",
|
| 131 |
-
"b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1",
|
| 132 |
-
"c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3",
|
| 133 |
-
"c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8",
|
| 134 |
-
"d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2",
|
| 135 |
-
"d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4",
|
| 136 |
-
"d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1",
|
| 137 |
-
"e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2",
|
| 138 |
-
"e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4",
|
| 139 |
-
"e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1",
|
| 140 |
-
"f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3",
|
| 141 |
-
"f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6",
|
| 142 |
-
"f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1",
|
| 143 |
-
"g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3",
|
| 144 |
-
"g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8",
|
| 145 |
-
"h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2",
|
| 146 |
-
"h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6",
|
| 147 |
-
"h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2",
|
| 148 |
-
"a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3",
|
| 149 |
-
"a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7",
|
| 150 |
-
"a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2",
|
| 151 |
-
"b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4",
|
| 152 |
-
"b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7",
|
| 153 |
-
"b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2",
|
| 154 |
-
"c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3",
|
| 155 |
-
"c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6",
|
| 156 |
-
"c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1",
|
| 157 |
-
"d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3",
|
| 158 |
-
"d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5",
|
| 159 |
-
"d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1",
|
| 160 |
-
"e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2",
|
| 161 |
-
"e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4",
|
| 162 |
-
"e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1",
|
| 163 |
-
"f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2",
|
| 164 |
-
"f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4",
|
| 165 |
-
"f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7",
|
| 166 |
-
"f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2",
|
| 167 |
-
"g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4",
|
| 168 |
-
"g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8",
|
| 169 |
-
"g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2",
|
| 170 |
-
"h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5",
|
| 171 |
-
"h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1",
|
| 172 |
-
"a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3",
|
| 173 |
-
"a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6",
|
| 174 |
-
"a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1",
|
| 175 |
-
"b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3",
|
| 176 |
-
"b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5",
|
| 177 |
-
"b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1",
|
| 178 |
-
"c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3",
|
| 179 |
-
"c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4",
|
| 180 |
-
"c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6",
|
| 181 |
-
"c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1",
|
| 182 |
-
"d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3",
|
| 183 |
-
"d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5",
|
| 184 |
-
"d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7",
|
| 185 |
-
"d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2",
|
| 186 |
-
"e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3",
|
| 187 |
-
"e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5",
|
| 188 |
-
"e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1",
|
| 189 |
-
"f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3",
|
| 190 |
-
"f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4",
|
| 191 |
-
"f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6",
|
| 192 |
-
"f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2",
|
| 193 |
-
"g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3",
|
| 194 |
-
"g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5",
|
| 195 |
-
"g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1",
|
| 196 |
-
"h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3",
|
| 197 |
-
"h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6",
|
| 198 |
-
"h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2",
|
| 199 |
-
"a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4",
|
| 200 |
-
"a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7",
|
| 201 |
-
"a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3",
|
| 202 |
-
"b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4",
|
| 203 |
-
"b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6",
|
| 204 |
-
"b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2",
|
| 205 |
-
"c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4",
|
| 206 |
-
"c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5",
|
| 207 |
-
"c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8",
|
| 208 |
-
"c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2",
|
| 209 |
-
"d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4",
|
| 210 |
-
"d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6",
|
| 211 |
-
"d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8",
|
| 212 |
-
"e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3",
|
| 213 |
-
"e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4",
|
| 214 |
-
"e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6",
|
| 215 |
-
"e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1",
|
| 216 |
-
"f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3",
|
| 217 |
-
"f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4",
|
| 218 |
-
"f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6",
|
| 219 |
-
"f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2",
|
| 220 |
-
"g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4",
|
| 221 |
-
"g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6",
|
| 222 |
-
"g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1",
|
| 223 |
-
"h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4",
|
| 224 |
-
"h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6",
|
| 225 |
-
"h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2",
|
| 226 |
-
"a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5",
|
| 227 |
-
"a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7",
|
| 228 |
-
"a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3",
|
| 229 |
-
"b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5",
|
| 230 |
-
"b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7",
|
| 231 |
-
"b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2",
|
| 232 |
-
"c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4",
|
| 233 |
-
"c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6",
|
| 234 |
-
"c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7",
|
| 235 |
-
"c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3",
|
| 236 |
-
"d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5",
|
| 237 |
-
"d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6",
|
| 238 |
-
"d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8",
|
| 239 |
-
"d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3",
|
| 240 |
-
"e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5",
|
| 241 |
-
"e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6",
|
| 242 |
-
"e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8",
|
| 243 |
-
"f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3",
|
| 244 |
-
"f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5",
|
| 245 |
-
"f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7",
|
| 246 |
-
"f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2",
|
| 247 |
-
"g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4",
|
| 248 |
-
"g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6",
|
| 249 |
-
"g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1",
|
| 250 |
-
"h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4",
|
| 251 |
-
"h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6",
|
| 252 |
-
"h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2",
|
| 253 |
-
"a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5",
|
| 254 |
-
"a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7",
|
| 255 |
-
"a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3",
|
| 256 |
-
"b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5",
|
| 257 |
-
"b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7",
|
| 258 |
-
"b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2",
|
| 259 |
-
"c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5",
|
| 260 |
-
"c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6",
|
| 261 |
-
"c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8",
|
| 262 |
-
"c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3",
|
| 263 |
-
"d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5",
|
| 264 |
-
"d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7",
|
| 265 |
-
"d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8",
|
| 266 |
-
"e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4",
|
| 267 |
-
"e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6",
|
| 268 |
-
"e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7",
|
| 269 |
-
"e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2",
|
| 270 |
-
"f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5",
|
| 271 |
-
"f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6",
|
| 272 |
-
"f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8",
|
| 273 |
-
"f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3",
|
| 274 |
-
"g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6",
|
| 275 |
-
"g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7",
|
| 276 |
-
"g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2",
|
| 277 |
-
"h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6",
|
| 278 |
-
"h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7",
|
| 279 |
-
"h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3",
|
| 280 |
-
"a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7",
|
| 281 |
-
"a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8",
|
| 282 |
-
"b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5",
|
| 283 |
-
"b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7",
|
| 284 |
-
"b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8",
|
| 285 |
-
"c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5",
|
| 286 |
-
"c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7",
|
| 287 |
-
"c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8",
|
| 288 |
-
"c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4",
|
| 289 |
-
"d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6",
|
| 290 |
-
"d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8",
|
| 291 |
-
"d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4",
|
| 292 |
-
"e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6",
|
| 293 |
-
"e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7",
|
| 294 |
-
"e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2",
|
| 295 |
-
"f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5",
|
| 296 |
-
"f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7",
|
| 297 |
-
"f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1",
|
| 298 |
-
"g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5",
|
| 299 |
-
"g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7",
|
| 300 |
-
"g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1",
|
| 301 |
-
"h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5",
|
| 302 |
-
"h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7",
|
| 303 |
-
"h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2",
|
| 304 |
-
"a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6",
|
| 305 |
-
"a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8",
|
| 306 |
-
"a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5",
|
| 307 |
-
"b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7",
|
| 308 |
-
"b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2",
|
| 309 |
-
"c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6",
|
| 310 |
-
"c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8",
|
| 311 |
-
"c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4",
|
| 312 |
-
"d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6",
|
| 313 |
-
"d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8",
|
| 314 |
-
"d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5",
|
| 315 |
-
"e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7",
|
| 316 |
-
"e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8",
|
| 317 |
-
"e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5",
|
| 318 |
-
"f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7",
|
| 319 |
-
"f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1",
|
| 320 |
-
"g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6",
|
| 321 |
-
"g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8",
|
| 322 |
-
"g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2",
|
| 323 |
-
"h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6",
|
| 324 |
-
"h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8",
|
| 325 |
-
"h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q",
|
| 326 |
-
"b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b",
|
| 327 |
-
"c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r",
|
| 328 |
-
"c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q",
|
| 329 |
-
"d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b",
|
| 330 |
-
"e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r",
|
| 331 |
-
"f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q",
|
| 332 |
-
"g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b",
|
| 333 |
-
"h7h8q", "h7h8r", "h7h8b", #add the promotions for black
|
| 334 |
-
"a2a1q","a2a1r","a2a1b","a2b1q","a2b1r","a2b1b",
|
| 335 |
-
"b2a1q","b2a1r","b2a1b","b2b1q","b2b1r","b2b1b","b2c1q","b2c1r","b2c1b",
|
| 336 |
-
"c2b1q","c2b1r","c2b1b","c2c1q","c2c1r","c2c1b","c2d1q","c2d1r","c2d1b",
|
| 337 |
-
"d2c1q","d2c1r","d2c1b","d2d1q","d2d1r","d2d1b","d2e1q","d2e1r","d2e1b",
|
| 338 |
-
"e2d1q","e2d1r","e2d1b","e2e1q","e2e1r","e2e1b","e2f1q","e2f1r","e2f1b",
|
| 339 |
-
"f2e1q","f2e1r","f2e1b","f2f1q","f2f1r","f2f1b","f2g1q","f2g1r","f2g1b",
|
| 340 |
-
"g2f1q","g2f1r","g2f1b","g2g1q","g2g1r","g2g1b","g2h1q","g2h1r","g2h1b",
|
| 341 |
-
"h2g1q","h2g1r","h2g1b","h2h1q","h2h1r","h2h1b",#add special tokens
|
| 342 |
-
"<thinking>","</thinking>","end_variation","end","padding_token"
|
| 343 |
-
]
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
# Configuration class
|
| 347 |
class ChessBotConfig(PretrainedConfig):
|
| 348 |
-
"""
|
| 349 |
-
Configuration class for ChessBot model.
|
| 350 |
-
"""
|
| 351 |
-
|
| 352 |
model_type = "chessbot"
|
| 353 |
|
| 354 |
def __init__(
|
| 355 |
self,
|
| 356 |
-
num_layers
|
| 357 |
-
d_model
|
| 358 |
-
d_ff
|
| 359 |
-
num_heads
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
|
|
|
| 363 |
):
|
| 364 |
-
super().__init__(**kwargs)
|
| 365 |
self.num_layers = num_layers
|
| 366 |
self.d_model = d_model
|
| 367 |
self.d_ff = d_ff
|
| 368 |
self.num_heads = num_heads
|
| 369 |
-
self.vocab_size = vocab_size
|
| 370 |
self.max_position_embeddings = max_position_embeddings
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
|
| 373 |
-
# Model components
|
| 374 |
class MaGating(nn.Module):
|
| 375 |
def __init__(self, d_model):
|
| 376 |
super().__init__()
|
|
@@ -395,38 +100,54 @@ class EncoderLayer(nn.Module):
|
|
| 395 |
attn_out = self.attention(x, x, x, pos_enc)
|
| 396 |
x = attn_out + x
|
| 397 |
x = self.norm1(x)
|
| 398 |
-
|
| 399 |
y = self.ff1(x)
|
| 400 |
y = self.gelu(y)
|
| 401 |
y = self.ff2(y)
|
| 402 |
y = y + x
|
| 403 |
y = self.norm2(y)
|
| 404 |
-
|
| 405 |
return y
|
| 406 |
|
| 407 |
|
| 408 |
class AbsolutePositionalEncoder(nn.Module):
|
| 409 |
def __init__(self, d_model):
|
| 410 |
-
super().__init__()
|
| 411 |
-
self.
|
| 412 |
-
self.
|
|
|
|
|
|
|
| 413 |
_2i = torch.arange(0, d_model, step=2).float()
|
| 414 |
-
|
| 415 |
-
|
| 416 |
|
| 417 |
-
|
| 418 |
-
self.register_buffer('pos_encoding', self.positional_encoding)
|
| 419 |
|
| 420 |
def forward(self, x):
|
| 421 |
batch_size, _, _ = x.size()
|
| 422 |
-
return self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
|
| 425 |
class ValueHead(nn.Module):
|
| 426 |
def __init__(self, d_model):
|
| 427 |
super().__init__()
|
| 428 |
self.dense1 = nn.Linear(d_model, 128)
|
| 429 |
-
self.dense2 = nn.Linear(128*64, 128)
|
| 430 |
self.dense3 = nn.Linear(128, 3)
|
| 431 |
|
| 432 |
def forward(self, x):
|
|
@@ -438,13 +159,13 @@ class ValueHead(nn.Module):
|
|
| 438 |
x = F.gelu(x)
|
| 439 |
x = self.dense3(x)
|
| 440 |
return x
|
| 441 |
-
|
| 442 |
|
| 443 |
class ValueHeadQ(nn.Module):
|
| 444 |
def __init__(self, d_model):
|
| 445 |
super().__init__()
|
| 446 |
self.dense1 = nn.Linear(d_model, 128)
|
| 447 |
-
self.dense2 = nn.Linear(128*64, 128)
|
| 448 |
self.dense3 = nn.Linear(128, 3)
|
| 449 |
|
| 450 |
def forward(self, x):
|
|
@@ -458,49 +179,21 @@ class ValueHeadQ(nn.Module):
|
|
| 458 |
return x
|
| 459 |
|
| 460 |
|
| 461 |
-
|
| 462 |
-
class ChessBotPreTrainedModel(PreTrainedModel):
|
| 463 |
-
"""
|
| 464 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
| 465 |
-
"""
|
| 466 |
-
|
| 467 |
config_class = ChessBotConfig
|
| 468 |
-
base_model_prefix = "chessbot"
|
| 469 |
-
supports_gradient_checkpointing = True
|
| 470 |
-
|
| 471 |
-
def _init_weights(self, module):
|
| 472 |
-
"""Initialize the weights"""
|
| 473 |
-
if isinstance(module, nn.Linear):
|
| 474 |
-
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 475 |
-
if module.bias is not None:
|
| 476 |
-
module.bias.data.zero_()
|
| 477 |
-
elif isinstance(module, nn.Embedding):
|
| 478 |
-
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 479 |
-
elif isinstance(module, nn.LayerNorm):
|
| 480 |
-
module.bias.data.zero_()
|
| 481 |
-
module.weight.data.fill_(1.0)
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
class ChessBotModel(ChessBotPreTrainedModel):
|
| 485 |
-
"""
|
| 486 |
-
Updated HuggingFace compatible ChessBot Chess model with d_ff=1024
|
| 487 |
-
"""
|
| 488 |
|
| 489 |
def __init__(self, config):
|
| 490 |
super().__init__(config)
|
| 491 |
self.config = config
|
| 492 |
-
|
| 493 |
-
# Initialize exactly like the updated BT4 model
|
| 494 |
self.is_thinking_model = False
|
| 495 |
self.d_model = config.d_model
|
| 496 |
self.num_layers = config.num_layers
|
| 497 |
-
|
| 498 |
-
# Model layers - same as updated model
|
| 499 |
self.layers = nn.ModuleList([
|
| 500 |
EncoderLayer(config.d_model, config.d_ff, config.num_heads)
|
| 501 |
for _ in range(config.num_layers)
|
| 502 |
])
|
| 503 |
-
|
| 504 |
self.linear1 = nn.Linear(19, config.d_model)
|
| 505 |
self.layernorm1 = nn.LayerNorm(config.d_model)
|
| 506 |
self.policy_tokens_lin = nn.Linear(config.d_model, config.d_model)
|
|
@@ -508,29 +201,33 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 508 |
self.keys_pol = nn.Linear(config.d_model, config.d_model)
|
| 509 |
self.positional = AbsolutePositionalEncoder(config.d_model)
|
| 510 |
self.ma_gating = MaGating(config.d_model)
|
| 511 |
-
self.policy_head = nn.Linear(64*64,
|
| 512 |
self.value_head = ValueHead(config.d_model)
|
| 513 |
self.value_head_q = ValueHeadQ(config.d_model)
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
self
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
|
|
|
| 524 |
inp = input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
x = inp[0]
|
| 526 |
-
compute_loss = compute_loss or len(inp) > 1
|
| 527 |
else:
|
| 528 |
-
x =
|
| 529 |
-
|
| 530 |
-
|
| 531 |
b, seq_len, _, _, emb = x.size()
|
| 532 |
x = x.view(b * seq_len, 64, emb)
|
| 533 |
-
|
| 534 |
x = self.linear1(x)
|
| 535 |
x = F.gelu(x)
|
| 536 |
x = self.layernorm1(x)
|
|
@@ -539,7 +236,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 539 |
pos_enc = self.positional(x)
|
| 540 |
for i in range(self.num_layers):
|
| 541 |
x = self.layers[i](x, pos_enc)
|
| 542 |
-
|
| 543 |
value_h = self.value_head(x)
|
| 544 |
value_h = value_h.view(b, seq_len, 3)
|
| 545 |
value_h_q = self.value_head_q(x)
|
|
@@ -548,64 +245,64 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 548 |
policy_tokens = self.policy_tokens_lin(x)
|
| 549 |
policy_tokens = F.gelu(policy_tokens)
|
| 550 |
policy_tokens = policy_tokens + pos_enc
|
| 551 |
-
|
| 552 |
queries = self.queries_pol(policy_tokens)
|
| 553 |
keys = self.keys_pol(policy_tokens)
|
| 554 |
-
|
| 555 |
matmul_qk = torch.matmul(queries, torch.transpose(keys, -2, -1))
|
| 556 |
-
dk = torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
|
| 557 |
-
|
| 558 |
policy_attn_logits = matmul_qk / dk
|
| 559 |
-
policy_attn_logits = policy_attn_logits.view(b, seq_len, 64*64)
|
| 560 |
-
|
| 561 |
policy = self.policy_head(policy_attn_logits)
|
| 562 |
-
|
| 563 |
-
if compute_loss:
|
| 564 |
targets = inp[1]
|
| 565 |
-
true_values = inp[3]
|
| 566 |
-
q_values = inp[4]
|
| 567 |
-
true_values = q_values
|
| 568 |
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
# Create mask for samples where true_values/q_values is not [0,0,0]
|
| 577 |
-
valid_mask = (true_values.sum(dim=-1) != 0) & (q_values.sum(dim=-1) != 0)
|
| 578 |
-
|
| 579 |
-
# Only compute value losses if we have valid samples
|
| 580 |
-
if valid_mask.any():
|
| 581 |
-
# Filter to only valid samples
|
| 582 |
-
valid_value_h = value_h[valid_mask]
|
| 583 |
-
valid_value_h_q = value_h_q_softmax[valid_mask]
|
| 584 |
-
valid_z = z[valid_mask]
|
| 585 |
-
valid_q_values = q_values[valid_mask]
|
| 586 |
|
| 587 |
-
|
| 588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
else:
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
return policy, value_h, value_h_q, loss_policy, loss_value, loss_q, targets, z, q
|
| 595 |
-
|
| 596 |
return BaseModelOutput(
|
| 597 |
-
last_hidden_state=
|
| 598 |
-
hidden_states=
|
| 599 |
-
|
| 600 |
-
), policy, value_h, value_h_q
|
| 601 |
|
| 602 |
def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
|
| 603 |
-
|
| 604 |
-
board
|
| 605 |
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
| 606 |
x = x.view(1, 1, 8, 8, 19)
|
| 607 |
|
| 608 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
logits = logits.view(-1, 1929)
|
| 610 |
legal_move_mask = torch.zeros((1, 1929), device=device)
|
| 611 |
for legal_move in board.legal_moves:
|
|
@@ -613,21 +310,25 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 613 |
legal_move_uci = legal_move.uci()[:-1]
|
| 614 |
else:
|
| 615 |
legal_move_uci = legal_move.uci()
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
# Set all illegal moves to -inf
|
| 620 |
if force_legal:
|
| 621 |
-
logits = logits + (1-legal_move_mask) * -999
|
| 622 |
|
| 623 |
if T == 0:
|
| 624 |
sampled = torch.argmax(logits, dim=-1, keepdim=True)
|
| 625 |
else:
|
| 626 |
-
probs = F.softmax(logits/T, dim=-1)
|
| 627 |
sampled = torch.multinomial(probs, num_samples=1)
|
| 628 |
if return_probs:
|
| 629 |
-
|
| 630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
move = policy_index[sampled.item()]
|
| 632 |
return move
|
| 633 |
|
|
@@ -651,6 +352,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 651 |
value_logits = self.value_head_q(x_processed)
|
| 652 |
value_logits = value_logits.view(b, seq_len, 3)
|
| 653 |
value = torch.softmax(value_logits, dim=-1)
|
|
|
|
| 654 |
return value.squeeze()
|
| 655 |
|
| 656 |
def get_batch_position_values(self, fens, device="cuda"):
|
|
@@ -681,6 +383,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 681 |
value_logits = self.value_head_q(x_processed)
|
| 682 |
value_logits = value_logits.view(b, seq_len, 3)
|
| 683 |
value_logits = torch.softmax(value_logits, dim=-1)
|
|
|
|
| 684 |
return value_logits.squeeze(1)
|
| 685 |
|
| 686 |
def calculate_move_values(self, fen, device="cuda"):
|
|
@@ -689,7 +392,6 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 689 |
board.set_fen(fen)
|
| 690 |
|
| 691 |
is_white_turn = board.turn == chess.WHITE
|
| 692 |
-
|
| 693 |
legal_moves = list(board.legal_moves)
|
| 694 |
if len(legal_moves) == 0:
|
| 695 |
return [], torch.empty(0, device=device)
|
|
@@ -701,9 +403,8 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 701 |
board.pop()
|
| 702 |
|
| 703 |
batch_value_q = self.get_batch_position_values(resulting_fens, device)
|
|
|
|
| 704 |
|
| 705 |
-
# Calculate values from the current player's perspective
|
| 706 |
-
batch_value_q = batch_value_q[:,2]-batch_value_q[:,0]
|
| 707 |
if is_white_turn:
|
| 708 |
player_values = batch_value_q
|
| 709 |
else:
|
|
@@ -713,20 +414,20 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 713 |
|
| 714 |
def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False, to_fall_back_to_policy=False):
|
| 715 |
"""Determine the best move based on the value of resulting positions using efficient batching."""
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
else:
|
| 726 |
-
value = value[0]-value[2]
|
| 727 |
|
| 728 |
-
|
| 729 |
-
|
|
|
|
|
|
|
| 730 |
|
| 731 |
legal_moves, move_values = self.calculate_move_values(fen, device)
|
| 732 |
|
|
@@ -749,19 +450,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 749 |
probs[best_idx] = 1.0
|
| 750 |
else:
|
| 751 |
probs = F.softmax(move_values / T, dim=0)
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
for i, move in enumerate(legal_moves):
|
| 755 |
-
move_dict[move.uci()] = probs[i].item()
|
| 756 |
-
return move_uci, move_dict
|
| 757 |
|
| 758 |
return move_uci
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
# Register the configuration and model with transformers
|
| 762 |
-
AutoConfig.register("chessbot", ChessBotConfig)
|
| 763 |
-
AutoModel.register(ChessBotConfig, ChessBotModel)
|
| 764 |
-
|
| 765 |
-
# For backward compatibility
|
| 766 |
-
ChessBot = ChessBotModel
|
| 767 |
-
BT4Model = ChessBotModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import sys
|
| 5 |
import os
|
| 6 |
+
import chess
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
|
| 10 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 11 |
+
from transformers.modeling_outputs import BaseModelOutput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
from attn import RelativeMultiHeadAttention2
|
| 14 |
+
from vocab import policy_index
|
| 15 |
+
|
| 16 |
+
# Make policy_index available for imports from this module
|
| 17 |
+
__all__ = ['ChessBotConfig', 'ChessBotModel', 'policy_index', 'fen_to_tensor']
|
| 18 |
|
| 19 |
def fen_to_tensor(fen: str):
|
|
|
|
| 20 |
board = chess.Board(fen)
|
| 21 |
+
P = 19 # 12 planes for pieces + 1 for side to play + 1 for en passant + 4 for castling + 1 for 50-move rule
|
| 22 |
+
tensor = np.zeros((8, 8, P), dtype=np.float32)
|
| 23 |
|
|
|
|
| 24 |
piece_map = {
|
| 25 |
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
|
| 26 |
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
|
| 27 |
}
|
| 28 |
|
| 29 |
+
# Populate piece planes
|
| 30 |
+
for square, piece in board.piece_map().items():
|
| 31 |
+
rank, file = divmod(square, 8)
|
| 32 |
+
plane = piece_map[piece.symbol()]
|
| 33 |
+
tensor[7 - rank, file, plane] = 1.0 # Flip rank to align with standard board representation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
# Side to play plane
|
| 36 |
+
tensor[:, :, 12] = 1.0 if board.turn == chess.WHITE else 0.0
|
| 37 |
+
|
| 38 |
+
# En passant plane
|
| 39 |
+
if board.ep_square is not None:
|
| 40 |
+
rank, file = divmod(board.ep_square, 8)
|
| 41 |
+
tensor[7 - rank, file, 13] = 1.0
|
| 42 |
|
| 43 |
+
# Castling rights planes (4 total: white kingside, white queenside, black kingside, black queenside)
|
| 44 |
+
tensor[:, :, 14] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0
|
| 45 |
+
tensor[:, :, 15] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0
|
| 46 |
+
tensor[:, :, 16] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0
|
| 47 |
+
tensor[:, :, 17] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
# 50-move rule plane (normalized to [0,1])
|
| 50 |
+
tensor[:, :, 18] = min(board.halfmove_clock / 100.0, 1.0)
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
return tensor
|
| 53 |
|
| 54 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
class ChessBotConfig(PretrainedConfig):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
model_type = "chessbot"
|
| 57 |
|
| 58 |
def __init__(
|
| 59 |
self,
|
| 60 |
+
num_layers=10,
|
| 61 |
+
d_model=512,
|
| 62 |
+
d_ff=1024,
|
| 63 |
+
num_heads=8,
|
| 64 |
+
max_position_embeddings=64,
|
| 65 |
+
vocab_size=1929,
|
| 66 |
+
torch_dtype="float32",
|
| 67 |
+
**kwargs
|
| 68 |
):
|
|
|
|
| 69 |
self.num_layers = num_layers
|
| 70 |
self.d_model = d_model
|
| 71 |
self.d_ff = d_ff
|
| 72 |
self.num_heads = num_heads
|
|
|
|
| 73 |
self.max_position_embeddings = max_position_embeddings
|
| 74 |
+
self.vocab_size = vocab_size
|
| 75 |
+
self.torch_dtype = torch_dtype
|
| 76 |
+
super().__init__(**kwargs)
|
| 77 |
|
| 78 |
|
|
|
|
| 79 |
class MaGating(nn.Module):
|
| 80 |
def __init__(self, d_model):
|
| 81 |
super().__init__()
|
|
|
|
| 100 |
attn_out = self.attention(x, x, x, pos_enc)
|
| 101 |
x = attn_out + x
|
| 102 |
x = self.norm1(x)
|
| 103 |
+
|
| 104 |
y = self.ff1(x)
|
| 105 |
y = self.gelu(y)
|
| 106 |
y = self.ff2(y)
|
| 107 |
y = y + x
|
| 108 |
y = self.norm2(y)
|
| 109 |
+
|
| 110 |
return y
|
| 111 |
|
| 112 |
|
| 113 |
class AbsolutePositionalEncoder(nn.Module):
|
| 114 |
def __init__(self, d_model):
|
| 115 |
+
super(AbsolutePositionalEncoder, self).__init__()
|
| 116 |
+
self.d_model = d_model
|
| 117 |
+
self.register_buffer('position', torch.arange(64).unsqueeze(1))
|
| 118 |
+
|
| 119 |
+
positional_encoding = torch.zeros(1, 64, d_model)
|
| 120 |
_2i = torch.arange(0, d_model, step=2).float()
|
| 121 |
+
positional_encoding[:, :, 0::2] = torch.sin(self.position / (10000 ** (_2i / d_model)))
|
| 122 |
+
positional_encoding[:, :, 1::2] = torch.cos(self.position / (10000 ** (_2i / d_model)))
|
| 123 |
|
| 124 |
+
self.register_buffer('positional_encoding', positional_encoding)
|
|
|
|
| 125 |
|
| 126 |
def forward(self, x):
|
| 127 |
batch_size, _, _ = x.size()
|
| 128 |
+
return self.positional_encoding.expand(batch_size, -1, -1)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class LearnedPositionalEncoder(nn.Module):
|
| 132 |
+
def __init__(self, d_model=1929, max_len=64):
|
| 133 |
+
super(LearnedPositionalEncoder, self).__init__()
|
| 134 |
+
self.d_model = d_model
|
| 135 |
+
self.max_len = max_len
|
| 136 |
+
self.positional_embedding = nn.Embedding(max_len, d_model)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
batch_size, seq_len, _ = x.size()
|
| 140 |
+
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
|
| 141 |
+
pos_embed = self.positional_embedding(positions)
|
| 142 |
+
pos_embed = pos_embed.expand(batch_size, -1, -1)
|
| 143 |
+
return pos_embed
|
| 144 |
|
| 145 |
|
| 146 |
class ValueHead(nn.Module):
|
| 147 |
def __init__(self, d_model):
|
| 148 |
super().__init__()
|
| 149 |
self.dense1 = nn.Linear(d_model, 128)
|
| 150 |
+
self.dense2 = nn.Linear(128 * 64, 128)
|
| 151 |
self.dense3 = nn.Linear(128, 3)
|
| 152 |
|
| 153 |
def forward(self, x):
|
|
|
|
| 159 |
x = F.gelu(x)
|
| 160 |
x = self.dense3(x)
|
| 161 |
return x
|
| 162 |
+
|
| 163 |
|
| 164 |
class ValueHeadQ(nn.Module):
|
| 165 |
def __init__(self, d_model):
|
| 166 |
super().__init__()
|
| 167 |
self.dense1 = nn.Linear(d_model, 128)
|
| 168 |
+
self.dense2 = nn.Linear(128 * 64, 128)
|
| 169 |
self.dense3 = nn.Linear(128, 3)
|
| 170 |
|
| 171 |
def forward(self, x):
|
|
|
|
| 179 |
return x
|
| 180 |
|
| 181 |
|
| 182 |
+
class ChessBotModel(PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
config_class = ChessBotConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
def __init__(self, config):
|
| 186 |
super().__init__(config)
|
| 187 |
self.config = config
|
|
|
|
|
|
|
| 188 |
self.is_thinking_model = False
|
| 189 |
self.d_model = config.d_model
|
| 190 |
self.num_layers = config.num_layers
|
| 191 |
+
|
|
|
|
| 192 |
self.layers = nn.ModuleList([
|
| 193 |
EncoderLayer(config.d_model, config.d_ff, config.num_heads)
|
| 194 |
for _ in range(config.num_layers)
|
| 195 |
])
|
| 196 |
+
|
| 197 |
self.linear1 = nn.Linear(19, config.d_model)
|
| 198 |
self.layernorm1 = nn.LayerNorm(config.d_model)
|
| 199 |
self.policy_tokens_lin = nn.Linear(config.d_model, config.d_model)
|
|
|
|
| 201 |
self.keys_pol = nn.Linear(config.d_model, config.d_model)
|
| 202 |
self.positional = AbsolutePositionalEncoder(config.d_model)
|
| 203 |
self.ma_gating = MaGating(config.d_model)
|
| 204 |
+
self.policy_head = nn.Linear(64 * 64, 1929, bias=False)
|
| 205 |
self.value_head = ValueHead(config.d_model)
|
| 206 |
self.value_head_q = ValueHeadQ(config.d_model)
|
| 207 |
+
|
| 208 |
+
def forward(
|
| 209 |
+
self,
|
| 210 |
+
input_ids=None,
|
| 211 |
+
inputs_embeds=None,
|
| 212 |
+
compute_loss=False,
|
| 213 |
+
step=None,
|
| 214 |
+
**kwargs
|
| 215 |
+
):
|
| 216 |
+
# Handle both old-style input format and new HF format
|
| 217 |
+
if input_ids is not None:
|
| 218 |
inp = input_ids
|
| 219 |
+
elif inputs_embeds is not None:
|
| 220 |
+
inp = inputs_embeds
|
| 221 |
+
else:
|
| 222 |
+
raise ValueError("Either input_ids or inputs_embeds must be provided")
|
| 223 |
+
|
| 224 |
+
if isinstance(inp, (list, tuple)):
|
| 225 |
x = inp[0]
|
|
|
|
| 226 |
else:
|
| 227 |
+
x = inp
|
| 228 |
+
|
|
|
|
| 229 |
b, seq_len, _, _, emb = x.size()
|
| 230 |
x = x.view(b * seq_len, 64, emb)
|
|
|
|
| 231 |
x = self.linear1(x)
|
| 232 |
x = F.gelu(x)
|
| 233 |
x = self.layernorm1(x)
|
|
|
|
| 236 |
pos_enc = self.positional(x)
|
| 237 |
for i in range(self.num_layers):
|
| 238 |
x = self.layers[i](x, pos_enc)
|
| 239 |
+
|
| 240 |
value_h = self.value_head(x)
|
| 241 |
value_h = value_h.view(b, seq_len, 3)
|
| 242 |
value_h_q = self.value_head_q(x)
|
|
|
|
| 245 |
policy_tokens = self.policy_tokens_lin(x)
|
| 246 |
policy_tokens = F.gelu(policy_tokens)
|
| 247 |
policy_tokens = policy_tokens + pos_enc
|
|
|
|
| 248 |
queries = self.queries_pol(policy_tokens)
|
| 249 |
keys = self.keys_pol(policy_tokens)
|
| 250 |
+
|
| 251 |
matmul_qk = torch.matmul(queries, torch.transpose(keys, -2, -1))
|
| 252 |
+
dk = torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32, device=x.device))
|
|
|
|
| 253 |
policy_attn_logits = matmul_qk / dk
|
| 254 |
+
policy_attn_logits = policy_attn_logits.view(b, seq_len, 64 * 64)
|
|
|
|
| 255 |
policy = self.policy_head(policy_attn_logits)
|
| 256 |
+
|
| 257 |
+
if compute_loss and isinstance(inp, (list, tuple)) and len(inp) > 1:
|
| 258 |
targets = inp[1]
|
| 259 |
+
true_values = inp[3] if len(inp) > 3 else None
|
| 260 |
+
q_values = inp[4] if len(inp) > 4 else None
|
|
|
|
| 261 |
|
| 262 |
+
if true_values is not None and q_values is not None:
|
| 263 |
+
true_values = q_values
|
| 264 |
+
z = torch.argmax(true_values, dim=-1)
|
| 265 |
+
q = torch.argmax(q_values, dim=-1)
|
| 266 |
+
value_h_q_softmax = torch.softmax(value_h_q, dim=-1)
|
| 267 |
+
|
| 268 |
+
loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
+
valid_mask = (true_values.sum(dim=-1) != 0) & (q_values.sum(dim=-1) != 0)
|
| 271 |
+
|
| 272 |
+
if valid_mask.any():
|
| 273 |
+
valid_value_h = value_h[valid_mask]
|
| 274 |
+
valid_value_h_q = value_h_q_softmax[valid_mask]
|
| 275 |
+
valid_z = z[valid_mask]
|
| 276 |
+
valid_q_values = q_values[valid_mask]
|
| 277 |
+
|
| 278 |
+
loss_value = F.cross_entropy(valid_value_h.view(-1, valid_value_h.size(-1)), valid_z.view(-1))
|
| 279 |
+
loss_q = F.mse_loss(valid_value_h_q.view(-1, valid_value_h_q.size(-1)), valid_q_values.view(-1, 3))
|
| 280 |
+
else:
|
| 281 |
+
loss_value = torch.tensor(0.0, device=value_h.device, requires_grad=True)
|
| 282 |
+
loss_q = torch.tensor(0.0, device=value_h_q.device, requires_grad=True)
|
| 283 |
+
|
| 284 |
+
return policy, value_h, value_h_q, loss_policy, loss_value, loss_q, targets, z, q
|
| 285 |
else:
|
| 286 |
+
loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
|
| 287 |
+
return policy, value_h, value_h_q, loss_policy
|
| 288 |
+
|
|
|
|
|
|
|
|
|
|
| 289 |
return BaseModelOutput(
|
| 290 |
+
last_hidden_state=policy,
|
| 291 |
+
hidden_states=(value_h, value_h_q)
|
| 292 |
+
)
|
|
|
|
| 293 |
|
| 294 |
def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
|
| 295 |
+
board = chess.Board()
|
| 296 |
+
board.set_fen(fen)
|
| 297 |
x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
|
| 298 |
x = x.view(1, 1, 8, 8, 19)
|
| 299 |
|
| 300 |
+
output = self(x)
|
| 301 |
+
if hasattr(output, 'last_hidden_state'):
|
| 302 |
+
logits = output.last_hidden_state
|
| 303 |
+
else:
|
| 304 |
+
logits = output
|
| 305 |
+
|
| 306 |
logits = logits.view(-1, 1929)
|
| 307 |
legal_move_mask = torch.zeros((1, 1929), device=device)
|
| 308 |
for legal_move in board.legal_moves:
|
|
|
|
| 310 |
legal_move_uci = legal_move.uci()[:-1]
|
| 311 |
else:
|
| 312 |
legal_move_uci = legal_move.uci()
|
| 313 |
+
legal_move_mask[0][policy_index.index(legal_move_uci)] = 1
|
| 314 |
+
|
|
|
|
|
|
|
| 315 |
if force_legal:
|
| 316 |
+
logits = logits + (1 - legal_move_mask) * -999
|
| 317 |
|
| 318 |
if T == 0:
|
| 319 |
sampled = torch.argmax(logits, dim=-1, keepdim=True)
|
| 320 |
else:
|
| 321 |
+
probs = F.softmax(logits / T, dim=-1)
|
| 322 |
sampled = torch.multinomial(probs, num_samples=1)
|
| 323 |
if return_probs:
|
| 324 |
+
# Map to legal moves
|
| 325 |
+
legal_move_probs = {}
|
| 326 |
+
for move in board.legal_moves:
|
| 327 |
+
idx = policy_index.index(move.uci())
|
| 328 |
+
legal_move_probs[move.uci()] = probs[0,idx].item()
|
| 329 |
+
|
| 330 |
+
return legal_move_probs
|
| 331 |
+
|
| 332 |
move = policy_index[sampled.item()]
|
| 333 |
return move
|
| 334 |
|
|
|
|
| 352 |
value_logits = self.value_head_q(x_processed)
|
| 353 |
value_logits = value_logits.view(b, seq_len, 3)
|
| 354 |
value = torch.softmax(value_logits, dim=-1)
|
| 355 |
+
|
| 356 |
return value.squeeze()
|
| 357 |
|
| 358 |
def get_batch_position_values(self, fens, device="cuda"):
|
|
|
|
| 383 |
value_logits = self.value_head_q(x_processed)
|
| 384 |
value_logits = value_logits.view(b, seq_len, 3)
|
| 385 |
value_logits = torch.softmax(value_logits, dim=-1)
|
| 386 |
+
|
| 387 |
return value_logits.squeeze(1)
|
| 388 |
|
| 389 |
def calculate_move_values(self, fen, device="cuda"):
|
|
|
|
| 392 |
board.set_fen(fen)
|
| 393 |
|
| 394 |
is_white_turn = board.turn == chess.WHITE
|
|
|
|
| 395 |
legal_moves = list(board.legal_moves)
|
| 396 |
if len(legal_moves) == 0:
|
| 397 |
return [], torch.empty(0, device=device)
|
|
|
|
| 403 |
board.pop()
|
| 404 |
|
| 405 |
batch_value_q = self.get_batch_position_values(resulting_fens, device)
|
| 406 |
+
batch_value_q = batch_value_q[:, 2] - batch_value_q[:, 0]
|
| 407 |
|
|
|
|
|
|
|
| 408 |
if is_white_turn:
|
| 409 |
player_values = batch_value_q
|
| 410 |
else:
|
|
|
|
| 414 |
|
| 415 |
def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False, to_fall_back_to_policy=False):
|
| 416 |
"""Determine the best move based on the value of resulting positions using efficient batching."""
|
| 417 |
+
value = self.get_position_value(fen, device)
|
| 418 |
+
board = chess.Board()
|
| 419 |
+
board.set_fen(fen)
|
| 420 |
+
|
| 421 |
+
is_white_turn = board.turn == chess.WHITE
|
| 422 |
+
if is_white_turn:
|
| 423 |
+
value = value[2] - value[0]
|
| 424 |
+
else:
|
| 425 |
+
value = value[0] - value[2]
|
|
|
|
|
|
|
| 426 |
|
| 427 |
+
if value > 0.9 and to_fall_back_to_policy:
|
| 428 |
+
self.fall_back_to_policy = True
|
| 429 |
+
if to_fall_back_to_policy and hasattr(self, 'fall_back_to_policy') and self.fall_back_to_policy:
|
| 430 |
+
return self.get_move_from_fen_no_thinking(fen, T, device, force_legal=True, return_probs=return_probs)
|
| 431 |
|
| 432 |
legal_moves, move_values = self.calculate_move_values(fen, device)
|
| 433 |
|
|
|
|
| 450 |
probs[best_idx] = 1.0
|
| 451 |
else:
|
| 452 |
probs = F.softmax(move_values / T, dim=0)
|
| 453 |
+
|
| 454 |
+
return probs.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
return move_uci
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#vocab consists of
|
| 2 |
+
|
| 3 |
+
#all the possible moves
|
| 4 |
+
#special tokens start_think and end_think
|
| 5 |
+
#special token end
|
| 6 |
+
#special token end_variation
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
policy_index = [
|
| 10 |
+
"a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
|
| 11 |
+
"a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
|
| 12 |
+
"a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1",
|
| 13 |
+
"b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3",
|
| 14 |
+
"b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7",
|
| 15 |
+
"b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1",
|
| 16 |
+
"c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3",
|
| 17 |
+
"c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8",
|
| 18 |
+
"d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2",
|
| 19 |
+
"d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4",
|
| 20 |
+
"d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1",
|
| 21 |
+
"e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2",
|
| 22 |
+
"e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4",
|
| 23 |
+
"e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1",
|
| 24 |
+
"f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3",
|
| 25 |
+
"f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6",
|
| 26 |
+
"f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1",
|
| 27 |
+
"g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3",
|
| 28 |
+
"g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8",
|
| 29 |
+
"h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2",
|
| 30 |
+
"h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6",
|
| 31 |
+
"h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2",
|
| 32 |
+
"a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3",
|
| 33 |
+
"a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7",
|
| 34 |
+
"a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2",
|
| 35 |
+
"b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4",
|
| 36 |
+
"b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7",
|
| 37 |
+
"b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2",
|
| 38 |
+
"c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3",
|
| 39 |
+
"c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6",
|
| 40 |
+
"c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1",
|
| 41 |
+
"d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3",
|
| 42 |
+
"d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5",
|
| 43 |
+
"d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1",
|
| 44 |
+
"e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2",
|
| 45 |
+
"e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4",
|
| 46 |
+
"e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1",
|
| 47 |
+
"f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2",
|
| 48 |
+
"f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4",
|
| 49 |
+
"f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7",
|
| 50 |
+
"f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2",
|
| 51 |
+
"g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4",
|
| 52 |
+
"g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8",
|
| 53 |
+
"g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2",
|
| 54 |
+
"h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5",
|
| 55 |
+
"h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1",
|
| 56 |
+
"a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3",
|
| 57 |
+
"a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6",
|
| 58 |
+
"a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1",
|
| 59 |
+
"b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3",
|
| 60 |
+
"b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5",
|
| 61 |
+
"b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1",
|
| 62 |
+
"c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3",
|
| 63 |
+
"c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4",
|
| 64 |
+
"c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6",
|
| 65 |
+
"c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1",
|
| 66 |
+
"d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3",
|
| 67 |
+
"d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5",
|
| 68 |
+
"d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7",
|
| 69 |
+
"d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2",
|
| 70 |
+
"e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3",
|
| 71 |
+
"e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5",
|
| 72 |
+
"e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1",
|
| 73 |
+
"f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3",
|
| 74 |
+
"f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4",
|
| 75 |
+
"f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6",
|
| 76 |
+
"f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2",
|
| 77 |
+
"g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3",
|
| 78 |
+
"g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5",
|
| 79 |
+
"g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1",
|
| 80 |
+
"h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3",
|
| 81 |
+
"h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6",
|
| 82 |
+
"h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2",
|
| 83 |
+
"a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4",
|
| 84 |
+
"a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7",
|
| 85 |
+
"a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3",
|
| 86 |
+
"b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4",
|
| 87 |
+
"b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6",
|
| 88 |
+
"b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2",
|
| 89 |
+
"c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4",
|
| 90 |
+
"c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5",
|
| 91 |
+
"c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8",
|
| 92 |
+
"c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2",
|
| 93 |
+
"d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4",
|
| 94 |
+
"d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6",
|
| 95 |
+
"d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8",
|
| 96 |
+
"e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3",
|
| 97 |
+
"e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4",
|
| 98 |
+
"e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6",
|
| 99 |
+
"e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1",
|
| 100 |
+
"f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3",
|
| 101 |
+
"f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4",
|
| 102 |
+
"f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6",
|
| 103 |
+
"f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2",
|
| 104 |
+
"g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4",
|
| 105 |
+
"g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6",
|
| 106 |
+
"g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1",
|
| 107 |
+
"h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4",
|
| 108 |
+
"h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6",
|
| 109 |
+
"h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2",
|
| 110 |
+
"a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5",
|
| 111 |
+
"a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7",
|
| 112 |
+
"a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3",
|
| 113 |
+
"b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5",
|
| 114 |
+
"b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7",
|
| 115 |
+
"b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2",
|
| 116 |
+
"c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4",
|
| 117 |
+
"c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6",
|
| 118 |
+
"c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7",
|
| 119 |
+
"c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3",
|
| 120 |
+
"d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5",
|
| 121 |
+
"d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6",
|
| 122 |
+
"d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8",
|
| 123 |
+
"d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3",
|
| 124 |
+
"e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5",
|
| 125 |
+
"e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6",
|
| 126 |
+
"e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8",
|
| 127 |
+
"f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3",
|
| 128 |
+
"f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5",
|
| 129 |
+
"f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7",
|
| 130 |
+
"f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2",
|
| 131 |
+
"g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4",
|
| 132 |
+
"g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6",
|
| 133 |
+
"g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1",
|
| 134 |
+
"h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4",
|
| 135 |
+
"h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6",
|
| 136 |
+
"h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2",
|
| 137 |
+
"a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5",
|
| 138 |
+
"a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7",
|
| 139 |
+
"a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3",
|
| 140 |
+
"b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5",
|
| 141 |
+
"b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7",
|
| 142 |
+
"b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2",
|
| 143 |
+
"c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5",
|
| 144 |
+
"c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6",
|
| 145 |
+
"c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8",
|
| 146 |
+
"c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3",
|
| 147 |
+
"d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5",
|
| 148 |
+
"d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7",
|
| 149 |
+
"d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8",
|
| 150 |
+
"e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4",
|
| 151 |
+
"e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6",
|
| 152 |
+
"e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7",
|
| 153 |
+
"e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2",
|
| 154 |
+
"f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5",
|
| 155 |
+
"f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6",
|
| 156 |
+
"f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8",
|
| 157 |
+
"f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3",
|
| 158 |
+
"g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6",
|
| 159 |
+
"g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7",
|
| 160 |
+
"g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2",
|
| 161 |
+
"h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6",
|
| 162 |
+
"h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7",
|
| 163 |
+
"h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3",
|
| 164 |
+
"a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7",
|
| 165 |
+
"a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8",
|
| 166 |
+
"b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5",
|
| 167 |
+
"b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7",
|
| 168 |
+
"b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8",
|
| 169 |
+
"c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5",
|
| 170 |
+
"c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7",
|
| 171 |
+
"c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8",
|
| 172 |
+
"c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4",
|
| 173 |
+
"d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6",
|
| 174 |
+
"d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8",
|
| 175 |
+
"d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4",
|
| 176 |
+
"e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6",
|
| 177 |
+
"e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7",
|
| 178 |
+
"e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2",
|
| 179 |
+
"f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5",
|
| 180 |
+
"f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7",
|
| 181 |
+
"f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1",
|
| 182 |
+
"g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5",
|
| 183 |
+
"g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7",
|
| 184 |
+
"g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1",
|
| 185 |
+
"h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5",
|
| 186 |
+
"h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7",
|
| 187 |
+
"h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2",
|
| 188 |
+
"a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6",
|
| 189 |
+
"a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8",
|
| 190 |
+
"a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5",
|
| 191 |
+
"b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7",
|
| 192 |
+
"b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2",
|
| 193 |
+
"c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6",
|
| 194 |
+
"c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8",
|
| 195 |
+
"c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4",
|
| 196 |
+
"d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6",
|
| 197 |
+
"d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8",
|
| 198 |
+
"d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5",
|
| 199 |
+
"e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7",
|
| 200 |
+
"e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8",
|
| 201 |
+
"e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5",
|
| 202 |
+
"f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7",
|
| 203 |
+
"f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1",
|
| 204 |
+
"g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6",
|
| 205 |
+
"g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8",
|
| 206 |
+
"g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2",
|
| 207 |
+
"h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6",
|
| 208 |
+
"h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8",
|
| 209 |
+
"h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q",
|
| 210 |
+
"b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b",
|
| 211 |
+
"c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r",
|
| 212 |
+
"c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q",
|
| 213 |
+
"d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b",
|
| 214 |
+
"e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r",
|
| 215 |
+
"f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q",
|
| 216 |
+
"g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b",
|
| 217 |
+
"h7h8q", "h7h8r", "h7h8b", #add the promotions for black
|
| 218 |
+
"a2a1q","a2a1r","a2a1b","a2b1q","a2b1r","a2b1b",
|
| 219 |
+
"b2a1q","b2a1r","b2a1b","b2b1q","b2b1r","b2b1b","b2c1q","b2c1r","b2c1b",
|
| 220 |
+
"c2b1q","c2b1r","c2b1b","c2c1q","c2c1r","c2c1b","c2d1q","c2d1r","c2d1b",
|
| 221 |
+
"d2c1q","d2c1r","d2c1b","d2d1q","d2d1r","d2d1b","d2e1q","d2e1r","d2e1b",
|
| 222 |
+
"e2d1q","e2d1r","e2d1b","e2e1q","e2e1r","e2e1b","e2f1q","e2f1r","e2f1b",
|
| 223 |
+
"f2e1q","f2e1r","f2e1b","f2f1q","f2f1r","f2f1b","f2g1q","f2g1r","f2g1b",
|
| 224 |
+
"g2f1q","g2f1r","g2f1b","g2g1q","g2g1r","g2g1b","g2h1q","g2h1r","g2h1b",
|
| 225 |
+
"h2g1q","h2g1r","h2g1b","h2h1q","h2h1r","h2h1b",#add special tokens
|
| 226 |
+
"<thinking>","</thinking>","end_variation","end","padding_token"
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
print("Number of unique tokens: ", len(policy_index))
|