Spaces:
Sleeping
Sleeping
File size: 1,297 Bytes
f3b11f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import torch
from torch.autograd import Variable
from models.transformer.module.subsequent_mask import subsequent_mask
def decode(model, src, src_mask, max_len, type):
ys = torch.ones(1)
ys = ys.repeat(src.shape[0], 1).view(src.shape[0], 1).type_as(src.data)
# ys shape [batch_size, 1]
encoder_outputs = model.encode(src, src_mask)
break_condition = torch.zeros(src.shape[0], dtype=torch.bool)
for i in range(max_len-1):
with torch.no_grad():
out = model.decode(encoder_outputs, src_mask, Variable(ys),
Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
log_prob = model.generator(out[:, -1])
prob = torch.exp(log_prob)
if type == 'greedy':
_, next_word = torch.max(prob, dim = 1)
ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) # [batch_size, i]
elif type == 'multinomial':
next_word = torch.multinomial(prob, 1)
ys = torch.cat([ys, next_word], dim=1) #[batch_size, i]
next_word = torch.squeeze(next_word)
break_condition = (break_condition | (next_word.to('cpu')==2))
if all(break_condition): # end token
break
return ys
|