4dgs-dpm / gs /optimizer.py
dxm21's picture
Upload folder using huggingface_hub
8c48cce verified
import warp as wp
from utils.wp_utils import to_warp_array, wp_vec3_mul_element, wp_vec3_add_element, wp_vec3_sqrt, wp_vec3_div_element, wp_vec3_clamp
from config import *
@wp.kernel
def adam_update(
# Parameters
positions: wp.array(dtype=wp.vec3),
scales: wp.array(dtype=wp.vec3),
rotations: wp.array(dtype=wp.vec4),
opacities: wp.array(dtype=float),
shs: wp.array(dtype=wp.vec3),
# Gradients
pos_grads: wp.array(dtype=wp.vec3),
scale_grads: wp.array(dtype=wp.vec3),
rot_grads: wp.array(dtype=wp.vec4),
opacity_grads: wp.array(dtype=float),
sh_grads: wp.array(dtype=wp.vec3),
# First moments (m)
m_positions: wp.array(dtype=wp.vec3),
m_scales: wp.array(dtype=wp.vec3),
m_rotations: wp.array(dtype=wp.vec4),
m_opacities: wp.array(dtype=float),
m_shs: wp.array(dtype=wp.vec3),
# Second moments (v)
v_positions: wp.array(dtype=wp.vec3),
v_scales: wp.array(dtype=wp.vec3),
v_rotations: wp.array(dtype=wp.vec4),
v_opacities: wp.array(dtype=float),
v_shs: wp.array(dtype=wp.vec3),
num_points: int,
lr_pos: float,
lr_scale: float,
lr_rot: float,
lr_opac: float,
lr_sh: float,
beta1: float,
beta2: float,
epsilon: float,
iteration: int
):
i = wp.tid()
if i >= num_points:
return
# Bias correction terms
bias_correction1 = 1.0 - wp.pow(beta1, float(iteration + 1))
bias_correction2 = 1.0 - wp.pow(beta2, float(iteration + 1))
# Update positions
m_positions[i] = beta1 * m_positions[i] + (1.0 - beta1) * pos_grads[i]
# Use the helper function for element-wise multiplication
v_positions[i] = beta2 * v_positions[i] + (1.0 - beta2) * wp_vec3_mul_element(pos_grads[i], pos_grads[i])
# Use distinct names for corrected moments per parameter type
m_pos_corrected = m_positions[i] / bias_correction1
v_pos_corrected = v_positions[i] / bias_correction2
# Use the helper function for element-wise sqrt and division
denominator_pos = wp_vec3_sqrt(v_pos_corrected) + wp.vec3(epsilon, epsilon, epsilon)
positions[i] = positions[i] - lr_pos * wp_vec3_div_element(m_pos_corrected, denominator_pos)
# Update scales (with some constraints to keep them positive)
m_scales[i] = beta1 * m_scales[i] + (1.0 - beta1) * scale_grads[i]
# Use the helper function for element-wise multiplication
v_scales[i] = beta2 * v_scales[i] + (1.0 - beta2) * wp_vec3_mul_element(scale_grads[i], scale_grads[i])
# Use distinct names for corrected moments per parameter type
m_scale_corrected = m_scales[i] / bias_correction1
v_scale_corrected = v_scales[i] / bias_correction2
# Use the helper function for element-wise sqrt and division
denominator_scale = wp_vec3_sqrt(v_scale_corrected) + wp.vec3(epsilon, epsilon, epsilon)
scale_update = lr_scale * wp_vec3_div_element(m_scale_corrected, denominator_scale)
scales[i] = wp.vec3(
wp.max(scales[i][0] - scale_update[0], 0.001),
wp.max(scales[i][1] - scale_update[1], 0.001),
wp.max(scales[i][2] - scale_update[2], 0.001)
)
# Update rotations
m_rotations[i] = beta1 * m_rotations[i] + (1.0 - beta1) * rot_grads[i]
# Element-wise multiplication for quaternions
v_rotations[i] = beta2 * v_rotations[i] + (1.0 - beta2) * wp.vec4(
rot_grads[i][0] * rot_grads[i][0],
rot_grads[i][1] * rot_grads[i][1],
rot_grads[i][2] * rot_grads[i][2],
rot_grads[i][3] * rot_grads[i][3]
)
m_rot_corrected = m_rotations[i] / bias_correction1
v_rot_corrected = v_rotations[i] / bias_correction2
# Element-wise sqrt and division for quaternions
denominator_rot = wp.vec4(
wp.sqrt(v_rot_corrected[0]) + epsilon,
wp.sqrt(v_rot_corrected[1]) + epsilon,
wp.sqrt(v_rot_corrected[2]) + epsilon,
wp.sqrt(v_rot_corrected[3]) + epsilon
)
rot_update = wp.vec4(
lr_rot * m_rot_corrected[0] / denominator_rot[0],
lr_rot * m_rot_corrected[1] / denominator_rot[1],
lr_rot * m_rot_corrected[2] / denominator_rot[2],
lr_rot * m_rot_corrected[3] / denominator_rot[3]
)
rotations[i] = rotations[i] - rot_update
# Normalize quaternion to ensure it's a valid rotation
quat_length = wp.sqrt(rotations[i][0]*rotations[i][0] +
rotations[i][1]*rotations[i][1] +
rotations[i][2]*rotations[i][2] +
rotations[i][3]*rotations[i][3])
if quat_length > 0.0:
rotations[i] = wp.vec4(
rotations[i][0] / quat_length,
rotations[i][1] / quat_length,
rotations[i][2] / quat_length,
rotations[i][3] / quat_length
)
# Update opacity (with clamping to [0,1])
m_opacities[i] = beta1 * m_opacities[i] + (1.0 - beta1) * opacity_grads[i]
# Opacity is scalar, direct multiplication is fine
v_opacities[i] = beta2 * v_opacities[i] + (1.0 - beta2) * (opacity_grads[i] * opacity_grads[i])
# Use distinct names for corrected moments per parameter type
m_opacity_corrected = m_opacities[i] / bias_correction1
v_opacity_corrected = v_opacities[i] / bias_correction2
# Opacity is scalar, direct wp.sqrt is fine here
opacity_update = lr_opac * m_opacity_corrected / (wp.sqrt(v_opacity_corrected) + epsilon)
opacities[i] = wp.max(wp.min(opacities[i] - opacity_update, 1.0), 0.0)
# Update SH coefficients
for j in range(16):
idx = i * 16 + j
m_shs[idx] = beta1 * m_shs[idx] + (1.0 - beta1) * sh_grads[idx]
# Use the helper function for element-wise multiplication
v_shs[idx] = beta2 * v_shs[idx] + (1.0 - beta2) * wp_vec3_mul_element(sh_grads[idx], sh_grads[idx])
# Use distinct names for corrected moments per parameter type
m_sh_corrected = m_shs[idx] / bias_correction1
v_sh_corrected = v_shs[idx] / bias_correction2
# Use the helper function for element-wise sqrt and division
denominator_sh = wp_vec3_sqrt(v_sh_corrected) + wp.vec3(epsilon, epsilon, epsilon)
shs[idx] = shs[idx] - lr_sh * wp_vec3_div_element(m_sh_corrected, denominator_sh)
@wp.kernel
def reset_opacities(
opacities: wp.array(dtype=float),
max_opacity: float,
num_points: int
):
"""Reset opacities to prevent oversaturation."""
i = wp.tid()
if i >= num_points:
return
# Reset opacity to a small value
opacities[i] = max_opacity
@wp.kernel
def reset_densification_stats(
xyz_gradient_accum: wp.array(dtype=float),
denom: wp.array(dtype=float),
max_radii2D: wp.array(dtype=float),
num_points: int
):
"""Reset densification statistics after parameter count changes."""
i = wp.tid()
if i >= num_points:
return
xyz_gradient_accum[i] = 0.0
denom[i] = 0.0
max_radii2D[i] = 0.0
@wp.kernel
def mark_split_candidates(
grads: wp.array(dtype=float),
scales: wp.array(dtype=wp.vec3),
grad_threshold: float,
scene_extent: float,
percent_dense: float,
split_mask: wp.array(dtype=int),
num_points: int
):
"""Mark large Gaussians with high gradients for splitting."""
i = wp.tid()
if i >= num_points:
return
# Check if gradient exceeds threshold
high_grad = grads[i] >= grad_threshold
# Check if Gaussian is large (max scale > threshold)
max_scale = wp.max(wp.max(scales[i][0], scales[i][1]), scales[i][2])
scale_threshold = percent_dense * scene_extent
large_gaussian = max_scale > scale_threshold
# Mark for splitting if both conditions are met
if (high_grad and large_gaussian):
split_mask[i] = 1
else:
split_mask[i] = 0
@wp.kernel
def mark_clone_candidates(
grads: wp.array(dtype=float),
scales: wp.array(dtype=wp.vec3),
grad_threshold: float,
scene_extent: float,
percent_dense: float,
clone_mask: wp.array(dtype=int),
num_points: int
):
"""Mark small Gaussians with high gradients for cloning."""
i = wp.tid()
if i >= num_points:
return
# Check if gradient exceeds threshold
high_grad = grads[i] >= grad_threshold
# Check if Gaussian is small (max scale <= threshold)
max_scale = wp.max(wp.max(scales[i][0], scales[i][1]), scales[i][2])
scale_threshold = percent_dense * scene_extent
small_gaussian = max_scale <= scale_threshold
# Mark for cloning if both conditions are met
if (high_grad and small_gaussian):
clone_mask[i] = 1
else:
clone_mask[i] = 0
@wp.kernel
def split_gaussians(
split_mask: wp.array(dtype=int),
prefix_sum: wp.array(dtype=int),
positions: wp.array(dtype=wp.vec3),
scales: wp.array(dtype=wp.vec3),
rotations: wp.array(dtype=wp.vec4),
opacities: wp.array(dtype=float),
shs: wp.array(dtype=wp.vec3),
N_split: int,
scale_factor: float,
offset: int,
out_positions: wp.array(dtype=wp.vec3),
out_scales: wp.array(dtype=wp.vec3),
out_rotations: wp.array(dtype=wp.vec4),
out_opacities: wp.array(dtype=float),
out_shs: wp.array(dtype=wp.vec3)
):
"""Split large Gaussians into multiple smaller ones."""
i = wp.tid()
# Copy original Gaussians first
if i < len(positions):
out_positions[i] = positions[i]
out_scales[i] = scales[i]
out_rotations[i] = rotations[i]
out_opacities[i] = opacities[i]
# Copy SH coefficients
for j in range(16):
out_shs[i * 16 + j] = shs[i * 16 + j]
# Handle splits
if i >= len(positions):
return
if split_mask[i] == 1:
# Find where to write new Gaussians
split_idx = prefix_sum[i]
# Create N_split new Gaussians
for j in range(N_split):
new_idx = offset + split_idx * N_split + j
if new_idx < len(out_positions):
# Scale down the original Gaussian
scaled_scales = wp.vec3(
scales[i][0] * scale_factor,
scales[i][1] * scale_factor,
scales[i][2] * scale_factor
)
# Add small random offset for position
random_offset = wp.vec3(
((wp.randf(wp.uint32(new_idx * 3))) * 2.0 - 1.0) * 0.01,
((wp.randf(wp.uint32(new_idx * 3 + 1))) * 2.0 - 1.0) * 0.01,
((wp.randf(wp.uint32(new_idx * 3 + 2))) * 2.0 - 1.0) * 0.01
)
out_positions[new_idx] = positions[i] + random_offset
out_scales[new_idx] = scaled_scales
out_rotations[new_idx] = rotations[i]
out_opacities[new_idx] = opacities[i]
# Copy SH coefficients
for k in range(16):
out_shs[new_idx * 16 + k] = shs[i * 16 + k]
@wp.kernel
def clone_gaussians(
clone_mask: wp.array(dtype=int),
prefix_sum: wp.array(dtype=int),
positions: wp.array(dtype=wp.vec3),
scales: wp.array(dtype=wp.vec3),
rotations: wp.array(dtype=wp.vec4),
opacities: wp.array(dtype=float),
shs: wp.array(dtype=wp.vec3), # shape: [N * 16]
noise_scale: float,
offset: int, # where to start writing new points
out_positions: wp.array(dtype=wp.vec3),
out_scales: wp.array(dtype=wp.vec3),
out_rotations: wp.array(dtype=wp.vec4),
out_opacities: wp.array(dtype=float),
out_shs: wp.array(dtype=wp.vec3),
):
i = wp.tid()
if i >= offset:
return
# Copy original to out[i]
out_positions[i] = positions[i]
out_scales[i] = scales[i]
out_rotations[i] = rotations[i]
out_opacities[i] = opacities[i]
for j in range(16):
out_shs[i * 16 + j] = shs[i * 16 + j]
if clone_mask[i] == 1:
base_idx = prefix_sum[i] + offset
pos = positions[i]
scale = scales[i]
rot = rotations[i]
opac = opacities[i]
noise = wp.vec3(
wp.randf(wp.uint32(i * 3)) * noise_scale,
wp.randf(wp.uint32(i * 3 + 1)) * noise_scale,
wp.randf(wp.uint32(i * 3 + 2)) * noise_scale
)
out_positions[base_idx] = pos + noise
out_scales[base_idx] = scale
out_rotations[base_idx] = rot
out_opacities[base_idx] = opac
for j in range(16):
out_shs[base_idx * 16 + j] = shs[i * 16 + j]
@wp.kernel
def prune_gaussians(
opacities: wp.array(dtype=float),
opacity_threshold: float,
valid_mask: wp.array(dtype=int),
num_points: int
):
i = wp.tid()
if i >= num_points:
return
# Mark Gaussians for keeping or removal
if opacities[i] > opacity_threshold:
valid_mask[i] = 1
else:
valid_mask[i] = 0
@wp.kernel
def compact_gaussians(
valid_mask: wp.array(dtype=int),
prefix_sum: wp.array(dtype=int),
positions: wp.array(dtype=wp.vec3),
scales: wp.array(dtype=wp.vec3),
rotations: wp.array(dtype=wp.vec4),
opacities: wp.array(dtype=float),
shs: wp.array(dtype=wp.vec3), # shape: [N * 16]
out_positions: wp.array(dtype=wp.vec3),
out_scales: wp.array(dtype=wp.vec3),
out_rotations: wp.array(dtype=wp.vec4),
out_opacities: wp.array(dtype=float),
out_shs: wp.array(dtype=wp.vec3)
):
i = wp.tid()
if valid_mask[i] == 0:
return
new_i = prefix_sum[i]
out_positions[new_i] = positions[i]
out_scales[new_i] = scales[i]
out_rotations[new_i] = rotations[i]
out_opacities[new_i] = opacities[i]
for j in range(16):
out_shs[new_i * 16 + j] = shs[i * 16 + j]