medmekk commited on
Commit
c67ae40
·
verified ·
1 Parent(s): 9090aa9

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. README.md +81 -0
  3. build.toml +100 -0
  4. build/torch210-cxx11-cu126-x86_64-linux/__init__.py +289 -0
  5. build/torch210-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so +3 -0
  6. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +9 -0
  7. build/torch210-cxx11-cu126-x86_64-linux/deep_gemm/__init__.py +26 -0
  8. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +3 -0
  9. build/torch210-cxx11-cu126-x86_64-linux/testing/__init__.py +4 -0
  10. build/torch210-cxx11-cu126-x86_64-linux/testing/bench.py +137 -0
  11. build/torch210-cxx11-cu126-x86_64-linux/testing/numeric.py +21 -0
  12. build/torch210-cxx11-cu126-x86_64-linux/testing/utils.py +38 -0
  13. build/torch210-cxx11-cu126-x86_64-linux/utils/__init__.py +3 -0
  14. build/torch210-cxx11-cu126-x86_64-linux/utils/layout.py +25 -0
  15. build/torch210-cxx11-cu126-x86_64-linux/utils/math.py +107 -0
  16. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +289 -0
  17. build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so +3 -0
  18. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +9 -0
  19. build/torch210-cxx11-cu128-x86_64-linux/deep_gemm/__init__.py +26 -0
  20. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +3 -0
  21. build/torch210-cxx11-cu128-x86_64-linux/testing/__init__.py +4 -0
  22. build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py +137 -0
  23. build/torch210-cxx11-cu128-x86_64-linux/testing/numeric.py +21 -0
  24. build/torch210-cxx11-cu128-x86_64-linux/testing/utils.py +38 -0
  25. build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py +3 -0
  26. build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py +25 -0
  27. build/torch210-cxx11-cu128-x86_64-linux/utils/math.py +107 -0
  28. build/torch210-cxx11-cu130-x86_64-linux/__init__.py +289 -0
  29. build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so +3 -0
  30. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +9 -0
  31. build/torch210-cxx11-cu130-x86_64-linux/deep_gemm/__init__.py +26 -0
  32. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +3 -0
  33. build/torch210-cxx11-cu130-x86_64-linux/testing/__init__.py +4 -0
  34. build/torch210-cxx11-cu130-x86_64-linux/testing/bench.py +137 -0
  35. build/torch210-cxx11-cu130-x86_64-linux/testing/numeric.py +21 -0
  36. build/torch210-cxx11-cu130-x86_64-linux/testing/utils.py +38 -0
  37. build/torch210-cxx11-cu130-x86_64-linux/utils/__init__.py +3 -0
  38. build/torch210-cxx11-cu130-x86_64-linux/utils/layout.py +25 -0
  39. build/torch210-cxx11-cu130-x86_64-linux/utils/math.py +107 -0
  40. build/torch29-cxx11-cu126-x86_64-linux/__init__.py +289 -0
  41. build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so +3 -0
  42. build/torch29-cxx11-cu126-x86_64-linux/_ops.py +9 -0
  43. build/torch29-cxx11-cu126-x86_64-linux/deep_gemm/__init__.py +26 -0
  44. build/torch29-cxx11-cu126-x86_64-linux/metadata.json +3 -0
  45. build/torch29-cxx11-cu126-x86_64-linux/testing/__init__.py +4 -0
  46. build/torch29-cxx11-cu126-x86_64-linux/testing/bench.py +137 -0
  47. build/torch29-cxx11-cu126-x86_64-linux/testing/numeric.py +21 -0
  48. build/torch29-cxx11-cu126-x86_64-linux/testing/utils.py +38 -0
  49. build/torch29-cxx11-cu126-x86_64-linux/utils/__init__.py +3 -0
  50. build/torch29-cxx11-cu126-x86_64-linux/utils/layout.py +25 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ build/torch210-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
