Upload model.py with huggingface_hub
Browse files
model.py
ADDED
|
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from .attn_map import apm_map, apm_out
|
| 5 |
+
import math
|
| 6 |
+
from .encoding_simple import encode_fen_to_tensor, encode_moves_to_tensor
|
| 7 |
+
from .vocab import policy_index
|
| 8 |
+
from typing import Union, List, Optional
|
| 9 |
+
import bulletchess
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
class Gating(nn.Module):
|
| 13 |
+
def __init__(self, features_shape, additive=True, init_value=None):
|
| 14 |
+
super(Gating, self).__init__()
|
| 15 |
+
self.additive = additive
|
| 16 |
+
if init_value is None:
|
| 17 |
+
init_value = 0 if self.additive else 1
|
| 18 |
+
|
| 19 |
+
self.gate = nn.Parameter(torch.full(features_shape, float(init_value)))
|
| 20 |
+
if not self.additive:
|
| 21 |
+
self.gate.register_hook(lambda grad: torch.clamp(grad, min=0))
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
if self.additive:
|
| 25 |
+
return x + self.gate
|
| 26 |
+
else:
|
| 27 |
+
return x * self.gate
|
| 28 |
+
|
| 29 |
+
def ma_gating(x, in_features):
|
| 30 |
+
x = Gating(in_features, additive=False)(x)
|
| 31 |
+
x = Gating(in_features, additive=True)(x)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
class RMSNorm(nn.Module):
|
| 35 |
+
def __init__(self, in_features, scale=True):
|
| 36 |
+
super(RMSNorm, self).__init__()
|
| 37 |
+
self.scale = scale
|
| 38 |
+
if self.scale:
|
| 39 |
+
self.gamma = nn.Parameter(torch.ones(in_features))
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-5)
|
| 43 |
+
x_normalized = x / rms
|
| 44 |
+
if self.scale:
|
| 45 |
+
return x_normalized * self.gamma
|
| 46 |
+
return x_normalized
|
| 47 |
+
|
| 48 |
+
class ApplyAttentionPolicyMap(nn.Module):
|
| 49 |
+
def __init__(self):
|
| 50 |
+
super(ApplyAttentionPolicyMap, self).__init__()
|
| 51 |
+
# Register as buffers so they move with the model when .to(device) is called
|
| 52 |
+
# Use same names as before for backward compatibility with saved models
|
| 53 |
+
self.register_buffer('fc1', torch.from_numpy(apm_map).float())
|
| 54 |
+
self.register_buffer('idx', torch.from_numpy(apm_out).long())
|
| 55 |
+
|
| 56 |
+
def forward(self, logits, pp_logits):
|
| 57 |
+
logits = torch.cat([logits.reshape(-1, 64 * 64),
|
| 58 |
+
pp_logits.reshape(-1, 8 * 24)],
|
| 59 |
+
dim=1)
|
| 60 |
+
|
| 61 |
+
batch_size = logits.size(0)
|
| 62 |
+
idx = self.idx.unsqueeze(0).expand(batch_size, -1)
|
| 63 |
+
|
| 64 |
+
return torch.gather(logits, 1, idx)
|
| 65 |
+
|
| 66 |
+
class Mish(nn.Module):
|
| 67 |
+
def __init__(self):
|
| 68 |
+
super(Mish, self).__init__()
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
return x * torch.tanh(F.softplus(x))
|
| 72 |
+
|
| 73 |
+
class CustomMHA(nn.Module):
|
| 74 |
+
def __init__(self, emb_size, d_model, num_heads, dropout=0.0, use_bias_qkv=True, use_bias_out=True,
|
| 75 |
+
use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'):
|
| 76 |
+
super(CustomMHA, self).__init__()
|
| 77 |
+
assert d_model % num_heads == 0
|
| 78 |
+
self.emb_size = emb_size
|
| 79 |
+
self.d_model = d_model
|
| 80 |
+
self.num_heads = num_heads
|
| 81 |
+
self.head_dim = d_model // num_heads
|
| 82 |
+
self.wq = nn.Linear(emb_size, d_model, bias=use_bias_qkv)
|
| 83 |
+
self.wk = nn.Linear(emb_size, d_model, bias=use_bias_qkv)
|
| 84 |
+
self.wv = nn.Linear(emb_size, d_model, bias=use_bias_qkv)
|
| 85 |
+
self.out_proj = nn.Linear(d_model, emb_size, bias=use_bias_out)
|
| 86 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 87 |
+
# Optional Smolgen components
|
| 88 |
+
self.smol_compress = None
|
| 89 |
+
self.smol_hidden1 = None
|
| 90 |
+
self.smol_hidden1_ln = None
|
| 91 |
+
self.smol_gen_from = None
|
| 92 |
+
self.smol_gen_from_ln = None
|
| 93 |
+
self.smol_weight_gen = None
|
| 94 |
+
if use_smolgen:
|
| 95 |
+
self.smol_compress = nn.Linear(emb_size, smol_hidden_channels, bias=False)
|
| 96 |
+
self.smol_hidden1 = nn.Linear(64 * smol_hidden_channels, smol_hidden_sz, bias=True)
|
| 97 |
+
self.smol_hidden1_ln = nn.LayerNorm(smol_hidden_sz, eps=1e-3)
|
| 98 |
+
self.smol_gen_from = nn.Linear(smol_hidden_sz, num_heads * smol_gen_sz, bias=True)
|
| 99 |
+
self.smol_gen_from_ln = nn.LayerNorm(num_heads * smol_gen_sz, eps=1e-3)
|
| 100 |
+
self.smol_weight_gen = nn.Linear(smol_gen_sz, 64 * 64, bias=False)
|
| 101 |
+
self.smol_activation = smol_activation
|
| 102 |
+
|
| 103 |
+
def _shape(self, x):
|
| 104 |
+
b, l, _ = x.shape
|
| 105 |
+
return x.view(b, l, self.num_heads, self.head_dim).transpose(1, 2)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, return_attn=False):
|
| 108 |
+
# x: (B, L, emb_size)
|
| 109 |
+
q = self.wq(x)
|
| 110 |
+
k = self.wk(x)
|
| 111 |
+
v = self.wv(x)
|
| 112 |
+
q = self._shape(q) # (B, H, L, D)
|
| 113 |
+
k = self._shape(k)
|
| 114 |
+
v = self._shape(v)
|
| 115 |
+
scale = torch.sqrt(torch.tensor(self.head_dim, dtype=x.dtype, device=x.device))
|
| 116 |
+
attn_logits = torch.matmul(q, k.transpose(-2, -1)) / scale
|
| 117 |
+
# Add Smolgen weights if present
|
| 118 |
+
smol_w = None
|
| 119 |
+
if self.smol_compress is not None:
|
| 120 |
+
b, l, _ = x.shape
|
| 121 |
+
compressed = self.smol_compress(x) # (B, L, hidden_channels)
|
| 122 |
+
compressed = compressed.reshape(b, l * compressed.shape[-1]) # (B, 64*hidden_channels)
|
| 123 |
+
hidden_pre = self.smol_hidden1(compressed)
|
| 124 |
+
hidden = F.silu(hidden_pre) if self.smol_activation == 'swish' else F.silu(hidden_pre)
|
| 125 |
+
hidden_ln = self.smol_hidden1_ln(hidden)
|
| 126 |
+
gen_from_pre = self.smol_gen_from(hidden_ln)
|
| 127 |
+
gen_from_act = F.silu(gen_from_pre) if self.smol_activation == 'swish' else F.silu(gen_from_pre)
|
| 128 |
+
gen_from = self.smol_gen_from_ln(gen_from_act)
|
| 129 |
+
gen_from = gen_from.view(b, self.num_heads, -1) # (B, H, gen_sz)
|
| 130 |
+
smol_w = self.smol_weight_gen(gen_from) # (B, H, 64*64)
|
| 131 |
+
smol_w = smol_w.view(b, self.num_heads, l, l)
|
| 132 |
+
attn_logits = attn_logits + smol_w
|
| 133 |
+
# Numerically stable softmax matching TF (float32, subtract max)
|
| 134 |
+
attn_logits = attn_logits - attn_logits.max(dim=-1, keepdim=True)[0]
|
| 135 |
+
attn_weights = torch.exp(attn_logits)
|
| 136 |
+
attn_weights = attn_weights / attn_weights.sum(dim=-1, keepdim=True)
|
| 137 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 138 |
+
attn_output = torch.matmul(attn_weights, v) # (B, H, L, D)
|
| 139 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(x.size(0), x.size(1), self.d_model)
|
| 140 |
+
out = self.out_proj(attn_output)
|
| 141 |
+
if return_attn:
|
| 142 |
+
return out, attn_weights, smol_w, attn_logits
|
| 143 |
+
return out
|
| 144 |
+
|
| 145 |
+
class FFN(nn.Module):
|
| 146 |
+
def __init__(self, emb_size, dff, activation=Mish(), omit_other_biases=False):
|
| 147 |
+
super(FFN, self).__init__()
|
| 148 |
+
self.dense1 = nn.Linear(emb_size, dff, bias=not omit_other_biases)
|
| 149 |
+
self.activation = activation
|
| 150 |
+
self.dense2 = nn.Linear(dff, emb_size, bias=not omit_other_biases)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
x = self.dense1(x)
|
| 154 |
+
x = self.activation(x)
|
| 155 |
+
x = self.dense2(x)
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
class EncoderLayer(nn.Module):
|
| 159 |
+
def __init__(self, emb_size, d_model, num_heads, dff, dropout_rate, encoder_layers, skip_first_ln=False, encoder_rms_norm=False, omit_qkv_biases=False, omit_other_biases=False,
|
| 160 |
+
use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'):
|
| 161 |
+
super(EncoderLayer, self).__init__()
|
| 162 |
+
self.mha = CustomMHA(emb_size, d_model, num_heads, dropout=dropout_rate, use_bias_qkv=not omit_qkv_biases, use_bias_out=not omit_other_biases,
|
| 163 |
+
use_smolgen=use_smolgen, smol_hidden_channels=smol_hidden_channels, smol_hidden_sz=smol_hidden_sz, smol_gen_sz=smol_gen_sz, smol_activation=smol_activation)
|
| 164 |
+
self.ffn = FFN(emb_size, dff, omit_other_biases=omit_other_biases)
|
| 165 |
+
|
| 166 |
+
self.norm1 = RMSNorm(emb_size) if encoder_rms_norm else nn.LayerNorm(emb_size, eps=0.001)
|
| 167 |
+
self.norm2 = RMSNorm(emb_size) if encoder_rms_norm else nn.LayerNorm(emb_size, eps=0.001)
|
| 168 |
+
|
| 169 |
+
self.dropout1 = nn.Dropout(dropout_rate)
|
| 170 |
+
self.dropout2 = nn.Dropout(dropout_rate)
|
| 171 |
+
|
| 172 |
+
self.alpha = (2. * encoder_layers)**-0.25
|
| 173 |
+
self.skip_first_ln = skip_first_ln
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
attn_output = self.mha(x)
|
| 177 |
+
attn_output = self.dropout1(attn_output)
|
| 178 |
+
|
| 179 |
+
out1 = x + attn_output * self.alpha
|
| 180 |
+
if not self.skip_first_ln:
|
| 181 |
+
out1 = self.norm1(out1)
|
| 182 |
+
ffn_output = self.ffn(out1)
|
| 183 |
+
ffn_output = self.dropout2(ffn_output)
|
| 184 |
+
|
| 185 |
+
out2 = self.norm2(out1 + ffn_output * self.alpha)
|
| 186 |
+
return out2
|
| 187 |
+
|
| 188 |
+
class PolicyHead(nn.Module):
|
| 189 |
+
def __init__(self, pol_embedding_size, policy_d_model, opponent=False):
|
| 190 |
+
super(PolicyHead, self).__init__()
|
| 191 |
+
self.opponent = opponent
|
| 192 |
+
self.wq = nn.Linear(pol_embedding_size, policy_d_model)
|
| 193 |
+
self.wk = nn.Linear(pol_embedding_size, policy_d_model)
|
| 194 |
+
self.ppo = nn.Linear(policy_d_model, 4, bias=False)
|
| 195 |
+
self.apply_map = ApplyAttentionPolicyMap()
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
if self.opponent:
|
| 199 |
+
x = torch.flip(x, [1])
|
| 200 |
+
|
| 201 |
+
queries = self.wq(x)
|
| 202 |
+
keys = self.wk(x)
|
| 203 |
+
|
| 204 |
+
matmul_qk = torch.matmul(queries, keys.transpose(-2, -1))
|
| 205 |
+
|
| 206 |
+
dk = torch.sqrt(torch.tensor(keys.shape[-1], dtype=keys.dtype, device=keys.device))
|
| 207 |
+
promotion_keys = keys[:, -8:, :]
|
| 208 |
+
promotion_offsets = self.ppo(promotion_keys).transpose(-2,-1) * dk
|
| 209 |
+
promotion_offsets = promotion_offsets[:, :3, :] + promotion_offsets[:, 3:4, :]
|
| 210 |
+
|
| 211 |
+
n_promo_logits = matmul_qk[:, -16:-8, -8:]
|
| 212 |
+
q_promo_logits = (n_promo_logits + promotion_offsets[:, 0:1, :]).unsqueeze(3)
|
| 213 |
+
r_promo_logits = (n_promo_logits + promotion_offsets[:, 1:2, :]).unsqueeze(3)
|
| 214 |
+
b_promo_logits = (n_promo_logits + promotion_offsets[:, 2:3, :]).unsqueeze(3)
|
| 215 |
+
promotion_logits = torch.cat([q_promo_logits, r_promo_logits, b_promo_logits], axis=3).reshape(-1, 8, 24)
|
| 216 |
+
|
| 217 |
+
promotion_logits = promotion_logits / dk
|
| 218 |
+
policy_attn_logits = matmul_qk / dk
|
| 219 |
+
|
| 220 |
+
return self.apply_map(policy_attn_logits, promotion_logits)
|
| 221 |
+
|
| 222 |
+
class ValueHead(nn.Module):
|
| 223 |
+
def __init__(self, embedding_size, val_embedding_size, default_activation=Mish()):
|
| 224 |
+
super(ValueHead, self).__init__()
|
| 225 |
+
self.embedding = nn.Linear(embedding_size, val_embedding_size)
|
| 226 |
+
self.activation = default_activation
|
| 227 |
+
self.flatten = nn.Flatten()
|
| 228 |
+
self.dense1 = nn.Linear(val_embedding_size * 64, 128)
|
| 229 |
+
self.dense2 = nn.Linear(128, 3)
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
x = self.embedding(x)
|
| 233 |
+
x = self.activation(x)
|
| 234 |
+
x = self.flatten(x)
|
| 235 |
+
x = self.dense1(x)
|
| 236 |
+
x = self.activation(x)
|
| 237 |
+
x = self.dense2(x)
|
| 238 |
+
return x
|
| 239 |
+
|
| 240 |
+
class BT4(nn.Module):
|
| 241 |
+
def __init__(self, embedding_size=1024, embedding_dense_sz=512, encoder_layers=15, encoder_d_model=1024, encoder_heads=32, encoder_dff=1536, dropout_rate=0.0, pol_embedding_size=1024, policy_d_model=1024, val_embedding_size=128, default_activation=Mish(),
|
| 242 |
+
use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'):
|
| 243 |
+
super(BT4, self).__init__()
|
| 244 |
+
self.embedding_dense_sz = embedding_dense_sz
|
| 245 |
+
# DeepNorm alpha used in embedding residual; default uses provided encoder_layers
|
| 246 |
+
self.deepnorm_alpha = (2. * encoder_layers) ** -0.25
|
| 247 |
+
|
| 248 |
+
self.embedding_preprocess = nn.Linear(64*12, 64*self.embedding_dense_sz)
|
| 249 |
+
self.embedding = nn.Linear(112 + self.embedding_dense_sz, embedding_size)
|
| 250 |
+
nn.init.xavier_uniform_(self.embedding.weight) # Explicitly set initializer
|
| 251 |
+
nn.init.zeros_(self.embedding.bias)
|
| 252 |
+
|
| 253 |
+
self.embedding_ln = nn.LayerNorm(embedding_size, eps=0.001)
|
| 254 |
+
|
| 255 |
+
self.gating_mult = Gating((64, embedding_size), additive=False)
|
| 256 |
+
self.gating_add = Gating((64, embedding_size), additive=True)
|
| 257 |
+
|
| 258 |
+
self.embedding_ffn = FFN(embedding_size, encoder_dff)
|
| 259 |
+
self.embedding_ffn_ln = nn.LayerNorm(embedding_size, eps=0.001)
|
| 260 |
+
|
| 261 |
+
self.encoder_layers_list = nn.ModuleList([
|
| 262 |
+
EncoderLayer(embedding_size, encoder_d_model, encoder_heads, encoder_dff, dropout_rate, encoder_layers,
|
| 263 |
+
use_smolgen=use_smolgen, smol_hidden_channels=smol_hidden_channels, smol_hidden_sz=smol_hidden_sz, smol_gen_sz=smol_gen_sz, smol_activation=smol_activation)
|
| 264 |
+
for _ in range(encoder_layers)
|
| 265 |
+
])
|
| 266 |
+
|
| 267 |
+
self.policy_embedding = nn.Linear(embedding_size, pol_embedding_size)
|
| 268 |
+
self.policy_head = PolicyHead(pol_embedding_size, policy_d_model)
|
| 269 |
+
self.value_head_winner = ValueHead(embedding_size, val_embedding_size)
|
| 270 |
+
self.value_head_q = ValueHead(embedding_size, val_embedding_size)
|
| 271 |
+
self.activation = default_activation
|
| 272 |
+
|
| 273 |
+
self.apply(self._init_weights)
|
| 274 |
+
|
| 275 |
+
def _init_weights(self, module):
|
| 276 |
+
if isinstance(module, nn.Linear):
|
| 277 |
+
# Keras' glorot_normal is equivalent to PyTorch's xavier_normal_
|
| 278 |
+
nn.init.xavier_normal_(module.weight)
|
| 279 |
+
if module.bias is not None:
|
| 280 |
+
nn.init.zeros_(module.bias)
|
| 281 |
+
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
# x shape: (batch, 112, 8, 8)
|
| 284 |
+
flow = x.permute(0, 2, 3, 1).reshape(-1, 64, 112)
|
| 285 |
+
|
| 286 |
+
pos_info = flow[..., :12]
|
| 287 |
+
pos_info_flat = pos_info.reshape(-1, 64 * 12)
|
| 288 |
+
|
| 289 |
+
pos_info_processed = self.embedding_preprocess(pos_info_flat)
|
| 290 |
+
pos_info = pos_info_processed.reshape(-1, 64, self.embedding_dense_sz)
|
| 291 |
+
|
| 292 |
+
flow = torch.cat([flow, pos_info], dim=-1)
|
| 293 |
+
|
| 294 |
+
flow = self.embedding(flow)
|
| 295 |
+
|
| 296 |
+
flow = self.activation(flow)
|
| 297 |
+
|
| 298 |
+
flow = self.embedding_ln(flow)
|
| 299 |
+
|
| 300 |
+
flow = self.gating_mult(flow)
|
| 301 |
+
flow = self.gating_add(flow)
|
| 302 |
+
|
| 303 |
+
ffn_dense1_pre = self.embedding_ffn.dense1(flow)
|
| 304 |
+
ffn_dense1 = self.embedding_ffn.activation(ffn_dense1_pre)
|
| 305 |
+
ffn_output = self.embedding_ffn.dense2(ffn_dense1)
|
| 306 |
+
|
| 307 |
+
residual = flow + ffn_output * self.deepnorm_alpha
|
| 308 |
+
flow = self.embedding_ffn_ln(residual)
|
| 309 |
+
|
| 310 |
+
for i, layer in enumerate(self.encoder_layers_list):
|
| 311 |
+
flow = layer(flow)
|
| 312 |
+
|
| 313 |
+
policy_tokens = self.policy_embedding(flow)
|
| 314 |
+
policy_tokens = self.activation(policy_tokens)
|
| 315 |
+
|
| 316 |
+
policy_logits = self.policy_head(policy_tokens)
|
| 317 |
+
|
| 318 |
+
value_winner = self.value_head_winner(flow)
|
| 319 |
+
value_q = self.value_head_q(flow)
|
| 320 |
+
|
| 321 |
+
return policy_logits, value_winner, value_q
|
| 322 |
+
|
| 323 |
+
def get_move_from_fen_no_thinking(self, fen_or_moves: Union[str, List[str]], T: float, device: str = None, **kwargs) -> str:
|
| 324 |
+
"""
|
| 325 |
+
Predict a move from a FEN position or move history without thinking/search.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves
|
| 329 |
+
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
|
| 330 |
+
device: Device to run the model on (if None, uses model's device)
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
UCI move string (e.g., 'e2e4')
|
| 334 |
+
"""
|
| 335 |
+
# Detect device from model if not provided
|
| 336 |
+
if device is None:
|
| 337 |
+
device = next(self.parameters()).device
|
| 338 |
+
else:
|
| 339 |
+
device = torch.device(device)
|
| 340 |
+
|
| 341 |
+
# Determine if input is FEN string or list of moves
|
| 342 |
+
if isinstance(fen_or_moves, str):
|
| 343 |
+
# FEN string input
|
| 344 |
+
fen = fen_or_moves
|
| 345 |
+
is_black_to_move = fen.split()[1] == 'b'
|
| 346 |
+
input_tensor_112, legal_moves_mask = encode_fen_to_tensor(fen)
|
| 347 |
+
castling_rights = fen.split()[2] if len(fen.split()) > 2 else ""
|
| 348 |
+
elif isinstance(fen_or_moves, list):
|
| 349 |
+
# List of UCI moves input
|
| 350 |
+
move_history = fen_or_moves
|
| 351 |
+
input_tensor_112, legal_moves_mask = encode_moves_to_tensor(move_history)
|
| 352 |
+
# Create board to check if black is to move and for castling rights
|
| 353 |
+
board = bulletchess.Board()
|
| 354 |
+
for mv in move_history:
|
| 355 |
+
move = bulletchess.Move.from_uci(mv)
|
| 356 |
+
board.apply(move)
|
| 357 |
+
is_black_to_move = (board.turn == bulletchess.BLACK)
|
| 358 |
+
fen_parts = board.fen().split()
|
| 359 |
+
castling_rights = fen_parts[2] if len(fen_parts) > 2 else ""
|
| 360 |
+
else:
|
| 361 |
+
raise ValueError("Input must be a FEN string or a list of UCI moves")
|
| 362 |
+
|
| 363 |
+
input_tensor_112 = input_tensor_112.to(device, non_blocking=True)
|
| 364 |
+
|
| 365 |
+
self.eval()
|
| 366 |
+
with torch.inference_mode():
|
| 367 |
+
policy_logits,_,_ = self.forward(input_tensor_112)
|
| 368 |
+
|
| 369 |
+
# Apply legal moves mask without in-place ops (inference tensor)
|
| 370 |
+
logits0 = policy_logits[0] + torch.from_numpy(legal_moves_mask).to(policy_logits.device)
|
| 371 |
+
|
| 372 |
+
# Check if return_probs is requested
|
| 373 |
+
return_probs = kwargs.get('return_probs', False)
|
| 374 |
+
|
| 375 |
+
if return_probs:
|
| 376 |
+
# Return probabilities dictionary
|
| 377 |
+
scaled_logits = logits0 / T if T > 0 else logits0
|
| 378 |
+
probs = F.softmax(scaled_logits, dim=0)
|
| 379 |
+
probs_dict = {}
|
| 380 |
+
for idx, move in enumerate(policy_index):
|
| 381 |
+
prob_val = probs[idx].item()
|
| 382 |
+
if prob_val > 1e-6: # Only include moves with non-negligible probability
|
| 383 |
+
probs_dict[move] = prob_val
|
| 384 |
+
return probs_dict
|
| 385 |
+
|
| 386 |
+
if T == 0.0:
|
| 387 |
+
# Deterministic: return best move
|
| 388 |
+
best_move_idx = torch.argmax(logits0).item()
|
| 389 |
+
uci_move = policy_index[best_move_idx]
|
| 390 |
+
else:
|
| 391 |
+
# Stochastic sampling with temperature
|
| 392 |
+
# Apply temperature scaling
|
| 393 |
+
scaled_logits = logits0 / T
|
| 394 |
+
# Apply softmax to get probabilities
|
| 395 |
+
probs = F.softmax(scaled_logits, dim=0)
|
| 396 |
+
# Sample from the distribution
|
| 397 |
+
move_idx = torch.multinomial(probs, 1).item()
|
| 398 |
+
uci_move = policy_index[move_idx]
|
| 399 |
+
|
| 400 |
+
# If black is to move, the board was mirrored during encoding, so we need to mirror the move back
|
| 401 |
+
# Mirror ranks: 1↔8, 2↔7, 3↔6, 4↔5 (keep file letters the same)
|
| 402 |
+
if is_black_to_move:
|
| 403 |
+
def mirror_rank(rank_char):
|
| 404 |
+
rank = int(rank_char)
|
| 405 |
+
return str(9 - rank)
|
| 406 |
+
|
| 407 |
+
# UCI format: e2e4, e7e8q, etc.
|
| 408 |
+
if len(uci_move) >= 4:
|
| 409 |
+
from_file = uci_move[0]
|
| 410 |
+
from_rank = uci_move[1]
|
| 411 |
+
to_file = uci_move[2]
|
| 412 |
+
to_rank = uci_move[3]
|
| 413 |
+
promo = uci_move[4:] if len(uci_move) > 4 else ""
|
| 414 |
+
|
| 415 |
+
uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo
|
| 416 |
+
|
| 417 |
+
# Convert castling moves from king-to-rook-square format to standard castling format
|
| 418 |
+
# Only if castling rights are available (check FEN castling rights)
|
| 419 |
+
# Check and convert white castling moves
|
| 420 |
+
if uci_move == "e1h1" and "K" in castling_rights:
|
| 421 |
+
uci_move = "e1g1"
|
| 422 |
+
elif uci_move == "e1a1" and "Q" in castling_rights:
|
| 423 |
+
uci_move = "e1c1"
|
| 424 |
+
# Check and convert black castling moves
|
| 425 |
+
elif uci_move == "e8h8" and "k" in castling_rights:
|
| 426 |
+
uci_move = "e8g8"
|
| 427 |
+
elif uci_move == "e8a8" and "q" in castling_rights:
|
| 428 |
+
uci_move = "e8c8"
|
| 429 |
+
|
| 430 |
+
return uci_move
|
| 431 |
+
|
| 432 |
+
def get_best_move_value(self, fen_or_moves: Union[str, List[str]], T: float = 0.0, device: str = None) -> tuple:
|
| 433 |
+
"""
|
| 434 |
+
Get the best move and its value using value analysis.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves
|
| 438 |
+
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
|
| 439 |
+
device: Device to run the model on (if None, uses model's device)
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
Tuple of (best_move, value) where value is the position evaluation
|
| 443 |
+
"""
|
| 444 |
+
# Detect device from model if not provided
|
| 445 |
+
if device is None:
|
| 446 |
+
device = next(self.parameters()).device
|
| 447 |
+
else:
|
| 448 |
+
device = torch.device(device)
|
| 449 |
+
|
| 450 |
+
# Determine if input is FEN string or list of moves
|
| 451 |
+
if isinstance(fen_or_moves, str):
|
| 452 |
+
fen = fen_or_moves
|
| 453 |
+
is_black_to_move = fen.split()[1] == 'b'
|
| 454 |
+
input_tensor_112, legal_moves_mask = encode_fen_to_tensor(fen)
|
| 455 |
+
castling_rights = fen.split()[2] if len(fen.split()) > 2 else ""
|
| 456 |
+
elif isinstance(fen_or_moves, list):
|
| 457 |
+
move_history = fen_or_moves
|
| 458 |
+
input_tensor_112, legal_moves_mask = encode_moves_to_tensor(move_history)
|
| 459 |
+
board = bulletchess.Board()
|
| 460 |
+
for mv in move_history:
|
| 461 |
+
move = bulletchess.Move.from_uci(mv)
|
| 462 |
+
board.apply(move)
|
| 463 |
+
is_black_to_move = (board.turn == bulletchess.BLACK)
|
| 464 |
+
fen_parts = board.fen().split()
|
| 465 |
+
castling_rights = fen_parts[2] if len(fen_parts) > 2 else ""
|
| 466 |
+
else:
|
| 467 |
+
raise ValueError("Input must be a FEN string or a list of UCI moves")
|
| 468 |
+
|
| 469 |
+
input_tensor_112 = input_tensor_112.to(device, non_blocking=True)
|
| 470 |
+
|
| 471 |
+
self.eval()
|
| 472 |
+
with torch.inference_mode():
|
| 473 |
+
policy_logits, _, value_q = self.forward(input_tensor_112)
|
| 474 |
+
|
| 475 |
+
# Apply legal moves mask
|
| 476 |
+
logits0 = policy_logits[0] + torch.from_numpy(legal_moves_mask).to(policy_logits.device)
|
| 477 |
+
|
| 478 |
+
# Get best move
|
| 479 |
+
if T == 0.0:
|
| 480 |
+
best_move_idx = torch.argmax(logits0).item()
|
| 481 |
+
else:
|
| 482 |
+
scaled_logits = logits0 / T
|
| 483 |
+
probs = F.softmax(scaled_logits, dim=0)
|
| 484 |
+
move_idx = torch.multinomial(probs, 1).item()
|
| 485 |
+
best_move_idx = move_idx
|
| 486 |
+
|
| 487 |
+
uci_move = policy_index[best_move_idx]
|
| 488 |
+
|
| 489 |
+
# Mirror move if black is to move
|
| 490 |
+
if is_black_to_move:
|
| 491 |
+
def mirror_rank(rank_char):
|
| 492 |
+
rank = int(rank_char)
|
| 493 |
+
return str(9 - rank)
|
| 494 |
+
|
| 495 |
+
if len(uci_move) >= 4:
|
| 496 |
+
from_file = uci_move[0]
|
| 497 |
+
from_rank = uci_move[1]
|
| 498 |
+
to_file = uci_move[2]
|
| 499 |
+
to_rank = uci_move[3]
|
| 500 |
+
promo = uci_move[4:] if len(uci_move) > 4 else ""
|
| 501 |
+
uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo
|
| 502 |
+
|
| 503 |
+
# Convert castling moves
|
| 504 |
+
if uci_move == "e1h1" and "K" in castling_rights:
|
| 505 |
+
uci_move = "e1g1"
|
| 506 |
+
elif uci_move == "e1a1" and "Q" in castling_rights:
|
| 507 |
+
uci_move = "e1c1"
|
| 508 |
+
elif uci_move == "e8h8" and "k" in castling_rights:
|
| 509 |
+
uci_move = "e8g8"
|
| 510 |
+
elif uci_move == "e8a8" and "q" in castling_rights:
|
| 511 |
+
uci_move = "e8c8"
|
| 512 |
+
|
| 513 |
+
# Get value (softmax over value_q)
|
| 514 |
+
value_probs = F.softmax(value_q[0], dim=0)
|
| 515 |
+
value = value_probs.cpu().numpy()
|
| 516 |
+
|
| 517 |
+
return uci_move, value
|
| 518 |
+
|
| 519 |
+
def get_position_value(self, fen_or_moves: Union[str, List[str]], device: str = None) -> np.ndarray:
|
| 520 |
+
"""
|
| 521 |
+
Get position evaluation using value_q.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves
|
| 525 |
+
device: Device to run the model on (if None, uses model's device)
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
Array of [black_win, draw, white_win] probabilities
|
| 529 |
+
"""
|
| 530 |
+
# Detect device from model if not provided
|
| 531 |
+
if device is None:
|
| 532 |
+
device = next(self.parameters()).device
|
| 533 |
+
else:
|
| 534 |
+
device = torch.device(device)
|
| 535 |
+
|
| 536 |
+
# Determine if input is FEN string or list of moves
|
| 537 |
+
if isinstance(fen_or_moves, str):
|
| 538 |
+
input_tensor_112, _ = encode_fen_to_tensor(fen_or_moves)
|
| 539 |
+
elif isinstance(fen_or_moves, list):
|
| 540 |
+
input_tensor_112, _ = encode_moves_to_tensor(fen_or_moves)
|
| 541 |
+
else:
|
| 542 |
+
raise ValueError("Input must be a FEN string or a list of UCI moves")
|
| 543 |
+
|
| 544 |
+
input_tensor_112 = input_tensor_112.to(device, non_blocking=True)
|
| 545 |
+
|
| 546 |
+
self.eval()
|
| 547 |
+
with torch.inference_mode():
|
| 548 |
+
_, _, value_q = self.forward(input_tensor_112)
|
| 549 |
+
|
| 550 |
+
# Apply softmax to get probabilities [black_win, draw, white_win]
|
| 551 |
+
value_probs = F.softmax(value_q[0], dim=0)
|
| 552 |
+
return value_probs.cpu().numpy()
|
| 553 |
+
|
| 554 |
+
def batch_get_moves_from_fens(self, fens: List[str], T: float, device: str = None, use_fp16: bool = False) -> List[str]:
|
| 555 |
+
"""
|
| 556 |
+
Get moves for multiple FEN positions using batched inference.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
fens: List of FEN strings representing chess positions
|
| 560 |
+
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
|
| 561 |
+
device: Device to run the model on (if None, uses model's device)
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
List of UCI move strings
|
| 565 |
+
"""
|
| 566 |
+
if not fens:
|
| 567 |
+
return []
|
| 568 |
+
|
| 569 |
+
# Detect device from model if not provided
|
| 570 |
+
if device is None:
|
| 571 |
+
device = next(self.parameters()).device
|
| 572 |
+
else:
|
| 573 |
+
device = torch.device(device)
|
| 574 |
+
|
| 575 |
+
batch_size = len(fens)
|
| 576 |
+
|
| 577 |
+
# Batch encode all FENs
|
| 578 |
+
input_tensors = []
|
| 579 |
+
legal_moves_masks = []
|
| 580 |
+
is_black_to_move_list = []
|
| 581 |
+
castling_rights_list = []
|
| 582 |
+
|
| 583 |
+
for fen in fens:
|
| 584 |
+
input_tensor, legal_mask = encode_fen_to_tensor(fen)
|
| 585 |
+
input_tensors.append(input_tensor.squeeze(0)) # Remove batch dim
|
| 586 |
+
legal_moves_masks.append(legal_mask)
|
| 587 |
+
is_black_to_move_list.append(fen.split()[1] == 'b')
|
| 588 |
+
castling_rights_list.append(fen.split()[2] if len(fen.split()) > 2 else "")
|
| 589 |
+
|
| 590 |
+
# Stack into batch tensor: (batch_size, 112, 8, 8)
|
| 591 |
+
batch_tensor = torch.stack(input_tensors).to(device, non_blocking=True)
|
| 592 |
+
if use_fp16 and device.type == 'cuda':
|
| 593 |
+
batch_tensor = batch_tensor.half()
|
| 594 |
+
|
| 595 |
+
# Run batched inference
|
| 596 |
+
self.eval()
|
| 597 |
+
with torch.inference_mode():
|
| 598 |
+
if use_fp16 and device.type == 'cuda':
|
| 599 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 600 |
+
policy_logits,_,_ = self.forward(batch_tensor)
|
| 601 |
+
else:
|
| 602 |
+
policy_logits,_,_ = self.forward(batch_tensor)
|
| 603 |
+
|
| 604 |
+
# Process each position in the batch
|
| 605 |
+
moves = []
|
| 606 |
+
for i in range(batch_size):
|
| 607 |
+
# Apply legal moves mask
|
| 608 |
+
logits = policy_logits[i] + torch.from_numpy(legal_moves_masks[i]).to(policy_logits.device, dtype=policy_logits.dtype)
|
| 609 |
+
|
| 610 |
+
# Sample move
|
| 611 |
+
if T == 0.0:
|
| 612 |
+
best_move_idx = torch.argmax(logits).item()
|
| 613 |
+
uci_move = policy_index[best_move_idx]
|
| 614 |
+
else:
|
| 615 |
+
scaled_logits = logits / T
|
| 616 |
+
probs = F.softmax(scaled_logits, dim=0)
|
| 617 |
+
move_idx = torch.multinomial(probs, 1).item()
|
| 618 |
+
uci_move = policy_index[move_idx]
|
| 619 |
+
|
| 620 |
+
# Mirror move if black is to move
|
| 621 |
+
if is_black_to_move_list[i]:
|
| 622 |
+
def mirror_rank(rank_char):
|
| 623 |
+
rank = int(rank_char)
|
| 624 |
+
return str(9 - rank)
|
| 625 |
+
|
| 626 |
+
if len(uci_move) >= 4:
|
| 627 |
+
from_file = uci_move[0]
|
| 628 |
+
from_rank = uci_move[1]
|
| 629 |
+
to_file = uci_move[2]
|
| 630 |
+
to_rank = uci_move[3]
|
| 631 |
+
promo = uci_move[4:] if len(uci_move) > 4 else ""
|
| 632 |
+
|
| 633 |
+
uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo
|
| 634 |
+
|
| 635 |
+
# Convert castling moves
|
| 636 |
+
castling_rights = castling_rights_list[i]
|
| 637 |
+
if uci_move == "e1h1" and "K" in castling_rights:
|
| 638 |
+
uci_move = "e1g1"
|
| 639 |
+
elif uci_move == "e1a1" and "Q" in castling_rights:
|
| 640 |
+
uci_move = "e1c1"
|
| 641 |
+
elif uci_move == "e8h8" and "k" in castling_rights:
|
| 642 |
+
uci_move = "e8g8"
|
| 643 |
+
elif uci_move == "e8a8" and "q" in castling_rights:
|
| 644 |
+
uci_move = "e8c8"
|
| 645 |
+
|
| 646 |
+
moves.append(uci_move)
|
| 647 |
+
|
| 648 |
+
return moves
|
| 649 |
+
|
| 650 |
+
def batch_get_moves_from_move_lists(self, move_lists: List[List[str]], T: float, device: str = None, use_fp16: bool = False, fens: Optional[List[str]] = None):
|
| 651 |
+
"""
|
| 652 |
+
Get moves for multiple move histories using batched inference.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
move_lists: List of move sequences, where each sequence is a list of UCI moves
|
| 656 |
+
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
|
| 657 |
+
device: Device to run the model on (if None, uses model's device)
|
| 658 |
+
fens: Optional list of FEN strings that represent the board state prior to
|
| 659 |
+
applying the corresponding move list. When provided, each move history
|
| 660 |
+
is applied starting from the supplied FEN instead of the standard initial position.
|
| 661 |
+
|
| 662 |
+
Returns:
|
| 663 |
+
List of UCI move strings
|
| 664 |
+
"""
|
| 665 |
+
if not move_lists:
|
| 666 |
+
return []
|
| 667 |
+
|
| 668 |
+
# Detect device from model if not provided
|
| 669 |
+
if device is None:
|
| 670 |
+
device = next(self.parameters()).device
|
| 671 |
+
else:
|
| 672 |
+
device = torch.device(device)
|
| 673 |
+
|
| 674 |
+
batch_size = len(move_lists)
|
| 675 |
+
|
| 676 |
+
if fens is not None and len(fens) != len(move_lists):
|
| 677 |
+
raise ValueError("Length of fens must match length of move_lists when provided.")
|
| 678 |
+
|
| 679 |
+
# Batch encode all move histories
|
| 680 |
+
input_tensors = []
|
| 681 |
+
legal_moves_masks = []
|
| 682 |
+
is_black_to_move_list = []
|
| 683 |
+
castling_rights_list = []
|
| 684 |
+
|
| 685 |
+
for idx, move_history in enumerate(move_lists):
|
| 686 |
+
starting_fen = fens[idx] if fens is not None else None
|
| 687 |
+
input_tensor, legal_mask = encode_moves_to_tensor(move_history, starting_fen=starting_fen)
|
| 688 |
+
input_tensors.append(input_tensor.squeeze(0)) # Remove batch dim
|
| 689 |
+
legal_moves_masks.append(legal_mask)
|
| 690 |
+
|
| 691 |
+
board = bulletchess.Board.from_fen(starting_fen) if starting_fen is not None else bulletchess.Board()
|
| 692 |
+
for mv in move_history:
|
| 693 |
+
move = bulletchess.Move.from_uci(mv)
|
| 694 |
+
board.apply(move)
|
| 695 |
+
is_black_to_move_list.append(board.turn == bulletchess.BLACK)
|
| 696 |
+
fen_parts = board.fen().split()
|
| 697 |
+
castling_rights_list.append(fen_parts[2] if len(fen_parts) > 2 else "")
|
| 698 |
+
|
| 699 |
+
# Stack into batch tensor: (batch_size, 112, 8, 8)
|
| 700 |
+
batch_tensor = torch.stack(input_tensors).to(device, non_blocking=True)
|
| 701 |
+
if use_fp16 and device.type == 'cuda':
|
| 702 |
+
batch_tensor = batch_tensor.half()
|
| 703 |
+
|
| 704 |
+
# Run batched inference
|
| 705 |
+
self.eval()
|
| 706 |
+
with torch.inference_mode():
|
| 707 |
+
if use_fp16 and device.type == 'cuda':
|
| 708 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 709 |
+
policy_logits,_,_ = self.forward(batch_tensor)
|
| 710 |
+
else:
|
| 711 |
+
policy_logits,_,_ = self.forward(batch_tensor)
|
| 712 |
+
|
| 713 |
+
# Process each position in the batch
|
| 714 |
+
moves = []
|
| 715 |
+
for i in range(batch_size):
|
| 716 |
+
# Apply legal moves mask
|
| 717 |
+
logits = policy_logits[i] + torch.from_numpy(legal_moves_masks[i]).to(policy_logits.device, dtype=policy_logits.dtype)
|
| 718 |
+
|
| 719 |
+
# Sample move
|
| 720 |
+
if T == 0.0:
|
| 721 |
+
best_move_idx = torch.argmax(logits).item()
|
| 722 |
+
uci_move = policy_index[best_move_idx]
|
| 723 |
+
else:
|
| 724 |
+
scaled_logits = logits / T
|
| 725 |
+
probs = F.softmax(scaled_logits, dim=0)
|
| 726 |
+
move_idx = torch.multinomial(probs, 1).item()
|
| 727 |
+
uci_move = policy_index[move_idx]
|
| 728 |
+
|
| 729 |
+
# Mirror move if black is to move
|
| 730 |
+
if is_black_to_move_list[i]:
|
| 731 |
+
def mirror_rank(rank_char):
|
| 732 |
+
rank = int(rank_char)
|
| 733 |
+
return str(9 - rank)
|
| 734 |
+
|
| 735 |
+
if len(uci_move) >= 4:
|
| 736 |
+
from_file = uci_move[0]
|
| 737 |
+
from_rank = uci_move[1]
|
| 738 |
+
to_file = uci_move[2]
|
| 739 |
+
to_rank = uci_move[3]
|
| 740 |
+
promo = uci_move[4:] if len(uci_move) > 4 else ""
|
| 741 |
+
|
| 742 |
+
uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo
|
| 743 |
+
|
| 744 |
+
# Convert castling moves
|
| 745 |
+
castling_rights = castling_rights_list[i]
|
| 746 |
+
if uci_move == "e1h1" and "K" in castling_rights:
|
| 747 |
+
uci_move = "e1g1"
|
| 748 |
+
elif uci_move == "e1a1" and "Q" in castling_rights:
|
| 749 |
+
uci_move = "e1c1"
|
| 750 |
+
elif uci_move == "e8h8" and "k" in castling_rights:
|
| 751 |
+
uci_move = "e8g8"
|
| 752 |
+
elif uci_move == "e8a8" and "q" in castling_rights:
|
| 753 |
+
uci_move = "e8c8"
|
| 754 |
+
|
| 755 |
+
moves.append(uci_move)
|
| 756 |
+
return moves
|