File size: 2,034 Bytes
98ab355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

PIECE_TO_INDEX = {
    'wp': 0, 'wN': 1, 'wB': 2, 'wR': 3, 'wQ': 4,
    'bp': 5, 'bN': 6, 'bB': 7, 'bR': 8, 'bQ': 9
}


NUM_PIECES = 10
NUM_SQUARES = 64
NUM_FEATURES = NUM_PIECES * NUM_SQUARES * NUM_SQUARES

PAD_IDX = NUM_FEATURES  
def find_king_squares(board):
    wk = bk = None
    for r in range(8):
        for c in range(8):
            if board[r][c] == "wK":
                wk = r * 8 + c
            elif board[r][c] == "bK":
                bk = r * 8 + c
    return wk, bk

def gs_to_nnue_features(gs):
    board = gs.board
    wk, bk = find_king_squares(board)

    features = []

    for r in range(8):
        for c in range(8):
            piece = board[r][c]
            if piece == "--" or piece[1] == "K":
                continue

            p_idx = PIECE_TO_INDEX[piece]
            sq = r * 8 + c

            if piece[0] == 'w':
                king_sq = wk
            else:
                king_sq = bk

            if king_sq is None:
                continue

            # index = p * 64 * 64 + king * 64 + piece_square
            idx = (
                p_idx * 64 * 64 +
                king_sq * 64 +
                sq
            )
            features.append(idx)

    return features

import torch

class NNUEInfer:
    def __init__(self, model, device="cpu"):
        self.device = device
        self.model = model.to(device)
        self.model.eval()

    @torch.no_grad()
    def __call__(self, features, stm):
        """

        features : List[int]

        stm      : 1 if white to move, 0 if black

        returns  : float score

        """
        if not features:
            features = [PAD_IDX]

        feats = torch.tensor(
            features,
            dtype=torch.long,
            device=self.device
        ).unsqueeze(0)

        stm = torch.tensor(
            [stm],
            dtype=torch.long,
            device=self.device
        )

        return self.model(feats, stm).item()