rkingery commited on
Commit
ac144e1
·
1 Parent(s): 5cefcd7

changed temperature; added text detokenizer

Browse files
Files changed (3) hide show
  1. __pycache__/model.cpython-310.pyc +0 -0
  2. app.py +2 -2
  3. utils.py +17 -19
__pycache__/model.cpython-310.pyc ADDED
Binary file (5.65 kB). View file
 
app.py CHANGED
@@ -9,12 +9,12 @@ st.title('Just a Dumb Language Model')
9
  text = st.text_input('Enter some text (this will be used to seed the language model)', value='Tell me a story about')
10
  max_len = st.number_input('Enter a max number of words to generate', min_value=0, max_value=512, value=MAX_LEN)
11
  temperature = st.number_input('Enter a temperature (the higher it is, the more random the output will be)',
12
- min_value=0., max_value=1., value=TEMPERATURE)
13
 
14
  if st.button('Click to run'):
15
  vocab = get_vocab()
16
  model = get_model()
17
- generated = generate_text(text, model, vocab, max_len=max_len, temperature=temperature)
18
 
19
  st.markdown('### Generated Text')
20
  st.markdown(f'{generated}')
 
9
  text = st.text_input('Enter some text (this will be used to seed the language model)', value='Tell me a story about')
10
  max_len = st.number_input('Enter a max number of words to generate', min_value=0, max_value=512, value=MAX_LEN)
11
  temperature = st.number_input('Enter a temperature (the higher it is, the more random the output will be)',
12
+ min_value=0., max_value=100., value=TEMPERATURE)
13
 
14
  if st.button('Click to run'):
15
  vocab = get_vocab()
16
  model = get_model()
17
+ generated = generate_text(text, model, vocab, max_len=MAX_LEN, temperature=TEMPERATURE)
18
 
19
  st.markdown('### Generated Text')
20
  st.markdown(f'{generated}')
utils.py CHANGED
@@ -5,7 +5,7 @@ from pathlib import Path
5
  from model import EncoderLM
6
 
7
  MAX_LEN = 50
8
- TEMPERATURE = 0.5
9
 
10
  device = 'cpu'
11
  model_dir = Path().cwd() / 'models'
@@ -31,6 +31,7 @@ def get_model():
31
  return model
32
 
33
  def clean_text(tokens):
 
34
  text = []
35
  prev_token = '<bos>'
36
  for token in tokens:
@@ -44,9 +45,10 @@ def clean_text(tokens):
44
  if token not in ['<bos>', '<eos>', '<up>', '<cap>']:
45
  text.append(token)
46
  prev_token = token
47
- return ' '.join(text)
 
48
 
49
- def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'], top_k=50):
50
  stoi, itos = vocab.get_stoi(), vocab.get_itos()
51
  stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
52
  tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
@@ -54,25 +56,21 @@ def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device
54
  seed_tokens = ['<bos>'] + tokenizer(seed)
55
  x = torch.tensor([stoi_map(word) for word in seed_tokens]).long().to(device)[None, :]
56
  idxs = []
57
- if not temperature > 0:
58
- temperature += 1e-3 # keeps while loop from getting stuck on special tokens
59
- idx_prev = stoi['<unk>']
60
- idx_prev_prev = stoi['<unk>']
61
- idx = idx_prev
62
  for _ in range(max_len):
63
- yhat = model(x)
64
  prob = yhat[:, -1].softmax(dim=-1).squeeze()
65
  top_probs = torch.topk(prob, top_k, dim=-1).indices
66
  prob[~top_probs] = 0.
67
- while (itos[idx] in skip_tokens) or (idx == idx_prev) or (idx == idx_prev_prev):
68
- if (torch.rand(1) > temperature):
69
- idx = prob.argmax(-1).item()
70
- else:
71
- idx = torch.multinomial(prob, 1, replacement=True).item()
72
  idxs.append(idx)
73
  x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
74
- idx_prev_prev = idx_prev
75
- idx_prev = idx
76
  if itos[idx] == '<eos>':
77
  break
78
  generated = [itos[idx] for idx in idxs]
@@ -81,8 +79,8 @@ def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device
81
 
82
 
83
  if __name__ == '__main__':
84
- vocab = get_vocab(vocab_path)
85
- model = get_model(model_path)
86
  seed = 'The entropy of the universe is'
87
- generated = generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'])
88
  print(generated)
 
5
  from model import EncoderLM
6
 
7
  MAX_LEN = 50
8
+ TEMPERATURE = 5.0
9
 
10
  device = 'cpu'
11
  model_dir = Path().cwd() / 'models'
 
31
  return model
32
 
33
  def clean_text(tokens):
34
+ from nltk.tokenize.treebank import TreebankWordDetokenizer
35
  text = []
36
  prev_token = '<bos>'
37
  for token in tokens:
 
45
  if token not in ['<bos>', '<eos>', '<up>', '<cap>']:
46
  text.append(token)
47
  prev_token = token
48
+ detokenizer = TreebankWordDetokenizer()
49
+ return detokenizer.detokenize(text)
50
 
51
+ def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'], top_k=100):
52
  stoi, itos = vocab.get_stoi(), vocab.get_itos()
53
  stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
54
  tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
 
56
  seed_tokens = ['<bos>'] + tokenizer(seed)
57
  x = torch.tensor([stoi_map(word) for word in seed_tokens]).long().to(device)[None, :]
58
  idxs = []
59
+ temperature = 1e-5 if temperature < 1e-5 else temperature
60
+ # idx_prev = stoi['<unk>']
61
+ # idx_prev_prev = stoi['<unk>']
62
+ # idx = idx_prev
 
63
  for _ in range(max_len):
64
+ yhat = model(x) / temperature
65
  prob = yhat[:, -1].softmax(dim=-1).squeeze()
66
  top_probs = torch.topk(prob, top_k, dim=-1).indices
67
  prob[~top_probs] = 0.
68
+ #while (itos[idx] in skip_tokens) or (idx == idx_prev) or (idx == idx_prev_prev):
69
+ idx = torch.multinomial(prob, 1, replacement=True).item()
 
 
 
70
  idxs.append(idx)
71
  x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
72
+ # idx_prev_prev = idx_prev
73
+ # idx_prev = idx
74
  if itos[idx] == '<eos>':
75
  break
76
  generated = [itos[idx] for idx in idxs]
 
79
 
80
 
81
  if __name__ == '__main__':
82
+ vocab = get_vocab()
83
+ model = get_model()
84
  seed = 'The entropy of the universe is'
85
+ generated = generate_text(seed, model, vocab, max_len=20, temperature=0.1, device=device, skip_tokens=['<unk>'], top_k=100)
86
  print(generated)