4dgs-dpm / gs /utils /point_cloud_utils.py
dxm21's picture
Upload folder using huggingface_hub
8c48cce verified
import math
import numpy as np
from plyfile import PlyData, PlyElement
import math
import os
import warp as wp
def load_ply(filepath):
"""
Load a Gaussian splat PLY file.
Returns dict with: positions, scales, rotations, opacities, shs
"""
plydata = PlyData.read(filepath)
vertex = plydata['vertex']
num_points = len(vertex)
# Load positions
positions = np.stack([
vertex['x'], vertex['y'], vertex['z']
], axis=-1).astype(np.float32)
# Load scales (stored in log space)
scales = np.stack([
np.exp(vertex['scale_0']),
np.exp(vertex['scale_1']),
np.exp(vertex['scale_2'])
], axis=-1).astype(np.float32)
# Load opacities
opacities = vertex['opacity'].astype(np.float32).reshape(-1, 1)
# Load rotations (quaternion)
rotations = np.stack([
vertex['rot_0'], vertex['rot_1'], vertex['rot_2'], vertex['rot_3']
], axis=-1).astype(np.float32)
# Load SH coefficients
# DC term
sh_dc = np.stack([
vertex['f_dc_0'], vertex['f_dc_1'], vertex['f_dc_2']
], axis=-1).astype(np.float32)
# Rest of SH coefficients
sh_rest = []
for i in range(45):
sh_rest.append(vertex[f'f_rest_{i}'])
sh_rest = np.stack(sh_rest, axis=-1).astype(np.float32) # (N, 45)
sh_rest = sh_rest.reshape(num_points, 15, 3) # (N, 15, 3)
# Combine into (N*16, 3) format expected by renderer
shs = np.zeros((num_points * 16, 3), dtype=np.float32)
for i in range(num_points):
shs[i * 16] = sh_dc[i]
for j in range(15):
shs[i * 16 + j + 1] = sh_rest[i, j]
return {
'positions': positions,
'scales': scales,
'rotations': rotations,
'opacities': opacities,
'shs': shs,
'num_points': num_points
}
# Function to save point cloud to PLY file
def save_ply(params, filepath, num_points, colors=None):
# Get numpy arrays
positions = params['positions'].numpy()
scales = params['scales'].numpy()
rotations = params['rotations'].numpy()
opacities = params['opacities'].numpy()
shs = params['shs'].numpy()
# Handle colors - either provided or computed from SH coefficients
if colors is not None:
# Use provided colors
if hasattr(colors, 'numpy'):
colors_np = colors.numpy()
else:
colors_np = colors
else:
# Compute colors from SH coefficients (DC term only for simplicity)
# SH DC coefficients are stored in the first coefficient (index 0)
colors_np = np.zeros((num_points, 3), dtype=np.float32)
for i in range(num_points):
# Get DC term from SH coefficients
sh_dc = shs[i * 16] # First SH coefficient contains DC term
# Convert from SH to RGB (simplified - just use DC term)
colors_np[i] = np.clip(sh_dc + 0.5, 0.0, 1.0) # Add 0.5 offset and clamp
# Create vertex data
vertex_data = []
for i in range(num_points):
# Basic properties
vertex = (
positions[i][0], positions[i][1], positions[i][2],
np.log(scales[i][0]), np.log(scales[i][1]), np.log(scales[i][2]), # Log-space encoding
(opacities[i])
)
# Add rotation quaternion elements
quat = rotations[i]
rot_elements = (quat[0], quat[1], quat[2], quat[3]) # x, y, z, w
vertex += rot_elements
# Add RGB colors (convert to 0-255 range)
color_255 = (
int(np.clip(colors_np[i][0] * 255, 0, 255)),
int(np.clip(colors_np[i][1] * 255, 0, 255)),
int(np.clip(colors_np[i][2] * 255, 0, 255))
)
vertex += color_255
# Add SH coefficients
sh_dc = tuple(shs[i * 16][j] for j in range(3))
vertex += sh_dc
# Add remaining SH coefficients
sh_rest = []
for j in range(1, 16):
for c in range(3):
sh_rest.append(shs[i * 16 + j][c])
vertex += tuple(sh_rest)
vertex_data.append(vertex)
# Define the structure of the PLY file
vertex_type = [
('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
('opacity', 'f4')
]
# Add rotation quaternion elements
vertex_type.extend([('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4')])
# Add RGB color fields
vertex_type.extend([('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
# Add SH coefficients
vertex_type.extend([('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4')])
# Add remaining SH coefficients
for i in range(45): # 15 coeffs * 3 channels
vertex_type.append((f'f_rest_{i}', 'f4'))
vertex_array = np.array(vertex_data, dtype=vertex_type)
el = PlyElement.describe(vertex_array, 'vertex')
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# Save the PLY file
PlyData([el], text=False).write(filepath)
print(f"Point cloud saved to {filepath}")