import argparse import random import time import torch import numpy as np from network import GNet from trainer import Trainer from utils.data_loader import FileLoader def get_args(): parser = argparse.ArgumentParser(description='Args for graph predition') parser.add_argument('-seed', type=int, default=1, help='seed') parser.add_argument('-data', default='DD', help='data folder name') parser.add_argument('-fold', type=int, default=1, help='fold (1..10)') parser.add_argument('-num_epochs', type=int, default=2, help='epochs') parser.add_argument('-batch', type=int, default=8, help='batch size') parser.add_argument('-lr', type=float, default=0.001, help='learning rate') parser.add_argument('-deg_as_tag', type=int, default=0, help='1 or degree') parser.add_argument('-l_num', type=int, default=3, help='layer num') parser.add_argument('-h_dim', type=int, default=512, help='hidden dim') parser.add_argument('-l_dim', type=int, default=48, help='layer dim') parser.add_argument('-drop_n', type=float, default=0.3, help='drop net') parser.add_argument('-drop_c', type=float, default=0.2, help='drop output') parser.add_argument('-act_n', type=str, default='ELU', help='network act') parser.add_argument('-act_c', type=str, default='ELU', help='output act') parser.add_argument('-ks', nargs='+', type=float, default='0.9 0.8 0.7') parser.add_argument('-acc_file', type=str, default='re', help='acc file') args, _ = parser.parse_known_args() return args def set_random(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def app_run(args, G_data, fold_idx): G_data.use_fold_data(fold_idx) net = GNet(G_data.feat_dim, G_data.num_class, args) trainer = Trainer(args, net, G_data) trainer.train() def main(): args = get_args() print(args) set_random(args.seed) start = time.time() G_data = FileLoader(args).load_data() print('load data using ------>', time.time()-start) if args.fold == 0: for fold_idx in range(10): print('start training ------> fold', fold_idx+1) app_run(args, G_data, fold_idx) else: print('start training ------> fold', args.fold) app_run(args, G_data, args.fold-1) if __name__ == "__main__": main()