Commit
·
026a224
1
Parent(s):
e298a77
train
Browse files- train/README +11 -0
- train/dataloader.py +85 -0
- train/model.py +165 -0
- train/train.py +123 -0
train/README
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Train
|
| 2 |
+
|
| 3 |
+
## Dependencies
|
| 4 |
+
|
| 5 |
+
python 3.13
|
| 6 |
+
pytorch 2.6.0+cu126
|
| 7 |
+
libzmq
|
| 8 |
+
|
| 9 |
+
## Run
|
| 10 |
+
|
| 11 |
+
python train.py --game go9
|
train/dataloader.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from time import time
|
| 5 |
+
import torch
|
| 6 |
+
from rb import ReplayBuffer, SocketManager
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@contextmanager
|
| 10 |
+
def timer(header):
|
| 11 |
+
time_start = time()
|
| 12 |
+
yield
|
| 13 |
+
print(header, (time() - time_start))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DataLoader:
|
| 17 |
+
def __init__(self, port, cpus, batch_size, sgf_prefix, is_selfplay):
|
| 18 |
+
seed = torch.randint(0, 2**32, [1]).item()
|
| 19 |
+
rb = ReplayBuffer()
|
| 20 |
+
rb.max_iters = 20
|
| 21 |
+
# moves_per_iter = #games * #moves_per_game
|
| 22 |
+
rb.moves_per_iter = 5000 * 60
|
| 23 |
+
rb.run(seed, cpus, batch_size)
|
| 24 |
+
sock = SocketManager()
|
| 25 |
+
sock.run(port)
|
| 26 |
+
self._rb = rb
|
| 27 |
+
self._sock = sock
|
| 28 |
+
# total = #moves_per_game * #isom / batch_size
|
| 29 |
+
self._total = int(rb.moves_per_iter * 1 / batch_size)
|
| 30 |
+
self._sgf_prefix = sgf_prefix
|
| 31 |
+
self._iter = 0
|
| 32 |
+
self._is_selfplay = is_selfplay
|
| 33 |
+
if not is_selfplay:
|
| 34 |
+
import zmq
|
| 35 |
+
|
| 36 |
+
self._ctx = zmq.Context()
|
| 37 |
+
zsock = self._ctx.socket(zmq.DEALER)
|
| 38 |
+
zsock.setsockopt(zmq.LINGER, 0)
|
| 39 |
+
zsock.setsockopt(zmq.ROUTING_ID, b"0")
|
| 40 |
+
zsock.connect(f"tcp://127.0.0.1:{port}")
|
| 41 |
+
zsock.send_multipart([b"0", b"0"])
|
| 42 |
+
self._zsock = zsock
|
| 43 |
+
|
| 44 |
+
def load(self, sgf_prefix, epoch_ckpt):
|
| 45 |
+
rb, sock = self._rb, self._sock
|
| 46 |
+
for i in range(epoch_ckpt):
|
| 47 |
+
if self._iter > 0:
|
| 48 |
+
sock.notify()
|
| 49 |
+
if i == 0 or i + rb.max_iters >= epoch_ckpt:
|
| 50 |
+
print(f"[{i:3d}] Load selfplay")
|
| 51 |
+
pattern = re.compile(rf"iter-{self._iter}-(\d+).sgf")
|
| 52 |
+
nodes = 1 + max(
|
| 53 |
+
int(m.group(1))
|
| 54 |
+
for f in os.listdir(sgf_prefix)
|
| 55 |
+
if (m := pattern.search(f))
|
| 56 |
+
)
|
| 57 |
+
rb.add_iter(sgf_prefix, self._iter, nodes)
|
| 58 |
+
for _ in range(self._total):
|
| 59 |
+
rb.sample().free()
|
| 60 |
+
self._iter += 1
|
| 61 |
+
|
| 62 |
+
def __del__(self):
|
| 63 |
+
self._rb.terminate()
|
| 64 |
+
self._sock.terminate()
|
| 65 |
+
|
| 66 |
+
def __iter__(self):
|
| 67 |
+
rb, sock = self._rb, self._sock
|
| 68 |
+
if self._iter > 0:
|
| 69 |
+
sock.notify()
|
| 70 |
+
if self._is_selfplay:
|
| 71 |
+
with timer("[{:3d}] Time for selfplay:".format(self._iter)):
|
| 72 |
+
if sock.wait():
|
| 73 |
+
exit(0) # SIGINT
|
| 74 |
+
else:
|
| 75 |
+
finished = rb.moves_per_iter + 1
|
| 76 |
+
self._zsock.send_multipart(
|
| 77 |
+
[bytes(str(self._iter), "utf-8"), bytes(str(finished), "utf-8")]
|
| 78 |
+
)
|
| 79 |
+
with timer("[{:3d}] Time for training:".format(self._iter)):
|
| 80 |
+
rb.add_iter(self._sgf_prefix, self._iter, sock.nodes)
|
| 81 |
+
for _ in range(self._total):
|
| 82 |
+
sample = rb.sample()
|
| 83 |
+
yield sample
|
| 84 |
+
sample.free()
|
| 85 |
+
self._iter += 1
|
train/model.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BasicBlock(nn.Module):
|
| 5 |
+
def __init__(self, in_channels, channels, bias, k=3, p=1):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.conv1 = nn.Conv2d(in_channels, channels, k, stride=1, padding=p, bias=bias)
|
| 8 |
+
self.bn1 = nn.BatchNorm2d(channels)
|
| 9 |
+
self.relu1 = nn.ReLU()
|
| 10 |
+
self.conv2 = nn.Conv2d(channels, channels, k, stride=1, padding=p, bias=bias)
|
| 11 |
+
self.bn2 = nn.BatchNorm2d(channels)
|
| 12 |
+
self.relu2 = nn.ReLU()
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
y = self.conv1(x)
|
| 16 |
+
y = self.bn1(y)
|
| 17 |
+
y = self.relu1(y)
|
| 18 |
+
y = self.conv2(y)
|
| 19 |
+
y = self.bn2(y)
|
| 20 |
+
x = x + y
|
| 21 |
+
x = self.relu2(x)
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Bottleneck(nn.Module):
|
| 26 |
+
def __init__(self, in_channels, channels, bias):
|
| 27 |
+
super().__init__()
|
| 28 |
+
mid_channels = channels // 2
|
| 29 |
+
self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias)
|
| 30 |
+
self.bn1 = nn.BatchNorm2d(mid_channels)
|
| 31 |
+
self.relu1 = nn.ReLU()
|
| 32 |
+
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
|
| 33 |
+
self.bn2 = nn.BatchNorm2d(mid_channels)
|
| 34 |
+
self.relu2 = nn.ReLU()
|
| 35 |
+
self.conv3 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias)
|
| 36 |
+
self.bn3 = nn.BatchNorm2d(channels)
|
| 37 |
+
self.relu3 = nn.ReLU()
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
y = self.conv1(x)
|
| 41 |
+
y = self.bn1(y)
|
| 42 |
+
y = self.relu1(y)
|
| 43 |
+
y = self.conv2(y)
|
| 44 |
+
y = self.bn2(y)
|
| 45 |
+
y = self.relu2(y)
|
| 46 |
+
y = self.conv3(y)
|
| 47 |
+
y = self.bn3(y)
|
| 48 |
+
x = x + y
|
| 49 |
+
x = self.relu3(x)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Bottlenest(nn.Module):
|
| 54 |
+
def __init__(self, in_channels, channels, bias):
|
| 55 |
+
super().__init__()
|
| 56 |
+
mid_channels = channels // 2
|
| 57 |
+
self.conv0 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias)
|
| 58 |
+
self.bn0 = nn.BatchNorm2d(mid_channels)
|
| 59 |
+
self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
|
| 60 |
+
self.bn1 = nn.BatchNorm2d(mid_channels)
|
| 61 |
+
self.relu1 = nn.ReLU()
|
| 62 |
+
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
|
| 63 |
+
self.bn2 = nn.BatchNorm2d(mid_channels)
|
| 64 |
+
self.relu2 = nn.ReLU()
|
| 65 |
+
self.conv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
|
| 66 |
+
self.bn3 = nn.BatchNorm2d(mid_channels)
|
| 67 |
+
self.relu3 = nn.ReLU()
|
| 68 |
+
self.conv4 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
|
| 69 |
+
self.bn4 = nn.BatchNorm2d(mid_channels)
|
| 70 |
+
self.relu4 = nn.ReLU()
|
| 71 |
+
self.conv5 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias)
|
| 72 |
+
self.bn5 = nn.BatchNorm2d(channels)
|
| 73 |
+
self.relu5 = nn.ReLU()
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
y = self.conv0(x)
|
| 77 |
+
y = self.bn0(y)
|
| 78 |
+
z = self.conv1(y)
|
| 79 |
+
z = self.bn1(z)
|
| 80 |
+
z = self.relu1(z)
|
| 81 |
+
z = self.conv2(z)
|
| 82 |
+
z = self.bn2(z)
|
| 83 |
+
y = y + z
|
| 84 |
+
y = self.relu2(y)
|
| 85 |
+
z = self.conv3(y)
|
| 86 |
+
z = self.bn3(z)
|
| 87 |
+
z = self.relu3(z)
|
| 88 |
+
z = self.conv4(z)
|
| 89 |
+
z = self.bn4(z)
|
| 90 |
+
y = y + z
|
| 91 |
+
y = self.relu4(y)
|
| 92 |
+
y = self.conv5(y)
|
| 93 |
+
y = self.bn5(y)
|
| 94 |
+
x = x + y
|
| 95 |
+
x = self.relu5(x)
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ResNet(nn.Module):
|
| 100 |
+
def __init__(self, block, in_channels, layers, channels, bias):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.conv1 = nn.Sequential(
|
| 103 |
+
nn.Conv2d(
|
| 104 |
+
in_channels, channels, kernel_size=5, stride=1, padding=2, bias=bias
|
| 105 |
+
),
|
| 106 |
+
nn.BatchNorm2d(channels),
|
| 107 |
+
nn.ReLU(),
|
| 108 |
+
)
|
| 109 |
+
self.convs = nn.ModuleList(
|
| 110 |
+
[block(channels, channels, bias) for _ in range(layers)]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
x = self.conv1(x)
|
| 115 |
+
for conv in self.convs:
|
| 116 |
+
x = conv(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class AlphaZero(nn.Module):
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
in_channels,
|
| 124 |
+
layers,
|
| 125 |
+
channels,
|
| 126 |
+
moves,
|
| 127 |
+
board_size,
|
| 128 |
+
value_heads=1,
|
| 129 |
+
bias=False,
|
| 130 |
+
block=BasicBlock,
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.board_size = board_size
|
| 134 |
+
self.resnet = ResNet(block, in_channels, layers, channels, bias)
|
| 135 |
+
# policy head
|
| 136 |
+
self.policy_head_front = nn.Sequential(
|
| 137 |
+
nn.Conv2d(channels, 2, 1),
|
| 138 |
+
nn.BatchNorm2d(2),
|
| 139 |
+
nn.ReLU(),
|
| 140 |
+
)
|
| 141 |
+
self.policy_head_end = nn.Linear(2 * board_size, moves)
|
| 142 |
+
# value head
|
| 143 |
+
self.value_head_front = nn.Sequential(
|
| 144 |
+
nn.Conv2d(channels, 1, 1),
|
| 145 |
+
nn.BatchNorm2d(1),
|
| 146 |
+
nn.ReLU(),
|
| 147 |
+
)
|
| 148 |
+
self.value_head_end = nn.Sequential(
|
| 149 |
+
nn.Linear(board_size, channels),
|
| 150 |
+
nn.ReLU(),
|
| 151 |
+
nn.Linear(channels, value_heads),
|
| 152 |
+
nn.Tanh(),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x = self.resnet(x)
|
| 157 |
+
# policy head
|
| 158 |
+
p = self.policy_head_front(x)
|
| 159 |
+
p = p.view(-1, 2 * self.board_size)
|
| 160 |
+
p = self.policy_head_end(p)
|
| 161 |
+
# value head
|
| 162 |
+
v = self.value_head_front(x)
|
| 163 |
+
v = v.view(-1, self.board_size)
|
| 164 |
+
v = self.value_head_end(v)
|
| 165 |
+
return p, v
|
train/train.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataloader import DataLoader
|
| 2 |
+
from model import AlphaZero, BasicBlock, Bottlenest
|
| 3 |
+
#from export_ait import save_ait
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import time
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
kGames = dict(
|
| 12 |
+
nogo=dict(num_features=4, moves=81, board_size=81, value_heads=1),
|
| 13 |
+
go9=dict(num_features=20, moves=82, board_size=81, value_heads=31),
|
| 14 |
+
go19=dict(num_features=20, moves=362, board_size=361, value_heads=31),
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def save_model(model_prefix, epoch, net, optimizer, moves, board_size):
|
| 19 |
+
net.eval()
|
| 20 |
+
net_state = net.state_dict()
|
| 21 |
+
torch.save(
|
| 22 |
+
{
|
| 23 |
+
"epoch": epoch,
|
| 24 |
+
"net": net_state,
|
| 25 |
+
"optimizer": optimizer.state_dict(),
|
| 26 |
+
},
|
| 27 |
+
f"{model_prefix}/model-{epoch}.ckpt",
|
| 28 |
+
)
|
| 29 |
+
#save_ait(net_state, moves, board_size, f"{model_prefix}/model-{epoch}.ait")
|
| 30 |
+
net.train()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main(args):
|
| 34 |
+
torch.backends.cudnn.benchmark = True
|
| 35 |
+
game = kGames[args.game]
|
| 36 |
+
moves, board_size = game["moves"], game["board_size"]
|
| 37 |
+
layers, channels, block = re.search(r"b(\d+)c(\d+)(.*)", args.model_prefix).groups()
|
| 38 |
+
block = BasicBlock if block == "" else Bottlenest
|
| 39 |
+
net = AlphaZero(
|
| 40 |
+
in_channels=game["num_features"],
|
| 41 |
+
layers=int(layers),
|
| 42 |
+
channels=int(channels),
|
| 43 |
+
moves=moves,
|
| 44 |
+
board_size=board_size,
|
| 45 |
+
value_heads=game["value_heads"],
|
| 46 |
+
bias=False,
|
| 47 |
+
block=block,
|
| 48 |
+
).cuda()
|
| 49 |
+
# loss fn
|
| 50 |
+
p_criterion = lambda p_logits, p_labels: (
|
| 51 |
+
(-p_labels * torch.log_softmax(p_logits, dim=1)).sum(dim=1).mean()
|
| 52 |
+
)
|
| 53 |
+
v_criterion = nn.MSELoss()
|
| 54 |
+
optimizer = torch.optim.SGD(
|
| 55 |
+
net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001, nesterov=True
|
| 56 |
+
)
|
| 57 |
+
# load checkpoint
|
| 58 |
+
epoch_start = 0
|
| 59 |
+
dataloader = DataLoader(
|
| 60 |
+
args.port, args.cpus, args.batch_size, args.sgf_prefix, not args.pretrain
|
| 61 |
+
)
|
| 62 |
+
if args.load_ckpt:
|
| 63 |
+
print("> Restore from", args.load_ckpt)
|
| 64 |
+
ckpt = torch.load(args.load_ckpt, weights_only=True)
|
| 65 |
+
net.load_state_dict(ckpt["net"])
|
| 66 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 67 |
+
if args.load_data:
|
| 68 |
+
epoch_start = ckpt["epoch"]
|
| 69 |
+
dataloader.load(args.load_data, epoch_start)
|
| 70 |
+
save_model(args.model_prefix, epoch_start, net, optimizer, moves, board_size)
|
| 71 |
+
print("> Start training")
|
| 72 |
+
# train
|
| 73 |
+
for epoch in range(epoch_start, epoch_start + 6000):
|
| 74 |
+
net.train()
|
| 75 |
+
time_start = time.time()
|
| 76 |
+
for i, batch in enumerate(dataloader):
|
| 77 |
+
inputs, p_labels, v_labels = batch.inputs, batch.policy, batch.value
|
| 78 |
+
|
| 79 |
+
# forward + backward
|
| 80 |
+
p_logits, v_logits = net(inputs)
|
| 81 |
+
|
| 82 |
+
v_loss = v_criterion(v_logits, v_labels)
|
| 83 |
+
p_loss = p_criterion(p_logits, p_labels)
|
| 84 |
+
loss = v_loss * args.value_ratio + p_loss
|
| 85 |
+
|
| 86 |
+
# optimize
|
| 87 |
+
optimizer.zero_grad()
|
| 88 |
+
loss.backward()
|
| 89 |
+
optimizer.step()
|
| 90 |
+
|
| 91 |
+
# train loss
|
| 92 |
+
if i % 10 == 0:
|
| 93 |
+
print(
|
| 94 |
+
"[{:3d}:{:5d}] PN_Loss: {:.5f} VN_Loss: {:.5f}".format(
|
| 95 |
+
epoch, i, p_loss.item(), v_loss.item()
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
print("[{:3d}] Time per epoch: {}".format(epoch, time.time() - time_start))
|
| 100 |
+
save_model(args.model_prefix, epoch + 1, net, optimizer, moves, board_size)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
parser = argparse.ArgumentParser()
|
| 105 |
+
# game
|
| 106 |
+
parser.add_argument("--game", default="nogo")
|
| 107 |
+
# training
|
| 108 |
+
parser.add_argument("--pretrain", action="store_true")
|
| 109 |
+
parser.add_argument("--sgf-prefix", default="../selfplay/sp")
|
| 110 |
+
parser.add_argument("--model-prefix", default="models_b6c96")
|
| 111 |
+
parser.add_argument("--load-ckpt", default="")
|
| 112 |
+
parser.add_argument("--load-data", default="")
|
| 113 |
+
parser.add_argument("--cpus", default=32, type=int)
|
| 114 |
+
parser.add_argument("--port", default=5566, type=int)
|
| 115 |
+
# hyperparameters
|
| 116 |
+
parser.add_argument("-lr", "--lr", default=0.01, type=float)
|
| 117 |
+
parser.add_argument("-bs", "--batch-size", default=512, type=int)
|
| 118 |
+
parser.add_argument("-vr", "--value-ratio", default=1, type=float)
|
| 119 |
+
|
| 120 |
+
args = parser.parse_args()
|
| 121 |
+
os.makedirs(args.sgf_prefix, exist_ok=True)
|
| 122 |
+
os.makedirs(args.model_prefix, exist_ok=True)
|
| 123 |
+
main(args)
|