ntf-space / training /data.py
Nexuss0781's picture
Upload folder using huggingface_hub
9e6020b verified
"""
Data collation for language modeling.
Provides efficient batch collation with padding and label handling
for autoregressive language model training.
Integrated with NTFTokenizer (EthioBBPE-based) for Ethiopian languages.
"""
from typing import Dict, List, Optional, Any, Union
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HFDataset
try:
from tokenizer import NTFTokenizer
NTF_TOKENIZER_AVAILABLE = True
except ImportError:
NTF_TOKENIZER_AVAILABLE = False
NTFTokenizer = None
class TextDataset(Dataset):
"""
Simple text dataset that loads texts from a file.
Args:
file_path: Path to text file (one sample per line)
max_length: Maximum sequence length
tokenizer_name: Tokenizer to use ('char' for character-level)
"""
def __init__(
self,
file_path: str,
max_length: int = 512,
tokenizer_name: str = "char",
):
self.max_length = max_length
# Load texts from file
with open(file_path, 'r', encoding='utf-8') as f:
self.texts = [line.strip() for line in f if line.strip()]
# Create simple character-level tokenizer
self.char_to_idx = {}
self.idx_to_char = {}
self._build_vocab()
def _build_vocab(self):
"""Build character-level vocabulary."""
all_chars = set()
for text in self.texts:
all_chars.update(text)
# Special tokens
self.char_to_idx['<pad>'] = 0
self.char_to_idx['<unk>'] = 1
idx = len(self.char_to_idx)
for char in sorted(all_chars):
if char not in self.char_to_idx:
self.char_to_idx[char] = idx
idx += 1
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
def _tokenize(self, text: str) -> List[int]:
"""Tokenize text to character IDs."""
return [self.char_to_idx.get(c, self.char_to_idx['<unk>']) for c in text]
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
text = self.texts[idx]
input_ids = self._tokenize(text)[:self.max_length]
# Ensure we have at least some tokens
if len(input_ids) == 0:
input_ids = [self.char_to_idx['<pad>']]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"labels": torch.tensor(input_ids, dtype=torch.long),
}
def create_dataloader(
dataset: Dataset,
batch_size: int = 4,
shuffle: bool = True,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""
Create a DataLoader from a dataset.
Args:
dataset: Dataset to load from
batch_size: Batch size
shuffle: Whether to shuffle data
num_workers: Number of worker processes
pin_memory: Pin memory for faster GPU transfer
Returns:
DataLoader instance
"""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=lambda x: {
"input_ids": torch.stack([item["input_ids"] for item in x]),
"labels": torch.stack([item["labels"] for item in x]),
},
)
class DataCollatorForLanguageModeling:
"""
Data collator for language modeling tasks.
Features:
- Dynamic padding to max length in batch
- Label creation for next-token prediction
- Optional masking for MLM (not used in decoder-only)
- Efficient tensor conversion
- Compatible with EthioBBPE and HuggingFace tokenizers
Args:
pad_token_id: Token ID for padding
max_length: Maximum sequence length (None for dynamic)
return_tensors: Type of tensors to return ('pt', 'np')
"""
def __init__(
self,
pad_token_id: int = 0,
max_length: Optional[int] = None,
return_tensors: str = "pt",
):
self.pad_token_id = pad_token_id
self.max_length = max_length
self.return_tensors = return_tensors
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""
Collate a batch of examples.
Args:
examples: List of dictionaries with 'input_ids' and optional 'labels'
Returns:
Batch dictionary with padded tensors
"""
# Extract input_ids from examples
input_ids_list = [example["input_ids"] for example in examples]
# Handle labels
has_labels = "labels" in examples[0]
if has_labels:
labels_list = [example["labels"] for example in examples]
else:
# Use input_ids as labels (standard for LM)
labels_list = input_ids_list
# Truncate if max_length specified
if self.max_length is not None:
input_ids_list = [ids[:self.max_length] for ids in input_ids_list]
labels_list = [lbl[:self.max_length] for lbl in labels_list]
# Convert to tensors
input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids_list]
labels_tensors = [torch.tensor(lbl, dtype=torch.long) for lbl in labels_list]
# Pad sequences
input_ids = pad_sequence(
input_ids_tensors,
batch_first=True,
padding_value=self.pad_token_id,
)
labels = pad_sequence(
labels_tensors,
batch_first=True,
padding_value=-100, # Ignore padding in loss calculation
)
# Create attention mask
attention_mask = (input_ids != self.pad_token_id).long()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def create_training_dataset(
texts: List[str],
tokenizer,
max_length: int = 512,
stride: Optional[int] = None,
) -> torch.utils.data.Dataset:
"""
Create a dataset from raw texts.
Supports NTFTokenizer, EthioBBPE, and HuggingFace tokenizers.
Args:
texts: List of text strings
tokenizer: Tokenizer instance (NTFTokenizer, EthioBBPE, or HF tokenizer)
max_length: Maximum sequence length
stride: Stride for chunking long texts (None for truncation)
Returns:
Dataset with tokenized examples
"""
from datasets import Dataset
# Detect tokenizer type
is_ntf_tokenizer = NTF_TOKENIZER_AVAILABLE and isinstance(tokenizer, NTFTokenizer)
is_ethiobbpe = hasattr(tokenizer, 'encode_batch') or type(tokenizer).__name__ in ['EthioBBPE', 'NTFTokenizer']
# Tokenize all texts
def tokenize_function(examples):
if is_ntf_tokenizer:
# NTFTokenizer - use callable interface
encoded = tokenizer(examples["text"], add_special_tokens=True, padding=False, truncation=True, max_length=max_length)
return {"input_ids": encoded["input_ids"]}
elif is_ethiobbpe:
# EthioBBPE-style tokenizer with encode_batch
encoded = tokenizer.encode_batch(examples["text"])
# Handle both TokenizerOutput objects and dict returns
if hasattr(encoded[0], 'ids'):
input_ids = [item.ids for item in encoded]
else:
input_ids = [item["ids"] for item in encoded]
return {"input_ids": input_ids}
else:
# HuggingFace tokenizer
return tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
padding=False,
return_special_tokens_mask=False,
)
dataset = Dataset.from_dict({"text": texts})
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"],
)
# Group texts if needed for longer context
if stride is not None:
tokenized_dataset = _group_texts(tokenized_dataset, max_length, stride)
return tokenized_dataset
def _group_texts(
dataset,
max_length: int,
stride: int,
) -> torch.utils.data.Dataset:
"""Group shorter sequences into longer ones."""
def group_function(examples):
# Concatenate all input_ids
concatenated_ids = sum(examples["input_ids"], [])
# Calculate number of chunks
total_length = len(concatenated_ids)
if total_length <= max_length:
return {"input_ids": [concatenated_ids]}
# Create chunks with stride
chunks = []
for i in range(0, total_length - max_length + 1, stride):
chunks.append(concatenated_ids[i : i + max_length])
# Add final chunk if there's remaining data
if total_length % stride != 0 or total_length < max_length:
remaining = total_length - ((total_length // stride) * stride)
if remaining > 0:
chunks.append(concatenated_ids[-max_length:])
return {"input_ids": chunks}
return dataset.map(
group_function,
batched=True,
desc="Grouping texts",
)