qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
import argparse
import random
import time
import torch
import numpy as np
from network import GNet
from trainer import Trainer
from utils.data_loader import FileLoader
def get_args():
parser = argparse.ArgumentParser(description='Args for graph predition')
parser.add_argument('-seed', type=int, default=1, help='seed')
parser.add_argument('-data', default='DD', help='data folder name')
parser.add_argument('-fold', type=int, default=1, help='fold (1..10)')
parser.add_argument('-num_epochs', type=int, default=2, help='epochs')
parser.add_argument('-batch', type=int, default=8, help='batch size')
parser.add_argument('-lr', type=float, default=0.001, help='learning rate')
parser.add_argument('-deg_as_tag', type=int, default=0, help='1 or degree')
parser.add_argument('-l_num', type=int, default=3, help='layer num')
parser.add_argument('-h_dim', type=int, default=512, help='hidden dim')
parser.add_argument('-l_dim', type=int, default=48, help='layer dim')
parser.add_argument('-drop_n', type=float, default=0.3, help='drop net')
parser.add_argument('-drop_c', type=float, default=0.2, help='drop output')
parser.add_argument('-act_n', type=str, default='ELU', help='network act')
parser.add_argument('-act_c', type=str, default='ELU', help='output act')
parser.add_argument('-ks', nargs='+', type=float, default='0.9 0.8 0.7')
parser.add_argument('-acc_file', type=str, default='re', help='acc file')
args, _ = parser.parse_known_args()
return args
def set_random(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def app_run(args, G_data, fold_idx):
G_data.use_fold_data(fold_idx)
net = GNet(G_data.feat_dim, G_data.num_class, args)
trainer = Trainer(args, net, G_data)
trainer.train()
def main():
args = get_args()
print(args)
set_random(args.seed)
start = time.time()
G_data = FileLoader(args).load_data()
print('load data using ------>', time.time()-start)
if args.fold == 0:
for fold_idx in range(10):
print('start training ------> fold', fold_idx+1)
app_run(args, G_data, fold_idx)
else:
print('start training ------> fold', args.fold)
app_run(args, G_data, args.fold-1)
if __name__ == "__main__":
main()