Spaces:
Build error
Build error
rkingery commited on
Commit ·
07e8e52
1
Parent(s): cce1c2a
fixed temperature again
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ st.title('Dumb Language Model')
|
|
| 9 |
text = st.text_input('Enter some text (this will be used to seed the language model)', value='Tell me a story about')
|
| 10 |
max_len = st.number_input('Enter a max number of words to generate', min_value=0, max_value=512, value=MAX_LEN)
|
| 11 |
temperature = st.number_input('Enter a temperature (the higher it is, the more random the output will be)',
|
| 12 |
-
min_value=
|
| 13 |
|
| 14 |
if st.button('Click to run'):
|
| 15 |
vocab = get_vocab()
|
|
|
|
| 9 |
text = st.text_input('Enter some text (this will be used to seed the language model)', value='Tell me a story about')
|
| 10 |
max_len = st.number_input('Enter a max number of words to generate', min_value=0, max_value=512, value=MAX_LEN)
|
| 11 |
temperature = st.number_input('Enter a temperature (the higher it is, the more random the output will be)',
|
| 12 |
+
min_value=1., max_value=100., value=TEMPERATURE)
|
| 13 |
|
| 14 |
if st.button('Click to run'):
|
| 15 |
vocab = get_vocab()
|
utils.py
CHANGED
|
@@ -48,7 +48,7 @@ def clean_text(tokens):
|
|
| 48 |
detokenizer = TreebankWordDetokenizer()
|
| 49 |
return detokenizer.detokenize(text).replace("' ", "'")
|
| 50 |
|
| 51 |
-
def generate_text(seed, model, vocab, max_len=20, temperature=
|
| 52 |
stoi, itos = vocab.get_stoi(), vocab.get_itos()
|
| 53 |
stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
|
| 54 |
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
|
|
|
|
| 48 |
detokenizer = TreebankWordDetokenizer()
|
| 49 |
return detokenizer.detokenize(text).replace("' ", "'")
|
| 50 |
|
| 51 |
+
def generate_text(seed, model, vocab, max_len=20, temperature=1., device=device, skip_tokens=['<unk>'], top_k=50):
|
| 52 |
stoi, itos = vocab.get_stoi(), vocab.get_itos()
|
| 53 |
stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
|
| 54 |
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
|