Kernels
JackYoung27 commited on
Commit
ec0cc2d
·
verified ·
1 Parent(s): b0b199b

Fix divide-by-zero NaN in _fp8_act_quant_kernel

Browse files

# Fix divide-by-zero NaN in `_fp8_act_quant_kernel`

`s = tl.max(tl.abs(x)) / 448.0` is zero when a 128-block is all-zero. Then `y = x/s = NaN` propagates through every downstream FP8 matmul.

```diff
- s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
- y = (x / s).to(y_ptr.dtype.element_ty)
+ amax = tl.max(tl.abs(x))
+ s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
+ y = (x / s).to(y_ptr.dtype.element_ty)
```

Same hunk in `build/torch-cuda`, `build/torch-xpu`, `build/torch-rocm`. Matches `_per_token_group_quant_fp8` (vllm) and `_per_token_group_quant_8bit` (sglang).

## Trigger

Attention masks zero out hidden states at padded positions before the first FP8 quant. Fires on `Qwen/Qwen3-8B-FP8`, `Qwen/Qwen3.5-9B-FP8`, `Qwen/Qwen3.5-27B-FP8`, `RedHatAI/Qwen3-8B-FP8-block` when a batch contains any padding.

## Validation

H100 80GB, `Qwen/Qwen3.5-27B-FP8` via transformers git main, Triton path forced.

| condition | before | after |
|---|---|---|
| single prompt | clean | clean |
| batched, padded | NaN at L24/L36/L45 on every padded row | clean |

Reproducer in this PR: `tests/test_act_quant_zero_block.py` (four pytest cases, CUDA required).

## References

Originally diagnosed and retracted in huggingface/transformers#42831. The retraction was about a separate `lm-eval` padding issue; the divide-by-zero kernel bug is real and never landed a fix.

build/torch-cuda/act_quant.py CHANGED
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
- s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
 
33
  y = (x / s).to(y_ptr.dtype.element_ty)
34
  tl.store(y_ptr + offs, y)
35
  tl.store(s_ptr + pid, s)
 
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
+ amax = tl.max(tl.abs(x))
33
+ s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
34
  y = (x / s).to(y_ptr.dtype.element_ty)
35
  tl.store(y_ptr + offs, y)
36
  tl.store(s_ptr + pid, s)
build/torch-rocm/act_quant.py CHANGED
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
- s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
 
33
  y = (x / s).to(y_ptr.dtype.element_ty)
34
  tl.store(y_ptr + offs, y)
35
  tl.store(s_ptr + pid, s)
 
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
+ amax = tl.max(tl.abs(x))
33
+ s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
34
  y = (x / s).to(y_ptr.dtype.element_ty)
35
  tl.store(y_ptr + offs, y)
36
  tl.store(s_ptr + pid, s)
build/torch-xpu/act_quant.py CHANGED
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
- s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
 
33
  y = (x / s).to(y_ptr.dtype.element_ty)
34
  tl.store(y_ptr + offs, y)
35
  tl.store(s_ptr + pid, s)
 
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
+ amax = tl.max(tl.abs(x))
33
+ s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
34
  y = (x / s).to(y_ptr.dtype.element_ty)
35
  tl.store(y_ptr + offs, y)
36
  tl.store(s_ptr + pid, s)