| """ |
| Data collation for language modeling. |
| |
| Provides efficient batch collation with padding and label handling |
| for autoregressive language model training. |
| Integrated with EthioBBPE tokenizer for Ethiopian languages. |
| """ |
|
|
| from typing import Dict, List, Optional, Any, Union |
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import Dataset as HFDataset |
|
|
|
|
| class TextDataset(Dataset): |
| """ |
| Simple text dataset that loads texts from a file. |
| |
| Args: |
| file_path: Path to text file (one sample per line) |
| max_length: Maximum sequence length |
| tokenizer_name: Tokenizer to use ('char' for character-level) |
| """ |
| |
| def __init__( |
| self, |
| file_path: str, |
| max_length: int = 512, |
| tokenizer_name: str = "char", |
| ): |
| self.max_length = max_length |
| |
| |
| with open(file_path, 'r', encoding='utf-8') as f: |
| self.texts = [line.strip() for line in f if line.strip()] |
| |
| |
| self.char_to_idx = {} |
| self.idx_to_char = {} |
| self._build_vocab() |
| |
| def _build_vocab(self): |
| """Build character-level vocabulary.""" |
| all_chars = set() |
| for text in self.texts: |
| all_chars.update(text) |
| |
| |
| self.char_to_idx['<pad>'] = 0 |
| self.char_to_idx['<unk>'] = 1 |
| |
| idx = len(self.char_to_idx) |
| for char in sorted(all_chars): |
| if char not in self.char_to_idx: |
| self.char_to_idx[char] = idx |
| idx += 1 |
| |
| self.idx_to_char = {v: k for k, v in self.char_to_idx.items()} |
| |
| def _tokenize(self, text: str) -> List[int]: |
| """Tokenize text to character IDs.""" |
| return [self.char_to_idx.get(c, self.char_to_idx['<unk>']) for c in text] |
| |
| def __len__(self) -> int: |
| return len(self.texts) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| text = self.texts[idx] |
| input_ids = self._tokenize(text)[:self.max_length] |
| |
| |
| if len(input_ids) == 0: |
| input_ids = [self.char_to_idx['<pad>']] |
| |
| return { |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), |
| "labels": torch.tensor(input_ids, dtype=torch.long), |
| } |
|
|
|
|
| def create_dataloader( |
| dataset: Dataset, |
| batch_size: int = 4, |
| shuffle: bool = True, |
| num_workers: int = 0, |
| pin_memory: bool = False, |
| ) -> DataLoader: |
| """ |
| Create a DataLoader from a dataset. |
| |
| Args: |
| dataset: Dataset to load from |
| batch_size: Batch size |
| shuffle: Whether to shuffle data |
| num_workers: Number of worker processes |
| pin_memory: Pin memory for faster GPU transfer |
| |
| Returns: |
| DataLoader instance |
| """ |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| collate_fn=lambda x: { |
| "input_ids": torch.stack([item["input_ids"] for item in x]), |
| "labels": torch.stack([item["labels"] for item in x]), |
| }, |
| ) |
|
|
|
|
| class DataCollatorForLanguageModeling: |
| """ |
| Data collator for language modeling tasks. |
| |
| Features: |
| - Dynamic padding to max length in batch |
| - Label creation for next-token prediction |
| - Optional masking for MLM (not used in decoder-only) |
| - Efficient tensor conversion |
| - Compatible with EthioBBPE and HuggingFace tokenizers |
| |
| Args: |
| pad_token_id: Token ID for padding |
| max_length: Maximum sequence length (None for dynamic) |
| return_tensors: Type of tensors to return ('pt', 'np') |
| """ |
| |
| def __init__( |
| self, |
| pad_token_id: int = 0, |
| max_length: Optional[int] = None, |
| return_tensors: str = "pt", |
| ): |
| self.pad_token_id = pad_token_id |
| self.max_length = max_length |
| self.return_tensors = return_tensors |
| |
| def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
| """ |
| Collate a batch of examples. |
| |
| Args: |
| examples: List of dictionaries with 'input_ids' and optional 'labels' |
| |
| Returns: |
| Batch dictionary with padded tensors |
| """ |
| |
| input_ids_list = [example["input_ids"] for example in examples] |
| |
| |
| has_labels = "labels" in examples[0] |
| if has_labels: |
| labels_list = [example["labels"] for example in examples] |
| else: |
| |
| labels_list = input_ids_list |
| |
| |
| if self.max_length is not None: |
| input_ids_list = [ids[:self.max_length] for ids in input_ids_list] |
| labels_list = [lbl[:self.max_length] for lbl in labels_list] |
| |
| |
| input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids_list] |
| labels_tensors = [torch.tensor(lbl, dtype=torch.long) for lbl in labels_list] |
| |
| |
| input_ids = pad_sequence( |
| input_ids_tensors, |
| batch_first=True, |
| padding_value=self.pad_token_id, |
| ) |
| labels = pad_sequence( |
| labels_tensors, |
| batch_first=True, |
| padding_value=-100, |
| ) |
| |
| |
| attention_mask = (input_ids != self.pad_token_id).long() |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| } |
|
|
|
|
| def create_training_dataset( |
| texts: List[str], |
| tokenizer, |
| max_length: int = 512, |
| stride: Optional[int] = None, |
| ) -> torch.utils.data.Dataset: |
| """ |
| Create a dataset from raw texts. |
| |
| Supports both EthioBBPE and HuggingFace tokenizers. |
| |
| Args: |
| texts: List of text strings |
| tokenizer: Tokenizer instance (EthioBBPE or HF tokenizer) |
| max_length: Maximum sequence length |
| stride: Stride for chunking long texts (None for truncation) |
| |
| Returns: |
| Dataset with tokenized examples |
| """ |
| from datasets import Dataset |
| |
| |
| is_ethiobbpe = hasattr(tokenizer, 'encode_batch') or type(tokenizer).__name__ == 'EthioBBPE' |
| |
| |
| def tokenize_function(examples): |
| if is_ethiobbpe: |
| |
| encoded = tokenizer.encode_batch(examples["text"]) |
| input_ids = [item["ids"] for item in encoded] |
| return {"input_ids": input_ids} |
| else: |
| |
| return tokenizer( |
| examples["text"], |
| truncation=True, |
| max_length=max_length, |
| padding=False, |
| return_special_tokens_mask=False, |
| ) |
| |
| dataset = Dataset.from_dict({"text": texts}) |
| tokenized_dataset = dataset.map( |
| tokenize_function, |
| batched=True, |
| remove_columns=["text"], |
| ) |
| |
| |
| if stride is not None: |
| tokenized_dataset = _group_texts(tokenized_dataset, max_length, stride) |
| |
| return tokenized_dataset |
|
|
|
|
| def _group_texts( |
| dataset, |
| max_length: int, |
| stride: int, |
| ) -> torch.utils.data.Dataset: |
| """Group shorter sequences into longer ones.""" |
| |
| def group_function(examples): |
| |
| concatenated_ids = sum(examples["input_ids"], []) |
| |
| |
| total_length = len(concatenated_ids) |
| if total_length <= max_length: |
| return {"input_ids": [concatenated_ids]} |
| |
| |
| chunks = [] |
| for i in range(0, total_length - max_length + 1, stride): |
| chunks.append(concatenated_ids[i : i + max_length]) |
| |
| |
| if total_length % stride != 0 or total_length < max_length: |
| remaining = total_length - ((total_length // stride) * stride) |
| if remaining > 0: |
| chunks.append(concatenated_ids[-max_length:]) |
| |
| return {"input_ids": chunks} |
| |
| return dataset.map( |
| group_function, |
| batched=True, |
| desc="Grouping texts", |
| ) |
|
|