File size: 2,467 Bytes
cf3d756 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import torch
import torch.distributed as dist
from collections.abc import Iterator
def _is_batch_valid(batch):
"""
Check if a batch is valid for training/evaluation.
A valid batch must have input_ids and at least one image.
"""
if not batch:
return False
# The collator can return a batch with empty lists
if len(batch['input_ids']) == 0:
return False
if len(batch['images']) == 0:
return False
# `images` is a list of lists of tensors. Check that at least one image is not None.
if len([img for sublist in batch['images'] for img in sublist]) == 0:
# During training, not having images creates gradients computed without all model parameters.
# This creates deadlocks in DDP.
return False
return True
def synchronized_dataloader_step(train_loader, is_dist):
"""
Create a synchronized iterator that handles uneven data distribution in DDP.
All ranks will stop when the first rank runs out of data.
This happens because when packing a presharded dataset, a rank might have less groups than the others.
It also handles cases where a collator returns an empty/invalid batch on some ranks,
by ensuring all ranks skip the invalid batch and attempt to fetch a new one.
"""
if not is_dist:
# For single GPU, we don't need synchronization, just filter invalid batches.
for batch in train_loader:
if _is_batch_valid(batch):
yield batch
return
# For DDP, we need synchronization.
if isinstance(train_loader, Iterator):
train_iter = train_loader
else:
train_iter = iter(train_loader)
while True:
is_valid = False
try:
while not is_valid:
batch = next(train_iter)
is_valid = _is_batch_valid(batch)
has_data = torch.tensor(1, device=torch.cuda.current_device())
except StopIteration:
batch = None
has_data = torch.tensor(0, device=torch.cuda.current_device())
# We synchronize across all ranks. If any rank is out of data, all ranks stop.
dist.all_reduce(has_data, op=dist.ReduceOp.MIN)
if has_data.item() == 0:
# At least one rank is out of data. All ranks should stop.
break
yield batch
return None |