| # DeepGEMM |
|
|
| DeepGEMM kernel for the [Hugging Face kernel-builder](https://github.com/huggingface/kernels) infrastructure. |
|
|
| This package provides FP8/FP4/BF16 GEMM kernels, einsum, attention, and hyperconnection operations |
| from [DeepSeek-AI/DeepGEMM](https://github.com/DeepSeek-AI/DeepGEMM), adapted to the kernels-community |
| build structure with torch library bindings. |
|
|
| ## Features |
|
|
| - **FP8/FP4 GEMMs**: NT, NN, TN, TT variants with M-grouped and K-grouped support |
| - **BF16 GEMMs**: NT, NN, TN, TT variants with M-grouped and K-grouped support |
| - **cuBLASLt GEMMs**: NT, NN, TN, TT wrappers |
| - **Einsum**: bmk,bnk->mn, bhr,hdr->bhd, bhd,hdr->bhr expressions (BF16 and FP8) |
| - **Attention**: FP8 MQA logits (regular and paged) |
| - **Hyperconnection**: TF32 prenorm GEMM |
| - **Layout utilities**: Scaling factor transformations, TMA alignment |
|
|
| ## Architecture Support |
|
|
| - SM 9.0a (Hopper / H100) |
| - SM 10.0a (Blackwell / B200) |
|
|
| ## Requirements |
|
|
| - CUDA >= 12.1 |
| - PyTorch >= 2.1 |
| - CUTLASS 3.9+ |
| - NVRTC (part of CUDA Toolkit) |
|
|
| ## Installation |
|
|
| ```bash |
| pip install kernels |
| ``` |
|
|
| ```python |
| import kernels |
| kernels.install("kernels-community/DeepGEMM") |
| ``` |
|
|
| ## Usage |
|
|
| ```python |
| import deep_gemm |
| |
| # FP8 GEMM: D = A @ B.T |
| deep_gemm.fp8_gemm_nt((a_fp8, sfa), (b_fp8, sfb), d) |
| |
| # BF16 GEMM: D = A @ B.T |
| deep_gemm.bf16_gemm_nt(a_bf16, b_bf16, d) |
| |
| # cuBLASLt GEMM |
| deep_gemm.cublaslt_gemm_nt(a, b, d) |
| ``` |
|
|
| ## JIT Compilation |
|
|
| DeepGEMM uses Just-In-Time (JIT) compilation for its CUDA kernels. The kernel |
| templates (`.cuh` files in `include/deep_gemm/`) are compiled at runtime using |
| NVCC or NVRTC. First invocations may be slower due to compilation; results are |
| cached in `~/.deep_gemm/` for subsequent calls. |
|
|
| ### CUTLASS Runtime Dependency |
|
|
| The JIT-compiled kernels depend on CUTLASS headers (`cute/`, `cutlass/`) at |
| runtime. The package will automatically search for CUTLASS in these locations: |
|
|
| 1. `DG_CUTLASS_INCLUDE` environment variable (direct path to include dir) |
| 2. `CUTLASS_HOME` environment variable (`$CUTLASS_HOME/include`) |
| 3. Bundled in the package's `include/` directory |
| 4. `CUDA_HOME/include` (some CUDA 12.8+ installs bundle `cute/`) |
| 5. `nvidia-cutlass` Python package |
|
|
| Set one of these if JIT compilation fails with missing CUTLASS headers: |
|
|
| ```bash |
| export CUTLASS_HOME=/path/to/cutlass |
| # or |
| export DG_CUTLASS_INCLUDE=/path/to/cutlass/include |
| ``` |
|
|