File size: 3,933 Bytes
3f42bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import pickle
import os
import torch
import torch.nn as nn
from gensim.models.word2vec import Word2Vec
from torch.utils.data import Dataset


def padding(poetries, maxlen, pad):
    batch_seq = [poetry + pad * (maxlen - len(poetry)) for poetry in poetries]
    return batch_seq


# 输入向后滑一字符为target,即预测下一个字
def split_input_target(seq):
    inputs = seq[:-1]
    targets = seq[1:]
    return inputs, targets


# 创建词汇表
def get_poetry(arg):
    poetrys = []
    if arg.Augmented_dataset:
        path = arg.Augmented_data
    else:
        path = arg.data
    with open(path, "r", encoding='UTF-8') as f:
        for line in f:
            try:
                # line = line.decode('UTF-8')
                line = line.strip(u'\n')
                if arg.Augmented_dataset:
                    content = line.strip(u' ')
                else:
                    title, content = line.strip(u' ').split(u':')
                content = content.replace(u' ', u'')
                if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
                    continue
                if arg.strict_dataset:
                    if len(content) < 12 or len(content) > 79:
                        continue
                else:
                    if len(content) < 5 or len(content) > 79:
                        continue
                content = u'[' + content + u']'
                poetrys.append(content)
            except Exception as e:
                pass

            # 按诗的字数排序
    poetrys = sorted(poetrys, key=lambda line: len(line))

    with open("data/org_poetry.txt", "w", encoding="utf-8") as f:
        for poetry in poetrys:
            poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n'
            f.write(poetry)

    return poetrys


# 切分文档
def split_text(poetrys):
    with open("data/split_poetry.txt", "w", encoding="utf-8") as f:
        for poetry in poetrys:
            poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n '
            split_data = " ".join(poetry)
            f.write(split_data)
    return open("data/split_poetry.txt", "r", encoding='UTF-8').read()


# 训练词向量
def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"):
    param_file = "data/word_vec.pkl"
    org_data = open(org_file, "r", encoding="utf-8").read().split("\n")
    if os.path.exists(split_file):
        all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")
    else:
        all_data_split = split_text().split("\n")

    if os.path.exists(param_file):
        return org_data, pickle.load(open(param_file, "rb"))

    models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
    pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
    return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)


class Poetry_Dataset(Dataset):
    def __init__(self, w1, word_2_index, all_data, Word2Vec):
        self.Word2Vec = Word2Vec
        self.w1 = w1
        self.word_2_index = word_2_index
        word_size, embedding_num = w1.shape
        self.embedding = nn.Embedding(word_size, embedding_num)
        # 最长句子长度
        maxlen = max([len(seq) for seq in all_data])
        pad = ' '
        self.all_data = padding(all_data[:-1], maxlen, pad)

    def __getitem__(self, index):
        a_poetry = self.all_data[index]

        a_poetry_index = [self.word_2_index[i] for i in a_poetry]
        xs, ys = split_input_target(a_poetry_index)
        if self.Word2Vec:
            xs_embedding = self.w1[xs]
        else:
            xs_embedding = np.array(xs)

        return xs_embedding, np.array(ys).astype(np.int64)

    def __len__(self):
        return len(self.all_data)