YYYYYYUUU's picture
Backup FULL local core code incl. libs/ CUDA ext + all configs
3499c27 verified
Raw
History Blame Contribute Delete
2.72 kB
# pointcept_framework/pointcept/datasets/utils.py
import torch
import collections
import numpy as np
from pointcept.models.utils.structure import Point
def point_collate_fn(batch, mix_prob=0):
"""
The TRUE ultimate and final collate function.
It intelligently handles all data types and correctly generates the 'batch' tensor.
This is the final answer.
"""
# First, use our intelligent custom collate_fn to handle the complex data structure.
collated_dict = collate_fn(batch)
# `Collect` stores per-sample lengths as offset=[num_points]. After tensor
# concatenation, we must convert them into cumulative offsets so downstream
# Point/PTv3 utilities can recover correct per-sample bincounts.
if "offset" in collated_dict:
collated_dict["offset"] = torch.cumsum(collated_dict["offset"].long(), dim=0)
# Build the explicit batch index tensor from cumulative offsets when
# possible, otherwise fall back to raw coord lengths.
if "offset" in collated_dict:
counts = torch.diff(
collated_dict["offset"],
prepend=collated_dict["offset"].new_zeros(1),
)
indices = torch.arange(counts.numel(), dtype=torch.long)
collated_dict["batch"] = torch.repeat_interleave(indices, counts)
elif "coord" in collated_dict:
counts = [d["coord"].shape[0] for d in batch]
indices = torch.arange(len(batch), dtype=torch.long)
collated_dict["batch"] = torch.repeat_interleave(
indices, torch.tensor(counts, dtype=torch.long)
)
return collated_dict
def collate_fn(batch):
"""
Our robust, intelligent, custom collate function that handles Tensors,
NumPy arrays, strings, dicts, and the special 'Point' object.
"""
if not isinstance(batch, list):
return batch
elem = batch[0]
if isinstance(elem, collections.abc.Mapping):
# This handles the main dictionary structure
return {key: collate_fn([d[key] for d in batch]) for key in elem}
# This handles the special 'Point' object by converting it to a dictionary first
if isinstance(elem, Point):
return collate_fn([p.to_dict() for p in batch])
if isinstance(elem, str):
return batch
if isinstance(elem, torch.Tensor):
return torch.cat(batch, 0)
if isinstance(elem, np.ndarray):
return np.concatenate(batch, 0)
if isinstance(elem, (int, float)):
return torch.tensor(batch)
# Fallback for any other types, although it's unlikely to be used now.
try:
return torch.utils.data.dataloader.default_collate(batch)
except (RuntimeError, TypeError):
return batch