SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
"""
Fused KNN Gather + Attention kernel.
Replaces the two-step gather-then-attend pattern in KNNAttention with a single
CUDA kernel that never materializes the [N, K, C] intermediate tensors.
Usage:
from .fused_knn_attn import fused_knn_attention # autograd Function
# Same result as the unfused version but faster and less memory
out = fused_knn_attention(q, k, v, knn_idx, scale)
"""
import torch
from torch.autograd import Function
# Try to import the compiled CUDA extension; fall back to pure-PyTorch
try:
import fused_knn_attn_cuda as _C
FUSED_KNN_ATTN_CUDA_AVAILABLE = True
except ImportError:
FUSED_KNN_ATTN_CUDA_AVAILABLE = False
class FusedKNNAttentionFunction(Function):
"""Autograd function for fused KNN gather + scaled dot-product attention.
Forward:
q, k, v: [N, C] (query, key, value features for all points)
idx: [N, K] (pre-computed KNN indices, int32)
scale: float (attention scale factor, typically head_dim ** -0.5)
Returns:
out: [N, C]
"""
@staticmethod
def forward(ctx, q, k, v, idx, scale):
# Ensure contiguous float32 for the CUDA kernel
q = q.contiguous().float()
k = k.contiguous().float()
v = v.contiguous().float()
idx = idx.contiguous().int()
N, C = q.shape
num_k = idx.shape[1]
out = torch.empty((N, C), dtype=torch.float32, device=q.device)
attn_weights = torch.empty((N, num_k), dtype=torch.float32, device=q.device)
_C.fused_knn_attn_forward_cuda(
q, k, v, idx, out, attn_weights,
N, C, num_k, float(scale)
)
ctx.save_for_backward(q, k, v, idx, attn_weights)
ctx.scale = scale
ctx.N = N
ctx.C = C
ctx.num_k = num_k
return out
@staticmethod
def backward(ctx, grad_out):
q, k, v, idx, attn_weights = ctx.saved_tensors
scale = ctx.scale
N, C, num_k = ctx.N, ctx.C, ctx.num_k
grad_out = grad_out.contiguous().float()
grad_q = torch.zeros((N, C), dtype=torch.float32, device=q.device)
grad_k = torch.zeros((N, C), dtype=torch.float32, device=q.device)
grad_v = torch.zeros((N, C), dtype=torch.float32, device=q.device)
_C.fused_knn_attn_backward_cuda(
grad_out, q, k, v, idx, attn_weights,
grad_q, grad_k, grad_v,
N, C, num_k, float(scale)
)
# idx and scale don't need gradients
return grad_q, grad_k, grad_v, None, None
class FusedKNNAttentionFunctionPyTorch(Function):
"""Pure-PyTorch fallback (same semantics, no CUDA extension required).
Avoids materializing full [N, K, C] by iterating over K neighbors.
Still faster than the original due to not creating [N, K, C] tensors,
but slower than the CUDA kernel.
"""
@staticmethod
def forward(ctx, q, k, v, idx, scale):
N, C = q.shape
num_k = idx.shape[1]
# Compute scores by iterating over neighbors (avoids [N, K, C] tensor)
scores = torch.empty(N, num_k, device=q.device, dtype=q.dtype)
for kk in range(num_k):
neighbor_idx = idx[:, kk].long() # [N]
k_neighbor = k[neighbor_idx] # [N, C]
scores[:, kk] = (q * k_neighbor).sum(dim=-1) * scale
attn_weights = torch.softmax(scores, dim=-1) # [N, K]
# Compute output
out = torch.zeros(N, C, device=q.device, dtype=q.dtype)
for kk in range(num_k):
neighbor_idx = idx[:, kk].long()
v_neighbor = v[neighbor_idx] # [N, C]
out += attn_weights[:, kk:kk+1] * v_neighbor
ctx.save_for_backward(q, k, v, idx, attn_weights)
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_out):
q, k, v, idx, attn_weights = ctx.saved_tensors
scale = ctx.scale
N, C = q.shape
num_k = idx.shape[1]
# grad_attn[k] = dot(grad_out, V[idx[:, k]])
grad_attn = torch.empty(N, num_k, device=q.device, dtype=q.dtype)
for kk in range(num_k):
neighbor_idx = idx[:, kk].long()
v_neighbor = v[neighbor_idx]
grad_attn[:, kk] = (grad_out * v_neighbor).sum(dim=-1)
# Softmax backward: grad_scores = attn * (grad_attn - sum(attn * grad_attn))
ds = (attn_weights * grad_attn).sum(dim=-1, keepdim=True) # [N, 1]
grad_scores = attn_weights * (grad_attn - ds) # [N, K]
# grad_Q
grad_q = torch.zeros(N, C, device=q.device, dtype=q.dtype)
for kk in range(num_k):
neighbor_idx = idx[:, kk].long()
k_neighbor = k[neighbor_idx]
grad_q += grad_scores[:, kk:kk+1] * k_neighbor * scale
# grad_K (scatter add)
grad_k = torch.zeros_like(k)
for kk in range(num_k):
neighbor_idx = idx[:, kk].long()
contrib = grad_scores[:, kk:kk+1] * q * scale # [N, C]
grad_k.index_add_(0, neighbor_idx, contrib)
# grad_V (scatter add)
grad_v = torch.zeros_like(v)
for kk in range(num_k):
neighbor_idx = idx[:, kk].long()
contrib = attn_weights[:, kk:kk+1] * grad_out # [N, C]
grad_v.index_add_(0, neighbor_idx, contrib)
return grad_q, grad_k, grad_v, None, None
def fused_knn_attention(q, k, v, idx, scale):
"""Fused KNN gather + attention. Uses CUDA kernel if available, else PyTorch fallback.
Args:
q: [N, C] query features
k: [N, C] key features
v: [N, C] value features
idx: [N, K] pre-computed KNN neighbor indices (int32)
scale: attention scale factor
Returns:
out: [N, C] attention output
"""
if FUSED_KNN_ATTN_CUDA_AVAILABLE and q.is_cuda:
return FusedKNNAttentionFunction.apply(q, k, v, idx, scale)
else:
return FusedKNNAttentionFunctionPyTorch.apply(q, k, v, idx, scale)