Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -0
- README.md +81 -0
- build.toml +100 -0
- build/torch210-cxx11-cu126-x86_64-linux/__init__.py +289 -0
- build/torch210-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so +3 -0
- build/torch210-cxx11-cu126-x86_64-linux/_ops.py +9 -0
- build/torch210-cxx11-cu126-x86_64-linux/deep_gemm/__init__.py +26 -0
- build/torch210-cxx11-cu126-x86_64-linux/metadata.json +3 -0
- build/torch210-cxx11-cu126-x86_64-linux/testing/__init__.py +4 -0
- build/torch210-cxx11-cu126-x86_64-linux/testing/bench.py +137 -0
- build/torch210-cxx11-cu126-x86_64-linux/testing/numeric.py +21 -0
- build/torch210-cxx11-cu126-x86_64-linux/testing/utils.py +38 -0
- build/torch210-cxx11-cu126-x86_64-linux/utils/__init__.py +3 -0
- build/torch210-cxx11-cu126-x86_64-linux/utils/layout.py +25 -0
- build/torch210-cxx11-cu126-x86_64-linux/utils/math.py +107 -0
- build/torch210-cxx11-cu128-x86_64-linux/__init__.py +289 -0
- build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so +3 -0
- build/torch210-cxx11-cu128-x86_64-linux/_ops.py +9 -0
- build/torch210-cxx11-cu128-x86_64-linux/deep_gemm/__init__.py +26 -0
- build/torch210-cxx11-cu128-x86_64-linux/metadata.json +3 -0
- build/torch210-cxx11-cu128-x86_64-linux/testing/__init__.py +4 -0
- build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py +137 -0
- build/torch210-cxx11-cu128-x86_64-linux/testing/numeric.py +21 -0
- build/torch210-cxx11-cu128-x86_64-linux/testing/utils.py +38 -0
- build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py +3 -0
- build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py +25 -0
- build/torch210-cxx11-cu128-x86_64-linux/utils/math.py +107 -0
- build/torch210-cxx11-cu130-x86_64-linux/__init__.py +289 -0
- build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so +3 -0
- build/torch210-cxx11-cu130-x86_64-linux/_ops.py +9 -0
- build/torch210-cxx11-cu130-x86_64-linux/deep_gemm/__init__.py +26 -0
- build/torch210-cxx11-cu130-x86_64-linux/metadata.json +3 -0
- build/torch210-cxx11-cu130-x86_64-linux/testing/__init__.py +4 -0
- build/torch210-cxx11-cu130-x86_64-linux/testing/bench.py +137 -0
- build/torch210-cxx11-cu130-x86_64-linux/testing/numeric.py +21 -0
- build/torch210-cxx11-cu130-x86_64-linux/testing/utils.py +38 -0
- build/torch210-cxx11-cu130-x86_64-linux/utils/__init__.py +3 -0
- build/torch210-cxx11-cu130-x86_64-linux/utils/layout.py +25 -0
- build/torch210-cxx11-cu130-x86_64-linux/utils/math.py +107 -0
- build/torch29-cxx11-cu126-x86_64-linux/__init__.py +289 -0
- build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so +3 -0
- build/torch29-cxx11-cu126-x86_64-linux/_ops.py +9 -0
- build/torch29-cxx11-cu126-x86_64-linux/deep_gemm/__init__.py +26 -0
- build/torch29-cxx11-cu126-x86_64-linux/metadata.json +3 -0
- build/torch29-cxx11-cu126-x86_64-linux/testing/__init__.py +4 -0
- build/torch29-cxx11-cu126-x86_64-linux/testing/bench.py +137 -0
- build/torch29-cxx11-cu126-x86_64-linux/testing/numeric.py +21 -0
- build/torch29-cxx11-cu126-x86_64-linux/testing/utils.py +38 -0
- build/torch29-cxx11-cu126-x86_64-linux/utils/__init__.py +3 -0
- build/torch29-cxx11-cu126-x86_64-linux/utils/layout.py +25 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
build/torch210-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepGEMM
|
| 2 |
+
|
| 3 |
+
DeepGEMM kernel for the [Hugging Face kernel-builder](https://github.com/huggingface/kernels) infrastructure.
|
| 4 |
+
|
| 5 |
+
This package provides FP8/FP4/BF16 GEMM kernels, einsum, attention, and hyperconnection operations
|
| 6 |
+
from [DeepSeek-AI/DeepGEMM](https://github.com/DeepSeek-AI/DeepGEMM), adapted to the kernels-community
|
| 7 |
+
build structure with torch library bindings.
|
| 8 |
+
|
| 9 |
+
## Features
|
| 10 |
+
|
| 11 |
+
- **FP8/FP4 GEMMs**: NT, NN, TN, TT variants with M-grouped and K-grouped support
|
| 12 |
+
- **BF16 GEMMs**: NT, NN, TN, TT variants with M-grouped and K-grouped support
|
| 13 |
+
- **cuBLASLt GEMMs**: NT, NN, TN, TT wrappers
|
| 14 |
+
- **Einsum**: bmk,bnk->mn, bhr,hdr->bhd, bhd,hdr->bhr expressions (BF16 and FP8)
|
| 15 |
+
- **Attention**: FP8 MQA logits (regular and paged)
|
| 16 |
+
- **Hyperconnection**: TF32 prenorm GEMM
|
| 17 |
+
- **Layout utilities**: Scaling factor transformations, TMA alignment
|
| 18 |
+
|
| 19 |
+
## Architecture Support
|
| 20 |
+
|
| 21 |
+
- SM 9.0a (Hopper / H100)
|
| 22 |
+
- SM 10.0a (Blackwell / B200)
|
| 23 |
+
|
| 24 |
+
## Requirements
|
| 25 |
+
|
| 26 |
+
- CUDA >= 12.1
|
| 27 |
+
- PyTorch >= 2.1
|
| 28 |
+
- CUTLASS 3.9+
|
| 29 |
+
- NVRTC (part of CUDA Toolkit)
|
| 30 |
+
|
| 31 |
+
## Installation
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
pip install kernels
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
import kernels
|
| 39 |
+
kernels.install("kernels-community/DeepGEMM")
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Usage
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
import deep_gemm
|
| 46 |
+
|
| 47 |
+
# FP8 GEMM: D = A @ B.T
|
| 48 |
+
deep_gemm.fp8_gemm_nt((a_fp8, sfa), (b_fp8, sfb), d)
|
| 49 |
+
|
| 50 |
+
# BF16 GEMM: D = A @ B.T
|
| 51 |
+
deep_gemm.bf16_gemm_nt(a_bf16, b_bf16, d)
|
| 52 |
+
|
| 53 |
+
# cuBLASLt GEMM
|
| 54 |
+
deep_gemm.cublaslt_gemm_nt(a, b, d)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## JIT Compilation
|
| 58 |
+
|
| 59 |
+
DeepGEMM uses Just-In-Time (JIT) compilation for its CUDA kernels. The kernel
|
| 60 |
+
templates (`.cuh` files in `include/deep_gemm/`) are compiled at runtime using
|
| 61 |
+
NVCC or NVRTC. First invocations may be slower due to compilation; results are
|
| 62 |
+
cached in `~/.deep_gemm/` for subsequent calls.
|
| 63 |
+
|
| 64 |
+
### CUTLASS Runtime Dependency
|
| 65 |
+
|
| 66 |
+
The JIT-compiled kernels depend on CUTLASS headers (`cute/`, `cutlass/`) at
|
| 67 |
+
runtime. The package will automatically search for CUTLASS in these locations:
|
| 68 |
+
|
| 69 |
+
1. `DG_CUTLASS_INCLUDE` environment variable (direct path to include dir)
|
| 70 |
+
2. `CUTLASS_HOME` environment variable (`$CUTLASS_HOME/include`)
|
| 71 |
+
3. Bundled in the package's `include/` directory
|
| 72 |
+
4. `CUDA_HOME/include` (some CUDA 12.8+ installs bundle `cute/`)
|
| 73 |
+
5. `nvidia-cutlass` Python package
|
| 74 |
+
|
| 75 |
+
Set one of these if JIT compilation fails with missing CUTLASS headers:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
export CUTLASS_HOME=/path/to/cutlass
|
| 79 |
+
# or
|
| 80 |
+
export DG_CUTLASS_INCLUDE=/path/to/cutlass/include
|
| 81 |
+
```
|
build.toml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "deep_gemm"
|
| 3 |
+
backends = ["cuda"]
|
| 4 |
+
|
| 5 |
+
[general.hub]
|
| 6 |
+
repo-id = "kernels-community/DeepGEMM"
|
| 7 |
+
|
| 8 |
+
[torch]
|
| 9 |
+
src = [
|
| 10 |
+
"torch-ext/torch_binding.cpp",
|
| 11 |
+
"torch-ext/torch_binding.h",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
[kernel.deep_gemm]
|
| 15 |
+
backend = "cuda"
|
| 16 |
+
cuda-capabilities = [
|
| 17 |
+
"9.0a",
|
| 18 |
+
"10.0a",
|
| 19 |
+
]
|
| 20 |
+
cxx-flags = [
|
| 21 |
+
"-std=c++17",
|
| 22 |
+
"-O3",
|
| 23 |
+
"-Wno-psabi",
|
| 24 |
+
"-Wno-deprecated-declarations",
|
| 25 |
+
]
|
| 26 |
+
depends = [
|
| 27 |
+
"torch",
|
| 28 |
+
"cutlass_3_9",
|
| 29 |
+
]
|
| 30 |
+
include = [
|
| 31 |
+
".",
|
| 32 |
+
"csrc",
|
| 33 |
+
"deep_gemm/include",
|
| 34 |
+
"third-party/fmt/include",
|
| 35 |
+
]
|
| 36 |
+
src = [
|
| 37 |
+
"csrc/deep_gemm_impl.cpp",
|
| 38 |
+
"csrc/apis/attention.hpp",
|
| 39 |
+
"csrc/apis/einsum.hpp",
|
| 40 |
+
"csrc/apis/gemm.hpp",
|
| 41 |
+
"csrc/apis/hyperconnection.hpp",
|
| 42 |
+
"csrc/apis/layout.hpp",
|
| 43 |
+
"csrc/apis/runtime.hpp",
|
| 44 |
+
"csrc/jit/cache.hpp",
|
| 45 |
+
"csrc/jit/compiler.hpp",
|
| 46 |
+
"csrc/jit/device_runtime.hpp",
|
| 47 |
+
"csrc/jit/handle.hpp",
|
| 48 |
+
"csrc/jit/kernel_runtime.hpp",
|
| 49 |
+
"csrc/jit_kernels/heuristics/common.hpp",
|
| 50 |
+
"csrc/jit_kernels/heuristics/sm90.hpp",
|
| 51 |
+
"csrc/jit_kernels/heuristics/sm100.hpp",
|
| 52 |
+
"csrc/jit_kernels/impls/epilogue.hpp",
|
| 53 |
+
"csrc/jit_kernels/impls/runtime_utils.hpp",
|
| 54 |
+
"csrc/jit_kernels/impls/sm90_bf16_gemm.hpp",
|
| 55 |
+
"csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp",
|
| 56 |
+
"csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp",
|
| 57 |
+
"csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp",
|
| 58 |
+
"csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp",
|
| 59 |
+
"csrc/jit_kernels/impls/sm100_bf16_gemm.hpp",
|
| 60 |
+
"csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp",
|
| 61 |
+
"csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp",
|
| 62 |
+
"csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp",
|
| 63 |
+
"csrc/jit_kernels/impls/smxx_clean_logits.hpp",
|
| 64 |
+
"csrc/jit_kernels/impls/smxx_cublaslt.hpp",
|
| 65 |
+
"csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp",
|
| 66 |
+
"csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp",
|
| 67 |
+
"csrc/jit_kernels/impls/smxx_layout.hpp",
|
| 68 |
+
"csrc/utils/compatibility.hpp",
|
| 69 |
+
"csrc/utils/exception.hpp",
|
| 70 |
+
"csrc/utils/format.hpp",
|
| 71 |
+
"csrc/utils/hash.hpp",
|
| 72 |
+
"csrc/utils/layout.hpp",
|
| 73 |
+
"csrc/utils/lazy_init.hpp",
|
| 74 |
+
"csrc/utils/math.hpp",
|
| 75 |
+
"csrc/utils/system.hpp",
|
| 76 |
+
"deep_gemm/include/deep_gemm/common/cute_tie.cuh",
|
| 77 |
+
"deep_gemm/include/deep_gemm/common/epilogue_utils.cuh",
|
| 78 |
+
"deep_gemm/include/deep_gemm/common/reduction.cuh",
|
| 79 |
+
"deep_gemm/include/deep_gemm/common/scheduler.cuh",
|
| 80 |
+
"deep_gemm/include/deep_gemm/common/sm100_utils.cuh",
|
| 81 |
+
"deep_gemm/include/deep_gemm/common/sm90_utils.cuh",
|
| 82 |
+
"deep_gemm/include/deep_gemm/common/tma_utils.cuh",
|
| 83 |
+
"deep_gemm/include/deep_gemm/common/types.hpp",
|
| 84 |
+
"deep_gemm/include/deep_gemm/common/utils.cuh",
|
| 85 |
+
"deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh",
|
| 86 |
+
"deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh",
|
| 87 |
+
"deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh",
|
| 88 |
+
"deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh",
|
| 89 |
+
"deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh",
|
| 90 |
+
"deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh",
|
| 91 |
+
"deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh",
|
| 92 |
+
"deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh",
|
| 93 |
+
"deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh",
|
| 94 |
+
"deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh",
|
| 95 |
+
"deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh",
|
| 96 |
+
"deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh",
|
| 97 |
+
"deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh",
|
| 98 |
+
"deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh",
|
| 99 |
+
"deep_gemm/include/deep_gemm/impls/smxx_layout.cuh",
|
| 100 |
+
]
|
build/torch210-cxx11-cu126-x86_64-linux/__init__.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _find_cuda_home():
|
| 9 |
+
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
| 10 |
+
if cuda_home is None:
|
| 11 |
+
try:
|
| 12 |
+
with open(os.devnull, 'w') as devnull:
|
| 13 |
+
nvcc = subprocess.check_output(
|
| 14 |
+
['which', 'nvcc'], stderr=devnull
|
| 15 |
+
).decode().rstrip('\r\n')
|
| 16 |
+
cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
| 17 |
+
except Exception:
|
| 18 |
+
cuda_home = '/usr/local/cuda'
|
| 19 |
+
if not os.path.exists(cuda_home):
|
| 20 |
+
cuda_home = ''
|
| 21 |
+
return cuda_home or ''
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _find_cutlass_include():
|
| 25 |
+
"""Find CUTLASS include path for JIT compilation of .cuh templates."""
|
| 26 |
+
# 1. Explicit env var
|
| 27 |
+
cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
|
| 28 |
+
if cutlass_include and os.path.isdir(cutlass_include):
|
| 29 |
+
return cutlass_include
|
| 30 |
+
|
| 31 |
+
# 2. CUTLASS_HOME env var
|
| 32 |
+
cutlass_home = os.environ.get('CUTLASS_HOME')
|
| 33 |
+
if cutlass_home:
|
| 34 |
+
p = os.path.join(cutlass_home, 'include')
|
| 35 |
+
if os.path.isdir(os.path.join(p, 'cute')):
|
| 36 |
+
return p
|
| 37 |
+
|
| 38 |
+
# 3. Check in package include/ directory (bundled cute/cutlass headers)
|
| 39 |
+
pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
|
| 40 |
+
if os.path.isdir(os.path.join(pkg_include, 'cute')):
|
| 41 |
+
return pkg_include
|
| 42 |
+
|
| 43 |
+
# 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
|
| 44 |
+
cuda_home = _find_cuda_home()
|
| 45 |
+
if cuda_home:
|
| 46 |
+
cuda_inc = os.path.join(cuda_home, 'include')
|
| 47 |
+
if os.path.isdir(os.path.join(cuda_inc, 'cute')):
|
| 48 |
+
return cuda_inc
|
| 49 |
+
|
| 50 |
+
# 5. Try to find nvidia-cutlass Python package
|
| 51 |
+
try:
|
| 52 |
+
import cutlass as _cutlass
|
| 53 |
+
cutlass_dir = os.path.dirname(_cutlass.__file__)
|
| 54 |
+
p = os.path.join(cutlass_dir, 'include')
|
| 55 |
+
if os.path.isdir(os.path.join(p, 'cute')):
|
| 56 |
+
return p
|
| 57 |
+
except ImportError:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
# Return empty string; C++ side will also check env vars
|
| 61 |
+
return ""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def set_num_sms(new_num_sms):
|
| 65 |
+
ops.set_num_sms(new_num_sms)
|
| 66 |
+
|
| 67 |
+
def get_num_sms():
|
| 68 |
+
return ops.get_num_sms()
|
| 69 |
+
|
| 70 |
+
def set_tc_util(new_tc_util):
|
| 71 |
+
ops.set_tc_util(new_tc_util)
|
| 72 |
+
|
| 73 |
+
def get_tc_util():
|
| 74 |
+
return ops.get_tc_util()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# cuBLASLt GEMMs
|
| 78 |
+
def cublaslt_gemm_nt(a, b, d, c=None):
|
| 79 |
+
ops.cublaslt_gemm_nt(a, b, d, c)
|
| 80 |
+
|
| 81 |
+
def cublaslt_gemm_nn(a, b, d, c=None):
|
| 82 |
+
ops.cublaslt_gemm_nn(a, b, d, c)
|
| 83 |
+
|
| 84 |
+
def cublaslt_gemm_tn(a, b, d, c=None):
|
| 85 |
+
ops.cublaslt_gemm_tn(a, b, d, c)
|
| 86 |
+
|
| 87 |
+
def cublaslt_gemm_tt(a, b, d, c=None):
|
| 88 |
+
ops.cublaslt_gemm_tt(a, b, d, c)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# FP8/FP4 GEMMs
|
| 93 |
+
def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 94 |
+
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
|
| 95 |
+
ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
|
| 96 |
+
list(recipe) if recipe else None,
|
| 97 |
+
list(recipe_a) if recipe_a else None,
|
| 98 |
+
list(recipe_b) if recipe_b else None,
|
| 99 |
+
compiled_dims, disable_ue8m0_cast)
|
| 100 |
+
|
| 101 |
+
def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 102 |
+
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
|
| 103 |
+
ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
|
| 104 |
+
list(recipe) if recipe else None,
|
| 105 |
+
list(recipe_a) if recipe_a else None,
|
| 106 |
+
list(recipe_b) if recipe_b else None,
|
| 107 |
+
compiled_dims, disable_ue8m0_cast)
|
| 108 |
+
|
| 109 |
+
def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 110 |
+
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
|
| 111 |
+
ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
|
| 112 |
+
list(recipe) if recipe else None,
|
| 113 |
+
list(recipe_a) if recipe_a else None,
|
| 114 |
+
list(recipe_b) if recipe_b else None,
|
| 115 |
+
compiled_dims, disable_ue8m0_cast)
|
| 116 |
+
|
| 117 |
+
def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 118 |
+
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
|
| 119 |
+
ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
|
| 120 |
+
list(recipe) if recipe else None,
|
| 121 |
+
list(recipe_a) if recipe_a else None,
|
| 122 |
+
list(recipe_b) if recipe_b else None,
|
| 123 |
+
compiled_dims, disable_ue8m0_cast)
|
| 124 |
+
|
| 125 |
+
fp8_gemm_nt = fp8_fp4_gemm_nt
|
| 126 |
+
fp8_gemm_nn = fp8_fp4_gemm_nn
|
| 127 |
+
fp8_gemm_tn = fp8_fp4_gemm_tn
|
| 128 |
+
fp8_gemm_tt = fp8_fp4_gemm_tt
|
| 129 |
+
|
| 130 |
+
def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
|
| 131 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 132 |
+
disable_ue8m0_cast=False, use_psum_layout=False,
|
| 133 |
+
expected_m_for_psum_layout=None):
|
| 134 |
+
ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
| 135 |
+
a[0], a[1], b[0], b[1], d, grouped_layout,
|
| 136 |
+
list(recipe) if recipe else None,
|
| 137 |
+
list(recipe_a) if recipe_a else None,
|
| 138 |
+
list(recipe_b) if recipe_b else None,
|
| 139 |
+
compiled_dims, disable_ue8m0_cast, use_psum_layout,
|
| 140 |
+
expected_m_for_psum_layout)
|
| 141 |
+
|
| 142 |
+
m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
|
| 143 |
+
|
| 144 |
+
def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
|
| 145 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 146 |
+
disable_ue8m0_cast=False, use_psum_layout=False):
|
| 147 |
+
ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
|
| 148 |
+
a[0], a[1], b[0], b[1], d, grouped_layout,
|
| 149 |
+
list(recipe) if recipe else None,
|
| 150 |
+
list(recipe_a) if recipe_a else None,
|
| 151 |
+
list(recipe_b) if recipe_b else None,
|
| 152 |
+
compiled_dims, disable_ue8m0_cast, use_psum_layout)
|
| 153 |
+
|
| 154 |
+
m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
|
| 155 |
+
|
| 156 |
+
def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
|
| 157 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 158 |
+
disable_ue8m0_cast=False):
|
| 159 |
+
ops.m_grouped_fp8_fp4_gemm_nt_masked(
|
| 160 |
+
a[0], a[1], b[0], b[1], d, masked_m, expected_m,
|
| 161 |
+
list(recipe) if recipe else None,
|
| 162 |
+
list(recipe_a) if recipe_a else None,
|
| 163 |
+
list(recipe_b) if recipe_b else None,
|
| 164 |
+
compiled_dims, disable_ue8m0_cast)
|
| 165 |
+
|
| 166 |
+
m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 167 |
+
|
| 168 |
+
def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
|
| 169 |
+
recipe=(1, 1, 128), compiled_dims="mn"):
|
| 170 |
+
ops.k_grouped_fp8_gemm_nt_contiguous(
|
| 171 |
+
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
|
| 172 |
+
list(recipe), compiled_dims)
|
| 173 |
+
|
| 174 |
+
def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
|
| 175 |
+
recipe=(1, 1, 128), compiled_dims="mn"):
|
| 176 |
+
ops.k_grouped_fp8_gemm_tn_contiguous(
|
| 177 |
+
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
|
| 178 |
+
list(recipe), compiled_dims)
|
| 179 |
+
|
| 180 |
+
# BF16 GEMMs
|
| 181 |
+
def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
|
| 182 |
+
ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
|
| 183 |
+
|
| 184 |
+
def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
|
| 185 |
+
ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
|
| 186 |
+
|
| 187 |
+
def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
|
| 188 |
+
ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
|
| 189 |
+
|
| 190 |
+
def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
|
| 191 |
+
ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
|
| 192 |
+
|
| 193 |
+
def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
|
| 194 |
+
compiled_dims="nk", use_psum_layout=False,
|
| 195 |
+
expected_m_for_psum_layout=None):
|
| 196 |
+
ops.m_grouped_bf16_gemm_nt_contiguous(
|
| 197 |
+
a, b, d, grouped_layout, compiled_dims,
|
| 198 |
+
use_psum_layout, expected_m_for_psum_layout)
|
| 199 |
+
|
| 200 |
+
def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
|
| 201 |
+
compiled_dims="nk", use_psum_layout=False):
|
| 202 |
+
ops.m_grouped_bf16_gemm_nn_contiguous(
|
| 203 |
+
a, b, d, grouped_layout, compiled_dims, use_psum_layout)
|
| 204 |
+
|
| 205 |
+
def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
|
| 206 |
+
compiled_dims="nk"):
|
| 207 |
+
ops.m_grouped_bf16_gemm_nt_masked(
|
| 208 |
+
a, b, d, masked_m, expected_m, compiled_dims)
|
| 209 |
+
|
| 210 |
+
def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
|
| 211 |
+
c=None, compiled_dims="mn"):
|
| 212 |
+
ops.k_grouped_bf16_gemm_tn_contiguous(
|
| 213 |
+
a, b, d, ks, ks_tensor, c, compiled_dims)
|
| 214 |
+
|
| 215 |
+
# Einsum
|
| 216 |
+
def einsum(expr, a, b, d, c=None, use_cublaslt=False):
|
| 217 |
+
ops.einsum(expr, a, b, d, c, use_cublaslt)
|
| 218 |
+
|
| 219 |
+
def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
|
| 220 |
+
ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
|
| 221 |
+
|
| 222 |
+
# Attention
|
| 223 |
+
def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
|
| 224 |
+
compiled_dims="nk", disable_ue8m0_cast=False):
|
| 225 |
+
ops.fp8_gemm_nt_skip_head_mid(
|
| 226 |
+
a[0], a[1], b[0], b[1], d, list(head_splits),
|
| 227 |
+
list(recipe) if recipe else None,
|
| 228 |
+
compiled_dims, disable_ue8m0_cast)
|
| 229 |
+
|
| 230 |
+
def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
|
| 231 |
+
cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
|
| 232 |
+
return ops.fp8_mqa_logits(
|
| 233 |
+
q, kv[0], kv[1], weights,
|
| 234 |
+
cu_seq_len_k_start, cu_seq_len_k_end,
|
| 235 |
+
clean_logits, max_seqlen_k)
|
| 236 |
+
|
| 237 |
+
def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
|
| 238 |
+
return ops.get_paged_mqa_logits_metadata(
|
| 239 |
+
context_lens, block_kv, num_sms)
|
| 240 |
+
|
| 241 |
+
def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
|
| 242 |
+
block_table, schedule_meta,
|
| 243 |
+
max_context_len, clean_logits=False):
|
| 244 |
+
return ops.fp8_paged_mqa_logits(
|
| 245 |
+
q, fused_kv_cache, weights, context_lens,
|
| 246 |
+
block_table, schedule_meta, max_context_len, clean_logits)
|
| 247 |
+
|
| 248 |
+
# Hyperconnection
|
| 249 |
+
def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
|
| 250 |
+
ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
|
| 251 |
+
|
| 252 |
+
# Layout
|
| 253 |
+
def transform_sf_into_required_layout(sf, mn, k, recipe=None,
|
| 254 |
+
recipe_ab=None, num_groups=None, is_sfa=False,
|
| 255 |
+
disable_ue8m0_cast=False):
|
| 256 |
+
return ops.transform_sf_into_required_layout(
|
| 257 |
+
sf, mn, k,
|
| 258 |
+
list(recipe) if recipe else None,
|
| 259 |
+
list(recipe_ab) if recipe_ab else None,
|
| 260 |
+
num_groups, is_sfa, disable_ue8m0_cast)
|
| 261 |
+
|
| 262 |
+
def get_mk_alignment_for_contiguous_layout():
|
| 263 |
+
return ops.get_mk_alignment_for_contiguous_layout()
|
| 264 |
+
|
| 265 |
+
# Legacy aliases
|
| 266 |
+
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 267 |
+
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
|
| 268 |
+
|
| 269 |
+
except Exception:
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
# Utils
|
| 273 |
+
from . import utils
|
| 274 |
+
from .utils import *
|
| 275 |
+
|
| 276 |
+
# Testing
|
| 277 |
+
from . import testing
|
| 278 |
+
|
| 279 |
+
# Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
|
| 280 |
+
try:
|
| 281 |
+
ops.init(
|
| 282 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 283 |
+
_find_cuda_home(),
|
| 284 |
+
_find_cutlass_include()
|
| 285 |
+
)
|
| 286 |
+
except Exception:
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
__version__ = '2.3.0'
|
build/torch210-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be5d0bb69c96d55b15ba62ba83e0743eb80ef4e93198fe59862dc247540f4956
|
| 3 |
+
size 3006712
|
build/torch210-cxx11-cu126-x86_64-linux/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _deep_gemm_099ac3c_dirty
|
| 3 |
+
ops = torch.ops._deep_gemm_099ac3c_dirty
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_deep_gemm_099ac3c_dirty::{op_name}"
|
build/torch210-cxx11-cu126-x86_64-linux/deep_gemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch210-cxx11-cu126-x86_64-linux/metadata.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python-depends": []
|
| 3 |
+
}
|
build/torch210-cxx11-cu126-x86_64-linux/testing/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import bench, numeric, utils
|
| 2 |
+
from .bench import *
|
| 3 |
+
from .numeric import *
|
| 4 |
+
from .utils import *
|
build/torch210-cxx11-cu126-x86_64-linux/testing/bench.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
| 7 |
+
high_precision: bool = False):
|
| 8 |
+
# Flush L2 cache with 256 MB data
|
| 9 |
+
torch.cuda.synchronize()
|
| 10 |
+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
| 11 |
+
cache.zero_()
|
| 12 |
+
|
| 13 |
+
# Warmup
|
| 14 |
+
for _ in range(num_warmups):
|
| 15 |
+
fn()
|
| 16 |
+
|
| 17 |
+
# Add a large kernel to eliminate the CPU launch overhead
|
| 18 |
+
if high_precision:
|
| 19 |
+
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 20 |
+
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 21 |
+
x @ y
|
| 22 |
+
|
| 23 |
+
# Testing
|
| 24 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 25 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 26 |
+
start_event.record()
|
| 27 |
+
for i in range(num_tests):
|
| 28 |
+
fn()
|
| 29 |
+
end_event.record()
|
| 30 |
+
torch.cuda.synchronize()
|
| 31 |
+
|
| 32 |
+
return start_event.elapsed_time(end_event) / num_tests / 1e3
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class empty_suppress:
|
| 36 |
+
def __enter__(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
def __exit__(self, *_):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class suppress_stdout_stderr:
|
| 44 |
+
def __enter__(self):
|
| 45 |
+
self.outnull_file = open(os.devnull, 'w')
|
| 46 |
+
self.errnull_file = open(os.devnull, 'w')
|
| 47 |
+
|
| 48 |
+
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
| 49 |
+
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
| 50 |
+
|
| 51 |
+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
| 52 |
+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
| 53 |
+
|
| 54 |
+
self.old_stdout = sys.stdout
|
| 55 |
+
self.old_stderr = sys.stderr
|
| 56 |
+
|
| 57 |
+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
| 58 |
+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
| 59 |
+
|
| 60 |
+
sys.stdout = self.outnull_file
|
| 61 |
+
sys.stderr = self.errnull_file
|
| 62 |
+
return self
|
| 63 |
+
|
| 64 |
+
def __exit__(self, *_):
|
| 65 |
+
sys.stdout = self.old_stdout
|
| 66 |
+
sys.stderr = self.old_stderr
|
| 67 |
+
|
| 68 |
+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
| 69 |
+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
| 70 |
+
|
| 71 |
+
os.close(self.old_stdout_fileno)
|
| 72 |
+
os.close(self.old_stderr_fileno)
|
| 73 |
+
|
| 74 |
+
self.outnull_file.close()
|
| 75 |
+
self.errnull_file.close()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
| 79 |
+
suppress_kineto_output: bool = False,
|
| 80 |
+
trace_path: str = None, flush_l2: bool = True,
|
| 81 |
+
with_multiple_kernels: bool = False):
|
| 82 |
+
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
| 83 |
+
is_tuple = isinstance(kernel_names, tuple)
|
| 84 |
+
|
| 85 |
+
# Skip profiling
|
| 86 |
+
# Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
|
| 87 |
+
if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
|
| 88 |
+
return (1, ) * len(kernel_names) if is_tuple else 1
|
| 89 |
+
|
| 90 |
+
# By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
|
| 91 |
+
flush_l2_size = int(8e9 // 4)
|
| 92 |
+
|
| 93 |
+
# For some auto-tuning kernels with prints
|
| 94 |
+
fn()
|
| 95 |
+
|
| 96 |
+
# Profile
|
| 97 |
+
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
| 98 |
+
with suppress():
|
| 99 |
+
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
|
| 100 |
+
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
|
| 101 |
+
with profiler:
|
| 102 |
+
for i in range(2):
|
| 103 |
+
for _ in range(num_tests):
|
| 104 |
+
if flush_l2:
|
| 105 |
+
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
| 106 |
+
fn()
|
| 107 |
+
profiler.step()
|
| 108 |
+
|
| 109 |
+
# Parse the profiling table
|
| 110 |
+
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
| 111 |
+
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
| 112 |
+
if not with_multiple_kernels:
|
| 113 |
+
for name in kernel_names:
|
| 114 |
+
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
|
| 115 |
+
|
| 116 |
+
# Save chrome traces
|
| 117 |
+
if trace_path is not None:
|
| 118 |
+
profiler.export_chrome_trace(trace_path)
|
| 119 |
+
|
| 120 |
+
# Return average kernel times
|
| 121 |
+
units = {'ms': 1e3, 'us': 1e6}
|
| 122 |
+
kernel_times = []
|
| 123 |
+
for name in kernel_names:
|
| 124 |
+
total_time = 0
|
| 125 |
+
total_num = 0
|
| 126 |
+
for line in prof_lines:
|
| 127 |
+
if name in line:
|
| 128 |
+
time_str = line.split()[-2]
|
| 129 |
+
num_str = line.split()[-1]
|
| 130 |
+
for unit, scale in units.items():
|
| 131 |
+
if unit in time_str:
|
| 132 |
+
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
| 133 |
+
total_num += int(num_str)
|
| 134 |
+
break
|
| 135 |
+
kernel_times.append(total_time / total_num if total_num > 0 else 0)
|
| 136 |
+
|
| 137 |
+
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
build/torch210-cxx11-cu126-x86_64-linux/testing/numeric.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Iterable
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
| 6 |
+
x, y = x.double(), y.double()
|
| 7 |
+
denominator = (x * x + y * y).sum()
|
| 8 |
+
if denominator == 0: # Which means that all elements in x and y are 0
|
| 9 |
+
return 0.0
|
| 10 |
+
sim = 2 * (x * y).sum() / denominator
|
| 11 |
+
return 1 - sim
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def count_bytes(*tensors):
|
| 15 |
+
total = 0
|
| 16 |
+
for t in tensors:
|
| 17 |
+
if isinstance(t, (tuple, list)):
|
| 18 |
+
total += count_bytes(*t)
|
| 19 |
+
elif t is not None:
|
| 20 |
+
total += t.numel() * t.element_size()
|
| 21 |
+
return total
|
build/torch210-cxx11-cu126-x86_64-linux/testing/utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Callable
|
| 5 |
+
|
| 6 |
+
def get_arch_major() -> int:
|
| 7 |
+
major, minor = torch.cuda.get_device_capability()
|
| 8 |
+
return major
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_filter(condition: Callable):
|
| 12 |
+
def decorator(func):
|
| 13 |
+
@functools.wraps(func)
|
| 14 |
+
def wrapper(*args, **kwargs):
|
| 15 |
+
if condition():
|
| 16 |
+
func(*args, **kwargs)
|
| 17 |
+
else:
|
| 18 |
+
print(f'{func.__name__}:')
|
| 19 |
+
print(f' > Filtered by {condition}')
|
| 20 |
+
print()
|
| 21 |
+
return wrapper
|
| 22 |
+
return decorator
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ignore_env(name: str, condition: Callable):
|
| 26 |
+
def decorator(func):
|
| 27 |
+
@functools.wraps(func)
|
| 28 |
+
def wrapper(*args, **kwargs):
|
| 29 |
+
if condition():
|
| 30 |
+
saved = os.environ.pop(name, None)
|
| 31 |
+
func(*args, **kwargs)
|
| 32 |
+
if saved is not None:
|
| 33 |
+
os.environ[name] = saved
|
| 34 |
+
else:
|
| 35 |
+
func(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
return decorator
|
build/torch210-cxx11-cu126-x86_64-linux/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import math, layout
|
| 2 |
+
from .layout import *
|
| 3 |
+
from .math import *
|
build/torch210-cxx11-cu126-x86_64-linux/utils/layout.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .._ops import ops
|
| 3 |
+
|
| 4 |
+
def get_tma_aligned_size(x, element_size):
|
| 5 |
+
return ops.get_tma_aligned_size(x, element_size)
|
| 6 |
+
|
| 7 |
+
def get_mn_major_tma_aligned_tensor(sf):
|
| 8 |
+
return ops.get_mn_major_tma_aligned_tensor(sf)
|
| 9 |
+
|
| 10 |
+
def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
|
| 11 |
+
return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
| 12 |
+
|
| 13 |
+
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
|
| 14 |
+
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 15 |
+
sf, ks_tensor, ks)
|
| 16 |
+
except ImportError:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
from .._ops import ops as _ops
|
| 20 |
+
|
| 21 |
+
def get_mk_alignment_for_contiguous_layout():
|
| 22 |
+
return _ops.get_mk_alignment_for_contiguous_layout()
|
| 23 |
+
|
| 24 |
+
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
| 25 |
+
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
build/torch210-cxx11-cu126-x86_64-linux/utils/math.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def ceil_div(x: int, y: int) -> int:
|
| 6 |
+
return (x + y - 1) // y
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def align(x: int, y: int) -> int:
|
| 10 |
+
return ceil_div(x, y) * y
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def ceil_to_ue8m0(x: torch.Tensor):
|
| 14 |
+
assert x.view(-1).amax().item() > 0
|
| 15 |
+
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 19 |
+
assert x.dim() == 2
|
| 20 |
+
m, n = x.shape
|
| 21 |
+
padded_n = align(n, gran_k)
|
| 22 |
+
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
|
| 23 |
+
x_padded[:, :n] = x
|
| 24 |
+
x_view = x_padded.view(m, -1, gran_k)
|
| 25 |
+
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
| 26 |
+
sf = x_amax / 448.0
|
| 27 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 28 |
+
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 32 |
+
assert x.dim() == 2 and x.size(0) % gran_k == 0
|
| 33 |
+
m, n = x.shape
|
| 34 |
+
x_view = x.view(-1, gran_k, n)
|
| 35 |
+
x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
|
| 36 |
+
sf = x_amax / 448.0
|
| 37 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 38 |
+
return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
assert x.dim() == 2
|
| 43 |
+
m, n = x.shape
|
| 44 |
+
x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
|
| 45 |
+
x_padded[:m, :n] = x
|
| 46 |
+
x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
|
| 47 |
+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
| 48 |
+
sf = x_amax / 448.0
|
| 49 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 50 |
+
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
| 51 |
+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 55 |
+
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
| 56 |
+
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
| 57 |
+
sf = x_amax / 448.0
|
| 58 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 59 |
+
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
| 60 |
+
return x_scaled, sf.squeeze()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
ax = x.abs().clamp_max(6.0)
|
| 65 |
+
# {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
| 66 |
+
# midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0
|
| 67 |
+
boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
|
| 68 |
+
device=x.device, dtype=ax.dtype)
|
| 69 |
+
idx = torch.bucketize(ax, boundaries)
|
| 70 |
+
code = idx.to(torch.uint8)
|
| 71 |
+
sign = (x < 0) & (idx != 0)
|
| 72 |
+
code = code | (sign.to(torch.uint8) << 3)
|
| 73 |
+
return code # uint8, 0..15
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 77 |
+
assert x.dim() == 2
|
| 78 |
+
m, n = x.shape
|
| 79 |
+
assert n % 2 == 0
|
| 80 |
+
padded_n = align(n, gran_k)
|
| 81 |
+
x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
|
| 82 |
+
x_padded[:, :n] = x
|
| 83 |
+
x_view = x_padded.view(m, -1, gran_k)
|
| 84 |
+
x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4)
|
| 85 |
+
sf = x_amax / 6.0
|
| 86 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 87 |
+
x_scaled = x_view * (1.0 / sf.unsqueeze(2))
|
| 88 |
+
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
|
| 89 |
+
codes2 = codes.view(m, padded_n // 2, 2)
|
| 90 |
+
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
|
| 91 |
+
return packed[:, :n // 2].contiguous(), sf
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
assert a.dtype == torch.uint8
|
| 96 |
+
assert a.dim() == 2
|
| 97 |
+
m, n2 = a.shape
|
| 98 |
+
n = n2 * 2
|
| 99 |
+
assert (m % 2) == 0
|
| 100 |
+
lo = a & 0x0F
|
| 101 |
+
hi = (a >> 4) & 0x0F
|
| 102 |
+
codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
|
| 103 |
+
codes[:, 0::2], codes[:, 1::2] = lo, hi
|
| 104 |
+
codes_t = codes.transpose(0, 1).contiguous()
|
| 105 |
+
codes2 = codes_t.view(n, m // 2, 2)
|
| 106 |
+
out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
|
| 107 |
+
return out.contiguous()
|
build/torch210-cxx11-cu128-x86_64-linux/__init__.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _find_cuda_home():
|
| 9 |
+
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
| 10 |
+
if cuda_home is None:
|
| 11 |
+
try:
|
| 12 |
+
with open(os.devnull, 'w') as devnull:
|
| 13 |
+
nvcc = subprocess.check_output(
|
| 14 |
+
['which', 'nvcc'], stderr=devnull
|
| 15 |
+
).decode().rstrip('\r\n')
|
| 16 |
+
cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
| 17 |
+
except Exception:
|
| 18 |
+
cuda_home = '/usr/local/cuda'
|
| 19 |
+
if not os.path.exists(cuda_home):
|
| 20 |
+
cuda_home = ''
|
| 21 |
+
return cuda_home or ''
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _find_cutlass_include():
|
| 25 |
+
"""Find CUTLASS include path for JIT compilation of .cuh templates."""
|
| 26 |
+
# 1. Explicit env var
|
| 27 |
+
cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
|
| 28 |
+
if cutlass_include and os.path.isdir(cutlass_include):
|
| 29 |
+
return cutlass_include
|
| 30 |
+
|
| 31 |
+
# 2. CUTLASS_HOME env var
|
| 32 |
+
cutlass_home = os.environ.get('CUTLASS_HOME')
|
| 33 |
+
if cutlass_home:
|
| 34 |
+
p = os.path.join(cutlass_home, 'include')
|
| 35 |
+
if os.path.isdir(os.path.join(p, 'cute')):
|
| 36 |
+
return p
|
| 37 |
+
|
| 38 |
+
# 3. Check in package include/ directory (bundled cute/cutlass headers)
|
| 39 |
+
pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
|
| 40 |
+
if os.path.isdir(os.path.join(pkg_include, 'cute')):
|
| 41 |
+
return pkg_include
|
| 42 |
+
|
| 43 |
+
# 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
|
| 44 |
+
cuda_home = _find_cuda_home()
|
| 45 |
+
if cuda_home:
|
| 46 |
+
cuda_inc = os.path.join(cuda_home, 'include')
|
| 47 |
+
if os.path.isdir(os.path.join(cuda_inc, 'cute')):
|
| 48 |
+
return cuda_inc
|
| 49 |
+
|
| 50 |
+
# 5. Try to find nvidia-cutlass Python package
|
| 51 |
+
try:
|
| 52 |
+
import cutlass as _cutlass
|
| 53 |
+
cutlass_dir = os.path.dirname(_cutlass.__file__)
|
| 54 |
+
p = os.path.join(cutlass_dir, 'include')
|
| 55 |
+
if os.path.isdir(os.path.join(p, 'cute')):
|
| 56 |
+
return p
|
| 57 |
+
except ImportError:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
# Return empty string; C++ side will also check env vars
|
| 61 |
+
return ""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def set_num_sms(new_num_sms):
|
| 65 |
+
ops.set_num_sms(new_num_sms)
|
| 66 |
+
|
| 67 |
+
def get_num_sms():
|
| 68 |
+
return ops.get_num_sms()
|
| 69 |
+
|
| 70 |
+
def set_tc_util(new_tc_util):
|
| 71 |
+
ops.set_tc_util(new_tc_util)
|
| 72 |
+
|
| 73 |
+
def get_tc_util():
|
| 74 |
+
return ops.get_tc_util()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# cuBLASLt GEMMs
|
| 78 |
+
def cublaslt_gemm_nt(a, b, d, c=None):
|
| 79 |
+
ops.cublaslt_gemm_nt(a, b, d, c)
|
| 80 |
+
|
| 81 |
+
def cublaslt_gemm_nn(a, b, d, c=None):
|
| 82 |
+
ops.cublaslt_gemm_nn(a, b, d, c)
|
| 83 |
+
|
| 84 |
+
def cublaslt_gemm_tn(a, b, d, c=None):
|
| 85 |
+
ops.cublaslt_gemm_tn(a, b, d, c)
|
| 86 |
+
|
| 87 |
+
def cublaslt_gemm_tt(a, b, d, c=None):
|
| 88 |
+
ops.cublaslt_gemm_tt(a, b, d, c)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# FP8/FP4 GEMMs
|
| 93 |
+
def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 94 |
+
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
|
| 95 |
+
ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
|
| 96 |
+
list(recipe) if recipe else None,
|
| 97 |
+
list(recipe_a) if recipe_a else None,
|
| 98 |
+
list(recipe_b) if recipe_b else None,
|
| 99 |
+
compiled_dims, disable_ue8m0_cast)
|
| 100 |
+
|
| 101 |
+
def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 102 |
+
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
|
| 103 |
+
ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
|
| 104 |
+
list(recipe) if recipe else None,
|
| 105 |
+
list(recipe_a) if recipe_a else None,
|
| 106 |
+
list(recipe_b) if recipe_b else None,
|
| 107 |
+
compiled_dims, disable_ue8m0_cast)
|
| 108 |
+
|
| 109 |
+
def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 110 |
+
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
|
| 111 |
+
ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
|
| 112 |
+
list(recipe) if recipe else None,
|
| 113 |
+
list(recipe_a) if recipe_a else None,
|
| 114 |
+
list(recipe_b) if recipe_b else None,
|
| 115 |
+
compiled_dims, disable_ue8m0_cast)
|
| 116 |
+
|
| 117 |
+
def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 118 |
+
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
|
| 119 |
+
ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
|
| 120 |
+
list(recipe) if recipe else None,
|
| 121 |
+
list(recipe_a) if recipe_a else None,
|
| 122 |
+
list(recipe_b) if recipe_b else None,
|
| 123 |
+
compiled_dims, disable_ue8m0_cast)
|
| 124 |
+
|
| 125 |
+
fp8_gemm_nt = fp8_fp4_gemm_nt
|
| 126 |
+
fp8_gemm_nn = fp8_fp4_gemm_nn
|
| 127 |
+
fp8_gemm_tn = fp8_fp4_gemm_tn
|
| 128 |
+
fp8_gemm_tt = fp8_fp4_gemm_tt
|
| 129 |
+
|
| 130 |
+
def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
|
| 131 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 132 |
+
disable_ue8m0_cast=False, use_psum_layout=False,
|
| 133 |
+
expected_m_for_psum_layout=None):
|
| 134 |
+
ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
| 135 |
+
a[0], a[1], b[0], b[1], d, grouped_layout,
|
| 136 |
+
list(recipe) if recipe else None,
|
| 137 |
+
list(recipe_a) if recipe_a else None,
|
| 138 |
+
list(recipe_b) if recipe_b else None,
|
| 139 |
+
compiled_dims, disable_ue8m0_cast, use_psum_layout,
|
| 140 |
+
expected_m_for_psum_layout)
|
| 141 |
+
|
| 142 |
+
m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
|
| 143 |
+
|
| 144 |
+
def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
|
| 145 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 146 |
+
disable_ue8m0_cast=False, use_psum_layout=False):
|
| 147 |
+
ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
|
| 148 |
+
a[0], a[1], b[0], b[1], d, grouped_layout,
|
| 149 |
+
list(recipe) if recipe else None,
|
| 150 |
+
list(recipe_a) if recipe_a else None,
|
| 151 |
+
list(recipe_b) if recipe_b else None,
|
| 152 |
+
compiled_dims, disable_ue8m0_cast, use_psum_layout)
|
| 153 |
+
|
| 154 |
+
m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
|
| 155 |
+
|
| 156 |
+
def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
|
| 157 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 158 |
+
disable_ue8m0_cast=False):
|
| 159 |
+
ops.m_grouped_fp8_fp4_gemm_nt_masked(
|
| 160 |
+
a[0], a[1], b[0], b[1], d, masked_m, expected_m,
|
| 161 |
+
list(recipe) if recipe else None,
|
| 162 |
+
list(recipe_a) if recipe_a else None,
|
| 163 |
+
list(recipe_b) if recipe_b else None,
|
| 164 |
+
compiled_dims, disable_ue8m0_cast)
|
| 165 |
+
|
| 166 |
+
m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 167 |
+
|
| 168 |
+
def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
|
| 169 |
+
recipe=(1, 1, 128), compiled_dims="mn"):
|
| 170 |
+
ops.k_grouped_fp8_gemm_nt_contiguous(
|
| 171 |
+
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
|
| 172 |
+
list(recipe), compiled_dims)
|
| 173 |
+
|
| 174 |
+
def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
|
| 175 |
+
recipe=(1, 1, 128), compiled_dims="mn"):
|
| 176 |
+
ops.k_grouped_fp8_gemm_tn_contiguous(
|
| 177 |
+
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
|
| 178 |
+
list(recipe), compiled_dims)
|
| 179 |
+
|
| 180 |
+
# BF16 GEMMs
|
| 181 |
+
def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
|
| 182 |
+
ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
|
| 183 |
+
|
| 184 |
+
def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
|
| 185 |
+
ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
|
| 186 |
+
|
| 187 |
+
def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
|
| 188 |
+
ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
|
| 189 |
+
|
| 190 |
+
def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
|
| 191 |
+
ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
|
| 192 |
+
|
| 193 |
+
def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
|
| 194 |
+
compiled_dims="nk", use_psum_layout=False,
|
| 195 |
+
expected_m_for_psum_layout=None):
|
| 196 |
+
ops.m_grouped_bf16_gemm_nt_contiguous(
|
| 197 |
+
a, b, d, grouped_layout, compiled_dims,
|
| 198 |
+
use_psum_layout, expected_m_for_psum_layout)
|
| 199 |
+
|
| 200 |
+
def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
|
| 201 |
+
compiled_dims="nk", use_psum_layout=False):
|
| 202 |
+
ops.m_grouped_bf16_gemm_nn_contiguous(
|
| 203 |
+
a, b, d, grouped_layout, compiled_dims, use_psum_layout)
|
| 204 |
+
|
| 205 |
+
def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
|
| 206 |
+
compiled_dims="nk"):
|
| 207 |
+
ops.m_grouped_bf16_gemm_nt_masked(
|
| 208 |
+
a, b, d, masked_m, expected_m, compiled_dims)
|
| 209 |
+
|
| 210 |
+
def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
|
| 211 |
+
c=None, compiled_dims="mn"):
|
| 212 |
+
ops.k_grouped_bf16_gemm_tn_contiguous(
|
| 213 |
+
a, b, d, ks, ks_tensor, c, compiled_dims)
|
| 214 |
+
|
| 215 |
+
# Einsum
|
| 216 |
+
def einsum(expr, a, b, d, c=None, use_cublaslt=False):
|
| 217 |
+
ops.einsum(expr, a, b, d, c, use_cublaslt)
|
| 218 |
+
|
| 219 |
+
def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
|
| 220 |
+
ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
|
| 221 |
+
|
| 222 |
+
# Attention
|
| 223 |
+
def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
|
| 224 |
+
compiled_dims="nk", disable_ue8m0_cast=False):
|
| 225 |
+
ops.fp8_gemm_nt_skip_head_mid(
|
| 226 |
+
a[0], a[1], b[0], b[1], d, list(head_splits),
|
| 227 |
+
list(recipe) if recipe else None,
|
| 228 |
+
compiled_dims, disable_ue8m0_cast)
|
| 229 |
+
|
| 230 |
+
def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
|
| 231 |
+
cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
|
| 232 |
+
return ops.fp8_mqa_logits(
|
| 233 |
+
q, kv[0], kv[1], weights,
|
| 234 |
+
cu_seq_len_k_start, cu_seq_len_k_end,
|
| 235 |
+
clean_logits, max_seqlen_k)
|
| 236 |
+
|
| 237 |
+
def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
|
| 238 |
+
return ops.get_paged_mqa_logits_metadata(
|
| 239 |
+
context_lens, block_kv, num_sms)
|
| 240 |
+
|
| 241 |
+
def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
|
| 242 |
+
block_table, schedule_meta,
|
| 243 |
+
max_context_len, clean_logits=False):
|
| 244 |
+
return ops.fp8_paged_mqa_logits(
|
| 245 |
+
q, fused_kv_cache, weights, context_lens,
|
| 246 |
+
block_table, schedule_meta, max_context_len, clean_logits)
|
| 247 |
+
|
| 248 |
+
# Hyperconnection
|
| 249 |
+
def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
|
| 250 |
+
ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
|
| 251 |
+
|
| 252 |
+
# Layout
|
| 253 |
+
def transform_sf_into_required_layout(sf, mn, k, recipe=None,
|
| 254 |
+
recipe_ab=None, num_groups=None, is_sfa=False,
|
| 255 |
+
disable_ue8m0_cast=False):
|
| 256 |
+
return ops.transform_sf_into_required_layout(
|
| 257 |
+
sf, mn, k,
|
| 258 |
+
list(recipe) if recipe else None,
|
| 259 |
+
list(recipe_ab) if recipe_ab else None,
|
| 260 |
+
num_groups, is_sfa, disable_ue8m0_cast)
|
| 261 |
+
|
| 262 |
+
def get_mk_alignment_for_contiguous_layout():
|
| 263 |
+
return ops.get_mk_alignment_for_contiguous_layout()
|
| 264 |
+
|
| 265 |
+
# Legacy aliases
|
| 266 |
+
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 267 |
+
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
|
| 268 |
+
|
| 269 |
+
except Exception:
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
# Utils
|
| 273 |
+
from . import utils
|
| 274 |
+
from .utils import *
|
| 275 |
+
|
| 276 |
+
# Testing
|
| 277 |
+
from . import testing
|
| 278 |
+
|
| 279 |
+
# Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
|
| 280 |
+
try:
|
| 281 |
+
ops.init(
|
| 282 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 283 |
+
_find_cuda_home(),
|
| 284 |
+
_find_cutlass_include()
|
| 285 |
+
)
|
| 286 |
+
except Exception:
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
__version__ = '2.3.0'
|
build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b4ca9c42204f1909adcefc61053c7943c105eadb44a447a1ea9a488e01675df
|
| 3 |
+
size 3078080
|
build/torch210-cxx11-cu128-x86_64-linux/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _deep_gemm_099ac3c_dirty
|
| 3 |
+
ops = torch.ops._deep_gemm_099ac3c_dirty
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_deep_gemm_099ac3c_dirty::{op_name}"
|
build/torch210-cxx11-cu128-x86_64-linux/deep_gemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch210-cxx11-cu128-x86_64-linux/metadata.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python-depends": []
|
| 3 |
+
}
|
build/torch210-cxx11-cu128-x86_64-linux/testing/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import bench, numeric, utils
|
| 2 |
+
from .bench import *
|
| 3 |
+
from .numeric import *
|
| 4 |
+
from .utils import *
|
build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
| 7 |
+
high_precision: bool = False):
|
| 8 |
+
# Flush L2 cache with 256 MB data
|
| 9 |
+
torch.cuda.synchronize()
|
| 10 |
+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
| 11 |
+
cache.zero_()
|
| 12 |
+
|
| 13 |
+
# Warmup
|
| 14 |
+
for _ in range(num_warmups):
|
| 15 |
+
fn()
|
| 16 |
+
|
| 17 |
+
# Add a large kernel to eliminate the CPU launch overhead
|
| 18 |
+
if high_precision:
|
| 19 |
+
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 20 |
+
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 21 |
+
x @ y
|
| 22 |
+
|
| 23 |
+
# Testing
|
| 24 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 25 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 26 |
+
start_event.record()
|
| 27 |
+
for i in range(num_tests):
|
| 28 |
+
fn()
|
| 29 |
+
end_event.record()
|
| 30 |
+
torch.cuda.synchronize()
|
| 31 |
+
|
| 32 |
+
return start_event.elapsed_time(end_event) / num_tests / 1e3
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class empty_suppress:
|
| 36 |
+
def __enter__(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
def __exit__(self, *_):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class suppress_stdout_stderr:
|
| 44 |
+
def __enter__(self):
|
| 45 |
+
self.outnull_file = open(os.devnull, 'w')
|
| 46 |
+
self.errnull_file = open(os.devnull, 'w')
|
| 47 |
+
|
| 48 |
+
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
| 49 |
+
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
| 50 |
+
|
| 51 |
+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
| 52 |
+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
| 53 |
+
|
| 54 |
+
self.old_stdout = sys.stdout
|
| 55 |
+
self.old_stderr = sys.stderr
|
| 56 |
+
|
| 57 |
+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
| 58 |
+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
| 59 |
+
|
| 60 |
+
sys.stdout = self.outnull_file
|
| 61 |
+
sys.stderr = self.errnull_file
|
| 62 |
+
return self
|
| 63 |
+
|
| 64 |
+
def __exit__(self, *_):
|
| 65 |
+
sys.stdout = self.old_stdout
|
| 66 |
+
sys.stderr = self.old_stderr
|
| 67 |
+
|
| 68 |
+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
| 69 |
+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
| 70 |
+
|
| 71 |
+
os.close(self.old_stdout_fileno)
|
| 72 |
+
os.close(self.old_stderr_fileno)
|
| 73 |
+
|
| 74 |
+
self.outnull_file.close()
|
| 75 |
+
self.errnull_file.close()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
| 79 |
+
suppress_kineto_output: bool = False,
|
| 80 |
+
trace_path: str = None, flush_l2: bool = True,
|
| 81 |
+
with_multiple_kernels: bool = False):
|
| 82 |
+
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
| 83 |
+
is_tuple = isinstance(kernel_names, tuple)
|
| 84 |
+
|
| 85 |
+
# Skip profiling
|
| 86 |
+
# Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
|
| 87 |
+
if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
|
| 88 |
+
return (1, ) * len(kernel_names) if is_tuple else 1
|
| 89 |
+
|
| 90 |
+
# By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
|
| 91 |
+
flush_l2_size = int(8e9 // 4)
|
| 92 |
+
|
| 93 |
+
# For some auto-tuning kernels with prints
|
| 94 |
+
fn()
|
| 95 |
+
|
| 96 |
+
# Profile
|
| 97 |
+
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
| 98 |
+
with suppress():
|
| 99 |
+
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
|
| 100 |
+
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
|
| 101 |
+
with profiler:
|
| 102 |
+
for i in range(2):
|
| 103 |
+
for _ in range(num_tests):
|
| 104 |
+
if flush_l2:
|
| 105 |
+
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
| 106 |
+
fn()
|
| 107 |
+
profiler.step()
|
| 108 |
+
|
| 109 |
+
# Parse the profiling table
|
| 110 |
+
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
| 111 |
+
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
| 112 |
+
if not with_multiple_kernels:
|
| 113 |
+
for name in kernel_names:
|
| 114 |
+
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
|
| 115 |
+
|
| 116 |
+
# Save chrome traces
|
| 117 |
+
if trace_path is not None:
|
| 118 |
+
profiler.export_chrome_trace(trace_path)
|
| 119 |
+
|
| 120 |
+
# Return average kernel times
|
| 121 |
+
units = {'ms': 1e3, 'us': 1e6}
|
| 122 |
+
kernel_times = []
|
| 123 |
+
for name in kernel_names:
|
| 124 |
+
total_time = 0
|
| 125 |
+
total_num = 0
|
| 126 |
+
for line in prof_lines:
|
| 127 |
+
if name in line:
|
| 128 |
+
time_str = line.split()[-2]
|
| 129 |
+
num_str = line.split()[-1]
|
| 130 |
+
for unit, scale in units.items():
|
| 131 |
+
if unit in time_str:
|
| 132 |
+
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
| 133 |
+
total_num += int(num_str)
|
| 134 |
+
break
|
| 135 |
+
kernel_times.append(total_time / total_num if total_num > 0 else 0)
|
| 136 |
+
|
| 137 |
+
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
build/torch210-cxx11-cu128-x86_64-linux/testing/numeric.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Iterable
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
| 6 |
+
x, y = x.double(), y.double()
|
| 7 |
+
denominator = (x * x + y * y).sum()
|
| 8 |
+
if denominator == 0: # Which means that all elements in x and y are 0
|
| 9 |
+
return 0.0
|
| 10 |
+
sim = 2 * (x * y).sum() / denominator
|
| 11 |
+
return 1 - sim
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def count_bytes(*tensors):
|
| 15 |
+
total = 0
|
| 16 |
+
for t in tensors:
|
| 17 |
+
if isinstance(t, (tuple, list)):
|
| 18 |
+
total += count_bytes(*t)
|
| 19 |
+
elif t is not None:
|
| 20 |
+
total += t.numel() * t.element_size()
|
| 21 |
+
return total
|
build/torch210-cxx11-cu128-x86_64-linux/testing/utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Callable
|
| 5 |
+
|
| 6 |
+
def get_arch_major() -> int:
|
| 7 |
+
major, minor = torch.cuda.get_device_capability()
|
| 8 |
+
return major
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_filter(condition: Callable):
|
| 12 |
+
def decorator(func):
|
| 13 |
+
@functools.wraps(func)
|
| 14 |
+
def wrapper(*args, **kwargs):
|
| 15 |
+
if condition():
|
| 16 |
+
func(*args, **kwargs)
|
| 17 |
+
else:
|
| 18 |
+
print(f'{func.__name__}:')
|
| 19 |
+
print(f' > Filtered by {condition}')
|
| 20 |
+
print()
|
| 21 |
+
return wrapper
|
| 22 |
+
return decorator
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ignore_env(name: str, condition: Callable):
|
| 26 |
+
def decorator(func):
|
| 27 |
+
@functools.wraps(func)
|
| 28 |
+
def wrapper(*args, **kwargs):
|
| 29 |
+
if condition():
|
| 30 |
+
saved = os.environ.pop(name, None)
|
| 31 |
+
func(*args, **kwargs)
|
| 32 |
+
if saved is not None:
|
| 33 |
+
os.environ[name] = saved
|
| 34 |
+
else:
|
| 35 |
+
func(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
return decorator
|
build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import math, layout
|
| 2 |
+
from .layout import *
|
| 3 |
+
from .math import *
|
build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .._ops import ops
|
| 3 |
+
|
| 4 |
+
def get_tma_aligned_size(x, element_size):
|
| 5 |
+
return ops.get_tma_aligned_size(x, element_size)
|
| 6 |
+
|
| 7 |
+
def get_mn_major_tma_aligned_tensor(sf):
|
| 8 |
+
return ops.get_mn_major_tma_aligned_tensor(sf)
|
| 9 |
+
|
| 10 |
+
def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
|
| 11 |
+
return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
| 12 |
+
|
| 13 |
+
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
|
| 14 |
+
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 15 |
+
sf, ks_tensor, ks)
|
| 16 |
+
except ImportError:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
from .._ops import ops as _ops
|
| 20 |
+
|
| 21 |
+
def get_mk_alignment_for_contiguous_layout():
|
| 22 |
+
return _ops.get_mk_alignment_for_contiguous_layout()
|
| 23 |
+
|
| 24 |
+
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
| 25 |
+
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
build/torch210-cxx11-cu128-x86_64-linux/utils/math.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def ceil_div(x: int, y: int) -> int:
|
| 6 |
+
return (x + y - 1) // y
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def align(x: int, y: int) -> int:
|
| 10 |
+
return ceil_div(x, y) * y
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def ceil_to_ue8m0(x: torch.Tensor):
|
| 14 |
+
assert x.view(-1).amax().item() > 0
|
| 15 |
+
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 19 |
+
assert x.dim() == 2
|
| 20 |
+
m, n = x.shape
|
| 21 |
+
padded_n = align(n, gran_k)
|
| 22 |
+
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
|
| 23 |
+
x_padded[:, :n] = x
|
| 24 |
+
x_view = x_padded.view(m, -1, gran_k)
|
| 25 |
+
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
| 26 |
+
sf = x_amax / 448.0
|
| 27 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 28 |
+
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 32 |
+
assert x.dim() == 2 and x.size(0) % gran_k == 0
|
| 33 |
+
m, n = x.shape
|
| 34 |
+
x_view = x.view(-1, gran_k, n)
|
| 35 |
+
x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
|
| 36 |
+
sf = x_amax / 448.0
|
| 37 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 38 |
+
return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
assert x.dim() == 2
|
| 43 |
+
m, n = x.shape
|
| 44 |
+
x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
|
| 45 |
+
x_padded[:m, :n] = x
|
| 46 |
+
x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
|
| 47 |
+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
| 48 |
+
sf = x_amax / 448.0
|
| 49 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 50 |
+
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
| 51 |
+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 55 |
+
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
| 56 |
+
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
| 57 |
+
sf = x_amax / 448.0
|
| 58 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 59 |
+
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
| 60 |
+
return x_scaled, sf.squeeze()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
ax = x.abs().clamp_max(6.0)
|
| 65 |
+
# {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
| 66 |
+
# midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0
|
| 67 |
+
boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
|
| 68 |
+
device=x.device, dtype=ax.dtype)
|
| 69 |
+
idx = torch.bucketize(ax, boundaries)
|
| 70 |
+
code = idx.to(torch.uint8)
|
| 71 |
+
sign = (x < 0) & (idx != 0)
|
| 72 |
+
code = code | (sign.to(torch.uint8) << 3)
|
| 73 |
+
return code # uint8, 0..15
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 77 |
+
assert x.dim() == 2
|
| 78 |
+
m, n = x.shape
|
| 79 |
+
assert n % 2 == 0
|
| 80 |
+
padded_n = align(n, gran_k)
|
| 81 |
+
x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
|
| 82 |
+
x_padded[:, :n] = x
|
| 83 |
+
x_view = x_padded.view(m, -1, gran_k)
|
| 84 |
+
x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4)
|
| 85 |
+
sf = x_amax / 6.0
|
| 86 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 87 |
+
x_scaled = x_view * (1.0 / sf.unsqueeze(2))
|
| 88 |
+
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
|
| 89 |
+
codes2 = codes.view(m, padded_n // 2, 2)
|
| 90 |
+
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
|
| 91 |
+
return packed[:, :n // 2].contiguous(), sf
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
assert a.dtype == torch.uint8
|
| 96 |
+
assert a.dim() == 2
|
| 97 |
+
m, n2 = a.shape
|
| 98 |
+
n = n2 * 2
|
| 99 |
+
assert (m % 2) == 0
|
| 100 |
+
lo = a & 0x0F
|
| 101 |
+
hi = (a >> 4) & 0x0F
|
| 102 |
+
codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
|
| 103 |
+
codes[:, 0::2], codes[:, 1::2] = lo, hi
|
| 104 |
+
codes_t = codes.transpose(0, 1).contiguous()
|
| 105 |
+
codes2 = codes_t.view(n, m // 2, 2)
|
| 106 |
+
out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
|
| 107 |
+
return out.contiguous()
|
build/torch210-cxx11-cu130-x86_64-linux/__init__.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _find_cuda_home():
|
| 9 |
+
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
| 10 |
+
if cuda_home is None:
|
| 11 |
+
try:
|
| 12 |
+
with open(os.devnull, 'w') as devnull:
|
| 13 |
+
nvcc = subprocess.check_output(
|
| 14 |
+
['which', 'nvcc'], stderr=devnull
|
| 15 |
+
).decode().rstrip('\r\n')
|
| 16 |
+
cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
| 17 |
+
except Exception:
|
| 18 |
+
cuda_home = '/usr/local/cuda'
|
| 19 |
+
if not os.path.exists(cuda_home):
|
| 20 |
+
cuda_home = ''
|
| 21 |
+
return cuda_home or ''
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _find_cutlass_include():
|
| 25 |
+
"""Find CUTLASS include path for JIT compilation of .cuh templates."""
|
| 26 |
+
# 1. Explicit env var
|
| 27 |
+
cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
|
| 28 |
+
if cutlass_include and os.path.isdir(cutlass_include):
|
| 29 |
+
return cutlass_include
|
| 30 |
+
|
| 31 |
+
# 2. CUTLASS_HOME env var
|
| 32 |
+
cutlass_home = os.environ.get('CUTLASS_HOME')
|
| 33 |
+
if cutlass_home:
|
| 34 |
+
p = os.path.join(cutlass_home, 'include')
|
| 35 |
+
if os.path.isdir(os.path.join(p, 'cute')):
|
| 36 |
+
return p
|
| 37 |
+
|
| 38 |
+
# 3. Check in package include/ directory (bundled cute/cutlass headers)
|
| 39 |
+
pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
|
| 40 |
+
if os.path.isdir(os.path.join(pkg_include, 'cute')):
|
| 41 |
+
return pkg_include
|
| 42 |
+
|
| 43 |
+
# 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
|
| 44 |
+
cuda_home = _find_cuda_home()
|
| 45 |
+
if cuda_home:
|
| 46 |
+
cuda_inc = os.path.join(cuda_home, 'include')
|
| 47 |
+
if os.path.isdir(os.path.join(cuda_inc, 'cute')):
|
| 48 |
+
return cuda_inc
|
| 49 |
+
|
| 50 |
+
# 5. Try to find nvidia-cutlass Python package
|
| 51 |
+
try:
|
| 52 |
+
import cutlass as _cutlass
|
| 53 |
+
cutlass_dir = os.path.dirname(_cutlass.__file__)
|
| 54 |
+
p = os.path.join(cutlass_dir, 'include')
|
| 55 |
+
if os.path.isdir(os.path.join(p, 'cute')):
|
| 56 |
+
return p
|
| 57 |
+
except ImportError:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
# Return empty string; C++ side will also check env vars
|
| 61 |
+
return ""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def set_num_sms(new_num_sms):
|
| 65 |
+
ops.set_num_sms(new_num_sms)
|
| 66 |
+
|
| 67 |
+
def get_num_sms():
|
| 68 |
+
return ops.get_num_sms()
|
| 69 |
+
|
| 70 |
+
def set_tc_util(new_tc_util):
|
| 71 |
+
ops.set_tc_util(new_tc_util)
|
| 72 |
+
|
| 73 |
+
def get_tc_util():
|
| 74 |
+
return ops.get_tc_util()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# cuBLASLt GEMMs
|
| 78 |
+
def cublaslt_gemm_nt(a, b, d, c=None):
|
| 79 |
+
ops.cublaslt_gemm_nt(a, b, d, c)
|
| 80 |
+
|
| 81 |
+
def cublaslt_gemm_nn(a, b, d, c=None):
|
| 82 |
+
ops.cublaslt_gemm_nn(a, b, d, c)
|
| 83 |
+
|
| 84 |
+
def cublaslt_gemm_tn(a, b, d, c=None):
|
| 85 |
+
ops.cublaslt_gemm_tn(a, b, d, c)
|
| 86 |
+
|
| 87 |
+
def cublaslt_gemm_tt(a, b, d, c=None):
|
| 88 |
+
ops.cublaslt_gemm_tt(a, b, d, c)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# FP8/FP4 GEMMs
|
| 93 |
+
def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 94 |
+
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
|
| 95 |
+
ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
|
| 96 |
+
list(recipe) if recipe else None,
|
| 97 |
+
list(recipe_a) if recipe_a else None,
|
| 98 |
+
list(recipe_b) if recipe_b else None,
|
| 99 |
+
compiled_dims, disable_ue8m0_cast)
|
| 100 |
+
|
| 101 |
+
def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 102 |
+
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
|
| 103 |
+
ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
|
| 104 |
+
list(recipe) if recipe else None,
|
| 105 |
+
list(recipe_a) if recipe_a else None,
|
| 106 |
+
list(recipe_b) if recipe_b else None,
|
| 107 |
+
compiled_dims, disable_ue8m0_cast)
|
| 108 |
+
|
| 109 |
+
def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 110 |
+
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
|
| 111 |
+
ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
|
| 112 |
+
list(recipe) if recipe else None,
|
| 113 |
+
list(recipe_a) if recipe_a else None,
|
| 114 |
+
list(recipe_b) if recipe_b else None,
|
| 115 |
+
compiled_dims, disable_ue8m0_cast)
|
| 116 |
+
|
| 117 |
+
def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 118 |
+
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
|
| 119 |
+
ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
|
| 120 |
+
list(recipe) if recipe else None,
|
| 121 |
+
list(recipe_a) if recipe_a else None,
|
| 122 |
+
list(recipe_b) if recipe_b else None,
|
| 123 |
+
compiled_dims, disable_ue8m0_cast)
|
| 124 |
+
|
| 125 |
+
fp8_gemm_nt = fp8_fp4_gemm_nt
|
| 126 |
+
fp8_gemm_nn = fp8_fp4_gemm_nn
|
| 127 |
+
fp8_gemm_tn = fp8_fp4_gemm_tn
|
| 128 |
+
fp8_gemm_tt = fp8_fp4_gemm_tt
|
| 129 |
+
|
| 130 |
+
def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
|
| 131 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 132 |
+
disable_ue8m0_cast=False, use_psum_layout=False,
|
| 133 |
+
expected_m_for_psum_layout=None):
|
| 134 |
+
ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
| 135 |
+
a[0], a[1], b[0], b[1], d, grouped_layout,
|
| 136 |
+
list(recipe) if recipe else None,
|
| 137 |
+
list(recipe_a) if recipe_a else None,
|
| 138 |
+
list(recipe_b) if recipe_b else None,
|
| 139 |
+
compiled_dims, disable_ue8m0_cast, use_psum_layout,
|
| 140 |
+
expected_m_for_psum_layout)
|
| 141 |
+
|
| 142 |
+
m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
|
| 143 |
+
|
| 144 |
+
def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
|
| 145 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 146 |
+
disable_ue8m0_cast=False, use_psum_layout=False):
|
| 147 |
+
ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
|
| 148 |
+
a[0], a[1], b[0], b[1], d, grouped_layout,
|
| 149 |
+
list(recipe) if recipe else None,
|
| 150 |
+
list(recipe_a) if recipe_a else None,
|
| 151 |
+
list(recipe_b) if recipe_b else None,
|
| 152 |
+
compiled_dims, disable_ue8m0_cast, use_psum_layout)
|
| 153 |
+
|
| 154 |
+
m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
|
| 155 |
+
|
| 156 |
+
def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
|
| 157 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 158 |
+
disable_ue8m0_cast=False):
|
| 159 |
+
ops.m_grouped_fp8_fp4_gemm_nt_masked(
|
| 160 |
+
a[0], a[1], b[0], b[1], d, masked_m, expected_m,
|
| 161 |
+
list(recipe) if recipe else None,
|
| 162 |
+
list(recipe_a) if recipe_a else None,
|
| 163 |
+
list(recipe_b) if recipe_b else None,
|
| 164 |
+
compiled_dims, disable_ue8m0_cast)
|
| 165 |
+
|
| 166 |
+
m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 167 |
+
|
| 168 |
+
def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
|
| 169 |
+
recipe=(1, 1, 128), compiled_dims="mn"):
|
| 170 |
+
ops.k_grouped_fp8_gemm_nt_contiguous(
|
| 171 |
+
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
|
| 172 |
+
list(recipe), compiled_dims)
|
| 173 |
+
|
| 174 |
+
def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
|
| 175 |
+
recipe=(1, 1, 128), compiled_dims="mn"):
|
| 176 |
+
ops.k_grouped_fp8_gemm_tn_contiguous(
|
| 177 |
+
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
|
| 178 |
+
list(recipe), compiled_dims)
|
| 179 |
+
|
| 180 |
+
# BF16 GEMMs
|
| 181 |
+
def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
|
| 182 |
+
ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
|
| 183 |
+
|
| 184 |
+
def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
|
| 185 |
+
ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
|
| 186 |
+
|
| 187 |
+
def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
|
| 188 |
+
ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
|
| 189 |
+
|
| 190 |
+
def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
|
| 191 |
+
ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
|
| 192 |
+
|
| 193 |
+
def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
|
| 194 |
+
compiled_dims="nk", use_psum_layout=False,
|
| 195 |
+
expected_m_for_psum_layout=None):
|
| 196 |
+
ops.m_grouped_bf16_gemm_nt_contiguous(
|
| 197 |
+
a, b, d, grouped_layout, compiled_dims,
|
| 198 |
+
use_psum_layout, expected_m_for_psum_layout)
|
| 199 |
+
|
| 200 |
+
def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
|
| 201 |
+
compiled_dims="nk", use_psum_layout=False):
|
| 202 |
+
ops.m_grouped_bf16_gemm_nn_contiguous(
|
| 203 |
+
a, b, d, grouped_layout, compiled_dims, use_psum_layout)
|
| 204 |
+
|
| 205 |
+
def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
|
| 206 |
+
compiled_dims="nk"):
|
| 207 |
+
ops.m_grouped_bf16_gemm_nt_masked(
|
| 208 |
+
a, b, d, masked_m, expected_m, compiled_dims)
|
| 209 |
+
|
| 210 |
+
def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
|
| 211 |
+
c=None, compiled_dims="mn"):
|
| 212 |
+
ops.k_grouped_bf16_gemm_tn_contiguous(
|
| 213 |
+
a, b, d, ks, ks_tensor, c, compiled_dims)
|
| 214 |
+
|
| 215 |
+
# Einsum
|
| 216 |
+
def einsum(expr, a, b, d, c=None, use_cublaslt=False):
|
| 217 |
+
ops.einsum(expr, a, b, d, c, use_cublaslt)
|
| 218 |
+
|
| 219 |
+
def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
|
| 220 |
+
ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
|
| 221 |
+
|
| 222 |
+
# Attention
|
| 223 |
+
def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
|
| 224 |
+
compiled_dims="nk", disable_ue8m0_cast=False):
|
| 225 |
+
ops.fp8_gemm_nt_skip_head_mid(
|
| 226 |
+
a[0], a[1], b[0], b[1], d, list(head_splits),
|
| 227 |
+
list(recipe) if recipe else None,
|
| 228 |
+
compiled_dims, disable_ue8m0_cast)
|
| 229 |
+
|
| 230 |
+
def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
|
| 231 |
+
cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
|
| 232 |
+
return ops.fp8_mqa_logits(
|
| 233 |
+
q, kv[0], kv[1], weights,
|
| 234 |
+
cu_seq_len_k_start, cu_seq_len_k_end,
|
| 235 |
+
clean_logits, max_seqlen_k)
|
| 236 |
+
|
| 237 |
+
def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
|
| 238 |
+
return ops.get_paged_mqa_logits_metadata(
|
| 239 |
+
context_lens, block_kv, num_sms)
|
| 240 |
+
|
| 241 |
+
def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
|
| 242 |
+
block_table, schedule_meta,
|
| 243 |
+
max_context_len, clean_logits=False):
|
| 244 |
+
return ops.fp8_paged_mqa_logits(
|
| 245 |
+
q, fused_kv_cache, weights, context_lens,
|
| 246 |
+
block_table, schedule_meta, max_context_len, clean_logits)
|
| 247 |
+
|
| 248 |
+
# Hyperconnection
|
| 249 |
+
def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
|
| 250 |
+
ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
|
| 251 |
+
|
| 252 |
+
# Layout
|
| 253 |
+
def transform_sf_into_required_layout(sf, mn, k, recipe=None,
|
| 254 |
+
recipe_ab=None, num_groups=None, is_sfa=False,
|
| 255 |
+
disable_ue8m0_cast=False):
|
| 256 |
+
return ops.transform_sf_into_required_layout(
|
| 257 |
+
sf, mn, k,
|
| 258 |
+
list(recipe) if recipe else None,
|
| 259 |
+
list(recipe_ab) if recipe_ab else None,
|
| 260 |
+
num_groups, is_sfa, disable_ue8m0_cast)
|
| 261 |
+
|
| 262 |
+
def get_mk_alignment_for_contiguous_layout():
|
| 263 |
+
return ops.get_mk_alignment_for_contiguous_layout()
|
| 264 |
+
|
| 265 |
+
# Legacy aliases
|
| 266 |
+
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 267 |
+
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
|
| 268 |
+
|
| 269 |
+
except Exception:
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
# Utils
|
| 273 |
+
from . import utils
|
| 274 |
+
from .utils import *
|
| 275 |
+
|
| 276 |
+
# Testing
|
| 277 |
+
from . import testing
|
| 278 |
+
|
| 279 |
+
# Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
|
| 280 |
+
try:
|
| 281 |
+
ops.init(
|
| 282 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 283 |
+
_find_cuda_home(),
|
| 284 |
+
_find_cutlass_include()
|
| 285 |
+
)
|
| 286 |
+
except Exception:
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
__version__ = '2.3.0'
|
build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8307e5e24ea3f68435a8251df19977bfd2323e60f761b4c3cd7c5ba7aada4c3f
|
| 3 |
+
size 3078072
|
build/torch210-cxx11-cu130-x86_64-linux/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _deep_gemm_099ac3c_dirty
|
| 3 |
+
ops = torch.ops._deep_gemm_099ac3c_dirty
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_deep_gemm_099ac3c_dirty::{op_name}"
|
build/torch210-cxx11-cu130-x86_64-linux/deep_gemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch210-cxx11-cu130-x86_64-linux/metadata.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python-depends": []
|
| 3 |
+
}
|
build/torch210-cxx11-cu130-x86_64-linux/testing/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import bench, numeric, utils
|
| 2 |
+
from .bench import *
|
| 3 |
+
from .numeric import *
|
| 4 |
+
from .utils import *
|
build/torch210-cxx11-cu130-x86_64-linux/testing/bench.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
| 7 |
+
high_precision: bool = False):
|
| 8 |
+
# Flush L2 cache with 256 MB data
|
| 9 |
+
torch.cuda.synchronize()
|
| 10 |
+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
| 11 |
+
cache.zero_()
|
| 12 |
+
|
| 13 |
+
# Warmup
|
| 14 |
+
for _ in range(num_warmups):
|
| 15 |
+
fn()
|
| 16 |
+
|
| 17 |
+
# Add a large kernel to eliminate the CPU launch overhead
|
| 18 |
+
if high_precision:
|
| 19 |
+
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 20 |
+
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 21 |
+
x @ y
|
| 22 |
+
|
| 23 |
+
# Testing
|
| 24 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 25 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 26 |
+
start_event.record()
|
| 27 |
+
for i in range(num_tests):
|
| 28 |
+
fn()
|
| 29 |
+
end_event.record()
|
| 30 |
+
torch.cuda.synchronize()
|
| 31 |
+
|
| 32 |
+
return start_event.elapsed_time(end_event) / num_tests / 1e3
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class empty_suppress:
|
| 36 |
+
def __enter__(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
def __exit__(self, *_):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class suppress_stdout_stderr:
|
| 44 |
+
def __enter__(self):
|
| 45 |
+
self.outnull_file = open(os.devnull, 'w')
|
| 46 |
+
self.errnull_file = open(os.devnull, 'w')
|
| 47 |
+
|
| 48 |
+
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
| 49 |
+
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
| 50 |
+
|
| 51 |
+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
| 52 |
+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
| 53 |
+
|
| 54 |
+
self.old_stdout = sys.stdout
|
| 55 |
+
self.old_stderr = sys.stderr
|
| 56 |
+
|
| 57 |
+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
| 58 |
+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
| 59 |
+
|
| 60 |
+
sys.stdout = self.outnull_file
|
| 61 |
+
sys.stderr = self.errnull_file
|
| 62 |
+
return self
|
| 63 |
+
|
| 64 |
+
def __exit__(self, *_):
|
| 65 |
+
sys.stdout = self.old_stdout
|
| 66 |
+
sys.stderr = self.old_stderr
|
| 67 |
+
|
| 68 |
+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
| 69 |
+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
| 70 |
+
|
| 71 |
+
os.close(self.old_stdout_fileno)
|
| 72 |
+
os.close(self.old_stderr_fileno)
|
| 73 |
+
|
| 74 |
+
self.outnull_file.close()
|
| 75 |
+
self.errnull_file.close()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
| 79 |
+
suppress_kineto_output: bool = False,
|
| 80 |
+
trace_path: str = None, flush_l2: bool = True,
|
| 81 |
+
with_multiple_kernels: bool = False):
|
| 82 |
+
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
| 83 |
+
is_tuple = isinstance(kernel_names, tuple)
|
| 84 |
+
|
| 85 |
+
# Skip profiling
|
| 86 |
+
# Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
|
| 87 |
+
if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
|
| 88 |
+
return (1, ) * len(kernel_names) if is_tuple else 1
|
| 89 |
+
|
| 90 |
+
# By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
|
| 91 |
+
flush_l2_size = int(8e9 // 4)
|
| 92 |
+
|
| 93 |
+
# For some auto-tuning kernels with prints
|
| 94 |
+
fn()
|
| 95 |
+
|
| 96 |
+
# Profile
|
| 97 |
+
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
| 98 |
+
with suppress():
|
| 99 |
+
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
|
| 100 |
+
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
|
| 101 |
+
with profiler:
|
| 102 |
+
for i in range(2):
|
| 103 |
+
for _ in range(num_tests):
|
| 104 |
+
if flush_l2:
|
| 105 |
+
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
| 106 |
+
fn()
|
| 107 |
+
profiler.step()
|
| 108 |
+
|
| 109 |
+
# Parse the profiling table
|
| 110 |
+
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
| 111 |
+
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
| 112 |
+
if not with_multiple_kernels:
|
| 113 |
+
for name in kernel_names:
|
| 114 |
+
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
|
| 115 |
+
|
| 116 |
+
# Save chrome traces
|
| 117 |
+
if trace_path is not None:
|
| 118 |
+
profiler.export_chrome_trace(trace_path)
|
| 119 |
+
|
| 120 |
+
# Return average kernel times
|
| 121 |
+
units = {'ms': 1e3, 'us': 1e6}
|
| 122 |
+
kernel_times = []
|
| 123 |
+
for name in kernel_names:
|
| 124 |
+
total_time = 0
|
| 125 |
+
total_num = 0
|
| 126 |
+
for line in prof_lines:
|
| 127 |
+
if name in line:
|
| 128 |
+
time_str = line.split()[-2]
|
| 129 |
+
num_str = line.split()[-1]
|
| 130 |
+
for unit, scale in units.items():
|
| 131 |
+
if unit in time_str:
|
| 132 |
+
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
| 133 |
+
total_num += int(num_str)
|
| 134 |
+
break
|
| 135 |
+
kernel_times.append(total_time / total_num if total_num > 0 else 0)
|
| 136 |
+
|
| 137 |
+
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
build/torch210-cxx11-cu130-x86_64-linux/testing/numeric.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Iterable
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
| 6 |
+
x, y = x.double(), y.double()
|
| 7 |
+
denominator = (x * x + y * y).sum()
|
| 8 |
+
if denominator == 0: # Which means that all elements in x and y are 0
|
| 9 |
+
return 0.0
|
| 10 |
+
sim = 2 * (x * y).sum() / denominator
|
| 11 |
+
return 1 - sim
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def count_bytes(*tensors):
|
| 15 |
+
total = 0
|
| 16 |
+
for t in tensors:
|
| 17 |
+
if isinstance(t, (tuple, list)):
|
| 18 |
+
total += count_bytes(*t)
|
| 19 |
+
elif t is not None:
|
| 20 |
+
total += t.numel() * t.element_size()
|
| 21 |
+
return total
|
build/torch210-cxx11-cu130-x86_64-linux/testing/utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Callable
|
| 5 |
+
|
| 6 |
+
def get_arch_major() -> int:
|
| 7 |
+
major, minor = torch.cuda.get_device_capability()
|
| 8 |
+
return major
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_filter(condition: Callable):
|
| 12 |
+
def decorator(func):
|
| 13 |
+
@functools.wraps(func)
|
| 14 |
+
def wrapper(*args, **kwargs):
|
| 15 |
+
if condition():
|
| 16 |
+
func(*args, **kwargs)
|
| 17 |
+
else:
|
| 18 |
+
print(f'{func.__name__}:')
|
| 19 |
+
print(f' > Filtered by {condition}')
|
| 20 |
+
print()
|
| 21 |
+
return wrapper
|
| 22 |
+
return decorator
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ignore_env(name: str, condition: Callable):
|
| 26 |
+
def decorator(func):
|
| 27 |
+
@functools.wraps(func)
|
| 28 |
+
def wrapper(*args, **kwargs):
|
| 29 |
+
if condition():
|
| 30 |
+
saved = os.environ.pop(name, None)
|
| 31 |
+
func(*args, **kwargs)
|
| 32 |
+
if saved is not None:
|
| 33 |
+
os.environ[name] = saved
|
| 34 |
+
else:
|
| 35 |
+
func(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
return decorator
|
build/torch210-cxx11-cu130-x86_64-linux/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import math, layout
|
| 2 |
+
from .layout import *
|
| 3 |
+
from .math import *
|
build/torch210-cxx11-cu130-x86_64-linux/utils/layout.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .._ops import ops
|
| 3 |
+
|
| 4 |
+
def get_tma_aligned_size(x, element_size):
|
| 5 |
+
return ops.get_tma_aligned_size(x, element_size)
|
| 6 |
+
|
| 7 |
+
def get_mn_major_tma_aligned_tensor(sf):
|
| 8 |
+
return ops.get_mn_major_tma_aligned_tensor(sf)
|
| 9 |
+
|
| 10 |
+
def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
|
| 11 |
+
return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
| 12 |
+
|
| 13 |
+
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
|
| 14 |
+
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 15 |
+
sf, ks_tensor, ks)
|
| 16 |
+
except ImportError:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
from .._ops import ops as _ops
|
| 20 |
+
|
| 21 |
+
def get_mk_alignment_for_contiguous_layout():
|
| 22 |
+
return _ops.get_mk_alignment_for_contiguous_layout()
|
| 23 |
+
|
| 24 |
+
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
| 25 |
+
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
build/torch210-cxx11-cu130-x86_64-linux/utils/math.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def ceil_div(x: int, y: int) -> int:
|
| 6 |
+
return (x + y - 1) // y
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def align(x: int, y: int) -> int:
|
| 10 |
+
return ceil_div(x, y) * y
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def ceil_to_ue8m0(x: torch.Tensor):
|
| 14 |
+
assert x.view(-1).amax().item() > 0
|
| 15 |
+
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 19 |
+
assert x.dim() == 2
|
| 20 |
+
m, n = x.shape
|
| 21 |
+
padded_n = align(n, gran_k)
|
| 22 |
+
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
|
| 23 |
+
x_padded[:, :n] = x
|
| 24 |
+
x_view = x_padded.view(m, -1, gran_k)
|
| 25 |
+
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
| 26 |
+
sf = x_amax / 448.0
|
| 27 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 28 |
+
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 32 |
+
assert x.dim() == 2 and x.size(0) % gran_k == 0
|
| 33 |
+
m, n = x.shape
|
| 34 |
+
x_view = x.view(-1, gran_k, n)
|
| 35 |
+
x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
|
| 36 |
+
sf = x_amax / 448.0
|
| 37 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 38 |
+
return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
assert x.dim() == 2
|
| 43 |
+
m, n = x.shape
|
| 44 |
+
x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
|
| 45 |
+
x_padded[:m, :n] = x
|
| 46 |
+
x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
|
| 47 |
+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
| 48 |
+
sf = x_amax / 448.0
|
| 49 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 50 |
+
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
| 51 |
+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 55 |
+
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
| 56 |
+
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
| 57 |
+
sf = x_amax / 448.0
|
| 58 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 59 |
+
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
| 60 |
+
return x_scaled, sf.squeeze()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
ax = x.abs().clamp_max(6.0)
|
| 65 |
+
# {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
| 66 |
+
# midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0
|
| 67 |
+
boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
|
| 68 |
+
device=x.device, dtype=ax.dtype)
|
| 69 |
+
idx = torch.bucketize(ax, boundaries)
|
| 70 |
+
code = idx.to(torch.uint8)
|
| 71 |
+
sign = (x < 0) & (idx != 0)
|
| 72 |
+
code = code | (sign.to(torch.uint8) << 3)
|
| 73 |
+
return code # uint8, 0..15
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 77 |
+
assert x.dim() == 2
|
| 78 |
+
m, n = x.shape
|
| 79 |
+
assert n % 2 == 0
|
| 80 |
+
padded_n = align(n, gran_k)
|
| 81 |
+
x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
|
| 82 |
+
x_padded[:, :n] = x
|
| 83 |
+
x_view = x_padded.view(m, -1, gran_k)
|
| 84 |
+
x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4)
|
| 85 |
+
sf = x_amax / 6.0
|
| 86 |
+
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
| 87 |
+
x_scaled = x_view * (1.0 / sf.unsqueeze(2))
|
| 88 |
+
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
|
| 89 |
+
codes2 = codes.view(m, padded_n // 2, 2)
|
| 90 |
+
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
|
| 91 |
+
return packed[:, :n // 2].contiguous(), sf
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
assert a.dtype == torch.uint8
|
| 96 |
+
assert a.dim() == 2
|
| 97 |
+
m, n2 = a.shape
|
| 98 |
+
n = n2 * 2
|
| 99 |
+
assert (m % 2) == 0
|
| 100 |
+
lo = a & 0x0F
|
| 101 |
+
hi = (a >> 4) & 0x0F
|
| 102 |
+
codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
|
| 103 |
+
codes[:, 0::2], codes[:, 1::2] = lo, hi
|
| 104 |
+
codes_t = codes.transpose(0, 1).contiguous()
|
| 105 |
+
codes2 = codes_t.view(n, m // 2, 2)
|
| 106 |
+
out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
|
| 107 |
+
return out.contiguous()
|
build/torch29-cxx11-cu126-x86_64-linux/__init__.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _find_cuda_home():
|
| 9 |
+
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
| 10 |
+
if cuda_home is None:
|
| 11 |
+
try:
|
| 12 |
+
with open(os.devnull, 'w') as devnull:
|
| 13 |
+
nvcc = subprocess.check_output(
|
| 14 |
+
['which', 'nvcc'], stderr=devnull
|
| 15 |
+
).decode().rstrip('\r\n')
|
| 16 |
+
cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
| 17 |
+
except Exception:
|
| 18 |
+
cuda_home = '/usr/local/cuda'
|
| 19 |
+
if not os.path.exists(cuda_home):
|
| 20 |
+
cuda_home = ''
|
| 21 |
+
return cuda_home or ''
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _find_cutlass_include():
|
| 25 |
+
"""Find CUTLASS include path for JIT compilation of .cuh templates."""
|
| 26 |
+
# 1. Explicit env var
|
| 27 |
+
cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
|
| 28 |
+
if cutlass_include and os.path.isdir(cutlass_include):
|
| 29 |
+
return cutlass_include
|
| 30 |
+
|
| 31 |
+
# 2. CUTLASS_HOME env var
|
| 32 |
+
cutlass_home = os.environ.get('CUTLASS_HOME')
|
| 33 |
+
if cutlass_home:
|
| 34 |
+
p = os.path.join(cutlass_home, 'include')
|
| 35 |
+
if os.path.isdir(os.path.join(p, 'cute')):
|
| 36 |
+
return p
|
| 37 |
+
|
| 38 |
+
# 3. Check in package include/ directory (bundled cute/cutlass headers)
|
| 39 |
+
pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
|
| 40 |
+
if os.path.isdir(os.path.join(pkg_include, 'cute')):
|
| 41 |
+
return pkg_include
|
| 42 |
+
|
| 43 |
+
# 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
|
| 44 |
+
cuda_home = _find_cuda_home()
|
| 45 |
+
if cuda_home:
|
| 46 |
+
cuda_inc = os.path.join(cuda_home, 'include')
|
| 47 |
+
if os.path.isdir(os.path.join(cuda_inc, 'cute')):
|
| 48 |
+
return cuda_inc
|
| 49 |
+
|
| 50 |
+
# 5. Try to find nvidia-cutlass Python package
|
| 51 |
+
try:
|
| 52 |
+
import cutlass as _cutlass
|
| 53 |
+
cutlass_dir = os.path.dirname(_cutlass.__file__)
|
| 54 |
+
p = os.path.join(cutlass_dir, 'include')
|
| 55 |
+
if os.path.isdir(os.path.join(p, 'cute')):
|
| 56 |
+
return p
|
| 57 |
+
except ImportError:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
# Return empty string; C++ side will also check env vars
|
| 61 |
+
return ""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def set_num_sms(new_num_sms):
|
| 65 |
+
ops.set_num_sms(new_num_sms)
|
| 66 |
+
|
| 67 |
+
def get_num_sms():
|
| 68 |
+
return ops.get_num_sms()
|
| 69 |
+
|
| 70 |
+
def set_tc_util(new_tc_util):
|
| 71 |
+
ops.set_tc_util(new_tc_util)
|
| 72 |
+
|
| 73 |
+
def get_tc_util():
|
| 74 |
+
return ops.get_tc_util()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# cuBLASLt GEMMs
|
| 78 |
+
def cublaslt_gemm_nt(a, b, d, c=None):
|
| 79 |
+
ops.cublaslt_gemm_nt(a, b, d, c)
|
| 80 |
+
|
| 81 |
+
def cublaslt_gemm_nn(a, b, d, c=None):
|
| 82 |
+
ops.cublaslt_gemm_nn(a, b, d, c)
|
| 83 |
+
|
| 84 |
+
def cublaslt_gemm_tn(a, b, d, c=None):
|
| 85 |
+
ops.cublaslt_gemm_tn(a, b, d, c)
|
| 86 |
+
|
| 87 |
+
def cublaslt_gemm_tt(a, b, d, c=None):
|
| 88 |
+
ops.cublaslt_gemm_tt(a, b, d, c)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# FP8/FP4 GEMMs
|
| 93 |
+
def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 94 |
+
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
|
| 95 |
+
ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
|
| 96 |
+
list(recipe) if recipe else None,
|
| 97 |
+
list(recipe_a) if recipe_a else None,
|
| 98 |
+
list(recipe_b) if recipe_b else None,
|
| 99 |
+
compiled_dims, disable_ue8m0_cast)
|
| 100 |
+
|
| 101 |
+
def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 102 |
+
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
|
| 103 |
+
ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
|
| 104 |
+
list(recipe) if recipe else None,
|
| 105 |
+
list(recipe_a) if recipe_a else None,
|
| 106 |
+
list(recipe_b) if recipe_b else None,
|
| 107 |
+
compiled_dims, disable_ue8m0_cast)
|
| 108 |
+
|
| 109 |
+
def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 110 |
+
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
|
| 111 |
+
ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
|
| 112 |
+
list(recipe) if recipe else None,
|
| 113 |
+
list(recipe_a) if recipe_a else None,
|
| 114 |
+
list(recipe_b) if recipe_b else None,
|
| 115 |
+
compiled_dims, disable_ue8m0_cast)
|
| 116 |
+
|
| 117 |
+
def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
|
| 118 |
+
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
|
| 119 |
+
ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
|
| 120 |
+
list(recipe) if recipe else None,
|
| 121 |
+
list(recipe_a) if recipe_a else None,
|
| 122 |
+
list(recipe_b) if recipe_b else None,
|
| 123 |
+
compiled_dims, disable_ue8m0_cast)
|
| 124 |
+
|
| 125 |
+
fp8_gemm_nt = fp8_fp4_gemm_nt
|
| 126 |
+
fp8_gemm_nn = fp8_fp4_gemm_nn
|
| 127 |
+
fp8_gemm_tn = fp8_fp4_gemm_tn
|
| 128 |
+
fp8_gemm_tt = fp8_fp4_gemm_tt
|
| 129 |
+
|
| 130 |
+
def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
|
| 131 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 132 |
+
disable_ue8m0_cast=False, use_psum_layout=False,
|
| 133 |
+
expected_m_for_psum_layout=None):
|
| 134 |
+
ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
| 135 |
+
a[0], a[1], b[0], b[1], d, grouped_layout,
|
| 136 |
+
list(recipe) if recipe else None,
|
| 137 |
+
list(recipe_a) if recipe_a else None,
|
| 138 |
+
list(recipe_b) if recipe_b else None,
|
| 139 |
+
compiled_dims, disable_ue8m0_cast, use_psum_layout,
|
| 140 |
+
expected_m_for_psum_layout)
|
| 141 |
+
|
| 142 |
+
m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
|
| 143 |
+
|
| 144 |
+
def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
|
| 145 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 146 |
+
disable_ue8m0_cast=False, use_psum_layout=False):
|
| 147 |
+
ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
|
| 148 |
+
a[0], a[1], b[0], b[1], d, grouped_layout,
|
| 149 |
+
list(recipe) if recipe else None,
|
| 150 |
+
list(recipe_a) if recipe_a else None,
|
| 151 |
+
list(recipe_b) if recipe_b else None,
|
| 152 |
+
compiled_dims, disable_ue8m0_cast, use_psum_layout)
|
| 153 |
+
|
| 154 |
+
m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
|
| 155 |
+
|
| 156 |
+
def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
|
| 157 |
+
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
|
| 158 |
+
disable_ue8m0_cast=False):
|
| 159 |
+
ops.m_grouped_fp8_fp4_gemm_nt_masked(
|
| 160 |
+
a[0], a[1], b[0], b[1], d, masked_m, expected_m,
|
| 161 |
+
list(recipe) if recipe else None,
|
| 162 |
+
list(recipe_a) if recipe_a else None,
|
| 163 |
+
list(recipe_b) if recipe_b else None,
|
| 164 |
+
compiled_dims, disable_ue8m0_cast)
|
| 165 |
+
|
| 166 |
+
m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 167 |
+
|
| 168 |
+
def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
|
| 169 |
+
recipe=(1, 1, 128), compiled_dims="mn"):
|
| 170 |
+
ops.k_grouped_fp8_gemm_nt_contiguous(
|
| 171 |
+
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
|
| 172 |
+
list(recipe), compiled_dims)
|
| 173 |
+
|
| 174 |
+
def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
|
| 175 |
+
recipe=(1, 1, 128), compiled_dims="mn"):
|
| 176 |
+
ops.k_grouped_fp8_gemm_tn_contiguous(
|
| 177 |
+
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
|
| 178 |
+
list(recipe), compiled_dims)
|
| 179 |
+
|
| 180 |
+
# BF16 GEMMs
|
| 181 |
+
def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
|
| 182 |
+
ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
|
| 183 |
+
|
| 184 |
+
def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
|
| 185 |
+
ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
|
| 186 |
+
|
| 187 |
+
def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
|
| 188 |
+
ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
|
| 189 |
+
|
| 190 |
+
def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
|
| 191 |
+
ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
|
| 192 |
+
|
| 193 |
+
def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
|
| 194 |
+
compiled_dims="nk", use_psum_layout=False,
|
| 195 |
+
expected_m_for_psum_layout=None):
|
| 196 |
+
ops.m_grouped_bf16_gemm_nt_contiguous(
|
| 197 |
+
a, b, d, grouped_layout, compiled_dims,
|
| 198 |
+
use_psum_layout, expected_m_for_psum_layout)
|
| 199 |
+
|
| 200 |
+
def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
|
| 201 |
+
compiled_dims="nk", use_psum_layout=False):
|
| 202 |
+
ops.m_grouped_bf16_gemm_nn_contiguous(
|
| 203 |
+
a, b, d, grouped_layout, compiled_dims, use_psum_layout)
|
| 204 |
+
|
| 205 |
+
def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
|
| 206 |
+
compiled_dims="nk"):
|
| 207 |
+
ops.m_grouped_bf16_gemm_nt_masked(
|
| 208 |
+
a, b, d, masked_m, expected_m, compiled_dims)
|
| 209 |
+
|
| 210 |
+
def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
|
| 211 |
+
c=None, compiled_dims="mn"):
|
| 212 |
+
ops.k_grouped_bf16_gemm_tn_contiguous(
|
| 213 |
+
a, b, d, ks, ks_tensor, c, compiled_dims)
|
| 214 |
+
|
| 215 |
+
# Einsum
|
| 216 |
+
def einsum(expr, a, b, d, c=None, use_cublaslt=False):
|
| 217 |
+
ops.einsum(expr, a, b, d, c, use_cublaslt)
|
| 218 |
+
|
| 219 |
+
def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
|
| 220 |
+
ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
|
| 221 |
+
|
| 222 |
+
# Attention
|
| 223 |
+
def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
|
| 224 |
+
compiled_dims="nk", disable_ue8m0_cast=False):
|
| 225 |
+
ops.fp8_gemm_nt_skip_head_mid(
|
| 226 |
+
a[0], a[1], b[0], b[1], d, list(head_splits),
|
| 227 |
+
list(recipe) if recipe else None,
|
| 228 |
+
compiled_dims, disable_ue8m0_cast)
|
| 229 |
+
|
| 230 |
+
def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
|
| 231 |
+
cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
|
| 232 |
+
return ops.fp8_mqa_logits(
|
| 233 |
+
q, kv[0], kv[1], weights,
|
| 234 |
+
cu_seq_len_k_start, cu_seq_len_k_end,
|
| 235 |
+
clean_logits, max_seqlen_k)
|
| 236 |
+
|
| 237 |
+
def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
|
| 238 |
+
return ops.get_paged_mqa_logits_metadata(
|
| 239 |
+
context_lens, block_kv, num_sms)
|
| 240 |
+
|
| 241 |
+
def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
|
| 242 |
+
block_table, schedule_meta,
|
| 243 |
+
max_context_len, clean_logits=False):
|
| 244 |
+
return ops.fp8_paged_mqa_logits(
|
| 245 |
+
q, fused_kv_cache, weights, context_lens,
|
| 246 |
+
block_table, schedule_meta, max_context_len, clean_logits)
|
| 247 |
+
|
| 248 |
+
# Hyperconnection
|
| 249 |
+
def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
|
| 250 |
+
ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
|
| 251 |
+
|
| 252 |
+
# Layout
|
| 253 |
+
def transform_sf_into_required_layout(sf, mn, k, recipe=None,
|
| 254 |
+
recipe_ab=None, num_groups=None, is_sfa=False,
|
| 255 |
+
disable_ue8m0_cast=False):
|
| 256 |
+
return ops.transform_sf_into_required_layout(
|
| 257 |
+
sf, mn, k,
|
| 258 |
+
list(recipe) if recipe else None,
|
| 259 |
+
list(recipe_ab) if recipe_ab else None,
|
| 260 |
+
num_groups, is_sfa, disable_ue8m0_cast)
|
| 261 |
+
|
| 262 |
+
def get_mk_alignment_for_contiguous_layout():
|
| 263 |
+
return ops.get_mk_alignment_for_contiguous_layout()
|
| 264 |
+
|
| 265 |
+
# Legacy aliases
|
| 266 |
+
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
|
| 267 |
+
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
|
| 268 |
+
|
| 269 |
+
except Exception:
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
# Utils
|
| 273 |
+
from . import utils
|
| 274 |
+
from .utils import *
|
| 275 |
+
|
| 276 |
+
# Testing
|
| 277 |
+
from . import testing
|
| 278 |
+
|
| 279 |
+
# Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
|
| 280 |
+
try:
|
| 281 |
+
ops.init(
|
| 282 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 283 |
+
_find_cuda_home(),
|
| 284 |
+
_find_cutlass_include()
|
| 285 |
+
)
|
| 286 |
+
except Exception:
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
__version__ = '2.3.0'
|
build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ac9ad7e5f8bcd1642692d50e321db2ee6a668bdc448fa481490e307e2dfb0ffe
|
| 3 |
+
size 2967864
|
build/torch29-cxx11-cu126-x86_64-linux/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _deep_gemm_099ac3c_dirty
|
| 3 |
+
ops = torch.ops._deep_gemm_099ac3c_dirty
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_deep_gemm_099ac3c_dirty::{op_name}"
|
build/torch29-cxx11-cu126-x86_64-linux/deep_gemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch29-cxx11-cu126-x86_64-linux/metadata.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python-depends": []
|
| 3 |
+
}
|
build/torch29-cxx11-cu126-x86_64-linux/testing/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import bench, numeric, utils
|
| 2 |
+
from .bench import *
|
| 3 |
+
from .numeric import *
|
| 4 |
+
from .utils import *
|
build/torch29-cxx11-cu126-x86_64-linux/testing/bench.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
| 7 |
+
high_precision: bool = False):
|
| 8 |
+
# Flush L2 cache with 256 MB data
|
| 9 |
+
torch.cuda.synchronize()
|
| 10 |
+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
| 11 |
+
cache.zero_()
|
| 12 |
+
|
| 13 |
+
# Warmup
|
| 14 |
+
for _ in range(num_warmups):
|
| 15 |
+
fn()
|
| 16 |
+
|
| 17 |
+
# Add a large kernel to eliminate the CPU launch overhead
|
| 18 |
+
if high_precision:
|
| 19 |
+
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 20 |
+
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
| 21 |
+
x @ y
|
| 22 |
+
|
| 23 |
+
# Testing
|
| 24 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 25 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 26 |
+
start_event.record()
|
| 27 |
+
for i in range(num_tests):
|
| 28 |
+
fn()
|
| 29 |
+
end_event.record()
|
| 30 |
+
torch.cuda.synchronize()
|
| 31 |
+
|
| 32 |
+
return start_event.elapsed_time(end_event) / num_tests / 1e3
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class empty_suppress:
|
| 36 |
+
def __enter__(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
def __exit__(self, *_):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class suppress_stdout_stderr:
|
| 44 |
+
def __enter__(self):
|
| 45 |
+
self.outnull_file = open(os.devnull, 'w')
|
| 46 |
+
self.errnull_file = open(os.devnull, 'w')
|
| 47 |
+
|
| 48 |
+
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
| 49 |
+
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
| 50 |
+
|
| 51 |
+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
| 52 |
+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
| 53 |
+
|
| 54 |
+
self.old_stdout = sys.stdout
|
| 55 |
+
self.old_stderr = sys.stderr
|
| 56 |
+
|
| 57 |
+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
| 58 |
+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
| 59 |
+
|
| 60 |
+
sys.stdout = self.outnull_file
|
| 61 |
+
sys.stderr = self.errnull_file
|
| 62 |
+
return self
|
| 63 |
+
|
| 64 |
+
def __exit__(self, *_):
|
| 65 |
+
sys.stdout = self.old_stdout
|
| 66 |
+
sys.stderr = self.old_stderr
|
| 67 |
+
|
| 68 |
+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
| 69 |
+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
| 70 |
+
|
| 71 |
+
os.close(self.old_stdout_fileno)
|
| 72 |
+
os.close(self.old_stderr_fileno)
|
| 73 |
+
|
| 74 |
+
self.outnull_file.close()
|
| 75 |
+
self.errnull_file.close()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
| 79 |
+
suppress_kineto_output: bool = False,
|
| 80 |
+
trace_path: str = None, flush_l2: bool = True,
|
| 81 |
+
with_multiple_kernels: bool = False):
|
| 82 |
+
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
| 83 |
+
is_tuple = isinstance(kernel_names, tuple)
|
| 84 |
+
|
| 85 |
+
# Skip profiling
|
| 86 |
+
# Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
|
| 87 |
+
if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
|
| 88 |
+
return (1, ) * len(kernel_names) if is_tuple else 1
|
| 89 |
+
|
| 90 |
+
# By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
|
| 91 |
+
flush_l2_size = int(8e9 // 4)
|
| 92 |
+
|
| 93 |
+
# For some auto-tuning kernels with prints
|
| 94 |
+
fn()
|
| 95 |
+
|
| 96 |
+
# Profile
|
| 97 |
+
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
| 98 |
+
with suppress():
|
| 99 |
+
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
|
| 100 |
+
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
|
| 101 |
+
with profiler:
|
| 102 |
+
for i in range(2):
|
| 103 |
+
for _ in range(num_tests):
|
| 104 |
+
if flush_l2:
|
| 105 |
+
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
| 106 |
+
fn()
|
| 107 |
+
profiler.step()
|
| 108 |
+
|
| 109 |
+
# Parse the profiling table
|
| 110 |
+
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
| 111 |
+
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
| 112 |
+
if not with_multiple_kernels:
|
| 113 |
+
for name in kernel_names:
|
| 114 |
+
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
|
| 115 |
+
|
| 116 |
+
# Save chrome traces
|
| 117 |
+
if trace_path is not None:
|
| 118 |
+
profiler.export_chrome_trace(trace_path)
|
| 119 |
+
|
| 120 |
+
# Return average kernel times
|
| 121 |
+
units = {'ms': 1e3, 'us': 1e6}
|
| 122 |
+
kernel_times = []
|
| 123 |
+
for name in kernel_names:
|
| 124 |
+
total_time = 0
|
| 125 |
+
total_num = 0
|
| 126 |
+
for line in prof_lines:
|
| 127 |
+
if name in line:
|
| 128 |
+
time_str = line.split()[-2]
|
| 129 |
+
num_str = line.split()[-1]
|
| 130 |
+
for unit, scale in units.items():
|
| 131 |
+
if unit in time_str:
|
| 132 |
+
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
| 133 |
+
total_num += int(num_str)
|
| 134 |
+
break
|
| 135 |
+
kernel_times.append(total_time / total_num if total_num > 0 else 0)
|
| 136 |
+
|
| 137 |
+
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
build/torch29-cxx11-cu126-x86_64-linux/testing/numeric.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Iterable
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
| 6 |
+
x, y = x.double(), y.double()
|
| 7 |
+
denominator = (x * x + y * y).sum()
|
| 8 |
+
if denominator == 0: # Which means that all elements in x and y are 0
|
| 9 |
+
return 0.0
|
| 10 |
+
sim = 2 * (x * y).sum() / denominator
|
| 11 |
+
return 1 - sim
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def count_bytes(*tensors):
|
| 15 |
+
total = 0
|
| 16 |
+
for t in tensors:
|
| 17 |
+
if isinstance(t, (tuple, list)):
|
| 18 |
+
total += count_bytes(*t)
|
| 19 |
+
elif t is not None:
|
| 20 |
+
total += t.numel() * t.element_size()
|
| 21 |
+
return total
|
build/torch29-cxx11-cu126-x86_64-linux/testing/utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Callable
|
| 5 |
+
|
| 6 |
+
def get_arch_major() -> int:
|
| 7 |
+
major, minor = torch.cuda.get_device_capability()
|
| 8 |
+
return major
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_filter(condition: Callable):
|
| 12 |
+
def decorator(func):
|
| 13 |
+
@functools.wraps(func)
|
| 14 |
+
def wrapper(*args, **kwargs):
|
| 15 |
+
if condition():
|
| 16 |
+
func(*args, **kwargs)
|
| 17 |
+
else:
|
| 18 |
+
print(f'{func.__name__}:')
|
| 19 |
+
print(f' > Filtered by {condition}')
|
| 20 |
+
print()
|
| 21 |
+
return wrapper
|
| 22 |
+
return decorator
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ignore_env(name: str, condition: Callable):
|
| 26 |
+
def decorator(func):
|
| 27 |
+
@functools.wraps(func)
|
| 28 |
+
def wrapper(*args, **kwargs):
|
| 29 |
+
if condition():
|
| 30 |
+
saved = os.environ.pop(name, None)
|
| 31 |
+
func(*args, **kwargs)
|
| 32 |
+
if saved is not None:
|
| 33 |
+
os.environ[name] = saved
|
| 34 |
+
else:
|
| 35 |
+
func(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
return decorator
|
build/torch29-cxx11-cu126-x86_64-linux/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import math, layout
|
| 2 |
+
from .layout import *
|
| 3 |
+
from .math import *
|
build/torch29-cxx11-cu126-x86_64-linux/utils/layout.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .._ops import ops
|
| 3 |
+
|
| 4 |
+
def get_tma_aligned_size(x, element_size):
|
| 5 |
+
return ops.get_tma_aligned_size(x, element_size)
|
| 6 |
+
|
| 7 |
+
def get_mn_major_tma_aligned_tensor(sf):
|
| 8 |
+
return ops.get_mn_major_tma_aligned_tensor(sf)
|
| 9 |
+
|
| 10 |
+
def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
|
| 11 |
+
return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
| 12 |
+
|
| 13 |
+
def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
|
| 14 |
+
return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
|
| 15 |
+
sf, ks_tensor, ks)
|
| 16 |
+
except ImportError:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
from .._ops import ops as _ops
|
| 20 |
+
|
| 21 |
+
def get_mk_alignment_for_contiguous_layout():
|
| 22 |
+
return _ops.get_mk_alignment_for_contiguous_layout()
|
| 23 |
+
|
| 24 |
+
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
| 25 |
+
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|