| import torch | |
| from torch.utils.data.dataloader import default_collate | |
| def collate_tensor_fn(batch): | |
| elem = batch[0] | |
| out = None | |
| if torch.utils.data.get_worker_info() is not None: | |
| # If we're in a background process, concatenate directly into a | |
| # shared memory tensor to avoid an extra copy | |
| numel = sum(x.numel() for x in batch) | |
| storage = elem._typed_storage()._new_shared(numel, device=elem.device) | |
| out = elem.new(storage).resize_(len(batch), *list(elem.size())) | |
| return torch.stack(batch, 0, out=out) | |
| def collate_fn_pad_lidar(batch): | |
| feats = dict() | |
| # skip: 1. collating lidar points | |
| # skip: 2. collating boxes | |
| for k in batch[0][0]: | |
| if k == 'lidar' or k == 'lidars_warped': | |
| feats[k] = [tmp[0][k] for tmp in batch] | |
| else: | |
| feats[k] = collate_tensor_fn([tmp[0][k] for tmp in batch]) | |
| targets = dict() | |
| # contains gt | |
| if len(batch[0]) > 1: | |
| for k in batch[0][1]: | |
| # targets[k] = collate_tensor_fn([tmp[1][k] for tmp in batch]) | |
| targets[k] = [tmp[1][k] for tmp in batch] | |
| return feats, targets |