# pointcept_framework/pointcept/datasets/utils.py import torch import collections import numpy as np from pointcept.models.utils.structure import Point def point_collate_fn(batch, mix_prob=0): """ The TRUE ultimate and final collate function. It intelligently handles all data types and correctly generates the 'batch' tensor. This is the final answer. """ # First, use our intelligent custom collate_fn to handle the complex data structure. collated_dict = collate_fn(batch) # `Collect` stores per-sample lengths as offset=[num_points]. After tensor # concatenation, we must convert them into cumulative offsets so downstream # Point/PTv3 utilities can recover correct per-sample bincounts. if "offset" in collated_dict: collated_dict["offset"] = torch.cumsum(collated_dict["offset"].long(), dim=0) # Build the explicit batch index tensor from cumulative offsets when # possible, otherwise fall back to raw coord lengths. if "offset" in collated_dict: counts = torch.diff( collated_dict["offset"], prepend=collated_dict["offset"].new_zeros(1), ) indices = torch.arange(counts.numel(), dtype=torch.long) collated_dict["batch"] = torch.repeat_interleave(indices, counts) elif "coord" in collated_dict: counts = [d["coord"].shape[0] for d in batch] indices = torch.arange(len(batch), dtype=torch.long) collated_dict["batch"] = torch.repeat_interleave( indices, torch.tensor(counts, dtype=torch.long) ) return collated_dict def collate_fn(batch): """ Our robust, intelligent, custom collate function that handles Tensors, NumPy arrays, strings, dicts, and the special 'Point' object. """ if not isinstance(batch, list): return batch elem = batch[0] if isinstance(elem, collections.abc.Mapping): # This handles the main dictionary structure return {key: collate_fn([d[key] for d in batch]) for key in elem} # This handles the special 'Point' object by converting it to a dictionary first if isinstance(elem, Point): return collate_fn([p.to_dict() for p in batch]) if isinstance(elem, str): return batch if isinstance(elem, torch.Tensor): return torch.cat(batch, 0) if isinstance(elem, np.ndarray): return np.concatenate(batch, 0) if isinstance(elem, (int, float)): return torch.tensor(batch) # Fallback for any other types, although it's unlikely to be used now. try: return torch.utils.data.dataloader.default_collate(batch) except (RuntimeError, TypeError): return batch