Instructions to use kernels-community/finegrained-fp8 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/finegrained-fp8 with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/finegrained-fp8") - Notebooks
- Google Colab
- Kaggle
Fix divide-by-zero NaN in _fp8_act_quant_kernel
Browse files# Fix divide-by-zero NaN in `_fp8_act_quant_kernel`
`s = tl.max(tl.abs(x)) / 448.0` is zero when a 128-block is all-zero. Then `y = x/s = NaN` propagates through every downstream FP8 matmul.
```diff
- s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
- y = (x / s).to(y_ptr.dtype.element_ty)
+ amax = tl.max(tl.abs(x))
+ s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
+ y = (x / s).to(y_ptr.dtype.element_ty)
```
Same hunk in `build/torch-cuda`, `build/torch-xpu`, `build/torch-rocm`. Matches `_per_token_group_quant_fp8` (vllm) and `_per_token_group_quant_8bit` (sglang).
## Trigger
Attention masks zero out hidden states at padded positions before the first FP8 quant. Fires on `Qwen/Qwen3-8B-FP8`, `Qwen/Qwen3.5-9B-FP8`, `Qwen/Qwen3.5-27B-FP8`, `RedHatAI/Qwen3-8B-FP8-block` when a batch contains any padding.
## Validation
H100 80GB, `Qwen/Qwen3.5-27B-FP8` via transformers git main, Triton path forced.
| condition | before | after |
|---|---|---|
| single prompt | clean | clean |
| batched, padded | NaN at L24/L36/L45 on every padded row | clean |
Reproducer in this PR: `tests/test_act_quant_zero_block.py` (four pytest cases, CUDA required).
## References
Originally diagnosed and retracted in huggingface/transformers#42831. The retraction was about a separate `lm-eval` padding issue; the divide-by-zero kernel bug is real and never landed a fix.
|
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
-
|
|
|
|
| 33 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
tl.store(y_ptr + offs, y)
|
| 35 |
tl.store(s_ptr + pid, s)
|
|
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
amax = tl.max(tl.abs(x))
|
| 33 |
+
s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
|
| 34 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 35 |
tl.store(y_ptr + offs, y)
|
| 36 |
tl.store(s_ptr + pid, s)
|
|
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
-
|
|
|
|
| 33 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
tl.store(y_ptr + offs, y)
|
| 35 |
tl.store(s_ptr + pid, s)
|
|
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
amax = tl.max(tl.abs(x))
|
| 33 |
+
s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
|
| 34 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 35 |
tl.store(y_ptr + offs, y)
|
| 36 |
tl.store(s_ptr + pid, s)
|
|
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
-
|
|
|
|
| 33 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
tl.store(y_ptr + offs, y)
|
| 35 |
tl.store(s_ptr + pid, s)
|
|
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
amax = tl.max(tl.abs(x))
|
| 33 |
+
s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
|
| 34 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 35 |
tl.store(y_ptr + offs, y)
|
| 36 |
tl.store(s_ptr + pid, s)
|