LLM-fastAPI / models /transformer /module /subsequent_mask.py
Songyou's picture
add files
f3b11f9
raw
history blame contribute delete
254 Bytes
import numpy as np
import torch
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0