vlm-demo / src /data /data_utils.py
tiltaf's picture
Upload 29 files
cf3d756 verified
import torch
import torch.distributed as dist
from collections.abc import Iterator
def _is_batch_valid(batch):
"""
Check if a batch is valid for training/evaluation.
A valid batch must have input_ids and at least one image.
"""
if not batch:
return False
# The collator can return a batch with empty lists
if len(batch['input_ids']) == 0:
return False
if len(batch['images']) == 0:
return False
# `images` is a list of lists of tensors. Check that at least one image is not None.
if len([img for sublist in batch['images'] for img in sublist]) == 0:
# During training, not having images creates gradients computed without all model parameters.
# This creates deadlocks in DDP.
return False
return True
def synchronized_dataloader_step(train_loader, is_dist):
"""
Create a synchronized iterator that handles uneven data distribution in DDP.
All ranks will stop when the first rank runs out of data.
This happens because when packing a presharded dataset, a rank might have less groups than the others.
It also handles cases where a collator returns an empty/invalid batch on some ranks,
by ensuring all ranks skip the invalid batch and attempt to fetch a new one.
"""
if not is_dist:
# For single GPU, we don't need synchronization, just filter invalid batches.
for batch in train_loader:
if _is_batch_valid(batch):
yield batch
return
# For DDP, we need synchronization.
if isinstance(train_loader, Iterator):
train_iter = train_loader
else:
train_iter = iter(train_loader)
while True:
is_valid = False
try:
while not is_valid:
batch = next(train_iter)
is_valid = _is_batch_valid(batch)
has_data = torch.tensor(1, device=torch.cuda.current_device())
except StopIteration:
batch = None
has_data = torch.tensor(0, device=torch.cuda.current_device())
# We synchronize across all ranks. If any rank is out of data, all ranks stop.
dist.all_reduce(has_data, op=dist.ReduceOp.MIN)
if has_data.item() == 0:
# At least one rank is out of data. All ranks should stop.
break
yield batch
return None