zhangj726's picture
Upload 50 files
3f42bd3
import numpy as np
import pickle
import os
import torch
import torch.nn as nn
from gensim.models.word2vec import Word2Vec
from torch.utils.data import Dataset
def padding(poetries, maxlen, pad):
batch_seq = [poetry + pad * (maxlen - len(poetry)) for poetry in poetries]
return batch_seq
# 输入向后滑一字符为target,即预测下一个字
def split_input_target(seq):
inputs = seq[:-1]
targets = seq[1:]
return inputs, targets
# 创建词汇表
def get_poetry(arg):
poetrys = []
if arg.Augmented_dataset:
path = arg.Augmented_data
else:
path = arg.data
with open(path, "r", encoding='UTF-8') as f:
for line in f:
try:
# line = line.decode('UTF-8')
line = line.strip(u'\n')
if arg.Augmented_dataset:
content = line.strip(u' ')
else:
title, content = line.strip(u' ').split(u':')
content = content.replace(u' ', u'')
if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
continue
if arg.strict_dataset:
if len(content) < 12 or len(content) > 79:
continue
else:
if len(content) < 5 or len(content) > 79:
continue
content = u'[' + content + u']'
poetrys.append(content)
except Exception as e:
pass
# 按诗的字数排序
poetrys = sorted(poetrys, key=lambda line: len(line))
with open("data/org_poetry.txt", "w", encoding="utf-8") as f:
for poetry in poetrys:
poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n'
f.write(poetry)
return poetrys
# 切分文档
def split_text(poetrys):
with open("data/split_poetry.txt", "w", encoding="utf-8") as f:
for poetry in poetrys:
poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n '
split_data = " ".join(poetry)
f.write(split_data)
return open("data/split_poetry.txt", "r", encoding='UTF-8').read()
# 训练词向量
def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"):
param_file = "data/word_vec.pkl"
org_data = open(org_file, "r", encoding="utf-8").read().split("\n")
if os.path.exists(split_file):
all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")
else:
all_data_split = split_text().split("\n")
if os.path.exists(param_file):
return org_data, pickle.load(open(param_file, "rb"))
models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)
class Poetry_Dataset(Dataset):
def __init__(self, w1, word_2_index, all_data, Word2Vec):
self.Word2Vec = Word2Vec
self.w1 = w1
self.word_2_index = word_2_index
word_size, embedding_num = w1.shape
self.embedding = nn.Embedding(word_size, embedding_num)
# 最长句子长度
maxlen = max([len(seq) for seq in all_data])
pad = ' '
self.all_data = padding(all_data[:-1], maxlen, pad)
def __getitem__(self, index):
a_poetry = self.all_data[index]
a_poetry_index = [self.word_2_index[i] for i in a_poetry]
xs, ys = split_input_target(a_poetry_index)
if self.Word2Vec:
xs_embedding = self.w1[xs]
else:
xs_embedding = np.array(xs)
return xs_embedding, np.array(ys).astype(np.int64)
def __len__(self):
return len(self.all_data)