File size: 2,467 Bytes
cf3d756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.distributed as dist
from collections.abc import Iterator

def _is_batch_valid(batch):
    """

    Check if a batch is valid for training/evaluation.

    A valid batch must have input_ids and at least one image.

    """
    if not batch:
        return False
    # The collator can return a batch with empty lists
    if len(batch['input_ids']) == 0:
        return False
    
    if len(batch['images']) == 0:
        return False
    
    # `images` is a list of lists of tensors. Check that at least one image is not None.
    if len([img for sublist in batch['images'] for img in sublist]) == 0:
        # During training, not having images creates gradients computed without all model parameters.
        # This creates deadlocks in DDP.
        return False

    return True


def synchronized_dataloader_step(train_loader, is_dist):
    """

    Create a synchronized iterator that handles uneven data distribution in DDP.

    All ranks will stop when the first rank runs out of data.

    This happens because when packing a presharded dataset, a rank might have less groups than the others.

    It also handles cases where a collator returns an empty/invalid batch on some ranks,

    by ensuring all ranks skip the invalid batch and attempt to fetch a new one.

    """
    if not is_dist:
        # For single GPU, we don't need synchronization, just filter invalid batches.
        for batch in train_loader:
            if _is_batch_valid(batch):
                yield batch
        return
    
    # For DDP, we need synchronization.
    if isinstance(train_loader, Iterator):
        train_iter = train_loader
    else:
        train_iter = iter(train_loader)
    
    while True:
        is_valid = False
        try:
            while not is_valid:
                batch = next(train_iter)
                is_valid = _is_batch_valid(batch)
            has_data = torch.tensor(1, device=torch.cuda.current_device())
        except StopIteration:
            batch = None
            has_data = torch.tensor(0, device=torch.cuda.current_device())
        
        # We synchronize across all ranks. If any rank is out of data, all ranks stop.
        dist.all_reduce(has_data, op=dist.ReduceOp.MIN)
        
        if has_data.item() == 0:
            # At least one rank is out of data. All ranks should stop.
            break
        yield batch
    return None