| |
| |
| from .utils.transforms import * |
| from .base.batched_sampler import BatchedRandomSampler |
| from .arkitscenes import ARKitScenes |
| from .blendedmvs import BlendedMVS |
| from .co3d import Co3d |
| from .habitat import Habitat |
| from .megadepth import MegaDepth |
| from .scannetpp import ScanNetpp |
| from .staticthings3d import StaticThings3D |
| from .waymo import Waymo |
| from .wildrgbd import WildRGBD |
|
|
|
|
| def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): |
| import torch |
| from croco.utils.misc import get_world_size, get_rank |
|
|
| |
| if isinstance(dataset, str): |
| dataset = eval(dataset) |
|
|
| world_size = get_world_size() |
| rank = get_rank() |
|
|
| try: |
| sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, |
| rank=rank, drop_last=drop_last) |
| except (AttributeError, NotImplementedError): |
| |
| if torch.distributed.is_initialized(): |
| sampler = torch.utils.data.DistributedSampler( |
| dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last |
| ) |
| elif shuffle: |
| sampler = torch.utils.data.RandomSampler(dataset) |
| else: |
| sampler = torch.utils.data.SequentialSampler(dataset) |
|
|
| data_loader = torch.utils.data.DataLoader( |
| dataset, |
| sampler=sampler, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| pin_memory=pin_mem, |
| drop_last=drop_last, |
| ) |
|
|
| return data_loader |
|
|