Spaces:
Sleeping
A newer version of the Gradio SDK is available: 6.13.0
CUDA Kernel Troubleshooting Guide
Overview
This guide covers common issues encountered when developing, building, and integrating custom CUDA kernels for the HuggingFace Kernels ecosystem. Issues are organized by category: build issues, performance issues, integration issues, torch.compile compatibility, and debugging tips.
Build Issues
Type Conversion Errors
Error: error: no suitable conversion function from "__nv_bfloat16" to "float" exists
This is the most common build error. CUDA does not allow implicit conversion between __nv_bfloat16 and float.
// BAD: Implicit conversion
__nv_bfloat16 val = input[idx];
float result = val * 2.0f; // ERROR: no implicit conversion
// GOOD: Explicit conversion
__nv_bfloat16 val = input[idx];
float result = __bfloat162float(val) * 2.0f;
Error: error: no operator "*" matches these operands (operand types are: __nv_bfloat16 * float)
You cannot mix BF16 and FP32 in arithmetic operations:
// BAD: Mixed types
__nv_bfloat16 a = input[idx];
float b = 2.0f;
auto c = a * b; // ERROR
// GOOD: Convert to same type first
__nv_bfloat16 a = input[idx];
float b = 2.0f;
float c = __bfloat162float(a) * b; // OK: both float
Error: error: no suitable constructor exists to convert from "float" to "__nv_bfloat16"
// BAD: Direct assignment
__nv_bfloat16 val = 1.0f; // ERROR
// GOOD: Use conversion function
__nv_bfloat16 val = __float2bfloat16(1.0f); // OK
Error: error: identifier "__nv_bfloat16" is undefined
Missing header include:
// Add this at the top of your .cu file:
#include <cuda_bf16.h> // For __nv_bfloat16, __nv_bfloat162
#include <cuda_fp16.h> // For half, half2
Error: error: calling a __host__ function from a __global__ function is not allowed
Using host-only functions in device code:
// BAD: printf is not always available in device code
__global__ void kernel() {
printf("debug: %f\n", val); // May fail depending on compute capability
}
// If you need printf in device code, ensure you compile with:
// -arch=sm_XX where XX >= 20 (all modern GPUs support this)
// BAD: std::sqrt in device code
__global__ void kernel() {
float x = std::sqrt(val); // ERROR: host function
}
// GOOD: Use CUDA math functions
__global__ void kernel() {
float x = sqrtf(val); // OK: device function
float y = rsqrtf(val); // OK: reciprocal square root
float z = expf(val); // OK: exponential
float w = tanhf(val); // OK: hyperbolic tangent
}
Missing CUDA Headers
Error: fatal error: cuda_runtime.h: No such file or directory
The CUDA toolkit is not properly installed or not in the include path.
# Check CUDA installation
nvcc --version
# Check include paths
echo $CUDA_HOME
ls $CUDA_HOME/include/cuda_runtime.h
# If CUDA_HOME is not set:
export CUDA_HOME=/usr/local/cuda
# Or find it:
which nvcc
# /usr/local/cuda-12.4/bin/nvcc -> CUDA_HOME=/usr/local/cuda-12.4
Error: fatal error: torch/extension.h: No such file or directory
PyTorch C++ headers are not found:
# Check PyTorch installation
python -c "import torch; print(torch.utils.cmake_prefix_path)"
python -c "import torch; print(torch.__path__)"
# The headers should be at:
python -c "from torch.utils.cpp_extension import include_paths; print(include_paths())"
Error: error: namespace "at" has no member "BFloat16"
Your PyTorch version may be too old:
# BFloat16 support requires PyTorch >= 1.10
python -c "import torch; print(torch.__version__)"
# Upgrade if needed:
pip install torch --upgrade
Architecture Mismatch
Error: error: --gpu-architecture=sm_90 is not supported
Your CUDA toolkit version does not support the target architecture:
| Architecture | Minimum CUDA Version |
|---|---|
| sm_75 (T4) | CUDA 10.0 |
| sm_80 (A100) | CUDA 11.0 |
| sm_86 (A10G) | CUDA 11.1 |
| sm_89 (L4) | CUDA 11.8 |
| sm_90 (H100) | CUDA 12.0 |
# Check your CUDA version
nvcc --version
# Update build.toml to match your CUDA version
# If you have CUDA 11.8, remove sm_90:
# cuda-capabilities = ["8.0"] # Not ["8.0", "9.0"]
Performance Issues
Bank Conflicts in Shared Memory
Shared memory is organized into 32 banks. When multiple threads in a warp access different addresses in the same bank, a bank conflict occurs, serializing the accesses.
Diagnosis
# Use ncu to detect bank conflicts
ncu --metrics \
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,\
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum \
python benchmark.py
Common Cause: Column-Major Access
// BAD: Bank conflicts when accessing columns
__shared__ float smem[32][32];
float val = smem[threadIdx.x][col]; // 32 threads access same bank if col is fixed
// GOOD: Pad shared memory to avoid conflicts
__shared__ float smem[32][33]; // +1 padding shifts bank alignment
float val = smem[threadIdx.x][col]; // No conflicts
Common Cause: Stride-2 Access
// BAD: Even-numbered threads hit same banks
__shared__ float smem[256];
float val = smem[threadIdx.x * 2]; // Stride-2 = 2-way bank conflict
// GOOD: Sequential access, process 2 elements per thread later
float val = smem[threadIdx.x];
Poor Occupancy
Diagnosis
# Check occupancy
ncu --metrics \
sm__warps_active.avg.pct_of_peak_sustained,\
launch__occupancy,\
launch__registers_per_thread,\
launch__block_size \
python benchmark.py
Cause: Too Many Registers
// Check register usage at compile time
// nvcc --ptxas-options=-v kernel.cu
// If register usage is high, limit it:
__global__ __launch_bounds__(256, 4) // max 256 threads, min 4 blocks/SM
void my_kernel() {
// Compiler will try to limit registers to allow 4 blocks/SM
// For H100: 2048 threads / (4 blocks * 256 threads) = 100% occupancy
// Registers per thread: 65536 / (4 * 256) = 64
}
// Or use compiler flag:
// --maxrregcount=64
Cause: Too Much Shared Memory
// If your kernel uses too much shared memory, fewer blocks can run per SM
// H100: 192 KB max shared memory per SM
// BAD: Each block uses 96 KB -> only 2 blocks per SM
__shared__ float big_buffer[24576]; // 96 KB
// GOOD: Reduce shared memory or use multiple stages
__shared__ float small_buffer[8192]; // 32 KB -> 6 blocks per SM
// Process data in multiple stages
Cause: Block Size Too Large or Small
// Use cudaOccupancyMaxPotentialBlockSize to find optimal block size
int min_grid_size, optimal_block_size;
cudaOccupancyMaxPotentialBlockSize(
&min_grid_size,
&optimal_block_size,
my_kernel,
shared_mem_bytes,
0 // no limit
);
printf("Optimal block size: %d\n", optimal_block_size);
Memory Coalescing Issues
Diagnosis
# Check memory efficiency
ncu --metrics \
l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum,\
l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum \
python benchmark.py
# Ideal ratio: sectors/requests = 1 (perfectly coalesced)
# Bad ratio: sectors/requests > 4 (poor coalescing)
Common Cause: Strided Access Pattern
// BAD: Accessing every N-th element (stride = N)
// If hidden_states is [batch, seq, hidden] and you access along batch dimension:
int idx = threadIdx.x * hidden_size; // Stride = hidden_size
float val = input[idx]; // Non-coalesced!
// GOOD: Access consecutive elements
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float val = input[idx]; // Coalesced!
// If you need to process along a non-contiguous dimension,
// transpose the data first or use shared memory:
__shared__ float smem[BLOCK_SIZE];
// Load coalesced
smem[threadIdx.x] = input[blockIdx.x * blockDim.x + threadIdx.x];
__syncthreads();
// Access with stride in shared memory (much cheaper than global)
float val = smem[threadIdx.x * stride % BLOCK_SIZE];
Common Cause: Misaligned Access
// BAD: Starting from an unaligned address
float4 data = reinterpret_cast<const float4*>(input + 1)[tid]; // Misaligned!
// GOOD: Ensure alignment
// PyTorch tensors are typically 256-byte aligned at the start
// But offsets (e.g., input + row * hidden_size) may not be float4-aligned
// Check at runtime:
assert(reinterpret_cast<uintptr_t>(ptr) % 16 == 0); // 16-byte aligned for float4
Integration Issues
NoneType Weight Error
Error: TypeError: expected Tensor as element 1 in argument 0, but got NoneType
This occurs when patching a diffusers RMSNorm that has no weight (elementwise_affine=False):
# The error happens here:
def patched_forward(hidden_states):
return cuda_rmsnorm(hidden_states, module.weight, module.eps)
# ^^^^^^^^^^^^^ This is None!
# FIX: Always check for None weight
def patched_forward(hidden_states):
if module.weight is not None:
return cuda_rmsnorm(hidden_states, module.weight, module.eps)
else:
return cuda_rmsnorm_no_weight(hidden_states, module.eps)
Important: This issue is specific to diffusers. In transformers, RMSNorm weight is always present.
GEGLU Not Called / Wrong Activation
Symptom: Model produces garbage output after kernel injection
This usually means you injected GEGLU into a model that uses GELU (or vice versa):
# LTX-Video uses GELU, NOT GEGLU
# If you inject GEGLU into LTX-Video, the dimensions will be wrong
# and the output will be garbage
# Diagnosis: check what activation the model actually uses
for name, module in model.named_modules():
if 'act' in name.lower() or 'gelu' in name.lower() or 'silu' in name.lower():
print(f"{name}: {type(module).__name__}")
# Model-specific activations:
# LTX-Video: GELU
# SD3: GEGLU
# FLUX: GEGLU
# LLaMA: SiLU (in MLP gate)
Symptom: Custom activation kernel is never called
The activation may be inlined or implemented differently than expected:
# Some models implement GEGLU as a combined module, not separate GELU + multiply
# Check if the feed-forward uses a custom class:
for name, module in model.named_modules():
if 'ff' in name or 'mlp' in name:
print(f"{name}: {type(module)}")
for child_name, child in module.named_children():
print(f" {child_name}: {type(child)}")
CPU Offloading Conflicts
Error: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
This happens when custom kernels are injected after CPU offloading is enabled:
# WRONG ORDER:
pipe.enable_model_cpu_offload()
inject_custom_kernels(pipe.transformer) # Model may be on CPU at this point!
# CORRECT ORDER:
inject_custom_kernels(pipe.transformer) # Inject first
pipe.enable_model_cpu_offload() # Then enable offloading
Symptom: Custom kernels work but model is slower than baseline
CPU offloading may be moving the model back and forth, negating kernel speedups:
# Check if offloading is active
print(hasattr(pipe, '_offload_gpu_id'))
# If using sequential offloading, each forward call triggers CPU->GPU transfer
# This overhead can dwarf any kernel speedup
# Solution: If you have enough GPU memory, don't use offloading
pipe.to("cuda") # Keep everything on GPU
isinstance Not Matching
Symptom: isinstance(module, RMSNorm) returns False even though the module looks like RMSNorm
This happens when you import RMSNorm from the wrong module:
# WRONG: Importing from the wrong place
from torch.nn import RMSNorm # This is PyTorch's RMSNorm (if it exists)
# The model uses diffusers' RMSNorm, which is a DIFFERENT class
from diffusers.models.normalization import RMSNorm # This is what you need
# Or for transformers models:
from transformers.models.llama.modeling_llama import LlamaRMSNorm
# DEBUGGING: Check what class the module actually is
for name, module in model.named_modules():
if "norm" in name.lower():
print(f"{name}: {type(module).__module__}.{type(module).__name__}")
Another cause: Dynamic module reloading
If you reload modules during development, the class identity may change:
# After importlib.reload(), isinstance checks may fail because
# the class object is different even though the code is the same
# FIX: Use string-based type checking as fallback
def is_rmsnorm(module):
class_name = type(module).__name__
return "RMSNorm" in class_name
torch.compile Compatibility
Error: torch._dynamo.exc.Unsupported: call_function ... unsupported
Custom CUDA kernels must be registered as custom ops to work with torch.compile:
# BAD: Direct kernel call -- breaks torch.compile
def rmsnorm(x, weight, eps):
return _cuda_rmsnorm(x, weight, eps) # torch.compile cannot trace this
# GOOD: Register as custom op
from torch.library import custom_op
@custom_op("mylib::rmsnorm", mutates_args=())
def rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
return _cuda_rmsnorm(input, weight, eps)
@rmsnorm.register_fake
def rmsnorm_fake(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
return torch.empty_like(input)
Error: torch._dynamo.exc.BackendCompilerFailed: ... graph break
Graph breaks prevent torch.compile from optimizing through your custom kernel:
# Check for graph breaks
import torch._dynamo
torch._dynamo.config.verbose = True
# Or use explain mode
explanation = torch._dynamo.explain(model)(input_tensor)
print(explanation)
Error: RuntimeError: Tensor-likes are not close! ... during FakeTensor propagation
The register_fake function returns incorrect shapes or dtypes:
# BAD: Wrong shape in fake implementation
@rmsnorm.register_fake
def rmsnorm_fake(input, weight, eps):
return torch.empty(input.shape[0], input.shape[1]) # Wrong shape!
# GOOD: Match the real kernel's output exactly
@rmsnorm.register_fake
def rmsnorm_fake(input, weight, eps):
return torch.empty_like(input) # Same shape, dtype, device
CUDA Graphs Issues with torch.compile
# torch.compile with mode="reduce-overhead" uses CUDA graphs
# CUDA graphs require:
# 1. Fixed input shapes (no dynamic shapes)
# 2. No CPU-GPU synchronization inside the graph
# 3. No memory allocations inside the graph
# If your kernel allocates memory, CUDA graphs will fail:
# BAD:
def my_kernel_wrapper(x):
output = torch.empty_like(x) # Allocation inside the graph!
_cuda_kernel(x, output)
return output
# GOOD: Pre-allocate or use torch's allocation
@custom_op("mylib::my_op", mutates_args=())
def my_op(x: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
_cuda_kernel(x.data_ptr(), output.data_ptr(), x.numel())
return output
Debugging Tips
Print From CUDA Kernels
// Use printf for debugging (supported on sm_20+, all modern GPUs)
__global__ void debug_kernel(const float* input, int n) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx == 0) { // Only print from one thread!
printf("First value: %f\n", input[0]);
printf("n = %d\n", n);
}
}
// WARNING: printf from many threads will be extremely slow
// Always guard with `if (idx == 0)` or similar
Check for CUDA Errors
// Always check for CUDA errors after kernel launches
#define CUDA_CHECK(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
fprintf(stderr, "CUDA error at %s:%d: %s\n", \
__FILE__, __LINE__, cudaGetErrorString(err)); \
exit(1); \
} \
} while (0)
// Usage:
my_kernel<<<grid, block>>>(args...);
CUDA_CHECK(cudaGetLastError()); // Check launch errors
CUDA_CHECK(cudaDeviceSynchronize()); // Check execution errors
In Python:
# Enable CUDA error checking
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # Synchronous launch -- errors at launch site
# Check for errors
torch.cuda.synchronize() # Forces synchronization, surfaces async errors
Memory Debugging
# Check GPU memory usage
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
print(f"Max allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
# Memory snapshot (PyTorch 2.1+)
torch.cuda.memory._record_memory_history()
# ... run your code ...
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
# Visualize at https://pytorch.org/memory_viz
NaN/Inf Detection
# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)
# Check for NaN in outputs
def check_nan(tensor, name="tensor"):
if torch.isnan(tensor).any():
print(f"NaN detected in {name}!")
print(f" Shape: {tensor.shape}")
print(f" NaN count: {torch.isnan(tensor).sum().item()}")
print(f" Min: {tensor[~torch.isnan(tensor)].min().item()}")
print(f" Max: {tensor[~torch.isnan(tensor)].max().item()}")
return True
return False
# Check intermediate values in kernels
def debug_kernel_output(kernel_fn, *args):
output = kernel_fn(*args)
has_nan = check_nan(output, "kernel output")
has_inf = torch.isinf(output).any()
if has_inf:
print(f"Inf detected! Count: {torch.isinf(output).sum().item()}")
return output
Common NaN Sources in CUDA Kernels
// 1. Division by zero
float result = a / b; // NaN if b == 0
// FIX: Add epsilon
float result = a / (b + 1e-8f);
// 2. rsqrt of zero or negative
float inv_std = rsqrtf(variance); // NaN if variance < 0, Inf if variance == 0
// FIX: Add epsilon
float inv_std = rsqrtf(variance + 1e-6f);
// 3. exp overflow
float result = expf(large_value); // Inf if large_value > ~88
// FIX: Subtract max before exp (online softmax pattern)
float result = expf(val - max_val);
// 4. log of zero or negative
float result = logf(x); // -Inf if x == 0, NaN if x < 0
// FIX: Clamp input
float result = logf(fmaxf(x, 1e-8f));
// 5. Race condition in reduction
// If __syncthreads() is missing, partial results may be NaN
__shared__ float shared_result;
if (threadIdx.x == 0) shared_result = value;
__syncthreads(); // CRITICAL: without this, other threads read garbage
float result = shared_result;
Comparing Custom Kernel vs Reference
def compare_implementations(custom_fn, reference_fn, *args, name="kernel"):
"""Compare a custom kernel against a PyTorch reference."""
# Run both
custom_out = custom_fn(*args)
ref_out = reference_fn(*args)
# Statistics
abs_diff = (custom_out - ref_out).abs()
rel_diff = abs_diff / (ref_out.abs() + 1e-8)
print(f"\n{name} comparison:")
print(f" Shape: {custom_out.shape}")
print(f" Dtype: {custom_out.dtype}")
print(f" Max absolute diff: {abs_diff.max().item():.8f}")
print(f" Mean absolute diff: {abs_diff.mean().item():.8f}")
print(f" Max relative diff: {rel_diff.max().item():.8f}")
print(f" Mean relative diff: {rel_diff.mean().item():.8f}")
# Check specific thresholds
thresholds = {
torch.float32: (1e-5, 1e-4),
torch.float16: (1e-2, 1e-2),
torch.bfloat16: (1e-2, 1e-2),
}
rtol, atol = thresholds.get(custom_out.dtype, (1e-3, 1e-3))
is_close = torch.allclose(custom_out, ref_out, rtol=rtol, atol=atol)
print(f" torch.allclose (rtol={rtol}, atol={atol}): {is_close}")
if not is_close:
# Show worst-case elements
flat_diff = abs_diff.flatten()
worst_indices = flat_diff.topk(min(5, flat_diff.numel())).indices
print(f" Worst elements:")
for idx in worst_indices:
print(f" idx={idx.item()}: custom={custom_out.flatten()[idx].item():.8f}, "
f"ref={ref_out.flatten()[idx].item():.8f}, "
f"diff={flat_diff[idx].item():.8f}")
return is_close
Profiling Specific Kernel Launches
# Use NVTX markers to identify specific kernels in profiler output
import torch.cuda.nvtx as nvtx
def profiled_forward(model, input_tensor):
nvtx.range_push("full_forward")
nvtx.range_push("rmsnorm")
x = custom_rmsnorm(input_tensor, weight, eps)
nvtx.range_pop()
nvtx.range_push("attention")
x = attention(x)
nvtx.range_pop()
nvtx.range_push("ffn")
x = ffn(x)
nvtx.range_pop()
nvtx.range_pop()
return x
# Profile with NVTX markers visible
nsys profile --trace=cuda,nvtx python profiled_script.py
Quick Sanity Checks
def sanity_check_kernel(kernel_fn, name="kernel"):
"""Quick sanity checks for a CUDA kernel."""
print(f"\n=== Sanity checks for {name} ===")
# Check 1: Basic forward pass
print("1. Basic forward pass...", end=" ")
try:
x = torch.randn(2, 64, 2048, dtype=torch.bfloat16, device="cuda")
w = torch.ones(2048, dtype=torch.bfloat16, device="cuda")
out = kernel_fn(x, w, 1e-6)
assert out.shape == x.shape, f"Shape mismatch: {out.shape} vs {x.shape}"
assert out.dtype == x.dtype, f"Dtype mismatch: {out.dtype} vs {x.dtype}"
assert out.is_cuda, "Output not on CUDA"
print("PASS")
except Exception as e:
print(f"FAIL: {e}")
# Check 2: No NaN in output
print("2. No NaN in output...", end=" ")
try:
assert not torch.isnan(out).any(), "NaN detected"
print("PASS")
except Exception as e:
print(f"FAIL: {e}")
# Check 3: No Inf in output
print("3. No Inf in output...", end=" ")
try:
assert not torch.isinf(out).any(), "Inf detected"
print("PASS")
except Exception as e:
print(f"FAIL: {e}")
# Check 4: Output is not all zeros
print("4. Non-trivial output...", end=" ")
try:
assert out.abs().max() > 1e-6, "Output is all zeros"
print("PASS")
except Exception as e:
print(f"FAIL: {e}")
# Check 5: Deterministic
print("5. Deterministic...", end=" ")
try:
out2 = kernel_fn(x, w, 1e-6)
assert torch.equal(out, out2), "Non-deterministic output"
print("PASS")
except Exception as e:
print(f"FAIL: {e}")
# Check 6: Batch independence
print("6. Batch independence...", end=" ")
try:
x1 = x[0:1]
x_full = x
out_single = kernel_fn(x1, w, 1e-6)
out_batch = kernel_fn(x_full, w, 1e-6)
assert torch.allclose(out_single[0], out_batch[0], atol=1e-6), \
"Single-batch and multi-batch outputs differ"
print("PASS")
except Exception as e:
print(f"FAIL: {e}")
print("=== Done ===\n")
Common Error Quick Reference
| Error | Likely Cause | Fix |
|---|---|---|
no suitable conversion from __nv_bfloat16 to float |
Missing explicit conversion | Use __bfloat162float() |
no operator * matches |
Mixed BF16/FP32 arithmetic | Convert to same type first |
expected Tensor, got NoneType |
RMSNorm weight is None | Check weight is not None |
Expected all tensors on same device |
CPU offloading conflict | Inject kernels before offloading |
isinstance returns False |
Wrong class imported | Import from correct module |
sm_90 not supported |
CUDA too old for H100 | Update to CUDA 12.0+ |
cuda_runtime.h not found |
Missing CUDA toolkit | Install CUDA or set CUDA_HOME |
NaN in output |
Numerical instability | Add epsilon, check for division by zero |
Graph break in torch.compile |
Unregistered custom op | Use @custom_op registration |
Tensor-likes not close |
Wrong fake tensor shape | Match shape/dtype in register_fake |
CUDA out of memory |
Tensor too large for GPU | Reduce batch size, use quantization |
Misaligned address |
Unaligned vectorized load | Check pointer alignment |
Summary
When troubleshooting CUDA kernel issues:
- Build errors: Usually type conversion issues -- use explicit
__bfloat162float()/__float2bfloat16()conversions - Performance: Profile with ncu, check bank conflicts, coalescing, and occupancy
- Integration: Watch for None weights (diffusers), wrong epsilon attribute name (transformers), and CPU offloading ordering
- torch.compile: Register kernels with
@custom_opand provide correctregister_fakeimplementations - Debugging: Use
CUDA_LAUNCH_BLOCKING=1, check for NaN/Inf, compare against PyTorch reference implementations