from collections import deque import random import torch import torch from engine import GameState from move_finder import find_best_move_shallow from infer_nnue import gs_to_nnue_features from nnue_model import NNUE from tqdm import tqdm from infer_nnue import NNUEInfer NNUE_FEATURES = 32 def pad_features(feats): if len(feats) < NNUE_FEATURES: return feats + [0] * (NNUE_FEATURES - len(feats)) return feats[:NNUE_FEATURES] import pickle def load_pgn_dataset(path): trajectories = [] current_traj = [] with open(path, "rb") as f: while True: try: chunk = pickle.load(f) for item in chunk: current_traj.append(item) # heuristic: end trajectory on side-to-move flip if len(current_traj) > 1 and \ current_traj[-1]["stm"] != current_traj[-2]["stm"]: trajectories.append(current_traj) current_traj = [] except EOFError: break if current_traj: trajectories.append(current_traj) return trajectories @torch.no_grad() @torch.no_grad() def td_targets_from_traj(model, traj, gamma=0.99): if len(traj) == 1: return [0.0] feats = [pad_features(x["features"]) for x in traj] stm = [x["stm"] for x in traj] feats = torch.tensor(feats, dtype=torch.long, device="cuda") stm = torch.tensor(stm, dtype=torch.long, device="cuda") values = model(feats, stm).view(-1) targets = torch.empty_like(values) # TD(0) with turn flip targets[:-1] = gamma * (-values[1:]) targets[-1] = values[-1].detach() # value clipping (STOCKFISH STYLE) targets = torch.clamp(targets, -1.0, 1.0) return targets.cpu().tolist() from collections import deque import random class ReplayBuffer: def __init__(self, capacity=300_000): self.buf = deque(maxlen=capacity) def add(self, f, stm, t): self.buf.append((f, stm, t)) def sample(self, n): return random.sample(self.buf, n) def __len__(self): return len(self.buf) def train_from_replay(model, optimizer, replay, batch_size): if len(replay) < batch_size: return batch = replay.sample(batch_size) feats, stm, targets = zip(*batch) feats = torch.tensor(feats, dtype=torch.long, device="cuda") stm = torch.tensor(stm, dtype=torch.long, device="cuda") targ = torch.tensor(targets, dtype=torch.float, device="cuda") preds = model(feats, stm).view(-1) loss = torch.nn.functional.smooth_l1_loss(preds, targ) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() from tqdm import tqdm device = "cuda" model = NNUE().to(device) model.load_state_dict(torch.load("nnue_model.pt", weights_only=True)) optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) replay = ReplayBuffer() trajectories = load_pgn_dataset("nnue_dataset.pkl") for epoch in range(3): print(f"Epoch {epoch}") for traj in tqdm(trajectories): if len(traj) < 2: continue targets = td_targets_from_traj(model, traj) for x, t in zip(traj, targets): replay.add( pad_features(x["features"]), x["stm"], t ) for _ in range(3): train_from_replay(model, optimizer, replay, batch_size=512) torch.save(model.state_dict(), "nnue_model_td.pt")