StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""DataLoader utilities."""
from typing import Optional
import torch
from torch.utils.data import DataLoader, Dataset
from taoTrain.config import TrainingConfig
def get_dataloader(
dataset: Dataset,
config: TrainingConfig,
shuffle: bool = True,
drop_last: bool = True,
) -> DataLoader:
"""
Create a DataLoader from a dataset.
**NOTE**: For JSONL-based datasets (PretrainJSONLDataset, SFTJSONLDataset, etc.),
this function is now deprecated in favor of AsyncBatchIterator for better performance.
AsyncBatchIterator enables tokenization to happen in parallel with training,
avoiding the startup bottleneck of tokenizing all data upfront.
See: taoTrain/data/async_loader.py for the new async loading approach.
The trainer automatically uses AsyncBatchIterator for JSONL datasets.
Args:
dataset: PyTorch Dataset instance
config: Training configuration
shuffle: Whether to shuffle data
drop_last: Whether to drop last incomplete batch
Returns:
DataLoader instance
"""
def collate_fn(batch):
"""Collate function for padding sequences."""
# Batch is a list of dicts
collated = {}
keys = batch[0].keys()
for key in keys:
items = [item[key] for item in batch]
# Stack tensors
if isinstance(items[0], torch.Tensor):
if key in ["input_ids", "labels"]:
# Pad sequences
max_len = max(item.shape[0] for item in items)
padded = []
for item in items:
if len(item.shape) == 1:
# 1D tensor - pad it
pad_len = max_len - item.shape[0]
if pad_len > 0:
item = torch.nn.functional.pad(item, (0, pad_len), value=-100 if key == "labels" else 0)
padded.append(item)
collated[key] = torch.stack(padded)
elif key == "attention_mask":
# Also pad attention mask
max_len = max(item.shape[0] for item in items)
padded = []
for item in items:
if len(item.shape) == 1:
pad_len = max_len - item.shape[0]
if pad_len > 0:
item = torch.nn.functional.pad(item, (0, pad_len), value=0)
padded.append(item)
collated[key] = torch.stack(padded)
else:
collated[key] = torch.stack(items)
else:
collated[key] = items
return collated
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
collate_fn=collate_fn,
)