| from typing import *
|
| from fractions import Fraction
|
| import torch
|
| from . import config
|
|
|
|
|
| __all__ = [
|
| 'VarLenTensor',
|
| 'varlen_cat',
|
| 'varlen_unbind',
|
| 'SparseTensor',
|
| 'sparse_cat',
|
| 'sparse_unbind',
|
| ]
|
|
|
|
|
| class VarLenTensor:
|
| """
|
| Sequential tensor with variable length.
|
|
|
| Args:
|
| feats (torch.Tensor): Features of the varlen tensor.
|
| layout (List[slice]): Layout of the varlen tensor for each batch
|
| """
|
| def __init__(self, feats: torch.Tensor, layout: List[slice]=None):
|
| self.feats = feats
|
| self.layout = layout if layout is not None else [slice(0, feats.shape[0])]
|
| self._cache = {}
|
|
|
| @staticmethod
|
| def layout_from_seqlen(seqlen: list) -> List[slice]:
|
| """
|
| Create a layout from a tensor of sequence lengths.
|
| """
|
| layout = []
|
| start = 0
|
| for l in seqlen:
|
| layout.append(slice(start, start + l))
|
| start += l
|
| return layout
|
|
|
| @staticmethod
|
| def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor':
|
| """
|
| Create a VarLenTensor from a list of tensors.
|
| """
|
| feats = torch.cat(tensor_list, dim=0)
|
| layout = []
|
| start = 0
|
| for tensor in tensor_list:
|
| layout.append(slice(start, start + tensor.shape[0]))
|
| start += tensor.shape[0]
|
| return VarLenTensor(feats, layout)
|
|
|
| def to_tensor_list(self) -> List[torch.Tensor]:
|
| """
|
| Convert a VarLenTensor to a list of tensors.
|
| """
|
| tensor_list = []
|
| for s in self.layout:
|
| tensor_list.append(self.feats[s])
|
| return tensor_list
|
|
|
| def __len__(self) -> int:
|
| return len(self.layout)
|
|
|
| @property
|
| def shape(self) -> torch.Size:
|
| return torch.Size([len(self.layout), *self.feats.shape[1:]])
|
|
|
| def dim(self) -> int:
|
| return len(self.shape)
|
|
|
| @property
|
| def ndim(self) -> int:
|
| return self.dim()
|
|
|
| @property
|
| def dtype(self):
|
| return self.feats.dtype
|
|
|
| @property
|
| def device(self):
|
| return self.feats.device
|
|
|
| @property
|
| def seqlen(self) -> torch.LongTensor:
|
| if 'seqlen' not in self._cache:
|
| self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device)
|
| return self._cache['seqlen']
|
|
|
| @property
|
| def cum_seqlen(self) -> torch.LongTensor:
|
| if 'cum_seqlen' not in self._cache:
|
| self._cache['cum_seqlen'] = torch.cat([
|
| torch.tensor([0], dtype=torch.long, device=self.device),
|
| self.seqlen.cumsum(dim=0)
|
| ], dim=0)
|
| return self._cache['cum_seqlen']
|
|
|
| @property
|
| def batch_boardcast_map(self) -> torch.LongTensor:
|
| """
|
| Get the broadcast map for the varlen tensor.
|
| """
|
| if 'batch_boardcast_map' not in self._cache:
|
| self._cache['batch_boardcast_map'] = torch.repeat_interleave(
|
| torch.arange(len(self.layout), device=self.device),
|
| self.seqlen,
|
| )
|
| return self._cache['batch_boardcast_map']
|
|
|
| @overload
|
| def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ...
|
|
|
| @overload
|
| def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ...
|
|
|
| def to(self, *args, **kwargs) -> 'VarLenTensor':
|
| device = None
|
| dtype = None
|
| if len(args) == 2:
|
| device, dtype = args
|
| elif len(args) == 1:
|
| if isinstance(args[0], torch.dtype):
|
| dtype = args[0]
|
| else:
|
| device = args[0]
|
| if 'dtype' in kwargs:
|
| assert dtype is None, "to() received multiple values for argument 'dtype'"
|
| dtype = kwargs['dtype']
|
| if 'device' in kwargs:
|
| assert device is None, "to() received multiple values for argument 'device'"
|
| device = kwargs['device']
|
| non_blocking = kwargs.get('non_blocking', False)
|
| copy = kwargs.get('copy', False)
|
|
|
| new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy)
|
| return self.replace(new_feats)
|
|
|
| def type(self, dtype):
|
| new_feats = self.feats.type(dtype)
|
| return self.replace(new_feats)
|
|
|
| def cpu(self) -> 'VarLenTensor':
|
| new_feats = self.feats.cpu()
|
| return self.replace(new_feats)
|
|
|
| def cuda(self) -> 'VarLenTensor':
|
| new_feats = self.feats.cuda()
|
| return self.replace(new_feats)
|
|
|
| def half(self) -> 'VarLenTensor':
|
| new_feats = self.feats.half()
|
| return self.replace(new_feats)
|
|
|
| def float(self) -> 'VarLenTensor':
|
| new_feats = self.feats.float()
|
| return self.replace(new_feats)
|
|
|
| def detach(self) -> 'VarLenTensor':
|
| new_feats = self.feats.detach()
|
| return self.replace(new_feats)
|
|
|
| def reshape(self, *shape) -> 'VarLenTensor':
|
| new_feats = self.feats.reshape(self.feats.shape[0], *shape)
|
| return self.replace(new_feats)
|
|
|
| def unbind(self, dim: int) -> List['VarLenTensor']:
|
| return varlen_unbind(self, dim)
|
|
|
| def replace(self, feats: torch.Tensor) -> 'VarLenTensor':
|
| new_tensor = VarLenTensor(
|
| feats=feats,
|
| layout=self.layout,
|
| )
|
| new_tensor._cache = self._cache
|
| return new_tensor
|
|
|
| def to_dense(self, max_length=None) -> torch.Tensor:
|
| """
|
| Convert a VarLenTensor to a dense representation without for-loop.
|
|
|
| Returns:
|
| dense (torch.Tensor): (N, L, C) dense tensor
|
| mask (torch.BoolTensor): (N, L) mask indicating valid positions
|
| """
|
| N = len(self)
|
| L = max_length or self.seqlen.max().item()
|
| spatial = self.feats.shape[1:]
|
| idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L)
|
| mask = (idx < self.seqlen.unsqueeze(1))
|
| mapping = mask.reshape(-1).cumsum(dim=0) - 1
|
| dense = self.feats[mapping]
|
| dense = dense.reshape(N, L, *spatial)
|
| return dense, mask
|
|
|
| def __neg__(self) -> 'VarLenTensor':
|
| return self.replace(-self.feats)
|
|
|
| def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor':
|
| if isinstance(other, torch.Tensor):
|
| try:
|
| other = torch.broadcast_to(other, self.shape)
|
| other = other[self.batch_boardcast_map]
|
| except:
|
| pass
|
| if isinstance(other, VarLenTensor):
|
| other = other.feats
|
| new_feats = op(self.feats, other)
|
| new_tensor = self.replace(new_feats)
|
| return new_tensor
|
|
|
| def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
|
| return self.__elemwise__(other, torch.add)
|
|
|
| def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
|
| return self.__elemwise__(other, torch.add)
|
|
|
| def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
|
| return self.__elemwise__(other, torch.sub)
|
|
|
| def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
|
| return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
|
|
|
| def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
|
| return self.__elemwise__(other, torch.mul)
|
|
|
| def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
|
| return self.__elemwise__(other, torch.mul)
|
|
|
| def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
|
| return self.__elemwise__(other, torch.div)
|
|
|
| def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
|
| return self.__elemwise__(other, lambda x, y: torch.div(y, x))
|
|
|
| def __getitem__(self, idx):
|
| if isinstance(idx, int):
|
| idx = [idx]
|
| elif isinstance(idx, slice):
|
| idx = range(*idx.indices(self.shape[0]))
|
| elif isinstance(idx, list):
|
| assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}"
|
| elif isinstance(idx, torch.Tensor):
|
| if idx.dtype == torch.bool:
|
| assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
|
| idx = idx.nonzero().squeeze(1)
|
| elif idx.dtype in [torch.int32, torch.int64]:
|
| assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
|
| else:
|
| raise ValueError(f"Unknown index type: {idx.dtype}")
|
| else:
|
| raise ValueError(f"Unknown index type: {type(idx)}")
|
|
|
| new_feats = []
|
| new_layout = []
|
| start = 0
|
| for new_idx, old_idx in enumerate(idx):
|
| new_feats.append(self.feats[self.layout[old_idx]])
|
| new_layout.append(slice(start, start + len(new_feats[-1])))
|
| start += len(new_feats[-1])
|
| new_feats = torch.cat(new_feats, dim=0).contiguous()
|
| new_tensor = VarLenTensor(feats=new_feats, layout=new_layout)
|
| return new_tensor
|
|
|
| def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
| if isinstance(dim, int):
|
| dim = (dim,)
|
|
|
| if op =='mean':
|
| red = self.feats.mean(dim=dim, keepdim=keepdim)
|
| elif op =='sum':
|
| red = self.feats.sum(dim=dim, keepdim=keepdim)
|
| elif op == 'prod':
|
| red = self.feats.prod(dim=dim, keepdim=keepdim)
|
| else:
|
| raise ValueError(f"Unsupported reduce operation: {op}")
|
|
|
| if dim is None or 0 in dim:
|
| return red
|
|
|
| red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen)
|
| return red
|
|
|
| def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
| return self.reduce(op='mean', dim=dim, keepdim=keepdim)
|
|
|
| def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
| return self.reduce(op='sum', dim=dim, keepdim=keepdim)
|
|
|
| def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
| return self.reduce(op='prod', dim=dim, keepdim=keepdim)
|
|
|
| def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
| mean = self.mean(dim=dim, keepdim=True)
|
| mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True)
|
| std = (mean2 - mean ** 2).sqrt()
|
| return std
|
|
|
| def __repr__(self) -> str:
|
| return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
|
|
|
|
|
| def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor:
|
| """
|
| Concatenate a list of varlen tensors.
|
|
|
| Args:
|
| inputs (List[VarLenTensor]): List of varlen tensors to concatenate.
|
| """
|
| if dim == 0:
|
| new_feats = torch.cat([input.feats for input in inputs], dim=0)
|
| start = 0
|
| new_layout = []
|
| for input in inputs:
|
| for l in input.layout:
|
| new_layout.append(slice(start, start + l.stop - l.start))
|
| start += l.stop - l.start
|
| output = VarLenTensor(feats=new_feats, layout=new_layout)
|
| else:
|
| feats = torch.cat([input.feats for input in inputs], dim=dim)
|
| output = inputs[0].replace(feats)
|
|
|
| return output
|
|
|
|
|
| def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]:
|
| """
|
| Unbind a varlen tensor along a dimension.
|
|
|
| Args:
|
| input (VarLenTensor): Varlen tensor to unbind.
|
| dim (int): Dimension to unbind.
|
| """
|
| if dim == 0:
|
| return [input[i] for i in range(len(input))]
|
| else:
|
| feats = input.feats.unbind(dim)
|
| return [input.replace(f) for f in feats]
|
|
|
|
|
| class SparseTensor(VarLenTensor):
|
| """
|
| Sparse tensor with support for both torchsparse and spconv backends.
|
|
|
| Parameters:
|
| - feats (torch.Tensor): Features of the sparse tensor.
|
| - coords (torch.Tensor): Coordinates of the sparse tensor.
|
| - shape (torch.Size): Shape of the sparse tensor.
|
| - layout (List[slice]): Layout of the sparse tensor for each batch
|
| - data (SparseTensorData): Sparse tensor data used for convolusion
|
|
|
| NOTE:
|
| - Data corresponding to a same batch should be contiguous.
|
| - Coords should be in [0, 1023]
|
| """
|
| SparseTensorData = None
|
|
|
| @overload
|
| def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ...
|
|
|
| @overload
|
| def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ...
|
|
|
| def __init__(self, *args, **kwargs):
|
|
|
| if self.SparseTensorData is None:
|
| import importlib
|
| if config.CONV == 'torchsparse':
|
| self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor
|
| elif config.CONV == 'spconv':
|
| self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
|
|
|
| method_id = 0
|
| if len(args) != 0:
|
| method_id = 0 if isinstance(args[0], torch.Tensor) else 1
|
| else:
|
| method_id = 1 if 'data' in kwargs else 0
|
|
|
| if method_id == 0:
|
| feats, coords, shape = args + (None,) * (3 - len(args))
|
| if 'feats' in kwargs:
|
| feats = kwargs['feats']
|
| del kwargs['feats']
|
| if 'coords' in kwargs:
|
| coords = kwargs['coords']
|
| del kwargs['coords']
|
| if 'shape' in kwargs:
|
| shape = kwargs['shape']
|
| del kwargs['shape']
|
|
|
| if config.CONV == 'torchsparse':
|
| self.data = self.SparseTensorData(feats, coords, **kwargs)
|
| elif config.CONV == 'spconv':
|
| spatial_shape = list(coords.max(0)[0] + 1)
|
| self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs)
|
| self.data._features = feats
|
| else:
|
| self.data = {
|
| 'feats': feats,
|
| 'coords': coords,
|
| }
|
| elif method_id == 1:
|
| data, shape = args + (None,) * (2 - len(args))
|
| if 'data' in kwargs:
|
| data = kwargs['data']
|
| del kwargs['data']
|
| if 'shape' in kwargs:
|
| shape = kwargs['shape']
|
| del kwargs['shape']
|
|
|
| self.data = data
|
|
|
| self._shape = shape
|
| self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)))
|
| self._spatial_cache = kwargs.get('spatial_cache', {})
|
|
|
| if config.DEBUG:
|
| try:
|
| assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
|
| assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
|
| assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
|
| for i in range(self.shape[0]):
|
| assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
|
| except Exception as e:
|
| print('Debugging information:')
|
| print(f"- Shape: {self.shape}")
|
| print(f"- Layout: {self.layout}")
|
| print(f"- Scale: {self._scale}")
|
| print(f"- Coords: {self.coords}")
|
| raise e
|
|
|
| @staticmethod
|
| def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor':
|
| """
|
| Create a SparseTensor from a list of tensors.
|
| """
|
| feats = torch.cat(feats_list, dim=0)
|
| coords = []
|
| for i, coord in enumerate(coords_list):
|
| coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1)
|
| coords.append(coord)
|
| coords = torch.cat(coords, dim=0)
|
| return SparseTensor(feats, coords)
|
|
|
| def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| """
|
| Convert a SparseTensor to list of tensors.
|
| """
|
| feats_list = []
|
| coords_list = []
|
| for s in self.layout:
|
| feats_list.append(self.feats[s])
|
| coords_list.append(self.coords[s])
|
| return feats_list, coords_list
|
|
|
| def __len__(self) -> int:
|
| return len(self.layout)
|
|
|
| def __cal_shape(self, feats, coords):
|
| shape = []
|
| shape.append(coords[:, 0].max().item() + 1)
|
| shape.extend([*feats.shape[1:]])
|
| return torch.Size(shape)
|
|
|
| def __cal_layout(self, coords, batch_size):
|
| seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
|
| offset = torch.cumsum(seq_len, dim=0)
|
| layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
|
| return layout
|
|
|
| def __cal_spatial_shape(self, coords):
|
| return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist())
|
|
|
| @property
|
| def shape(self) -> torch.Size:
|
| if self._shape is None:
|
| self._shape = self.__cal_shape(self.feats, self.coords)
|
| return self._shape
|
|
|
| @property
|
| def layout(self) -> List[slice]:
|
| layout = self.get_spatial_cache('layout')
|
| if layout is None:
|
| layout = self.__cal_layout(self.coords, self.shape[0])
|
| self.register_spatial_cache('layout', layout)
|
| return layout
|
|
|
| @property
|
| def spatial_shape(self) -> torch.Size:
|
| spatial_shape = self.get_spatial_cache('shape')
|
| if spatial_shape is None:
|
| spatial_shape = self.__cal_spatial_shape(self.coords)
|
| self.register_spatial_cache('shape', spatial_shape)
|
| return spatial_shape
|
|
|
| @property
|
| def feats(self) -> torch.Tensor:
|
| if config.CONV == 'torchsparse':
|
| return self.data.F
|
| elif config.CONV == 'spconv':
|
| return self.data.features
|
| else:
|
| return self.data['feats']
|
|
|
| @feats.setter
|
| def feats(self, value: torch.Tensor):
|
| if config.CONV == 'torchsparse':
|
| self.data.F = value
|
| elif config.CONV == 'spconv':
|
| self.data.features = value
|
| else:
|
| self.data['feats'] = value
|
|
|
| @property
|
| def coords(self) -> torch.Tensor:
|
| if config.CONV == 'torchsparse':
|
| return self.data.C
|
| elif config.CONV == 'spconv':
|
| return self.data.indices
|
| else:
|
| return self.data['coords']
|
|
|
| @coords.setter
|
| def coords(self, value: torch.Tensor):
|
| if config.CONV == 'torchsparse':
|
| self.data.C = value
|
| elif config.CONV == 'spconv':
|
| self.data.indices = value
|
| else:
|
| self.data['coords'] = value
|
|
|
| @property
|
| def dtype(self):
|
| return self.feats.dtype
|
|
|
| @property
|
| def device(self):
|
| return self.feats.device
|
|
|
| @property
|
| def seqlen(self) -> torch.LongTensor:
|
| seqlen = self.get_spatial_cache('seqlen')
|
| if seqlen is None:
|
| seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device)
|
| self.register_spatial_cache('seqlen', seqlen)
|
| return seqlen
|
|
|
| @property
|
| def cum_seqlen(self) -> torch.LongTensor:
|
| cum_seqlen = self.get_spatial_cache('cum_seqlen')
|
| if cum_seqlen is None:
|
| cum_seqlen = torch.cat([
|
| torch.tensor([0], dtype=torch.long, device=self.device),
|
| self.seqlen.cumsum(dim=0)
|
| ], dim=0)
|
| self.register_spatial_cache('cum_seqlen', cum_seqlen)
|
| return cum_seqlen
|
|
|
| @property
|
| def batch_boardcast_map(self) -> torch.LongTensor:
|
| """
|
| Get the broadcast map for the varlen tensor.
|
| """
|
| batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map')
|
| if batch_boardcast_map is None:
|
| batch_boardcast_map = torch.repeat_interleave(
|
| torch.arange(len(self.layout), device=self.device),
|
| self.seqlen,
|
| )
|
| self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map)
|
| return batch_boardcast_map
|
|
|
| @overload
|
| def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ...
|
|
|
| @overload
|
| def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ...
|
|
|
| def to(self, *args, **kwargs) -> 'SparseTensor':
|
| device = None
|
| dtype = None
|
| if len(args) == 2:
|
| device, dtype = args
|
| elif len(args) == 1:
|
| if isinstance(args[0], torch.dtype):
|
| dtype = args[0]
|
| else:
|
| device = args[0]
|
| if 'dtype' in kwargs:
|
| assert dtype is None, "to() received multiple values for argument 'dtype'"
|
| dtype = kwargs['dtype']
|
| if 'device' in kwargs:
|
| assert device is None, "to() received multiple values for argument 'device'"
|
| device = kwargs['device']
|
| non_blocking = kwargs.get('non_blocking', False)
|
| copy = kwargs.get('copy', False)
|
|
|
| new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy)
|
| new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy)
|
| return self.replace(new_feats, new_coords)
|
|
|
| def type(self, dtype):
|
| new_feats = self.feats.type(dtype)
|
| return self.replace(new_feats)
|
|
|
| def cpu(self) -> 'SparseTensor':
|
| new_feats = self.feats.cpu()
|
| new_coords = self.coords.cpu()
|
| return self.replace(new_feats, new_coords)
|
|
|
| def cuda(self) -> 'SparseTensor':
|
| new_feats = self.feats.cuda()
|
| new_coords = self.coords.cuda()
|
| return self.replace(new_feats, new_coords)
|
|
|
| def half(self) -> 'SparseTensor':
|
| new_feats = self.feats.half()
|
| return self.replace(new_feats)
|
|
|
| def float(self) -> 'SparseTensor':
|
| new_feats = self.feats.float()
|
| return self.replace(new_feats)
|
|
|
| def detach(self) -> 'SparseTensor':
|
| new_coords = self.coords.detach()
|
| new_feats = self.feats.detach()
|
| return self.replace(new_feats, new_coords)
|
|
|
| def reshape(self, *shape) -> 'SparseTensor':
|
| new_feats = self.feats.reshape(self.feats.shape[0], *shape)
|
| return self.replace(new_feats)
|
|
|
| def unbind(self, dim: int) -> List['SparseTensor']:
|
| return sparse_unbind(self, dim)
|
|
|
| def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
|
| if config.CONV == 'torchsparse':
|
| new_data = self.SparseTensorData(
|
| feats=feats,
|
| coords=self.data.coords if coords is None else coords,
|
| stride=self.data.stride,
|
| spatial_range=self.data.spatial_range,
|
| )
|
| new_data._caches = self.data._caches
|
| elif config.CONV == 'spconv':
|
| new_data = self.SparseTensorData(
|
| self.data.features.reshape(self.data.features.shape[0], -1),
|
| self.data.indices,
|
| self.data.spatial_shape,
|
| self.data.batch_size,
|
| self.data.grid,
|
| self.data.voxel_num,
|
| self.data.indice_dict
|
| )
|
| new_data._features = feats
|
| new_data.benchmark = self.data.benchmark
|
| new_data.benchmark_record = self.data.benchmark_record
|
| new_data.thrust_allocator = self.data.thrust_allocator
|
| new_data._timer = self.data._timer
|
| new_data.force_algo = self.data.force_algo
|
| new_data.int8_scale = self.data.int8_scale
|
| if coords is not None:
|
| new_data.indices = coords
|
| else:
|
| new_data = {
|
| 'feats': feats,
|
| 'coords': self.data['coords'] if coords is None else coords,
|
| }
|
| new_tensor = SparseTensor(
|
| new_data,
|
| shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None,
|
| scale=self._scale,
|
| spatial_cache=self._spatial_cache
|
| )
|
| return new_tensor
|
|
|
| def to_dense(self) -> torch.Tensor:
|
| if config.CONV == 'torchsparse':
|
| return self.data.dense()
|
| elif config.CONV == 'spconv':
|
| return self.data.dense()
|
| else:
|
| spatial_shape = self.spatial_shape
|
| ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device)
|
| idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1)
|
| ret[tuple(idx)] = self.feats
|
| return ret
|
|
|
| @staticmethod
|
| def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
|
| N, C = dim
|
| x = torch.arange(aabb[0], aabb[3] + 1)
|
| y = torch.arange(aabb[1], aabb[4] + 1)
|
| z = torch.arange(aabb[2], aabb[5] + 1)
|
| coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
|
| coords = torch.cat([
|
| torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
|
| coords.repeat(N, 1),
|
| ], dim=1).to(dtype=torch.int32, device=device)
|
| feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
|
| return SparseTensor(feats=feats, coords=coords)
|
|
|
| def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
|
| new_cache = {}
|
| for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
|
| if k in self._spatial_cache:
|
| new_cache[k] = self._spatial_cache[k]
|
| if k in other._spatial_cache:
|
| if k not in new_cache:
|
| new_cache[k] = other._spatial_cache[k]
|
| else:
|
| new_cache[k].update(other._spatial_cache[k])
|
| return new_cache
|
|
|
| def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor':
|
| if isinstance(other, torch.Tensor):
|
| try:
|
| other = torch.broadcast_to(other, self.shape)
|
| other = other[self.batch_boardcast_map]
|
| except:
|
| pass
|
| if isinstance(other, VarLenTensor):
|
| other = other.feats
|
| new_feats = op(self.feats, other)
|
| new_tensor = self.replace(new_feats)
|
| if isinstance(other, SparseTensor):
|
| new_tensor._spatial_cache = self.__merge_sparse_cache(other)
|
| return new_tensor
|
|
|
| def __getitem__(self, idx):
|
| if isinstance(idx, int):
|
| idx = [idx]
|
| elif isinstance(idx, slice):
|
| idx = range(*idx.indices(self.shape[0]))
|
| elif isinstance(idx, list):
|
| assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}"
|
| elif isinstance(idx, torch.Tensor):
|
| if idx.dtype == torch.bool:
|
| assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
|
| idx = idx.nonzero().squeeze(1)
|
| elif idx.dtype in [torch.int32, torch.int64]:
|
| assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
|
| else:
|
| raise ValueError(f"Unknown index type: {idx.dtype}")
|
| else:
|
| raise ValueError(f"Unknown index type: {type(idx)}")
|
|
|
| new_coords = []
|
| new_feats = []
|
| new_layout = []
|
| new_shape = torch.Size([len(idx)] + list(self.shape[1:]))
|
| start = 0
|
| for new_idx, old_idx in enumerate(idx):
|
| new_coords.append(self.coords[self.layout[old_idx]].clone())
|
| new_coords[-1][:, 0] = new_idx
|
| new_feats.append(self.feats[self.layout[old_idx]])
|
| new_layout.append(slice(start, start + len(new_coords[-1])))
|
| start += len(new_coords[-1])
|
| new_coords = torch.cat(new_coords, dim=0).contiguous()
|
| new_feats = torch.cat(new_feats, dim=0).contiguous()
|
| new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape)
|
| new_tensor.register_spatial_cache('layout', new_layout)
|
| return new_tensor
|
|
|
| def clear_spatial_cache(self) -> None:
|
| """
|
| Clear all spatial caches.
|
| """
|
| self._spatial_cache = {}
|
|
|
| def register_spatial_cache(self, key, value) -> None:
|
| """
|
| Register a spatial cache.
|
| The spatial cache can be any thing you want to cache.
|
| The registery and retrieval of the cache is based on current scale.
|
| """
|
| scale_key = str(self._scale)
|
| if scale_key not in self._spatial_cache:
|
| self._spatial_cache[scale_key] = {}
|
| self._spatial_cache[scale_key][key] = value
|
|
|
| def get_spatial_cache(self, key=None):
|
| """
|
| Get a spatial cache.
|
| """
|
| scale_key = str(self._scale)
|
| cur_scale_cache = self._spatial_cache.get(scale_key, {})
|
| if key is None:
|
| return cur_scale_cache
|
| return cur_scale_cache.get(key, None)
|
|
|
| def __repr__(self) -> str:
|
| return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
|
|
|
| def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
|
| """
|
| Concatenate a list of sparse tensors.
|
|
|
| Args:
|
| inputs (List[SparseTensor]): List of sparse tensors to concatenate.
|
| """
|
| if dim == 0:
|
| start = 0
|
| coords = []
|
| for input in inputs:
|
| coords.append(input.coords.clone())
|
| coords[-1][:, 0] += start
|
| start += input.shape[0]
|
| coords = torch.cat(coords, dim=0)
|
| feats = torch.cat([input.feats for input in inputs], dim=0)
|
| output = SparseTensor(
|
| coords=coords,
|
| feats=feats,
|
| )
|
| else:
|
| feats = torch.cat([input.feats for input in inputs], dim=dim)
|
| output = inputs[0].replace(feats)
|
|
|
| return output
|
|
|
|
|
| def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
|
| """
|
| Unbind a sparse tensor along a dimension.
|
|
|
| Args:
|
| input (SparseTensor): Sparse tensor to unbind.
|
| dim (int): Dimension to unbind.
|
| """
|
| if dim == 0:
|
| return [input[i] for i in range(input.shape[0])]
|
| else:
|
| feats = input.feats.unbind(dim)
|
| return [input.replace(f) for f in feats]
|
|
|