|
|
from collections import deque
|
|
|
import random
|
|
|
import torch
|
|
|
import torch
|
|
|
from engine import GameState
|
|
|
from move_finder import find_best_move_shallow
|
|
|
from infer_nnue import gs_to_nnue_features
|
|
|
from nnue_model import NNUE
|
|
|
from tqdm import tqdm
|
|
|
from infer_nnue import NNUEInfer
|
|
|
NNUE_FEATURES = 32
|
|
|
def pad_features(feats):
|
|
|
if len(feats) < NNUE_FEATURES:
|
|
|
return feats + [0] * (NNUE_FEATURES - len(feats))
|
|
|
return feats[:NNUE_FEATURES]
|
|
|
|
|
|
import pickle
|
|
|
|
|
|
def load_pgn_dataset(path):
|
|
|
trajectories = []
|
|
|
current_traj = []
|
|
|
|
|
|
with open(path, "rb") as f:
|
|
|
while True:
|
|
|
try:
|
|
|
chunk = pickle.load(f)
|
|
|
for item in chunk:
|
|
|
current_traj.append(item)
|
|
|
|
|
|
|
|
|
if len(current_traj) > 1 and \
|
|
|
current_traj[-1]["stm"] != current_traj[-2]["stm"]:
|
|
|
trajectories.append(current_traj)
|
|
|
current_traj = []
|
|
|
|
|
|
except EOFError:
|
|
|
break
|
|
|
|
|
|
if current_traj:
|
|
|
trajectories.append(current_traj)
|
|
|
|
|
|
return trajectories
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
@torch.no_grad()
|
|
|
def td_targets_from_traj(model, traj, gamma=0.99):
|
|
|
if len(traj) == 1:
|
|
|
return [0.0]
|
|
|
|
|
|
feats = [pad_features(x["features"]) for x in traj]
|
|
|
stm = [x["stm"] for x in traj]
|
|
|
|
|
|
feats = torch.tensor(feats, dtype=torch.long, device="cuda")
|
|
|
stm = torch.tensor(stm, dtype=torch.long, device="cuda")
|
|
|
|
|
|
values = model(feats, stm).view(-1)
|
|
|
|
|
|
targets = torch.empty_like(values)
|
|
|
|
|
|
|
|
|
targets[:-1] = gamma * (-values[1:])
|
|
|
targets[-1] = values[-1].detach()
|
|
|
|
|
|
|
|
|
targets = torch.clamp(targets, -1.0, 1.0)
|
|
|
|
|
|
return targets.cpu().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
from collections import deque
|
|
|
import random
|
|
|
|
|
|
class ReplayBuffer:
|
|
|
def __init__(self, capacity=300_000):
|
|
|
self.buf = deque(maxlen=capacity)
|
|
|
|
|
|
def add(self, f, stm, t):
|
|
|
self.buf.append((f, stm, t))
|
|
|
|
|
|
def sample(self, n):
|
|
|
return random.sample(self.buf, n)
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.buf)
|
|
|
|
|
|
|
|
|
def train_from_replay(model, optimizer, replay, batch_size):
|
|
|
if len(replay) < batch_size:
|
|
|
return
|
|
|
|
|
|
batch = replay.sample(batch_size)
|
|
|
feats, stm, targets = zip(*batch)
|
|
|
|
|
|
feats = torch.tensor(feats, dtype=torch.long, device="cuda")
|
|
|
stm = torch.tensor(stm, dtype=torch.long, device="cuda")
|
|
|
targ = torch.tensor(targets, dtype=torch.float, device="cuda")
|
|
|
|
|
|
preds = model(feats, stm).view(-1)
|
|
|
|
|
|
loss = torch.nn.functional.smooth_l1_loss(preds, targ)
|
|
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
device = "cuda"
|
|
|
model = NNUE().to(device)
|
|
|
model.load_state_dict(torch.load("nnue_model.pt", weights_only=True))
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
|
|
|
|
|
|
replay = ReplayBuffer()
|
|
|
trajectories = load_pgn_dataset("nnue_dataset.pkl")
|
|
|
|
|
|
for epoch in range(3):
|
|
|
print(f"Epoch {epoch}")
|
|
|
|
|
|
for traj in tqdm(trajectories):
|
|
|
if len(traj) < 2:
|
|
|
continue
|
|
|
|
|
|
targets = td_targets_from_traj(model, traj)
|
|
|
|
|
|
for x, t in zip(traj, targets):
|
|
|
replay.add(
|
|
|
pad_features(x["features"]),
|
|
|
x["stm"],
|
|
|
t
|
|
|
)
|
|
|
|
|
|
for _ in range(3):
|
|
|
train_from_replay(model, optimizer, replay, batch_size=512)
|
|
|
|
|
|
torch.save(model.state_dict(), "nnue_model_td.pt")
|
|
|
|