Spaces:
Build error
Build error
rkingery commited on
Commit ·
ac144e1
1
Parent(s): 5cefcd7
changed temperature; added text detokenizer
Browse files- __pycache__/model.cpython-310.pyc +0 -0
- app.py +2 -2
- utils.py +17 -19
__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (5.65 kB). View file
|
|
|
app.py
CHANGED
|
@@ -9,12 +9,12 @@ st.title('Just a Dumb Language Model')
|
|
| 9 |
text = st.text_input('Enter some text (this will be used to seed the language model)', value='Tell me a story about')
|
| 10 |
max_len = st.number_input('Enter a max number of words to generate', min_value=0, max_value=512, value=MAX_LEN)
|
| 11 |
temperature = st.number_input('Enter a temperature (the higher it is, the more random the output will be)',
|
| 12 |
-
min_value=0., max_value=
|
| 13 |
|
| 14 |
if st.button('Click to run'):
|
| 15 |
vocab = get_vocab()
|
| 16 |
model = get_model()
|
| 17 |
-
generated = generate_text(text, model, vocab, max_len=
|
| 18 |
|
| 19 |
st.markdown('### Generated Text')
|
| 20 |
st.markdown(f'{generated}')
|
|
|
|
| 9 |
text = st.text_input('Enter some text (this will be used to seed the language model)', value='Tell me a story about')
|
| 10 |
max_len = st.number_input('Enter a max number of words to generate', min_value=0, max_value=512, value=MAX_LEN)
|
| 11 |
temperature = st.number_input('Enter a temperature (the higher it is, the more random the output will be)',
|
| 12 |
+
min_value=0., max_value=100., value=TEMPERATURE)
|
| 13 |
|
| 14 |
if st.button('Click to run'):
|
| 15 |
vocab = get_vocab()
|
| 16 |
model = get_model()
|
| 17 |
+
generated = generate_text(text, model, vocab, max_len=MAX_LEN, temperature=TEMPERATURE)
|
| 18 |
|
| 19 |
st.markdown('### Generated Text')
|
| 20 |
st.markdown(f'{generated}')
|
utils.py
CHANGED
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
|
| 5 |
from model import EncoderLM
|
| 6 |
|
| 7 |
MAX_LEN = 50
|
| 8 |
-
TEMPERATURE =
|
| 9 |
|
| 10 |
device = 'cpu'
|
| 11 |
model_dir = Path().cwd() / 'models'
|
|
@@ -31,6 +31,7 @@ def get_model():
|
|
| 31 |
return model
|
| 32 |
|
| 33 |
def clean_text(tokens):
|
|
|
|
| 34 |
text = []
|
| 35 |
prev_token = '<bos>'
|
| 36 |
for token in tokens:
|
|
@@ -44,9 +45,10 @@ def clean_text(tokens):
|
|
| 44 |
if token not in ['<bos>', '<eos>', '<up>', '<cap>']:
|
| 45 |
text.append(token)
|
| 46 |
prev_token = token
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
-
def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'], top_k=
|
| 50 |
stoi, itos = vocab.get_stoi(), vocab.get_itos()
|
| 51 |
stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
|
| 52 |
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
|
|
@@ -54,25 +56,21 @@ def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device
|
|
| 54 |
seed_tokens = ['<bos>'] + tokenizer(seed)
|
| 55 |
x = torch.tensor([stoi_map(word) for word in seed_tokens]).long().to(device)[None, :]
|
| 56 |
idxs = []
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
idx = idx_prev
|
| 62 |
for _ in range(max_len):
|
| 63 |
-
yhat = model(x)
|
| 64 |
prob = yhat[:, -1].softmax(dim=-1).squeeze()
|
| 65 |
top_probs = torch.topk(prob, top_k, dim=-1).indices
|
| 66 |
prob[~top_probs] = 0.
|
| 67 |
-
while (itos[idx] in skip_tokens) or (idx == idx_prev) or (idx == idx_prev_prev):
|
| 68 |
-
|
| 69 |
-
idx = prob.argmax(-1).item()
|
| 70 |
-
else:
|
| 71 |
-
idx = torch.multinomial(prob, 1, replacement=True).item()
|
| 72 |
idxs.append(idx)
|
| 73 |
x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
|
| 74 |
-
idx_prev_prev = idx_prev
|
| 75 |
-
idx_prev = idx
|
| 76 |
if itos[idx] == '<eos>':
|
| 77 |
break
|
| 78 |
generated = [itos[idx] for idx in idxs]
|
|
@@ -81,8 +79,8 @@ def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device
|
|
| 81 |
|
| 82 |
|
| 83 |
if __name__ == '__main__':
|
| 84 |
-
vocab = get_vocab(
|
| 85 |
-
model = get_model(
|
| 86 |
seed = 'The entropy of the universe is'
|
| 87 |
-
generated = generate_text(seed, model, vocab, max_len=20, temperature=0.
|
| 88 |
print(generated)
|
|
|
|
| 5 |
from model import EncoderLM
|
| 6 |
|
| 7 |
MAX_LEN = 50
|
| 8 |
+
TEMPERATURE = 5.0
|
| 9 |
|
| 10 |
device = 'cpu'
|
| 11 |
model_dir = Path().cwd() / 'models'
|
|
|
|
| 31 |
return model
|
| 32 |
|
| 33 |
def clean_text(tokens):
|
| 34 |
+
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
| 35 |
text = []
|
| 36 |
prev_token = '<bos>'
|
| 37 |
for token in tokens:
|
|
|
|
| 45 |
if token not in ['<bos>', '<eos>', '<up>', '<cap>']:
|
| 46 |
text.append(token)
|
| 47 |
prev_token = token
|
| 48 |
+
detokenizer = TreebankWordDetokenizer()
|
| 49 |
+
return detokenizer.detokenize(text)
|
| 50 |
|
| 51 |
+
def generate_text(seed, model, vocab, max_len=20, temperature=0.5, device=device, skip_tokens=['<unk>'], top_k=100):
|
| 52 |
stoi, itos = vocab.get_stoi(), vocab.get_itos()
|
| 53 |
stoi_map = lambda word: stoi[word] if word in stoi.keys() else stoi['<unk>']
|
| 54 |
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
|
|
|
|
| 56 |
seed_tokens = ['<bos>'] + tokenizer(seed)
|
| 57 |
x = torch.tensor([stoi_map(word) for word in seed_tokens]).long().to(device)[None, :]
|
| 58 |
idxs = []
|
| 59 |
+
temperature = 1e-5 if temperature < 1e-5 else temperature
|
| 60 |
+
# idx_prev = stoi['<unk>']
|
| 61 |
+
# idx_prev_prev = stoi['<unk>']
|
| 62 |
+
# idx = idx_prev
|
|
|
|
| 63 |
for _ in range(max_len):
|
| 64 |
+
yhat = model(x) / temperature
|
| 65 |
prob = yhat[:, -1].softmax(dim=-1).squeeze()
|
| 66 |
top_probs = torch.topk(prob, top_k, dim=-1).indices
|
| 67 |
prob[~top_probs] = 0.
|
| 68 |
+
#while (itos[idx] in skip_tokens) or (idx == idx_prev) or (idx == idx_prev_prev):
|
| 69 |
+
idx = torch.multinomial(prob, 1, replacement=True).item()
|
|
|
|
|
|
|
|
|
|
| 70 |
idxs.append(idx)
|
| 71 |
x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
|
| 72 |
+
# idx_prev_prev = idx_prev
|
| 73 |
+
# idx_prev = idx
|
| 74 |
if itos[idx] == '<eos>':
|
| 75 |
break
|
| 76 |
generated = [itos[idx] for idx in idxs]
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
if __name__ == '__main__':
|
| 82 |
+
vocab = get_vocab()
|
| 83 |
+
model = get_model()
|
| 84 |
seed = 'The entropy of the universe is'
|
| 85 |
+
generated = generate_text(seed, model, vocab, max_len=20, temperature=0.1, device=device, skip_tokens=['<unk>'], top_k=100)
|
| 86 |
print(generated)
|