| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from einops import einsum, rearrange |
| from jaxtyping import Float |
| from plyfile import PlyData, PlyElement |
| from scipy.spatial.transform import Rotation as R |
| from torch import Tensor |
|
|
|
|
| def construct_list_of_attributes(num_rest: int) -> list[str]: |
| attributes = ["x", "y", "z", "nx", "ny", "nz"] |
| for i in range(3): |
| attributes.append(f"f_dc_{i}") |
| for i in range(num_rest): |
| attributes.append(f"f_rest_{i}") |
| attributes.append("opacity") |
| for i in range(3): |
| attributes.append(f"scale_{i}") |
| for i in range(4): |
| attributes.append(f"rot_{i}") |
| return attributes |
|
|
|
|
| def export_ply( |
| means: Float[Tensor, "gaussian 3"], |
| scales: Float[Tensor, "gaussian 3"], |
| rotations: Float[Tensor, "gaussian 4"], |
| harmonics: Float[Tensor, "gaussian 3 d_sh"], |
| opacities: Float[Tensor, " gaussian"], |
| path: Path, |
| shift_and_scale: bool = False, |
| save_sh_dc_only: bool = True, |
| ): |
| if shift_and_scale: |
| |
| means = means - means.median(dim=0).values |
|
|
| |
| scale_factor = means.abs().quantile(0.95, dim=0).max() |
| means = means / scale_factor |
| scales = scales / scale_factor |
|
|
| |
| rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() |
| rotations = R.from_matrix(rotations).as_quat() |
| x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") |
| rotations = np.stack((w, x, y, z), axis=-1) |
|
|
| |
| |
| f_dc = harmonics[..., 0] |
| f_rest = harmonics[..., 1:].flatten(start_dim=1) |
|
|
| dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])] |
| elements = np.empty(means.shape[0], dtype=dtype_full) |
| attributes = [ |
| means.detach().cpu().numpy(), |
| torch.zeros_like(means).detach().cpu().numpy(), |
| f_dc.detach().cpu().contiguous().numpy(), |
| f_rest.detach().cpu().contiguous().numpy(), |
| opacities[..., None].detach().cpu().numpy(), |
| scales.log().detach().cpu().numpy(), |
| rotations, |
| ] |
| if save_sh_dc_only: |
| |
| attributes.pop(3) |
|
|
| attributes = np.concatenate(attributes, axis=1) |
| elements[:] = list(map(tuple, attributes)) |
| path.parent.mkdir(exist_ok=True, parents=True) |
| PlyData([PlyElement.describe(elements, "vertex")]).write(path) |
|
|