Spaces:
Runtime error
Runtime error
| """ | |
| Utils for Datasets | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| import random | |
| from collections.abc import Mapping, Sequence | |
| import numpy as np | |
| import torch | |
| from torch.utils.data.dataloader import default_collate | |
| def collate_fn(batch): | |
| """ | |
| collate function for point cloud which support dict and list, | |
| 'coord' is necessary to determine 'offset' | |
| """ | |
| if not isinstance(batch, Sequence): | |
| raise TypeError(f"{batch.dtype} is not supported.") | |
| if isinstance(batch[0], torch.Tensor): | |
| return torch.cat(list(batch)) | |
| elif isinstance(batch[0], str): | |
| # str is also a kind of Sequence, judgement should before Sequence | |
| return list(batch) | |
| elif isinstance(batch[0], Sequence): | |
| for data in batch: | |
| data.append(torch.tensor([data[0].shape[0]])) | |
| batch = [collate_fn(samples) for samples in zip(*batch)] | |
| batch[-1] = torch.cumsum(batch[-1], dim=0).int() | |
| return batch | |
| elif isinstance(batch[0], Mapping): | |
| batch = {key: collate_fn([d[key] for d in batch]) for key in batch[0]} | |
| for key in batch.keys(): | |
| if "offset" in key: | |
| batch[key] = torch.cumsum(batch[key], dim=0) | |
| return batch | |
| else: | |
| return default_collate(batch) | |
| def point_collate_fn(batch, mix_prob=0): | |
| assert isinstance( | |
| batch[0], Mapping | |
| ) # currently, only support input_dict, rather than input_list | |
| batch = collate_fn(batch) | |
| if "offset" in batch.keys(): | |
| # Mix3d (https://arxiv.org/pdf/2110.02210.pdf) | |
| if random.random() < mix_prob: | |
| batch["offset"] = torch.cat( | |
| [batch["offset"][1:-1:2], batch["offset"][-1].unsqueeze(0)], dim=0 | |
| ) | |
| return batch | |
| def gaussian_kernel(dist2: np.array, a: float = 1, c: float = 5): | |
| return a * np.exp(-dist2 / (2 * c**2)) | |