|
|
|
|
|
PIECE_TO_INDEX = {
|
|
|
'wp': 0, 'wN': 1, 'wB': 2, 'wR': 3, 'wQ': 4,
|
|
|
'bp': 5, 'bN': 6, 'bB': 7, 'bR': 8, 'bQ': 9
|
|
|
}
|
|
|
|
|
|
|
|
|
NUM_PIECES = 10
|
|
|
NUM_SQUARES = 64
|
|
|
NUM_FEATURES = NUM_PIECES * NUM_SQUARES * NUM_SQUARES
|
|
|
|
|
|
PAD_IDX = NUM_FEATURES
|
|
|
def find_king_squares(board):
|
|
|
wk = bk = None
|
|
|
for r in range(8):
|
|
|
for c in range(8):
|
|
|
if board[r][c] == "wK":
|
|
|
wk = r * 8 + c
|
|
|
elif board[r][c] == "bK":
|
|
|
bk = r * 8 + c
|
|
|
return wk, bk
|
|
|
|
|
|
def gs_to_nnue_features(gs):
|
|
|
board = gs.board
|
|
|
wk, bk = find_king_squares(board)
|
|
|
|
|
|
features = []
|
|
|
|
|
|
for r in range(8):
|
|
|
for c in range(8):
|
|
|
piece = board[r][c]
|
|
|
if piece == "--" or piece[1] == "K":
|
|
|
continue
|
|
|
|
|
|
p_idx = PIECE_TO_INDEX[piece]
|
|
|
sq = r * 8 + c
|
|
|
|
|
|
if piece[0] == 'w':
|
|
|
king_sq = wk
|
|
|
else:
|
|
|
king_sq = bk
|
|
|
|
|
|
if king_sq is None:
|
|
|
continue
|
|
|
|
|
|
|
|
|
idx = (
|
|
|
p_idx * 64 * 64 +
|
|
|
king_sq * 64 +
|
|
|
sq
|
|
|
)
|
|
|
features.append(idx)
|
|
|
|
|
|
return features
|
|
|
|
|
|
import torch
|
|
|
|
|
|
class NNUEInfer:
|
|
|
def __init__(self, model, device="cpu"):
|
|
|
self.device = device
|
|
|
self.model = model.to(device)
|
|
|
self.model.eval()
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def __call__(self, features, stm):
|
|
|
"""
|
|
|
features : List[int]
|
|
|
stm : 1 if white to move, 0 if black
|
|
|
returns : float score
|
|
|
"""
|
|
|
if not features:
|
|
|
features = [PAD_IDX]
|
|
|
|
|
|
feats = torch.tensor(
|
|
|
features,
|
|
|
dtype=torch.long,
|
|
|
device=self.device
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
stm = torch.tensor(
|
|
|
[stm],
|
|
|
dtype=torch.long,
|
|
|
device=self.device
|
|
|
)
|
|
|
|
|
|
return self.model(feats, stm).item()
|
|
|
|