Spaces:
Build error
Build error
rkingery commited on
Commit ·
435428f
1
Parent(s): ac144e1
fixed bug
Browse files- .gitignore +9 -0
- utils.py +4 -4
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.log
|
| 2 |
+
*.aux
|
| 3 |
+
*.tex
|
| 4 |
+
*.synctex.gz
|
| 5 |
+
.ipynb_checkpoints
|
| 6 |
+
**/*.ipynb_checkpoints
|
| 7 |
+
__pycache__
|
| 8 |
+
**/__pycache__
|
| 9 |
+
hidden
|
utils.py
CHANGED
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
|
| 5 |
from model import EncoderLM
|
| 6 |
|
| 7 |
MAX_LEN = 50
|
| 8 |
-
TEMPERATURE =
|
| 9 |
|
| 10 |
device = 'cpu'
|
| 11 |
model_dir = Path().cwd() / 'models'
|
|
@@ -48,7 +48,7 @@ def clean_text(tokens):
|
|
| 48 |
detokenizer = TreebankWordDetokenizer()
|
| 49 |
return detokenizer.detokenize(text)
|
| 50 |
|
| 51 |
-
def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'], top_k=
|
| 52 |
stoi, itos = vocab.get_stoi(), vocab.get_itos()
|
| 53 |
stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
|
| 54 |
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
|
|
@@ -81,6 +81,6 @@ def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device
|
|
| 81 |
if __name__ == '__main__':
|
| 82 |
vocab = get_vocab()
|
| 83 |
model = get_model()
|
| 84 |
-
seed = '
|
| 85 |
-
generated = generate_text(seed, model, vocab, max_len=20, temperature=
|
| 86 |
print(generated)
|
|
|
|
| 5 |
from model import EncoderLM
|
| 6 |
|
| 7 |
MAX_LEN = 50
|
| 8 |
+
TEMPERATURE = 1.0
|
| 9 |
|
| 10 |
device = 'cpu'
|
| 11 |
model_dir = Path().cwd() / 'models'
|
|
|
|
| 48 |
detokenizer = TreebankWordDetokenizer()
|
| 49 |
return detokenizer.detokenize(text)
|
| 50 |
|
| 51 |
+
def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'], top_k=50):
|
| 52 |
stoi, itos = vocab.get_stoi(), vocab.get_itos()
|
| 53 |
stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
|
| 54 |
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
|
|
|
|
| 81 |
if __name__ == '__main__':
|
| 82 |
vocab = get_vocab()
|
| 83 |
model = get_model()
|
| 84 |
+
seed = 'Tell me a story about'
|
| 85 |
+
generated = generate_text(seed, model, vocab, max_len=20, temperature=1.0, device=device, skip_tokens=['<unk>'], top_k=50)
|
| 86 |
print(generated)
|