| import math |
| import struct |
| from io import BytesIO |
| from typing import Literal, Optional |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def sh2rgb(sh: torch.Tensor) -> torch.Tensor: |
| """Convert Sphere Harmonics to RGB |
| |
| Args: |
| sh (torch.Tensor): SH tensor |
| |
| Returns: |
| torch.Tensor: RGB tensor |
| """ |
| C0 = 0.28209479177387814 |
| return sh * C0 + 0.5 |
|
|
|
|
| def part1by2_vec(x: torch.Tensor) -> torch.Tensor: |
| """Interleave bits of x with 0s |
| |
| Args: |
| x (torch.Tensor): Input tensor. Shape (N,) |
| |
| Returns: |
| torch.Tensor: Output tensor. Shape (N,) |
| """ |
|
|
| x = x & 0x000003FF |
| x = (x ^ (x << 16)) & 0xFF0000FF |
| x = (x ^ (x << 8)) & 0x0300F00F |
| x = (x ^ (x << 4)) & 0x030C30C3 |
| x = (x ^ (x << 2)) & 0x09249249 |
| return x |
|
|
|
|
| def encode_morton3_vec( |
| x: torch.Tensor, y: torch.Tensor, z: torch.Tensor |
| ) -> torch.Tensor: |
| """Compute Morton codes for 3D coordinates |
| |
| Args: |
| x (torch.Tensor): X coordinates. Shape (N,) |
| y (torch.Tensor): Y coordinates. Shape (N,) |
| z (torch.Tensor): Z coordinates. Shape (N,) |
| Returns: |
| torch.Tensor: Morton codes. Shape (N,) |
| """ |
| return (part1by2_vec(z) << 2) + (part1by2_vec(y) << 1) + part1by2_vec(x) |
|
|
|
|
| def sort_centers(centers: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: |
| """Sort centers based on Morton codes |
| |
| Args: |
| centers (torch.Tensor): Centers. Shape (N, 3) |
| indices (torch.Tensor): Indices. Shape (N,) |
| Returns: |
| torch.Tensor: Sorted indices. Shape (N,) |
| """ |
| |
| min_vals, _ = torch.min(centers, dim=0) |
| max_vals, _ = torch.max(centers, dim=0) |
|
|
| |
| lengths = max_vals - min_vals |
| lengths[lengths == 0] = 1 |
|
|
| |
| scaled_centers = ((centers - min_vals) / lengths * 1024).floor().to(torch.int32) |
|
|
| |
| x, y, z = scaled_centers[:, 0], scaled_centers[:, 1], scaled_centers[:, 2] |
|
|
| |
| morton = encode_morton3_vec(x, y, z) |
|
|
| |
| sorted_indices = indices[torch.argsort(morton).to(indices.device)] |
|
|
| return sorted_indices |
|
|
|
|
| def pack_unorm(value: torch.Tensor, bits: int) -> torch.Tensor: |
| """Pack a floating point value into an unsigned integer with a given number of bits. |
| |
| Args: |
| value (torch.Tensor): Floating point value to pack. Shape (N,) |
| bits (int): Number of bits to pack into. |
| |
| Returns: |
| torch.Tensor: Packed value. Shape (N,) |
| """ |
|
|
| t = (1 << bits) - 1 |
| packed = torch.clamp((value * t + 0.5).floor(), min=0, max=t) |
| |
| return packed.to(torch.int64) |
|
|
|
|
| def pack_111011(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
| """Pack three floating point values into a 32-bit integer with 11, 10, and 11 bits. |
| |
| Args: |
| x (torch.Tensor): X component. Shape (N,) |
| y (torch.Tensor): Y component. Shape (N,) |
| z (torch.Tensor): Z component. Shape (N,) |
| Returns: |
| torch.Tensor: Packed values. Shape (N,) |
| """ |
| |
| packed_x = pack_unorm(x, 11) << 21 |
| packed_y = pack_unorm(y, 10) << 11 |
| packed_z = pack_unorm(z, 11) |
|
|
| |
| return packed_x | packed_y | packed_z |
|
|
|
|
| def pack_8888( |
| x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, w: torch.Tensor |
| ) -> torch.Tensor: |
| """Pack four floating point values into a 32-bit integer with 8 bits each. |
| |
| Args: |
| x (torch.Tensor): X component. Shape (N,) |
| y (torch.Tensor): Y component. Shape (N,) |
| z (torch.Tensor): Z component. Shape (N,) |
| w (torch.Tensor): W component. Shape (N,) |
| Returns: |
| torch.Tensor: Packed values. Shape (N,) |
| """ |
| |
| packed_x = pack_unorm(x, 8) << 24 |
| packed_y = pack_unorm(y, 8) << 16 |
| packed_z = pack_unorm(z, 8) << 8 |
| packed_w = pack_unorm(w, 8) |
|
|
| |
| return packed_x | packed_y | packed_z | packed_w |
|
|
|
|
| def pack_rotation(q: torch.Tensor) -> torch.Tensor: |
| """Pack a quaternion into a 32-bit integer. |
| |
| Args: |
| q (torch.Tensor): Quaternions. Shape (N, 4) |
| |
| Returns: |
| torch.Tensor: Packed values. Shape (N,) |
| """ |
|
|
| |
| norms = torch.linalg.norm(q, dim=-1, keepdim=True) |
| q = q / norms |
|
|
| |
| largest_components = torch.argmax(torch.abs(q), dim=-1) |
|
|
| |
| batch_indices = torch.arange(q.size(0), device=q.device) |
| largest_values = q[batch_indices, largest_components] |
| flip_mask = largest_values < 0 |
| q[flip_mask] *= -1 |
|
|
| |
| precomputed_indices = torch.tensor( |
| [[1, 2, 3], [0, 2, 3], [0, 1, 3], [0, 1, 2]], dtype=torch.long, device=q.device |
| ) |
|
|
| |
| pack_indices = precomputed_indices[largest_components] |
| components_to_pack = q[batch_indices[:, None], pack_indices] |
|
|
| |
| norm = math.sqrt(2) * 0.5 |
| scaled = components_to_pack * norm + 0.5 |
| packed = pack_unorm(scaled, 10) |
|
|
| |
| largest_packed = largest_components.to(torch.int64) << 30 |
| c0_packed = packed[:, 0] << 20 |
| c1_packed = packed[:, 1] << 10 |
| c2_packed = packed[:, 2] |
|
|
| result = largest_packed | c0_packed | c1_packed | c2_packed |
| return result |
|
|
|
|
| def splat2ply_bytes_compressed( |
| means: torch.Tensor, |
| scales: torch.Tensor, |
| quats: torch.Tensor, |
| opacities: torch.Tensor, |
| sh0: torch.Tensor, |
| shN: torch.Tensor, |
| chunk_max_size: int = 256, |
| opacity_threshold: float = 1 / 255, |
| ) -> bytes: |
| """Return the binary compressed Ply file. Used by Supersplat viewer. |
| |
| Args: |
| means (torch.Tensor): Splat means. Shape (N, 3) |
| scales (torch.Tensor): Splat scales. Shape (N, 3) |
| quats (torch.Tensor): Splat quaternions. Shape (N, 4) |
| opacities (torch.Tensor): Splat opacities. Shape (N,) |
| sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) |
| shN (torch.Tensor): Spherical harmonics. Shape (N, K*3) |
| chunk_max_size (int): Maximum number of splats per chunk. Default: 256 |
| opacity_threshold (float): Opacity threshold. Default: 1 / 255 |
| |
| Returns: |
| bytes: Binary compressed Ply file representing the model. |
| """ |
|
|
| |
| mask = torch.sigmoid(opacities) > opacity_threshold |
| means = means[mask] |
| scales = scales[mask] |
| sh0_colors = sh2rgb(sh0) |
| sh0_colors = sh0_colors[mask] |
| shN = shN[mask] |
| quats = quats[mask] |
| opacities = opacities[mask] |
|
|
| num_splats = means.shape[0] |
| n_chunks = num_splats // chunk_max_size + (num_splats % chunk_max_size != 0) |
| indices = torch.arange(num_splats) |
| indices = sort_centers(means, indices) |
|
|
| float_properties = [ |
| "min_x", |
| "min_y", |
| "min_z", |
| "max_x", |
| "max_y", |
| "max_z", |
| "min_scale_x", |
| "min_scale_y", |
| "min_scale_z", |
| "max_scale_x", |
| "max_scale_y", |
| "max_scale_z", |
| "min_r", |
| "min_g", |
| "min_b", |
| "max_r", |
| "max_g", |
| "max_b", |
| ] |
| uint_properties = [ |
| "packed_position", |
| "packed_rotation", |
| "packed_scale", |
| "packed_color", |
| ] |
| buffer = BytesIO() |
|
|
| |
| buffer.write(b"ply\n") |
| buffer.write(b"format binary_little_endian 1.0\n") |
| buffer.write(f"element chunk {n_chunks}\n".encode()) |
| for prop in float_properties: |
| buffer.write(f"property float {prop}\n".encode()) |
| buffer.write(f"element vertex {num_splats}\n".encode()) |
| for prop in uint_properties: |
| buffer.write(f"property uint {prop}\n".encode()) |
| buffer.write(f"element sh {num_splats}\n".encode()) |
| for j in range(shN.shape[1]): |
| buffer.write(f"property uchar f_rest_{j}\n".encode()) |
| buffer.write(b"end_header\n") |
|
|
| chunk_data = [] |
| splat_data = [] |
| sh_data = [] |
| for chunk_idx in range(n_chunks): |
| chunk_end_idx = min((chunk_idx + 1) * chunk_max_size, num_splats) |
| chunk_start_idx = chunk_idx * chunk_max_size |
| splat_idxs = indices[chunk_start_idx:chunk_end_idx] |
|
|
| |
| |
| chunk_means = means[splat_idxs] |
| min_means = torch.min(chunk_means, dim=0).values |
| max_means = torch.max(chunk_means, dim=0).values |
| mean_bounds = torch.cat([min_means, max_means]) |
| |
| chunk_scales = scales[splat_idxs] |
| min_scales = torch.min(chunk_scales, dim=0).values |
| max_scales = torch.max(chunk_scales, dim=0).values |
| min_scales = torch.clamp(min_scales, -20, 20) |
| max_scales = torch.clamp(max_scales, -20, 20) |
| scale_bounds = torch.cat([min_scales, max_scales]) |
| |
| chunk_colors = sh0_colors[splat_idxs] |
| min_colors = torch.min(chunk_colors, dim=0).values |
| max_colors = torch.max(chunk_colors, dim=0).values |
| color_bounds = torch.cat([min_colors, max_colors]) |
| chunk_data.extend([mean_bounds, scale_bounds, color_bounds]) |
|
|
| |
| |
| normalized_means = (chunk_means - min_means) / (max_means - min_means) |
| means_i = pack_111011( |
| normalized_means[:, 0], |
| normalized_means[:, 1], |
| normalized_means[:, 2], |
| ) |
| |
| chunk_quats = quats[splat_idxs] |
| quat_i = pack_rotation(chunk_quats) |
| |
| normalized_scales = (chunk_scales - min_scales) / (max_scales - min_scales) |
| scales_i = pack_111011( |
| normalized_scales[:, 0], |
| normalized_scales[:, 1], |
| normalized_scales[:, 2], |
| ) |
| |
| normalized_colors = (chunk_colors - min_colors) / (max_colors - min_colors) |
| chunk_opacities = opacities[splat_idxs] |
| chunk_opacities = 1 / (1 + torch.exp(-chunk_opacities)) |
| chunk_opacities = chunk_opacities.unsqueeze(-1) |
| normalized_colors_i = torch.cat([normalized_colors, chunk_opacities], dim=-1) |
| color_i = pack_8888( |
| normalized_colors_i[:, 0], |
| normalized_colors_i[:, 1], |
| normalized_colors_i[:, 2], |
| normalized_colors_i[:, 3], |
| ) |
| splat_data_chunk = torch.stack([means_i, quat_i, scales_i, color_i], dim=1) |
| splat_data_chunk = splat_data_chunk.ravel().to(torch.int64) |
| splat_data.extend([splat_data_chunk]) |
|
|
| |
| shN_chunk = shN[splat_idxs] |
| shN_chunk_quantized = (shN_chunk / 8 + 0.5) * 256 |
| shN_chunk_quantized = torch.clamp(torch.trunc(shN_chunk_quantized), 0, 255) |
| shN_chunk_quantized = shN_chunk_quantized.to(torch.uint8) |
| sh_data.extend([shN_chunk_quantized.ravel()]) |
|
|
| float_dtype = np.dtype(np.float32).newbyteorder("<") |
| uint32_dtype = np.dtype(np.uint32).newbyteorder("<") |
| uint8_dtype = np.dtype(np.uint8) |
|
|
| buffer.write( |
| torch.cat(chunk_data).detach().cpu().numpy().astype(float_dtype).tobytes() |
| ) |
| buffer.write( |
| torch.cat(splat_data).detach().cpu().numpy().astype(uint32_dtype).tobytes() |
| ) |
| buffer.write( |
| torch.cat(sh_data).detach().cpu().numpy().astype(uint8_dtype).tobytes() |
| ) |
|
|
| return buffer.getvalue() |
|
|
|
|
| def splat2ply_bytes( |
| means: torch.Tensor, |
| scales: torch.Tensor, |
| quats: torch.Tensor, |
| opacities: torch.Tensor, |
| sh0: torch.Tensor, |
| shN: torch.Tensor, |
| ) -> bytes: |
| """Return the binary Ply file. Supported by almost all viewers. |
| |
| Args: |
| means (torch.Tensor): Splat means. Shape (N, 3) |
| scales (torch.Tensor): Splat scales. Shape (N, 3) |
| quats (torch.Tensor): Splat quaternions. Shape (N, 4) |
| opacities (torch.Tensor): Splat opacities. Shape (N,) |
| sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) |
| shN (torch.Tensor): Spherical harmonics. Shape (N, K*3) |
| |
| Returns: |
| bytes: Binary Ply file representing the model. |
| """ |
|
|
| num_splats = means.shape[0] |
| buffer = BytesIO() |
| |
| |
| buffer.write(b"ply\n") |
| buffer.write(b"format binary_little_endian 1.0\n") |
| buffer.write(f"element vertex {num_splats}\n".encode()) |
| buffer.write(b"property float x\n") |
| buffer.write(b"property float y\n") |
| buffer.write(b"property float z\n") |
| for i, data in enumerate([sh0, shN]): |
| prefix = "f_dc" if i == 0 else "f_rest" |
| for j in range(data.shape[1]): |
| buffer.write(f"property float {prefix}_{j}\n".encode()) |
| buffer.write(b"property float opacity\n") |
| for i in range(scales.shape[1]): |
| buffer.write(f"property float scale_{i}\n".encode()) |
| for i in range(quats.shape[1]): |
| buffer.write(f"property float rot_{i}\n".encode()) |
| buffer.write(b"end_header\n") |
|
|
| |
| splat_data = torch.cat( |
| [means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1 |
| ) |
| |
| splat_data = splat_data.to(torch.float32) |
|
|
| |
| float_dtype = np.dtype(np.float32).newbyteorder("<") |
| buffer.write(splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()) |
|
|
| return buffer.getvalue() |
|
|
|
|
| def splat2splat_bytes( |
| means: torch.Tensor, |
| scales: torch.Tensor, |
| quats: torch.Tensor, |
| opacities: torch.Tensor, |
| sh0: torch.Tensor, |
| ) -> bytes: |
| """Return the binary Splat file. Supported by antimatter15 viewer. |
| |
| Args: |
| means (torch.Tensor): Splat means. Shape (N, 3) |
| scales (torch.Tensor): Splat scales. Shape (N, 3) |
| quats (torch.Tensor): Splat quaternions. Shape (N, 4) |
| opacities (torch.Tensor): Splat opacities. Shape (N,) |
| sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) |
| |
| Returns: |
| bytes: Binary Splat file representing the model. |
| """ |
|
|
| |
| scales = torch.exp(scales) |
| sh0_color = sh2rgb(sh0) |
| colors = torch.cat([sh0_color, torch.sigmoid(opacities).unsqueeze(-1)], dim=1) |
| colors = (colors * 255).clamp(0, 255).to(torch.uint8) |
|
|
| rots = (quats / torch.linalg.norm(quats, dim=1, keepdim=True)) * 128 + 128 |
| rots = rots.clamp(0, 255).to(torch.uint8) |
|
|
| |
| num_splats = means.shape[0] |
| indices = sort_centers(means, torch.arange(num_splats)) |
|
|
| |
| means = means[indices] |
| scales = scales[indices] |
| colors = colors[indices] |
| rots = rots[indices] |
|
|
| float_dtype = np.dtype(np.float32).newbyteorder("<") |
| means_np = means.detach().cpu().numpy().astype(float_dtype) |
| scales_np = scales.detach().cpu().numpy().astype(float_dtype) |
| colors_np = colors.detach().cpu().numpy().astype(np.uint8) |
| rots_np = rots.detach().cpu().numpy().astype(np.uint8) |
|
|
| buffer = BytesIO() |
| for i in range(num_splats): |
| buffer.write(means_np[i].tobytes()) |
| buffer.write(scales_np[i].tobytes()) |
| buffer.write(colors_np[i].tobytes()) |
| buffer.write(rots_np[i].tobytes()) |
|
|
| return buffer.getvalue() |
|
|
|
|
| def export_splats( |
| means: torch.Tensor, |
| scales: torch.Tensor, |
| quats: torch.Tensor, |
| opacities: torch.Tensor, |
| sh0: torch.Tensor, |
| shN: torch.Tensor, |
| format: Literal["ply", "splat", "ply_compressed"] = "ply", |
| save_to: Optional[str] = None, |
| ) -> bytes: |
| """Export a Gaussian Splats model to bytes. |
| The three supported formats are: |
| - ply: A standard PLY file format. Supported by most viewers. |
| - splat: A custom Splat file format. Supported by antimatter15 viewer. |
| - ply_compressed: A compressed PLY file format. Used by Supersplat viewer. |
| |
| Args: |
| means (torch.Tensor): Splat means. Shape (N, 3) |
| scales (torch.Tensor): Splat scales. Shape (N, 3) |
| quats (torch.Tensor): Splat quaternions. Shape (N, 4) |
| opacities (torch.Tensor): Splat opacities. Shape (N,) |
| sh0 (torch.Tensor): Spherical harmonics. Shape (N, 1, 3) |
| shN (torch.Tensor): Spherical harmonics. Shape (N, K, 3) |
| format (str): Export format. Options: "ply", "splat", "ply_compressed". Default: "ply" |
| save_to (str): Output file path. If provided, the bytes will be written to file. |
| """ |
| total_splats = means.shape[0] |
| assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)" |
| assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)" |
| assert quats.shape == (total_splats, 4), "Quaternions must be of shape (N, 4)" |
| assert opacities.shape == (total_splats,), "Opacities must be of shape (N,)" |
| assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)" |
| assert ( |
| shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3 |
| ), f"shN must be of shape (N, K, 3), got {shN.shape}" |
| |
| |
| sh0 = sh0.squeeze(1) |
| shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) |
|
|
| |
| invalid_mask = ( |
| torch.isnan(means).any(dim=1) |
| | torch.isinf(means).any(dim=1) |
| | torch.isnan(scales).any(dim=1) |
| | torch.isinf(scales).any(dim=1) |
| | torch.isnan(quats).any(dim=1) |
| | torch.isinf(quats).any(dim=1) |
| | torch.isnan(opacities).any(dim=0) |
| | torch.isinf(opacities).any(dim=0) |
| | torch.isnan(sh0).any(dim=1) |
| | torch.isinf(sh0).any(dim=1) |
| | torch.isnan(shN).any(dim=1) |
| | torch.isinf(shN).any(dim=1) |
| ) |
|
|
| |
| valid_mask = ~invalid_mask |
| means = means[valid_mask] |
| scales = scales[valid_mask] |
| quats = quats[valid_mask] |
| opacities = opacities[valid_mask] |
| sh0 = sh0[valid_mask] |
| shN = shN[valid_mask] |
|
|
| if format == "ply": |
| data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN) |
| elif format == "splat": |
| data = splat2splat_bytes(means, scales, quats, opacities, sh0) |
| elif format == "ply_compressed": |
| data = splat2ply_bytes_compressed(means, scales, quats, opacities, sh0, shN) |
| else: |
| raise ValueError(f"Unsupported format: {format}") |
|
|
| if save_to: |
| with open(save_to, "wb") as binary_file: |
| binary_file.write(data) |
|
|
| return data |
|
|