File size: 2,890 Bytes
0e85155
 
 
 
 
 
 
435428f
0e85155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac144e1
0e85155
 
 
 
 
 
 
 
 
 
 
 
 
ac144e1
2db9276
 
 
 
0e85155
07e8e52
0e85155
 
 
 
 
 
 
 
a9456d3
89ff2ee
 
 
 
0e85155
 
 
 
 
 
 
 
 
 
ac144e1
 
435428f
 
0e85155
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torchtext
from pathlib import Path

from model import EncoderLM

MAX_LEN = 50
TEMPERATURE = 1.0

device = 'cpu'
model_dir = Path().cwd() / 'models'
model_path = model_dir / 'lm_8_layers.pth'
vocab_path = model_dir / 'vocab_8_layers.pth'

emb_size = 768
dim_feedforward = 2048
num_layers = 8
nhead = 12

def get_vocab():
    return torch.load(vocab_path, map_location=torch.device(device))

vocab = get_vocab()
vocab_size = len(vocab)
padding_idx = vocab.get_stoi()['<pad>']

def get_model():
    model = EncoderLM(vocab_size, emb_size, num_layers=num_layers, nhead=8, dim_feedforward=dim_feedforward,
                  masking=True, padding_idx=padding_idx, dropout=0.1, max_len=525, deepnorm=True).to(device)
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    return model

def clean_text(tokens):
    from nltk.tokenize.treebank import TreebankWordDetokenizer
    text = []
    prev_token = '<bos>'
    for token in tokens:
        if token != '<unk>':
            if prev_token == '<up>':
                token = token.upper()
            if prev_token == '<cap>':
                token = token.title()
        if token == '@-@':
            token = '-'
        if token not in ['<bos>', '<eos>', '<up>', '<cap>']:
            text.append(token)
        prev_token = token
    detokenizer = TreebankWordDetokenizer()
    text = detokenizer.detokenize(text)
    text = text.replace(" ' ", "' ").replace("' ", "'")
    text = text.replace(" . . . ", "...").replace(" . ", ". ").replace(" ? ", "? ").replace(" ! ", "! ")
    return text

def generate_text(seed, model, vocab, max_len=20, temperature=1., device=device, skip_tokens=['<unk>'], top_k=50):
    stoi, itos = vocab.get_stoi(), vocab.get_itos()
    stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
    tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
    model = model.eval()
    seed_tokens = ['<bos>'] + tokenizer(seed)
    x = torch.tensor([stoi_map(word) for word in seed_tokens]).long().to(device)[None, :]
    idxs = []
    for _ in range(max_len):
        yhat = model(x) / temperature
        probs = yhat[:, -1].softmax(dim=-1).squeeze()
        top_probs = torch.topk(probs, top_k, dim=-1).indices
        probs[~top_probs] = 0.
        idx = torch.multinomial(probs, 1, replacement=True).item()
        idxs.append(idx)
        x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
        if itos[idx] == '<eos>':
            break
    generated = [itos[idx] for idx in idxs]
    text = seed + ' ' + clean_text(generated)
    return text


if __name__ == '__main__':
    vocab = get_vocab()
    model = get_model()
    seed = 'Tell me a story about'
    generated = generate_text(seed, model, vocab, max_len=20, temperature=1.0, device=device, skip_tokens=['<unk>'], top_k=50)
    print(generated)