File size: 1,960 Bytes
a24db0c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import torch
import random
MASK_TOKEN = 0
PAD_TOKEN = 1
CLS_TOKEN = 2
def complete_masking(batch, masking_p, n_tokens):
"""Apply masking to input batch for masked language modeling.
Args:
batch (dict): Input batch containing 'input_ids' and 'attention_mask'
masking_p (float): Probability of masking a token
n_tokens (int): Total number of tokens in vocabulary
Returns:
dict: Batch with masked indices and masking information
"""
device = batch['input_ids'].device
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
# Create mask tensor (1 for tokens to be masked, 0 otherwise)
prob = torch.rand(input_ids.shape, device=device)
mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN)
# For masked tokens:
# - 80% replace with MASK token
# - 10% replace with random token
# - 10% keep unchanged
masked_indices = input_ids.clone()
# Calculate number of tokens to be masked
num_tokens_to_mask = mask.sum().item()
# Determine which tokens get which type of masking
mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8
random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask
# Apply MASK token (80% of masked tokens)
masked_indices[mask] = torch.where(
mask_mask,
torch.tensor(MASK_TOKEN, device=device, dtype=torch.long),
masked_indices[mask]
)
# Apply random tokens (10% of masked tokens)
random_tokens = torch.randint(
3, n_tokens, # Start from 3 to avoid special tokens
(random_mask.sum(),),
device=device,
dtype=torch.long
)
masked_indices[mask][random_mask] = random_tokens
# 10% remain unchanged
return {
'masked_indices': masked_indices,
'attention_mask': attention_mask,
'mask': mask,
'input_ids': input_ids
} |