from dataloader import DataLoader from model import AlphaZero, BasicBlock, Bottlenest #from export_ait import save_ait import argparse import os import re import time import torch from torch import nn kGames = dict( nogo=dict(num_features=4, moves=81, board_size=81, value_heads=1), go9=dict(num_features=20, moves=82, board_size=81, value_heads=31), go19=dict(num_features=20, moves=362, board_size=361, value_heads=31), ) def save_model(model_prefix, epoch, net, optimizer, moves, board_size): net.eval() net_state = net.state_dict() torch.save( { "epoch": epoch, "net": net_state, "optimizer": optimizer.state_dict(), }, f"{model_prefix}/model-{epoch}.ckpt", ) #save_ait(net_state, moves, board_size, f"{model_prefix}/model-{epoch}.ait") net.train() def main(args): torch.backends.cudnn.benchmark = True game = kGames[args.game] moves, board_size = game["moves"], game["board_size"] layers, channels, block = re.search(r"b(\d+)c(\d+)(.*)", args.model_prefix).groups() block = BasicBlock if block == "" else Bottlenest net = AlphaZero( in_channels=game["num_features"], layers=int(layers), channels=int(channels), moves=moves, board_size=board_size, value_heads=game["value_heads"], bias=False, block=block, ).cuda() # loss fn p_criterion = lambda p_logits, p_labels: ( (-p_labels * torch.log_softmax(p_logits, dim=1)).sum(dim=1).mean() ) v_criterion = nn.MSELoss() optimizer = torch.optim.SGD( net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001, nesterov=True ) # load checkpoint epoch_start = 0 dataloader = DataLoader( args.port, args.cpus, args.batch_size, args.sgf_prefix, not args.pretrain ) if args.load_ckpt: print("> Restore from", args.load_ckpt) ckpt = torch.load(args.load_ckpt, weights_only=True) net.load_state_dict(ckpt["net"]) optimizer.load_state_dict(ckpt["optimizer"]) if args.load_data: epoch_start = ckpt["epoch"] dataloader.load(args.load_data, epoch_start) save_model(args.model_prefix, epoch_start, net, optimizer, moves, board_size) print("> Start training") # train for epoch in range(epoch_start, epoch_start + 6000): net.train() time_start = time.time() for i, batch in enumerate(dataloader): inputs, p_labels, v_labels = batch.inputs, batch.policy, batch.value # forward + backward p_logits, v_logits = net(inputs) v_loss = v_criterion(v_logits, v_labels) p_loss = p_criterion(p_logits, p_labels) loss = v_loss * args.value_ratio + p_loss # optimize optimizer.zero_grad() loss.backward() optimizer.step() # train loss if i % 10 == 0: print( "[{:3d}:{:5d}] PN_Loss: {:.5f} VN_Loss: {:.5f}".format( epoch, i, p_loss.item(), v_loss.item() ) ) print("[{:3d}] Time per epoch: {}".format(epoch, time.time() - time_start)) save_model(args.model_prefix, epoch + 1, net, optimizer, moves, board_size) if __name__ == "__main__": parser = argparse.ArgumentParser() # game parser.add_argument("--game", default="nogo") # training parser.add_argument("--pretrain", action="store_true") parser.add_argument("--sgf-prefix", default="../selfplay/sp") parser.add_argument("--model-prefix", default="models_b6c96") parser.add_argument("--load-ckpt", default="") parser.add_argument("--load-data", default="") parser.add_argument("--cpus", default=32, type=int) parser.add_argument("--port", default=5566, type=int) # hyperparameters parser.add_argument("-lr", "--lr", default=0.01, type=float) parser.add_argument("-bs", "--batch-size", default=512, type=int) parser.add_argument("-vr", "--value-ratio", default=1, type=float) args = parser.parse_args() os.makedirs(args.sgf_prefix, exist_ok=True) os.makedirs(args.model_prefix, exist_ok=True) main(args)