""" Custom data collators for transformer training. """ import torch import random from typing import Dict, List, Any, Union from dataclasses import dataclass @dataclass class DataCollatorForLanguageModeling: """ Data collator for language modeling. This collator will tokenize inputs and dynamically mask tokens for masked language modeling tasks. """ tokenizer: Any mlm: bool = True # Whether to use masked language modeling mlm_probability: float = 0.15 # Probability of masking a token def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """ Collate a batch of examples. Args: examples: List of examples from dataset Returns: Batch dictionary for model """ # Extract input_ids input_ids = [example["input_ids"] for example in examples] # Concatenate inputs batch = self.tokenizer.pad( {"input_ids": input_ids}, return_tensors="pt" ) # If masked language modeling is enabled if self.mlm: inputs, labels = self.mask_tokens(batch["input_ids"]) return {"input_ids": inputs, "labels": labels} else: labels = batch["input_ids"].clone() return { "input_ids": batch["input_ids"], "labels": labels, "attention_mask": batch.get("attention_mask", None) } def mask_tokens( self, inputs: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Prepare masked tokens inputs/labels for masked language modeling. Args: inputs: Input tensor Returns: Tuple of (masked inputs, labels) """ labels = inputs.clone() # Get probability mask probability_matrix = torch.full(labels.shape, self.mlm_probability) # Create special tokens mask if hasattr(self.tokenizer, "get_special_tokens_mask"): special_tokens_mask = [ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) else: special_tokens_mask = torch.tensor( [ [self._is_special_token(x) for x in val] for val in labels.tolist() ], dtype=torch.bool, ) # Don't mask special tokens probability_matrix.masked_fill_(special_tokens_mask, value=0.0) # Get mask indices masked_indices = torch.bernoulli(probability_matrix).bool() # Set labels for non-masked tokens to -100 (ignored in loss) labels[~masked_indices] = -100 # Set 80% of masked tokens to [MASK] indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices if hasattr(self.tokenizer, "mask_token_id") and self.tokenizer.mask_token_id is not None: inputs[indices_replaced] = self.tokenizer.mask_token_id # Set