Spaces:
Sleeping
Sleeping
File size: 254 Bytes
f3b11f9 | 1 2 3 4 5 6 7 8 9 10 11 | import numpy as np
import torch
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0
|