deep-gemm / README.md
medmekk's picture
Upload folder using huggingface_hub
c67ae40 verified
# DeepGEMM
DeepGEMM kernel for the [Hugging Face kernel-builder](https://github.com/huggingface/kernels) infrastructure.
This package provides FP8/FP4/BF16 GEMM kernels, einsum, attention, and hyperconnection operations
from [DeepSeek-AI/DeepGEMM](https://github.com/DeepSeek-AI/DeepGEMM), adapted to the kernels-community
build structure with torch library bindings.
## Features
- **FP8/FP4 GEMMs**: NT, NN, TN, TT variants with M-grouped and K-grouped support
- **BF16 GEMMs**: NT, NN, TN, TT variants with M-grouped and K-grouped support
- **cuBLASLt GEMMs**: NT, NN, TN, TT wrappers
- **Einsum**: bmk,bnk->mn, bhr,hdr->bhd, bhd,hdr->bhr expressions (BF16 and FP8)
- **Attention**: FP8 MQA logits (regular and paged)
- **Hyperconnection**: TF32 prenorm GEMM
- **Layout utilities**: Scaling factor transformations, TMA alignment
## Architecture Support
- SM 9.0a (Hopper / H100)
- SM 10.0a (Blackwell / B200)
## Requirements
- CUDA >= 12.1
- PyTorch >= 2.1
- CUTLASS 3.9+
- NVRTC (part of CUDA Toolkit)
## Installation
```bash
pip install kernels
```
```python
import kernels
kernels.install("kernels-community/DeepGEMM")
```
## Usage
```python
import deep_gemm
# FP8 GEMM: D = A @ B.T
deep_gemm.fp8_gemm_nt((a_fp8, sfa), (b_fp8, sfb), d)
# BF16 GEMM: D = A @ B.T
deep_gemm.bf16_gemm_nt(a_bf16, b_bf16, d)
# cuBLASLt GEMM
deep_gemm.cublaslt_gemm_nt(a, b, d)
```
## JIT Compilation
DeepGEMM uses Just-In-Time (JIT) compilation for its CUDA kernels. The kernel
templates (`.cuh` files in `include/deep_gemm/`) are compiled at runtime using
NVCC or NVRTC. First invocations may be slower due to compilation; results are
cached in `~/.deep_gemm/` for subsequent calls.
### CUTLASS Runtime Dependency
The JIT-compiled kernels depend on CUTLASS headers (`cute/`, `cutlass/`) at
runtime. The package will automatically search for CUTLASS in these locations:
1. `DG_CUTLASS_INCLUDE` environment variable (direct path to include dir)
2. `CUTLASS_HOME` environment variable (`$CUTLASS_HOME/include`)
3. Bundled in the package's `include/` directory
4. `CUDA_HOME/include` (some CUDA 12.8+ installs bundle `cute/`)
5. `nvidia-cutlass` Python package
Set one of these if JIT compilation fails with missing CUTLASS headers:
```bash
export CUTLASS_HOME=/path/to/cutlass
# or
export DG_CUTLASS_INCLUDE=/path/to/cutlass/include
```