Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import pickle | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from gensim.models.word2vec import Word2Vec | |
| from torch.utils.data import Dataset | |
| def padding(poetries, maxlen, pad): | |
| batch_seq = [poetry + pad * (maxlen - len(poetry)) for poetry in poetries] | |
| return batch_seq | |
| # 输入向后滑一字符为target,即预测下一个字 | |
| def split_input_target(seq): | |
| inputs = seq[:-1] | |
| targets = seq[1:] | |
| return inputs, targets | |
| # 创建词汇表 | |
| def get_poetry(arg): | |
| poetrys = [] | |
| if arg.Augmented_dataset: | |
| path = arg.Augmented_data | |
| else: | |
| path = arg.data | |
| with open(path, "r", encoding='UTF-8') as f: | |
| for line in f: | |
| try: | |
| # line = line.decode('UTF-8') | |
| line = line.strip(u'\n') | |
| if arg.Augmented_dataset: | |
| content = line.strip(u' ') | |
| else: | |
| title, content = line.strip(u' ').split(u':') | |
| content = content.replace(u' ', u'') | |
| if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content: | |
| continue | |
| if arg.strict_dataset: | |
| if len(content) < 12 or len(content) > 79: | |
| continue | |
| else: | |
| if len(content) < 5 or len(content) > 79: | |
| continue | |
| content = u'[' + content + u']' | |
| poetrys.append(content) | |
| except Exception as e: | |
| pass | |
| # 按诗的字数排序 | |
| poetrys = sorted(poetrys, key=lambda line: len(line)) | |
| with open("data/org_poetry.txt", "w", encoding="utf-8") as f: | |
| for poetry in poetrys: | |
| poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n' | |
| f.write(poetry) | |
| return poetrys | |
| # 切分文档 | |
| def split_text(poetrys): | |
| with open("data/split_poetry.txt", "w", encoding="utf-8") as f: | |
| for poetry in poetrys: | |
| poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n ' | |
| split_data = " ".join(poetry) | |
| f.write(split_data) | |
| return open("data/split_poetry.txt", "r", encoding='UTF-8').read() | |
| # 训练词向量 | |
| def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"): | |
| param_file = "data/word_vec.pkl" | |
| org_data = open(org_file, "r", encoding="utf-8").read().split("\n") | |
| if os.path.exists(split_file): | |
| all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n") | |
| else: | |
| all_data_split = split_text().split("\n") | |
| if os.path.exists(param_file): | |
| return org_data, pickle.load(open(param_file, "rb")) | |
| models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1) | |
| pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb")) | |
| return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key) | |
| class Poetry_Dataset(Dataset): | |
| def __init__(self, w1, word_2_index, all_data, Word2Vec): | |
| self.Word2Vec = Word2Vec | |
| self.w1 = w1 | |
| self.word_2_index = word_2_index | |
| word_size, embedding_num = w1.shape | |
| self.embedding = nn.Embedding(word_size, embedding_num) | |
| # 最长句子长度 | |
| maxlen = max([len(seq) for seq in all_data]) | |
| pad = ' ' | |
| self.all_data = padding(all_data[:-1], maxlen, pad) | |
| def __getitem__(self, index): | |
| a_poetry = self.all_data[index] | |
| a_poetry_index = [self.word_2_index[i] for i in a_poetry] | |
| xs, ys = split_input_target(a_poetry_index) | |
| if self.Word2Vec: | |
| xs_embedding = self.w1[xs] | |
| else: | |
| xs_embedding = np.array(xs) | |
| return xs_embedding, np.array(ys).astype(np.int64) | |
| def __len__(self): | |
| return len(self.all_data) | |