""" Data collation for language modeling. Provides efficient batch collation with padding and label handling for autoregressive language model training. Integrated with NTFTokenizer (EthioBBPE-based) for Ethiopian languages. """ from typing import Dict, List, Optional, Any, Union import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader from datasets import Dataset as HFDataset try: from tokenizer import NTFTokenizer NTF_TOKENIZER_AVAILABLE = True except ImportError: NTF_TOKENIZER_AVAILABLE = False NTFTokenizer = None class TextDataset(Dataset): """ Simple text dataset that loads texts from a file. Args: file_path: Path to text file (one sample per line) max_length: Maximum sequence length tokenizer_name: Tokenizer to use ('char' for character-level) """ def __init__( self, file_path: str, max_length: int = 512, tokenizer_name: str = "char", ): self.max_length = max_length # Load texts from file with open(file_path, 'r', encoding='utf-8') as f: self.texts = [line.strip() for line in f if line.strip()] # Create simple character-level tokenizer self.char_to_idx = {} self.idx_to_char = {} self._build_vocab() def _build_vocab(self): """Build character-level vocabulary.""" all_chars = set() for text in self.texts: all_chars.update(text) # Special tokens self.char_to_idx[''] = 0 self.char_to_idx[''] = 1 idx = len(self.char_to_idx) for char in sorted(all_chars): if char not in self.char_to_idx: self.char_to_idx[char] = idx idx += 1 self.idx_to_char = {v: k for k, v in self.char_to_idx.items()} def _tokenize(self, text: str) -> List[int]: """Tokenize text to character IDs.""" return [self.char_to_idx.get(c, self.char_to_idx['']) for c in text] def __len__(self) -> int: return len(self.texts) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: text = self.texts[idx] input_ids = self._tokenize(text)[:self.max_length] # Ensure we have at least some tokens if len(input_ids) == 0: input_ids = [self.char_to_idx['']] return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "labels": torch.tensor(input_ids, dtype=torch.long), } def create_dataloader( dataset: Dataset, batch_size: int = 4, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = False, ) -> DataLoader: """ Create a DataLoader from a dataset. Args: dataset: Dataset to load from batch_size: Batch size shuffle: Whether to shuffle data num_workers: Number of worker processes pin_memory: Pin memory for faster GPU transfer Returns: DataLoader instance """ return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, collate_fn=lambda x: { "input_ids": torch.stack([item["input_ids"] for item in x]), "labels": torch.stack([item["labels"] for item in x]), }, ) class DataCollatorForLanguageModeling: """ Data collator for language modeling tasks. Features: - Dynamic padding to max length in batch - Label creation for next-token prediction - Optional masking for MLM (not used in decoder-only) - Efficient tensor conversion - Compatible with EthioBBPE and HuggingFace tokenizers Args: pad_token_id: Token ID for padding max_length: Maximum sequence length (None for dynamic) return_tensors: Type of tensors to return ('pt', 'np') """ def __init__( self, pad_token_id: int = 0, max_length: Optional[int] = None, return_tensors: str = "pt", ): self.pad_token_id = pad_token_id self.max_length = max_length self.return_tensors = return_tensors def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: """ Collate a batch of examples. Args: examples: List of dictionaries with 'input_ids' and optional 'labels' Returns: Batch dictionary with padded tensors """ # Extract input_ids from examples input_ids_list = [example["input_ids"] for example in examples] # Handle labels has_labels = "labels" in examples[0] if has_labels: labels_list = [example["labels"] for example in examples] else: # Use input_ids as labels (standard for LM) labels_list = input_ids_list # Truncate if max_length specified if self.max_length is not None: input_ids_list = [ids[:self.max_length] for ids in input_ids_list] labels_list = [lbl[:self.max_length] for lbl in labels_list] # Convert to tensors input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids_list] labels_tensors = [torch.tensor(lbl, dtype=torch.long) for lbl in labels_list] # Pad sequences input_ids = pad_sequence( input_ids_tensors, batch_first=True, padding_value=self.pad_token_id, ) labels = pad_sequence( labels_tensors, batch_first=True, padding_value=-100, # Ignore padding in loss calculation ) # Create attention mask attention_mask = (input_ids != self.pad_token_id).long() return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def create_training_dataset( texts: List[str], tokenizer, max_length: int = 512, stride: Optional[int] = None, ) -> torch.utils.data.Dataset: """ Create a dataset from raw texts. Supports NTFTokenizer, EthioBBPE, and HuggingFace tokenizers. Args: texts: List of text strings tokenizer: Tokenizer instance (NTFTokenizer, EthioBBPE, or HF tokenizer) max_length: Maximum sequence length stride: Stride for chunking long texts (None for truncation) Returns: Dataset with tokenized examples """ from datasets import Dataset # Detect tokenizer type is_ntf_tokenizer = NTF_TOKENIZER_AVAILABLE and isinstance(tokenizer, NTFTokenizer) is_ethiobbpe = hasattr(tokenizer, 'encode_batch') or type(tokenizer).__name__ in ['EthioBBPE', 'NTFTokenizer'] # Tokenize all texts def tokenize_function(examples): if is_ntf_tokenizer: # NTFTokenizer - use callable interface encoded = tokenizer(examples["text"], add_special_tokens=True, padding=False, truncation=True, max_length=max_length) return {"input_ids": encoded["input_ids"]} elif is_ethiobbpe: # EthioBBPE-style tokenizer with encode_batch encoded = tokenizer.encode_batch(examples["text"]) # Handle both TokenizerOutput objects and dict returns if hasattr(encoded[0], 'ids'): input_ids = [item.ids for item in encoded] else: input_ids = [item["ids"] for item in encoded] return {"input_ids": input_ids} else: # HuggingFace tokenizer return tokenizer( examples["text"], truncation=True, max_length=max_length, padding=False, return_special_tokens_mask=False, ) dataset = Dataset.from_dict({"text": texts}) tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=["text"], ) # Group texts if needed for longer context if stride is not None: tokenized_dataset = _group_texts(tokenized_dataset, max_length, stride) return tokenized_dataset def _group_texts( dataset, max_length: int, stride: int, ) -> torch.utils.data.Dataset: """Group shorter sequences into longer ones.""" def group_function(examples): # Concatenate all input_ids concatenated_ids = sum(examples["input_ids"], []) # Calculate number of chunks total_length = len(concatenated_ids) if total_length <= max_length: return {"input_ids": [concatenated_ids]} # Create chunks with stride chunks = [] for i in range(0, total_length - max_length + 1, stride): chunks.append(concatenated_ids[i : i + max_length]) # Add final chunk if there's remaining data if total_length % stride != 0 or total_length < max_length: remaining = total_length - ((total_length // stride) * stride) if remaining > 0: chunks.append(concatenated_ids[-max_length:]) return {"input_ids": chunks} return dataset.map( group_function, batched=True, desc="Grouping texts", )