|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from .attn_map import apm_map, apm_out |
|
|
import math |
|
|
from .encoding_simple import encode_fen_to_tensor, encode_moves_to_tensor |
|
|
from .vocab import policy_index |
|
|
from typing import Union, List, Optional |
|
|
import bulletchess |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
class Gating(nn.Module): |
|
|
def __init__(self, features_shape, additive=True, init_value=None): |
|
|
super(Gating, self).__init__() |
|
|
self.additive = additive |
|
|
if init_value is None: |
|
|
init_value = 0 if self.additive else 1 |
|
|
|
|
|
self.gate = nn.Parameter(torch.full(features_shape, float(init_value))) |
|
|
if not self.additive: |
|
|
self.gate.register_hook(lambda grad: torch.clamp(grad, min=0)) |
|
|
|
|
|
def forward(self, x): |
|
|
if self.additive: |
|
|
return x + self.gate |
|
|
else: |
|
|
return x * self.gate |
|
|
|
|
|
def ma_gating(x, in_features): |
|
|
x = Gating(in_features, additive=False)(x) |
|
|
x = Gating(in_features, additive=True)(x) |
|
|
return x |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, in_features, scale=True): |
|
|
super(RMSNorm, self).__init__() |
|
|
self.scale = scale |
|
|
if self.scale: |
|
|
self.gamma = nn.Parameter(torch.ones(in_features)) |
|
|
|
|
|
def forward(self, x): |
|
|
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-5) |
|
|
x_normalized = x / rms |
|
|
if self.scale: |
|
|
return x_normalized * self.gamma |
|
|
return x_normalized |
|
|
|
|
|
class ApplyAttentionPolicyMap(nn.Module): |
|
|
def __init__(self): |
|
|
super(ApplyAttentionPolicyMap, self).__init__() |
|
|
|
|
|
|
|
|
self.register_buffer('fc1', torch.from_numpy(apm_map).float()) |
|
|
self.register_buffer('idx', torch.from_numpy(apm_out).long()) |
|
|
|
|
|
def forward(self, logits, pp_logits): |
|
|
logits = torch.cat([logits.reshape(-1, 64 * 64), |
|
|
pp_logits.reshape(-1, 8 * 24)], |
|
|
dim=1) |
|
|
|
|
|
batch_size = logits.size(0) |
|
|
idx = self.idx.unsqueeze(0).expand(batch_size, -1) |
|
|
|
|
|
return torch.gather(logits, 1, idx) |
|
|
|
|
|
class Mish(nn.Module): |
|
|
def __init__(self): |
|
|
super(Mish, self).__init__() |
|
|
|
|
|
def forward(self, x): |
|
|
return x * torch.tanh(F.softplus(x)) |
|
|
|
|
|
class CustomMHA(nn.Module): |
|
|
def __init__(self, emb_size, d_model, num_heads, dropout=0.0, use_bias_qkv=True, use_bias_out=True, |
|
|
use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'): |
|
|
super(CustomMHA, self).__init__() |
|
|
assert d_model % num_heads == 0 |
|
|
self.emb_size = emb_size |
|
|
self.d_model = d_model |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = d_model // num_heads |
|
|
self.wq = nn.Linear(emb_size, d_model, bias=use_bias_qkv) |
|
|
self.wk = nn.Linear(emb_size, d_model, bias=use_bias_qkv) |
|
|
self.wv = nn.Linear(emb_size, d_model, bias=use_bias_qkv) |
|
|
self.out_proj = nn.Linear(d_model, emb_size, bias=use_bias_out) |
|
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.smol_compress = None |
|
|
self.smol_hidden1 = None |
|
|
self.smol_hidden1_ln = None |
|
|
self.smol_gen_from = None |
|
|
self.smol_gen_from_ln = None |
|
|
self.smol_weight_gen = None |
|
|
if use_smolgen: |
|
|
self.smol_compress = nn.Linear(emb_size, smol_hidden_channels, bias=False) |
|
|
self.smol_hidden1 = nn.Linear(64 * smol_hidden_channels, smol_hidden_sz, bias=True) |
|
|
self.smol_hidden1_ln = nn.LayerNorm(smol_hidden_sz, eps=1e-3) |
|
|
self.smol_gen_from = nn.Linear(smol_hidden_sz, num_heads * smol_gen_sz, bias=True) |
|
|
self.smol_gen_from_ln = nn.LayerNorm(num_heads * smol_gen_sz, eps=1e-3) |
|
|
self.smol_weight_gen = nn.Linear(smol_gen_sz, 64 * 64, bias=False) |
|
|
self.smol_activation = smol_activation |
|
|
|
|
|
def _shape(self, x): |
|
|
b, l, _ = x.shape |
|
|
return x.view(b, l, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
def forward(self, x, return_attn=False): |
|
|
|
|
|
q = self.wq(x) |
|
|
k = self.wk(x) |
|
|
v = self.wv(x) |
|
|
q = self._shape(q) |
|
|
k = self._shape(k) |
|
|
v = self._shape(v) |
|
|
scale = torch.sqrt(torch.tensor(self.head_dim, dtype=x.dtype, device=x.device)) |
|
|
attn_logits = torch.matmul(q, k.transpose(-2, -1)) / scale |
|
|
|
|
|
smol_w = None |
|
|
if self.smol_compress is not None: |
|
|
b, l, _ = x.shape |
|
|
compressed = self.smol_compress(x) |
|
|
compressed = compressed.reshape(b, l * compressed.shape[-1]) |
|
|
hidden_pre = self.smol_hidden1(compressed) |
|
|
hidden = F.silu(hidden_pre) if self.smol_activation == 'swish' else F.silu(hidden_pre) |
|
|
hidden_ln = self.smol_hidden1_ln(hidden) |
|
|
gen_from_pre = self.smol_gen_from(hidden_ln) |
|
|
gen_from_act = F.silu(gen_from_pre) if self.smol_activation == 'swish' else F.silu(gen_from_pre) |
|
|
gen_from = self.smol_gen_from_ln(gen_from_act) |
|
|
gen_from = gen_from.view(b, self.num_heads, -1) |
|
|
smol_w = self.smol_weight_gen(gen_from) |
|
|
smol_w = smol_w.view(b, self.num_heads, l, l) |
|
|
attn_logits = attn_logits + smol_w |
|
|
|
|
|
attn_logits = attn_logits - attn_logits.max(dim=-1, keepdim=True)[0] |
|
|
attn_weights = torch.exp(attn_logits) |
|
|
attn_weights = attn_weights / attn_weights.sum(dim=-1, keepdim=True) |
|
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
attn_output = torch.matmul(attn_weights, v) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(x.size(0), x.size(1), self.d_model) |
|
|
out = self.out_proj(attn_output) |
|
|
if return_attn: |
|
|
return out, attn_weights, smol_w, attn_logits |
|
|
return out |
|
|
|
|
|
class FFN(nn.Module): |
|
|
def __init__(self, emb_size, dff, activation=Mish(), omit_other_biases=False): |
|
|
super(FFN, self).__init__() |
|
|
self.dense1 = nn.Linear(emb_size, dff, bias=not omit_other_biases) |
|
|
self.activation = activation |
|
|
self.dense2 = nn.Linear(dff, emb_size, bias=not omit_other_biases) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.dense1(x) |
|
|
x = self.activation(x) |
|
|
x = self.dense2(x) |
|
|
return x |
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
|
def __init__(self, emb_size, d_model, num_heads, dff, dropout_rate, encoder_layers, skip_first_ln=False, encoder_rms_norm=False, omit_qkv_biases=False, omit_other_biases=False, |
|
|
use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'): |
|
|
super(EncoderLayer, self).__init__() |
|
|
self.mha = CustomMHA(emb_size, d_model, num_heads, dropout=dropout_rate, use_bias_qkv=not omit_qkv_biases, use_bias_out=not omit_other_biases, |
|
|
use_smolgen=use_smolgen, smol_hidden_channels=smol_hidden_channels, smol_hidden_sz=smol_hidden_sz, smol_gen_sz=smol_gen_sz, smol_activation=smol_activation) |
|
|
self.ffn = FFN(emb_size, dff, omit_other_biases=omit_other_biases) |
|
|
|
|
|
self.norm1 = RMSNorm(emb_size) if encoder_rms_norm else nn.LayerNorm(emb_size, eps=0.001) |
|
|
self.norm2 = RMSNorm(emb_size) if encoder_rms_norm else nn.LayerNorm(emb_size, eps=0.001) |
|
|
|
|
|
self.dropout1 = nn.Dropout(dropout_rate) |
|
|
self.dropout2 = nn.Dropout(dropout_rate) |
|
|
|
|
|
self.alpha = (2. * encoder_layers)**-0.25 |
|
|
self.skip_first_ln = skip_first_ln |
|
|
|
|
|
def forward(self, x): |
|
|
attn_output = self.mha(x) |
|
|
attn_output = self.dropout1(attn_output) |
|
|
|
|
|
out1 = x + attn_output * self.alpha |
|
|
if not self.skip_first_ln: |
|
|
out1 = self.norm1(out1) |
|
|
ffn_output = self.ffn(out1) |
|
|
ffn_output = self.dropout2(ffn_output) |
|
|
|
|
|
out2 = self.norm2(out1 + ffn_output * self.alpha) |
|
|
return out2 |
|
|
|
|
|
class PolicyHead(nn.Module): |
|
|
def __init__(self, pol_embedding_size, policy_d_model, opponent=False): |
|
|
super(PolicyHead, self).__init__() |
|
|
self.opponent = opponent |
|
|
self.wq = nn.Linear(pol_embedding_size, policy_d_model) |
|
|
self.wk = nn.Linear(pol_embedding_size, policy_d_model) |
|
|
self.ppo = nn.Linear(policy_d_model, 4, bias=False) |
|
|
self.apply_map = ApplyAttentionPolicyMap() |
|
|
|
|
|
def forward(self, x): |
|
|
if self.opponent: |
|
|
x = torch.flip(x, [1]) |
|
|
|
|
|
queries = self.wq(x) |
|
|
keys = self.wk(x) |
|
|
|
|
|
matmul_qk = torch.matmul(queries, keys.transpose(-2, -1)) |
|
|
|
|
|
dk = torch.sqrt(torch.tensor(keys.shape[-1], dtype=keys.dtype, device=keys.device)) |
|
|
promotion_keys = keys[:, -8:, :] |
|
|
promotion_offsets = self.ppo(promotion_keys).transpose(-2,-1) * dk |
|
|
promotion_offsets = promotion_offsets[:, :3, :] + promotion_offsets[:, 3:4, :] |
|
|
|
|
|
n_promo_logits = matmul_qk[:, -16:-8, -8:] |
|
|
q_promo_logits = (n_promo_logits + promotion_offsets[:, 0:1, :]).unsqueeze(3) |
|
|
r_promo_logits = (n_promo_logits + promotion_offsets[:, 1:2, :]).unsqueeze(3) |
|
|
b_promo_logits = (n_promo_logits + promotion_offsets[:, 2:3, :]).unsqueeze(3) |
|
|
promotion_logits = torch.cat([q_promo_logits, r_promo_logits, b_promo_logits], axis=3).reshape(-1, 8, 24) |
|
|
|
|
|
promotion_logits = promotion_logits / dk |
|
|
policy_attn_logits = matmul_qk / dk |
|
|
|
|
|
return self.apply_map(policy_attn_logits, promotion_logits) |
|
|
|
|
|
class ValueHead(nn.Module): |
|
|
def __init__(self, embedding_size, val_embedding_size, default_activation=Mish()): |
|
|
super(ValueHead, self).__init__() |
|
|
self.embedding = nn.Linear(embedding_size, val_embedding_size) |
|
|
self.activation = default_activation |
|
|
self.flatten = nn.Flatten() |
|
|
self.dense1 = nn.Linear(val_embedding_size * 64, 128) |
|
|
self.dense2 = nn.Linear(128, 3) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.embedding(x) |
|
|
x = self.activation(x) |
|
|
x = self.flatten(x) |
|
|
x = self.dense1(x) |
|
|
x = self.activation(x) |
|
|
x = self.dense2(x) |
|
|
return x |
|
|
|
|
|
class BT4Config(PretrainedConfig): |
|
|
"""Configuration class for BT4 model.""" |
|
|
model_type = "bt4" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedding_size=1024, |
|
|
embedding_dense_sz=512, |
|
|
encoder_layers=15, |
|
|
encoder_d_model=1024, |
|
|
encoder_heads=32, |
|
|
encoder_dff=1536, |
|
|
dropout_rate=0.0, |
|
|
pol_embedding_size=1024, |
|
|
policy_d_model=1024, |
|
|
val_embedding_size=128, |
|
|
use_smolgen=True, |
|
|
smol_hidden_channels=32, |
|
|
smol_hidden_sz=256, |
|
|
smol_gen_sz=256, |
|
|
smol_activation="swish", |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.embedding_size = embedding_size |
|
|
self.embedding_dense_sz = embedding_dense_sz |
|
|
self.encoder_layers = encoder_layers |
|
|
self.encoder_d_model = encoder_d_model |
|
|
self.encoder_heads = encoder_heads |
|
|
self.encoder_dff = encoder_dff |
|
|
self.dropout_rate = dropout_rate |
|
|
self.pol_embedding_size = pol_embedding_size |
|
|
self.policy_d_model = policy_d_model |
|
|
self.val_embedding_size = val_embedding_size |
|
|
self.use_smolgen = use_smolgen |
|
|
self.smol_hidden_channels = smol_hidden_channels |
|
|
self.smol_hidden_sz = smol_hidden_sz |
|
|
self.smol_gen_sz = smol_gen_sz |
|
|
self.smol_activation = smol_activation |
|
|
|
|
|
class BT4(PreTrainedModel): |
|
|
config_class = BT4Config |
|
|
|
|
|
def __init__(self, config=None, embedding_size=1024, embedding_dense_sz=512, encoder_layers=15, encoder_d_model=1024, encoder_heads=32, encoder_dff=1536, dropout_rate=0.0, pol_embedding_size=1024, policy_d_model=1024, val_embedding_size=128, default_activation=Mish(), |
|
|
use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'): |
|
|
|
|
|
if config is None: |
|
|
config = BT4Config( |
|
|
embedding_size=embedding_size, |
|
|
embedding_dense_sz=embedding_dense_sz, |
|
|
encoder_layers=encoder_layers, |
|
|
encoder_d_model=encoder_d_model, |
|
|
encoder_heads=encoder_heads, |
|
|
encoder_dff=encoder_dff, |
|
|
dropout_rate=dropout_rate, |
|
|
pol_embedding_size=pol_embedding_size, |
|
|
policy_d_model=policy_d_model, |
|
|
val_embedding_size=val_embedding_size, |
|
|
use_smolgen=use_smolgen, |
|
|
smol_hidden_channels=smol_hidden_channels, |
|
|
smol_hidden_sz=smol_hidden_sz, |
|
|
smol_gen_sz=smol_gen_sz, |
|
|
smol_activation=smol_activation, |
|
|
) |
|
|
super(BT4, self).__init__(config) |
|
|
|
|
|
|
|
|
embedding_size = config.embedding_size |
|
|
embedding_dense_sz = config.embedding_dense_sz |
|
|
encoder_layers = config.encoder_layers |
|
|
encoder_d_model = config.encoder_d_model |
|
|
encoder_heads = config.encoder_heads |
|
|
encoder_dff = config.encoder_dff |
|
|
dropout_rate = config.dropout_rate |
|
|
pol_embedding_size = config.pol_embedding_size |
|
|
policy_d_model = config.policy_d_model |
|
|
val_embedding_size = config.val_embedding_size |
|
|
use_smolgen = config.use_smolgen |
|
|
smol_hidden_channels = config.smol_hidden_channels |
|
|
smol_hidden_sz = config.smol_hidden_sz |
|
|
smol_gen_sz = config.smol_gen_sz |
|
|
smol_activation = config.smol_activation |
|
|
self.embedding_dense_sz = embedding_dense_sz |
|
|
|
|
|
self.deepnorm_alpha = (2. * encoder_layers) ** -0.25 |
|
|
|
|
|
self.embedding_preprocess = nn.Linear(64*12, 64*self.embedding_dense_sz) |
|
|
self.embedding = nn.Linear(112 + self.embedding_dense_sz, embedding_size) |
|
|
nn.init.xavier_uniform_(self.embedding.weight) |
|
|
nn.init.zeros_(self.embedding.bias) |
|
|
|
|
|
self.embedding_ln = nn.LayerNorm(embedding_size, eps=0.001) |
|
|
|
|
|
self.gating_mult = Gating((64, embedding_size), additive=False) |
|
|
self.gating_add = Gating((64, embedding_size), additive=True) |
|
|
|
|
|
self.embedding_ffn = FFN(embedding_size, encoder_dff) |
|
|
self.embedding_ffn_ln = nn.LayerNorm(embedding_size, eps=0.001) |
|
|
|
|
|
self.encoder_layers_list = nn.ModuleList([ |
|
|
EncoderLayer(embedding_size, encoder_d_model, encoder_heads, encoder_dff, dropout_rate, encoder_layers, |
|
|
use_smolgen=use_smolgen, smol_hidden_channels=smol_hidden_channels, smol_hidden_sz=smol_hidden_sz, smol_gen_sz=smol_gen_sz, smol_activation=smol_activation) |
|
|
for _ in range(encoder_layers) |
|
|
]) |
|
|
|
|
|
self.policy_embedding = nn.Linear(embedding_size, pol_embedding_size) |
|
|
self.policy_head = PolicyHead(pol_embedding_size, policy_d_model) |
|
|
self.value_head_winner = ValueHead(embedding_size, val_embedding_size) |
|
|
self.value_head_q = ValueHead(embedding_size, val_embedding_size) |
|
|
self.activation = default_activation |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
"""Load model from pretrained checkpoint (required by transformers).""" |
|
|
from transformers import AutoConfig |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
import os |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
|
|
|
|
|
|
|
|
model = cls(config=config) |
|
|
|
|
|
|
|
|
is_hf_hub = "/" in pretrained_model_name_or_path and not os.path.isdir(pretrained_model_name_or_path) |
|
|
|
|
|
if is_hf_hub: |
|
|
|
|
|
safetensors_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="model.safetensors", |
|
|
cache_dir=kwargs.get("cache_dir", None), |
|
|
token=kwargs.get("token", None), |
|
|
) |
|
|
state_dict = load_file(safetensors_path) |
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
if missing_keys: |
|
|
print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys") |
|
|
if unexpected_keys: |
|
|
print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys") |
|
|
else: |
|
|
|
|
|
safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") |
|
|
if os.path.exists(safetensors_path): |
|
|
state_dict = load_file(safetensors_path) |
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
if missing_keys: |
|
|
print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys") |
|
|
if unexpected_keys: |
|
|
print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys") |
|
|
else: |
|
|
|
|
|
pt_path = os.path.join(pretrained_model_name_or_path, "model.pt") |
|
|
checkpoint = torch.load(pt_path, map_location="cpu") |
|
|
if isinstance(checkpoint, dict): |
|
|
if "state_dict" in checkpoint: |
|
|
state_dict = checkpoint["state_dict"] |
|
|
elif "model" in checkpoint: |
|
|
state_dict = checkpoint["model"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
if missing_keys: |
|
|
print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys") |
|
|
if unexpected_keys: |
|
|
print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys") |
|
|
|
|
|
return model |
|
|
|
|
|
@classmethod |
|
|
def register_for_auto_class(cls, auto_class): |
|
|
"""Register this class for auto class loading (required by transformers).""" |
|
|
|
|
|
pass |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
nn.init.xavier_normal_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
flow = x.permute(0, 2, 3, 1).reshape(-1, 64, 112) |
|
|
|
|
|
pos_info = flow[..., :12] |
|
|
pos_info_flat = pos_info.reshape(-1, 64 * 12) |
|
|
|
|
|
pos_info_processed = self.embedding_preprocess(pos_info_flat) |
|
|
pos_info = pos_info_processed.reshape(-1, 64, self.embedding_dense_sz) |
|
|
|
|
|
flow = torch.cat([flow, pos_info], dim=-1) |
|
|
|
|
|
flow = self.embedding(flow) |
|
|
|
|
|
flow = self.activation(flow) |
|
|
|
|
|
flow = self.embedding_ln(flow) |
|
|
|
|
|
flow = self.gating_mult(flow) |
|
|
flow = self.gating_add(flow) |
|
|
|
|
|
ffn_dense1_pre = self.embedding_ffn.dense1(flow) |
|
|
ffn_dense1 = self.embedding_ffn.activation(ffn_dense1_pre) |
|
|
ffn_output = self.embedding_ffn.dense2(ffn_dense1) |
|
|
|
|
|
residual = flow + ffn_output * self.deepnorm_alpha |
|
|
flow = self.embedding_ffn_ln(residual) |
|
|
|
|
|
for i, layer in enumerate(self.encoder_layers_list): |
|
|
flow = layer(flow) |
|
|
|
|
|
policy_tokens = self.policy_embedding(flow) |
|
|
policy_tokens = self.activation(policy_tokens) |
|
|
|
|
|
policy_logits = self.policy_head(policy_tokens) |
|
|
|
|
|
value_winner = self.value_head_winner(flow) |
|
|
value_q = self.value_head_q(flow) |
|
|
|
|
|
return policy_logits, value_winner, value_q |
|
|
|
|
|
def get_move_from_history(self, fen_or_moves: Union[str, List[str]], T: float, device: str = None, **kwargs) -> str: |
|
|
""" |
|
|
Predict a move from a move history or FEN position. |
|
|
|
|
|
Args: |
|
|
fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves |
|
|
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic) |
|
|
device: Device to run the model on (if None, uses model's device) |
|
|
return_probs: If True, returns a dictionary of move probabilities instead of a single move |
|
|
|
|
|
Returns: |
|
|
UCI move string (e.g., 'e2e4') or dictionary of move probabilities if return_probs=True |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
device = next(self.parameters()).device |
|
|
else: |
|
|
device = torch.device(device) |
|
|
|
|
|
|
|
|
if isinstance(fen_or_moves, str): |
|
|
|
|
|
fen = fen_or_moves |
|
|
is_black_to_move = fen.split()[1] == 'b' |
|
|
input_tensor_112, legal_moves_mask = encode_fen_to_tensor(fen) |
|
|
castling_rights = fen.split()[2] if len(fen.split()) > 2 else "" |
|
|
elif isinstance(fen_or_moves, list): |
|
|
|
|
|
move_history = fen_or_moves |
|
|
input_tensor_112, legal_moves_mask = encode_moves_to_tensor(move_history) |
|
|
|
|
|
board = bulletchess.Board() |
|
|
for mv in move_history: |
|
|
move = bulletchess.Move.from_uci(mv) |
|
|
board.apply(move) |
|
|
is_black_to_move = (board.turn == bulletchess.BLACK) |
|
|
fen_parts = board.fen().split() |
|
|
castling_rights = fen_parts[2] if len(fen_parts) > 2 else "" |
|
|
else: |
|
|
raise ValueError("Input must be a FEN string or a list of UCI moves") |
|
|
|
|
|
input_tensor_112 = input_tensor_112.to(device, non_blocking=True) |
|
|
|
|
|
self.eval() |
|
|
with torch.inference_mode(): |
|
|
policy_logits,_,_ = self.forward(input_tensor_112) |
|
|
|
|
|
|
|
|
logits0 = policy_logits[0] + torch.from_numpy(legal_moves_mask).to(policy_logits.device) |
|
|
|
|
|
|
|
|
return_probs = kwargs.get('return_probs', False) |
|
|
|
|
|
if return_probs: |
|
|
|
|
|
scaled_logits = logits0 / T if T > 0 else logits0 |
|
|
probs = F.softmax(scaled_logits, dim=0) |
|
|
probs_dict = {} |
|
|
for idx, move in enumerate(policy_index): |
|
|
prob_val = probs[idx].item() |
|
|
if prob_val > 1e-6: |
|
|
probs_dict[move] = prob_val |
|
|
return probs_dict |
|
|
|
|
|
if T == 0.0: |
|
|
|
|
|
best_move_idx = torch.argmax(logits0).item() |
|
|
uci_move = policy_index[best_move_idx] |
|
|
else: |
|
|
|
|
|
|
|
|
scaled_logits = logits0 / T |
|
|
|
|
|
probs = F.softmax(scaled_logits, dim=0) |
|
|
|
|
|
move_idx = torch.multinomial(probs, 1).item() |
|
|
uci_move = policy_index[move_idx] |
|
|
|
|
|
|
|
|
|
|
|
if is_black_to_move: |
|
|
def mirror_rank(rank_char): |
|
|
rank = int(rank_char) |
|
|
return str(9 - rank) |
|
|
|
|
|
|
|
|
if len(uci_move) >= 4: |
|
|
from_file = uci_move[0] |
|
|
from_rank = uci_move[1] |
|
|
to_file = uci_move[2] |
|
|
to_rank = uci_move[3] |
|
|
promo = uci_move[4:] if len(uci_move) > 4 else "" |
|
|
|
|
|
uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if uci_move == "e1h1" and "K" in castling_rights: |
|
|
uci_move = "e1g1" |
|
|
elif uci_move == "e1a1" and "Q" in castling_rights: |
|
|
uci_move = "e1c1" |
|
|
|
|
|
elif uci_move == "e8h8" and "k" in castling_rights: |
|
|
uci_move = "e8g8" |
|
|
elif uci_move == "e8a8" and "q" in castling_rights: |
|
|
uci_move = "e8c8" |
|
|
|
|
|
return uci_move |
|
|
|
|
|
def get_best_move_value(self, fen_or_moves: Union[str, List[str]], T: float = 0.0, device: str = None) -> tuple: |
|
|
""" |
|
|
Get the best move and its value using value analysis. |
|
|
|
|
|
Args: |
|
|
fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves |
|
|
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic) |
|
|
device: Device to run the model on (if None, uses model's device) |
|
|
|
|
|
Returns: |
|
|
Tuple of (best_move, value) where value is the position evaluation |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
device = next(self.parameters()).device |
|
|
else: |
|
|
device = torch.device(device) |
|
|
|
|
|
|
|
|
if isinstance(fen_or_moves, str): |
|
|
fen = fen_or_moves |
|
|
is_black_to_move = fen.split()[1] == 'b' |
|
|
input_tensor_112, legal_moves_mask = encode_fen_to_tensor(fen) |
|
|
castling_rights = fen.split()[2] if len(fen.split()) > 2 else "" |
|
|
elif isinstance(fen_or_moves, list): |
|
|
move_history = fen_or_moves |
|
|
input_tensor_112, legal_moves_mask = encode_moves_to_tensor(move_history) |
|
|
board = bulletchess.Board() |
|
|
for mv in move_history: |
|
|
move = bulletchess.Move.from_uci(mv) |
|
|
board.apply(move) |
|
|
is_black_to_move = (board.turn == bulletchess.BLACK) |
|
|
fen_parts = board.fen().split() |
|
|
castling_rights = fen_parts[2] if len(fen_parts) > 2 else "" |
|
|
else: |
|
|
raise ValueError("Input must be a FEN string or a list of UCI moves") |
|
|
|
|
|
input_tensor_112 = input_tensor_112.to(device, non_blocking=True) |
|
|
|
|
|
self.eval() |
|
|
with torch.inference_mode(): |
|
|
policy_logits, _, value_q = self.forward(input_tensor_112) |
|
|
|
|
|
|
|
|
logits0 = policy_logits[0] + torch.from_numpy(legal_moves_mask).to(policy_logits.device) |
|
|
|
|
|
|
|
|
if T == 0.0: |
|
|
best_move_idx = torch.argmax(logits0).item() |
|
|
else: |
|
|
scaled_logits = logits0 / T |
|
|
probs = F.softmax(scaled_logits, dim=0) |
|
|
move_idx = torch.multinomial(probs, 1).item() |
|
|
best_move_idx = move_idx |
|
|
|
|
|
uci_move = policy_index[best_move_idx] |
|
|
|
|
|
|
|
|
if is_black_to_move: |
|
|
def mirror_rank(rank_char): |
|
|
rank = int(rank_char) |
|
|
return str(9 - rank) |
|
|
|
|
|
if len(uci_move) >= 4: |
|
|
from_file = uci_move[0] |
|
|
from_rank = uci_move[1] |
|
|
to_file = uci_move[2] |
|
|
to_rank = uci_move[3] |
|
|
promo = uci_move[4:] if len(uci_move) > 4 else "" |
|
|
uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo |
|
|
|
|
|
|
|
|
if uci_move == "e1h1" and "K" in castling_rights: |
|
|
uci_move = "e1g1" |
|
|
elif uci_move == "e1a1" and "Q" in castling_rights: |
|
|
uci_move = "e1c1" |
|
|
elif uci_move == "e8h8" and "k" in castling_rights: |
|
|
uci_move = "e8g8" |
|
|
elif uci_move == "e8a8" and "q" in castling_rights: |
|
|
uci_move = "e8c8" |
|
|
|
|
|
|
|
|
value_probs = F.softmax(value_q[0], dim=0) |
|
|
value = value_probs.cpu().numpy() |
|
|
|
|
|
return uci_move, value |
|
|
|
|
|
def get_position_value(self, fen_or_moves: Union[str, List[str]], device: str = None) -> np.ndarray: |
|
|
""" |
|
|
Get position evaluation using value_q. |
|
|
|
|
|
Args: |
|
|
fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves |
|
|
device: Device to run the model on (if None, uses model's device) |
|
|
|
|
|
Returns: |
|
|
Array of [black_win, draw, white_win] probabilities |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
device = next(self.parameters()).device |
|
|
else: |
|
|
device = torch.device(device) |
|
|
|
|
|
|
|
|
if isinstance(fen_or_moves, str): |
|
|
input_tensor_112, _ = encode_fen_to_tensor(fen_or_moves) |
|
|
elif isinstance(fen_or_moves, list): |
|
|
input_tensor_112, _ = encode_moves_to_tensor(fen_or_moves) |
|
|
else: |
|
|
raise ValueError("Input must be a FEN string or a list of UCI moves") |
|
|
|
|
|
input_tensor_112 = input_tensor_112.to(device, non_blocking=True) |
|
|
|
|
|
self.eval() |
|
|
with torch.inference_mode(): |
|
|
_, _, value_q = self.forward(input_tensor_112) |
|
|
|
|
|
|
|
|
value_probs = F.softmax(value_q[0], dim=0) |
|
|
return value_probs.cpu().numpy() |
|
|
|
|
|
def batch_get_moves_from_fens(self, fens: List[str], T: float, device: str = None, use_fp16: bool = False) -> List[str]: |
|
|
""" |
|
|
Get moves for multiple FEN positions using batched inference. |
|
|
|
|
|
Args: |
|
|
fens: List of FEN strings representing chess positions |
|
|
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic) |
|
|
device: Device to run the model on (if None, uses model's device) |
|
|
|
|
|
Returns: |
|
|
List of UCI move strings |
|
|
""" |
|
|
if not fens: |
|
|
return [] |
|
|
|
|
|
|
|
|
if device is None: |
|
|
device = next(self.parameters()).device |
|
|
else: |
|
|
device = torch.device(device) |
|
|
|
|
|
batch_size = len(fens) |
|
|
|
|
|
|
|
|
input_tensors = [] |
|
|
legal_moves_masks = [] |
|
|
is_black_to_move_list = [] |
|
|
castling_rights_list = [] |
|
|
|
|
|
for fen in fens: |
|
|
input_tensor, legal_mask = encode_fen_to_tensor(fen) |
|
|
input_tensors.append(input_tensor.squeeze(0)) |
|
|
legal_moves_masks.append(legal_mask) |
|
|
is_black_to_move_list.append(fen.split()[1] == 'b') |
|
|
castling_rights_list.append(fen.split()[2] if len(fen.split()) > 2 else "") |
|
|
|
|
|
|
|
|
batch_tensor = torch.stack(input_tensors).to(device, non_blocking=True) |
|
|
if use_fp16 and device.type == 'cuda': |
|
|
batch_tensor = batch_tensor.half() |
|
|
|
|
|
|
|
|
self.eval() |
|
|
with torch.inference_mode(): |
|
|
if use_fp16 and device.type == 'cuda': |
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16): |
|
|
policy_logits,_,_ = self.forward(batch_tensor) |
|
|
else: |
|
|
policy_logits,_,_ = self.forward(batch_tensor) |
|
|
|
|
|
|
|
|
moves = [] |
|
|
for i in range(batch_size): |
|
|
|
|
|
logits = policy_logits[i] + torch.from_numpy(legal_moves_masks[i]).to(policy_logits.device, dtype=policy_logits.dtype) |
|
|
|
|
|
|
|
|
if T == 0.0: |
|
|
best_move_idx = torch.argmax(logits).item() |
|
|
uci_move = policy_index[best_move_idx] |
|
|
else: |
|
|
scaled_logits = logits / T |
|
|
probs = F.softmax(scaled_logits, dim=0) |
|
|
move_idx = torch.multinomial(probs, 1).item() |
|
|
uci_move = policy_index[move_idx] |
|
|
|
|
|
|
|
|
if is_black_to_move_list[i]: |
|
|
def mirror_rank(rank_char): |
|
|
rank = int(rank_char) |
|
|
return str(9 - rank) |
|
|
|
|
|
if len(uci_move) >= 4: |
|
|
from_file = uci_move[0] |
|
|
from_rank = uci_move[1] |
|
|
to_file = uci_move[2] |
|
|
to_rank = uci_move[3] |
|
|
promo = uci_move[4:] if len(uci_move) > 4 else "" |
|
|
|
|
|
uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo |
|
|
|
|
|
|
|
|
castling_rights = castling_rights_list[i] |
|
|
if uci_move == "e1h1" and "K" in castling_rights: |
|
|
uci_move = "e1g1" |
|
|
elif uci_move == "e1a1" and "Q" in castling_rights: |
|
|
uci_move = "e1c1" |
|
|
elif uci_move == "e8h8" and "k" in castling_rights: |
|
|
uci_move = "e8g8" |
|
|
elif uci_move == "e8a8" and "q" in castling_rights: |
|
|
uci_move = "e8c8" |
|
|
|
|
|
moves.append(uci_move) |
|
|
|
|
|
return moves |
|
|
|
|
|
def batch_get_moves_from_move_lists(self, move_lists: List[List[str]], T: float, device: str = None, use_fp16: bool = False, fens: Optional[List[str]] = None): |
|
|
""" |
|
|
Get moves for multiple move histories using batched inference. |
|
|
|
|
|
Args: |
|
|
move_lists: List of move sequences, where each sequence is a list of UCI moves |
|
|
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic) |
|
|
device: Device to run the model on (if None, uses model's device) |
|
|
fens: Optional list of FEN strings that represent the board state prior to |
|
|
applying the corresponding move list. When provided, each move history |
|
|
is applied starting from the supplied FEN instead of the standard initial position. |
|
|
|
|
|
Returns: |
|
|
List of UCI move strings |
|
|
""" |
|
|
if not move_lists: |
|
|
return [] |
|
|
|
|
|
|
|
|
if device is None: |
|
|
device = next(self.parameters()).device |
|
|
else: |
|
|
device = torch.device(device) |
|
|
|
|
|
batch_size = len(move_lists) |
|
|
|
|
|
if fens is not None and len(fens) != len(move_lists): |
|
|
raise ValueError("Length of fens must match length of move_lists when provided.") |
|
|
|
|
|
|
|
|
input_tensors = [] |
|
|
legal_moves_masks = [] |
|
|
is_black_to_move_list = [] |
|
|
castling_rights_list = [] |
|
|
|
|
|
for idx, move_history in enumerate(move_lists): |
|
|
starting_fen = fens[idx] if fens is not None else None |
|
|
input_tensor, legal_mask = encode_moves_to_tensor(move_history, starting_fen=starting_fen) |
|
|
input_tensors.append(input_tensor.squeeze(0)) |
|
|
legal_moves_masks.append(legal_mask) |
|
|
|
|
|
board = bulletchess.Board.from_fen(starting_fen) if starting_fen is not None else bulletchess.Board() |
|
|
for mv in move_history: |
|
|
move = bulletchess.Move.from_uci(mv) |
|
|
board.apply(move) |
|
|
is_black_to_move_list.append(board.turn == bulletchess.BLACK) |
|
|
fen_parts = board.fen().split() |
|
|
castling_rights_list.append(fen_parts[2] if len(fen_parts) > 2 else "") |
|
|
|
|
|
|
|
|
batch_tensor = torch.stack(input_tensors).to(device, non_blocking=True) |
|
|
if use_fp16 and device.type == 'cuda': |
|
|
batch_tensor = batch_tensor.half() |
|
|
|
|
|
|
|
|
self.eval() |
|
|
with torch.inference_mode(): |
|
|
if use_fp16 and device.type == 'cuda': |
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16): |
|
|
policy_logits,_,_ = self.forward(batch_tensor) |
|
|
else: |
|
|
policy_logits,_,_ = self.forward(batch_tensor) |
|
|
|
|
|
|
|
|
moves = [] |
|
|
for i in range(batch_size): |
|
|
|
|
|
logits = policy_logits[i] + torch.from_numpy(legal_moves_masks[i]).to(policy_logits.device, dtype=policy_logits.dtype) |
|
|
|
|
|
|
|
|
if T == 0.0: |
|
|
best_move_idx = torch.argmax(logits).item() |
|
|
uci_move = policy_index[best_move_idx] |
|
|
else: |
|
|
scaled_logits = logits / T |
|
|
probs = F.softmax(scaled_logits, dim=0) |
|
|
move_idx = torch.multinomial(probs, 1).item() |
|
|
uci_move = policy_index[move_idx] |
|
|
|
|
|
|
|
|
if is_black_to_move_list[i]: |
|
|
def mirror_rank(rank_char): |
|
|
rank = int(rank_char) |
|
|
return str(9 - rank) |
|
|
|
|
|
if len(uci_move) >= 4: |
|
|
from_file = uci_move[0] |
|
|
from_rank = uci_move[1] |
|
|
to_file = uci_move[2] |
|
|
to_rank = uci_move[3] |
|
|
promo = uci_move[4:] if len(uci_move) > 4 else "" |
|
|
|
|
|
uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo |
|
|
|
|
|
|
|
|
castling_rights = castling_rights_list[i] |
|
|
if uci_move == "e1h1" and "K" in castling_rights: |
|
|
uci_move = "e1g1" |
|
|
elif uci_move == "e1a1" and "Q" in castling_rights: |
|
|
uci_move = "e1c1" |
|
|
elif uci_move == "e8h8" and "k" in castling_rights: |
|
|
uci_move = "e8g8" |
|
|
elif uci_move == "e8a8" and "q" in castling_rights: |
|
|
uci_move = "e8c8" |
|
|
|
|
|
moves.append(uci_move) |
|
|
return moves |
|
|
|