Spaces:
Runtime error
Runtime error
File size: 3,933 Bytes
3f42bd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import numpy as np
import pickle
import os
import torch
import torch.nn as nn
from gensim.models.word2vec import Word2Vec
from torch.utils.data import Dataset
def padding(poetries, maxlen, pad):
batch_seq = [poetry + pad * (maxlen - len(poetry)) for poetry in poetries]
return batch_seq
# 输入向后滑一字符为target,即预测下一个字
def split_input_target(seq):
inputs = seq[:-1]
targets = seq[1:]
return inputs, targets
# 创建词汇表
def get_poetry(arg):
poetrys = []
if arg.Augmented_dataset:
path = arg.Augmented_data
else:
path = arg.data
with open(path, "r", encoding='UTF-8') as f:
for line in f:
try:
# line = line.decode('UTF-8')
line = line.strip(u'\n')
if arg.Augmented_dataset:
content = line.strip(u' ')
else:
title, content = line.strip(u' ').split(u':')
content = content.replace(u' ', u'')
if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
continue
if arg.strict_dataset:
if len(content) < 12 or len(content) > 79:
continue
else:
if len(content) < 5 or len(content) > 79:
continue
content = u'[' + content + u']'
poetrys.append(content)
except Exception as e:
pass
# 按诗的字数排序
poetrys = sorted(poetrys, key=lambda line: len(line))
with open("data/org_poetry.txt", "w", encoding="utf-8") as f:
for poetry in poetrys:
poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n'
f.write(poetry)
return poetrys
# 切分文档
def split_text(poetrys):
with open("data/split_poetry.txt", "w", encoding="utf-8") as f:
for poetry in poetrys:
poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n '
split_data = " ".join(poetry)
f.write(split_data)
return open("data/split_poetry.txt", "r", encoding='UTF-8').read()
# 训练词向量
def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"):
param_file = "data/word_vec.pkl"
org_data = open(org_file, "r", encoding="utf-8").read().split("\n")
if os.path.exists(split_file):
all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")
else:
all_data_split = split_text().split("\n")
if os.path.exists(param_file):
return org_data, pickle.load(open(param_file, "rb"))
models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)
class Poetry_Dataset(Dataset):
def __init__(self, w1, word_2_index, all_data, Word2Vec):
self.Word2Vec = Word2Vec
self.w1 = w1
self.word_2_index = word_2_index
word_size, embedding_num = w1.shape
self.embedding = nn.Embedding(word_size, embedding_num)
# 最长句子长度
maxlen = max([len(seq) for seq in all_data])
pad = ' '
self.all_data = padding(all_data[:-1], maxlen, pad)
def __getitem__(self, index):
a_poetry = self.all_data[index]
a_poetry_index = [self.word_2_index[i] for i in a_poetry]
xs, ys = split_input_target(a_poetry_index)
if self.Word2Vec:
xs_embedding = self.w1[xs]
else:
xs_embedding = np.array(xs)
return xs_embedding, np.array(ys).astype(np.int64)
def __len__(self):
return len(self.all_data)
|