Nexuss0781's picture
Upload data/train-00000-of-00001.parquet with huggingface_hub
7cb972e
"""
Data collation for language modeling.
Provides efficient batch collation with padding and label handling
for autoregressive language model training.
Integrated with EthioBBPE tokenizer for Ethiopian languages.
"""
from typing import Dict, List, Optional, Any, Union
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HFDataset
class TextDataset(Dataset):
"""
Simple text dataset that loads texts from a file.
Args:
file_path: Path to text file (one sample per line)
max_length: Maximum sequence length
tokenizer_name: Tokenizer to use ('char' for character-level)
"""
def __init__(
self,
file_path: str,
max_length: int = 512,
tokenizer_name: str = "char",
):
self.max_length = max_length
# Load texts from file
with open(file_path, 'r', encoding='utf-8') as f:
self.texts = [line.strip() for line in f if line.strip()]
# Create simple character-level tokenizer
self.char_to_idx = {}
self.idx_to_char = {}
self._build_vocab()
def _build_vocab(self):
"""Build character-level vocabulary."""
all_chars = set()
for text in self.texts:
all_chars.update(text)
# Special tokens
self.char_to_idx['<pad>'] = 0
self.char_to_idx['<unk>'] = 1
idx = len(self.char_to_idx)
for char in sorted(all_chars):
if char not in self.char_to_idx:
self.char_to_idx[char] = idx
idx += 1
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
def _tokenize(self, text: str) -> List[int]:
"""Tokenize text to character IDs."""
return [self.char_to_idx.get(c, self.char_to_idx['<unk>']) for c in text]
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
text = self.texts[idx]
input_ids = self._tokenize(text)[:self.max_length]
# Ensure we have at least some tokens
if len(input_ids) == 0:
input_ids = [self.char_to_idx['<pad>']]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"labels": torch.tensor(input_ids, dtype=torch.long),
}
def create_dataloader(
dataset: Dataset,
batch_size: int = 4,
shuffle: bool = True,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""
Create a DataLoader from a dataset.
Args:
dataset: Dataset to load from
batch_size: Batch size
shuffle: Whether to shuffle data
num_workers: Number of worker processes
pin_memory: Pin memory for faster GPU transfer
Returns:
DataLoader instance
"""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=lambda x: {
"input_ids": torch.stack([item["input_ids"] for item in x]),
"labels": torch.stack([item["labels"] for item in x]),
},
)
class DataCollatorForLanguageModeling:
"""
Data collator for language modeling tasks.
Features:
- Dynamic padding to max length in batch
- Label creation for next-token prediction
- Optional masking for MLM (not used in decoder-only)
- Efficient tensor conversion
- Compatible with EthioBBPE and HuggingFace tokenizers
Args:
pad_token_id: Token ID for padding
max_length: Maximum sequence length (None for dynamic)
return_tensors: Type of tensors to return ('pt', 'np')
"""
def __init__(
self,
pad_token_id: int = 0,
max_length: Optional[int] = None,
return_tensors: str = "pt",
):
self.pad_token_id = pad_token_id
self.max_length = max_length
self.return_tensors = return_tensors
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""
Collate a batch of examples.
Args:
examples: List of dictionaries with 'input_ids' and optional 'labels'
Returns:
Batch dictionary with padded tensors
"""
# Extract input_ids from examples
input_ids_list = [example["input_ids"] for example in examples]
# Handle labels
has_labels = "labels" in examples[0]
if has_labels:
labels_list = [example["labels"] for example in examples]
else:
# Use input_ids as labels (standard for LM)
labels_list = input_ids_list
# Truncate if max_length specified
if self.max_length is not None:
input_ids_list = [ids[:self.max_length] for ids in input_ids_list]
labels_list = [lbl[:self.max_length] for lbl in labels_list]
# Convert to tensors
input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids_list]
labels_tensors = [torch.tensor(lbl, dtype=torch.long) for lbl in labels_list]
# Pad sequences
input_ids = pad_sequence(
input_ids_tensors,
batch_first=True,
padding_value=self.pad_token_id,
)
labels = pad_sequence(
labels_tensors,
batch_first=True,
padding_value=-100, # Ignore padding in loss calculation
)
# Create attention mask
attention_mask = (input_ids != self.pad_token_id).long()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def create_training_dataset(
texts: List[str],
tokenizer,
max_length: int = 512,
stride: Optional[int] = None,
) -> torch.utils.data.Dataset:
"""
Create a dataset from raw texts.
Supports both EthioBBPE and HuggingFace tokenizers.
Args:
texts: List of text strings
tokenizer: Tokenizer instance (EthioBBPE or HF tokenizer)
max_length: Maximum sequence length
stride: Stride for chunking long texts (None for truncation)
Returns:
Dataset with tokenized examples
"""
from datasets import Dataset
# Detect tokenizer type
is_ethiobbpe = hasattr(tokenizer, 'encode_batch') or type(tokenizer).__name__ == 'EthioBBPE'
# Tokenize all texts
def tokenize_function(examples):
if is_ethiobbpe:
# EthioBBPE tokenizer
encoded = tokenizer.encode_batch(examples["text"])
input_ids = [item["ids"] for item in encoded]
return {"input_ids": input_ids}
else:
# HuggingFace tokenizer
return tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
padding=False,
return_special_tokens_mask=False,
)
dataset = Dataset.from_dict({"text": texts})
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"],
)
# Group texts if needed for longer context
if stride is not None:
tokenized_dataset = _group_texts(tokenized_dataset, max_length, stride)
return tokenized_dataset
def _group_texts(
dataset,
max_length: int,
stride: int,
) -> torch.utils.data.Dataset:
"""Group shorter sequences into longer ones."""
def group_function(examples):
# Concatenate all input_ids
concatenated_ids = sum(examples["input_ids"], [])
# Calculate number of chunks
total_length = len(concatenated_ids)
if total_length <= max_length:
return {"input_ids": [concatenated_ids]}
# Create chunks with stride
chunks = []
for i in range(0, total_length - max_length + 1, stride):
chunks.append(concatenated_ids[i : i + max_length])
# Add final chunk if there's remaining data
if total_length % stride != 0 or total_length < max_length:
remaining = total_length - ((total_length // stride) * stride)
if remaining > 0:
chunks.append(concatenated_ids[-max_length:])
return {"input_ids": chunks}
return dataset.map(
group_function,
batched=True,
desc="Grouping texts",
)