|
|
import torch |
|
|
|
|
|
|
|
|
def mask_tensor(input_tensor, mask_prob=0.15, padding_value=-1, add_cls=False): |
|
|
|
|
|
zero_mask = (input_tensor == 0).float() |
|
|
nonzero_mask = (input_tensor != 0).float() |
|
|
padding_mask = (input_tensor == padding_value).float() |
|
|
|
|
|
|
|
|
zero_random_mask = (torch.rand_like(input_tensor) < mask_prob).float() |
|
|
nonzero_random_mask = (torch.rand_like(input_tensor) < mask_prob).float() |
|
|
|
|
|
|
|
|
zero_applied_mask = zero_mask * zero_random_mask |
|
|
|
|
|
|
|
|
nonzero_applied_mask = nonzero_mask * nonzero_random_mask |
|
|
|
|
|
|
|
|
nonzero_indices = (input_tensor != 0) & (input_tensor != padding_value) |
|
|
|
|
|
|
|
|
sampled_nonzero_values = input_tensor[nonzero_indices] |
|
|
if len(sampled_nonzero_values) > 0: |
|
|
sampled_nonzero_values = sampled_nonzero_values[ |
|
|
torch.randint(0, len(sampled_nonzero_values), input_tensor.shape) |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
masked_tensor = input_tensor.clone() |
|
|
masked_tensor[nonzero_applied_mask.bool()] = 0 |
|
|
|
|
|
|
|
|
masked_tensor[zero_applied_mask.bool()] = sampled_nonzero_values[zero_applied_mask.bool()] |
|
|
|
|
|
|
|
|
masked_tensor[padding_mask.bool()] = padding_value |
|
|
if add_cls: |
|
|
masked_tensor[:, 0] = input_tensor[:, 0] |
|
|
|
|
|
apply_mask = zero_applied_mask + nonzero_applied_mask |
|
|
|
|
|
return masked_tensor, apply_mask, padding_mask |
|
|
|