File size: 3,166 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""DataLoader utilities."""

from typing import Optional
import torch
from torch.utils.data import DataLoader, Dataset
from taoTrain.config import TrainingConfig


def get_dataloader(

    dataset: Dataset,

    config: TrainingConfig,

    shuffle: bool = True,

    drop_last: bool = True,

) -> DataLoader:
    """

    Create a DataLoader from a dataset.

    

    **NOTE**: For JSONL-based datasets (PretrainJSONLDataset, SFTJSONLDataset, etc.),

    this function is now deprecated in favor of AsyncBatchIterator for better performance.

    AsyncBatchIterator enables tokenization to happen in parallel with training,

    avoiding the startup bottleneck of tokenizing all data upfront.

    

    See: taoTrain/data/async_loader.py for the new async loading approach.

    The trainer automatically uses AsyncBatchIterator for JSONL datasets.

    

    Args:

        dataset: PyTorch Dataset instance

        config: Training configuration

        shuffle: Whether to shuffle data

        drop_last: Whether to drop last incomplete batch

    

    Returns:

        DataLoader instance

    """
    
    def collate_fn(batch):
        """Collate function for padding sequences."""
        # Batch is a list of dicts
        collated = {}
        keys = batch[0].keys()
        
        for key in keys:
            items = [item[key] for item in batch]
            
            # Stack tensors
            if isinstance(items[0], torch.Tensor):
                if key in ["input_ids", "labels"]:
                    # Pad sequences
                    max_len = max(item.shape[0] for item in items)
                    padded = []
                    for item in items:
                        if len(item.shape) == 1:
                            # 1D tensor - pad it
                            pad_len = max_len - item.shape[0]
                            if pad_len > 0:
                                item = torch.nn.functional.pad(item, (0, pad_len), value=-100 if key == "labels" else 0)
                        padded.append(item)
                    collated[key] = torch.stack(padded)
                elif key == "attention_mask":
                    # Also pad attention mask
                    max_len = max(item.shape[0] for item in items)
                    padded = []
                    for item in items:
                        if len(item.shape) == 1:
                            pad_len = max_len - item.shape[0]
                            if pad_len > 0:
                                item = torch.nn.functional.pad(item, (0, pad_len), value=0)
                        padded.append(item)
                    collated[key] = torch.stack(padded)
                else:
                    collated[key] = torch.stack(items)
            else:
                collated[key] = items
        
        return collated
    
    return DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        collate_fn=collate_fn,
    )