File size: 7,634 Bytes
db0dcb9
 
 
 
 
 
 
1e53095
 
db0dcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
1e53095
 
db0dcb9
ffb7d49
1e53095
 
 
 
 
 
db0dcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
d17b8f3
 
db0dcb9
 
 
d17b8f3
 
db0dcb9
 
 
 
 
1e53095
db0dcb9
 
 
ffb7d49
1e53095
 
db0dcb9
1e53095
db0dcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e53095
db0dcb9
1e53095
 
 
db0dcb9
ffb7d49
db0dcb9
 
 
 
1e53095
 
 
 
db0dcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e53095
db0dcb9
 
 
 
 
 
 
 
 
 
 
 
 
1e53095
 
 
db0dcb9
 
 
 
 
1e53095
 
 
 
db0dcb9
1e53095
 
 
 
db0dcb9
1e53095
 
db0dcb9
1e53095
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import StepLR
from tqdm.auto import tqdm
import torch.nn.functional as F
import pandas as pd

torch.manual_seed(114514)
torch.set_default_device('cuda')

SOS_token = 1
EOS_token = 2
katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ')
vocab = ['<pad>', '<sos>', '<eos>'] + katakana
vocab_dict = {v: k for k, v in enumerate(vocab)}

texts = pd.read_csv('rolename.txt', header=None)[0].tolist()
vocab_size=len(vocab)
h=192
h_latent=64
max_len=40
bs=128
lr=0.2
lr_step_size=32
lr_decay=0.5
momentum=0.9
epochs=192
grad_max_norm=1

def tokenize(text):
    return [vocab_dict[ch] for ch in text]

def detokenize(tokens):
    if EOS_token in tokens:
        tokens = tokens[:tokens.index(EOS_token)]
    return ''.join(vocab[token] for token in tokens)

class BatchNormVAE(nn.Module): # https://spaces.ac.cn/archives/7381/
    def __init__(self, num_features, **kwargs):
        super(BatchNormVAE, self).__init__()
        kwargs['affine'] = False
        self.TAU = 0.5
        self.bn_mu = nn.BatchNorm1d(num_features, **kwargs)
        self.bn_sigma = nn.BatchNorm1d(num_features, **kwargs)
        self.theta = nn.Parameter(torch.zeros(1))

    def forward(self, mu, sigma):
        mu = self.bn_mu(mu)
        sigma = self.bn_sigma(sigma)
        scale_mu = torch.sqrt(self.TAU + (1 - self.TAU) * F.sigmoid(self.theta))
        scale_sigma = torch.sqrt((1 - self.TAU) * F.sigmoid(-self.theta))
        return mu*scale_mu, sigma*scale_sigma

class EncoderVAEBiGRU(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, dropout_p=0.1):
        super(EncoderVAEBiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True, bidirectional=True)
        self.proj_mu = nn.Linear(4 * hidden_size, latent_size)
        self.proj_sigma = nn.Linear(4 * hidden_size, latent_size)
        self.dropout = nn.Dropout(dropout_p)
        self.bn = BatchNormVAE(latent_size)

    def forward(self, input, input_lengths):
        input_lengths = input_lengths.to('cpu')
        embedded = self.dropout(self.embedding(input))
        embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False)
        _, hidden = self.gru(embedded)
        hidden = hidden.permute(1, 0, 2).flatten(1, 2)
        mu = self.proj_mu(hidden)
        sigma = self.proj_sigma(hidden) # not std, can be negative
        mu, sigma = self.bn(mu, sigma)
        return self._reparameterize(mu, sigma), mu, sigma ** 2

    def _reparameterize(self, mu, sigma):
        eps = torch.randn_like(sigma)
        return eps * sigma + mu # var is sigma^2

class DecoderGRU(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super(DecoderGRU, self).__init__()
        self.proj1 = nn.Linear(latent_size, latent_size)
        self.proj_activation = nn.ReLU()
        self.proj2 = nn.Linear(latent_size, 2 * hidden_size)
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, encoder_sample, target_tensor=None, max_length=16):
        batch_size = encoder_sample.size(0)
        decoder_hidden = self.proj1(encoder_sample)
        decoder_hidden = self.proj_activation(decoder_hidden)
        decoder_hidden = self.proj2(decoder_hidden)
        decoder_hidden = decoder_hidden.view(batch_size, 2, -1).permute(1, 0, 2).contiguous()
        if target_tensor is not None:
            decoder_input = target_tensor
            decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
        else:
            decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
            decoder_outputs = []
            for i in range(max_length):
                decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
                decoder_outputs.append(decoder_output)
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()
            decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        return decoder_outputs, decoder_hidden

    def forward_step(self, input, hidden):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.out(output)
        return output, hidden

class KatakanaDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer(text)
        enc_text = tokens
        enc_len = len(enc_text)
        input_text = [SOS_token] + tokens
        target_text = tokens + [EOS_token]
        enc_text = torch.tensor(enc_text + [0] * (self.max_length - len(enc_text)), dtype=torch.long)
        input_text = torch.tensor(input_text + [0] * (self.max_length - len(input_text)), dtype=torch.long)
        target_text = torch.tensor(target_text + [0] * (self.max_length - len(target_text)), dtype=torch.long)
        return enc_text, enc_len, input_text, target_text

dataloader = DataLoader(
    KatakanaDataset(texts, tokenize, max_len),
    batch_size=bs,
    shuffle=True,
    generator=torch.Generator(device='cuda'),
)

def train_epoch(dataloader, encoder, decoder, optimizer, max_norm, norm_p=2):
    total_loss = 0
    nll = nn.NLLLoss()
    for enc_text, enc_len, input_text, target_text in dataloader:
        optimizer.zero_grad()

        encoder_sample, mu, var = encoder(enc_text, enc_len)
        decoder_outputs, _ = decoder(encoder_sample, input_text)

        loss_recons = nll(decoder_outputs.view(-1, decoder_outputs.size(-1)), target_text.view(-1))
        loss_kld = 0.5 * torch.mean(mu ** 2 + var - var.log() - 1)
        loss = loss_recons + loss_kld
        loss.backward()

        # gradient clipping by norm
        nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), max_norm, norm_type=norm_p)

        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

encoder = EncoderVAEBiGRU(vocab_size, h, h_latent).train()
decoder = DecoderGRU(h_latent, h, vocab_size).train()
optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr, momentum=momentum) # momentum
scheduler = StepLR(optimizer, step_size=lr_step_size, gamma=lr_decay)

with tqdm(range(epochs), desc='Training') as pbar:
    for i in pbar:
        pbar.set_postfix(loss=train_epoch(dataloader, encoder, decoder, optimizer, grad_max_norm))
        scheduler.step()

decoder.eval()
for name in [detokenize(seq) for seq in decoder(torch.randn(8,h_latent), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]:
    print(name)
torch.save(decoder, 'decoder.pt')