File size: 3,652 Bytes
98ab355 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from collections import deque
import random
import torch
import torch
from engine import GameState
from move_finder import find_best_move_shallow
from infer_nnue import gs_to_nnue_features
from nnue_model import NNUE
from tqdm import tqdm
from infer_nnue import NNUEInfer
NNUE_FEATURES = 32
def pad_features(feats):
if len(feats) < NNUE_FEATURES:
return feats + [0] * (NNUE_FEATURES - len(feats))
return feats[:NNUE_FEATURES]
import pickle
def load_pgn_dataset(path):
trajectories = []
current_traj = []
with open(path, "rb") as f:
while True:
try:
chunk = pickle.load(f)
for item in chunk:
current_traj.append(item)
# heuristic: end trajectory on side-to-move flip
if len(current_traj) > 1 and \
current_traj[-1]["stm"] != current_traj[-2]["stm"]:
trajectories.append(current_traj)
current_traj = []
except EOFError:
break
if current_traj:
trajectories.append(current_traj)
return trajectories
@torch.no_grad()
@torch.no_grad()
def td_targets_from_traj(model, traj, gamma=0.99):
if len(traj) == 1:
return [0.0]
feats = [pad_features(x["features"]) for x in traj]
stm = [x["stm"] for x in traj]
feats = torch.tensor(feats, dtype=torch.long, device="cuda")
stm = torch.tensor(stm, dtype=torch.long, device="cuda")
values = model(feats, stm).view(-1)
targets = torch.empty_like(values)
# TD(0) with turn flip
targets[:-1] = gamma * (-values[1:])
targets[-1] = values[-1].detach()
# value clipping (STOCKFISH STYLE)
targets = torch.clamp(targets, -1.0, 1.0)
return targets.cpu().tolist()
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=300_000):
self.buf = deque(maxlen=capacity)
def add(self, f, stm, t):
self.buf.append((f, stm, t))
def sample(self, n):
return random.sample(self.buf, n)
def __len__(self):
return len(self.buf)
def train_from_replay(model, optimizer, replay, batch_size):
if len(replay) < batch_size:
return
batch = replay.sample(batch_size)
feats, stm, targets = zip(*batch)
feats = torch.tensor(feats, dtype=torch.long, device="cuda")
stm = torch.tensor(stm, dtype=torch.long, device="cuda")
targ = torch.tensor(targets, dtype=torch.float, device="cuda")
preds = model(feats, stm).view(-1)
loss = torch.nn.functional.smooth_l1_loss(preds, targ)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
from tqdm import tqdm
device = "cuda"
model = NNUE().to(device)
model.load_state_dict(torch.load("nnue_model.pt", weights_only=True))
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
replay = ReplayBuffer()
trajectories = load_pgn_dataset("nnue_dataset.pkl")
for epoch in range(3):
print(f"Epoch {epoch}")
for traj in tqdm(trajectories):
if len(traj) < 2:
continue
targets = td_targets_from_traj(model, traj)
for x, t in zip(traj, targets):
replay.add(
pad_features(x["features"]),
x["stm"],
t
)
for _ in range(3):
train_from_replay(model, optimizer, replay, batch_size=512)
torch.save(model.state_dict(), "nnue_model_td.pt")
|