|
|
""" |
|
|
DataLoader utilities for SLM training. |
|
|
|
|
|
Provides efficient batching and data loading for training. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Dict, Optional, List |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler |
|
|
|
|
|
from .dataset import ConversationalDataset, StreamingTextDataset, PackedDataset |
|
|
from .tokenizer import SLMTokenizer |
|
|
|
|
|
|
|
|
def create_dataloader( |
|
|
dataset: Dataset, |
|
|
batch_size: int, |
|
|
shuffle: bool = True, |
|
|
num_workers: int = 4, |
|
|
pin_memory: bool = None, |
|
|
drop_last: bool = True, |
|
|
distributed: bool = False, |
|
|
world_size: int = 1, |
|
|
rank: int = 0, |
|
|
) -> DataLoader: |
|
|
"""Create a DataLoader with optimal settings. |
|
|
|
|
|
Args: |
|
|
dataset: The dataset to load from |
|
|
batch_size: Batch size per device |
|
|
shuffle: Whether to shuffle data |
|
|
num_workers: Number of data loading workers |
|
|
pin_memory: Pin memory for faster GPU transfer |
|
|
drop_last: Drop last incomplete batch |
|
|
distributed: Whether using distributed training |
|
|
world_size: Number of distributed processes |
|
|
rank: Current process rank |
|
|
|
|
|
Returns: |
|
|
Configured DataLoader |
|
|
""" |
|
|
sampler = None |
|
|
if distributed: |
|
|
sampler = DistributedSampler( |
|
|
dataset, |
|
|
num_replicas=world_size, |
|
|
rank=rank, |
|
|
shuffle=shuffle, |
|
|
) |
|
|
shuffle = False |
|
|
|
|
|
|
|
|
if pin_memory is None: |
|
|
import torch |
|
|
pin_memory = torch.cuda.is_available() |
|
|
|
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle if sampler is None else False, |
|
|
sampler=sampler, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
drop_last=drop_last, |
|
|
collate_fn=default_collate_fn, |
|
|
) |
|
|
|
|
|
|
|
|
def default_collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
|
|
"""Collate function for batching samples. |
|
|
|
|
|
Args: |
|
|
batch: List of sample dictionaries |
|
|
|
|
|
Returns: |
|
|
Batched dictionary with stacked tensors |
|
|
""" |
|
|
return { |
|
|
"input_ids": torch.stack([s["input_ids"] for s in batch]), |
|
|
"attention_mask": torch.stack([s["attention_mask"] for s in batch]), |
|
|
"labels": torch.stack([s["labels"] for s in batch]), |
|
|
} |
|
|
|
|
|
|
|
|
class DataModule: |
|
|
"""Data module for managing train/val dataloaders. |
|
|
|
|
|
Provides a unified interface for data loading during training. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_dir: str, |
|
|
tokenizer_path: str, |
|
|
max_length: int = 1024, |
|
|
batch_size: int = 32, |
|
|
num_workers: int = 4, |
|
|
val_batch_size: Optional[int] = None, |
|
|
): |
|
|
"""Initialize data module. |
|
|
|
|
|
Args: |
|
|
data_dir: Directory containing processed data |
|
|
tokenizer_path: Path to tokenizer.json |
|
|
max_length: Maximum sequence length |
|
|
batch_size: Training batch size |
|
|
num_workers: Number of data loading workers |
|
|
val_batch_size: Validation batch size (defaults to batch_size) |
|
|
""" |
|
|
self.data_dir = data_dir |
|
|
self.max_length = max_length |
|
|
self.batch_size = batch_size |
|
|
self.val_batch_size = val_batch_size or batch_size |
|
|
self.num_workers = num_workers |
|
|
|
|
|
|
|
|
self.tokenizer = SLMTokenizer.from_file(tokenizer_path) |
|
|
|
|
|
|
|
|
self._train_dataset = None |
|
|
self._val_dataset = None |
|
|
|
|
|
@property |
|
|
def train_dataset(self) -> Dataset: |
|
|
"""Get or create training dataset.""" |
|
|
if self._train_dataset is None: |
|
|
self._train_dataset = ConversationalDataset( |
|
|
data_path=self.data_dir, |
|
|
tokenizer=self.tokenizer, |
|
|
max_length=self.max_length, |
|
|
split="train", |
|
|
) |
|
|
return self._train_dataset |
|
|
|
|
|
@property |
|
|
def val_dataset(self) -> Dataset: |
|
|
"""Get or create validation dataset.""" |
|
|
if self._val_dataset is None: |
|
|
self._val_dataset = ConversationalDataset( |
|
|
data_path=self.data_dir, |
|
|
tokenizer=self.tokenizer, |
|
|
max_length=self.max_length, |
|
|
split="val", |
|
|
) |
|
|
return self._val_dataset |
|
|
|
|
|
def train_dataloader( |
|
|
self, |
|
|
distributed: bool = False, |
|
|
world_size: int = 1, |
|
|
rank: int = 0, |
|
|
) -> DataLoader: |
|
|
"""Get training dataloader.""" |
|
|
return create_dataloader( |
|
|
self.train_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=self.num_workers, |
|
|
drop_last=True, |
|
|
distributed=distributed, |
|
|
world_size=world_size, |
|
|
rank=rank, |
|
|
) |
|
|
|
|
|
def val_dataloader(self) -> DataLoader: |
|
|
"""Get validation dataloader.""" |
|
|
return create_dataloader( |
|
|
self.val_dataset, |
|
|
batch_size=self.val_batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
|
|
|
class StreamingDataModule: |
|
|
"""Data module for streaming large datasets. |
|
|
|
|
|
Memory-efficient loading for large text corpora. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_files: List[str], |
|
|
tokenizer_path: str, |
|
|
max_length: int = 1024, |
|
|
batch_size: int = 32, |
|
|
num_workers: int = 4, |
|
|
): |
|
|
"""Initialize streaming data module. |
|
|
|
|
|
Args: |
|
|
data_files: List of text file paths |
|
|
tokenizer_path: Path to tokenizer.json |
|
|
max_length: Maximum sequence length |
|
|
batch_size: Batch size |
|
|
num_workers: Number of data loading workers |
|
|
""" |
|
|
self.data_files = data_files |
|
|
self.max_length = max_length |
|
|
self.batch_size = batch_size |
|
|
self.num_workers = num_workers |
|
|
|
|
|
|
|
|
self.tokenizer = SLMTokenizer.from_file(tokenizer_path) |
|
|
|
|
|
def train_dataloader(self) -> DataLoader: |
|
|
"""Get training dataloader for streaming data.""" |
|
|
dataset = StreamingTextDataset( |
|
|
data_files=self.data_files, |
|
|
tokenizer=self.tokenizer, |
|
|
max_length=self.max_length, |
|
|
shuffle=True, |
|
|
) |
|
|
|
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=self.batch_size, |
|
|
num_workers=self.num_workers, |
|
|
pin_memory=True, |
|
|
collate_fn=default_collate_fn, |
|
|
) |
|
|
|
|
|
|
|
|
def estimate_dataset_tokens(data_dir: str, tokenizer_path: str) -> Dict[str, int]: |
|
|
"""Estimate total tokens in a dataset. |
|
|
|
|
|
Args: |
|
|
data_dir: Directory containing data files |
|
|
tokenizer_path: Path to tokenizer |
|
|
|
|
|
Returns: |
|
|
Dictionary with token counts |
|
|
""" |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
tokenizer = SLMTokenizer.from_file(tokenizer_path) |
|
|
|
|
|
total_tokens = 0 |
|
|
total_samples = 0 |
|
|
|
|
|
for file_path in Path(data_dir).glob("*.json*"): |
|
|
with open(file_path, "r") as f: |
|
|
if file_path.suffix == ".jsonl": |
|
|
samples = [json.loads(line) for line in f if line.strip()] |
|
|
else: |
|
|
samples = json.load(f) |
|
|
if not isinstance(samples, list): |
|
|
samples = [samples] |
|
|
|
|
|
for sample in samples: |
|
|
if "user" in sample and "assistant" in sample: |
|
|
tokens = tokenizer.encode_conversation( |
|
|
sample["user"], sample["assistant"] |
|
|
) |
|
|
elif "text" in sample: |
|
|
tokens = tokenizer.encode(sample["text"]) |
|
|
else: |
|
|
continue |
|
|
|
|
|
total_tokens += len(tokens) |
|
|
total_samples += 1 |
|
|
|
|
|
return { |
|
|
"total_tokens": total_tokens, |
|
|
"total_samples": total_samples, |
|
|
"avg_tokens_per_sample": total_tokens / max(total_samples, 1), |
|
|
} |
|
|
|
|
|
|
|
|
def get_dataloader_stats(dataloader: DataLoader) -> Dict[str, float]: |
|
|
"""Get statistics from a dataloader. |
|
|
|
|
|
Args: |
|
|
dataloader: The dataloader to analyze |
|
|
|
|
|
Returns: |
|
|
Dictionary with statistics |
|
|
""" |
|
|
total_batches = 0 |
|
|
total_tokens = 0 |
|
|
total_non_pad_tokens = 0 |
|
|
|
|
|
for batch in dataloader: |
|
|
total_batches += 1 |
|
|
total_tokens += batch["input_ids"].numel() |
|
|
total_non_pad_tokens += batch["attention_mask"].sum().item() |
|
|
|
|
|
|
|
|
if total_batches >= 100: |
|
|
break |
|
|
|
|
|
return { |
|
|
"batches_sampled": total_batches, |
|
|
"tokens_per_batch": total_tokens / max(total_batches, 1), |
|
|
"non_pad_ratio": total_non_pad_tokens / max(total_tokens, 1), |
|
|
} |
|
|
|