Spaces:
Build error
Build error
| import os | |
| import json | |
| import torch | |
| import time | |
| import random | |
| from typing import Iterable | |
| from collections import OrderedDict | |
| from PIL import Image | |
| from torch.utils.data import Dataset, DataLoader, ConcatDataset, IterableDataset, DistributedSampler, RandomSampler | |
| from torch.utils.data.dataloader import default_collate | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| from torchvision.transforms import functional as F | |
| from .bucket_loader import Bucketeer, TemporalLengthBucketeer | |
| class IterLoader: | |
| """ | |
| A wrapper to convert DataLoader as an infinite iterator. | |
| Modified from: | |
| https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py | |
| """ | |
| def __init__(self, dataloader: DataLoader, use_distributed: bool = False, epoch: int = 0): | |
| self._dataloader = dataloader | |
| self.iter_loader = iter(self._dataloader) | |
| self._use_distributed = use_distributed | |
| self._epoch = epoch | |
| def epoch(self) -> int: | |
| return self._epoch | |
| def __next__(self): | |
| try: | |
| data = next(self.iter_loader) | |
| except StopIteration: | |
| self._epoch += 1 | |
| if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: | |
| self._dataloader.sampler.set_epoch(self._epoch) | |
| time.sleep(2) # Prevent possible deadlock during epoch transition | |
| self.iter_loader = iter(self._dataloader) | |
| data = next(self.iter_loader) | |
| return data | |
| def __iter__(self): | |
| return self | |
| def __len__(self): | |
| return len(self._dataloader) | |
| def identity(x): | |
| return x | |
| def create_image_text_dataloaders(dataset, batch_size, num_workers, | |
| multi_aspect_ratio=True, epoch=0, sizes=[(512, 512), (384, 640), (640, 384)], | |
| use_distributed=True, world_size=None, rank=None, | |
| ): | |
| """ | |
| The dataset has already been splited by different rank | |
| """ | |
| if use_distributed: | |
| assert world_size is not None | |
| assert rank is not None | |
| sampler = DistributedSampler( | |
| dataset, | |
| shuffle=True, | |
| num_replicas=world_size, | |
| rank=rank, | |
| seed=epoch, | |
| ) | |
| else: | |
| sampler = RandomSampler(dataset) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| sampler=sampler, | |
| collate_fn=identity if multi_aspect_ratio else default_collate, | |
| drop_last=True, | |
| ) | |
| if multi_aspect_ratio: | |
| dataloader_iterator = Bucketeer( | |
| dataloader, | |
| sizes=sizes, | |
| is_infinite=True, epoch=epoch, | |
| ) | |
| else: | |
| dataloader_iterator = iter(dataloader) | |
| # To make it infinite | |
| loader = IterLoader(dataloader_iterator, use_distributed=False, epoch=epoch) | |
| return loader | |
| def create_length_grouped_video_text_dataloader(dataset, batch_size, num_workers, max_frames, | |
| world_size=None, rank=None, epoch=0, use_distributed=False): | |
| if use_distributed: | |
| assert world_size is not None | |
| assert rank is not None | |
| sampler = DistributedSampler( | |
| dataset, | |
| shuffle=True, | |
| num_replicas=world_size, | |
| rank=rank, | |
| seed=epoch, | |
| ) | |
| else: | |
| sampler = RandomSampler(dataset) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| sampler=sampler, | |
| collate_fn=identity, | |
| drop_last=True, | |
| ) | |
| # make it infinite | |
| dataloader_iterator = TemporalLengthBucketeer( | |
| dataloader, | |
| max_frames=max_frames, | |
| epoch=epoch, | |
| ) | |
| return dataloader_iterator | |
| def create_mixed_dataloaders( | |
| dataset, batch_size, num_workers, world_size=None, rank=None, epoch=0, | |
| image_mix_ratio=0.1, use_image_video_mixed_training=True, | |
| ): | |
| """ | |
| The video & image mixed training dataloader builder | |
| """ | |
| assert world_size is not None | |
| assert rank is not None | |
| image_gpus = max(1, int(world_size * image_mix_ratio)) | |
| if use_image_video_mixed_training: | |
| video_gpus = world_size - image_gpus | |
| else: | |
| # only use video data | |
| video_gpus = world_size | |
| image_gpus = 0 | |
| print(f"{image_gpus} gpus for image, {video_gpus} gpus for video") | |
| if rank < video_gpus: | |
| sampler = DistributedSampler( | |
| dataset, | |
| shuffle=True, | |
| num_replicas=video_gpus, | |
| rank=rank, | |
| seed=epoch, | |
| ) | |
| else: | |
| sampler = DistributedSampler( | |
| dataset, | |
| shuffle=True, | |
| num_replicas=image_gpus, | |
| rank=rank - video_gpus, | |
| seed=epoch, | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| sampler=sampler, | |
| collate_fn=default_collate, | |
| drop_last=True, | |
| ) | |
| # To make it infinite | |
| loader = IterLoader(loader, use_distributed=True, epoch=epoch) | |
| return loader |