File size: 254 Bytes
f3b11f9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import numpy as np
import torch


def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0