chengscott commited on
Commit
026a224
·
1 Parent(s): e298a77
Files changed (4) hide show
  1. train/README +11 -0
  2. train/dataloader.py +85 -0
  3. train/model.py +165 -0
  4. train/train.py +123 -0
train/README ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train
2
+
3
+ ## Dependencies
4
+
5
+ python 3.13
6
+ pytorch 2.6.0+cu126
7
+ libzmq
8
+
9
+ ## Run
10
+
11
+ python train.py --game go9
train/dataloader.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import os
3
+ import re
4
+ from time import time
5
+ import torch
6
+ from rb import ReplayBuffer, SocketManager
7
+
8
+
9
+ @contextmanager
10
+ def timer(header):
11
+ time_start = time()
12
+ yield
13
+ print(header, (time() - time_start))
14
+
15
+
16
+ class DataLoader:
17
+ def __init__(self, port, cpus, batch_size, sgf_prefix, is_selfplay):
18
+ seed = torch.randint(0, 2**32, [1]).item()
19
+ rb = ReplayBuffer()
20
+ rb.max_iters = 20
21
+ # moves_per_iter = #games * #moves_per_game
22
+ rb.moves_per_iter = 5000 * 60
23
+ rb.run(seed, cpus, batch_size)
24
+ sock = SocketManager()
25
+ sock.run(port)
26
+ self._rb = rb
27
+ self._sock = sock
28
+ # total = #moves_per_game * #isom / batch_size
29
+ self._total = int(rb.moves_per_iter * 1 / batch_size)
30
+ self._sgf_prefix = sgf_prefix
31
+ self._iter = 0
32
+ self._is_selfplay = is_selfplay
33
+ if not is_selfplay:
34
+ import zmq
35
+
36
+ self._ctx = zmq.Context()
37
+ zsock = self._ctx.socket(zmq.DEALER)
38
+ zsock.setsockopt(zmq.LINGER, 0)
39
+ zsock.setsockopt(zmq.ROUTING_ID, b"0")
40
+ zsock.connect(f"tcp://127.0.0.1:{port}")
41
+ zsock.send_multipart([b"0", b"0"])
42
+ self._zsock = zsock
43
+
44
+ def load(self, sgf_prefix, epoch_ckpt):
45
+ rb, sock = self._rb, self._sock
46
+ for i in range(epoch_ckpt):
47
+ if self._iter > 0:
48
+ sock.notify()
49
+ if i == 0 or i + rb.max_iters >= epoch_ckpt:
50
+ print(f"[{i:3d}] Load selfplay")
51
+ pattern = re.compile(rf"iter-{self._iter}-(\d+).sgf")
52
+ nodes = 1 + max(
53
+ int(m.group(1))
54
+ for f in os.listdir(sgf_prefix)
55
+ if (m := pattern.search(f))
56
+ )
57
+ rb.add_iter(sgf_prefix, self._iter, nodes)
58
+ for _ in range(self._total):
59
+ rb.sample().free()
60
+ self._iter += 1
61
+
62
+ def __del__(self):
63
+ self._rb.terminate()
64
+ self._sock.terminate()
65
+
66
+ def __iter__(self):
67
+ rb, sock = self._rb, self._sock
68
+ if self._iter > 0:
69
+ sock.notify()
70
+ if self._is_selfplay:
71
+ with timer("[{:3d}] Time for selfplay:".format(self._iter)):
72
+ if sock.wait():
73
+ exit(0) # SIGINT
74
+ else:
75
+ finished = rb.moves_per_iter + 1
76
+ self._zsock.send_multipart(
77
+ [bytes(str(self._iter), "utf-8"), bytes(str(finished), "utf-8")]
78
+ )
79
+ with timer("[{:3d}] Time for training:".format(self._iter)):
80
+ rb.add_iter(self._sgf_prefix, self._iter, sock.nodes)
81
+ for _ in range(self._total):
82
+ sample = rb.sample()
83
+ yield sample
84
+ sample.free()
85
+ self._iter += 1
train/model.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class BasicBlock(nn.Module):
5
+ def __init__(self, in_channels, channels, bias, k=3, p=1):
6
+ super().__init__()
7
+ self.conv1 = nn.Conv2d(in_channels, channels, k, stride=1, padding=p, bias=bias)
8
+ self.bn1 = nn.BatchNorm2d(channels)
9
+ self.relu1 = nn.ReLU()
10
+ self.conv2 = nn.Conv2d(channels, channels, k, stride=1, padding=p, bias=bias)
11
+ self.bn2 = nn.BatchNorm2d(channels)
12
+ self.relu2 = nn.ReLU()
13
+
14
+ def forward(self, x):
15
+ y = self.conv1(x)
16
+ y = self.bn1(y)
17
+ y = self.relu1(y)
18
+ y = self.conv2(y)
19
+ y = self.bn2(y)
20
+ x = x + y
21
+ x = self.relu2(x)
22
+ return x
23
+
24
+
25
+ class Bottleneck(nn.Module):
26
+ def __init__(self, in_channels, channels, bias):
27
+ super().__init__()
28
+ mid_channels = channels // 2
29
+ self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias)
30
+ self.bn1 = nn.BatchNorm2d(mid_channels)
31
+ self.relu1 = nn.ReLU()
32
+ self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
33
+ self.bn2 = nn.BatchNorm2d(mid_channels)
34
+ self.relu2 = nn.ReLU()
35
+ self.conv3 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias)
36
+ self.bn3 = nn.BatchNorm2d(channels)
37
+ self.relu3 = nn.ReLU()
38
+
39
+ def forward(self, x):
40
+ y = self.conv1(x)
41
+ y = self.bn1(y)
42
+ y = self.relu1(y)
43
+ y = self.conv2(y)
44
+ y = self.bn2(y)
45
+ y = self.relu2(y)
46
+ y = self.conv3(y)
47
+ y = self.bn3(y)
48
+ x = x + y
49
+ x = self.relu3(x)
50
+ return x
51
+
52
+
53
+ class Bottlenest(nn.Module):
54
+ def __init__(self, in_channels, channels, bias):
55
+ super().__init__()
56
+ mid_channels = channels // 2
57
+ self.conv0 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias)
58
+ self.bn0 = nn.BatchNorm2d(mid_channels)
59
+ self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
60
+ self.bn1 = nn.BatchNorm2d(mid_channels)
61
+ self.relu1 = nn.ReLU()
62
+ self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
63
+ self.bn2 = nn.BatchNorm2d(mid_channels)
64
+ self.relu2 = nn.ReLU()
65
+ self.conv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
66
+ self.bn3 = nn.BatchNorm2d(mid_channels)
67
+ self.relu3 = nn.ReLU()
68
+ self.conv4 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
69
+ self.bn4 = nn.BatchNorm2d(mid_channels)
70
+ self.relu4 = nn.ReLU()
71
+ self.conv5 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias)
72
+ self.bn5 = nn.BatchNorm2d(channels)
73
+ self.relu5 = nn.ReLU()
74
+
75
+ def forward(self, x):
76
+ y = self.conv0(x)
77
+ y = self.bn0(y)
78
+ z = self.conv1(y)
79
+ z = self.bn1(z)
80
+ z = self.relu1(z)
81
+ z = self.conv2(z)
82
+ z = self.bn2(z)
83
+ y = y + z
84
+ y = self.relu2(y)
85
+ z = self.conv3(y)
86
+ z = self.bn3(z)
87
+ z = self.relu3(z)
88
+ z = self.conv4(z)
89
+ z = self.bn4(z)
90
+ y = y + z
91
+ y = self.relu4(y)
92
+ y = self.conv5(y)
93
+ y = self.bn5(y)
94
+ x = x + y
95
+ x = self.relu5(x)
96
+ return x
97
+
98
+
99
+ class ResNet(nn.Module):
100
+ def __init__(self, block, in_channels, layers, channels, bias):
101
+ super().__init__()
102
+ self.conv1 = nn.Sequential(
103
+ nn.Conv2d(
104
+ in_channels, channels, kernel_size=5, stride=1, padding=2, bias=bias
105
+ ),
106
+ nn.BatchNorm2d(channels),
107
+ nn.ReLU(),
108
+ )
109
+ self.convs = nn.ModuleList(
110
+ [block(channels, channels, bias) for _ in range(layers)]
111
+ )
112
+
113
+ def forward(self, x):
114
+ x = self.conv1(x)
115
+ for conv in self.convs:
116
+ x = conv(x)
117
+ return x
118
+
119
+
120
+ class AlphaZero(nn.Module):
121
+ def __init__(
122
+ self,
123
+ in_channels,
124
+ layers,
125
+ channels,
126
+ moves,
127
+ board_size,
128
+ value_heads=1,
129
+ bias=False,
130
+ block=BasicBlock,
131
+ ):
132
+ super().__init__()
133
+ self.board_size = board_size
134
+ self.resnet = ResNet(block, in_channels, layers, channels, bias)
135
+ # policy head
136
+ self.policy_head_front = nn.Sequential(
137
+ nn.Conv2d(channels, 2, 1),
138
+ nn.BatchNorm2d(2),
139
+ nn.ReLU(),
140
+ )
141
+ self.policy_head_end = nn.Linear(2 * board_size, moves)
142
+ # value head
143
+ self.value_head_front = nn.Sequential(
144
+ nn.Conv2d(channels, 1, 1),
145
+ nn.BatchNorm2d(1),
146
+ nn.ReLU(),
147
+ )
148
+ self.value_head_end = nn.Sequential(
149
+ nn.Linear(board_size, channels),
150
+ nn.ReLU(),
151
+ nn.Linear(channels, value_heads),
152
+ nn.Tanh(),
153
+ )
154
+
155
+ def forward(self, x):
156
+ x = self.resnet(x)
157
+ # policy head
158
+ p = self.policy_head_front(x)
159
+ p = p.view(-1, 2 * self.board_size)
160
+ p = self.policy_head_end(p)
161
+ # value head
162
+ v = self.value_head_front(x)
163
+ v = v.view(-1, self.board_size)
164
+ v = self.value_head_end(v)
165
+ return p, v
train/train.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataloader import DataLoader
2
+ from model import AlphaZero, BasicBlock, Bottlenest
3
+ #from export_ait import save_ait
4
+ import argparse
5
+ import os
6
+ import re
7
+ import time
8
+ import torch
9
+ from torch import nn
10
+
11
+ kGames = dict(
12
+ nogo=dict(num_features=4, moves=81, board_size=81, value_heads=1),
13
+ go9=dict(num_features=20, moves=82, board_size=81, value_heads=31),
14
+ go19=dict(num_features=20, moves=362, board_size=361, value_heads=31),
15
+ )
16
+
17
+
18
+ def save_model(model_prefix, epoch, net, optimizer, moves, board_size):
19
+ net.eval()
20
+ net_state = net.state_dict()
21
+ torch.save(
22
+ {
23
+ "epoch": epoch,
24
+ "net": net_state,
25
+ "optimizer": optimizer.state_dict(),
26
+ },
27
+ f"{model_prefix}/model-{epoch}.ckpt",
28
+ )
29
+ #save_ait(net_state, moves, board_size, f"{model_prefix}/model-{epoch}.ait")
30
+ net.train()
31
+
32
+
33
+ def main(args):
34
+ torch.backends.cudnn.benchmark = True
35
+ game = kGames[args.game]
36
+ moves, board_size = game["moves"], game["board_size"]
37
+ layers, channels, block = re.search(r"b(\d+)c(\d+)(.*)", args.model_prefix).groups()
38
+ block = BasicBlock if block == "" else Bottlenest
39
+ net = AlphaZero(
40
+ in_channels=game["num_features"],
41
+ layers=int(layers),
42
+ channels=int(channels),
43
+ moves=moves,
44
+ board_size=board_size,
45
+ value_heads=game["value_heads"],
46
+ bias=False,
47
+ block=block,
48
+ ).cuda()
49
+ # loss fn
50
+ p_criterion = lambda p_logits, p_labels: (
51
+ (-p_labels * torch.log_softmax(p_logits, dim=1)).sum(dim=1).mean()
52
+ )
53
+ v_criterion = nn.MSELoss()
54
+ optimizer = torch.optim.SGD(
55
+ net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001, nesterov=True
56
+ )
57
+ # load checkpoint
58
+ epoch_start = 0
59
+ dataloader = DataLoader(
60
+ args.port, args.cpus, args.batch_size, args.sgf_prefix, not args.pretrain
61
+ )
62
+ if args.load_ckpt:
63
+ print("> Restore from", args.load_ckpt)
64
+ ckpt = torch.load(args.load_ckpt, weights_only=True)
65
+ net.load_state_dict(ckpt["net"])
66
+ optimizer.load_state_dict(ckpt["optimizer"])
67
+ if args.load_data:
68
+ epoch_start = ckpt["epoch"]
69
+ dataloader.load(args.load_data, epoch_start)
70
+ save_model(args.model_prefix, epoch_start, net, optimizer, moves, board_size)
71
+ print("> Start training")
72
+ # train
73
+ for epoch in range(epoch_start, epoch_start + 6000):
74
+ net.train()
75
+ time_start = time.time()
76
+ for i, batch in enumerate(dataloader):
77
+ inputs, p_labels, v_labels = batch.inputs, batch.policy, batch.value
78
+
79
+ # forward + backward
80
+ p_logits, v_logits = net(inputs)
81
+
82
+ v_loss = v_criterion(v_logits, v_labels)
83
+ p_loss = p_criterion(p_logits, p_labels)
84
+ loss = v_loss * args.value_ratio + p_loss
85
+
86
+ # optimize
87
+ optimizer.zero_grad()
88
+ loss.backward()
89
+ optimizer.step()
90
+
91
+ # train loss
92
+ if i % 10 == 0:
93
+ print(
94
+ "[{:3d}:{:5d}] PN_Loss: {:.5f} VN_Loss: {:.5f}".format(
95
+ epoch, i, p_loss.item(), v_loss.item()
96
+ )
97
+ )
98
+
99
+ print("[{:3d}] Time per epoch: {}".format(epoch, time.time() - time_start))
100
+ save_model(args.model_prefix, epoch + 1, net, optimizer, moves, board_size)
101
+
102
+
103
+ if __name__ == "__main__":
104
+ parser = argparse.ArgumentParser()
105
+ # game
106
+ parser.add_argument("--game", default="nogo")
107
+ # training
108
+ parser.add_argument("--pretrain", action="store_true")
109
+ parser.add_argument("--sgf-prefix", default="../selfplay/sp")
110
+ parser.add_argument("--model-prefix", default="models_b6c96")
111
+ parser.add_argument("--load-ckpt", default="")
112
+ parser.add_argument("--load-data", default="")
113
+ parser.add_argument("--cpus", default=32, type=int)
114
+ parser.add_argument("--port", default=5566, type=int)
115
+ # hyperparameters
116
+ parser.add_argument("-lr", "--lr", default=0.01, type=float)
117
+ parser.add_argument("-bs", "--batch-size", default=512, type=int)
118
+ parser.add_argument("-vr", "--value-ratio", default=1, type=float)
119
+
120
+ args = parser.parse_args()
121
+ os.makedirs(args.sgf_prefix, exist_ok=True)
122
+ os.makedirs(args.model_prefix, exist_ok=True)
123
+ main(args)