Instructions to use kernels-community/finegrained-fp8 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/finegrained-fp8 with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/finegrained-fp8") - Notebooks
- Google Colab
- Kaggle
Fix divide-by-zero NaN in _fp8_act_quant_kernel
#1
by JackYoung27 - opened
build/torch-cuda/act_quant.py
CHANGED
|
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
-
|
|
|
|
| 33 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
tl.store(y_ptr + offs, y)
|
| 35 |
tl.store(s_ptr + pid, s)
|
|
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
amax = tl.max(tl.abs(x))
|
| 33 |
+
s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
|
| 34 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 35 |
tl.store(y_ptr + offs, y)
|
| 36 |
tl.store(s_ptr + pid, s)
|
build/torch-rocm/act_quant.py
CHANGED
|
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
-
|
|
|
|
| 33 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
tl.store(y_ptr + offs, y)
|
| 35 |
tl.store(s_ptr + pid, s)
|
|
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
amax = tl.max(tl.abs(x))
|
| 33 |
+
s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
|
| 34 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 35 |
tl.store(y_ptr + offs, y)
|
| 36 |
tl.store(s_ptr + pid, s)
|
build/torch-xpu/act_quant.py
CHANGED
|
@@ -29,7 +29,8 @@ def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
-
|
|
|
|
| 33 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 34 |
tl.store(y_ptr + offs, y)
|
| 35 |
tl.store(s_ptr + pid, s)
|
|
|
|
| 29 |
pid = tl.program_id(axis=0)
|
| 30 |
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 31 |
x = tl.load(x_ptr + offs).to(tl.float32)
|
| 32 |
+
amax = tl.max(tl.abs(x))
|
| 33 |
+
s = tl.maximum(amax / 448.0, 1e-12) # eps floor; float8_e4m3fn max = 448
|
| 34 |
y = (x / s).to(y_ptr.dtype.element_ty)
|
| 35 |
tl.store(y_ptr + offs, y)
|
| 36 |
tl.store(s_ptr + pid, s)
|