| from typing import Dict |
| import torch |
|
|
|
|
| class Voxel: |
| def __init__( |
| self, |
| origin: list, |
| voxel_size: float, |
| coords: torch.Tensor = None, |
| attrs: torch.Tensor = None, |
| layout: Dict = {}, |
| device: torch.device = 'cuda' |
| ): |
| self.origin = torch.tensor(origin, dtype=torch.float32, device=device) |
| self.voxel_size = voxel_size |
| self.coords = coords |
| self.attrs = attrs |
| self.layout = layout |
| self.device = device |
| |
| @property |
| def position(self): |
| return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] |
| |
| def split_attrs(self): |
| return { |
| k: self.attrs[:, self.layout[k]] |
| for k in self.layout |
| } |
| |
| def save(self, path): |
| |
| if 'o_voxel' not in globals(): |
| import o_voxel |
| o_voxel.io.write( |
| path, |
| self.coords, |
| self.split_attrs(), |
| ) |
| |
| def load(self, path): |
| |
| if 'o_voxel' not in globals(): |
| import o_voxel |
| coord, attrs = o_voxel.io.read(path) |
| self.coords = coord.int().to(self.device) |
| self.attrs = torch.cat([attrs[k] for k in attrs], dim=1).to(self.device) |
| |
| start = 0 |
| self.layout = {} |
| for k in attrs: |
| self.layout[k] = slice(start, start + attrs[k].shape[1]) |
| start += attrs[k].shape[1] |
|
|