Learn2Splat / optgs /model /ply_export.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from jaxtyping import Float
from plyfile import PlyData, PlyElement
from torch import Tensor
from optgs.model.types import Gaussians
from optgs.scene_trainer.gaussian_module import GaussiansModule
def construct_list_of_attributes(num_rest: int) -> list[str]:
attributes = ["x", "y", "z", "nx", "ny", "nz"]
for i in range(3):
attributes.append(f"f_dc_{i}")
for i in range(num_rest):
attributes.append(f"f_rest_{i}")
attributes.append("opacity")
for i in range(3):
attributes.append(f"scale_{i}")
for i in range(4):
attributes.append(f"rot_{i}")
return attributes
def export_ply(
# extrinsics: Float[Tensor, "4 4"],
means: Float[Tensor, "gaussian 3"],
scales: Float[Tensor, "gaussian 3"],
rotations: Float[Tensor, "gaussian 4"],
harmonics: Float[Tensor, "gaussian 3 d_sh"],
opacities: Float[Tensor, "gaussian"],
path: Path,
# align_to_view: bool = False, # whether to align world space to the view space (camera space) of the extrinsics
):
means = means.detach().cpu().numpy()
scales = scales.log().detach().cpu().numpy()
rotations = rotations.detach().cpu().numpy()
harmonics = harmonics.detach() # .cpu().numpy()
opacities = torch.logit(opacities[..., None]).detach().cpu().numpy()
num_rest = 3 * (harmonics.shape[-1] - 1)
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(num_rest)]
elements = np.empty(means.shape[0], dtype=dtype_full)
attributes = (
means,
np.zeros_like(means),
harmonics[..., 0].cpu().numpy(),
harmonics[..., 1:].flatten(start_dim=1).cpu().numpy(),
opacities,
scales,
rotations,
)
attributes = np.concatenate(attributes, axis=1)
elements[:] = list(map(tuple, attributes))
path.parent.mkdir(exist_ok=True, parents=True)
PlyData([PlyElement.describe(elements, "vertex")]).write(path)
def save_gaussian_ply(
gaussians: Gaussians | GaussiansModule,
save_path,
save_all_gaussians=True, # no trim
):
"""
Save Gaussians to a .ply file for visualization.
The saved object will have opacities and scales in the pre-activation space,
i.e., before applying the activation functions (sigmoid for opacity, exp for scales).
"""
if not save_all_gaussians:
raise NotImplementedError("Not implemented yet.")
if isinstance(gaussians, GaussiansModule):
# no batch dimension
means = gaussians.means # [H*W, 3]
rotations = gaussians.rotations # [H*W, 4] in xyzw
scales = gaussians.scales # [H*W, 3]
opacities = gaussians.opacities # [H*W]
harmonics = gaussians.harmonics # [H*W, 3, d_sh]
elif isinstance(gaussians, Gaussians):
assert gaussians.means.shape[0] == 1, "Batch size > 1 not supported for saving ply."
means = gaussians.means[0] # [H*W, 3]
rotations = F.normalize(gaussians.rotations_unnorm[0], dim=-1) # [H*W, 4] in xyzw
scales = gaussians.scales[0] # [H*W, 3]
opacities = gaussians.opacities[0] # [H*W]
harmonics = gaussians.harmonics[0] # [H*W, 3, d_sh]
# export_ply expects activated values (post-exp scales, post-sigmoid opacities)
# and applies inverse activation internally. If values are already deactivated,
# we must activate them first to avoid double inverse activation.
if not gaussians.stores_activated:
scales = torch.exp(scales)
opacities = torch.sigmoid(opacities)
else:
raise ValueError(f"Unknown type of gaussians: {type(gaussians)}")
# convert to wxyz for saving
rotations = rotations[:, [3, 0, 1, 2]] # [H*W, 4] in wxyz
# This fn invert activation of opacity and scales (for standard gaussian object, loaded by viewer)
export_ply(
means=means,
scales=scales,
rotations=rotations,
harmonics=harmonics, # [H*W, 3, d_sh]
opacities=opacities,
path=save_path,
)
def load_gaussians_ply(path, max_sh_degree=3) -> Gaussians:
""" Load Gaussians from a .ply file saved by export_ply().
The loaded object will have opacities and scales in the pre-activation space,
i.e., before applying the activation functions (sigmoid for opacity, expfor scales).
"""
plydata = PlyData.read(path)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
#
if len(extra_f_names) == 0:
# loaded ply has no SH coefficients
# TODO: does this mean that features_dc probably encodes RGB which needs to be converted to SH0?
# all other features are zero
print("Loaded PLY has no SH coefficients, only DC features.")
features_extra = np.zeros((xyz.shape[0], 3, (max_sh_degree + 1) ** 2 - 1))
elif len(extra_f_names) == (3 * (max_sh_degree + 1) ** 2 - 3):
# loaded ply has full SH coefficients
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
features_extra = features_extra.reshape((features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1))
else:
# not know how to handle
raise ValueError("Mismatch in number of SH coefficients.")
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Create Gaussian object
means = torch.tensor(xyz, dtype=torch.float32) # [P, 3]
opacities = torch.tensor(opacities, dtype=torch.float32).squeeze(-1) # [P]
opacities = torch.sigmoid(opacities) # convert to post-activation space
harmonics = torch.zeros((xyz.shape[0], 3, (max_sh_degree + 1) ** 2), dtype=torch.float32) # [P, 3, d_sh]
harmonics[:, :, 0] = torch.tensor(features_dc[:, :, 0], dtype=torch.float32)
harmonics[:, :, 1:] = torch.tensor(features_extra, dtype=torch.float32)
scales = torch.tensor(scales, dtype=torch.float32)
scales = torch.exp(scales) # convert to post-activation space
quats = torch.tensor(rots, dtype=torch.float32) # in wxyz
quats = quats[:, [1, 2, 3, 0]] # convert to xyzw
quats = F.normalize(quats, dim=-1) # match 3DGS-LM get_rotation which normalizes before rendering
return Gaussians(
means=means.unsqueeze(0),
harmonics=harmonics.unsqueeze(0),
opacities=opacities.unsqueeze(0),
scales=scales.unsqueeze(0),
rotations=quats.unsqueeze(0),
rotations_unnorm=quats.unsqueeze(0),
)