Spaces:
Build error
Build error
Ryan Kingery commited on
Commit ·
89ff2ee
1
Parent(s): 2db9276
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -61,11 +61,11 @@ def generate_text(seed, model, vocab, max_len=20, temperature=1., device=device,
|
|
| 61 |
idxs = []
|
| 62 |
for _ in range(max_len):
|
| 63 |
yhat = model(x)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
top_probs = torch.topk(
|
| 67 |
-
|
| 68 |
-
idx = torch.multinomial(
|
| 69 |
idxs.append(idx)
|
| 70 |
x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
|
| 71 |
if itos[idx] == '<eos>':
|
|
|
|
| 61 |
idxs = []
|
| 62 |
for _ in range(max_len):
|
| 63 |
yhat = model(x)
|
| 64 |
+
probs = yhat[:, -1].softmax(dim=-1).squeeze()
|
| 65 |
+
probs /= temperature
|
| 66 |
+
top_probs = torch.topk(probs, top_k, dim=-1).indices
|
| 67 |
+
probs[~top_probs] = 0.
|
| 68 |
+
idx = torch.multinomial(probs, 1, replacement=True).item()
|
| 69 |
idxs.append(idx)
|
| 70 |
x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
|
| 71 |
if itos[idx] == '<eos>':
|