37
+ build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
38
+ build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
39
+ build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
40
+ build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
41
+ build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepGEMM
2
+
3
+ DeepGEMM kernel for the [Hugging Face kernel-builder](https://github.com/huggingface/kernels) infrastructure.
4
+
5
+ This package provides FP8/FP4/BF16 GEMM kernels, einsum, attention, and hyperconnection operations
6
+ from [DeepSeek-AI/DeepGEMM](https://github.com/DeepSeek-AI/DeepGEMM), adapted to the kernels-community
7
+ build structure with torch library bindings.
8
+
9
+ ## Features
10
+
11
+ - **FP8/FP4 GEMMs**: NT, NN, TN, TT variants with M-grouped and K-grouped support
12
+ - **BF16 GEMMs**: NT, NN, TN, TT variants with M-grouped and K-grouped support
13
+ - **cuBLASLt GEMMs**: NT, NN, TN, TT wrappers
14
+ - **Einsum**: bmk,bnk->mn, bhr,hdr->bhd, bhd,hdr->bhr expressions (BF16 and FP8)
15
+ - **Attention**: FP8 MQA logits (regular and paged)
16
+ - **Hyperconnection**: TF32 prenorm GEMM
17
+ - **Layout utilities**: Scaling factor transformations, TMA alignment
18
+
19
+ ## Architecture Support
20
+
21
+ - SM 9.0a (Hopper / H100)
22
+ - SM 10.0a (Blackwell / B200)
23
+
24
+ ## Requirements
25
+
26
+ - CUDA >= 12.1
27
+ - PyTorch >= 2.1
28
+ - CUTLASS 3.9+
29
+ - NVRTC (part of CUDA Toolkit)
30
+
31
+ ## Installation
32
+
33
+ ```bash
34
+ pip install kernels
35
+ ```
36
+
37
+ ```python
38
+ import kernels
39
+ kernels.install("kernels-community/DeepGEMM")
40
+ ```
41
+
42
+ ## Usage
43
+
44
+ ```python
45
+ import deep_gemm
46
+
47
+ # FP8 GEMM: D = A @ B.T
48
+ deep_gemm.fp8_gemm_nt((a_fp8, sfa), (b_fp8, sfb), d)
49
+
50
+ # BF16 GEMM: D = A @ B.T
51
+ deep_gemm.bf16_gemm_nt(a_bf16, b_bf16, d)
52
+
53
+ # cuBLASLt GEMM
54
+ deep_gemm.cublaslt_gemm_nt(a, b, d)
55
+ ```
56
+
57
+ ## JIT Compilation
58
+
59
+ DeepGEMM uses Just-In-Time (JIT) compilation for its CUDA kernels. The kernel
60
+ templates (`.cuh` files in `include/deep_gemm/`) are compiled at runtime using
61
+ NVCC or NVRTC. First invocations may be slower due to compilation; results are
62
+ cached in `~/.deep_gemm/` for subsequent calls.
63
+
64
+ ### CUTLASS Runtime Dependency
65
+
66
+ The JIT-compiled kernels depend on CUTLASS headers (`cute/`, `cutlass/`) at
67
+ runtime. The package will automatically search for CUTLASS in these locations:
68
+
69
+ 1. `DG_CUTLASS_INCLUDE` environment variable (direct path to include dir)
70
+ 2. `CUTLASS_HOME` environment variable (`$CUTLASS_HOME/include`)
71
+ 3. Bundled in the package's `include/` directory
72
+ 4. `CUDA_HOME/include` (some CUDA 12.8+ installs bundle `cute/`)
73
+ 5. `nvidia-cutlass` Python package
74
+
75
+ Set one of these if JIT compilation fails with missing CUTLASS headers:
76
+
77
+ ```bash
78
+ export CUTLASS_HOME=/path/to/cutlass
79
+ # or
80
+ export DG_CUTLASS_INCLUDE=/path/to/cutlass/include
81
+ ```
build.toml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "deep_gemm"
3
+ backends = ["cuda"]
4
+
5
+ [general.hub]
6
+ repo-id = "kernels-community/DeepGEMM"
7
+
8
+ [torch]
9
+ src = [
10
+ "torch-ext/torch_binding.cpp",
11
+ "torch-ext/torch_binding.h",
12
+ ]
13
+
14
+ [kernel.deep_gemm]
15
+ backend = "cuda"
16
+ cuda-capabilities = [
17
+ "9.0a",
18
+ "10.0a",
19
+ ]
20
+ cxx-flags = [
21
+ "-std=c++17",
22
+ "-O3",
23
+ "-Wno-psabi",
24
+ "-Wno-deprecated-declarations",
25
+ ]
26
+ depends = [
27
+ "torch",
28
+ "cutlass_3_9",
29
+ ]
30
+ include = [
31
+ ".",
32
+ "csrc",
33
+ "deep_gemm/include",
34
+ "third-party/fmt/include",
35
+ ]
36
+ src = [
37
+ "csrc/deep_gemm_impl.cpp",
38
+ "csrc/apis/attention.hpp",
39
+ "csrc/apis/einsum.hpp",
40
+ "csrc/apis/gemm.hpp",
41
+ "csrc/apis/hyperconnection.hpp",
42
+ "csrc/apis/layout.hpp",
43
+ "csrc/apis/runtime.hpp",
44
+ "csrc/jit/cache.hpp",
45
+ "csrc/jit/compiler.hpp",
46
+ "csrc/jit/device_runtime.hpp",
47
+ "csrc/jit/handle.hpp",
48
+ "csrc/jit/kernel_runtime.hpp",
49
+ "csrc/jit_kernels/heuristics/common.hpp",
50
+ "csrc/jit_kernels/heuristics/sm90.hpp",
51
+ "csrc/jit_kernels/heuristics/sm100.hpp",
52
+ "csrc/jit_kernels/impls/epilogue.hpp",
53
+ "csrc/jit_kernels/impls/runtime_utils.hpp",
54
+ "csrc/jit_kernels/impls/sm90_bf16_gemm.hpp",
55
+ "csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp",
56
+ "csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp",
57
+ "csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp",
58
+ "csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp",
59
+ "csrc/jit_kernels/impls/sm100_bf16_gemm.hpp",
60
+ "csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp",
61
+ "csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp",
62
+ "csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp",
63
+ "csrc/jit_kernels/impls/smxx_clean_logits.hpp",
64
+ "csrc/jit_kernels/impls/smxx_cublaslt.hpp",
65
+ "csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp",
66
+ "csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp",
67
+ "csrc/jit_kernels/impls/smxx_layout.hpp",
68
+ "csrc/utils/compatibility.hpp",
69
+ "csrc/utils/exception.hpp",
70
+ "csrc/utils/format.hpp",
71
+ "csrc/utils/hash.hpp",
72
+ "csrc/utils/layout.hpp",
73
+ "csrc/utils/lazy_init.hpp",
74
+ "csrc/utils/math.hpp",
75
+ "csrc/utils/system.hpp",
76
+ "deep_gemm/include/deep_gemm/common/cute_tie.cuh",
77
+ "deep_gemm/include/deep_gemm/common/epilogue_utils.cuh",
78
+ "deep_gemm/include/deep_gemm/common/reduction.cuh",
79
+ "deep_gemm/include/deep_gemm/common/scheduler.cuh",
80
+ "deep_gemm/include/deep_gemm/common/sm100_utils.cuh",
81
+ "deep_gemm/include/deep_gemm/common/sm90_utils.cuh",
82
+ "deep_gemm/include/deep_gemm/common/tma_utils.cuh",
83
+ "deep_gemm/include/deep_gemm/common/types.hpp",
84
+ "deep_gemm/include/deep_gemm/common/utils.cuh",
85
+ "deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh",
86
+ "deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh",
87
+ "deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh",
88
+ "deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh",
89
+ "deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh",
90
+ "deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh",
91
+ "deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh",
92
+ "deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh",
93
+ "deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh",
94
+ "deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh",
95
+ "deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh",
96
+ "deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh",
97
+ "deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh",
98
+ "deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh",
99
+ "deep_gemm/include/deep_gemm/impls/smxx_layout.cuh",
100
+ ]
build/torch210-cxx11-cu126-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ def _find_cuda_home():
9
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
10
+ if cuda_home is None:
11
+ try:
12
+ with open(os.devnull, 'w') as devnull:
13
+ nvcc = subprocess.check_output(
14
+ ['which', 'nvcc'], stderr=devnull
15
+ ).decode().rstrip('\r\n')
16
+ cuda_home = os.path.dirname(os.path.dirname(nvcc))
17
+ except Exception:
18
+ cuda_home = '/usr/local/cuda'
19
+ if not os.path.exists(cuda_home):
20
+ cuda_home = ''
21
+ return cuda_home or ''
22
+
23
+
24
+ def _find_cutlass_include():
25
+ """Find CUTLASS include path for JIT compilation of .cuh templates."""
26
+ # 1. Explicit env var
27
+ cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
28
+ if cutlass_include and os.path.isdir(cutlass_include):
29
+ return cutlass_include
30
+
31
+ # 2. CUTLASS_HOME env var
32
+ cutlass_home = os.environ.get('CUTLASS_HOME')
33
+ if cutlass_home:
34
+ p = os.path.join(cutlass_home, 'include')
35
+ if os.path.isdir(os.path.join(p, 'cute')):
36
+ return p
37
+
38
+ # 3. Check in package include/ directory (bundled cute/cutlass headers)
39
+ pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
40
+ if os.path.isdir(os.path.join(pkg_include, 'cute')):
41
+ return pkg_include
42
+
43
+ # 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
44
+ cuda_home = _find_cuda_home()
45
+ if cuda_home:
46
+ cuda_inc = os.path.join(cuda_home, 'include')
47
+ if os.path.isdir(os.path.join(cuda_inc, 'cute')):
48
+ return cuda_inc
49
+
50
+ # 5. Try to find nvidia-cutlass Python package
51
+ try:
52
+ import cutlass as _cutlass
53
+ cutlass_dir = os.path.dirname(_cutlass.__file__)
54
+ p = os.path.join(cutlass_dir, 'include')
55
+ if os.path.isdir(os.path.join(p, 'cute')):
56
+ return p
57
+ except ImportError:
58
+ pass
59
+
60
+ # Return empty string; C++ side will also check env vars
61
+ return ""
62
+
63
+
64
+ def set_num_sms(new_num_sms):
65
+ ops.set_num_sms(new_num_sms)
66
+
67
+ def get_num_sms():
68
+ return ops.get_num_sms()
69
+
70
+ def set_tc_util(new_tc_util):
71
+ ops.set_tc_util(new_tc_util)
72
+
73
+ def get_tc_util():
74
+ return ops.get_tc_util()
75
+
76
+
77
+ # cuBLASLt GEMMs
78
+ def cublaslt_gemm_nt(a, b, d, c=None):
79
+ ops.cublaslt_gemm_nt(a, b, d, c)
80
+
81
+ def cublaslt_gemm_nn(a, b, d, c=None):
82
+ ops.cublaslt_gemm_nn(a, b, d, c)
83
+
84
+ def cublaslt_gemm_tn(a, b, d, c=None):
85
+ ops.cublaslt_gemm_tn(a, b, d, c)
86
+
87
+ def cublaslt_gemm_tt(a, b, d, c=None):
88
+ ops.cublaslt_gemm_tt(a, b, d, c)
89
+
90
+
91
+ try:
92
+ # FP8/FP4 GEMMs
93
+ def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
94
+ recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
95
+ ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
96
+ list(recipe) if recipe else None,
97
+ list(recipe_a) if recipe_a else None,
98
+ list(recipe_b) if recipe_b else None,
99
+ compiled_dims, disable_ue8m0_cast)
100
+
101
+ def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
102
+ recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
103
+ ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
104
+ list(recipe) if recipe else None,
105
+ list(recipe_a) if recipe_a else None,
106
+ list(recipe_b) if recipe_b else None,
107
+ compiled_dims, disable_ue8m0_cast)
108
+
109
+ def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
110
+ recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
111
+ ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
112
+ list(recipe) if recipe else None,
113
+ list(recipe_a) if recipe_a else None,
114
+ list(recipe_b) if recipe_b else None,
115
+ compiled_dims, disable_ue8m0_cast)
116
+
117
+ def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
118
+ recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
119
+ ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
120
+ list(recipe) if recipe else None,
121
+ list(recipe_a) if recipe_a else None,
122
+ list(recipe_b) if recipe_b else None,
123
+ compiled_dims, disable_ue8m0_cast)
124
+
125
+ fp8_gemm_nt = fp8_fp4_gemm_nt
126
+ fp8_gemm_nn = fp8_fp4_gemm_nn
127
+ fp8_gemm_tn = fp8_fp4_gemm_tn
128
+ fp8_gemm_tt = fp8_fp4_gemm_tt
129
+
130
+ def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
131
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
132
+ disable_ue8m0_cast=False, use_psum_layout=False,
133
+ expected_m_for_psum_layout=None):
134
+ ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
135
+ a[0], a[1], b[0], b[1], d, grouped_layout,
136
+ list(recipe) if recipe else None,
137
+ list(recipe_a) if recipe_a else None,
138
+ list(recipe_b) if recipe_b else None,
139
+ compiled_dims, disable_ue8m0_cast, use_psum_layout,
140
+ expected_m_for_psum_layout)
141
+
142
+ m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
143
+
144
+ def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
145
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
146
+ disable_ue8m0_cast=False, use_psum_layout=False):
147
+ ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
148
+ a[0], a[1], b[0], b[1], d, grouped_layout,
149
+ list(recipe) if recipe else None,
150
+ list(recipe_a) if recipe_a else None,
151
+ list(recipe_b) if recipe_b else None,
152
+ compiled_dims, disable_ue8m0_cast, use_psum_layout)
153
+
154
+ m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
155
+
156
+ def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
157
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
158
+ disable_ue8m0_cast=False):
159
+ ops.m_grouped_fp8_fp4_gemm_nt_masked(
160
+ a[0], a[1], b[0], b[1], d, masked_m, expected_m,
161
+ list(recipe) if recipe else None,
162
+ list(recipe_a) if recipe_a else None,
163
+ list(recipe_b) if recipe_b else None,
164
+ compiled_dims, disable_ue8m0_cast)
165
+
166
+ m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
167
+
168
+ def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
169
+ recipe=(1, 1, 128), compiled_dims="mn"):
170
+ ops.k_grouped_fp8_gemm_nt_contiguous(
171
+ a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
172
+ list(recipe), compiled_dims)
173
+
174
+ def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
175
+ recipe=(1, 1, 128), compiled_dims="mn"):
176
+ ops.k_grouped_fp8_gemm_tn_contiguous(
177
+ a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
178
+ list(recipe), compiled_dims)
179
+
180
+ # BF16 GEMMs
181
+ def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
182
+ ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
183
+
184
+ def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
185
+ ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
186
+
187
+ def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
188
+ ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
189
+
190
+ def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
191
+ ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
192
+
193
+ def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
194
+ compiled_dims="nk", use_psum_layout=False,
195
+ expected_m_for_psum_layout=None):
196
+ ops.m_grouped_bf16_gemm_nt_contiguous(
197
+ a, b, d, grouped_layout, compiled_dims,
198
+ use_psum_layout, expected_m_for_psum_layout)
199
+
200
+ def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
201
+ compiled_dims="nk", use_psum_layout=False):
202
+ ops.m_grouped_bf16_gemm_nn_contiguous(
203
+ a, b, d, grouped_layout, compiled_dims, use_psum_layout)
204
+
205
+ def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
206
+ compiled_dims="nk"):
207
+ ops.m_grouped_bf16_gemm_nt_masked(
208
+ a, b, d, masked_m, expected_m, compiled_dims)
209
+
210
+ def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
211
+ c=None, compiled_dims="mn"):
212
+ ops.k_grouped_bf16_gemm_tn_contiguous(
213
+ a, b, d, ks, ks_tensor, c, compiled_dims)
214
+
215
+ # Einsum
216
+ def einsum(expr, a, b, d, c=None, use_cublaslt=False):
217
+ ops.einsum(expr, a, b, d, c, use_cublaslt)
218
+
219
+ def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
220
+ ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
221
+
222
+ # Attention
223
+ def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
224
+ compiled_dims="nk", disable_ue8m0_cast=False):
225
+ ops.fp8_gemm_nt_skip_head_mid(
226
+ a[0], a[1], b[0], b[1], d, list(head_splits),
227
+ list(recipe) if recipe else None,
228
+ compiled_dims, disable_ue8m0_cast)
229
+
230
+ def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
231
+ cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
232
+ return ops.fp8_mqa_logits(
233
+ q, kv[0], kv[1], weights,
234
+ cu_seq_len_k_start, cu_seq_len_k_end,
235
+ clean_logits, max_seqlen_k)
236
+
237
+ def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
238
+ return ops.get_paged_mqa_logits_metadata(
239
+ context_lens, block_kv, num_sms)
240
+
241
+ def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
242
+ block_table, schedule_meta,
243
+ max_context_len, clean_logits=False):
244
+ return ops.fp8_paged_mqa_logits(
245
+ q, fused_kv_cache, weights, context_lens,
246
+ block_table, schedule_meta, max_context_len, clean_logits)
247
+
248
+ # Hyperconnection
249
+ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
250
+ ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
251
+
252
+ # Layout
253
+ def transform_sf_into_required_layout(sf, mn, k, recipe=None,
254
+ recipe_ab=None, num_groups=None, is_sfa=False,
255
+ disable_ue8m0_cast=False):
256
+ return ops.transform_sf_into_required_layout(
257
+ sf, mn, k,
258
+ list(recipe) if recipe else None,
259
+ list(recipe_ab) if recipe_ab else None,
260
+ num_groups, is_sfa, disable_ue8m0_cast)
261
+
262
+ def get_mk_alignment_for_contiguous_layout():
263
+ return ops.get_mk_alignment_for_contiguous_layout()
264
+
265
+ # Legacy aliases
266
+ fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
267
+ bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
268
+
269
+ except Exception:
270
+ pass
271
+
272
+ # Utils
273
+ from . import utils
274
+ from .utils import *
275
+
276
+ # Testing
277
+ from . import testing
278
+
279
+ # Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
280
+ try:
281
+ ops.init(
282
+ os.path.dirname(os.path.abspath(__file__)),
283
+ _find_cuda_home(),
284
+ _find_cutlass_include()
285
+ )
286
+ except Exception:
287
+ pass
288
+
289
+ __version__ = '2.3.0'
build/torch210-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be5d0bb69c96d55b15ba62ba83e0743eb80ef4e93198fe59862dc247540f4956
3
+ size 3006712
build/torch210-cxx11-cu126-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _deep_gemm_099ac3c_dirty
3
+ ops = torch.ops._deep_gemm_099ac3c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_deep_gemm_099ac3c_dirty::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/deep_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu126-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python-depends": []
3
+ }
build/torch210-cxx11-cu126-x86_64-linux/testing/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import bench, numeric, utils
2
+ from .bench import *
3
+ from .numeric import *
4
+ from .utils import *
build/torch210-cxx11-cu126-x86_64-linux/testing/bench.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+
6
+ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
7
+ high_precision: bool = False):
8
+ # Flush L2 cache with 256 MB data
9
+ torch.cuda.synchronize()
10
+ cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
11
+ cache.zero_()
12
+
13
+ # Warmup
14
+ for _ in range(num_warmups):
15
+ fn()
16
+
17
+ # Add a large kernel to eliminate the CPU launch overhead
18
+ if high_precision:
19
+ x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
20
+ y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
21
+ x @ y
22
+
23
+ # Testing
24
+ start_event = torch.cuda.Event(enable_timing=True)
25
+ end_event = torch.cuda.Event(enable_timing=True)
26
+ start_event.record()
27
+ for i in range(num_tests):
28
+ fn()
29
+ end_event.record()
30
+ torch.cuda.synchronize()
31
+
32
+ return start_event.elapsed_time(end_event) / num_tests / 1e3
33
+
34
+
35
+ class empty_suppress:
36
+ def __enter__(self):
37
+ return self
38
+
39
+ def __exit__(self, *_):
40
+ pass
41
+
42
+
43
+ class suppress_stdout_stderr:
44
+ def __enter__(self):
45
+ self.outnull_file = open(os.devnull, 'w')
46
+ self.errnull_file = open(os.devnull, 'w')
47
+
48
+ self.old_stdout_fileno_undup = sys.stdout.fileno()
49
+ self.old_stderr_fileno_undup = sys.stderr.fileno()
50
+
51
+ self.old_stdout_fileno = os.dup(sys.stdout.fileno())
52
+ self.old_stderr_fileno = os.dup(sys.stderr.fileno())
53
+
54
+ self.old_stdout = sys.stdout
55
+ self.old_stderr = sys.stderr
56
+
57
+ os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
58
+ os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
59
+
60
+ sys.stdout = self.outnull_file
61
+ sys.stderr = self.errnull_file
62
+ return self
63
+
64
+ def __exit__(self, *_):
65
+ sys.stdout = self.old_stdout
66
+ sys.stderr = self.old_stderr
67
+
68
+ os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
69
+ os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
70
+
71
+ os.close(self.old_stdout_fileno)
72
+ os.close(self.old_stderr_fileno)
73
+
74
+ self.outnull_file.close()
75
+ self.errnull_file.close()
76
+
77
+
78
+ def bench_kineto(fn, kernel_names, num_tests: int = 30,
79
+ suppress_kineto_output: bool = False,
80
+ trace_path: str = None, flush_l2: bool = True,
81
+ with_multiple_kernels: bool = False):
82
+ assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
83
+ is_tuple = isinstance(kernel_names, tuple)
84
+
85
+ # Skip profiling
86
+ # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
87
+ if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
88
+ return (1, ) * len(kernel_names) if is_tuple else 1
89
+
90
+ # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
91
+ flush_l2_size = int(8e9 // 4)
92
+
93
+ # For some auto-tuning kernels with prints
94
+ fn()
95
+
96
+ # Profile
97
+ suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
98
+ with suppress():
99
+ schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
100
+ profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
101
+ with profiler:
102
+ for i in range(2):
103
+ for _ in range(num_tests):
104
+ if flush_l2:
105
+ torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
106
+ fn()
107
+ profiler.step()
108
+
109
+ # Parse the profiling table
110
+ prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
111
+ kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
112
+ if not with_multiple_kernels:
113
+ for name in kernel_names:
114
+ assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
115
+
116
+ # Save chrome traces
117
+ if trace_path is not None:
118
+ profiler.export_chrome_trace(trace_path)
119
+
120
+ # Return average kernel times
121
+ units = {'ms': 1e3, 'us': 1e6}
122
+ kernel_times = []
123
+ for name in kernel_names:
124
+ total_time = 0
125
+ total_num = 0
126
+ for line in prof_lines:
127
+ if name in line:
128
+ time_str = line.split()[-2]
129
+ num_str = line.split()[-1]
130
+ for unit, scale in units.items():
131
+ if unit in time_str:
132
+ total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
133
+ total_num += int(num_str)
134
+ break
135
+ kernel_times.append(total_time / total_num if total_num > 0 else 0)
136
+
137
+ return tuple(kernel_times) if is_tuple else kernel_times[0]
build/torch210-cxx11-cu126-x86_64-linux/testing/numeric.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Iterable
3
+
4
+
5
+ def calc_diff(x: torch.Tensor, y: torch.Tensor):
6
+ x, y = x.double(), y.double()
7
+ denominator = (x * x + y * y).sum()
8
+ if denominator == 0: # Which means that all elements in x and y are 0
9
+ return 0.0
10
+ sim = 2 * (x * y).sum() / denominator
11
+ return 1 - sim
12
+
13
+
14
+ def count_bytes(*tensors):
15
+ total = 0
16
+ for t in tensors:
17
+ if isinstance(t, (tuple, list)):
18
+ total += count_bytes(*t)
19
+ elif t is not None:
20
+ total += t.numel() * t.element_size()
21
+ return total
build/torch210-cxx11-cu126-x86_64-linux/testing/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import torch
4
+ from typing import Callable
5
+
6
+ def get_arch_major() -> int:
7
+ major, minor = torch.cuda.get_device_capability()
8
+ return major
9
+
10
+
11
+ def test_filter(condition: Callable):
12
+ def decorator(func):
13
+ @functools.wraps(func)
14
+ def wrapper(*args, **kwargs):
15
+ if condition():
16
+ func(*args, **kwargs)
17
+ else:
18
+ print(f'{func.__name__}:')
19
+ print(f' > Filtered by {condition}')
20
+ print()
21
+ return wrapper
22
+ return decorator
23
+
24
+
25
+ def ignore_env(name: str, condition: Callable):
26
+ def decorator(func):
27
+ @functools.wraps(func)
28
+ def wrapper(*args, **kwargs):
29
+ if condition():
30
+ saved = os.environ.pop(name, None)
31
+ func(*args, **kwargs)
32
+ if saved is not None:
33
+ os.environ[name] = saved
34
+ else:
35
+ func(*args, **kwargs)
36
+
37
+ return wrapper
38
+ return decorator
build/torch210-cxx11-cu126-x86_64-linux/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import math, layout
2
+ from .layout import *
3
+ from .math import *
build/torch210-cxx11-cu126-x86_64-linux/utils/layout.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .._ops import ops
3
+
4
+ def get_tma_aligned_size(x, element_size):
5
+ return ops.get_tma_aligned_size(x, element_size)
6
+
7
+ def get_mn_major_tma_aligned_tensor(sf):
8
+ return ops.get_mn_major_tma_aligned_tensor(sf)
9
+
10
+ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
11
+ return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
12
+
13
+ def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
14
+ return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
15
+ sf, ks_tensor, ks)
16
+ except ImportError:
17
+ pass
18
+
19
+ from .._ops import ops as _ops
20
+
21
+ def get_mk_alignment_for_contiguous_layout():
22
+ return _ops.get_mk_alignment_for_contiguous_layout()
23
+
24
+ get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
25
+ get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
build/torch210-cxx11-cu126-x86_64-linux/utils/math.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+
5
+ def ceil_div(x: int, y: int) -> int:
6
+ return (x + y - 1) // y
7
+
8
+
9
+ def align(x: int, y: int) -> int:
10
+ return ceil_div(x, y) * y
11
+
12
+
13
+ def ceil_to_ue8m0(x: torch.Tensor):
14
+ assert x.view(-1).amax().item() > 0
15
+ return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
16
+
17
+
18
+ def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
19
+ assert x.dim() == 2
20
+ m, n = x.shape
21
+ padded_n = align(n, gran_k)
22
+ x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
23
+ x_padded[:, :n] = x
24
+ x_view = x_padded.view(m, -1, gran_k)
25
+ x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
26
+ sf = x_amax / 448.0
27
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
28
+ return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
29
+
30
+
31
+ def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ assert x.dim() == 2 and x.size(0) % gran_k == 0
33
+ m, n = x.shape
34
+ x_view = x.view(-1, gran_k, n)
35
+ x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
36
+ sf = x_amax / 448.0
37
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
38
+ return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
39
+
40
+
41
+ def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
42
+ assert x.dim() == 2
43
+ m, n = x.shape
44
+ x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
45
+ x_padded[:m, :n] = x
46
+ x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
47
+ x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
48
+ sf = x_amax / 448.0
49
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
50
+ x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
51
+ return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
52
+
53
+
54
+ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
56
+ x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
57
+ sf = x_amax / 448.0
58
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
59
+ x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
60
+ return x_scaled, sf.squeeze()
61
+
62
+
63
+ def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
64
+ ax = x.abs().clamp_max(6.0)
65
+ # {0, 0.5, 1, 1.5, 2, 3, 4, 6}
66
+ # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0
67
+ boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
68
+ device=x.device, dtype=ax.dtype)
69
+ idx = torch.bucketize(ax, boundaries)
70
+ code = idx.to(torch.uint8)
71
+ sign = (x < 0) & (idx != 0)
72
+ code = code | (sign.to(torch.uint8) << 3)
73
+ return code # uint8, 0..15
74
+
75
+
76
+ def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ assert x.dim() == 2
78
+ m, n = x.shape
79
+ assert n % 2 == 0
80
+ padded_n = align(n, gran_k)
81
+ x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
82
+ x_padded[:, :n] = x
83
+ x_view = x_padded.view(m, -1, gran_k)
84
+ x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4)
85
+ sf = x_amax / 6.0
86
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
87
+ x_scaled = x_view * (1.0 / sf.unsqueeze(2))
88
+ codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
89
+ codes2 = codes.view(m, padded_n // 2, 2)
90
+ packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
91
+ return packed[:, :n // 2].contiguous(), sf
92
+
93
+
94
+ def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
95
+ assert a.dtype == torch.uint8
96
+ assert a.dim() == 2
97
+ m, n2 = a.shape
98
+ n = n2 * 2
99
+ assert (m % 2) == 0
100
+ lo = a & 0x0F
101
+ hi = (a >> 4) & 0x0F
102
+ codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
103
+ codes[:, 0::2], codes[:, 1::2] = lo, hi
104
+ codes_t = codes.transpose(0, 1).contiguous()
105
+ codes2 = codes_t.view(n, m // 2, 2)
106
+ out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
107
+ return out.contiguous()
build/torch210-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ def _find_cuda_home():
9
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
10
+ if cuda_home is None:
11
+ try:
12
+ with open(os.devnull, 'w') as devnull:
13
+ nvcc = subprocess.check_output(
14
+ ['which', 'nvcc'], stderr=devnull
15
+ ).decode().rstrip('\r\n')
16
+ cuda_home = os.path.dirname(os.path.dirname(nvcc))
17
+ except Exception:
18
+ cuda_home = '/usr/local/cuda'
19
+ if not os.path.exists(cuda_home):
20
+ cuda_home = ''
21
+ return cuda_home or ''
22
+
23
+
24
+ def _find_cutlass_include():
25
+ """Find CUTLASS include path for JIT compilation of .cuh templates."""
26
+ # 1. Explicit env var
27
+ cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
28
+ if cutlass_include and os.path.isdir(cutlass_include):
29
+ return cutlass_include
30
+
31
+ # 2. CUTLASS_HOME env var
32
+ cutlass_home = os.environ.get('CUTLASS_HOME')
33
+ if cutlass_home:
34
+ p = os.path.join(cutlass_home, 'include')
35
+ if os.path.isdir(os.path.join(p, 'cute')):
36
+ return p
37
+
38
+ # 3. Check in package include/ directory (bundled cute/cutlass headers)
39
+ pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
40
+ if os.path.isdir(os.path.join(pkg_include, 'cute')):
41
+ return pkg_include
42
+
43
+ # 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
44
+ cuda_home = _find_cuda_home()
45
+ if cuda_home:
46
+ cuda_inc = os.path.join(cuda_home, 'include')
47
+ if os.path.isdir(os.path.join(cuda_inc, 'cute')):
48
+ return cuda_inc
49
+
50
+ # 5. Try to find nvidia-cutlass Python package
51
+ try:
52
+ import cutlass as _cutlass
53
+ cutlass_dir = os.path.dirname(_cutlass.__file__)
54
+ p = os.path.join(cutlass_dir, 'include')
55
+ if os.path.isdir(os.path.join(p, 'cute')):
56
+ return p
57
+ except ImportError:
58
+ pass
59
+
60
+ # Return empty string; C++ side will also check env vars
61
+ return ""
62
+
63
+
64
+ def set_num_sms(new_num_sms):
65
+ ops.set_num_sms(new_num_sms)
66
+
67
+ def get_num_sms():
68
+ return ops.get_num_sms()
69
+
70
+ def set_tc_util(new_tc_util):
71
+ ops.set_tc_util(new_tc_util)
72
+
73
+ def get_tc_util():
74
+ return ops.get_tc_util()
75
+
76
+
77
+ # cuBLASLt GEMMs
78
+ def cublaslt_gemm_nt(a, b, d, c=None):
79
+ ops.cublaslt_gemm_nt(a, b, d, c)
80
+
81
+ def cublaslt_gemm_nn(a, b, d, c=None):
82
+ ops.cublaslt_gemm_nn(a, b, d, c)
83
+
84
+ def cublaslt_gemm_tn(a, b, d, c=None):
85
+ ops.cublaslt_gemm_tn(a, b, d, c)
86
+
87
+ def cublaslt_gemm_tt(a, b, d, c=None):
88
+ ops.cublaslt_gemm_tt(a, b, d, c)
89
+
90
+
91
+ try:
92
+ # FP8/FP4 GEMMs
93
+ def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
94
+ recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
95
+ ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
96
+ list(recipe) if recipe else None,
97
+ list(recipe_a) if recipe_a else None,
98
+ list(recipe_b) if recipe_b else None,
99
+ compiled_dims, disable_ue8m0_cast)
100
+
101
+ def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
102
+ recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
103
+ ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
104
+ list(recipe) if recipe else None,
105
+ list(recipe_a) if recipe_a else None,
106
+ list(recipe_b) if recipe_b else None,
107
+ compiled_dims, disable_ue8m0_cast)
108
+
109
+ def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
110
+ recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
111
+ ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
112
+ list(recipe) if recipe else None,
113
+ list(recipe_a) if recipe_a else None,
114
+ list(recipe_b) if recipe_b else None,
115
+ compiled_dims, disable_ue8m0_cast)
116
+
117
+ def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
118
+ recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
119
+ ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
120
+ list(recipe) if recipe else None,
121
+ list(recipe_a) if recipe_a else None,
122
+ list(recipe_b) if recipe_b else None,
123
+ compiled_dims, disable_ue8m0_cast)
124
+
125
+ fp8_gemm_nt = fp8_fp4_gemm_nt
126
+ fp8_gemm_nn = fp8_fp4_gemm_nn
127
+ fp8_gemm_tn = fp8_fp4_gemm_tn
128
+ fp8_gemm_tt = fp8_fp4_gemm_tt
129
+
130
+ def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
131
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
132
+ disable_ue8m0_cast=False, use_psum_layout=False,
133
+ expected_m_for_psum_layout=None):
134
+ ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
135
+ a[0], a[1], b[0], b[1], d, grouped_layout,
136
+ list(recipe) if recipe else None,
137
+ list(recipe_a) if recipe_a else None,
138
+ list(recipe_b) if recipe_b else None,
139
+ compiled_dims, disable_ue8m0_cast, use_psum_layout,
140
+ expected_m_for_psum_layout)
141
+
142
+ m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
143
+
144
+ def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
145
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
146
+ disable_ue8m0_cast=False, use_psum_layout=False):
147
+ ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
148
+ a[0], a[1], b[0], b[1], d, grouped_layout,
149
+ list(recipe) if recipe else None,
150
+ list(recipe_a) if recipe_a else None,
151
+ list(recipe_b) if recipe_b else None,
152
+ compiled_dims, disable_ue8m0_cast, use_psum_layout)
153
+
154
+ m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
155
+
156
+ def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
157
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
158
+ disable_ue8m0_cast=False):
159
+ ops.m_grouped_fp8_fp4_gemm_nt_masked(
160
+ a[0], a[1], b[0], b[1], d, masked_m, expected_m,
161
+ list(recipe) if recipe else None,
162
+ list(recipe_a) if recipe_a else None,
163
+ list(recipe_b) if recipe_b else None,
164
+ compiled_dims, disable_ue8m0_cast)
165
+
166
+ m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
167
+
168
+ def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
169
+ recipe=(1, 1, 128), compiled_dims="mn"):
170
+ ops.k_grouped_fp8_gemm_nt_contiguous(
171
+ a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
172
+ list(recipe), compiled_dims)
173
+
174
+ def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
175
+ recipe=(1, 1, 128), compiled_dims="mn"):
176
+ ops.k_grouped_fp8_gemm_tn_contiguous(
177
+ a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
178
+ list(recipe), compiled_dims)
179
+
180
+ # BF16 GEMMs
181
+ def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
182
+ ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
183
+
184
+ def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
185
+ ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
186
+
187
+ def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
188
+ ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
189
+
190
+ def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
191
+ ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
192
+
193
+ def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
194
+ compiled_dims="nk", use_psum_layout=False,
195
+ expected_m_for_psum_layout=None):
196
+ ops.m_grouped_bf16_gemm_nt_contiguous(
197
+ a, b, d, grouped_layout, compiled_dims,
198
+ use_psum_layout, expected_m_for_psum_layout)
199
+
200
+ def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
201
+ compiled_dims="nk", use_psum_layout=False):
202
+ ops.m_grouped_bf16_gemm_nn_contiguous(
203
+ a, b, d, grouped_layout, compiled_dims, use_psum_layout)
204
+
205
+ def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
206
+ compiled_dims="nk"):
207
+ ops.m_grouped_bf16_gemm_nt_masked(
208
+ a, b, d, masked_m, expected_m, compiled_dims)
209
+
210
+ def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
211
+ c=None, compiled_dims="mn"):
212
+ ops.k_grouped_bf16_gemm_tn_contiguous(
213
+ a, b, d, ks, ks_tensor, c, compiled_dims)
214
+
215
+ # Einsum
216
+ def einsum(expr, a, b, d, c=None, use_cublaslt=False):
217
+ ops.einsum(expr, a, b, d, c, use_cublaslt)
218
+
219
+ def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
220
+ ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
221
+
222
+ # Attention
223
+ def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
224
+ compiled_dims="nk", disable_ue8m0_cast=False):
225
+ ops.fp8_gemm_nt_skip_head_mid(
226
+ a[0], a[1], b[0], b[1], d, list(head_splits),
227
+ list(recipe) if recipe else None,
228
+ compiled_dims, disable_ue8m0_cast)
229
+
230
+ def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
231
+ cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
232
+ return ops.fp8_mqa_logits(
233
+ q, kv[0], kv[1], weights,
234
+ cu_seq_len_k_start, cu_seq_len_k_end,
235
+ clean_logits, max_seqlen_k)
236
+
237
+ def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
238
+ return ops.get_paged_mqa_logits_metadata(
239
+ context_lens, block_kv, num_sms)
240
+
241
+ def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
242
+ block_table, schedule_meta,
243
+ max_context_len, clean_logits=False):
244
+ return ops.fp8_paged_mqa_logits(
245
+ q, fused_kv_cache, weights, context_lens,
246
+ block_table, schedule_meta, max_context_len, clean_logits)
247
+
248
+ # Hyperconnection
249
+ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
250
+ ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
251
+
252
+ # Layout
253
+ def transform_sf_into_required_layout(sf, mn, k, recipe=None,
254
+ recipe_ab=None, num_groups=None, is_sfa=False,
255
+ disable_ue8m0_cast=False):
256
+ return ops.transform_sf_into_required_layout(
257
+ sf, mn, k,
258
+ list(recipe) if recipe else None,
259
+ list(recipe_ab) if recipe_ab else None,
260
+ num_groups, is_sfa, disable_ue8m0_cast)
261
+
262
+ def get_mk_alignment_for_contiguous_layout():
263
+ return ops.get_mk_alignment_for_contiguous_layout()
264
+
265
+ # Legacy aliases
266
+ fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
267
+ bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
268
+
269
+ except Exception:
270
+ pass
271
+
272
+ # Utils
273
+ from . import utils
274
+ from .utils import *
275
+
276
+ # Testing
277
+ from . import testing
278
+
279
+ # Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
280
+ try:
281
+ ops.init(
282
+ os.path.dirname(os.path.abspath(__file__)),
283
+ _find_cuda_home(),
284
+ _find_cutlass_include()
285
+ )
286
+ except Exception:
287
+ pass
288
+
289
+ __version__ = '2.3.0'
build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b4ca9c42204f1909adcefc61053c7943c105eadb44a447a1ea9a488e01675df
3
+ size 3078080
build/torch210-cxx11-cu128-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _deep_gemm_099ac3c_dirty
3
+ ops = torch.ops._deep_gemm_099ac3c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_deep_gemm_099ac3c_dirty::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/deep_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python-depends": []
3
+ }
build/torch210-cxx11-cu128-x86_64-linux/testing/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import bench, numeric, utils
2
+ from .bench import *
3
+ from .numeric import *
4
+ from .utils import *
build/torch210-cxx11-cu128-x86_64-linux/testing/bench.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+
6
+ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
7
+ high_precision: bool = False):
8
+ # Flush L2 cache with 256 MB data
9
+ torch.cuda.synchronize()
10
+ cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
11
+ cache.zero_()
12
+
13
+ # Warmup
14
+ for _ in range(num_warmups):
15
+ fn()
16
+
17
+ # Add a large kernel to eliminate the CPU launch overhead
18
+ if high_precision:
19
+ x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
20
+ y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
21
+ x @ y
22
+
23
+ # Testing
24
+ start_event = torch.cuda.Event(enable_timing=True)
25
+ end_event = torch.cuda.Event(enable_timing=True)
26
+ start_event.record()
27
+ for i in range(num_tests):
28
+ fn()
29
+ end_event.record()
30
+ torch.cuda.synchronize()
31
+
32
+ return start_event.elapsed_time(end_event) / num_tests / 1e3
33
+
34
+
35
+ class empty_suppress:
36
+ def __enter__(self):
37
+ return self
38
+
39
+ def __exit__(self, *_):
40
+ pass
41
+
42
+
43
+ class suppress_stdout_stderr:
44
+ def __enter__(self):
45
+ self.outnull_file = open(os.devnull, 'w')
46
+ self.errnull_file = open(os.devnull, 'w')
47
+
48
+ self.old_stdout_fileno_undup = sys.stdout.fileno()
49
+ self.old_stderr_fileno_undup = sys.stderr.fileno()
50
+
51
+ self.old_stdout_fileno = os.dup(sys.stdout.fileno())
52
+ self.old_stderr_fileno = os.dup(sys.stderr.fileno())
53
+
54
+ self.old_stdout = sys.stdout
55
+ self.old_stderr = sys.stderr
56
+
57
+ os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
58
+ os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
59
+
60
+ sys.stdout = self.outnull_file
61
+ sys.stderr = self.errnull_file
62
+ return self
63
+
64
+ def __exit__(self, *_):
65
+ sys.stdout = self.old_stdout
66
+ sys.stderr = self.old_stderr
67
+
68
+ os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
69
+ os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
70
+
71
+ os.close(self.old_stdout_fileno)
72
+ os.close(self.old_stderr_fileno)
73
+
74
+ self.outnull_file.close()
75
+ self.errnull_file.close()
76
+
77
+
78
+ def bench_kineto(fn, kernel_names, num_tests: int = 30,
79
+ suppress_kineto_output: bool = False,
80
+ trace_path: str = None, flush_l2: bool = True,
81
+ with_multiple_kernels: bool = False):
82
+ assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
83
+ is_tuple = isinstance(kernel_names, tuple)
84
+
85
+ # Skip profiling
86
+ # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
87
+ if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
88
+ return (1, ) * len(kernel_names) if is_tuple else 1
89
+
90
+ # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
91
+ flush_l2_size = int(8e9 // 4)
92
+
93
+ # For some auto-tuning kernels with prints
94
+ fn()
95
+
96
+ # Profile
97
+ suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
98
+ with suppress():
99
+ schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
100
+ profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
101
+ with profiler:
102
+ for i in range(2):
103
+ for _ in range(num_tests):
104
+ if flush_l2:
105
+ torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
106
+ fn()
107
+ profiler.step()
108
+
109
+ # Parse the profiling table
110
+ prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
111
+ kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
112
+ if not with_multiple_kernels:
113
+ for name in kernel_names:
114
+ assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
115
+
116
+ # Save chrome traces
117
+ if trace_path is not None:
118
+ profiler.export_chrome_trace(trace_path)
119
+
120
+ # Return average kernel times
121
+ units = {'ms': 1e3, 'us': 1e6}
122
+ kernel_times = []
123
+ for name in kernel_names:
124
+ total_time = 0
125
+ total_num = 0
126
+ for line in prof_lines:
127
+ if name in line:
128
+ time_str = line.split()[-2]
129
+ num_str = line.split()[-1]
130
+ for unit, scale in units.items():
131
+ if unit in time_str:
132
+ total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
133
+ total_num += int(num_str)
134
+ break
135
+ kernel_times.append(total_time / total_num if total_num > 0 else 0)
136
+
137
+ return tuple(kernel_times) if is_tuple else kernel_times[0]
build/torch210-cxx11-cu128-x86_64-linux/testing/numeric.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Iterable
3
+
4
+
5
+ def calc_diff(x: torch.Tensor, y: torch.Tensor):
6
+ x, y = x.double(), y.double()
7
+ denominator = (x * x + y * y).sum()
8
+ if denominator == 0: # Which means that all elements in x and y are 0
9
+ return 0.0
10
+ sim = 2 * (x * y).sum() / denominator
11
+ return 1 - sim
12
+
13
+
14
+ def count_bytes(*tensors):
15
+ total = 0
16
+ for t in tensors:
17
+ if isinstance(t, (tuple, list)):
18
+ total += count_bytes(*t)
19
+ elif t is not None:
20
+ total += t.numel() * t.element_size()
21
+ return total
build/torch210-cxx11-cu128-x86_64-linux/testing/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import torch
4
+ from typing import Callable
5
+
6
+ def get_arch_major() -> int:
7
+ major, minor = torch.cuda.get_device_capability()
8
+ return major
9
+
10
+
11
+ def test_filter(condition: Callable):
12
+ def decorator(func):
13
+ @functools.wraps(func)
14
+ def wrapper(*args, **kwargs):
15
+ if condition():
16
+ func(*args, **kwargs)
17
+ else:
18
+ print(f'{func.__name__}:')
19
+ print(f' > Filtered by {condition}')
20
+ print()
21
+ return wrapper
22
+ return decorator
23
+
24
+
25
+ def ignore_env(name: str, condition: Callable):
26
+ def decorator(func):
27
+ @functools.wraps(func)
28
+ def wrapper(*args, **kwargs):
29
+ if condition():
30
+ saved = os.environ.pop(name, None)
31
+ func(*args, **kwargs)
32
+ if saved is not None:
33
+ os.environ[name] = saved
34
+ else:
35
+ func(*args, **kwargs)
36
+
37
+ return wrapper
38
+ return decorator
build/torch210-cxx11-cu128-x86_64-linux/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import math, layout
2
+ from .layout import *
3
+ from .math import *
build/torch210-cxx11-cu128-x86_64-linux/utils/layout.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .._ops import ops
3
+
4
+ def get_tma_aligned_size(x, element_size):
5
+ return ops.get_tma_aligned_size(x, element_size)
6
+
7
+ def get_mn_major_tma_aligned_tensor(sf):
8
+ return ops.get_mn_major_tma_aligned_tensor(sf)
9
+
10
+ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
11
+ return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
12
+
13
+ def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
14
+ return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
15
+ sf, ks_tensor, ks)
16
+ except ImportError:
17
+ pass
18
+
19
+ from .._ops import ops as _ops
20
+
21
+ def get_mk_alignment_for_contiguous_layout():
22
+ return _ops.get_mk_alignment_for_contiguous_layout()
23
+
24
+ get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
25
+ get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
build/torch210-cxx11-cu128-x86_64-linux/utils/math.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+
5
+ def ceil_div(x: int, y: int) -> int:
6
+ return (x + y - 1) // y
7
+
8
+
9
+ def align(x: int, y: int) -> int:
10
+ return ceil_div(x, y) * y
11
+
12
+
13
+ def ceil_to_ue8m0(x: torch.Tensor):
14
+ assert x.view(-1).amax().item() > 0
15
+ return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
16
+
17
+
18
+ def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
19
+ assert x.dim() == 2
20
+ m, n = x.shape
21
+ padded_n = align(n, gran_k)
22
+ x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
23
+ x_padded[:, :n] = x
24
+ x_view = x_padded.view(m, -1, gran_k)
25
+ x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
26
+ sf = x_amax / 448.0
27
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
28
+ return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
29
+
30
+
31
+ def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ assert x.dim() == 2 and x.size(0) % gran_k == 0
33
+ m, n = x.shape
34
+ x_view = x.view(-1, gran_k, n)
35
+ x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
36
+ sf = x_amax / 448.0
37
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
38
+ return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
39
+
40
+
41
+ def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
42
+ assert x.dim() == 2
43
+ m, n = x.shape
44
+ x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
45
+ x_padded[:m, :n] = x
46
+ x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
47
+ x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
48
+ sf = x_amax / 448.0
49
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
50
+ x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
51
+ return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
52
+
53
+
54
+ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
56
+ x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
57
+ sf = x_amax / 448.0
58
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
59
+ x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
60
+ return x_scaled, sf.squeeze()
61
+
62
+
63
+ def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
64
+ ax = x.abs().clamp_max(6.0)
65
+ # {0, 0.5, 1, 1.5, 2, 3, 4, 6}
66
+ # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0
67
+ boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
68
+ device=x.device, dtype=ax.dtype)
69
+ idx = torch.bucketize(ax, boundaries)
70
+ code = idx.to(torch.uint8)
71
+ sign = (x < 0) & (idx != 0)
72
+ code = code | (sign.to(torch.uint8) << 3)
73
+ return code # uint8, 0..15
74
+
75
+
76
+ def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ assert x.dim() == 2
78
+ m, n = x.shape
79
+ assert n % 2 == 0
80
+ padded_n = align(n, gran_k)
81
+ x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
82
+ x_padded[:, :n] = x
83
+ x_view = x_padded.view(m, -1, gran_k)
84
+ x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4)
85
+ sf = x_amax / 6.0
86
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
87
+ x_scaled = x_view * (1.0 / sf.unsqueeze(2))
88
+ codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
89
+ codes2 = codes.view(m, padded_n // 2, 2)
90
+ packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
91
+ return packed[:, :n // 2].contiguous(), sf
92
+
93
+
94
+ def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
95
+ assert a.dtype == torch.uint8
96
+ assert a.dim() == 2
97
+ m, n2 = a.shape
98
+ n = n2 * 2
99
+ assert (m % 2) == 0
100
+ lo = a & 0x0F
101
+ hi = (a >> 4) & 0x0F
102
+ codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
103
+ codes[:, 0::2], codes[:, 1::2] = lo, hi
104
+ codes_t = codes.transpose(0, 1).contiguous()
105
+ codes2 = codes_t.view(n, m // 2, 2)
106
+ out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
107
+ return out.contiguous()
build/torch210-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ def _find_cuda_home():
9
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
10
+ if cuda_home is None:
11
+ try:
12
+ with open(os.devnull, 'w') as devnull:
13
+ nvcc = subprocess.check_output(
14
+ ['which', 'nvcc'], stderr=devnull
15
+ ).decode().rstrip('\r\n')
16
+ cuda_home = os.path.dirname(os.path.dirname(nvcc))
17
+ except Exception:
18
+ cuda_home = '/usr/local/cuda'
19
+ if not os.path.exists(cuda_home):
20
+ cuda_home = ''
21
+ return cuda_home or ''
22
+
23
+
24
+ def _find_cutlass_include():
25
+ """Find CUTLASS include path for JIT compilation of .cuh templates."""
26
+ # 1. Explicit env var
27
+ cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
28
+ if cutlass_include and os.path.isdir(cutlass_include):
29
+ return cutlass_include
30
+
31
+ # 2. CUTLASS_HOME env var
32
+ cutlass_home = os.environ.get('CUTLASS_HOME')
33
+ if cutlass_home:
34
+ p = os.path.join(cutlass_home, 'include')
35
+ if os.path.isdir(os.path.join(p, 'cute')):
36
+ return p
37
+
38
+ # 3. Check in package include/ directory (bundled cute/cutlass headers)
39
+ pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
40
+ if os.path.isdir(os.path.join(pkg_include, 'cute')):
41
+ return pkg_include
42
+
43
+ # 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
44
+ cuda_home = _find_cuda_home()
45
+ if cuda_home:
46
+ cuda_inc = os.path.join(cuda_home, 'include')
47
+ if os.path.isdir(os.path.join(cuda_inc, 'cute')):
48
+ return cuda_inc
49
+
50
+ # 5. Try to find nvidia-cutlass Python package
51
+ try:
52
+ import cutlass as _cutlass
53
+ cutlass_dir = os.path.dirname(_cutlass.__file__)
54
+ p = os.path.join(cutlass_dir, 'include')
55
+ if os.path.isdir(os.path.join(p, 'cute')):
56
+ return p
57
+ except ImportError:
58
+ pass
59
+
60
+ # Return empty string; C++ side will also check env vars
61
+ return ""
62
+
63
+
64
+ def set_num_sms(new_num_sms):
65
+ ops.set_num_sms(new_num_sms)
66
+
67
+ def get_num_sms():
68
+ return ops.get_num_sms()
69
+
70
+ def set_tc_util(new_tc_util):
71
+ ops.set_tc_util(new_tc_util)
72
+
73
+ def get_tc_util():
74
+ return ops.get_tc_util()
75
+
76
+
77
+ # cuBLASLt GEMMs
78
+ def cublaslt_gemm_nt(a, b, d, c=None):
79
+ ops.cublaslt_gemm_nt(a, b, d, c)
80
+
81
+ def cublaslt_gemm_nn(a, b, d, c=None):
82
+ ops.cublaslt_gemm_nn(a, b, d, c)
83
+
84
+ def cublaslt_gemm_tn(a, b, d, c=None):
85
+ ops.cublaslt_gemm_tn(a, b, d, c)
86
+
87
+ def cublaslt_gemm_tt(a, b, d, c=None):
88
+ ops.cublaslt_gemm_tt(a, b, d, c)
89
+
90
+
91
+ try:
92
+ # FP8/FP4 GEMMs
93
+ def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
94
+ recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
95
+ ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
96
+ list(recipe) if recipe else None,
97
+ list(recipe_a) if recipe_a else None,
98
+ list(recipe_b) if recipe_b else None,
99
+ compiled_dims, disable_ue8m0_cast)
100
+
101
+ def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
102
+ recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
103
+ ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
104
+ list(recipe) if recipe else None,
105
+ list(recipe_a) if recipe_a else None,
106
+ list(recipe_b) if recipe_b else None,
107
+ compiled_dims, disable_ue8m0_cast)
108
+
109
+ def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
110
+ recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
111
+ ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
112
+ list(recipe) if recipe else None,
113
+ list(recipe_a) if recipe_a else None,
114
+ list(recipe_b) if recipe_b else None,
115
+ compiled_dims, disable_ue8m0_cast)
116
+
117
+ def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
118
+ recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
119
+ ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
120
+ list(recipe) if recipe else None,
121
+ list(recipe_a) if recipe_a else None,
122
+ list(recipe_b) if recipe_b else None,
123
+ compiled_dims, disable_ue8m0_cast)
124
+
125
+ fp8_gemm_nt = fp8_fp4_gemm_nt
126
+ fp8_gemm_nn = fp8_fp4_gemm_nn
127
+ fp8_gemm_tn = fp8_fp4_gemm_tn
128
+ fp8_gemm_tt = fp8_fp4_gemm_tt
129
+
130
+ def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
131
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
132
+ disable_ue8m0_cast=False, use_psum_layout=False,
133
+ expected_m_for_psum_layout=None):
134
+ ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
135
+ a[0], a[1], b[0], b[1], d, grouped_layout,
136
+ list(recipe) if recipe else None,
137
+ list(recipe_a) if recipe_a else None,
138
+ list(recipe_b) if recipe_b else None,
139
+ compiled_dims, disable_ue8m0_cast, use_psum_layout,
140
+ expected_m_for_psum_layout)
141
+
142
+ m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
143
+
144
+ def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
145
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
146
+ disable_ue8m0_cast=False, use_psum_layout=False):
147
+ ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
148
+ a[0], a[1], b[0], b[1], d, grouped_layout,
149
+ list(recipe) if recipe else None,
150
+ list(recipe_a) if recipe_a else None,
151
+ list(recipe_b) if recipe_b else None,
152
+ compiled_dims, disable_ue8m0_cast, use_psum_layout)
153
+
154
+ m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
155
+
156
+ def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
157
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
158
+ disable_ue8m0_cast=False):
159
+ ops.m_grouped_fp8_fp4_gemm_nt_masked(
160
+ a[0], a[1], b[0], b[1], d, masked_m, expected_m,
161
+ list(recipe) if recipe else None,
162
+ list(recipe_a) if recipe_a else None,
163
+ list(recipe_b) if recipe_b else None,
164
+ compiled_dims, disable_ue8m0_cast)
165
+
166
+ m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
167
+
168
+ def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
169
+ recipe=(1, 1, 128), compiled_dims="mn"):
170
+ ops.k_grouped_fp8_gemm_nt_contiguous(
171
+ a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
172
+ list(recipe), compiled_dims)
173
+
174
+ def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
175
+ recipe=(1, 1, 128), compiled_dims="mn"):
176
+ ops.k_grouped_fp8_gemm_tn_contiguous(
177
+ a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
178
+ list(recipe), compiled_dims)
179
+
180
+ # BF16 GEMMs
181
+ def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
182
+ ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
183
+
184
+ def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
185
+ ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
186
+
187
+ def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
188
+ ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
189
+
190
+ def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
191
+ ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
192
+
193
+ def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
194
+ compiled_dims="nk", use_psum_layout=False,
195
+ expected_m_for_psum_layout=None):
196
+ ops.m_grouped_bf16_gemm_nt_contiguous(
197
+ a, b, d, grouped_layout, compiled_dims,
198
+ use_psum_layout, expected_m_for_psum_layout)
199
+
200
+ def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
201
+ compiled_dims="nk", use_psum_layout=False):
202
+ ops.m_grouped_bf16_gemm_nn_contiguous(
203
+ a, b, d, grouped_layout, compiled_dims, use_psum_layout)
204
+
205
+ def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
206
+ compiled_dims="nk"):
207
+ ops.m_grouped_bf16_gemm_nt_masked(
208
+ a, b, d, masked_m, expected_m, compiled_dims)
209
+
210
+ def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
211
+ c=None, compiled_dims="mn"):
212
+ ops.k_grouped_bf16_gemm_tn_contiguous(
213
+ a, b, d, ks, ks_tensor, c, compiled_dims)
214
+
215
+ # Einsum
216
+ def einsum(expr, a, b, d, c=None, use_cublaslt=False):
217
+ ops.einsum(expr, a, b, d, c, use_cublaslt)
218
+
219
+ def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
220
+ ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
221
+
222
+ # Attention
223
+ def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
224
+ compiled_dims="nk", disable_ue8m0_cast=False):
225
+ ops.fp8_gemm_nt_skip_head_mid(
226
+ a[0], a[1], b[0], b[1], d, list(head_splits),
227
+ list(recipe) if recipe else None,
228
+ compiled_dims, disable_ue8m0_cast)
229
+
230
+ def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
231
+ cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
232
+ return ops.fp8_mqa_logits(
233
+ q, kv[0], kv[1], weights,
234
+ cu_seq_len_k_start, cu_seq_len_k_end,
235
+ clean_logits, max_seqlen_k)
236
+
237
+ def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
238
+ return ops.get_paged_mqa_logits_metadata(
239
+ context_lens, block_kv, num_sms)
240
+
241
+ def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
242
+ block_table, schedule_meta,
243
+ max_context_len, clean_logits=False):
244
+ return ops.fp8_paged_mqa_logits(
245
+ q, fused_kv_cache, weights, context_lens,
246
+ block_table, schedule_meta, max_context_len, clean_logits)
247
+
248
+ # Hyperconnection
249
+ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
250
+ ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
251
+
252
+ # Layout
253
+ def transform_sf_into_required_layout(sf, mn, k, recipe=None,
254
+ recipe_ab=None, num_groups=None, is_sfa=False,
255
+ disable_ue8m0_cast=False):
256
+ return ops.transform_sf_into_required_layout(
257
+ sf, mn, k,
258
+ list(recipe) if recipe else None,
259
+ list(recipe_ab) if recipe_ab else None,
260
+ num_groups, is_sfa, disable_ue8m0_cast)
261
+
262
+ def get_mk_alignment_for_contiguous_layout():
263
+ return ops.get_mk_alignment_for_contiguous_layout()
264
+
265
+ # Legacy aliases
266
+ fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
267
+ bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
268
+
269
+ except Exception:
270
+ pass
271
+
272
+ # Utils
273
+ from . import utils
274
+ from .utils import *
275
+
276
+ # Testing
277
+ from . import testing
278
+
279
+ # Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
280
+ try:
281
+ ops.init(
282
+ os.path.dirname(os.path.abspath(__file__)),
283
+ _find_cuda_home(),
284
+ _find_cutlass_include()
285
+ )
286
+ except Exception:
287
+ pass
288
+
289
+ __version__ = '2.3.0'
build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8307e5e24ea3f68435a8251df19977bfd2323e60f761b4c3cd7c5ba7aada4c3f
3
+ size 3078072
build/torch210-cxx11-cu130-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _deep_gemm_099ac3c_dirty
3
+ ops = torch.ops._deep_gemm_099ac3c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_deep_gemm_099ac3c_dirty::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/deep_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python-depends": []
3
+ }
build/torch210-cxx11-cu130-x86_64-linux/testing/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import bench, numeric, utils
2
+ from .bench import *
3
+ from .numeric import *
4
+ from .utils import *
build/torch210-cxx11-cu130-x86_64-linux/testing/bench.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+
6
+ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
7
+ high_precision: bool = False):
8
+ # Flush L2 cache with 256 MB data
9
+ torch.cuda.synchronize()
10
+ cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
11
+ cache.zero_()
12
+
13
+ # Warmup
14
+ for _ in range(num_warmups):
15
+ fn()
16
+
17
+ # Add a large kernel to eliminate the CPU launch overhead
18
+ if high_precision:
19
+ x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
20
+ y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
21
+ x @ y
22
+
23
+ # Testing
24
+ start_event = torch.cuda.Event(enable_timing=True)
25
+ end_event = torch.cuda.Event(enable_timing=True)
26
+ start_event.record()
27
+ for i in range(num_tests):
28
+ fn()
29
+ end_event.record()
30
+ torch.cuda.synchronize()
31
+
32
+ return start_event.elapsed_time(end_event) / num_tests / 1e3
33
+
34
+
35
+ class empty_suppress:
36
+ def __enter__(self):
37
+ return self
38
+
39
+ def __exit__(self, *_):
40
+ pass
41
+
42
+
43
+ class suppress_stdout_stderr:
44
+ def __enter__(self):
45
+ self.outnull_file = open(os.devnull, 'w')
46
+ self.errnull_file = open(os.devnull, 'w')
47
+
48
+ self.old_stdout_fileno_undup = sys.stdout.fileno()
49
+ self.old_stderr_fileno_undup = sys.stderr.fileno()
50
+
51
+ self.old_stdout_fileno = os.dup(sys.stdout.fileno())
52
+ self.old_stderr_fileno = os.dup(sys.stderr.fileno())
53
+
54
+ self.old_stdout = sys.stdout
55
+ self.old_stderr = sys.stderr
56
+
57
+ os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
58
+ os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
59
+
60
+ sys.stdout = self.outnull_file
61
+ sys.stderr = self.errnull_file
62
+ return self
63
+
64
+ def __exit__(self, *_):
65
+ sys.stdout = self.old_stdout
66
+ sys.stderr = self.old_stderr
67
+
68
+ os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
69
+ os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
70
+
71
+ os.close(self.old_stdout_fileno)
72
+ os.close(self.old_stderr_fileno)
73
+
74
+ self.outnull_file.close()
75
+ self.errnull_file.close()
76
+
77
+
78
+ def bench_kineto(fn, kernel_names, num_tests: int = 30,
79
+ suppress_kineto_output: bool = False,
80
+ trace_path: str = None, flush_l2: bool = True,
81
+ with_multiple_kernels: bool = False):
82
+ assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
83
+ is_tuple = isinstance(kernel_names, tuple)
84
+
85
+ # Skip profiling
86
+ # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
87
+ if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
88
+ return (1, ) * len(kernel_names) if is_tuple else 1
89
+
90
+ # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
91
+ flush_l2_size = int(8e9 // 4)
92
+
93
+ # For some auto-tuning kernels with prints
94
+ fn()
95
+
96
+ # Profile
97
+ suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
98
+ with suppress():
99
+ schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
100
+ profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
101
+ with profiler:
102
+ for i in range(2):
103
+ for _ in range(num_tests):
104
+ if flush_l2:
105
+ torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
106
+ fn()
107
+ profiler.step()
108
+
109
+ # Parse the profiling table
110
+ prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
111
+ kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
112
+ if not with_multiple_kernels:
113
+ for name in kernel_names:
114
+ assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
115
+
116
+ # Save chrome traces
117
+ if trace_path is not None:
118
+ profiler.export_chrome_trace(trace_path)
119
+
120
+ # Return average kernel times
121
+ units = {'ms': 1e3, 'us': 1e6}
122
+ kernel_times = []
123
+ for name in kernel_names:
124
+ total_time = 0
125
+ total_num = 0
126
+ for line in prof_lines:
127
+ if name in line:
128
+ time_str = line.split()[-2]
129
+ num_str = line.split()[-1]
130
+ for unit, scale in units.items():
131
+ if unit in time_str:
132
+ total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
133
+ total_num += int(num_str)
134
+ break
135
+ kernel_times.append(total_time / total_num if total_num > 0 else 0)
136
+
137
+ return tuple(kernel_times) if is_tuple else kernel_times[0]
build/torch210-cxx11-cu130-x86_64-linux/testing/numeric.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Iterable
3
+
4
+
5
+ def calc_diff(x: torch.Tensor, y: torch.Tensor):
6
+ x, y = x.double(), y.double()
7
+ denominator = (x * x + y * y).sum()
8
+ if denominator == 0: # Which means that all elements in x and y are 0
9
+ return 0.0
10
+ sim = 2 * (x * y).sum() / denominator
11
+ return 1 - sim
12
+
13
+
14
+ def count_bytes(*tensors):
15
+ total = 0
16
+ for t in tensors:
17
+ if isinstance(t, (tuple, list)):
18
+ total += count_bytes(*t)
19
+ elif t is not None:
20
+ total += t.numel() * t.element_size()
21
+ return total
build/torch210-cxx11-cu130-x86_64-linux/testing/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import torch
4
+ from typing import Callable
5
+
6
+ def get_arch_major() -> int:
7
+ major, minor = torch.cuda.get_device_capability()
8
+ return major
9
+
10
+
11
+ def test_filter(condition: Callable):
12
+ def decorator(func):
13
+ @functools.wraps(func)
14
+ def wrapper(*args, **kwargs):
15
+ if condition():
16
+ func(*args, **kwargs)
17
+ else:
18
+ print(f'{func.__name__}:')
19
+ print(f' > Filtered by {condition}')
20
+ print()
21
+ return wrapper
22
+ return decorator
23
+
24
+
25
+ def ignore_env(name: str, condition: Callable):
26
+ def decorator(func):
27
+ @functools.wraps(func)
28
+ def wrapper(*args, **kwargs):
29
+ if condition():
30
+ saved = os.environ.pop(name, None)
31
+ func(*args, **kwargs)
32
+ if saved is not None:
33
+ os.environ[name] = saved
34
+ else:
35
+ func(*args, **kwargs)
36
+
37
+ return wrapper
38
+ return decorator
build/torch210-cxx11-cu130-x86_64-linux/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import math, layout
2
+ from .layout import *
3
+ from .math import *
build/torch210-cxx11-cu130-x86_64-linux/utils/layout.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .._ops import ops
3
+
4
+ def get_tma_aligned_size(x, element_size):
5
+ return ops.get_tma_aligned_size(x, element_size)
6
+
7
+ def get_mn_major_tma_aligned_tensor(sf):
8
+ return ops.get_mn_major_tma_aligned_tensor(sf)
9
+
10
+ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
11
+ return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
12
+
13
+ def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
14
+ return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
15
+ sf, ks_tensor, ks)
16
+ except ImportError:
17
+ pass
18
+
19
+ from .._ops import ops as _ops
20
+
21
+ def get_mk_alignment_for_contiguous_layout():
22
+ return _ops.get_mk_alignment_for_contiguous_layout()
23
+
24
+ get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
25
+ get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
build/torch210-cxx11-cu130-x86_64-linux/utils/math.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+
5
+ def ceil_div(x: int, y: int) -> int:
6
+ return (x + y - 1) // y
7
+
8
+
9
+ def align(x: int, y: int) -> int:
10
+ return ceil_div(x, y) * y
11
+
12
+
13
+ def ceil_to_ue8m0(x: torch.Tensor):
14
+ assert x.view(-1).amax().item() > 0
15
+ return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
16
+
17
+
18
+ def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
19
+ assert x.dim() == 2
20
+ m, n = x.shape
21
+ padded_n = align(n, gran_k)
22
+ x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
23
+ x_padded[:, :n] = x
24
+ x_view = x_padded.view(m, -1, gran_k)
25
+ x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
26
+ sf = x_amax / 448.0
27
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
28
+ return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
29
+
30
+
31
+ def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ assert x.dim() == 2 and x.size(0) % gran_k == 0
33
+ m, n = x.shape
34
+ x_view = x.view(-1, gran_k, n)
35
+ x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
36
+ sf = x_amax / 448.0
37
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
38
+ return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
39
+
40
+
41
+ def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
42
+ assert x.dim() == 2
43
+ m, n = x.shape
44
+ x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
45
+ x_padded[:m, :n] = x
46
+ x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
47
+ x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
48
+ sf = x_amax / 448.0
49
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
50
+ x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
51
+ return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
52
+
53
+
54
+ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
56
+ x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
57
+ sf = x_amax / 448.0
58
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
59
+ x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
60
+ return x_scaled, sf.squeeze()
61
+
62
+
63
+ def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
64
+ ax = x.abs().clamp_max(6.0)
65
+ # {0, 0.5, 1, 1.5, 2, 3, 4, 6}
66
+ # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0
67
+ boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
68
+ device=x.device, dtype=ax.dtype)
69
+ idx = torch.bucketize(ax, boundaries)
70
+ code = idx.to(torch.uint8)
71
+ sign = (x < 0) & (idx != 0)
72
+ code = code | (sign.to(torch.uint8) << 3)
73
+ return code # uint8, 0..15
74
+
75
+
76
+ def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ assert x.dim() == 2
78
+ m, n = x.shape
79
+ assert n % 2 == 0
80
+ padded_n = align(n, gran_k)
81
+ x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
82
+ x_padded[:, :n] = x
83
+ x_view = x_padded.view(m, -1, gran_k)
84
+ x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4)
85
+ sf = x_amax / 6.0
86
+ sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
87
+ x_scaled = x_view * (1.0 / sf.unsqueeze(2))
88
+ codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
89
+ codes2 = codes.view(m, padded_n // 2, 2)
90
+ packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
91
+ return packed[:, :n // 2].contiguous(), sf
92
+
93
+
94
+ def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
95
+ assert a.dtype == torch.uint8
96
+ assert a.dim() == 2
97
+ m, n2 = a.shape
98
+ n = n2 * 2
99
+ assert (m % 2) == 0
100
+ lo = a & 0x0F
101
+ hi = (a >> 4) & 0x0F
102
+ codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
103
+ codes[:, 0::2], codes[:, 1::2] = lo, hi
104
+ codes_t = codes.transpose(0, 1).contiguous()
105
+ codes2 = codes_t.view(n, m // 2, 2)
106
+ out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
107
+ return out.contiguous()
build/torch29-cxx11-cu126-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ def _find_cuda_home():
9
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
10
+ if cuda_home is None:
11
+ try:
12
+ with open(os.devnull, 'w') as devnull:
13
+ nvcc = subprocess.check_output(
14
+ ['which', 'nvcc'], stderr=devnull
15
+ ).decode().rstrip('\r\n')
16
+ cuda_home = os.path.dirname(os.path.dirname(nvcc))
17
+ except Exception:
18
+ cuda_home = '/usr/local/cuda'
19
+ if not os.path.exists(cuda_home):
20
+ cuda_home = ''
21
+ return cuda_home or ''
22
+
23
+
24
+ def _find_cutlass_include():
25
+ """Find CUTLASS include path for JIT compilation of .cuh templates."""
26
+ # 1. Explicit env var
27
+ cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
28
+ if cutlass_include and os.path.isdir(cutlass_include):
29
+ return cutlass_include
30
+
31
+ # 2. CUTLASS_HOME env var
32
+ cutlass_home = os.environ.get('CUTLASS_HOME')
33
+ if cutlass_home:
34
+ p = os.path.join(cutlass_home, 'include')
35
+ if os.path.isdir(os.path.join(p, 'cute')):
36
+ return p
37
+
38
+ # 3. Check in package include/ directory (bundled cute/cutlass headers)
39
+ pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
40
+ if os.path.isdir(os.path.join(pkg_include, 'cute')):
41
+ return pkg_include
42
+
43
+ # 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
44
+ cuda_home = _find_cuda_home()
45
+ if cuda_home:
46
+ cuda_inc = os.path.join(cuda_home, 'include')
47
+ if os.path.isdir(os.path.join(cuda_inc, 'cute')):
48
+ return cuda_inc
49
+
50
+ # 5. Try to find nvidia-cutlass Python package
51
+ try:
52
+ import cutlass as _cutlass
53
+ cutlass_dir = os.path.dirname(_cutlass.__file__)
54
+ p = os.path.join(cutlass_dir, 'include')
55
+ if os.path.isdir(os.path.join(p, 'cute')):
56
+ return p
57
+ except ImportError:
58
+ pass
59
+
60
+ # Return empty string; C++ side will also check env vars
61
+ return ""
62
+
63
+
64
+ def set_num_sms(new_num_sms):
65
+ ops.set_num_sms(new_num_sms)
66
+
67
+ def get_num_sms():
68
+ return ops.get_num_sms()
69
+
70
+ def set_tc_util(new_tc_util):
71
+ ops.set_tc_util(new_tc_util)
72
+
73
+ def get_tc_util():
74
+ return ops.get_tc_util()
75
+
76
+
77
+ # cuBLASLt GEMMs
78
+ def cublaslt_gemm_nt(a, b, d, c=None):
79
+ ops.cublaslt_gemm_nt(a, b, d, c)
80
+
81
+ def cublaslt_gemm_nn(a, b, d, c=None):
82
+ ops.cublaslt_gemm_nn(a, b, d, c)
83
+
84
+ def cublaslt_gemm_tn(a, b, d, c=None):
85
+ ops.cublaslt_gemm_tn(a, b, d, c)
86
+
87
+ def cublaslt_gemm_tt(a, b, d, c=None):
88
+ ops.cublaslt_gemm_tt(a, b, d, c)
89
+
90
+
91
+ try:
92
+ # FP8/FP4 GEMMs
93
+ def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
94
+ recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
95
+ ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
96
+ list(recipe) if recipe else None,
97
+ list(recipe_a) if recipe_a else None,
98
+ list(recipe_b) if recipe_b else None,
99
+ compiled_dims, disable_ue8m0_cast)
100
+
101
+ def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
102
+ recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
103
+ ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
104
+ list(recipe) if recipe else None,
105
+ list(recipe_a) if recipe_a else None,
106
+ list(recipe_b) if recipe_b else None,
107
+ compiled_dims, disable_ue8m0_cast)
108
+
109
+ def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
110
+ recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
111
+ ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
112
+ list(recipe) if recipe else None,
113
+ list(recipe_a) if recipe_a else None,
114
+ list(recipe_b) if recipe_b else None,
115
+ compiled_dims, disable_ue8m0_cast)
116
+
117
+ def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
118
+ recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
119
+ ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
120
+ list(recipe) if recipe else None,
121
+ list(recipe_a) if recipe_a else None,
122
+ list(recipe_b) if recipe_b else None,
123
+ compiled_dims, disable_ue8m0_cast)
124
+
125
+ fp8_gemm_nt = fp8_fp4_gemm_nt
126
+ fp8_gemm_nn = fp8_fp4_gemm_nn
127
+ fp8_gemm_tn = fp8_fp4_gemm_tn
128
+ fp8_gemm_tt = fp8_fp4_gemm_tt
129
+
130
+ def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
131
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
132
+ disable_ue8m0_cast=False, use_psum_layout=False,
133
+ expected_m_for_psum_layout=None):
134
+ ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
135
+ a[0], a[1], b[0], b[1], d, grouped_layout,
136
+ list(recipe) if recipe else None,
137
+ list(recipe_a) if recipe_a else None,
138
+ list(recipe_b) if recipe_b else None,
139
+ compiled_dims, disable_ue8m0_cast, use_psum_layout,
140
+ expected_m_for_psum_layout)
141
+
142
+ m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
143
+
144
+ def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
145
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
146
+ disable_ue8m0_cast=False, use_psum_layout=False):
147
+ ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
148
+ a[0], a[1], b[0], b[1], d, grouped_layout,
149
+ list(recipe) if recipe else None,
150
+ list(recipe_a) if recipe_a else None,
151
+ list(recipe_b) if recipe_b else None,
152
+ compiled_dims, disable_ue8m0_cast, use_psum_layout)
153
+
154
+ m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
155
+
156
+ def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
157
+ recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
158
+ disable_ue8m0_cast=False):
159
+ ops.m_grouped_fp8_fp4_gemm_nt_masked(
160
+ a[0], a[1], b[0], b[1], d, masked_m, expected_m,
161
+ list(recipe) if recipe else None,
162
+ list(recipe_a) if recipe_a else None,
163
+ list(recipe_b) if recipe_b else None,
164
+ compiled_dims, disable_ue8m0_cast)
165
+
166
+ m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
167
+
168
+ def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
169
+ recipe=(1, 1, 128), compiled_dims="mn"):
170
+ ops.k_grouped_fp8_gemm_nt_contiguous(
171
+ a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
172
+ list(recipe), compiled_dims)
173
+
174
+ def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
175
+ recipe=(1, 1, 128), compiled_dims="mn"):
176
+ ops.k_grouped_fp8_gemm_tn_contiguous(
177
+ a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
178
+ list(recipe), compiled_dims)
179
+
180
+ # BF16 GEMMs
181
+ def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
182
+ ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
183
+
184
+ def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
185
+ ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
186
+
187
+ def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
188
+ ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
189
+
190
+ def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
191
+ ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
192
+
193
+ def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
194
+ compiled_dims="nk", use_psum_layout=False,
195
+ expected_m_for_psum_layout=None):
196
+ ops.m_grouped_bf16_gemm_nt_contiguous(
197
+ a, b, d, grouped_layout, compiled_dims,
198
+ use_psum_layout, expected_m_for_psum_layout)
199
+
200
+ def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
201
+ compiled_dims="nk", use_psum_layout=False):
202
+ ops.m_grouped_bf16_gemm_nn_contiguous(
203
+ a, b, d, grouped_layout, compiled_dims, use_psum_layout)
204
+
205
+ def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
206
+ compiled_dims="nk"):
207
+ ops.m_grouped_bf16_gemm_nt_masked(
208
+ a, b, d, masked_m, expected_m, compiled_dims)
209
+
210
+ def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
211
+ c=None, compiled_dims="mn"):
212
+ ops.k_grouped_bf16_gemm_tn_contiguous(
213
+ a, b, d, ks, ks_tensor, c, compiled_dims)
214
+
215
+ # Einsum
216
+ def einsum(expr, a, b, d, c=None, use_cublaslt=False):
217
+ ops.einsum(expr, a, b, d, c, use_cublaslt)
218
+
219
+ def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
220
+ ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
221
+
222
+ # Attention
223
+ def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
224
+ compiled_dims="nk", disable_ue8m0_cast=False):
225
+ ops.fp8_gemm_nt_skip_head_mid(
226
+ a[0], a[1], b[0], b[1], d, list(head_splits),
227
+ list(recipe) if recipe else None,
228
+ compiled_dims, disable_ue8m0_cast)
229
+
230
+ def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
231
+ cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
232
+ return ops.fp8_mqa_logits(
233
+ q, kv[0], kv[1], weights,
234
+ cu_seq_len_k_start, cu_seq_len_k_end,
235
+ clean_logits, max_seqlen_k)
236
+
237
+ def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
238
+ return ops.get_paged_mqa_logits_metadata(
239
+ context_lens, block_kv, num_sms)
240
+
241
+ def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
242
+ block_table, schedule_meta,
243
+ max_context_len, clean_logits=False):
244
+ return ops.fp8_paged_mqa_logits(
245
+ q, fused_kv_cache, weights, context_lens,
246
+ block_table, schedule_meta, max_context_len, clean_logits)
247
+
248
+ # Hyperconnection
249
+ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
250
+ ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
251
+
252
+ # Layout
253
+ def transform_sf_into_required_layout(sf, mn, k, recipe=None,
254
+ recipe_ab=None, num_groups=None, is_sfa=False,
255
+ disable_ue8m0_cast=False):
256
+ return ops.transform_sf_into_required_layout(
257
+ sf, mn, k,
258
+ list(recipe) if recipe else None,
259
+ list(recipe_ab) if recipe_ab else None,
260
+ num_groups, is_sfa, disable_ue8m0_cast)
261
+
262
+ def get_mk_alignment_for_contiguous_layout():
263
+ return ops.get_mk_alignment_for_contiguous_layout()
264
+
265
+ # Legacy aliases
266
+ fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
267
+ bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
268
+
269
+ except Exception:
270
+ pass
271
+
272
+ # Utils
273
+ from . import utils
274
+ from .utils import *
275
+
276
+ # Testing
277
+ from . import testing
278
+
279
+ # Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
280
+ try:
281
+ ops.init(
282
+ os.path.dirname(os.path.abspath(__file__)),
283
+ _find_cuda_home(),
284
+ _find_cutlass_include()
285
+ )
286
+ except Exception:
287
+ pass
288
+
289
+ __version__ = '2.3.0'
build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_099ac3c_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9ad7e5f8bcd1642692d50e321db2ee6a668bdc448fa481490e307e2dfb0ffe
3
+ size 2967864
build/torch29-cxx11-cu126-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _deep_gemm_099ac3c_dirty
3
+ ops = torch.ops._deep_gemm_099ac3c_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_deep_gemm_099ac3c_dirty::{op_name}"
build/torch29-cxx11-cu126-x86_64-linux/deep_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch29-cxx11-cu126-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python-depends": []
3
+ }
build/torch29-cxx11-cu126-x86_64-linux/testing/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import bench, numeric, utils
2
+ from .bench import *
3
+ from .numeric import *
4
+ from .utils import *
build/torch29-cxx11-cu126-x86_64-linux/testing/bench.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+
6
+ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
7
+ high_precision: bool = False):
8
+ # Flush L2 cache with 256 MB data
9
+ torch.cuda.synchronize()
10
+ cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
11
+ cache.zero_()
12
+
13
+ # Warmup
14
+ for _ in range(num_warmups):
15
+ fn()
16
+
17
+ # Add a large kernel to eliminate the CPU launch overhead
18
+ if high_precision:
19
+ x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
20
+ y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
21
+ x @ y
22
+
23
+ # Testing
24
+ start_event = torch.cuda.Event(enable_timing=True)
25
+ end_event = torch.cuda.Event(enable_timing=True)
26
+ start_event.record()
27
+ for i in range(num_tests):
28
+ fn()
29
+ end_event.record()
30
+ torch.cuda.synchronize()
31
+
32
+ return start_event.elapsed_time(end_event) / num_tests / 1e3
33
+
34
+
35
+ class empty_suppress:
36
+ def __enter__(self):
37
+ return self
38
+
39
+ def __exit__(self, *_):
40
+ pass
41
+
42
+
43
+ class suppress_stdout_stderr:
44
+ def __enter__(self):
45
+ self.outnull_file = open(os.devnull, 'w')
46
+ self.errnull_file = open(os.devnull, 'w')
47
+
48
+ self.old_stdout_fileno_undup = sys.stdout.fileno()
49
+ self.old_stderr_fileno_undup = sys.stderr.fileno()
50
+
51
+ self.old_stdout_fileno = os.dup(sys.stdout.fileno())
52
+ self.old_stderr_fileno = os.dup(sys.stderr.fileno())
53
+
54
+ self.old_stdout = sys.stdout
55
+ self.old_stderr = sys.stderr
56
+
57
+ os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
58
+ os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
59
+
60
+ sys.stdout = self.outnull_file
61
+ sys.stderr = self.errnull_file
62
+ return self
63
+
64
+ def __exit__(self, *_):
65
+ sys.stdout = self.old_stdout
66
+ sys.stderr = self.old_stderr
67
+
68
+ os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
69
+ os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
70
+
71
+ os.close(self.old_stdout_fileno)
72
+ os.close(self.old_stderr_fileno)
73
+
74
+ self.outnull_file.close()
75
+ self.errnull_file.close()
76
+
77
+
78
+ def bench_kineto(fn, kernel_names, num_tests: int = 30,
79
+ suppress_kineto_output: bool = False,
80
+ trace_path: str = None, flush_l2: bool = True,
81
+ with_multiple_kernels: bool = False):
82
+ assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
83
+ is_tuple = isinstance(kernel_names, tuple)
84
+
85
+ # Skip profiling
86
+ # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
87
+ if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
88
+ return (1, ) * len(kernel_names) if is_tuple else 1
89
+
90
+ # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
91
+ flush_l2_size = int(8e9 // 4)
92
+
93
+ # For some auto-tuning kernels with prints
94
+ fn()
95
+
96
+ # Profile
97
+ suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
98
+ with suppress():
99
+ schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
100
+ profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
101
+ with profiler:
102
+ for i in range(2):
103
+ for _ in range(num_tests):
104
+ if flush_l2:
105
+ torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
106
+ fn()
107
+ profiler.step()
108
+
109
+ # Parse the profiling table
110
+ prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
111
+ kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
112
+ if not with_multiple_kernels:
113
+ for name in kernel_names:
114
+ assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
115
+
116
+ # Save chrome traces
117
+ if trace_path is not None:
118
+ profiler.export_chrome_trace(trace_path)
119
+
120
+ # Return average kernel times
121
+ units = {'ms': 1e3, 'us': 1e6}
122
+ kernel_times = []
123
+ for name in kernel_names:
124
+ total_time = 0
125
+ total_num = 0
126
+ for line in prof_lines:
127
+ if name in line:
128
+ time_str = line.split()[-2]
129
+ num_str = line.split()[-1]
130
+ for unit, scale in units.items():
131
+ if unit in time_str:
132
+ total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
133
+ total_num += int(num_str)
134
+ break
135
+ kernel_times.append(total_time / total_num if total_num > 0 else 0)
136
+
137
+ return tuple(kernel_times) if is_tuple else kernel_times[0]
build/torch29-cxx11-cu126-x86_64-linux/testing/numeric.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Iterable
3
+
4
+
5
+ def calc_diff(x: torch.Tensor, y: torch.Tensor):
6
+ x, y = x.double(), y.double()
7
+ denominator = (x * x + y * y).sum()
8
+ if denominator == 0: # Which means that all elements in x and y are 0
9
+ return 0.0
10
+ sim = 2 * (x * y).sum() / denominator
11
+ return 1 - sim
12
+
13
+
14
+ def count_bytes(*tensors):
15
+ total = 0
16
+ for t in tensors:
17
+ if isinstance(t, (tuple, list)):
18
+ total += count_bytes(*t)
19
+ elif t is not None:
20
+ total += t.numel() * t.element_size()
21
+ return total
build/torch29-cxx11-cu126-x86_64-linux/testing/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import torch
4
+ from typing import Callable
5
+
6
+ def get_arch_major() -> int:
7
+ major, minor = torch.cuda.get_device_capability()
8
+ return major
9
+
10
+
11
+ def test_filter(condition: Callable):
12
+ def decorator(func):
13
+ @functools.wraps(func)
14
+ def wrapper(*args, **kwargs):
15
+ if condition():
16
+ func(*args, **kwargs)
17
+ else:
18
+ print(f'{func.__name__}:')
19
+ print(f' > Filtered by {condition}')
20
+ print()
21
+ return wrapper
22
+ return decorator
23
+
24
+
25
+ def ignore_env(name: str, condition: Callable):
26
+ def decorator(func):
27
+ @functools.wraps(func)
28
+ def wrapper(*args, **kwargs):
29
+ if condition():
30
+ saved = os.environ.pop(name, None)
31
+ func(*args, **kwargs)
32
+ if saved is not None:
33
+ os.environ[name] = saved
34
+ else:
35
+ func(*args, **kwargs)
36
+
37
+ return wrapper
38
+ return decorator
build/torch29-cxx11-cu126-x86_64-linux/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import math, layout
2
+ from .layout import *
3
+ from .math import *
build/torch29-cxx11-cu126-x86_64-linux/utils/layout.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .._ops import ops
3
+
4
+ def get_tma_aligned_size(x, element_size):
5
+ return ops.get_tma_aligned_size(x, element_size)
6
+
7
+ def get_mn_major_tma_aligned_tensor(sf):
8
+ return ops.get_mn_major_tma_aligned_tensor(sf)
9
+
10
+ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
11
+ return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
12
+
13
+ def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
14
+ return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
15
+ sf, ks_tensor, ks)
16
+ except ImportError:
17
+ pass
18
+
19
+ from .._ops import ops as _ops
20
+
21
+ def get_mk_alignment_for_contiguous_layout():
22
+ return _ops.get_mk_alignment_for_contiguous_layout()
23
+
24
+ get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
25
+ get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout