nnue / final /games_play.py
hash-map's picture
Upload 40 files
98ab355 verified
from collections import deque
import random
import torch
import torch
from engine import GameState
from move_finder import find_best_move_shallow
from infer_nnue import gs_to_nnue_features
from nnue_model import NNUE
from tqdm import tqdm
from infer_nnue import NNUEInfer
NNUE_FEATURES = 32
def pad_features(feats):
if len(feats) < NNUE_FEATURES:
return feats + [0] * (NNUE_FEATURES - len(feats))
return feats[:NNUE_FEATURES]
import pickle
def load_pgn_dataset(path):
trajectories = []
current_traj = []
with open(path, "rb") as f:
while True:
try:
chunk = pickle.load(f)
for item in chunk:
current_traj.append(item)
# heuristic: end trajectory on side-to-move flip
if len(current_traj) > 1 and \
current_traj[-1]["stm"] != current_traj[-2]["stm"]:
trajectories.append(current_traj)
current_traj = []
except EOFError:
break
if current_traj:
trajectories.append(current_traj)
return trajectories
@torch.no_grad()
@torch.no_grad()
def td_targets_from_traj(model, traj, gamma=0.99):
if len(traj) == 1:
return [0.0]
feats = [pad_features(x["features"]) for x in traj]
stm = [x["stm"] for x in traj]
feats = torch.tensor(feats, dtype=torch.long, device="cuda")
stm = torch.tensor(stm, dtype=torch.long, device="cuda")
values = model(feats, stm).view(-1)
targets = torch.empty_like(values)
# TD(0) with turn flip
targets[:-1] = gamma * (-values[1:])
targets[-1] = values[-1].detach()
# value clipping (STOCKFISH STYLE)
targets = torch.clamp(targets, -1.0, 1.0)
return targets.cpu().tolist()
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=300_000):
self.buf = deque(maxlen=capacity)
def add(self, f, stm, t):
self.buf.append((f, stm, t))
def sample(self, n):
return random.sample(self.buf, n)
def __len__(self):
return len(self.buf)
def train_from_replay(model, optimizer, replay, batch_size):
if len(replay) < batch_size:
return
batch = replay.sample(batch_size)
feats, stm, targets = zip(*batch)
feats = torch.tensor(feats, dtype=torch.long, device="cuda")
stm = torch.tensor(stm, dtype=torch.long, device="cuda")
targ = torch.tensor(targets, dtype=torch.float, device="cuda")
preds = model(feats, stm).view(-1)
loss = torch.nn.functional.smooth_l1_loss(preds, targ)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
from tqdm import tqdm
device = "cuda"
model = NNUE().to(device)
model.load_state_dict(torch.load("nnue_model.pt", weights_only=True))
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
replay = ReplayBuffer()
trajectories = load_pgn_dataset("nnue_dataset.pkl")
for epoch in range(3):
print(f"Epoch {epoch}")
for traj in tqdm(trajectories):
if len(traj) < 2:
continue
targets = td_targets_from_traj(model, traj)
for x, t in zip(traj, targets):
replay.add(
pad_features(x["features"]),
x["stm"],
t
)
for _ in range(3):
train_from_replay(model, optimizer, replay, batch_size=512)
torch.save(model.state_dict(), "nnue_model_td.pt")