Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from jaxtyping import Float | |
| from plyfile import PlyData, PlyElement | |
| from torch import Tensor | |
| from optgs.model.types import Gaussians | |
| from optgs.scene_trainer.gaussian_module import GaussiansModule | |
| def construct_list_of_attributes(num_rest: int) -> list[str]: | |
| attributes = ["x", "y", "z", "nx", "ny", "nz"] | |
| for i in range(3): | |
| attributes.append(f"f_dc_{i}") | |
| for i in range(num_rest): | |
| attributes.append(f"f_rest_{i}") | |
| attributes.append("opacity") | |
| for i in range(3): | |
| attributes.append(f"scale_{i}") | |
| for i in range(4): | |
| attributes.append(f"rot_{i}") | |
| return attributes | |
| def export_ply( | |
| # extrinsics: Float[Tensor, "4 4"], | |
| means: Float[Tensor, "gaussian 3"], | |
| scales: Float[Tensor, "gaussian 3"], | |
| rotations: Float[Tensor, "gaussian 4"], | |
| harmonics: Float[Tensor, "gaussian 3 d_sh"], | |
| opacities: Float[Tensor, "gaussian"], | |
| path: Path, | |
| # align_to_view: bool = False, # whether to align world space to the view space (camera space) of the extrinsics | |
| ): | |
| means = means.detach().cpu().numpy() | |
| scales = scales.log().detach().cpu().numpy() | |
| rotations = rotations.detach().cpu().numpy() | |
| harmonics = harmonics.detach() # .cpu().numpy() | |
| opacities = torch.logit(opacities[..., None]).detach().cpu().numpy() | |
| num_rest = 3 * (harmonics.shape[-1] - 1) | |
| dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(num_rest)] | |
| elements = np.empty(means.shape[0], dtype=dtype_full) | |
| attributes = ( | |
| means, | |
| np.zeros_like(means), | |
| harmonics[..., 0].cpu().numpy(), | |
| harmonics[..., 1:].flatten(start_dim=1).cpu().numpy(), | |
| opacities, | |
| scales, | |
| rotations, | |
| ) | |
| attributes = np.concatenate(attributes, axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| PlyData([PlyElement.describe(elements, "vertex")]).write(path) | |
| def save_gaussian_ply( | |
| gaussians: Gaussians | GaussiansModule, | |
| save_path, | |
| save_all_gaussians=True, # no trim | |
| ): | |
| """ | |
| Save Gaussians to a .ply file for visualization. | |
| The saved object will have opacities and scales in the pre-activation space, | |
| i.e., before applying the activation functions (sigmoid for opacity, exp for scales). | |
| """ | |
| if not save_all_gaussians: | |
| raise NotImplementedError("Not implemented yet.") | |
| if isinstance(gaussians, GaussiansModule): | |
| # no batch dimension | |
| means = gaussians.means # [H*W, 3] | |
| rotations = gaussians.rotations # [H*W, 4] in xyzw | |
| scales = gaussians.scales # [H*W, 3] | |
| opacities = gaussians.opacities # [H*W] | |
| harmonics = gaussians.harmonics # [H*W, 3, d_sh] | |
| elif isinstance(gaussians, Gaussians): | |
| assert gaussians.means.shape[0] == 1, "Batch size > 1 not supported for saving ply." | |
| means = gaussians.means[0] # [H*W, 3] | |
| rotations = F.normalize(gaussians.rotations_unnorm[0], dim=-1) # [H*W, 4] in xyzw | |
| scales = gaussians.scales[0] # [H*W, 3] | |
| opacities = gaussians.opacities[0] # [H*W] | |
| harmonics = gaussians.harmonics[0] # [H*W, 3, d_sh] | |
| # export_ply expects activated values (post-exp scales, post-sigmoid opacities) | |
| # and applies inverse activation internally. If values are already deactivated, | |
| # we must activate them first to avoid double inverse activation. | |
| if not gaussians.stores_activated: | |
| scales = torch.exp(scales) | |
| opacities = torch.sigmoid(opacities) | |
| else: | |
| raise ValueError(f"Unknown type of gaussians: {type(gaussians)}") | |
| # convert to wxyz for saving | |
| rotations = rotations[:, [3, 0, 1, 2]] # [H*W, 4] in wxyz | |
| # This fn invert activation of opacity and scales (for standard gaussian object, loaded by viewer) | |
| export_ply( | |
| means=means, | |
| scales=scales, | |
| rotations=rotations, | |
| harmonics=harmonics, # [H*W, 3, d_sh] | |
| opacities=opacities, | |
| path=save_path, | |
| ) | |
| def load_gaussians_ply(path, max_sh_degree=3) -> Gaussians: | |
| """ Load Gaussians from a .ply file saved by export_ply(). | |
| The loaded object will have opacities and scales in the pre-activation space, | |
| i.e., before applying the activation functions (sigmoid for opacity, expfor scales). | |
| """ | |
| plydata = PlyData.read(path) | |
| xyz = np.stack((np.asarray(plydata.elements[0]["x"]), | |
| np.asarray(plydata.elements[0]["y"]), | |
| np.asarray(plydata.elements[0]["z"])), axis=1) | |
| opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] | |
| features_dc = np.zeros((xyz.shape[0], 3, 1)) | |
| features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) | |
| features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) | |
| features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) | |
| extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] | |
| extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) | |
| # | |
| if len(extra_f_names) == 0: | |
| # loaded ply has no SH coefficients | |
| # TODO: does this mean that features_dc probably encodes RGB which needs to be converted to SH0? | |
| # all other features are zero | |
| print("Loaded PLY has no SH coefficients, only DC features.") | |
| features_extra = np.zeros((xyz.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)) | |
| elif len(extra_f_names) == (3 * (max_sh_degree + 1) ** 2 - 3): | |
| # loaded ply has full SH coefficients | |
| features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) | |
| for idx, attr_name in enumerate(extra_f_names): | |
| features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) | |
| # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) | |
| features_extra = features_extra.reshape((features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)) | |
| else: | |
| # not know how to handle | |
| raise ValueError("Mismatch in number of SH coefficients.") | |
| scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] | |
| scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) | |
| scales = np.zeros((xyz.shape[0], len(scale_names))) | |
| for idx, attr_name in enumerate(scale_names): | |
| scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) | |
| rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] | |
| rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) | |
| rots = np.zeros((xyz.shape[0], len(rot_names))) | |
| for idx, attr_name in enumerate(rot_names): | |
| rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) | |
| # Create Gaussian object | |
| means = torch.tensor(xyz, dtype=torch.float32) # [P, 3] | |
| opacities = torch.tensor(opacities, dtype=torch.float32).squeeze(-1) # [P] | |
| opacities = torch.sigmoid(opacities) # convert to post-activation space | |
| harmonics = torch.zeros((xyz.shape[0], 3, (max_sh_degree + 1) ** 2), dtype=torch.float32) # [P, 3, d_sh] | |
| harmonics[:, :, 0] = torch.tensor(features_dc[:, :, 0], dtype=torch.float32) | |
| harmonics[:, :, 1:] = torch.tensor(features_extra, dtype=torch.float32) | |
| scales = torch.tensor(scales, dtype=torch.float32) | |
| scales = torch.exp(scales) # convert to post-activation space | |
| quats = torch.tensor(rots, dtype=torch.float32) # in wxyz | |
| quats = quats[:, [1, 2, 3, 0]] # convert to xyzw | |
| quats = F.normalize(quats, dim=-1) # match 3DGS-LM get_rotation which normalizes before rendering | |
| return Gaussians( | |
| means=means.unsqueeze(0), | |
| harmonics=harmonics.unsqueeze(0), | |
| opacities=opacities.unsqueeze(0), | |
| scales=scales.unsqueeze(0), | |
| rotations=quats.unsqueeze(0), | |
| rotations_unnorm=quats.unsqueeze(0), | |
| ) | |