| | from pathlib import Path |
| | from jaxtyping import Float |
| | import numpy as np |
| | from scipy.spatial.transform import Rotation as R |
| | from plyfile import PlyData, PlyElement |
| | import torch |
| | from torch import Tensor |
| | from einops import rearrange, einsum |
| |
|
| |
|
| | def construct_list_of_attributes(num_rest: int) -> list[str]: |
| | attributes = ["x", "y", "z", "nx", "ny", "nz"] |
| | for i in range(3): |
| | attributes.append(f"f_dc_{i}") |
| | for i in range(num_rest): |
| | attributes.append(f"f_rest_{i}") |
| | attributes.append("opacity") |
| | for i in range(3): |
| | attributes.append(f"scale_{i}") |
| | for i in range(4): |
| | attributes.append(f"rot_{i}") |
| | return attributes |
| |
|
| |
|
| | def export_ply( |
| | means: Float[Tensor, "gaussian 3"], |
| | scales: Float[Tensor, "gaussian 3"], |
| | rotations: Float[Tensor, "gaussian 4"], |
| | harmonics: Float[Tensor, "gaussian 3 d_sh"], |
| | opacities: Float[Tensor, "gaussian"], |
| | path: Path, |
| | ): |
| | path = Path(path) |
| | |
| | means = means - means.median(dim=0).values |
| |
|
| | |
| | scale_factor = means.abs().quantile(0.95, dim=0).max() |
| | means = means / scale_factor |
| | scales = scales / scale_factor |
| | scales = scales * 4.0 |
| | scales = torch.clamp(scales, 0, 0.0075) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | rotation = [ |
| | [1, 0, 0], |
| | [0, 1, 0], |
| | [0, 0, 1], |
| | ] |
| | rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | means = einsum(rotation, means, "i j, ... j -> ... i") |
| |
|
| | |
| | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() |
| | rotations = rotation.detach().cpu().numpy() @ rotations |
| | rotations = R.from_matrix(rotations).as_quat() |
| | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") |
| | rotations = np.stack((w, x, y, z), axis=-1) |
| |
|
| | |
| | |
| | harmonics_view_invariant = harmonics |
| |
|
| | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] |
| | elements = np.empty(means.shape[0], dtype=dtype_full) |
| | attributes = ( |
| | means.detach().cpu().numpy(), |
| | torch.zeros_like(means).detach().cpu().numpy(), |
| | harmonics_view_invariant.detach().cpu().contiguous().numpy(), |
| | opacities.detach().cpu().numpy(), |
| | scales.log().detach().cpu().numpy(), |
| | rotations, |
| | ) |
| | attributes = np.concatenate(attributes, axis=1) |
| | elements[:] = list(map(tuple, attributes)) |
| | path.parent.mkdir(exist_ok=True, parents=True) |
| | PlyData([PlyElement.describe(elements, "vertex")]).write(path) |
| |
|
| |
|
| | def save_ply(outputs, path, num_gauss=3): |
| | pad = 32 |
| |
|
| | def crop_r(t): |
| | h, w = 256, 384 |
| | H = h + pad * 2 |
| | W = w + pad * 2 |
| | t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W) |
| | t = t[..., pad:H-pad, pad:W-pad] |
| | t = rearrange(t, "b c h w -> b c (h w)") |
| | return t |
| |
|
| | def crop(t): |
| | h, w = 256, 384 |
| | H = h + pad * 2 |
| | W = w + pad * 2 |
| | t = t[..., pad:H-pad, pad:W-pad] |
| | return t |
| |
|
| | |
| | |
| | means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3] |
| | scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] |
| | rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] |
| | opacities = rearrange(crop(outputs[('gauss_opacity', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] |
| | harmonics = rearrange(crop(outputs[('gauss_features_dc', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] |
| |
|
| | export_ply( |
| | means, |
| | scales, |
| | rotations, |
| | harmonics, |
| | opacities, |
| | path |
| | ) |