| import warp as wp |
| from utils.wp_utils import to_warp_array, wp_vec3_mul_element, wp_vec3_add_element, wp_vec3_sqrt, wp_vec3_div_element, wp_vec3_clamp |
| from config import * |
|
|
| @wp.kernel |
| def adam_update( |
| |
| positions: wp.array(dtype=wp.vec3), |
| scales: wp.array(dtype=wp.vec3), |
| rotations: wp.array(dtype=wp.vec4), |
| opacities: wp.array(dtype=float), |
| shs: wp.array(dtype=wp.vec3), |
| |
| |
| pos_grads: wp.array(dtype=wp.vec3), |
| scale_grads: wp.array(dtype=wp.vec3), |
| rot_grads: wp.array(dtype=wp.vec4), |
| opacity_grads: wp.array(dtype=float), |
| sh_grads: wp.array(dtype=wp.vec3), |
| |
| |
| m_positions: wp.array(dtype=wp.vec3), |
| m_scales: wp.array(dtype=wp.vec3), |
| m_rotations: wp.array(dtype=wp.vec4), |
| m_opacities: wp.array(dtype=float), |
| m_shs: wp.array(dtype=wp.vec3), |
| |
| |
| v_positions: wp.array(dtype=wp.vec3), |
| v_scales: wp.array(dtype=wp.vec3), |
| v_rotations: wp.array(dtype=wp.vec4), |
| v_opacities: wp.array(dtype=float), |
| v_shs: wp.array(dtype=wp.vec3), |
| |
| num_points: int, |
| lr_pos: float, |
| lr_scale: float, |
| lr_rot: float, |
| lr_opac: float, |
| lr_sh: float, |
| beta1: float, |
| beta2: float, |
| epsilon: float, |
| iteration: int |
| ): |
| i = wp.tid() |
| if i >= num_points: |
| return |
| |
| |
| bias_correction1 = 1.0 - wp.pow(beta1, float(iteration + 1)) |
| bias_correction2 = 1.0 - wp.pow(beta2, float(iteration + 1)) |
| |
| |
| m_positions[i] = beta1 * m_positions[i] + (1.0 - beta1) * pos_grads[i] |
| |
| v_positions[i] = beta2 * v_positions[i] + (1.0 - beta2) * wp_vec3_mul_element(pos_grads[i], pos_grads[i]) |
| |
| m_pos_corrected = m_positions[i] / bias_correction1 |
| v_pos_corrected = v_positions[i] / bias_correction2 |
| |
| denominator_pos = wp_vec3_sqrt(v_pos_corrected) + wp.vec3(epsilon, epsilon, epsilon) |
| positions[i] = positions[i] - lr_pos * wp_vec3_div_element(m_pos_corrected, denominator_pos) |
| |
| |
| m_scales[i] = beta1 * m_scales[i] + (1.0 - beta1) * scale_grads[i] |
| |
| v_scales[i] = beta2 * v_scales[i] + (1.0 - beta2) * wp_vec3_mul_element(scale_grads[i], scale_grads[i]) |
| |
| m_scale_corrected = m_scales[i] / bias_correction1 |
| v_scale_corrected = v_scales[i] / bias_correction2 |
| |
| denominator_scale = wp_vec3_sqrt(v_scale_corrected) + wp.vec3(epsilon, epsilon, epsilon) |
| scale_update = lr_scale * wp_vec3_div_element(m_scale_corrected, denominator_scale) |
| scales[i] = wp.vec3( |
| wp.max(scales[i][0] - scale_update[0], 0.001), |
| wp.max(scales[i][1] - scale_update[1], 0.001), |
| wp.max(scales[i][2] - scale_update[2], 0.001) |
| ) |
| |
| |
| m_rotations[i] = beta1 * m_rotations[i] + (1.0 - beta1) * rot_grads[i] |
| |
| v_rotations[i] = beta2 * v_rotations[i] + (1.0 - beta2) * wp.vec4( |
| rot_grads[i][0] * rot_grads[i][0], |
| rot_grads[i][1] * rot_grads[i][1], |
| rot_grads[i][2] * rot_grads[i][2], |
| rot_grads[i][3] * rot_grads[i][3] |
| ) |
| m_rot_corrected = m_rotations[i] / bias_correction1 |
| v_rot_corrected = v_rotations[i] / bias_correction2 |
| |
| denominator_rot = wp.vec4( |
| wp.sqrt(v_rot_corrected[0]) + epsilon, |
| wp.sqrt(v_rot_corrected[1]) + epsilon, |
| wp.sqrt(v_rot_corrected[2]) + epsilon, |
| wp.sqrt(v_rot_corrected[3]) + epsilon |
| ) |
| rot_update = wp.vec4( |
| lr_rot * m_rot_corrected[0] / denominator_rot[0], |
| lr_rot * m_rot_corrected[1] / denominator_rot[1], |
| lr_rot * m_rot_corrected[2] / denominator_rot[2], |
| lr_rot * m_rot_corrected[3] / denominator_rot[3] |
| ) |
| rotations[i] = rotations[i] - rot_update |
| |
| |
| quat_length = wp.sqrt(rotations[i][0]*rotations[i][0] + |
| rotations[i][1]*rotations[i][1] + |
| rotations[i][2]*rotations[i][2] + |
| rotations[i][3]*rotations[i][3]) |
| |
| if quat_length > 0.0: |
| rotations[i] = wp.vec4( |
| rotations[i][0] / quat_length, |
| rotations[i][1] / quat_length, |
| rotations[i][2] / quat_length, |
| rotations[i][3] / quat_length |
| ) |
| |
| |
| m_opacities[i] = beta1 * m_opacities[i] + (1.0 - beta1) * opacity_grads[i] |
| |
| v_opacities[i] = beta2 * v_opacities[i] + (1.0 - beta2) * (opacity_grads[i] * opacity_grads[i]) |
| |
| m_opacity_corrected = m_opacities[i] / bias_correction1 |
| v_opacity_corrected = v_opacities[i] / bias_correction2 |
| |
| opacity_update = lr_opac * m_opacity_corrected / (wp.sqrt(v_opacity_corrected) + epsilon) |
| opacities[i] = wp.max(wp.min(opacities[i] - opacity_update, 1.0), 0.0) |
| |
| |
| for j in range(16): |
| idx = i * 16 + j |
| m_shs[idx] = beta1 * m_shs[idx] + (1.0 - beta1) * sh_grads[idx] |
| |
| v_shs[idx] = beta2 * v_shs[idx] + (1.0 - beta2) * wp_vec3_mul_element(sh_grads[idx], sh_grads[idx]) |
| |
| m_sh_corrected = m_shs[idx] / bias_correction1 |
| v_sh_corrected = v_shs[idx] / bias_correction2 |
| |
| denominator_sh = wp_vec3_sqrt(v_sh_corrected) + wp.vec3(epsilon, epsilon, epsilon) |
| shs[idx] = shs[idx] - lr_sh * wp_vec3_div_element(m_sh_corrected, denominator_sh) |
|
|
|
|
| @wp.kernel |
| def reset_opacities( |
| opacities: wp.array(dtype=float), |
| max_opacity: float, |
| num_points: int |
| ): |
| """Reset opacities to prevent oversaturation.""" |
| i = wp.tid() |
| if i >= num_points: |
| return |
| |
| |
| opacities[i] = max_opacity |
|
|
| @wp.kernel |
| def reset_densification_stats( |
| xyz_gradient_accum: wp.array(dtype=float), |
| denom: wp.array(dtype=float), |
| max_radii2D: wp.array(dtype=float), |
| num_points: int |
| ): |
| """Reset densification statistics after parameter count changes.""" |
| i = wp.tid() |
| if i >= num_points: |
| return |
| |
| xyz_gradient_accum[i] = 0.0 |
| denom[i] = 0.0 |
| max_radii2D[i] = 0.0 |
|
|
|
|
| @wp.kernel |
| def mark_split_candidates( |
| grads: wp.array(dtype=float), |
| scales: wp.array(dtype=wp.vec3), |
| grad_threshold: float, |
| scene_extent: float, |
| percent_dense: float, |
| split_mask: wp.array(dtype=int), |
| num_points: int |
| ): |
| """Mark large Gaussians with high gradients for splitting.""" |
| i = wp.tid() |
| if i >= num_points: |
| return |
| |
| |
| high_grad = grads[i] >= grad_threshold |
| |
| |
| max_scale = wp.max(wp.max(scales[i][0], scales[i][1]), scales[i][2]) |
| scale_threshold = percent_dense * scene_extent |
| large_gaussian = max_scale > scale_threshold |
| |
| |
| if (high_grad and large_gaussian): |
| split_mask[i] = 1 |
| else: |
| split_mask[i] = 0 |
|
|
| @wp.kernel |
| def mark_clone_candidates( |
| grads: wp.array(dtype=float), |
| scales: wp.array(dtype=wp.vec3), |
| grad_threshold: float, |
| scene_extent: float, |
| percent_dense: float, |
| clone_mask: wp.array(dtype=int), |
| num_points: int |
| ): |
| """Mark small Gaussians with high gradients for cloning.""" |
| i = wp.tid() |
| if i >= num_points: |
| return |
| |
| |
| high_grad = grads[i] >= grad_threshold |
| |
| |
| max_scale = wp.max(wp.max(scales[i][0], scales[i][1]), scales[i][2]) |
| scale_threshold = percent_dense * scene_extent |
| small_gaussian = max_scale <= scale_threshold |
| |
| |
| if (high_grad and small_gaussian): |
| clone_mask[i] = 1 |
| else: |
| clone_mask[i] = 0 |
|
|
| @wp.kernel |
| def split_gaussians( |
| split_mask: wp.array(dtype=int), |
| prefix_sum: wp.array(dtype=int), |
| positions: wp.array(dtype=wp.vec3), |
| scales: wp.array(dtype=wp.vec3), |
| rotations: wp.array(dtype=wp.vec4), |
| opacities: wp.array(dtype=float), |
| shs: wp.array(dtype=wp.vec3), |
| N_split: int, |
| scale_factor: float, |
| offset: int, |
| out_positions: wp.array(dtype=wp.vec3), |
| out_scales: wp.array(dtype=wp.vec3), |
| out_rotations: wp.array(dtype=wp.vec4), |
| out_opacities: wp.array(dtype=float), |
| out_shs: wp.array(dtype=wp.vec3) |
| ): |
| """Split large Gaussians into multiple smaller ones.""" |
| i = wp.tid() |
| |
| |
| if i < len(positions): |
| out_positions[i] = positions[i] |
| out_scales[i] = scales[i] |
| out_rotations[i] = rotations[i] |
| out_opacities[i] = opacities[i] |
| |
| |
| for j in range(16): |
| out_shs[i * 16 + j] = shs[i * 16 + j] |
| |
| |
| if i >= len(positions): |
| return |
| |
| if split_mask[i] == 1: |
| |
| split_idx = prefix_sum[i] |
| |
| |
| for j in range(N_split): |
| new_idx = offset + split_idx * N_split + j |
| if new_idx < len(out_positions): |
| |
| scaled_scales = wp.vec3( |
| scales[i][0] * scale_factor, |
| scales[i][1] * scale_factor, |
| scales[i][2] * scale_factor |
| ) |
| |
| |
| random_offset = wp.vec3( |
| ((wp.randf(wp.uint32(new_idx * 3))) * 2.0 - 1.0) * 0.01, |
| ((wp.randf(wp.uint32(new_idx * 3 + 1))) * 2.0 - 1.0) * 0.01, |
| ((wp.randf(wp.uint32(new_idx * 3 + 2))) * 2.0 - 1.0) * 0.01 |
| ) |
| |
| out_positions[new_idx] = positions[i] + random_offset |
| out_scales[new_idx] = scaled_scales |
| out_rotations[new_idx] = rotations[i] |
| out_opacities[new_idx] = opacities[i] |
| |
| |
| for k in range(16): |
| out_shs[new_idx * 16 + k] = shs[i * 16 + k] |
|
|
|
|
| @wp.kernel |
| def clone_gaussians( |
| clone_mask: wp.array(dtype=int), |
| prefix_sum: wp.array(dtype=int), |
| positions: wp.array(dtype=wp.vec3), |
| scales: wp.array(dtype=wp.vec3), |
| rotations: wp.array(dtype=wp.vec4), |
| opacities: wp.array(dtype=float), |
| shs: wp.array(dtype=wp.vec3), |
| |
| noise_scale: float, |
| offset: int, |
| out_positions: wp.array(dtype=wp.vec3), |
| out_scales: wp.array(dtype=wp.vec3), |
| out_rotations: wp.array(dtype=wp.vec4), |
| out_opacities: wp.array(dtype=float), |
| out_shs: wp.array(dtype=wp.vec3), |
| ): |
| i = wp.tid() |
| if i >= offset: |
| return |
|
|
| |
| out_positions[i] = positions[i] |
| out_scales[i] = scales[i] |
| out_rotations[i] = rotations[i] |
| out_opacities[i] = opacities[i] |
| for j in range(16): |
| out_shs[i * 16 + j] = shs[i * 16 + j] |
|
|
| if clone_mask[i] == 1: |
| base_idx = prefix_sum[i] + offset |
| pos = positions[i] |
| scale = scales[i] |
| rot = rotations[i] |
| opac = opacities[i] |
|
|
|
|
| noise = wp.vec3( |
| wp.randf(wp.uint32(i * 3)) * noise_scale, |
| wp.randf(wp.uint32(i * 3 + 1)) * noise_scale, |
| wp.randf(wp.uint32(i * 3 + 2)) * noise_scale |
| ) |
|
|
| out_positions[base_idx] = pos + noise |
| out_scales[base_idx] = scale |
| out_rotations[base_idx] = rot |
| out_opacities[base_idx] = opac |
|
|
| for j in range(16): |
| out_shs[base_idx * 16 + j] = shs[i * 16 + j] |
|
|
| @wp.kernel |
| def prune_gaussians( |
| opacities: wp.array(dtype=float), |
| opacity_threshold: float, |
| valid_mask: wp.array(dtype=int), |
| num_points: int |
| ): |
| i = wp.tid() |
| if i >= num_points: |
| return |
| |
| if opacities[i] > opacity_threshold: |
| valid_mask[i] = 1 |
| else: |
| valid_mask[i] = 0 |
|
|
| @wp.kernel |
| def compact_gaussians( |
| valid_mask: wp.array(dtype=int), |
| prefix_sum: wp.array(dtype=int), |
| positions: wp.array(dtype=wp.vec3), |
| scales: wp.array(dtype=wp.vec3), |
| rotations: wp.array(dtype=wp.vec4), |
| opacities: wp.array(dtype=float), |
| shs: wp.array(dtype=wp.vec3), |
| |
| out_positions: wp.array(dtype=wp.vec3), |
| out_scales: wp.array(dtype=wp.vec3), |
| out_rotations: wp.array(dtype=wp.vec4), |
| out_opacities: wp.array(dtype=float), |
| out_shs: wp.array(dtype=wp.vec3) |
| ): |
| i = wp.tid() |
| if valid_mask[i] == 0: |
| return |
|
|
| new_i = prefix_sum[i] |
|
|
| out_positions[new_i] = positions[i] |
| out_scales[new_i] = scales[i] |
| out_rotations[new_i] = rotations[i] |
| out_opacities[new_i] = opacities[i] |
|
|
| for j in range(16): |
| out_shs[new_i * 16 + j] = shs[i * 16 + j] |
|
|