Spaces:
Runtime error
Runtime error
| import torch | |
| import argparse | |
| import numpy as np | |
| from src.models.LSTM.model import Poetry_Model_lstm | |
| from src.datasets.dataloader import train_vec | |
| from src.utils.utils import make_cuda | |
| def parse_arguments(): | |
| # argument parsing | |
| parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting") | |
| parser.add_argument('--model', type=str, default='lstm', | |
| help="lstm/GRU/Seq2Seq/Transformer/GPT-2") | |
| parser.add_argument('--Word2Vec', default=True) | |
| parser.add_argument('--strict_dataset', default=False, help="strict dataset") | |
| parser.add_argument('--n_hidden', type=int, default=128) | |
| parser.add_argument('--save_path', type=str, default='save_models/lstm_50.pth') | |
| return parser.parse_args() | |
| def generate_poetry(model, head_string, w1, word_2_index, index_2_word,args): | |
| print("藏头诗生成中...., {}".format(head_string)) | |
| poem = "" | |
| # 以句子的每一个字为开头生成诗句 | |
| for head in head_string: | |
| if head not in word_2_index: | |
| print("抱歉,不能生成以{}开头的诗".format(head)) | |
| return | |
| sentence = head | |
| max_sent_len = 20 | |
| h_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32)) | |
| c_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32)) | |
| input_eval = word_2_index[head] | |
| for i in range(max_sent_len): | |
| if args.Word2Vec: | |
| word_embedding = torch.tensor(w1[input_eval][None][None]) | |
| else: | |
| word_embedding = torch.tensor([input_eval]).unsqueeze(dim=0) | |
| pre, (h_0, c_0) = model(word_embedding, h_0, c_0) | |
| char_generated = index_2_word[int(torch.argmax(pre))] | |
| if char_generated == '。': | |
| break | |
| # 以新生成的字为输入继续向下生成 | |
| input_eval = word_2_index[char_generated] | |
| sentence += char_generated | |
| poem += '\n' + sentence | |
| return poem | |
| def infer(model,head): | |
| args = parse_arguments() | |
| all_data, (w1, word_2_index, index_2_word) = train_vec() | |
| args.word_size, args.embedding_num = w1.shape | |
| # string = input("诗头:") | |
| # string = '自然语言' | |
| string=head | |
| args.model=model | |
| if args.model == 'lstm': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| args.save_path = 'save_models/lstm_50.pth' | |
| elif args.model == 'GRU': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| args.save_path = 'save_models/GRU_50.pth' | |
| elif args.model == 'Seq2Seq': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'Transformer': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'GPT-2': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| else: | |
| print("Please choose a model!\n") | |
| model.load_state_dict(torch.load(args.save_path,map_location=torch.device('cpu'))) | |
| # model = make_cuda(model) | |
| poem = generate_poetry(model, string, w1, word_2_index, index_2_word,args) | |
| return poem | |
| if __name__ == '__main__': | |
| args = parse_arguments() | |
| all_data, (w1, word_2_index, index_2_word) = train_vec() | |
| args.word_size, args.embedding_num = w1.shape | |
| # string = input("诗头:") | |
| string = '自然语言' | |
| if args.model == 'lstm': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'GRU': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'Seq2Seq': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'Transformer': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| elif args.model == 'GPT-2': | |
| model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
| else: | |
| print("Please choose a model!\n") | |
| model.load_state_dict(torch.load(args.save_path)) | |
| model = make_cuda(model) | |
| poem = generate_poetry(model, string, w1, word_2_index, index_2_word) | |
| print(poem) | |