| import sys |
| import math |
| import argparse |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| import utils |
| from greedrl import Solver |
|
|
|
|
| def do_train(args, rank): |
| world_size = args.world_size |
| model_filename = args.model_filename |
| problem_size = args.problem_size |
| batch_size = args.batch_size |
|
|
| index = model_filename.rfind('.') |
| if world_size > 1: |
| stdout_filename = '{}_r{}.log'.format(model_filename[0:index], rank) |
| else: |
| stdout_filename = '{}.log'.format(model_filename[0:index]) |
|
|
| stdout = open(stdout_filename, 'a') |
| sys.stdout = stdout |
| sys.stderr = stdout |
|
|
| print("args: {}".format(vars(args))) |
| if world_size > 1: |
| dist.init_process_group('NCCL', init_method='tcp://127.0.0.1:29500', |
| rank=rank, world_size=world_size) |
|
|
| problem_batch_size = 8 |
| batch_count = 0 |
| if problem_size == 100: |
| batch_count = math.ceil(10000 / problem_batch_size) |
| elif problem_size == 1000: |
| batch_count = math.ceil(200 / problem_batch_size) |
| elif problem_size == 2000: |
| batch_count = math.ceil(100 / problem_batch_size) |
| elif problem_size == 5000: |
| batch_count = math.ceil(10 / problem_batch_size) |
| else: |
| raise Exception("unsupported problem size: {}".format(problem_size)) |
|
|
| nn_args = { |
| 'encode_norm': 'instance', |
| 'encode_layers': 6, |
| 'decode_rnn': 'LSTM' |
| } |
|
|
| device = None if world_size == 1 else 'cuda:{}'.format(rank) |
| solver = Solver(device, nn_args) |
|
|
| train_dataset = utils.Dataset(None, problem_batch_size, problem_size) |
| valid_dataset = utils.Dataset(batch_count, problem_batch_size, problem_size) |
|
|
| solver.train(model_filename, train_dataset, valid_dataset, |
| train_dataset_workers=5, |
| batch_size=batch_size, |
| memopt=10, |
| topk_size=1, |
| init_lr=1e-4, |
| valid_steps=500, |
| warmup_steps=0) |
|
|
|
|
| if __name__ == '__main__': |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') |
| parser.add_argument('--model_filename', type=str, help='model file name') |
| parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size') |
| parser.add_argument('--batch_size', default=128, type=int, help='batch size for training') |
|
|
| args = parser.parse_args() |
|
|
| processes = [] |
| for rank in range(args.world_size): |
| p = mp.Process(target=do_train, args=(args, rank)) |
| p.start() |
| processes.append(p) |
|
|
| for p in processes: |
| p.join() |
|
|