Spaces:
Runtime error
Runtime error
| """ | |
| Fused KNN Gather + Attention kernel. | |
| Replaces the two-step gather-then-attend pattern in KNNAttention with a single | |
| CUDA kernel that never materializes the [N, K, C] intermediate tensors. | |
| Usage: | |
| from .fused_knn_attn import fused_knn_attention # autograd Function | |
| # Same result as the unfused version but faster and less memory | |
| out = fused_knn_attention(q, k, v, knn_idx, scale) | |
| """ | |
| import torch | |
| from torch.autograd import Function | |
| # Try to import the compiled CUDA extension; fall back to pure-PyTorch | |
| try: | |
| import fused_knn_attn_cuda as _C | |
| FUSED_KNN_ATTN_CUDA_AVAILABLE = True | |
| except ImportError: | |
| FUSED_KNN_ATTN_CUDA_AVAILABLE = False | |
| class FusedKNNAttentionFunction(Function): | |
| """Autograd function for fused KNN gather + scaled dot-product attention. | |
| Forward: | |
| q, k, v: [N, C] (query, key, value features for all points) | |
| idx: [N, K] (pre-computed KNN indices, int32) | |
| scale: float (attention scale factor, typically head_dim ** -0.5) | |
| Returns: | |
| out: [N, C] | |
| """ | |
| def forward(ctx, q, k, v, idx, scale): | |
| # Ensure contiguous float32 for the CUDA kernel | |
| q = q.contiguous().float() | |
| k = k.contiguous().float() | |
| v = v.contiguous().float() | |
| idx = idx.contiguous().int() | |
| N, C = q.shape | |
| num_k = idx.shape[1] | |
| out = torch.empty((N, C), dtype=torch.float32, device=q.device) | |
| attn_weights = torch.empty((N, num_k), dtype=torch.float32, device=q.device) | |
| _C.fused_knn_attn_forward_cuda( | |
| q, k, v, idx, out, attn_weights, | |
| N, C, num_k, float(scale) | |
| ) | |
| ctx.save_for_backward(q, k, v, idx, attn_weights) | |
| ctx.scale = scale | |
| ctx.N = N | |
| ctx.C = C | |
| ctx.num_k = num_k | |
| return out | |
| def backward(ctx, grad_out): | |
| q, k, v, idx, attn_weights = ctx.saved_tensors | |
| scale = ctx.scale | |
| N, C, num_k = ctx.N, ctx.C, ctx.num_k | |
| grad_out = grad_out.contiguous().float() | |
| grad_q = torch.zeros((N, C), dtype=torch.float32, device=q.device) | |
| grad_k = torch.zeros((N, C), dtype=torch.float32, device=q.device) | |
| grad_v = torch.zeros((N, C), dtype=torch.float32, device=q.device) | |
| _C.fused_knn_attn_backward_cuda( | |
| grad_out, q, k, v, idx, attn_weights, | |
| grad_q, grad_k, grad_v, | |
| N, C, num_k, float(scale) | |
| ) | |
| # idx and scale don't need gradients | |
| return grad_q, grad_k, grad_v, None, None | |
| class FusedKNNAttentionFunctionPyTorch(Function): | |
| """Pure-PyTorch fallback (same semantics, no CUDA extension required). | |
| Avoids materializing full [N, K, C] by iterating over K neighbors. | |
| Still faster than the original due to not creating [N, K, C] tensors, | |
| but slower than the CUDA kernel. | |
| """ | |
| def forward(ctx, q, k, v, idx, scale): | |
| N, C = q.shape | |
| num_k = idx.shape[1] | |
| # Compute scores by iterating over neighbors (avoids [N, K, C] tensor) | |
| scores = torch.empty(N, num_k, device=q.device, dtype=q.dtype) | |
| for kk in range(num_k): | |
| neighbor_idx = idx[:, kk].long() # [N] | |
| k_neighbor = k[neighbor_idx] # [N, C] | |
| scores[:, kk] = (q * k_neighbor).sum(dim=-1) * scale | |
| attn_weights = torch.softmax(scores, dim=-1) # [N, K] | |
| # Compute output | |
| out = torch.zeros(N, C, device=q.device, dtype=q.dtype) | |
| for kk in range(num_k): | |
| neighbor_idx = idx[:, kk].long() | |
| v_neighbor = v[neighbor_idx] # [N, C] | |
| out += attn_weights[:, kk:kk+1] * v_neighbor | |
| ctx.save_for_backward(q, k, v, idx, attn_weights) | |
| ctx.scale = scale | |
| return out | |
| def backward(ctx, grad_out): | |
| q, k, v, idx, attn_weights = ctx.saved_tensors | |
| scale = ctx.scale | |
| N, C = q.shape | |
| num_k = idx.shape[1] | |
| # grad_attn[k] = dot(grad_out, V[idx[:, k]]) | |
| grad_attn = torch.empty(N, num_k, device=q.device, dtype=q.dtype) | |
| for kk in range(num_k): | |
| neighbor_idx = idx[:, kk].long() | |
| v_neighbor = v[neighbor_idx] | |
| grad_attn[:, kk] = (grad_out * v_neighbor).sum(dim=-1) | |
| # Softmax backward: grad_scores = attn * (grad_attn - sum(attn * grad_attn)) | |
| ds = (attn_weights * grad_attn).sum(dim=-1, keepdim=True) # [N, 1] | |
| grad_scores = attn_weights * (grad_attn - ds) # [N, K] | |
| # grad_Q | |
| grad_q = torch.zeros(N, C, device=q.device, dtype=q.dtype) | |
| for kk in range(num_k): | |
| neighbor_idx = idx[:, kk].long() | |
| k_neighbor = k[neighbor_idx] | |
| grad_q += grad_scores[:, kk:kk+1] * k_neighbor * scale | |
| # grad_K (scatter add) | |
| grad_k = torch.zeros_like(k) | |
| for kk in range(num_k): | |
| neighbor_idx = idx[:, kk].long() | |
| contrib = grad_scores[:, kk:kk+1] * q * scale # [N, C] | |
| grad_k.index_add_(0, neighbor_idx, contrib) | |
| # grad_V (scatter add) | |
| grad_v = torch.zeros_like(v) | |
| for kk in range(num_k): | |
| neighbor_idx = idx[:, kk].long() | |
| contrib = attn_weights[:, kk:kk+1] * grad_out # [N, C] | |
| grad_v.index_add_(0, neighbor_idx, contrib) | |
| return grad_q, grad_k, grad_v, None, None | |
| def fused_knn_attention(q, k, v, idx, scale): | |
| """Fused KNN gather + attention. Uses CUDA kernel if available, else PyTorch fallback. | |
| Args: | |
| q: [N, C] query features | |
| k: [N, C] key features | |
| v: [N, C] value features | |
| idx: [N, K] pre-computed KNN neighbor indices (int32) | |
| scale: attention scale factor | |
| Returns: | |
| out: [N, C] attention output | |
| """ | |
| if FUSED_KNN_ATTN_CUDA_AVAILABLE and q.is_cuda: | |
| return FusedKNNAttentionFunction.apply(q, k, v, idx, scale) | |
| else: | |
| return FusedKNNAttentionFunctionPyTorch.apply(q, k, v, idx, scale) | |