Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from einops import einsum, rearrange | |
| from jaxtyping import Float | |
| from plyfile import PlyData, PlyElement | |
| from scipy.spatial.transform import Rotation as R | |
| from torch import Tensor | |
| def construct_list_of_attributes(num_rest: int) -> list[str]: | |
| attributes = ["x", "y", "z", "nx", "ny", "nz"] | |
| for i in range(3): | |
| attributes.append(f"f_dc_{i}") | |
| for i in range(num_rest): | |
| attributes.append(f"f_rest_{i}") | |
| attributes.append("opacity") | |
| for i in range(3): | |
| attributes.append(f"scale_{i}") | |
| for i in range(4): | |
| attributes.append(f"rot_{i}") | |
| return attributes | |
| def export_ply( | |
| means: Float[Tensor, "gaussian 3"], | |
| scales: Float[Tensor, "gaussian 3"], | |
| rotations: Float[Tensor, "gaussian 4"], | |
| harmonics: Float[Tensor, "gaussian 3 d_sh"], | |
| opacities: Float[Tensor, " gaussian"], | |
| path: Path, | |
| shift_and_scale: bool = False, | |
| save_sh_dc_only: bool = True, | |
| ): | |
| if shift_and_scale: | |
| # Shift the scene so that the median Gaussian is at the origin. | |
| means = means - means.median(dim=0).values | |
| # Rescale the scene so that most Gaussians are within range [-1, 1]. | |
| scale_factor = means.abs().quantile(0.95, dim=0).max() | |
| means = means / scale_factor | |
| scales = scales / scale_factor | |
| # Apply the rotation to the Gaussian rotations. | |
| rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() | |
| rotations = R.from_matrix(rotations).as_quat() | |
| x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") | |
| rotations = np.stack((w, x, y, z), axis=-1) | |
| # Since current model use SH_degree = 4, | |
| # which require large memory to store, we can only save the DC band to save memory. | |
| f_dc = harmonics[..., 0] | |
| f_rest = harmonics[..., 1:].flatten(start_dim=1) | |
| dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])] | |
| elements = np.empty(means.shape[0], dtype=dtype_full) | |
| attributes = [ | |
| means.detach().cpu().numpy(), | |
| torch.zeros_like(means).detach().cpu().numpy(), | |
| f_dc.detach().cpu().contiguous().numpy(), | |
| f_rest.detach().cpu().contiguous().numpy(), | |
| opacities[..., None].detach().cpu().numpy(), | |
| scales.log().detach().cpu().numpy(), | |
| rotations, | |
| ] | |
| if save_sh_dc_only: | |
| # remove f_rest from attributes | |
| attributes.pop(3) | |
| attributes = np.concatenate(attributes, axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| PlyData([PlyElement.describe(elements, "vertex")]).write(path) | |