""" Data collation for language modeling. Provides efficient batch collation with padding and label handling for autoregressive language model training. Integrated with EthioBBPE tokenizer for Ethiopian languages. """ from typing import Dict, List, Optional, Any, Union import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader from datasets import Dataset as HFDataset class TextDataset(Dataset): """ Simple text dataset that loads texts from a file. Args: file_path: Path to text file (one sample per line) max_length: Maximum sequence length tokenizer_name: Tokenizer to use ('char' for character-level) """ def __init__( self, file_path: str, max_length: int = 512, tokenizer_name: str = "char", ): self.max_length = max_length # Load texts from file with open(file_path, 'r', encoding='utf-8') as f: self.texts = [line.strip() for line in f if line.strip()] # Create simple character-level tokenizer self.char_to_idx = {} self.idx_to_char = {} self._build_vocab() def _build_vocab(self): """Build character-level vocabulary.""" all_chars = set() for text in self.texts: all_chars.update(text) # Special tokens self.char_to_idx[''] = 0 self.char_to_idx[''] = 1 idx = len(self.char_to_idx) for char in sorted(all_chars): if char not in self.char_to_idx: self.char_to_idx[char] = idx idx += 1 self.idx_to_char = {v: k for k, v in self.char_to_idx.items()} def _tokenize(self, text: str) -> List[int]: """Tokenize text to character IDs.""" return [self.char_to_idx.get(c, self.char_to_idx['']) for c in text] def __len__(self) -> int: return len(self.texts) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: text = self.texts[idx] input_ids = self._tokenize(text)[:self.max_length] # Ensure we have at least some tokens if len(input_ids) == 0: input_ids = [self.char_to_idx['']] return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "labels": torch.tensor(input_ids, dtype=torch.long), } def create_dataloader( dataset: Dataset, batch_size: int = 4, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = False, ) -> DataLoader: """ Create a DataLoader from a dataset. Args: dataset: Dataset to load from batch_size: Batch size shuffle: Whether to shuffle data num_workers: Number of worker processes pin_memory: Pin memory for faster GPU transfer Returns: DataLoader instance """ return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, collate_fn=lambda x: { "input_ids": torch.stack([item["input_ids"] for item in x]), "labels": torch.stack([item["labels"] for item in x]), }, ) class DataCollatorForLanguageModeling: """ Data collator for language modeling tasks. Features: - Dynamic padding to max length in batch - Label creation for next-token prediction - Optional masking for MLM (not used in decoder-only) - Efficient tensor conversion - Compatible with EthioBBPE and HuggingFace tokenizers Args: pad_token_id: Token ID for padding max_length: Maximum sequence length (None for dynamic) return_tensors: Type of tensors to return ('pt', 'np') """ def __init__( self, pad_token_id: int = 0, max_length: Optional[int] = None, return_tensors: str = "pt", ): self.pad_token_id = pad_token_id self.max_length = max_length self.return_tensors = return_tensors def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: """ Collate a batch of examples. Args: examples: List of dictionaries with 'input_ids' and optional 'labels' Returns: Batch dictionary with padded tensors """ # Extract input_ids from examples input_ids_list = [example["input_ids"] for example in examples] # Handle labels has_labels = "labels" in examples[0] if has_labels: labels_list = [example["labels"] for example in examples] else: # Use input_ids as labels (standard for LM) labels_list = input_ids_list # Truncate if max_length specified if self.max_length is not None: input_ids_list = [ids[:self.max_length] for ids in input_ids_list] labels_list = [lbl[:self.max_length] for lbl in labels_list] # Convert to tensors input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids_list] labels_tensors = [torch.tensor(lbl, dtype=torch.long) for lbl in labels_list] # Pad sequences input_ids = pad_sequence( input_ids_tensors, batch_first=True, padding_value=self.pad_token_id, ) labels = pad_sequence( labels_tensors, batch_first=True, padding_value=-100, # Ignore padding in loss calculation ) # Create attention mask attention_mask = (input_ids != self.pad_token_id).long() return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def create_training_dataset( texts: List[str], tokenizer, max_length: int = 512, stride: Optional[int] = None, ) -> torch.utils.data.Dataset: """ Create a dataset from raw texts. Supports both EthioBBPE and HuggingFace tokenizers. Args: texts: List of text strings tokenizer: Tokenizer instance (EthioBBPE or HF tokenizer) max_length: Maximum sequence length stride: Stride for chunking long texts (None for truncation) Returns: Dataset with tokenized examples """ from datasets import Dataset # Detect tokenizer type is_ethiobbpe = hasattr(tokenizer, 'encode_batch') or type(tokenizer).__name__ == 'EthioBBPE' # Tokenize all texts def tokenize_function(examples): if is_ethiobbpe: # EthioBBPE tokenizer encoded = tokenizer.encode_batch(examples["text"]) input_ids = [item["ids"] for item in encoded] return {"input_ids": input_ids} else: # HuggingFace tokenizer return tokenizer( examples["text"], truncation=True, max_length=max_length, padding=False, return_special_tokens_mask=False, ) dataset = Dataset.from_dict({"text": texts}) tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=["text"], ) # Group texts if needed for longer context if stride is not None: tokenized_dataset = _group_texts(tokenized_dataset, max_length, stride) return tokenized_dataset def _group_texts( dataset, max_length: int, stride: int, ) -> torch.utils.data.Dataset: """Group shorter sequences into longer ones.""" def group_function(examples): # Concatenate all input_ids concatenated_ids = sum(examples["input_ids"], []) # Calculate number of chunks total_length = len(concatenated_ids) if total_length <= max_length: return {"input_ids": [concatenated_ids]} # Create chunks with stride chunks = [] for i in range(0, total_length - max_length + 1, stride): chunks.append(concatenated_ids[i : i + max_length]) # Add final chunk if there's remaining data if total_length % stride != 0 or total_length < max_length: remaining = total_length - ((total_length // stride) * stride) if remaining > 0: chunks.append(concatenated_ids[-max_length:]) return {"input_ids": chunks} return dataset.map( group_function, batched=True, desc="Grouping texts", )