|
|
from dataloader import DataLoader |
|
|
from model import AlphaZero, BasicBlock, Bottlenest |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import re |
|
|
import time |
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
kGames = dict( |
|
|
nogo=dict(num_features=4, moves=81, board_size=81, value_heads=1), |
|
|
go9=dict(num_features=20, moves=82, board_size=81, value_heads=31), |
|
|
go19=dict(num_features=20, moves=362, board_size=361, value_heads=31), |
|
|
) |
|
|
|
|
|
|
|
|
def save_model(model_prefix, epoch, net, optimizer, moves, board_size): |
|
|
net.eval() |
|
|
net_state = net.state_dict() |
|
|
torch.save( |
|
|
{ |
|
|
"epoch": epoch, |
|
|
"net": net_state, |
|
|
"optimizer": optimizer.state_dict(), |
|
|
}, |
|
|
f"{model_prefix}/model-{epoch}.ckpt", |
|
|
) |
|
|
|
|
|
net.train() |
|
|
|
|
|
|
|
|
def main(args): |
|
|
torch.backends.cudnn.benchmark = True |
|
|
game = kGames[args.game] |
|
|
moves, board_size = game["moves"], game["board_size"] |
|
|
layers, channels, block = re.search(r"b(\d+)c(\d+)(.*)", args.model_prefix).groups() |
|
|
block = BasicBlock if block == "" else Bottlenest |
|
|
net = AlphaZero( |
|
|
in_channels=game["num_features"], |
|
|
layers=int(layers), |
|
|
channels=int(channels), |
|
|
moves=moves, |
|
|
board_size=board_size, |
|
|
value_heads=game["value_heads"], |
|
|
bias=False, |
|
|
block=block, |
|
|
).cuda() |
|
|
|
|
|
p_criterion = lambda p_logits, p_labels: ( |
|
|
(-p_labels * torch.log_softmax(p_logits, dim=1)).sum(dim=1).mean() |
|
|
) |
|
|
v_criterion = nn.MSELoss() |
|
|
optimizer = torch.optim.SGD( |
|
|
net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001, nesterov=True |
|
|
) |
|
|
|
|
|
epoch_start = 0 |
|
|
dataloader = DataLoader( |
|
|
args.port, args.cpus, args.batch_size, args.sgf_prefix, not args.pretrain |
|
|
) |
|
|
if args.load_ckpt: |
|
|
print("> Restore from", args.load_ckpt) |
|
|
ckpt = torch.load(args.load_ckpt, weights_only=True) |
|
|
net.load_state_dict(ckpt["net"]) |
|
|
optimizer.load_state_dict(ckpt["optimizer"]) |
|
|
if args.load_data: |
|
|
epoch_start = ckpt["epoch"] |
|
|
dataloader.load(args.load_data, epoch_start) |
|
|
save_model(args.model_prefix, epoch_start, net, optimizer, moves, board_size) |
|
|
print("> Start training") |
|
|
|
|
|
for epoch in range(epoch_start, epoch_start + 6000): |
|
|
net.train() |
|
|
time_start = time.time() |
|
|
for i, batch in enumerate(dataloader): |
|
|
inputs, p_labels, v_labels = batch.inputs, batch.policy, batch.value |
|
|
|
|
|
|
|
|
p_logits, v_logits = net(inputs) |
|
|
|
|
|
v_loss = v_criterion(v_logits, v_labels) |
|
|
p_loss = p_criterion(p_logits, p_labels) |
|
|
loss = v_loss * args.value_ratio + p_loss |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
if i % 10 == 0: |
|
|
print( |
|
|
"[{:3d}:{:5d}] PN_Loss: {:.5f} VN_Loss: {:.5f}".format( |
|
|
epoch, i, p_loss.item(), v_loss.item() |
|
|
) |
|
|
) |
|
|
|
|
|
print("[{:3d}] Time per epoch: {}".format(epoch, time.time() - time_start)) |
|
|
save_model(args.model_prefix, epoch + 1, net, optimizer, moves, board_size) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--game", default="nogo") |
|
|
|
|
|
parser.add_argument("--pretrain", action="store_true") |
|
|
parser.add_argument("--sgf-prefix", default="../selfplay/sp") |
|
|
parser.add_argument("--model-prefix", default="models_b6c96") |
|
|
parser.add_argument("--load-ckpt", default="") |
|
|
parser.add_argument("--load-data", default="") |
|
|
parser.add_argument("--cpus", default=32, type=int) |
|
|
parser.add_argument("--port", default=5566, type=int) |
|
|
|
|
|
parser.add_argument("-lr", "--lr", default=0.01, type=float) |
|
|
parser.add_argument("-bs", "--batch-size", default=512, type=int) |
|
|
parser.add_argument("-vr", "--value-ratio", default=1, type=float) |
|
|
|
|
|
args = parser.parse_args() |
|
|
os.makedirs(args.sgf_prefix, exist_ok=True) |
|
|
os.makedirs(args.model_prefix, exist_ok=True) |
|
|
main(args) |
|
|
|