Kernels

Fix divide-by-zero NaN in _fp8_act_quant_kernel

#1
build/torch-cuda/act_quant.py CHANGED
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
- s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
 
33
  y = (x / s).to(y_ptr.dtype.element_ty)
34
  tl.store(y_ptr + offs, y)
35
  tl.store(s_ptr + pid, s)
 
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
+ amax = tl.max(tl.abs(x))
33
+ s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
34
  y = (x / s).to(y_ptr.dtype.element_ty)
35
  tl.store(y_ptr + offs, y)
36
  tl.store(s_ptr + pid, s)
build/torch-rocm/act_quant.py CHANGED
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
- s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
 
33
  y = (x / s).to(y_ptr.dtype.element_ty)
34
  tl.store(y_ptr + offs, y)
35
  tl.store(s_ptr + pid, s)
 
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
+ amax = tl.max(tl.abs(x))
33
+ s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
34
  y = (x / s).to(y_ptr.dtype.element_ty)
35
  tl.store(y_ptr + offs, y)
36
  tl.store(s_ptr + pid, s)
build/torch-xpu/act_quant.py CHANGED
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
- s = tl.max(tl.abs(x)) / 448.0 # float8_e4m3fn max
 
33
  y = (x / s).to(y_ptr.dtype.element_ty)
34
  tl.store(y_ptr + offs, y)
35
  tl.store(s_ptr + pid, s)
 
29
  pid = tl.program_id(axis=0)
30
  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31
  x = tl.load(x_ptr + offs).to(tl.float32)
32
+ amax = tl.max(tl.abs(x))
33
+ s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
34
  y = (x / s).to(y_ptr.dtype.element_ty)
35
  tl.store(y_ptr + offs, y)
36
  tl.store(s_ptr + pid, s)