"""DataLoader utilities.""" from typing import Optional import torch from torch.utils.data import DataLoader, Dataset from taoTrain.config import TrainingConfig def get_dataloader( dataset: Dataset, config: TrainingConfig, shuffle: bool = True, drop_last: bool = True, ) -> DataLoader: """ Create a DataLoader from a dataset. **NOTE**: For JSONL-based datasets (PretrainJSONLDataset, SFTJSONLDataset, etc.), this function is now deprecated in favor of AsyncBatchIterator for better performance. AsyncBatchIterator enables tokenization to happen in parallel with training, avoiding the startup bottleneck of tokenizing all data upfront. See: taoTrain/data/async_loader.py for the new async loading approach. The trainer automatically uses AsyncBatchIterator for JSONL datasets. Args: dataset: PyTorch Dataset instance config: Training configuration shuffle: Whether to shuffle data drop_last: Whether to drop last incomplete batch Returns: DataLoader instance """ def collate_fn(batch): """Collate function for padding sequences.""" # Batch is a list of dicts collated = {} keys = batch[0].keys() for key in keys: items = [item[key] for item in batch] # Stack tensors if isinstance(items[0], torch.Tensor): if key in ["input_ids", "labels"]: # Pad sequences max_len = max(item.shape[0] for item in items) padded = [] for item in items: if len(item.shape) == 1: # 1D tensor - pad it pad_len = max_len - item.shape[0] if pad_len > 0: item = torch.nn.functional.pad(item, (0, pad_len), value=-100 if key == "labels" else 0) padded.append(item) collated[key] = torch.stack(padded) elif key == "attention_mask": # Also pad attention mask max_len = max(item.shape[0] for item in items) padded = [] for item in items: if len(item.shape) == 1: pad_len = max_len - item.shape[0] if pad_len > 0: item = torch.nn.functional.pad(item, (0, pad_len), value=0) padded.append(item) collated[key] = torch.stack(padded) else: collated[key] = torch.stack(items) else: collated[key] = items return collated return DataLoader( dataset, batch_size=config.batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=config.num_workers, pin_memory=config.pin_memory, collate_fn=collate_fn, )