Spaces:
Runtime error
Runtime error
| import math | |
| import struct | |
| from io import BytesIO | |
| from typing import Literal, Optional | |
| import numpy as np | |
| import torch | |
| def sh2rgb(sh: torch.Tensor) -> torch.Tensor: | |
| """Convert Sphere Harmonics to RGB | |
| Args: | |
| sh (torch.Tensor): SH tensor | |
| Returns: | |
| torch.Tensor: RGB tensor | |
| """ | |
| C0 = 0.28209479177387814 | |
| return sh * C0 + 0.5 | |
| def part1by2_vec(x: torch.Tensor) -> torch.Tensor: | |
| """Interleave bits of x with 0s | |
| Args: | |
| x (torch.Tensor): Input tensor. Shape (N,) | |
| Returns: | |
| torch.Tensor: Output tensor. Shape (N,) | |
| """ | |
| x = x & 0x000003FF | |
| x = (x ^ (x << 16)) & 0xFF0000FF | |
| x = (x ^ (x << 8)) & 0x0300F00F | |
| x = (x ^ (x << 4)) & 0x030C30C3 | |
| x = (x ^ (x << 2)) & 0x09249249 | |
| return x | |
| def encode_morton3_vec( | |
| x: torch.Tensor, y: torch.Tensor, z: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Compute Morton codes for 3D coordinates | |
| Args: | |
| x (torch.Tensor): X coordinates. Shape (N,) | |
| y (torch.Tensor): Y coordinates. Shape (N,) | |
| z (torch.Tensor): Z coordinates. Shape (N,) | |
| Returns: | |
| torch.Tensor: Morton codes. Shape (N,) | |
| """ | |
| return (part1by2_vec(z) << 2) + (part1by2_vec(y) << 1) + part1by2_vec(x) | |
| def sort_centers(centers: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: | |
| """Sort centers based on Morton codes | |
| Args: | |
| centers (torch.Tensor): Centers. Shape (N, 3) | |
| indices (torch.Tensor): Indices. Shape (N,) | |
| Returns: | |
| torch.Tensor: Sorted indices. Shape (N,) | |
| """ | |
| # Compute min and max values in a single operation | |
| min_vals, _ = torch.min(centers, dim=0) | |
| max_vals, _ = torch.max(centers, dim=0) | |
| # Compute the scaling factors | |
| lengths = max_vals - min_vals | |
| lengths[lengths == 0] = 1 # Prevent division by zero | |
| # Normalize and scale to 10-bit integer range (0-1024) | |
| scaled_centers = ((centers - min_vals) / lengths * 1024).floor().to(torch.int32) | |
| # Extract x, y, z coordinates | |
| x, y, z = scaled_centers[:, 0], scaled_centers[:, 1], scaled_centers[:, 2] | |
| # Compute Morton codes using vectorized operations | |
| morton = encode_morton3_vec(x, y, z) | |
| # Sort indices based on Morton codes | |
| sorted_indices = indices[torch.argsort(morton).to(indices.device)] | |
| return sorted_indices | |
| def pack_unorm(value: torch.Tensor, bits: int) -> torch.Tensor: | |
| """Pack a floating point value into an unsigned integer with a given number of bits. | |
| Args: | |
| value (torch.Tensor): Floating point value to pack. Shape (N,) | |
| bits (int): Number of bits to pack into. | |
| Returns: | |
| torch.Tensor: Packed value. Shape (N,) | |
| """ | |
| t = (1 << bits) - 1 | |
| packed = torch.clamp((value * t + 0.5).floor(), min=0, max=t) | |
| # Convert to integer type | |
| return packed.to(torch.int64) | |
| def pack_111011(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
| """Pack three floating point values into a 32-bit integer with 11, 10, and 11 bits. | |
| Args: | |
| x (torch.Tensor): X component. Shape (N,) | |
| y (torch.Tensor): Y component. Shape (N,) | |
| z (torch.Tensor): Z component. Shape (N,) | |
| Returns: | |
| torch.Tensor: Packed values. Shape (N,) | |
| """ | |
| # Pack each component using pack_unorm | |
| packed_x = pack_unorm(x, 11) << 21 | |
| packed_y = pack_unorm(y, 10) << 11 | |
| packed_z = pack_unorm(z, 11) | |
| # Combine the packed values using bitwise OR | |
| return packed_x | packed_y | packed_z | |
| def pack_8888( | |
| x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, w: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Pack four floating point values into a 32-bit integer with 8 bits each. | |
| Args: | |
| x (torch.Tensor): X component. Shape (N,) | |
| y (torch.Tensor): Y component. Shape (N,) | |
| z (torch.Tensor): Z component. Shape (N,) | |
| w (torch.Tensor): W component. Shape (N,) | |
| Returns: | |
| torch.Tensor: Packed values. Shape (N,) | |
| """ | |
| # Pack each component using pack_unorm | |
| packed_x = pack_unorm(x, 8) << 24 | |
| packed_y = pack_unorm(y, 8) << 16 | |
| packed_z = pack_unorm(z, 8) << 8 | |
| packed_w = pack_unorm(w, 8) | |
| # Combine the packed values using bitwise OR | |
| return packed_x | packed_y | packed_z | packed_w | |
| def pack_rotation(q: torch.Tensor) -> torch.Tensor: | |
| """Pack a quaternion into a 32-bit integer. | |
| Args: | |
| q (torch.Tensor): Quaternions. Shape (N, 4) | |
| Returns: | |
| torch.Tensor: Packed values. Shape (N,) | |
| """ | |
| # Normalize each quaternion | |
| norms = torch.linalg.norm(q, dim=-1, keepdim=True) | |
| q = q / norms | |
| # Find the largest component index for each quaternion | |
| largest_components = torch.argmax(torch.abs(q), dim=-1) | |
| # Flip quaternions where the largest component is negative | |
| batch_indices = torch.arange(q.size(0), device=q.device) | |
| largest_values = q[batch_indices, largest_components] | |
| flip_mask = largest_values < 0 | |
| q[flip_mask] *= -1 | |
| # Precomputed indices for the components to pack (excluding largest) | |
| precomputed_indices = torch.tensor( | |
| [[1, 2, 3], [0, 2, 3], [0, 1, 3], [0, 1, 2]], dtype=torch.long, device=q.device | |
| ) | |
| # Gather components to pack for each quaternion | |
| pack_indices = precomputed_indices[largest_components] | |
| components_to_pack = q[batch_indices[:, None], pack_indices] | |
| # Scale and pack each component into 10-bit integers | |
| norm = math.sqrt(2) * 0.5 | |
| scaled = components_to_pack * norm + 0.5 | |
| packed = pack_unorm(scaled, 10) # Assuming pack_unorm is vectorized | |
| # Combine into the final 32-bit integer | |
| largest_packed = largest_components.to(torch.int64) << 30 | |
| c0_packed = packed[:, 0] << 20 | |
| c1_packed = packed[:, 1] << 10 | |
| c2_packed = packed[:, 2] | |
| result = largest_packed | c0_packed | c1_packed | c2_packed | |
| return result | |
| def splat2ply_bytes_compressed( | |
| means: torch.Tensor, | |
| scales: torch.Tensor, | |
| quats: torch.Tensor, | |
| opacities: torch.Tensor, | |
| sh0: torch.Tensor, | |
| shN: torch.Tensor, | |
| chunk_max_size: int = 256, | |
| opacity_threshold: float = 1 / 255, | |
| ) -> bytes: | |
| """Return the binary compressed Ply file. Used by Supersplat viewer. | |
| Args: | |
| means (torch.Tensor): Splat means. Shape (N, 3) | |
| scales (torch.Tensor): Splat scales. Shape (N, 3) | |
| quats (torch.Tensor): Splat quaternions. Shape (N, 4) | |
| opacities (torch.Tensor): Splat opacities. Shape (N,) | |
| sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) | |
| shN (torch.Tensor): Spherical harmonics. Shape (N, K*3) | |
| chunk_max_size (int): Maximum number of splats per chunk. Default: 256 | |
| opacity_threshold (float): Opacity threshold. Default: 1 / 255 | |
| Returns: | |
| bytes: Binary compressed Ply file representing the model. | |
| """ | |
| # Filter the splats with too low opacity | |
| mask = torch.sigmoid(opacities) > opacity_threshold | |
| means = means[mask] | |
| scales = scales[mask] | |
| sh0_colors = sh2rgb(sh0) | |
| sh0_colors = sh0_colors[mask] | |
| shN = shN[mask] | |
| quats = quats[mask] | |
| opacities = opacities[mask] | |
| num_splats = means.shape[0] | |
| n_chunks = num_splats // chunk_max_size + (num_splats % chunk_max_size != 0) | |
| indices = torch.arange(num_splats) | |
| indices = sort_centers(means, indices) | |
| float_properties = [ | |
| "min_x", | |
| "min_y", | |
| "min_z", | |
| "max_x", | |
| "max_y", | |
| "max_z", | |
| "min_scale_x", | |
| "min_scale_y", | |
| "min_scale_z", | |
| "max_scale_x", | |
| "max_scale_y", | |
| "max_scale_z", | |
| "min_r", | |
| "min_g", | |
| "min_b", | |
| "max_r", | |
| "max_g", | |
| "max_b", | |
| ] | |
| uint_properties = [ | |
| "packed_position", | |
| "packed_rotation", | |
| "packed_scale", | |
| "packed_color", | |
| ] | |
| buffer = BytesIO() | |
| # Write PLY header | |
| buffer.write(b"ply\n") | |
| buffer.write(b"format binary_little_endian 1.0\n") | |
| buffer.write(f"element chunk {n_chunks}\n".encode()) | |
| for prop in float_properties: | |
| buffer.write(f"property float {prop}\n".encode()) | |
| buffer.write(f"element vertex {num_splats}\n".encode()) | |
| for prop in uint_properties: | |
| buffer.write(f"property uint {prop}\n".encode()) | |
| buffer.write(f"element sh {num_splats}\n".encode()) | |
| for j in range(shN.shape[1]): | |
| buffer.write(f"property uchar f_rest_{j}\n".encode()) | |
| buffer.write(b"end_header\n") | |
| chunk_data = [] | |
| splat_data = [] | |
| sh_data = [] | |
| for chunk_idx in range(n_chunks): | |
| chunk_end_idx = min((chunk_idx + 1) * chunk_max_size, num_splats) | |
| chunk_start_idx = chunk_idx * chunk_max_size | |
| splat_idxs = indices[chunk_start_idx:chunk_end_idx] | |
| # Bounds | |
| # Means | |
| chunk_means = means[splat_idxs] | |
| min_means = torch.min(chunk_means, dim=0).values | |
| max_means = torch.max(chunk_means, dim=0).values | |
| mean_bounds = torch.cat([min_means, max_means]) | |
| # Scales | |
| chunk_scales = scales[splat_idxs] | |
| min_scales = torch.min(chunk_scales, dim=0).values | |
| max_scales = torch.max(chunk_scales, dim=0).values | |
| min_scales = torch.clamp(min_scales, -20, 20) | |
| max_scales = torch.clamp(max_scales, -20, 20) | |
| scale_bounds = torch.cat([min_scales, max_scales]) | |
| # Colors | |
| chunk_colors = sh0_colors[splat_idxs] | |
| min_colors = torch.min(chunk_colors, dim=0).values | |
| max_colors = torch.max(chunk_colors, dim=0).values | |
| color_bounds = torch.cat([min_colors, max_colors]) | |
| chunk_data.extend([mean_bounds, scale_bounds, color_bounds]) | |
| # Quantized properties: | |
| # Means | |
| normalized_means = (chunk_means - min_means) / (max_means - min_means) | |
| means_i = pack_111011( | |
| normalized_means[:, 0], | |
| normalized_means[:, 1], | |
| normalized_means[:, 2], | |
| ) | |
| # Quaternions | |
| chunk_quats = quats[splat_idxs] | |
| quat_i = pack_rotation(chunk_quats) | |
| # Scales | |
| normalized_scales = (chunk_scales - min_scales) / (max_scales - min_scales) | |
| scales_i = pack_111011( | |
| normalized_scales[:, 0], | |
| normalized_scales[:, 1], | |
| normalized_scales[:, 2], | |
| ) | |
| # Colors | |
| normalized_colors = (chunk_colors - min_colors) / (max_colors - min_colors) | |
| chunk_opacities = opacities[splat_idxs] | |
| chunk_opacities = 1 / (1 + torch.exp(-chunk_opacities)) | |
| chunk_opacities = chunk_opacities.unsqueeze(-1) | |
| normalized_colors_i = torch.cat([normalized_colors, chunk_opacities], dim=-1) | |
| color_i = pack_8888( | |
| normalized_colors_i[:, 0], | |
| normalized_colors_i[:, 1], | |
| normalized_colors_i[:, 2], | |
| normalized_colors_i[:, 3], | |
| ) | |
| splat_data_chunk = torch.stack([means_i, quat_i, scales_i, color_i], dim=1) | |
| splat_data_chunk = splat_data_chunk.ravel().to(torch.int64) | |
| splat_data.extend([splat_data_chunk]) | |
| # Quantized spherical harmonics | |
| shN_chunk = shN[splat_idxs] | |
| shN_chunk_quantized = (shN_chunk / 8 + 0.5) * 256 | |
| shN_chunk_quantized = torch.clamp(torch.trunc(shN_chunk_quantized), 0, 255) | |
| shN_chunk_quantized = shN_chunk_quantized.to(torch.uint8) | |
| sh_data.extend([shN_chunk_quantized.ravel()]) | |
| float_dtype = np.dtype(np.float32).newbyteorder("<") | |
| uint32_dtype = np.dtype(np.uint32).newbyteorder("<") | |
| uint8_dtype = np.dtype(np.uint8) | |
| buffer.write( | |
| torch.cat(chunk_data).detach().cpu().numpy().astype(float_dtype).tobytes() | |
| ) | |
| buffer.write( | |
| torch.cat(splat_data).detach().cpu().numpy().astype(uint32_dtype).tobytes() | |
| ) | |
| buffer.write( | |
| torch.cat(sh_data).detach().cpu().numpy().astype(uint8_dtype).tobytes() | |
| ) | |
| return buffer.getvalue() | |
| def splat2ply_bytes( | |
| means: torch.Tensor, | |
| scales: torch.Tensor, | |
| quats: torch.Tensor, | |
| opacities: torch.Tensor, | |
| sh0: torch.Tensor, | |
| shN: torch.Tensor, | |
| ) -> bytes: | |
| """Return the binary Ply file. Supported by almost all viewers. | |
| Args: | |
| means (torch.Tensor): Splat means. Shape (N, 3) | |
| scales (torch.Tensor): Splat scales. Shape (N, 3) | |
| quats (torch.Tensor): Splat quaternions. Shape (N, 4) | |
| opacities (torch.Tensor): Splat opacities. Shape (N,) | |
| sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) | |
| shN (torch.Tensor): Spherical harmonics. Shape (N, K*3) | |
| Returns: | |
| bytes: Binary Ply file representing the model. | |
| """ | |
| num_splats = means.shape[0] | |
| buffer = BytesIO() | |
| # Write PLY header | |
| buffer.write(b"ply\n") | |
| buffer.write(b"format binary_little_endian 1.0\n") | |
| buffer.write(f"element vertex {num_splats}\n".encode()) | |
| buffer.write(b"property float x\n") | |
| buffer.write(b"property float y\n") | |
| buffer.write(b"property float z\n") | |
| for i, data in enumerate([sh0, shN]): | |
| prefix = "f_dc" if i == 0 else "f_rest" | |
| for j in range(data.shape[1]): | |
| buffer.write(f"property float {prefix}_{j}\n".encode()) | |
| buffer.write(b"property float opacity\n") | |
| for i in range(scales.shape[1]): | |
| buffer.write(f"property float scale_{i}\n".encode()) | |
| for i in range(quats.shape[1]): | |
| buffer.write(f"property float rot_{i}\n".encode()) | |
| buffer.write(b"end_header\n") | |
| # Concatenate all tensors in the correct order | |
| splat_data = torch.cat( | |
| [means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1 | |
| ) | |
| # Ensure correct dtype | |
| splat_data = splat_data.to(torch.float32) | |
| # Write binary data | |
| float_dtype = np.dtype(np.float32).newbyteorder("<") | |
| buffer.write(splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()) | |
| return buffer.getvalue() | |
| def splat2splat_bytes( | |
| means: torch.Tensor, | |
| scales: torch.Tensor, | |
| quats: torch.Tensor, | |
| opacities: torch.Tensor, | |
| sh0: torch.Tensor, | |
| ) -> bytes: | |
| """Return the binary Splat file. Supported by antimatter15 viewer. | |
| Args: | |
| means (torch.Tensor): Splat means. Shape (N, 3) | |
| scales (torch.Tensor): Splat scales. Shape (N, 3) | |
| quats (torch.Tensor): Splat quaternions. Shape (N, 4) | |
| opacities (torch.Tensor): Splat opacities. Shape (N,) | |
| sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) | |
| Returns: | |
| bytes: Binary Splat file representing the model. | |
| """ | |
| # Preprocess | |
| scales = torch.exp(scales) | |
| sh0_color = sh2rgb(sh0) | |
| colors = torch.cat([sh0_color, torch.sigmoid(opacities).unsqueeze(-1)], dim=1) | |
| colors = (colors * 255).clamp(0, 255).to(torch.uint8) | |
| rots = (quats / torch.linalg.norm(quats, dim=1, keepdim=True)) * 128 + 128 | |
| rots = rots.clamp(0, 255).to(torch.uint8) | |
| # Sort splats | |
| num_splats = means.shape[0] | |
| indices = sort_centers(means, torch.arange(num_splats)) | |
| # Reorder everything | |
| means = means[indices] | |
| scales = scales[indices] | |
| colors = colors[indices] | |
| rots = rots[indices] | |
| float_dtype = np.dtype(np.float32).newbyteorder("<") | |
| means_np = means.detach().cpu().numpy().astype(float_dtype) | |
| scales_np = scales.detach().cpu().numpy().astype(float_dtype) | |
| colors_np = colors.detach().cpu().numpy().astype(np.uint8) | |
| rots_np = rots.detach().cpu().numpy().astype(np.uint8) | |
| buffer = BytesIO() | |
| for i in range(num_splats): | |
| buffer.write(means_np[i].tobytes()) | |
| buffer.write(scales_np[i].tobytes()) | |
| buffer.write(colors_np[i].tobytes()) | |
| buffer.write(rots_np[i].tobytes()) | |
| return buffer.getvalue() | |
| def export_splats( | |
| means: torch.Tensor, | |
| scales: torch.Tensor, | |
| quats: torch.Tensor, | |
| opacities: torch.Tensor, | |
| sh0: torch.Tensor, | |
| shN: torch.Tensor, | |
| format: Literal["ply", "splat", "ply_compressed"] = "ply", | |
| save_to: Optional[str] = None, | |
| ) -> bytes: | |
| """Export a Gaussian Splats model to bytes. | |
| The three supported formats are: | |
| - ply: A standard PLY file format. Supported by most viewers. | |
| - splat: A custom Splat file format. Supported by antimatter15 viewer. | |
| - ply_compressed: A compressed PLY file format. Used by Supersplat viewer. | |
| Args: | |
| means (torch.Tensor): Splat means. Shape (N, 3) | |
| scales (torch.Tensor): Splat scales. Shape (N, 3) | |
| quats (torch.Tensor): Splat quaternions. Shape (N, 4) | |
| opacities (torch.Tensor): Splat opacities. Shape (N,) | |
| sh0 (torch.Tensor): Spherical harmonics. Shape (N, 1, 3) | |
| shN (torch.Tensor): Spherical harmonics. Shape (N, K, 3) | |
| format (str): Export format. Options: "ply", "splat", "ply_compressed". Default: "ply" | |
| save_to (str): Output file path. If provided, the bytes will be written to file. | |
| """ | |
| total_splats = means.shape[0] | |
| assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)" | |
| assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)" | |
| assert quats.shape == (total_splats, 4), "Quaternions must be of shape (N, 4)" | |
| assert opacities.shape == (total_splats,), "Opacities must be of shape (N,)" | |
| assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)" | |
| assert ( | |
| shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3 | |
| ), f"shN must be of shape (N, K, 3), got {shN.shape}" | |
| # Reshape spherical harmonics | |
| sh0 = sh0.squeeze(1) # Shape (N, 3) | |
| shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3) | |
| # Check for NaN or Inf values | |
| invalid_mask = ( | |
| torch.isnan(means).any(dim=1) | |
| | torch.isinf(means).any(dim=1) | |
| | torch.isnan(scales).any(dim=1) | |
| | torch.isinf(scales).any(dim=1) | |
| | torch.isnan(quats).any(dim=1) | |
| | torch.isinf(quats).any(dim=1) | |
| | torch.isnan(opacities).any(dim=0) | |
| | torch.isinf(opacities).any(dim=0) | |
| | torch.isnan(sh0).any(dim=1) | |
| | torch.isinf(sh0).any(dim=1) | |
| | torch.isnan(shN).any(dim=1) | |
| | torch.isinf(shN).any(dim=1) | |
| ) | |
| # Filter out invalid entries | |
| valid_mask = ~invalid_mask | |
| means = means[valid_mask] | |
| scales = scales[valid_mask] | |
| quats = quats[valid_mask] | |
| opacities = opacities[valid_mask] | |
| sh0 = sh0[valid_mask] | |
| shN = shN[valid_mask] | |
| if format == "ply": | |
| data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN) | |
| elif format == "splat": | |
| data = splat2splat_bytes(means, scales, quats, opacities, sh0) | |
| elif format == "ply_compressed": | |
| data = splat2ply_bytes_compressed(means, scales, quats, opacities, sh0, shN) | |
| else: | |
| raise ValueError(f"Unsupported format: {format}") | |
| if save_to: | |
| with open(save_to, "wb") as binary_file: | |
| binary_file.write(data) | |
| return data | |