File size: 3,652 Bytes
98ab355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from collections import deque
import random
import torch
import torch
from engine import GameState
from move_finder import find_best_move_shallow
from infer_nnue import gs_to_nnue_features
from nnue_model import NNUE
from tqdm import tqdm
from infer_nnue import NNUEInfer
NNUE_FEATURES = 32
def pad_features(feats):
    if len(feats) < NNUE_FEATURES:
        return feats + [0] * (NNUE_FEATURES - len(feats))
    return feats[:NNUE_FEATURES]

import pickle

def load_pgn_dataset(path):
    trajectories = []
    current_traj = []

    with open(path, "rb") as f:
        while True:
            try:
                chunk = pickle.load(f)
                for item in chunk:
                    current_traj.append(item)

                    # heuristic: end trajectory on side-to-move flip
                    if len(current_traj) > 1 and \
                       current_traj[-1]["stm"] != current_traj[-2]["stm"]:
                        trajectories.append(current_traj)
                        current_traj = []

            except EOFError:
                break

    if current_traj:
        trajectories.append(current_traj)

    return trajectories


@torch.no_grad()
@torch.no_grad()
def td_targets_from_traj(model, traj, gamma=0.99):
    if len(traj) == 1:
        return [0.0]

    feats = [pad_features(x["features"]) for x in traj]
    stm   = [x["stm"] for x in traj]

    feats = torch.tensor(feats, dtype=torch.long, device="cuda")
    stm   = torch.tensor(stm, dtype=torch.long, device="cuda")

    values = model(feats, stm).view(-1)

    targets = torch.empty_like(values)

    # TD(0) with turn flip
    targets[:-1] = gamma * (-values[1:])
    targets[-1]  = values[-1].detach()

    # value clipping (STOCKFISH STYLE)
    targets = torch.clamp(targets, -1.0, 1.0)

    return targets.cpu().tolist()



from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity=300_000):
        self.buf = deque(maxlen=capacity)

    def add(self, f, stm, t):
        self.buf.append((f, stm, t))

    def sample(self, n):
        return random.sample(self.buf, n)

    def __len__(self):
        return len(self.buf)


def train_from_replay(model, optimizer, replay, batch_size):
    if len(replay) < batch_size:
        return

    batch = replay.sample(batch_size)
    feats, stm, targets = zip(*batch)

    feats = torch.tensor(feats, dtype=torch.long, device="cuda")
    stm   = torch.tensor(stm, dtype=torch.long, device="cuda")
    targ  = torch.tensor(targets, dtype=torch.float, device="cuda")

    preds = model(feats, stm).view(-1)

    loss = torch.nn.functional.smooth_l1_loss(preds, targ)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


from tqdm import tqdm
device = "cuda"
model = NNUE().to(device)
model.load_state_dict(torch.load("nnue_model.pt", weights_only=True))
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

replay = ReplayBuffer()
trajectories = load_pgn_dataset("nnue_dataset.pkl")

for epoch in range(3):
    print(f"Epoch {epoch}")

    for traj in tqdm(trajectories):
        if len(traj) < 2:
            continue

        targets = td_targets_from_traj(model, traj)

        for x, t in zip(traj, targets):
            replay.add(
                pad_features(x["features"]),
                x["stm"],
                t
            )

        for _ in range(3):
            train_from_replay(model, optimizer, replay, batch_size=512)

    torch.save(model.state_dict(), "nnue_model_td.pt")