|
|
"""
|
|
|
Custom data collators for transformer training.
|
|
|
"""
|
|
|
import torch
|
|
|
import random
|
|
|
from typing import Dict, List, Any, Union
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
|
class DataCollatorForLanguageModeling:
|
|
|
"""
|
|
|
Data collator for language modeling.
|
|
|
|
|
|
This collator will tokenize inputs and dynamically mask tokens
|
|
|
for masked language modeling tasks.
|
|
|
"""
|
|
|
tokenizer: Any
|
|
|
mlm: bool = True
|
|
|
mlm_probability: float = 0.15
|
|
|
|
|
|
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Collate a batch of examples.
|
|
|
|
|
|
Args:
|
|
|
examples: List of examples from dataset
|
|
|
|
|
|
Returns:
|
|
|
Batch dictionary for model
|
|
|
"""
|
|
|
|
|
|
input_ids = [example["input_ids"] for example in examples]
|
|
|
|
|
|
|
|
|
batch = self.tokenizer.pad(
|
|
|
{"input_ids": input_ids},
|
|
|
return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.mlm:
|
|
|
inputs, labels = self.mask_tokens(batch["input_ids"])
|
|
|
return {"input_ids": inputs, "labels": labels}
|
|
|
else:
|
|
|
labels = batch["input_ids"].clone()
|
|
|
return {
|
|
|
"input_ids": batch["input_ids"],
|
|
|
"labels": labels,
|
|
|
"attention_mask": batch.get("attention_mask", None)
|
|
|
}
|
|
|
|
|
|
def mask_tokens(
|
|
|
self, inputs: torch.Tensor
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
Prepare masked tokens inputs/labels for masked language modeling.
|
|
|
|
|
|
Args:
|
|
|
inputs: Input tensor
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (masked inputs, labels)
|
|
|
"""
|
|
|
labels = inputs.clone()
|
|
|
|
|
|
|
|
|
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
|
|
|
|
|
|
|
|
if hasattr(self.tokenizer, "get_special_tokens_mask"):
|
|
|
special_tokens_mask = [
|
|
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
|
|
|
for val in labels.tolist()
|
|
|
]
|
|
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
|
|
else:
|
|
|
special_tokens_mask = torch.tensor(
|
|
|
[
|
|
|
[self._is_special_token(x) for x in val]
|
|
|
for val in labels.tolist()
|
|
|
],
|
|
|
dtype=torch.bool,
|
|
|
)
|
|
|
|
|
|
|
|
|
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
|
|
|
|
|
|
|
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
|
|
|
|
|
|
|
|
labels[~masked_indices] = -100
|
|
|
|
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
|
|
|
|
|
if hasattr(self.tokenizer, "mask_token_id") and self.tokenizer.mask_token_id is not None:
|
|
|
inputs[indices_replaced] = self.tokenizer.mask_token_id
|
|
|
|
|
|
|