| | import torch
|
| | import torch.distributed as dist
|
| | from collections.abc import Iterator
|
| |
|
| | def _is_batch_valid(batch):
|
| | """
|
| | Check if a batch is valid for training/evaluation.
|
| | A valid batch must have input_ids and at least one image.
|
| | """
|
| | if not batch:
|
| | return False
|
| |
|
| | if len(batch['input_ids']) == 0:
|
| | return False
|
| |
|
| | if len(batch['images']) == 0:
|
| | return False
|
| |
|
| |
|
| | if len([img for sublist in batch['images'] for img in sublist]) == 0:
|
| |
|
| |
|
| | return False
|
| |
|
| | return True
|
| |
|
| |
|
| | def synchronized_dataloader_step(train_loader, is_dist):
|
| | """
|
| | Create a synchronized iterator that handles uneven data distribution in DDP.
|
| | All ranks will stop when the first rank runs out of data.
|
| | This happens because when packing a presharded dataset, a rank might have less groups than the others.
|
| | It also handles cases where a collator returns an empty/invalid batch on some ranks,
|
| | by ensuring all ranks skip the invalid batch and attempt to fetch a new one.
|
| | """
|
| | if not is_dist:
|
| |
|
| | for batch in train_loader:
|
| | if _is_batch_valid(batch):
|
| | yield batch
|
| | return
|
| |
|
| |
|
| | if isinstance(train_loader, Iterator):
|
| | train_iter = train_loader
|
| | else:
|
| | train_iter = iter(train_loader)
|
| |
|
| | while True:
|
| | is_valid = False
|
| | try:
|
| | while not is_valid:
|
| | batch = next(train_iter)
|
| | is_valid = _is_batch_valid(batch)
|
| | has_data = torch.tensor(1, device=torch.cuda.current_device())
|
| | except StopIteration:
|
| | batch = None
|
| | has_data = torch.tensor(0, device=torch.cuda.current_device())
|
| |
|
| |
|
| | dist.all_reduce(has_data, op=dist.ReduceOp.MIN)
|
| |
|
| | if has_data.item() == 0:
|
| |
|
| | break
|
| | yield batch
|
| | return None |