| |
|
|
| import torch |
| import collections |
| import numpy as np |
| from pointcept.models.utils.structure import Point |
|
|
| def point_collate_fn(batch, mix_prob=0): |
| """ |
| The TRUE ultimate and final collate function. |
| It intelligently handles all data types and correctly generates the 'batch' tensor. |
| This is the final answer. |
| """ |
| |
| collated_dict = collate_fn(batch) |
|
|
| |
| |
| |
| if "offset" in collated_dict: |
| collated_dict["offset"] = torch.cumsum(collated_dict["offset"].long(), dim=0) |
|
|
| |
| |
| if "offset" in collated_dict: |
| counts = torch.diff( |
| collated_dict["offset"], |
| prepend=collated_dict["offset"].new_zeros(1), |
| ) |
| indices = torch.arange(counts.numel(), dtype=torch.long) |
| collated_dict["batch"] = torch.repeat_interleave(indices, counts) |
| elif "coord" in collated_dict: |
| counts = [d["coord"].shape[0] for d in batch] |
| indices = torch.arange(len(batch), dtype=torch.long) |
| collated_dict["batch"] = torch.repeat_interleave( |
| indices, torch.tensor(counts, dtype=torch.long) |
| ) |
| |
| return collated_dict |
|
|
| def collate_fn(batch): |
| """ |
| Our robust, intelligent, custom collate function that handles Tensors, |
| NumPy arrays, strings, dicts, and the special 'Point' object. |
| """ |
| if not isinstance(batch, list): |
| return batch |
| |
| elem = batch[0] |
| if isinstance(elem, collections.abc.Mapping): |
| |
| return {key: collate_fn([d[key] for d in batch]) for key in elem} |
| |
| |
| if isinstance(elem, Point): |
| return collate_fn([p.to_dict() for p in batch]) |
| |
| if isinstance(elem, str): |
| return batch |
| |
| if isinstance(elem, torch.Tensor): |
| return torch.cat(batch, 0) |
| |
| if isinstance(elem, np.ndarray): |
| return np.concatenate(batch, 0) |
|
|
| if isinstance(elem, (int, float)): |
| return torch.tensor(batch) |
|
|
| |
| try: |
| return torch.utils.data.dataloader.default_collate(batch) |
| except (RuntimeError, TypeError): |
| return batch |
|
|