ugb_zero / train /train.py
chengscott's picture
train
026a224
from dataloader import DataLoader
from model import AlphaZero, BasicBlock, Bottlenest
#from export_ait import save_ait
import argparse
import os
import re
import time
import torch
from torch import nn
kGames = dict(
nogo=dict(num_features=4, moves=81, board_size=81, value_heads=1),
go9=dict(num_features=20, moves=82, board_size=81, value_heads=31),
go19=dict(num_features=20, moves=362, board_size=361, value_heads=31),
)
def save_model(model_prefix, epoch, net, optimizer, moves, board_size):
net.eval()
net_state = net.state_dict()
torch.save(
{
"epoch": epoch,
"net": net_state,
"optimizer": optimizer.state_dict(),
},
f"{model_prefix}/model-{epoch}.ckpt",
)
#save_ait(net_state, moves, board_size, f"{model_prefix}/model-{epoch}.ait")
net.train()
def main(args):
torch.backends.cudnn.benchmark = True
game = kGames[args.game]
moves, board_size = game["moves"], game["board_size"]
layers, channels, block = re.search(r"b(\d+)c(\d+)(.*)", args.model_prefix).groups()
block = BasicBlock if block == "" else Bottlenest
net = AlphaZero(
in_channels=game["num_features"],
layers=int(layers),
channels=int(channels),
moves=moves,
board_size=board_size,
value_heads=game["value_heads"],
bias=False,
block=block,
).cuda()
# loss fn
p_criterion = lambda p_logits, p_labels: (
(-p_labels * torch.log_softmax(p_logits, dim=1)).sum(dim=1).mean()
)
v_criterion = nn.MSELoss()
optimizer = torch.optim.SGD(
net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001, nesterov=True
)
# load checkpoint
epoch_start = 0
dataloader = DataLoader(
args.port, args.cpus, args.batch_size, args.sgf_prefix, not args.pretrain
)
if args.load_ckpt:
print("> Restore from", args.load_ckpt)
ckpt = torch.load(args.load_ckpt, weights_only=True)
net.load_state_dict(ckpt["net"])
optimizer.load_state_dict(ckpt["optimizer"])
if args.load_data:
epoch_start = ckpt["epoch"]
dataloader.load(args.load_data, epoch_start)
save_model(args.model_prefix, epoch_start, net, optimizer, moves, board_size)
print("> Start training")
# train
for epoch in range(epoch_start, epoch_start + 6000):
net.train()
time_start = time.time()
for i, batch in enumerate(dataloader):
inputs, p_labels, v_labels = batch.inputs, batch.policy, batch.value
# forward + backward
p_logits, v_logits = net(inputs)
v_loss = v_criterion(v_logits, v_labels)
p_loss = p_criterion(p_logits, p_labels)
loss = v_loss * args.value_ratio + p_loss
# optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# train loss
if i % 10 == 0:
print(
"[{:3d}:{:5d}] PN_Loss: {:.5f} VN_Loss: {:.5f}".format(
epoch, i, p_loss.item(), v_loss.item()
)
)
print("[{:3d}] Time per epoch: {}".format(epoch, time.time() - time_start))
save_model(args.model_prefix, epoch + 1, net, optimizer, moves, board_size)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# game
parser.add_argument("--game", default="nogo")
# training
parser.add_argument("--pretrain", action="store_true")
parser.add_argument("--sgf-prefix", default="../selfplay/sp")
parser.add_argument("--model-prefix", default="models_b6c96")
parser.add_argument("--load-ckpt", default="")
parser.add_argument("--load-data", default="")
parser.add_argument("--cpus", default=32, type=int)
parser.add_argument("--port", default=5566, type=int)
# hyperparameters
parser.add_argument("-lr", "--lr", default=0.01, type=float)
parser.add_argument("-bs", "--batch-size", default=512, type=int)
parser.add_argument("-vr", "--value-ratio", default=1, type=float)
args = parser.parse_args()
os.makedirs(args.sgf_prefix, exist_ok=True)
os.makedirs(args.model_prefix, exist_ok=True)
main(args)