Ryan Kingery commited on
Commit
89ff2ee
·
1 Parent(s): 2db9276

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +5 -5
utils.py CHANGED
@@ -61,11 +61,11 @@ def generate_text(seed, model, vocab, max_len=20, temperature=1., device=device,
61
  idxs = []
62
  for _ in range(max_len):
63
  yhat = model(x)
64
- prob = yhat[:, -1].softmax(dim=-1).squeeze()
65
- prob /= temperature
66
- top_probs = torch.topk(prob, top_k, dim=-1).indices
67
- prob[~top_probs] = 0.
68
- idx = torch.multinomial(prob, 1, replacement=True).item()
69
  idxs.append(idx)
70
  x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
71
  if itos[idx] == '<eos>':
 
61
  idxs = []
62
  for _ in range(max_len):
63
  yhat = model(x)
64
+ probs = yhat[:, -1].softmax(dim=-1).squeeze()
65
+ probs /= temperature
66
+ top_probs = torch.topk(probs, top_k, dim=-1).indices
67
+ probs[~top_probs] = 0.
68
+ idx = torch.multinomial(probs, 1, replacement=True).item()
69
  idxs.append(idx)
70
  x = torch.cat([x, torch.ones(1, 1).fill_(idx).long().to(device)], dim=1)
71
  if itos[idx] == '<eos>':