rkingery commited on
Commit
435428f
·
1 Parent(s): ac144e1

fixed bug

Browse files
Files changed (2) hide show
  1. .gitignore +9 -0
  2. utils.py +4 -4
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ *.log
2
+ *.aux
3
+ *.tex
4
+ *.synctex.gz
5
+ .ipynb_checkpoints
6
+ **/*.ipynb_checkpoints
7
+ __pycache__
8
+ **/__pycache__
9
+ hidden
utils.py CHANGED
@@ -5,7 +5,7 @@ from pathlib import Path
5
  from model import EncoderLM
6
 
7
  MAX_LEN = 50
8
- TEMPERATURE = 5.0
9
 
10
  device = 'cpu'
11
  model_dir = Path().cwd() / 'models'
@@ -48,7 +48,7 @@ def clean_text(tokens):
48
  detokenizer = TreebankWordDetokenizer()
49
  return detokenizer.detokenize(text)
50
 
51
- def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'], top_k=100):
52
  stoi, itos = vocab.get_stoi(), vocab.get_itos()
53
  stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
54
  tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
@@ -81,6 +81,6 @@ def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device
81
  if __name__ == '__main__':
82
  vocab = get_vocab()
83
  model = get_model()
84
- seed = 'The entropy of the universe is'
85
- generated = generate_text(seed, model, vocab, max_len=20, temperature=0.1, device=device, skip_tokens=['<unk>'], top_k=100)
86
  print(generated)
 
5
  from model import EncoderLM
6
 
7
  MAX_LEN = 50
8
+ TEMPERATURE = 1.0
9
 
10
  device = 'cpu'
11
  model_dir = Path().cwd() / 'models'
 
48
  detokenizer = TreebankWordDetokenizer()
49
  return detokenizer.detokenize(text)
50
 
51
+ def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'], top_k=50):
52
  stoi, itos = vocab.get_stoi(), vocab.get_itos()
53
  stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
54
  tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
 
81
  if __name__ == '__main__':
82
  vocab = get_vocab()
83
  model = get_model()
84
+ seed = 'Tell me a story about'
85
+ generated = generate_text(seed, model, vocab, max_len=20, temperature=1.0, device=device, skip_tokens=['<unk>'], top_k=50)
86
  print(generated)