ChromFound / src /utils /mask_utils.py
YifengJiao's picture
Upload folder using huggingface_hub
534e5a3 verified
import torch
def mask_tensor(input_tensor, mask_prob=0.15, padding_value=-1, add_cls=False):
# Identify zero, nonzero, and padding elements
zero_mask = (input_tensor == 0).float()
nonzero_mask = (input_tensor != 0).float()
padding_mask = (input_tensor == padding_value).float()
# Generate random masks for zero and nonzero elements
zero_random_mask = (torch.rand_like(input_tensor) < mask_prob).float()
nonzero_random_mask = (torch.rand_like(input_tensor) < mask_prob).float()
# Apply mask for zero values (masking zeros with 15% probability)
zero_applied_mask = zero_mask * zero_random_mask
# Apply mask for nonzero values (masking nonzero elements with 15% probability)
nonzero_applied_mask = nonzero_mask * nonzero_random_mask
# Get indices of nonzero values to sample from
nonzero_indices = (input_tensor != 0) & (input_tensor != padding_value)
# Sample nonzero values uniformly from the input tensor
sampled_nonzero_values = input_tensor[nonzero_indices]
if len(sampled_nonzero_values) > 0:
sampled_nonzero_values = sampled_nonzero_values[
torch.randint(0, len(sampled_nonzero_values), input_tensor.shape)
]
# Apply mask strategy
# Nonzero values should be masked by zero
masked_tensor = input_tensor.clone()
masked_tensor[nonzero_applied_mask.bool()] = 0
# Zero values should be masked by sampled nonzero values
masked_tensor[zero_applied_mask.bool()] = sampled_nonzero_values[zero_applied_mask.bool()]
# Ensure padding values are not affected
masked_tensor[padding_mask.bool()] = padding_value
if add_cls:
masked_tensor[:, 0] = input_tensor[:, 0] # Retain the original first element
apply_mask = zero_applied_mask + nonzero_applied_mask
return masked_tensor, apply_mask, padding_mask