""" DataLoader utilities for SLM training. Provides efficient batching and data loading for training. """ import os from typing import Dict, Optional, List import torch from torch.utils.data import DataLoader, Dataset, DistributedSampler from .dataset import ConversationalDataset, StreamingTextDataset, PackedDataset from .tokenizer import SLMTokenizer def create_dataloader( dataset: Dataset, batch_size: int, shuffle: bool = True, num_workers: int = 4, pin_memory: bool = None, # Auto-detect based on device drop_last: bool = True, distributed: bool = False, world_size: int = 1, rank: int = 0, ) -> DataLoader: """Create a DataLoader with optimal settings. Args: dataset: The dataset to load from batch_size: Batch size per device shuffle: Whether to shuffle data num_workers: Number of data loading workers pin_memory: Pin memory for faster GPU transfer drop_last: Drop last incomplete batch distributed: Whether using distributed training world_size: Number of distributed processes rank: Current process rank Returns: Configured DataLoader """ sampler = None if distributed: sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, ) shuffle = False # Sampler handles shuffling # Auto-detect pin_memory: disable for MPS (not supported) if pin_memory is None: import torch pin_memory = torch.cuda.is_available() # Only True for CUDA return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle if sampler is None else False, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, collate_fn=default_collate_fn, ) def default_collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """Collate function for batching samples. Args: batch: List of sample dictionaries Returns: Batched dictionary with stacked tensors """ return { "input_ids": torch.stack([s["input_ids"] for s in batch]), "attention_mask": torch.stack([s["attention_mask"] for s in batch]), "labels": torch.stack([s["labels"] for s in batch]), } class DataModule: """Data module for managing train/val dataloaders. Provides a unified interface for data loading during training. """ def __init__( self, data_dir: str, tokenizer_path: str, max_length: int = 1024, batch_size: int = 32, num_workers: int = 4, val_batch_size: Optional[int] = None, ): """Initialize data module. Args: data_dir: Directory containing processed data tokenizer_path: Path to tokenizer.json max_length: Maximum sequence length batch_size: Training batch size num_workers: Number of data loading workers val_batch_size: Validation batch size (defaults to batch_size) """ self.data_dir = data_dir self.max_length = max_length self.batch_size = batch_size self.val_batch_size = val_batch_size or batch_size self.num_workers = num_workers # Load tokenizer self.tokenizer = SLMTokenizer.from_file(tokenizer_path) # Datasets (created on first access) self._train_dataset = None self._val_dataset = None @property def train_dataset(self) -> Dataset: """Get or create training dataset.""" if self._train_dataset is None: self._train_dataset = ConversationalDataset( data_path=self.data_dir, tokenizer=self.tokenizer, max_length=self.max_length, split="train", ) return self._train_dataset @property def val_dataset(self) -> Dataset: """Get or create validation dataset.""" if self._val_dataset is None: self._val_dataset = ConversationalDataset( data_path=self.data_dir, tokenizer=self.tokenizer, max_length=self.max_length, split="val", ) return self._val_dataset def train_dataloader( self, distributed: bool = False, world_size: int = 1, rank: int = 0, ) -> DataLoader: """Get training dataloader.""" return create_dataloader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, distributed=distributed, world_size=world_size, rank=rank, ) def val_dataloader(self) -> DataLoader: """Get validation dataloader.""" return create_dataloader( self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, ) class StreamingDataModule: """Data module for streaming large datasets. Memory-efficient loading for large text corpora. """ def __init__( self, data_files: List[str], tokenizer_path: str, max_length: int = 1024, batch_size: int = 32, num_workers: int = 4, ): """Initialize streaming data module. Args: data_files: List of text file paths tokenizer_path: Path to tokenizer.json max_length: Maximum sequence length batch_size: Batch size num_workers: Number of data loading workers """ self.data_files = data_files self.max_length = max_length self.batch_size = batch_size self.num_workers = num_workers # Load tokenizer self.tokenizer = SLMTokenizer.from_file(tokenizer_path) def train_dataloader(self) -> DataLoader: """Get training dataloader for streaming data.""" dataset = StreamingTextDataset( data_files=self.data_files, tokenizer=self.tokenizer, max_length=self.max_length, shuffle=True, ) return DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=default_collate_fn, ) def estimate_dataset_tokens(data_dir: str, tokenizer_path: str) -> Dict[str, int]: """Estimate total tokens in a dataset. Args: data_dir: Directory containing data files tokenizer_path: Path to tokenizer Returns: Dictionary with token counts """ import json from pathlib import Path tokenizer = SLMTokenizer.from_file(tokenizer_path) total_tokens = 0 total_samples = 0 for file_path in Path(data_dir).glob("*.json*"): with open(file_path, "r") as f: if file_path.suffix == ".jsonl": samples = [json.loads(line) for line in f if line.strip()] else: samples = json.load(f) if not isinstance(samples, list): samples = [samples] for sample in samples: if "user" in sample and "assistant" in sample: tokens = tokenizer.encode_conversation( sample["user"], sample["assistant"] ) elif "text" in sample: tokens = tokenizer.encode(sample["text"]) else: continue total_tokens += len(tokens) total_samples += 1 return { "total_tokens": total_tokens, "total_samples": total_samples, "avg_tokens_per_sample": total_tokens / max(total_samples, 1), } def get_dataloader_stats(dataloader: DataLoader) -> Dict[str, float]: """Get statistics from a dataloader. Args: dataloader: The dataloader to analyze Returns: Dictionary with statistics """ total_batches = 0 total_tokens = 0 total_non_pad_tokens = 0 for batch in dataloader: total_batches += 1 total_tokens += batch["input_ids"].numel() total_non_pad_tokens += batch["attention_mask"].sum().item() # Only sample first 100 batches if total_batches >= 100: break return { "batches_sampled": total_batches, "tokens_per_batch": total_tokens / max(total_batches, 1), "non_pad_ratio": total_non_pad_tokens / max(total_tokens, 1), }