4dgs-dpm / gs /forward.py
dxm21's picture
Upload folder using huggingface_hub
f61284c verified
import warp as wp
from utils.wp_utils import to_warp_array
from config import *
# Initialize Warp
wp.init()
print("Warp devices:", wp.get_devices())
# Define spherical harmonics constants
SH_C0 = 0.28209479177387814
SH_C1 = 0.4886025119029199
import warp as wp
# Define the CUDA code snippets for bit reinterpretation
float_to_uint32_snippet = """
return reinterpret_cast<uint32_t&>(x);
"""
@wp.func_native(float_to_uint32_snippet)
def float_bits_to_uint32(x: float) -> wp.uint32:
...
@wp.func
def ndc2pix(x: float, size: float) -> float:
return ((x + 1.0) * size - 1.0) * 0.5
@wp.func
def get_rect(p: wp.vec2, max_radius: float, tile_grid: wp.vec3):
# Extract grid dimensions
grid_size_x = tile_grid[0]
grid_size_y = tile_grid[1]
rect_min_x = wp.min(wp.int32(grid_size_x), wp.int32(wp.max(wp.int32(0), wp.int32((p[0] - max_radius) / float(TILE_M)))))
rect_min_y = wp.min(wp.int32(grid_size_y), wp.int32(wp.max(wp.int32(0), wp.int32((p[1] - max_radius) / float(TILE_N)))))
rect_max_x = wp.min(wp.int32(grid_size_x), wp.int32(wp.max(wp.int32(0), wp.int32((p[0] + max_radius + float(TILE_M) - 1.0) / float(TILE_M)))))
rect_max_y = wp.min(wp.int32(grid_size_y), wp.int32(wp.max(wp.int32(0), wp.int32((p[1] + max_radius + float(TILE_N) - 1.0) / float(TILE_N)))))
return rect_min_x, rect_min_y, rect_max_x, rect_max_y
@wp.func
def compute_cov2d(p_orig: wp.vec3, cov3d: VEC6, view_matrix: wp.mat44,
tan_fovx: float, tan_fovy: float, width: float, height: float) -> wp.vec3:
t = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * view_matrix
limx = 1.3 * tan_fovx
limy = 1.3 * tan_fovy
# Clamp X/Y to stay inside frustum
txtz = t[0] / t[2]
tytz = t[1] / t[2]
t[0] = min(limx, max(-limx, txtz)) * t[2]
t[1] = min(limy, max(-limy, tytz)) * t[2]
focal_x = width / (2.0 * tan_fovx)
focal_y = height / (2.0 * tan_fovy)
# compute Jacobian
J = wp.mat33(
focal_x / t[2], 0.0, -(focal_x * t[0]) / (t[2] * t[2]),
0.0, focal_y / t[2], -(focal_y * t[1]) / (t[2] * t[2]),
0.0, 0.0, 0.0
)
W = wp.mat33(
view_matrix[0, 0], view_matrix[0, 1], view_matrix[0, 2],
view_matrix[1, 0], view_matrix[1, 1], view_matrix[1, 2],
view_matrix[2, 0], view_matrix[2, 1], view_matrix[2, 2]
)
T = J * W
Vrk = wp.mat33(
cov3d[0], cov3d[1], cov3d[2],
cov3d[1], cov3d[3], cov3d[4],
cov3d[2], cov3d[4], cov3d[5]
)
cov = T * wp.transpose(Vrk) * wp.transpose(T)
return wp.vec3(cov[0, 0], cov[0, 1], cov[1, 1])
@wp.func
def compute_cov3d(scale: wp.vec3, scale_mod: float, rot: wp.vec4) -> VEC6:
# 2DGS: Create flat disk by setting z-scale to near-zero
# The Gaussian becomes a 2D ellipse oriented by the rotation
sz_2dgs = 1e-6 # Near-zero thickness for 2D Gaussian
# Create scaling matrix with modifier applied
S = wp.mat33(
scale_mod * scale[0], 0.0, 0.0,
0.0, scale_mod * scale[1], 0.0,
0.0, 0.0, sz_2dgs # Fixed small z-scale for 2DGS
)
R = wp.quat_to_matrix(wp.quaternion(rot[0], rot[1], rot[2], rot[3]))
M = R * S
# Compute 3D covariance matrix: Sigma = M * M^T
sigma = M * wp.transpose(M)
return VEC6(sigma[0, 0], sigma[0, 1], sigma[0, 2], sigma[1, 1], sigma[1, 2], sigma[2, 2])
@wp.kernel
def wp_preprocess(
orig_points: wp.array(dtype=wp.vec3),
scales: wp.array(dtype=wp.vec3),
scale_modifier: float,
rotations: wp.array(dtype=wp.vec4),
opacities: wp.array(dtype=float),
shs: wp.array(dtype=wp.vec3),
degree: int,
clamped: bool,
view_matrix: wp.mat44,
proj_matrix: wp.mat44,
cam_pos: wp.vec3,
W: int,
H: int,
tan_fovx: float,
tan_fovy: float,
focal_x: float,
focal_y: float,
radii: wp.array(dtype=int),
points_xy_image: wp.array(dtype=wp.vec2),
depths: wp.array(dtype=float),
cov3Ds: wp.array(dtype=VEC6),
rgb: wp.array(dtype=wp.vec3),
conic_opacity: wp.array(dtype=wp.vec4),
tile_grid: wp.vec3,
tiles_touched: wp.array(dtype=int),
clamped_state: wp.array(dtype=wp.vec3),
prefiltered: bool,
antialiasing: bool
):
# Get thread indices
i = wp.tid()
# For each Gaussian
p_orig = orig_points[i]
p_view = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * view_matrix
if p_view[2] < 0.2:
return
p_hom = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * proj_matrix
p_w = 1.0 / (p_hom[3] + 0.0000001)
p_proj = wp.vec3(p_hom[0] * p_w, p_hom[1] * p_w, p_hom[2] * p_w)
cov3d = compute_cov3d(scales[i], scale_modifier, rotations[i])
cov3Ds[i] = cov3d
# Compute 2D covariance matrix
cov2d = compute_cov2d(p_orig, cov3d, view_matrix, tan_fovx, tan_fovy, float(W), float(H))
# Constants
h_var = 0.3
W_float = float(W)
H_float = float(H)
C = 3 # RGB channels
# Add blur/antialiasing factor to covariance
det_cov = cov2d[0] * cov2d[2] - cov2d[1] * cov2d[1]
cov_with_blur = wp.vec3(cov2d[0] + h_var, cov2d[1], cov2d[2] + h_var)
det_cov_plus_h_cov = cov_with_blur[0] * cov_with_blur[2] - cov_with_blur[1] * cov_with_blur[1]
# Invert covariance (EWA algorithm)
det = det_cov_plus_h_cov
if det == 0.0:
return
det_inv = 1.0 / det
conic = wp.vec3(
cov_with_blur[2] * det_inv,
-cov_with_blur[1] * det_inv,
cov_with_blur[0] * det_inv
)
# Compute eigenvalues of covariance matrix to find screen-space extent
mid = 0.5 * (cov_with_blur[0] + cov_with_blur[2])
lambda1 = mid + wp.sqrt(wp.max(0.1, mid * mid - det))
lambda2 = mid - wp.sqrt(wp.max(0.1, mid * mid - det))
my_radius = wp.ceil(3.0 * wp.sqrt(wp.max(lambda1, lambda2)))
# Convert to pixel coordinates
point_image = wp.vec2(ndc2pix(p_proj[0], W_float), ndc2pix(p_proj[1], H_float))
# Get rectangle of affected tiles
rect_min_x, rect_min_y, rect_max_x, rect_max_y = get_rect(point_image, my_radius, tile_grid)
# Skip if rectangle has 0 area
if (rect_max_x - rect_min_x) * (rect_max_y - rect_min_y) == 0:
return
# Compute color from spherical harmonics
pos = p_orig
dir_orig = pos - cam_pos
dir = wp.normalize(dir_orig)
x, y, z = dir[0], dir[1], dir[2]
# Base offset for this Gaussian's SH coefficients
base_idx = i * 16 # assuming degree 3 (16 coefficients)
# Start with the DC component (degree 0)
result = SH_C0 * shs[base_idx]
# Add higher degree terms if requested
if degree > 0:
# Degree 1 terms
result = result - SH_C1 * y * shs[base_idx + 1] + SH_C1 * z * shs[base_idx + 2] - SH_C1 * x * shs[base_idx + 3]
if degree > 1:
# Degree 2 terms
xx = x*x
yy = y*y
zz = z*z
xy = x*y
yz = y*z
xz = x*z
# Degree 2 terms with hardcoded constants
result = result + 1.0925484305920792 * xy * shs[base_idx + 4]
result = result + (-1.0925484305920792) * yz * shs[base_idx + 5]
result = result + 0.31539156525252005 * (2.0 * zz - xx - yy) * shs[base_idx + 6]
result = result + (-1.0925484305920792) * xz * shs[base_idx + 7]
result = result + 0.5462742152960396 * (xx - yy) * shs[base_idx + 8]
if degree > 2:
# Degree 3 terms with hardcoded constants
result = result + (-0.5900435899266435) * y * (3.0 * xx - yy) * shs[base_idx + 9]
result = result + 2.890611442640554 * xy * z * shs[base_idx + 10]
result = result + (-0.4570457994644658) * y * (4.0 * zz - xx - yy) * shs[base_idx + 11]
result = result + 0.3731763325901154 * z * (2.0 * zz - 3.0 * xx - 3.0 * yy) * shs[base_idx + 12]
result = result + (-0.4570457994644658) * x * (4.0 * zz - xx - yy) * shs[base_idx + 13]
result = result + 1.445305721320277 * z * (xx - yy) * shs[base_idx + 14]
result = result + (-0.5900435899266435) * x * (xx - 3.0 * yy) * shs[base_idx + 15]
result = result + wp.vec3(0.5, 0.5, 0.5)
# Track which color channels are clamped (using wp.vec3 instead of separate uint32 values)
# Store 1.0 if clamped, 0.0 if not clamped
# Use separate assignments instead of conditional expressions
r_clamped = 0.0
g_clamped = 0.0
b_clamped = 0.0
if result[0] < 0.0:
r_clamped = 1.0
if result[1] < 0.0:
g_clamped = 1.0
if result[2] < 0.0:
b_clamped = 1.0
clamped_state[i] = wp.vec3(r_clamped, g_clamped, b_clamped)
if clamped:
# RGB colors are clamped to positive values
result = wp.vec3(
wp.max(result[0], 0.0),
wp.max(result[1], 0.0),
wp.max(result[2], 0.0)
)
rgb[i] = result
# Store computed data
depths[i] = p_view[2]
radii[i] = int(my_radius)
points_xy_image[i] = point_image
# Pack conic and opacity into single vec4
conic_opacity[i] = wp.vec4(conic[0], conic[1], conic[2], opacities[i])
# Store tile information
tiles_touched[i] = (rect_max_y - rect_min_y) * (rect_max_x - rect_min_x)
@wp.kernel
def wp_render_gaussians(
# Output buffers
rendered_image: wp.array2d(dtype=wp.vec3),
depth_image: wp.array2d(dtype=float),
# Tile data
ranges: wp.array(dtype=wp.vec2i),
point_list: wp.array(dtype=int),
# Image parameters
W: int,
H: int,
# Gaussian data
points_xy_image: wp.array(dtype=wp.vec2),
colors: wp.array(dtype=wp.vec3),
conic_opacity: wp.array(dtype=wp.vec4),
depths: wp.array(dtype=float),
# Background color
background: wp.vec3,
# Tile grid info
tile_grid: wp.vec3,
# Track additional data
final_Ts: wp.array2d(dtype=float),
n_contrib: wp.array2d(dtype=int),
):
tile_x, tile_y, tid_x, tid_y = wp.tid()
# Calculate tile index
if tile_y >= (H + TILE_N - 1) // TILE_N:
return
# Calculate pixel boundaries for this tile
pix_min_x = tile_x * TILE_M
pix_min_y = tile_y * TILE_N
pix_max_x = wp.min(pix_min_x + TILE_M, W)
pix_max_y = wp.min(pix_min_y + TILE_N, H)
# Calculate pixel position for this thread
pix_x = pix_min_x + tid_x
pix_y = pix_min_y + tid_y
# Check if this thread processes a valid pixel
inside = (pix_x < W) and (pix_y < H)
if not inside:
return
pixf_x = float(pix_x)
pixf_y = float(pix_y)
# Get start/end range of IDs to process for this tile
tile_id = tile_y * int(tile_grid[0]) + tile_x
range_start = ranges[tile_id][0]
range_end = ranges[tile_id][1]
# Initialize blending variables
T = float(1.0) # Transmittance
r, g, b = float(0.0), float(0.0), float(0.0) # Accumulated color
expected_depth = float(0.0) # For depth calculation (alpha-weighted Z)
# Track the number of contributors to this pixel
contributor_count = int(0)
last_contributor = int(0)
# Iterate over all Gaussians influencing this tile
for i in range(range_start, range_end):
# Get Gaussian ID
gaussian_id = point_list[i]
# Get Gaussian data
xy = points_xy_image[gaussian_id]
con_o = conic_opacity[gaussian_id]
color = colors[gaussian_id]
# Compute distance to Gaussian center
d_x = xy[0] - pixf_x
d_y = xy[1] - pixf_y
# Increment contributor count for this pixel
contributor_count += 1
# Compute Gaussian power (exponent)
power = -0.5 * (con_o[0] * d_x * d_x + con_o[2] * d_y * d_y) - con_o[1] * d_x * d_y
# Skip if power is positive (too far away)
if power > 0.0:
continue
# Compute alpha from power and opacity
alpha = wp.min(0.99, con_o[3] * wp.exp(power))
# Skip if alpha is too small
if alpha < (1.0 / 255.0):
continue
# Test if we're close to fully opaque
test_T = T * (1.0 - alpha)
if test_T < 0.0001:
break # Early termination if pixel is almost opaque
# Accumulate color contribution
r += color[0] * alpha * T
g += color[1] * alpha * T
b += color[2] * alpha * T
# Accumulate depth (alpha-weighted Z, same units as target depth)
expected_depth += depths[gaussian_id] * alpha * T
# Update transmittance
T = test_T
last_contributor = contributor_count
# Store final transmittance (T) and contributor count
final_Ts[pix_y, pix_x] = T
n_contrib[pix_y, pix_x] = last_contributor
# Write final color to output buffer (color + background)
rendered_image[pix_y, pix_x] = wp.vec3(
r + T * background[0],
g + T * background[1],
b + T * background[2]
)
# Write depth to output buffer (alpha-weighted expected depth)
depth_image[pix_y, pix_x] = expected_depth
@wp.kernel
def wp_duplicate_with_keys(
points_xy_image: wp.array(dtype=wp.vec2),
depths: wp.array(dtype=float),
point_offsets: wp.array(dtype=int),
point_list_keys_unsorted: wp.array(dtype=wp.int64),
point_list_unsorted: wp.array(dtype=int),
radii: wp.array(dtype=int),
tile_grid: wp.vec3
):
tid = wp.tid()
if tid >= points_xy_image.shape[0]:
return
r = radii[tid]
if r <= 0:
return
# Find the global offset into key/value buffers
offset = 0
if tid > 0:
offset = point_offsets[tid - 1]
pos = points_xy_image[tid]
depth_val = depths[tid]
rect_min_x, rect_min_y, rect_max_x, rect_max_y = get_rect(pos, float(r), tile_grid)
for y in range(rect_min_y, rect_max_y):
for x in range(rect_min_x, rect_max_x):
tile_id = y * int(tile_grid[0]) + x
# Convert to int64 to avoid overflow during bit shift
tile_id_64 = wp.int64(tile_id)
shifted = tile_id_64 << wp.int64(32)
depth_bits = wp.int64(float_bits_to_uint32(depth_val))
# Combine tile ID and depth into single key
key = wp.int64(shifted) | depth_bits
point_list_keys_unsorted[offset] = key
point_list_unsorted[offset] = tid
offset += 1
@wp.kernel
def wp_identify_tile_ranges(
num_rendered: int,
point_list_keys: wp.array(dtype=wp.int64),
ranges: wp.array(dtype=wp.vec2i) # Each range is (start, end)
):
idx = wp.tid()
if idx >= num_rendered:
return
key = point_list_keys[idx]
curr_tile = int(key >> wp.int64(32))
# Set start of range if first element or tile changed
if idx == 0:
ranges[curr_tile][0] = 0
else:
prev_key = point_list_keys[idx - 1]
prev_tile = int(prev_key >> wp.int64(32))
if curr_tile != prev_tile:
ranges[prev_tile][1] = idx
ranges[curr_tile][0] = idx
# Set end of range if last element
if idx == num_rendered - 1:
ranges[curr_tile][1] = num_rendered
@wp.kernel
def wp_prefix_sum(input_array: wp.array(dtype=int),
output_array: wp.array(dtype=int)):
tid = wp.tid()
if tid == 0:
output_array[0] = input_array[0]
# Perform prefix sum
for i in range(1, input_array.shape[0]):
output_array[i] = output_array[i-1] + input_array[i]
@wp.kernel
def wp_copy_int64(src: wp.array(dtype=wp.int64), dst: wp.array(dtype=wp.int64), count: int):
i = wp.tid()
if i < count:
dst[i] = src[i]
@wp.kernel
def wp_copy_int(src: wp.array(dtype=int), dst: wp.array(dtype=int), count: int):
i = wp.tid()
if i < count:
dst[i] = src[i]
@wp.kernel
def track_pixel_stats(
rendered_image: wp.array2d(dtype=wp.vec3),
depth_image: wp.array2d(dtype=float),
background: wp.vec3,
final_Ts: wp.array2d(dtype=float),
n_contrib: wp.array2d(dtype=int),
W: int,
H: int
):
"""Kernel to track final transparency values and contributor counts for each pixel."""
x, y = wp.tid()
if x >= W or y >= H:
return
# Get the rendered pixel
pixel = rendered_image[y, x]
# Calculate approximate alpha transparency by checking for background contribution
# If the pixel has no contribution from background, final_T should be close to 0
# If it's mostly background, final_T will be close to 1
diff_r = abs(pixel[0] - background[0])
diff_g = abs(pixel[1] - background[1])
diff_b = abs(pixel[2] - background[2])
has_content = (diff_r > 0.01) or (diff_g > 0.01) or (diff_b > 0.01)
if has_content:
# Approximate final_T - in a real scenario this should already be tracked during rendering
# We're just making sure it's populated for existing renderings
if final_Ts[y, x] == 0.0:
# If final_Ts hasn't been set during rendering, approximate it
# Higher difference from background means lower T
max_diff = max(diff_r, max(diff_g, diff_b))
final_Ts[y, x] = 1.0 - min(0.99, max_diff)
# Set n_contrib to 1 if we know the pixel has content but no contributor count
if n_contrib[y, x] == 0:
n_contrib[y, x] = 1
def render_gaussians(
background,
means3D,
colors=None,
opacity=None,
scales=None,
rotations=None,
scale_modifier=1.0,
viewmatrix=None,
projmatrix=None,
tan_fovx=0.5,
tan_fovy=0.5,
image_height=256,
image_width=256,
sh=None,
degree=3,
campos=None,
prefiltered=False,
antialiasing=False,
clamped=True,
debug=False,
):
"""Render 3D Gaussians using Warp.
Args:
background: Background color tensor of shape (3,)
means3D: 3D positions tensor of shape (N, 3)
colors: Optional RGB colors tensor of shape (N, 3)
opacity: Opacity values tensor of shape (N, 1) or (N,)
scales: Scales tensor of shape (N, 3)
rotations: Rotation quaternions of shape (N, 4)
scale_modifier: Global scale modifier (float)
viewmatrix: View matrix tensor of shape (4, 4)
projmatrix: Projection matrix tensor of shape (4, 4)
tan_fovx: Tangent of the horizontal field of view
tan_fovy: Tangent of the vertical field of view
image_height: Height of the output image
image_width: Width of the output image
sh: Spherical harmonics coefficients tensor of shape (N, D, 3)
degree: Degree of spherical harmonics
campos: Camera position tensor of shape (3,)
prefiltered: Whether input Gaussians are prefiltered
antialiasing: Whether to apply antialiasing
clamped: Whether to clamp the colors
debug: Whether to print debug information
Returns:
Tuple of (rendered_image, depth_image, intermediate_buffers)
"""
rendered_image = wp.zeros((image_height, image_width), dtype=wp.vec3, device=DEVICE)
depth_image = wp.zeros((image_height, image_width), dtype=float, device=DEVICE)
# Create additional buffers for tracking transparency and contributors
final_Ts = wp.zeros((image_height, image_width), dtype=float, device=DEVICE)
n_contrib = wp.zeros((image_height, image_width), dtype=int, device=DEVICE)
background_warp = wp.vec3(background[0], background[1], background[2])
points_warp = to_warp_array(means3D, wp.vec3)#(device=DEVICE)
# SH coefficients should be shape (n, 16, 3)
# Convert to a flattened array but preserve the structure
sh_data = sh.reshape(-1, 3) if hasattr(sh, 'reshape') else sh
shs_warp = to_warp_array(sh_data, wp.vec3)#.to(device=DEVICE)
# Handle other parameters
opacities_warp = to_warp_array(opacity, float, flatten=True)#.to(device=DEVICE)
scales_warp = to_warp_array(scales, wp.vec3)#.to(device=DEVICE)
rotations_warp = to_warp_array(rotations, wp.vec4)#.to(device=DEVICE)
# Handle camera parameters
view_matrix_warp = wp.mat44(viewmatrix.flatten()) if not isinstance(viewmatrix, wp.mat44) else viewmatrix
proj_matrix_warp = wp.mat44(projmatrix.flatten()) if not isinstance(projmatrix, wp.mat44) else projmatrix
campos_warp = wp.vec3(campos[0], campos[1], campos[2]) if not isinstance(campos, wp.vec3) else campos
# Calculate tile grid for spatial optimization
tile_grid = wp.vec3((image_width + TILE_M - 1) // TILE_M,
(image_height + TILE_N - 1) // TILE_N,
1)
# Preallocate buffers for preprocessed data
num_points = points_warp.shape[0]
radii = wp.zeros(num_points, dtype=int, device=DEVICE)
points_xy_image = wp.zeros(num_points, dtype=wp.vec2, device=DEVICE)
depths = wp.zeros(num_points, dtype=float, device=DEVICE)
cov3Ds = wp.zeros(num_points, dtype=VEC6, device=DEVICE)
rgb = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
conic_opacity = wp.zeros(num_points, dtype=wp.vec4, device=DEVICE)
tiles_touched = wp.zeros(num_points, dtype=int, device=DEVICE)
# Add clamped_state buffer to track which color channels are clamped
clamped_state = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
if debug:
print(f"\nWARP RENDERING: {image_width}x{image_height} image, {num_points} gaussians")
print(f"Colors: {'from SH' if colors is None else 'provided'}, SH degree: {degree}")
print(f"Antialiasing: {antialiasing}, Prefiltered: {prefiltered}")
# Launch preprocessing kernel
wp.launch(
kernel=wp_preprocess,
dim=(num_points,),
inputs=[
points_warp, # orig_points
scales_warp, # scales
scale_modifier, # scale_modifier
rotations_warp, # rotations_quat
opacities_warp, # opacities
shs_warp, # shs
degree,
clamped, # clamped
view_matrix_warp, # view_matrix
proj_matrix_warp, # proj_matrix
campos_warp, # cam_pos
image_width, # W
image_height, # H
tan_fovx, # tan_fovx
tan_fovy, # tan_fovy
image_width / (2.0 * tan_fovx), # focal_x
image_height / (2.0 * tan_fovy), # focal_y
radii, # radii
points_xy_image, # points_xy_image
depths, # depths
cov3Ds, # cov3Ds
rgb, # rgb
conic_opacity, # conic_opacity
tile_grid, # tile_grid
tiles_touched, # tiles_touched
clamped_state, # clamped_state - now using wp.vec3
prefiltered, # prefiltered
antialiasing # antialiasing
],
)
point_offsets = wp.zeros(num_points, dtype=int, device=DEVICE)
wp.launch(
kernel=wp_prefix_sum,
dim=1,
inputs=[
tiles_touched,
point_offsets
]
)
num_rendered = int(wp.to_torch(point_offsets)[-1].item()) # total number of duplicated entries
if num_rendered > (1 << 30):
# radix sort needs 2x memory
raise ValueError("Number of rendered points exceeds the maximum supported by Warp.")
point_list_keys_unsorted = wp.zeros(num_rendered, dtype=wp.int64, device=DEVICE)
point_list_unsorted = wp.zeros(num_rendered, dtype=int, device=DEVICE)
point_list_keys = wp.zeros(num_rendered, dtype=wp.int64, device=DEVICE)
point_list = wp.zeros(num_rendered, dtype=int, device=DEVICE)
wp.launch(
kernel=wp_duplicate_with_keys,
dim=num_points,
inputs=[
points_xy_image,
depths,
point_offsets,
point_list_keys_unsorted,
point_list_unsorted,
radii,
tile_grid
]
)#
point_list_keys_unsorted_padded = wp.zeros(num_rendered * 2, dtype=wp.int64, device=DEVICE)
point_list_unsorted_padded = wp.zeros(num_rendered * 2, dtype=int, device=DEVICE)
# Copy data to padded arrays
wp.copy(point_list_keys_unsorted_padded, point_list_keys_unsorted)
wp.copy(point_list_unsorted_padded, point_list_unsorted)
wp.utils.radix_sort_pairs(
point_list_keys_unsorted_padded, # keys to sort
point_list_unsorted_padded, # values to sort along with keys
num_rendered # number of elements to sort
)
wp.launch(
kernel=wp_copy_int64,
dim=num_rendered,
inputs=[
point_list_keys_unsorted_padded,
point_list_keys,
num_rendered
]
)
wp.launch(
kernel=wp_copy_int,
dim=num_rendered,
inputs=[
point_list_unsorted_padded,
point_list,
num_rendered
]
)
tile_count = int(tile_grid[0] * tile_grid[1])
ranges = wp.zeros(tile_count, dtype=wp.vec2i, device=DEVICE) # each is (start, end)
if num_rendered > 0:
wp.launch(
kernel=wp_identify_tile_ranges, # You also need this kernel
dim=num_rendered,
inputs=[
num_rendered,
point_list_keys,
ranges
]
)
wp.launch(
kernel=wp_render_gaussians,
dim=(int(tile_grid[0]), int(tile_grid[1]), TILE_M, TILE_N),
inputs=[
rendered_image, # Output color image
depth_image, # Output depth image
ranges, # Tile ranges
point_list, # Sorted point indices
image_width, # Image width
image_height, # Image height
points_xy_image, # 2D points
rgb, # Precomputed colors
conic_opacity, # Conic matrices and opacities
depths, # Depth values
background_warp, # Background color
tile_grid, # Tile grid configuration
final_Ts, # Final transparency values
n_contrib, # Number of contributors per pixel
]
)
# Launch the pixel stats tracking kernel as a fallback
# to make sure final_Ts and n_contrib are populated
# This is especially important for existing rendered pixels
wp.launch(
kernel=track_pixel_stats,
dim=(image_width, image_height),
inputs=[
rendered_image,
depth_image,
background_warp,
final_Ts,
n_contrib,
image_width,
image_height
]
)
return rendered_image, depth_image, {
"radii": radii,
"point_offsets": point_offsets,
"points_xy_image": points_xy_image,
"depths": depths,
"colors": rgb,
"cov3Ds": cov3Ds,
"conic_opacity": conic_opacity,
"point_list": point_list,
"ranges": ranges,
"final_Ts": final_Ts, # Add final_Ts to intermediate buffers
"n_contrib": n_contrib, # Add contributor count to intermediate buffers
"clamped_state": clamped_state # Add clamped state to intermediate buffers
}