|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import onnx
|
|
|
import sys
|
|
|
sys.stdout.reconfigure(encoding='utf-8')
|
|
|
|
|
|
NUM_FEATURES = 10 * 64 * 64
|
|
|
PAD_IDX = NUM_FEATURES
|
|
|
|
|
|
class NNUE(nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.embed = nn.Embedding(NUM_FEATURES + 1, 256, padding_idx=PAD_IDX)
|
|
|
self.fc1 = nn.Linear(256 + 1, 256)
|
|
|
self.fc2 = nn.Linear(256, 64)
|
|
|
self.fc3 = nn.Linear(64, 1)
|
|
|
|
|
|
def forward(self, feats, stm):
|
|
|
|
|
|
x = self.embed(feats).sum(dim=1)
|
|
|
|
|
|
|
|
|
stm = stm.float().unsqueeze(1) * 2 - 1
|
|
|
x = torch.cat([x, stm], dim=1)
|
|
|
|
|
|
x = torch.relu(self.fc1(x))
|
|
|
x = torch.relu(self.fc2(x))
|
|
|
return self.fc3(x).squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_model = NNUE()
|
|
|
torch_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BATCH = 1
|
|
|
ACTIVE_FEATURES = 32
|
|
|
|
|
|
|
|
|
feats = torch.randint(
|
|
|
low=0,
|
|
|
high=NUM_FEATURES,
|
|
|
size=(BATCH, ACTIVE_FEATURES),
|
|
|
dtype=torch.long
|
|
|
)
|
|
|
|
|
|
|
|
|
stm = torch.randint(0, 2, (BATCH,), dtype=torch.long)
|
|
|
|
|
|
example_inputs = (feats, stm)
|
|
|
|