WildnerveAI's picture
Upload 20 files
0861a59 verified
"""
Custom data collators for transformer training.
"""
import torch
import random
from typing import Dict, List, Any, Union
from dataclasses import dataclass
@dataclass
class DataCollatorForLanguageModeling:
"""
Data collator for language modeling.
This collator will tokenize inputs and dynamically mask tokens
for masked language modeling tasks.
"""
tokenizer: Any
mlm: bool = True # Whether to use masked language modeling
mlm_probability: float = 0.15 # Probability of masking a token
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""
Collate a batch of examples.
Args:
examples: List of examples from dataset
Returns:
Batch dictionary for model
"""
# Extract input_ids
input_ids = [example["input_ids"] for example in examples]
# Concatenate inputs
batch = self.tokenizer.pad(
{"input_ids": input_ids},
return_tensors="pt"
)
# If masked language modeling is enabled
if self.mlm:
inputs, labels = self.mask_tokens(batch["input_ids"])
return {"input_ids": inputs, "labels": labels}
else:
labels = batch["input_ids"].clone()
return {
"input_ids": batch["input_ids"],
"labels": labels,
"attention_mask": batch.get("attention_mask", None)
}
def mask_tokens(
self, inputs: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling.
Args:
inputs: Input tensor
Returns:
Tuple of (masked inputs, labels)
"""
labels = inputs.clone()
# Get probability mask
probability_matrix = torch.full(labels.shape, self.mlm_probability)
# Create special tokens mask
if hasattr(self.tokenizer, "get_special_tokens_mask"):
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = torch.tensor(
[
[self._is_special_token(x) for x in val]
for val in labels.tolist()
],
dtype=torch.bool,
)
# Don't mask special tokens
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
# Get mask indices
masked_indices = torch.bernoulli(probability_matrix).bool()
# Set labels for non-masked tokens to -100 (ignored in loss)
labels[~masked_indices] = -100
# Set 80% of masked tokens to [MASK]
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
if hasattr(self.tokenizer, "mask_token_id") and self.tokenizer.mask_token_id is not None:
inputs[indices_replaced] = self.tokenizer.mask_token_id
# Set