Spaces:
Runtime error
Runtime error
| from src.utils.utils import make_cuda | |
| from src.apis.train import train, evaluate | |
| from src.models.LSTM.model import Poetry_Model_lstm | |
| import argparse | |
| import torch | |
| import os | |
| from src.datasets.dataloader import Poetry_Dataset, train_vec, get_poetry, split_text | |
| from torch.utils.data import DataLoader | |
| def parse_arguments(): | |
| # argument parsing | |
| parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting") | |
| parser.add_argument('--model', type=str, default='lstm', | |
| help="lstm/GRU/Seq2Seq/Transformer/GPT-2") | |
| parser.add_argument('--Word2Vec', default=True) | |
| parser.add_argument('--Augmented_dataset', default=False, help="augmented dataset") | |
| parser.add_argument('--strict_dataset', default=False, help="strict dataset") | |
| parser.add_argument('--batch_size', type=int, default=64, | |
| help="Specify batch size") | |
| parser.add_argument('--num_epochs', type=int, default=50, | |
| help="Specify the number of epochs for competitive search") | |
| parser.add_argument('--log_step', type=int, default=100, | |
| help="Specify log step size for training") | |
| parser.add_argument('--learning_rate', type=float, default=1e-3, | |
| help="Learning rate") | |
| parser.add_argument('--data', type=str, default='data/poetry.txt', | |
| help="Path to the dataset") | |
| parser.add_argument('--Augmented_data', type=str, default='data/poetry_7.txt', | |
| help="Path to the Augmented_dataset") | |
| parser.add_argument('--n_hidden', type=int, default=128) | |
| parser.add_argument('--max_grad_norm', type=float, default=1.0) | |
| parser.add_argument('--save_path', type=str, default='save_models/') | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_arguments() | |
| # if you want to change the data(org data or argument data), please delete file: 'split_poetry.txt' and 'org_poetry.txt' | |
| if os.path.exists("data/split_poetry.txt") and os.path.exists("data/org_poetry.txt"): | |
| print("pre_file exit!") | |
| else: | |
| split_text(get_poetry(args)) | |
| all_data, (w1, word_2_index, index_2_word) = train_vec() | |
| args.word_size, args.embedding_num = w1.shape | |
| dataset = Poetry_Dataset(w1, word_2_index, all_data, args.Word2Vec) | |
| train_size = int(len(dataset) * 0.8) | |
| test_size = len(dataset) - train_size | |
| train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) | |
| train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) | |
| valid_data_loader = DataLoader(test_dataset, batch_size=int(args.batch_size/4), shuffle=True) | |
| if args.model == 'lstm': | |
| best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'GRU': | |
| best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'Seq2Seq': | |
| best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'Transformer': | |
| best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'GPT-2': | |
| best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| else: | |
| print("Please choose a model!\n") | |
| best_model = make_cuda(best_model) | |
| best_model = train(args, best_model, train_data_loader) | |
| torch.save(best_model.state_dict(), args.save_path + args.model + '_' + str(args.num_epochs)+'.pth') | |
| print('test evaluation:') | |
| evaluate(args, best_model, valid_data_loader) | |
| if __name__ == '__main__': | |
| main() | |