koichi12 commited on
Commit
80a73eb
·
verified ·
1 Parent(s): b964460

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/nvidia/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py +0 -0
  4. .venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/nvidia/cublas/include/__init__.py +0 -0
  6. .venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas.h +891 -0
  8. .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h +1845 -0
  9. .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasXt.h +693 -0
  10. .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h +0 -0
  11. .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h +478 -0
  12. .venv/lib/python3.11/site-packages/nvidia/cublas/include/nvblas.h +824 -0
  13. .venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py +0 -0
  14. .venv/lib/python3.11/site-packages/nvidia/cublas/lib/__pycache__/__init__.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 +3 -0
  16. .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__init__.py +0 -0
  17. .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__pycache__/__init__.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__init__.py +0 -0
  19. .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__pycache__/__init__.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/nvrtc.h +869 -0
  21. .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__init__.py +0 -0
  22. .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__pycache__/__init__.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/async.h +452 -0
  25. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/coalesced_scan.h +174 -0
  26. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/driver_abi.h +99 -0
  27. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/info.h +344 -0
  28. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/invoke.h +189 -0
  29. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/memory.h +135 -0
  30. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/partitioning.h +159 -0
  31. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/reduce.h +419 -0
  32. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/scan.h +320 -0
  33. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/sync.h +282 -0
  34. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/__init__.py +0 -0
  35. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/__pycache__/__init__.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 +3 -0
  37. .venv/lib/python3.11/site-packages/nvidia/cudnn/__init__.py +0 -0
  38. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn.h +68 -0
  39. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv_v9.h +671 -0
  40. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend_v9.h +60 -0
  41. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_graph.h +909 -0
  42. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops.h +1316 -0
  43. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version_v9.h +70 -0
  44. .venv/lib/python3.11/site-packages/nvidia/cusolver/__init__.py +0 -0
  45. .venv/lib/python3.11/site-packages/nvidia/cusolver/__pycache__/__init__.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/nvidia/cusolver/include/__init__.py +0 -0
  47. .venv/lib/python3.11/site-packages/nvidia/cusolver/include/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverDn.h +0 -0
  49. .venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverMg.h +318 -0
  50. .venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverRf.h +339 -0
.gitattributes CHANGED
@@ -120,3 +120,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
120
  .venv/lib/python3.11/site-packages/click/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
121
  .venv/lib/python3.11/site-packages/pyasn1/type/__pycache__/univ.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
122
  .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text
 
 
 
120
  .venv/lib/python3.11/site-packages/click/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
121
  .venv/lib/python3.11/site-packages/pyasn1/type/__pycache__/univ.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
122
  .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text
123
+ .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
124
+ .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/nvidia/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (179 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (186 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cublas/include/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (194 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas.h ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * This is the public header file for the CUBLAS library, defining the API
52
+ *
53
+ * CUBLAS is an implementation of BLAS (Basic Linear Algebra Subroutines)
54
+ * on top of the CUDA runtime.
55
+ */
56
+
57
+ #if !defined(CUBLAS_H_)
58
+ #define CUBLAS_H_
59
+
60
+ #if defined(CUBLAS_V2_H_)
61
+ #error "It is an error to include both cublas.h and cublas_v2.h"
62
+ #endif
63
+
64
+ #include <cuda_runtime.h>
65
+
66
+ #ifndef CUBLASWINAPI
67
+ #ifdef _WIN32
68
+ #define CUBLASWINAPI __stdcall
69
+ #else
70
+ #define CUBLASWINAPI
71
+ #endif
72
+ #endif
73
+
74
+ #undef CUBLASAPI
75
+ #ifdef __CUDACC__
76
+ #define CUBLASAPI __host__
77
+ #else
78
+ #define CUBLASAPI
79
+ #endif
80
+
81
+ #include "cublas_api.h"
82
+
83
+ #if defined(__cplusplus)
84
+ extern "C" {
85
+ #endif
86
+
87
+ /* CUBLAS data types */
88
+ #define cublasStatus cublasStatus_t
89
+
90
+ cublasStatus CUBLASWINAPI cublasInit(void);
91
+ cublasStatus CUBLASWINAPI cublasShutdown(void);
92
+ cublasStatus CUBLASWINAPI cublasGetError(void);
93
+
94
+ cublasStatus CUBLASWINAPI cublasGetVersion(int* version);
95
+ cublasStatus CUBLASWINAPI cublasAlloc(int n, int elemSize, void** devicePtr);
96
+
97
+ cublasStatus CUBLASWINAPI cublasFree(void* devicePtr);
98
+
99
+ cublasStatus CUBLASWINAPI cublasSetKernelStream(cudaStream_t stream);
100
+
101
+ /* ---------------- CUBLAS BLAS1 functions ---------------- */
102
+ /* NRM2 */
103
+ float CUBLASWINAPI cublasSnrm2(int n, const float* x, int incx);
104
+ double CUBLASWINAPI cublasDnrm2(int n, const double* x, int incx);
105
+ float CUBLASWINAPI cublasScnrm2(int n, const cuComplex* x, int incx);
106
+ double CUBLASWINAPI cublasDznrm2(int n, const cuDoubleComplex* x, int incx);
107
+ /*------------------------------------------------------------------------*/
108
+ /* DOT */
109
+ float CUBLASWINAPI cublasSdot(int n, const float* x, int incx, const float* y, int incy);
110
+ double CUBLASWINAPI cublasDdot(int n, const double* x, int incx, const double* y, int incy);
111
+ cuComplex CUBLASWINAPI cublasCdotu(int n, const cuComplex* x, int incx, const cuComplex* y, int incy);
112
+ cuComplex CUBLASWINAPI cublasCdotc(int n, const cuComplex* x, int incx, const cuComplex* y, int incy);
113
+ cuDoubleComplex CUBLASWINAPI cublasZdotu(int n, const cuDoubleComplex* x, int incx, const cuDoubleComplex* y, int incy);
114
+ cuDoubleComplex CUBLASWINAPI cublasZdotc(int n, const cuDoubleComplex* x, int incx, const cuDoubleComplex* y, int incy);
115
+ /*------------------------------------------------------------------------*/
116
+ /* SCAL */
117
+ void CUBLASWINAPI cublasSscal(int n, float alpha, float* x, int incx);
118
+ void CUBLASWINAPI cublasDscal(int n, double alpha, double* x, int incx);
119
+ void CUBLASWINAPI cublasCscal(int n, cuComplex alpha, cuComplex* x, int incx);
120
+ void CUBLASWINAPI cublasZscal(int n, cuDoubleComplex alpha, cuDoubleComplex* x, int incx);
121
+
122
+ void CUBLASWINAPI cublasCsscal(int n, float alpha, cuComplex* x, int incx);
123
+ void CUBLASWINAPI cublasZdscal(int n, double alpha, cuDoubleComplex* x, int incx);
124
+ /*------------------------------------------------------------------------*/
125
+ /* AXPY */
126
+ void CUBLASWINAPI cublasSaxpy(int n, float alpha, const float* x, int incx, float* y, int incy);
127
+ void CUBLASWINAPI cublasDaxpy(int n, double alpha, const double* x, int incx, double* y, int incy);
128
+ void CUBLASWINAPI cublasCaxpy(int n, cuComplex alpha, const cuComplex* x, int incx, cuComplex* y, int incy);
129
+ void CUBLASWINAPI
130
+ cublasZaxpy(int n, cuDoubleComplex alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy);
131
+ /*------------------------------------------------------------------------*/
132
+ /* COPY */
133
+ void CUBLASWINAPI cublasScopy(int n, const float* x, int incx, float* y, int incy);
134
+ void CUBLASWINAPI cublasDcopy(int n, const double* x, int incx, double* y, int incy);
135
+ void CUBLASWINAPI cublasCcopy(int n, const cuComplex* x, int incx, cuComplex* y, int incy);
136
+ void CUBLASWINAPI cublasZcopy(int n, const cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy);
137
+ /*------------------------------------------------------------------------*/
138
+ /* SWAP */
139
+ void CUBLASWINAPI cublasSswap(int n, float* x, int incx, float* y, int incy);
140
+ void CUBLASWINAPI cublasDswap(int n, double* x, int incx, double* y, int incy);
141
+ void CUBLASWINAPI cublasCswap(int n, cuComplex* x, int incx, cuComplex* y, int incy);
142
+ void CUBLASWINAPI cublasZswap(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy);
143
+ /*------------------------------------------------------------------------*/
144
+ /* AMAX */
145
+ int CUBLASWINAPI cublasIsamax(int n, const float* x, int incx);
146
+ int CUBLASWINAPI cublasIdamax(int n, const double* x, int incx);
147
+ int CUBLASWINAPI cublasIcamax(int n, const cuComplex* x, int incx);
148
+ int CUBLASWINAPI cublasIzamax(int n, const cuDoubleComplex* x, int incx);
149
+ /*------------------------------------------------------------------------*/
150
+ /* AMIN */
151
+ int CUBLASWINAPI cublasIsamin(int n, const float* x, int incx);
152
+ int CUBLASWINAPI cublasIdamin(int n, const double* x, int incx);
153
+
154
+ int CUBLASWINAPI cublasIcamin(int n, const cuComplex* x, int incx);
155
+ int CUBLASWINAPI cublasIzamin(int n, const cuDoubleComplex* x, int incx);
156
+ /*------------------------------------------------------------------------*/
157
+ /* ASUM */
158
+ float CUBLASWINAPI cublasSasum(int n, const float* x, int incx);
159
+ double CUBLASWINAPI cublasDasum(int n, const double* x, int incx);
160
+ float CUBLASWINAPI cublasScasum(int n, const cuComplex* x, int incx);
161
+ double CUBLASWINAPI cublasDzasum(int n, const cuDoubleComplex* x, int incx);
162
+ /*------------------------------------------------------------------------*/
163
+ /* ROT */
164
+ void CUBLASWINAPI cublasSrot(int n, float* x, int incx, float* y, int incy, float sc, float ss);
165
+ void CUBLASWINAPI cublasDrot(int n, double* x, int incx, double* y, int incy, double sc, double ss);
166
+ void CUBLASWINAPI cublasCrot(int n, cuComplex* x, int incx, cuComplex* y, int incy, float c, cuComplex s);
167
+ void CUBLASWINAPI
168
+ cublasZrot(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, double sc, cuDoubleComplex cs);
169
+ void CUBLASWINAPI cublasCsrot(int n, cuComplex* x, int incx, cuComplex* y, int incy, float c, float s);
170
+ void CUBLASWINAPI cublasZdrot(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, double c, double s);
171
+ /*------------------------------------------------------------------------*/
172
+ /* ROTG */
173
+ void CUBLASWINAPI cublasSrotg(float* sa, float* sb, float* sc, float* ss);
174
+ void CUBLASWINAPI cublasDrotg(double* sa, double* sb, double* sc, double* ss);
175
+ void CUBLASWINAPI cublasCrotg(cuComplex* ca, cuComplex cb, float* sc, cuComplex* cs);
176
+ void CUBLASWINAPI cublasZrotg(cuDoubleComplex* ca, cuDoubleComplex cb, double* sc, cuDoubleComplex* cs);
177
+ /*------------------------------------------------------------------------*/
178
+ /* ROTM */
179
+ void CUBLASWINAPI cublasSrotm(int n, float* x, int incx, float* y, int incy, const float* sparam);
180
+ void CUBLASWINAPI cublasDrotm(int n, double* x, int incx, double* y, int incy, const double* sparam);
181
+ /*------------------------------------------------------------------------*/
182
+ /* ROTMG */
183
+ void CUBLASWINAPI cublasSrotmg(float* sd1, float* sd2, float* sx1, const float* sy1, float* sparam);
184
+ void CUBLASWINAPI cublasDrotmg(double* sd1, double* sd2, double* sx1, const double* sy1, double* sparam);
185
+
186
+ /* --------------- CUBLAS BLAS2 functions ---------------- */
187
+ /* GEMV */
188
+ void CUBLASWINAPI cublasSgemv(char trans,
189
+ int m,
190
+ int n,
191
+ float alpha,
192
+ const float* A,
193
+ int lda,
194
+ const float* x,
195
+ int incx,
196
+ float beta,
197
+ float* y,
198
+ int incy);
199
+ void CUBLASWINAPI cublasDgemv(char trans,
200
+ int m,
201
+ int n,
202
+ double alpha,
203
+ const double* A,
204
+ int lda,
205
+ const double* x,
206
+ int incx,
207
+ double beta,
208
+ double* y,
209
+ int incy);
210
+ void CUBLASWINAPI cublasCgemv(char trans,
211
+ int m,
212
+ int n,
213
+ cuComplex alpha,
214
+ const cuComplex* A,
215
+ int lda,
216
+ const cuComplex* x,
217
+ int incx,
218
+ cuComplex beta,
219
+ cuComplex* y,
220
+ int incy);
221
+ void CUBLASWINAPI cublasZgemv(char trans,
222
+ int m,
223
+ int n,
224
+ cuDoubleComplex alpha,
225
+ const cuDoubleComplex* A,
226
+ int lda,
227
+ const cuDoubleComplex* x,
228
+ int incx,
229
+ cuDoubleComplex beta,
230
+ cuDoubleComplex* y,
231
+ int incy);
232
+ /*------------------------------------------------------------------------*/
233
+ /* GBMV */
234
+ void CUBLASWINAPI cublasSgbmv(char trans,
235
+ int m,
236
+ int n,
237
+ int kl,
238
+ int ku,
239
+ float alpha,
240
+ const float* A,
241
+ int lda,
242
+ const float* x,
243
+ int incx,
244
+ float beta,
245
+ float* y,
246
+ int incy);
247
+ void CUBLASWINAPI cublasDgbmv(char trans,
248
+ int m,
249
+ int n,
250
+ int kl,
251
+ int ku,
252
+ double alpha,
253
+ const double* A,
254
+ int lda,
255
+ const double* x,
256
+ int incx,
257
+ double beta,
258
+ double* y,
259
+ int incy);
260
+ void CUBLASWINAPI cublasCgbmv(char trans,
261
+ int m,
262
+ int n,
263
+ int kl,
264
+ int ku,
265
+ cuComplex alpha,
266
+ const cuComplex* A,
267
+ int lda,
268
+ const cuComplex* x,
269
+ int incx,
270
+ cuComplex beta,
271
+ cuComplex* y,
272
+ int incy);
273
+ void CUBLASWINAPI cublasZgbmv(char trans,
274
+ int m,
275
+ int n,
276
+ int kl,
277
+ int ku,
278
+ cuDoubleComplex alpha,
279
+ const cuDoubleComplex* A,
280
+ int lda,
281
+ const cuDoubleComplex* x,
282
+ int incx,
283
+ cuDoubleComplex beta,
284
+ cuDoubleComplex* y,
285
+ int incy);
286
+ /*------------------------------------------------------------------------*/
287
+ /* TRMV */
288
+ void CUBLASWINAPI cublasStrmv(char uplo, char trans, char diag, int n, const float* A, int lda, float* x, int incx);
289
+ void CUBLASWINAPI cublasDtrmv(char uplo, char trans, char diag, int n, const double* A, int lda, double* x, int incx);
290
+ void CUBLASWINAPI
291
+ cublasCtrmv(char uplo, char trans, char diag, int n, const cuComplex* A, int lda, cuComplex* x, int incx);
292
+ void CUBLASWINAPI
293
+ cublasZtrmv(char uplo, char trans, char diag, int n, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx);
294
+ /*------------------------------------------------------------------------*/
295
+ /* TBMV */
296
+ void CUBLASWINAPI
297
+ cublasStbmv(char uplo, char trans, char diag, int n, int k, const float* A, int lda, float* x, int incx);
298
+ void CUBLASWINAPI
299
+ cublasDtbmv(char uplo, char trans, char diag, int n, int k, const double* A, int lda, double* x, int incx);
300
+ void CUBLASWINAPI
301
+ cublasCtbmv(char uplo, char trans, char diag, int n, int k, const cuComplex* A, int lda, cuComplex* x, int incx);
302
+ void CUBLASWINAPI cublasZtbmv(
303
+ char uplo, char trans, char diag, int n, int k, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx);
304
+ /*------------------------------------------------------------------------*/
305
+ /* TPMV */
306
+ void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, const float* AP, float* x, int incx);
307
+
308
+ void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, const double* AP, double* x, int incx);
309
+
310
+ void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, const cuComplex* AP, cuComplex* x, int incx);
311
+
312
+ void CUBLASWINAPI
313
+ cublasZtpmv(char uplo, char trans, char diag, int n, const cuDoubleComplex* AP, cuDoubleComplex* x, int incx);
314
+ /*------------------------------------------------------------------------*/
315
+ /* TRSV */
316
+ void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, const float* A, int lda, float* x, int incx);
317
+
318
+ void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, const double* A, int lda, double* x, int incx);
319
+
320
+ void CUBLASWINAPI
321
+ cublasCtrsv(char uplo, char trans, char diag, int n, const cuComplex* A, int lda, cuComplex* x, int incx);
322
+
323
+ void CUBLASWINAPI
324
+ cublasZtrsv(char uplo, char trans, char diag, int n, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx);
325
+ /*------------------------------------------------------------------------*/
326
+ /* TPSV */
327
+ void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, const float* AP, float* x, int incx);
328
+
329
+ void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, const double* AP, double* x, int incx);
330
+
331
+ void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, const cuComplex* AP, cuComplex* x, int incx);
332
+
333
+ void CUBLASWINAPI
334
+ cublasZtpsv(char uplo, char trans, char diag, int n, const cuDoubleComplex* AP, cuDoubleComplex* x, int incx);
335
+ /*------------------------------------------------------------------------*/
336
+ /* TBSV */
337
+ void CUBLASWINAPI
338
+ cublasStbsv(char uplo, char trans, char diag, int n, int k, const float* A, int lda, float* x, int incx);
339
+
340
+ void CUBLASWINAPI
341
+ cublasDtbsv(char uplo, char trans, char diag, int n, int k, const double* A, int lda, double* x, int incx);
342
+ void CUBLASWINAPI
343
+ cublasCtbsv(char uplo, char trans, char diag, int n, int k, const cuComplex* A, int lda, cuComplex* x, int incx);
344
+
345
+ void CUBLASWINAPI cublasZtbsv(
346
+ char uplo, char trans, char diag, int n, int k, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx);
347
+ /*------------------------------------------------------------------------*/
348
+ /* SYMV/HEMV */
349
+ void CUBLASWINAPI cublasSsymv(
350
+ char uplo, int n, float alpha, const float* A, int lda, const float* x, int incx, float beta, float* y, int incy);
351
+ void CUBLASWINAPI cublasDsymv(char uplo,
352
+ int n,
353
+ double alpha,
354
+ const double* A,
355
+ int lda,
356
+ const double* x,
357
+ int incx,
358
+ double beta,
359
+ double* y,
360
+ int incy);
361
+ void CUBLASWINAPI cublasChemv(char uplo,
362
+ int n,
363
+ cuComplex alpha,
364
+ const cuComplex* A,
365
+ int lda,
366
+ const cuComplex* x,
367
+ int incx,
368
+ cuComplex beta,
369
+ cuComplex* y,
370
+ int incy);
371
+ void CUBLASWINAPI cublasZhemv(char uplo,
372
+ int n,
373
+ cuDoubleComplex alpha,
374
+ const cuDoubleComplex* A,
375
+ int lda,
376
+ const cuDoubleComplex* x,
377
+ int incx,
378
+ cuDoubleComplex beta,
379
+ cuDoubleComplex* y,
380
+ int incy);
381
+ /*------------------------------------------------------------------------*/
382
+ /* SBMV/HBMV */
383
+ void CUBLASWINAPI cublasSsbmv(char uplo,
384
+ int n,
385
+ int k,
386
+ float alpha,
387
+ const float* A,
388
+ int lda,
389
+ const float* x,
390
+ int incx,
391
+ float beta,
392
+ float* y,
393
+ int incy);
394
+ void CUBLASWINAPI cublasDsbmv(char uplo,
395
+ int n,
396
+ int k,
397
+ double alpha,
398
+ const double* A,
399
+ int lda,
400
+ const double* x,
401
+ int incx,
402
+ double beta,
403
+ double* y,
404
+ int incy);
405
+ void CUBLASWINAPI cublasChbmv(char uplo,
406
+ int n,
407
+ int k,
408
+ cuComplex alpha,
409
+ const cuComplex* A,
410
+ int lda,
411
+ const cuComplex* x,
412
+ int incx,
413
+ cuComplex beta,
414
+ cuComplex* y,
415
+ int incy);
416
+ void CUBLASWINAPI cublasZhbmv(char uplo,
417
+ int n,
418
+ int k,
419
+ cuDoubleComplex alpha,
420
+ const cuDoubleComplex* A,
421
+ int lda,
422
+ const cuDoubleComplex* x,
423
+ int incx,
424
+ cuDoubleComplex beta,
425
+ cuDoubleComplex* y,
426
+ int incy);
427
+ /*------------------------------------------------------------------------*/
428
+ /* SPMV/HPMV */
429
+ void CUBLASWINAPI
430
+ cublasSspmv(char uplo, int n, float alpha, const float* AP, const float* x, int incx, float beta, float* y, int incy);
431
+ void CUBLASWINAPI cublasDspmv(
432
+ char uplo, int n, double alpha, const double* AP, const double* x, int incx, double beta, double* y, int incy);
433
+ void CUBLASWINAPI cublasChpmv(char uplo,
434
+ int n,
435
+ cuComplex alpha,
436
+ const cuComplex* AP,
437
+ const cuComplex* x,
438
+ int incx,
439
+ cuComplex beta,
440
+ cuComplex* y,
441
+ int incy);
442
+ void CUBLASWINAPI cublasZhpmv(char uplo,
443
+ int n,
444
+ cuDoubleComplex alpha,
445
+ const cuDoubleComplex* AP,
446
+ const cuDoubleComplex* x,
447
+ int incx,
448
+ cuDoubleComplex beta,
449
+ cuDoubleComplex* y,
450
+ int incy);
451
+
452
+ /*------------------------------------------------------------------------*/
453
+ /* GER */
454
+ void CUBLASWINAPI
455
+ cublasSger(int m, int n, float alpha, const float* x, int incx, const float* y, int incy, float* A, int lda);
456
+ void CUBLASWINAPI
457
+ cublasDger(int m, int n, double alpha, const double* x, int incx, const double* y, int incy, double* A, int lda);
458
+
459
+ void CUBLASWINAPI cublasCgeru(
460
+ int m, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* A, int lda);
461
+ void CUBLASWINAPI cublasCgerc(
462
+ int m, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* A, int lda);
463
+ void CUBLASWINAPI cublasZgeru(int m,
464
+ int n,
465
+ cuDoubleComplex alpha,
466
+ const cuDoubleComplex* x,
467
+ int incx,
468
+ const cuDoubleComplex* y,
469
+ int incy,
470
+ cuDoubleComplex* A,
471
+ int lda);
472
+ void CUBLASWINAPI cublasZgerc(int m,
473
+ int n,
474
+ cuDoubleComplex alpha,
475
+ const cuDoubleComplex* x,
476
+ int incx,
477
+ const cuDoubleComplex* y,
478
+ int incy,
479
+ cuDoubleComplex* A,
480
+ int lda);
481
+ /*------------------------------------------------------------------------*/
482
+ /* SYR/HER */
483
+ void CUBLASWINAPI cublasSsyr(char uplo, int n, float alpha, const float* x, int incx, float* A, int lda);
484
+ void CUBLASWINAPI cublasDsyr(char uplo, int n, double alpha, const double* x, int incx, double* A, int lda);
485
+
486
+ void CUBLASWINAPI cublasCher(char uplo, int n, float alpha, const cuComplex* x, int incx, cuComplex* A, int lda);
487
+ void CUBLASWINAPI
488
+ cublasZher(char uplo, int n, double alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* A, int lda);
489
+
490
+ /*------------------------------------------------------------------------*/
491
+ /* SPR/HPR */
492
+ void CUBLASWINAPI cublasSspr(char uplo, int n, float alpha, const float* x, int incx, float* AP);
493
+ void CUBLASWINAPI cublasDspr(char uplo, int n, double alpha, const double* x, int incx, double* AP);
494
+ void CUBLASWINAPI cublasChpr(char uplo, int n, float alpha, const cuComplex* x, int incx, cuComplex* AP);
495
+ void CUBLASWINAPI cublasZhpr(char uplo, int n, double alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* AP);
496
+ /*------------------------------------------------------------------------*/
497
+ /* SYR2/HER2 */
498
+ void CUBLASWINAPI
499
+ cublasSsyr2(char uplo, int n, float alpha, const float* x, int incx, const float* y, int incy, float* A, int lda);
500
+ void CUBLASWINAPI
501
+ cublasDsyr2(char uplo, int n, double alpha, const double* x, int incx, const double* y, int incy, double* A, int lda);
502
+ void CUBLASWINAPI cublasCher2(char uplo,
503
+ int n,
504
+ cuComplex alpha,
505
+ const cuComplex* x,
506
+ int incx,
507
+ const cuComplex* y,
508
+ int incy,
509
+ cuComplex* A,
510
+ int lda);
511
+ void CUBLASWINAPI cublasZher2(char uplo,
512
+ int n,
513
+ cuDoubleComplex alpha,
514
+ const cuDoubleComplex* x,
515
+ int incx,
516
+ const cuDoubleComplex* y,
517
+ int incy,
518
+ cuDoubleComplex* A,
519
+ int lda);
520
+
521
+ /*------------------------------------------------------------------------*/
522
+ /* SPR2/HPR2 */
523
+ void CUBLASWINAPI
524
+ cublasSspr2(char uplo, int n, float alpha, const float* x, int incx, const float* y, int incy, float* AP);
525
+ void CUBLASWINAPI
526
+ cublasDspr2(char uplo, int n, double alpha, const double* x, int incx, const double* y, int incy, double* AP);
527
+ void CUBLASWINAPI cublasChpr2(
528
+ char uplo, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* AP);
529
+ void CUBLASWINAPI cublasZhpr2(char uplo,
530
+ int n,
531
+ cuDoubleComplex alpha,
532
+ const cuDoubleComplex* x,
533
+ int incx,
534
+ const cuDoubleComplex* y,
535
+ int incy,
536
+ cuDoubleComplex* AP);
537
+ /* ------------------------BLAS3 Functions ------------------------------- */
538
+ /* GEMM */
539
+ void CUBLASWINAPI cublasSgemm(char transa,
540
+ char transb,
541
+ int m,
542
+ int n,
543
+ int k,
544
+ float alpha,
545
+ const float* A,
546
+ int lda,
547
+ const float* B,
548
+ int ldb,
549
+ float beta,
550
+ float* C,
551
+ int ldc);
552
+ void CUBLASWINAPI cublasDgemm(char transa,
553
+ char transb,
554
+ int m,
555
+ int n,
556
+ int k,
557
+ double alpha,
558
+ const double* A,
559
+ int lda,
560
+ const double* B,
561
+ int ldb,
562
+ double beta,
563
+ double* C,
564
+ int ldc);
565
+ void CUBLASWINAPI cublasCgemm(char transa,
566
+ char transb,
567
+ int m,
568
+ int n,
569
+ int k,
570
+ cuComplex alpha,
571
+ const cuComplex* A,
572
+ int lda,
573
+ const cuComplex* B,
574
+ int ldb,
575
+ cuComplex beta,
576
+ cuComplex* C,
577
+ int ldc);
578
+ void CUBLASWINAPI cublasZgemm(char transa,
579
+ char transb,
580
+ int m,
581
+ int n,
582
+ int k,
583
+ cuDoubleComplex alpha,
584
+ const cuDoubleComplex* A,
585
+ int lda,
586
+ const cuDoubleComplex* B,
587
+ int ldb,
588
+ cuDoubleComplex beta,
589
+ cuDoubleComplex* C,
590
+ int ldc);
591
+ /* -------------------------------------------------------*/
592
+ /* SYRK */
593
+ void CUBLASWINAPI
594
+ cublasSsyrk(char uplo, char trans, int n, int k, float alpha, const float* A, int lda, float beta, float* C, int ldc);
595
+ void CUBLASWINAPI cublasDsyrk(
596
+ char uplo, char trans, int n, int k, double alpha, const double* A, int lda, double beta, double* C, int ldc);
597
+
598
+ void CUBLASWINAPI cublasCsyrk(char uplo,
599
+ char trans,
600
+ int n,
601
+ int k,
602
+ cuComplex alpha,
603
+ const cuComplex* A,
604
+ int lda,
605
+ cuComplex beta,
606
+ cuComplex* C,
607
+ int ldc);
608
+ void CUBLASWINAPI cublasZsyrk(char uplo,
609
+ char trans,
610
+ int n,
611
+ int k,
612
+ cuDoubleComplex alpha,
613
+ const cuDoubleComplex* A,
614
+ int lda,
615
+ cuDoubleComplex beta,
616
+ cuDoubleComplex* C,
617
+ int ldc);
618
+ /* ------------------------------------------------------- */
619
+ /* HERK */
620
+ void CUBLASWINAPI cublasCherk(
621
+ char uplo, char trans, int n, int k, float alpha, const cuComplex* A, int lda, float beta, cuComplex* C, int ldc);
622
+ void CUBLASWINAPI cublasZherk(char uplo,
623
+ char trans,
624
+ int n,
625
+ int k,
626
+ double alpha,
627
+ const cuDoubleComplex* A,
628
+ int lda,
629
+ double beta,
630
+ cuDoubleComplex* C,
631
+ int ldc);
632
+ /* ------------------------------------------------------- */
633
+ /* SYR2K */
634
+ void CUBLASWINAPI cublasSsyr2k(char uplo,
635
+ char trans,
636
+ int n,
637
+ int k,
638
+ float alpha,
639
+ const float* A,
640
+ int lda,
641
+ const float* B,
642
+ int ldb,
643
+ float beta,
644
+ float* C,
645
+ int ldc);
646
+
647
+ void CUBLASWINAPI cublasDsyr2k(char uplo,
648
+ char trans,
649
+ int n,
650
+ int k,
651
+ double alpha,
652
+ const double* A,
653
+ int lda,
654
+ const double* B,
655
+ int ldb,
656
+ double beta,
657
+ double* C,
658
+ int ldc);
659
+ void CUBLASWINAPI cublasCsyr2k(char uplo,
660
+ char trans,
661
+ int n,
662
+ int k,
663
+ cuComplex alpha,
664
+ const cuComplex* A,
665
+ int lda,
666
+ const cuComplex* B,
667
+ int ldb,
668
+ cuComplex beta,
669
+ cuComplex* C,
670
+ int ldc);
671
+
672
+ void CUBLASWINAPI cublasZsyr2k(char uplo,
673
+ char trans,
674
+ int n,
675
+ int k,
676
+ cuDoubleComplex alpha,
677
+ const cuDoubleComplex* A,
678
+ int lda,
679
+ const cuDoubleComplex* B,
680
+ int ldb,
681
+ cuDoubleComplex beta,
682
+ cuDoubleComplex* C,
683
+ int ldc);
684
+ /* ------------------------------------------------------- */
685
+ /* HER2K */
686
+ void CUBLASWINAPI cublasCher2k(char uplo,
687
+ char trans,
688
+ int n,
689
+ int k,
690
+ cuComplex alpha,
691
+ const cuComplex* A,
692
+ int lda,
693
+ const cuComplex* B,
694
+ int ldb,
695
+ float beta,
696
+ cuComplex* C,
697
+ int ldc);
698
+
699
+ void CUBLASWINAPI cublasZher2k(char uplo,
700
+ char trans,
701
+ int n,
702
+ int k,
703
+ cuDoubleComplex alpha,
704
+ const cuDoubleComplex* A,
705
+ int lda,
706
+ const cuDoubleComplex* B,
707
+ int ldb,
708
+ double beta,
709
+ cuDoubleComplex* C,
710
+ int ldc);
711
+
712
+ /*------------------------------------------------------------------------*/
713
+ /* SYMM*/
714
+ void CUBLASWINAPI cublasSsymm(char side,
715
+ char uplo,
716
+ int m,
717
+ int n,
718
+ float alpha,
719
+ const float* A,
720
+ int lda,
721
+ const float* B,
722
+ int ldb,
723
+ float beta,
724
+ float* C,
725
+ int ldc);
726
+ void CUBLASWINAPI cublasDsymm(char side,
727
+ char uplo,
728
+ int m,
729
+ int n,
730
+ double alpha,
731
+ const double* A,
732
+ int lda,
733
+ const double* B,
734
+ int ldb,
735
+ double beta,
736
+ double* C,
737
+ int ldc);
738
+
739
+ void CUBLASWINAPI cublasCsymm(char side,
740
+ char uplo,
741
+ int m,
742
+ int n,
743
+ cuComplex alpha,
744
+ const cuComplex* A,
745
+ int lda,
746
+ const cuComplex* B,
747
+ int ldb,
748
+ cuComplex beta,
749
+ cuComplex* C,
750
+ int ldc);
751
+
752
+ void CUBLASWINAPI cublasZsymm(char side,
753
+ char uplo,
754
+ int m,
755
+ int n,
756
+ cuDoubleComplex alpha,
757
+ const cuDoubleComplex* A,
758
+ int lda,
759
+ const cuDoubleComplex* B,
760
+ int ldb,
761
+ cuDoubleComplex beta,
762
+ cuDoubleComplex* C,
763
+ int ldc);
764
+ /*------------------------------------------------------------------------*/
765
+ /* HEMM*/
766
+ void CUBLASWINAPI cublasChemm(char side,
767
+ char uplo,
768
+ int m,
769
+ int n,
770
+ cuComplex alpha,
771
+ const cuComplex* A,
772
+ int lda,
773
+ const cuComplex* B,
774
+ int ldb,
775
+ cuComplex beta,
776
+ cuComplex* C,
777
+ int ldc);
778
+ void CUBLASWINAPI cublasZhemm(char side,
779
+ char uplo,
780
+ int m,
781
+ int n,
782
+ cuDoubleComplex alpha,
783
+ const cuDoubleComplex* A,
784
+ int lda,
785
+ const cuDoubleComplex* B,
786
+ int ldb,
787
+ cuDoubleComplex beta,
788
+ cuDoubleComplex* C,
789
+ int ldc);
790
+
791
+ /*------------------------------------------------------------------------*/
792
+ /* TRSM*/
793
+ void CUBLASWINAPI cublasStrsm(char side,
794
+ char uplo,
795
+ char transa,
796
+ char diag,
797
+ int m,
798
+ int n,
799
+ float alpha,
800
+ const float* A,
801
+ int lda,
802
+ float* B,
803
+ int ldb);
804
+
805
+ void CUBLASWINAPI cublasDtrsm(char side,
806
+ char uplo,
807
+ char transa,
808
+ char diag,
809
+ int m,
810
+ int n,
811
+ double alpha,
812
+ const double* A,
813
+ int lda,
814
+ double* B,
815
+ int ldb);
816
+
817
+ void CUBLASWINAPI cublasCtrsm(char side,
818
+ char uplo,
819
+ char transa,
820
+ char diag,
821
+ int m,
822
+ int n,
823
+ cuComplex alpha,
824
+ const cuComplex* A,
825
+ int lda,
826
+ cuComplex* B,
827
+ int ldb);
828
+
829
+ void CUBLASWINAPI cublasZtrsm(char side,
830
+ char uplo,
831
+ char transa,
832
+ char diag,
833
+ int m,
834
+ int n,
835
+ cuDoubleComplex alpha,
836
+ const cuDoubleComplex* A,
837
+ int lda,
838
+ cuDoubleComplex* B,
839
+ int ldb);
840
+ /*------------------------------------------------------------------------*/
841
+ /* TRMM*/
842
+ void CUBLASWINAPI cublasStrmm(char side,
843
+ char uplo,
844
+ char transa,
845
+ char diag,
846
+ int m,
847
+ int n,
848
+ float alpha,
849
+ const float* A,
850
+ int lda,
851
+ float* B,
852
+ int ldb);
853
+ void CUBLASWINAPI cublasDtrmm(char side,
854
+ char uplo,
855
+ char transa,
856
+ char diag,
857
+ int m,
858
+ int n,
859
+ double alpha,
860
+ const double* A,
861
+ int lda,
862
+ double* B,
863
+ int ldb);
864
+ void CUBLASWINAPI cublasCtrmm(char side,
865
+ char uplo,
866
+ char transa,
867
+ char diag,
868
+ int m,
869
+ int n,
870
+ cuComplex alpha,
871
+ const cuComplex* A,
872
+ int lda,
873
+ cuComplex* B,
874
+ int ldb);
875
+ void CUBLASWINAPI cublasZtrmm(char side,
876
+ char uplo,
877
+ char transa,
878
+ char diag,
879
+ int m,
880
+ int n,
881
+ cuDoubleComplex alpha,
882
+ const cuDoubleComplex* A,
883
+ int lda,
884
+ cuDoubleComplex* B,
885
+ int ldb);
886
+
887
+ #if defined(__cplusplus)
888
+ }
889
+ #endif /* __cplusplus */
890
+
891
+ #endif /* !defined(CUBLAS_H_) */
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h ADDED
@@ -0,0 +1,1845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+ #pragma once
50
+
51
+ #ifndef CUBLASAPI
52
+ #ifdef __CUDACC__
53
+ #define CUBLASAPI __host__ __device__
54
+ #else
55
+ #define CUBLASAPI
56
+ #endif
57
+ #endif
58
+
59
+ #include <cublas_api.h>
60
+
61
+ #include <stdint.h>
62
+ #include <stddef.h>
63
+ #include <stdio.h>
64
+
65
+ #if defined(__cplusplus)
66
+ extern "C" {
67
+ #endif /* __cplusplus */
68
+
69
+ /** Opaque structure holding CUBLASLT context
70
+ */
71
+ typedef struct cublasLtContext* cublasLtHandle_t;
72
+
73
+ cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t* lightHandle);
74
+
75
+ cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle);
76
+
77
+ const char* CUBLASWINAPI cublasLtGetStatusName(cublasStatus_t status);
78
+
79
+ const char* CUBLASWINAPI cublasLtGetStatusString(cublasStatus_t status);
80
+
81
+ size_t CUBLASWINAPI cublasLtGetVersion(void);
82
+
83
+ size_t CUBLASWINAPI cublasLtGetCudartVersion(void);
84
+
85
+ cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type, int* value);
86
+
87
+ cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheGetCapacity(size_t* capacity);
88
+ cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheSetCapacity(size_t capacity);
89
+
90
+ /** Restricts usage of CPU instructions (ISA) specified by the flags in the mask.
91
+ *
92
+ * Flags can be combined with bitwise OR(|) operator. Supported flags:
93
+ * - 0x1 -- x86-64 AVX512 ISA
94
+ *
95
+ * Default mask: 0 (any applicable ISA is allowed).
96
+ *
97
+ * The function returns the previous value of the mask.
98
+ * The function takes precedence over the environment variable CUBLASLT_DISABLE_CPU_INSTRUCTIONS_MASK.
99
+ */
100
+ unsigned CUBLASWINAPI cublasLtDisableCpuInstructionsSetMask(unsigned mask);
101
+
102
+ /** Semi-opaque descriptor for matrix memory layout
103
+ */
104
+ typedef struct {
105
+ uint64_t data[8];
106
+ } cublasLtMatrixLayoutOpaque_t;
107
+
108
+ /** Opaque descriptor for matrix memory layout
109
+ */
110
+ typedef cublasLtMatrixLayoutOpaque_t* cublasLtMatrixLayout_t;
111
+
112
+ /** Semi-opaque algorithm descriptor (to avoid complicated alloc/free schemes)
113
+ *
114
+ * This structure can be trivially serialized and later restored for use with the same version of cuBLAS library to save
115
+ * on selecting the right configuration again.
116
+ */
117
+ typedef struct {
118
+ uint64_t data[8];
119
+ } cublasLtMatmulAlgo_t;
120
+
121
+ /** Semi-opaque descriptor for cublasLtMatmul() operation details
122
+ */
123
+ typedef struct {
124
+ uint64_t data[32];
125
+ } cublasLtMatmulDescOpaque_t;
126
+
127
+ /** Opaque descriptor for cublasLtMatmul() operation details
128
+ */
129
+ typedef cublasLtMatmulDescOpaque_t* cublasLtMatmulDesc_t;
130
+
131
+ /** Semi-opaque descriptor for cublasLtMatrixTransform() operation details
132
+ */
133
+ typedef struct {
134
+ uint64_t data[8];
135
+ } cublasLtMatrixTransformDescOpaque_t;
136
+
137
+ /** Opaque descriptor for cublasLtMatrixTransform() operation details
138
+ */
139
+ typedef cublasLtMatrixTransformDescOpaque_t* cublasLtMatrixTransformDesc_t;
140
+
141
+ /** Semi-opaque descriptor for cublasLtMatmulPreference() operation details
142
+ */
143
+ typedef struct {
144
+ uint64_t data[8];
145
+ } cublasLtMatmulPreferenceOpaque_t;
146
+
147
+ /** Opaque descriptor for cublasLtMatmulAlgoGetHeuristic() configuration
148
+ */
149
+ typedef cublasLtMatmulPreferenceOpaque_t* cublasLtMatmulPreference_t;
150
+
151
+ /** Tile size (in C/D matrix Rows x Cols)
152
+ *
153
+ * General order of tile IDs is sorted by size first and by first dimension second.
154
+ */
155
+ typedef enum {
156
+ CUBLASLT_MATMUL_TILE_UNDEFINED = 0,
157
+ CUBLASLT_MATMUL_TILE_8x8 = 1,
158
+ CUBLASLT_MATMUL_TILE_8x16 = 2,
159
+ CUBLASLT_MATMUL_TILE_16x8 = 3,
160
+ CUBLASLT_MATMUL_TILE_8x32 = 4,
161
+ CUBLASLT_MATMUL_TILE_16x16 = 5,
162
+ CUBLASLT_MATMUL_TILE_32x8 = 6,
163
+ CUBLASLT_MATMUL_TILE_8x64 = 7,
164
+ CUBLASLT_MATMUL_TILE_16x32 = 8,
165
+ CUBLASLT_MATMUL_TILE_32x16 = 9,
166
+ CUBLASLT_MATMUL_TILE_64x8 = 10,
167
+ CUBLASLT_MATMUL_TILE_32x32 = 11,
168
+ CUBLASLT_MATMUL_TILE_32x64 = 12,
169
+ CUBLASLT_MATMUL_TILE_64x32 = 13,
170
+ CUBLASLT_MATMUL_TILE_32x128 = 14,
171
+ CUBLASLT_MATMUL_TILE_64x64 = 15,
172
+ CUBLASLT_MATMUL_TILE_128x32 = 16,
173
+ CUBLASLT_MATMUL_TILE_64x128 = 17,
174
+ CUBLASLT_MATMUL_TILE_128x64 = 18,
175
+ CUBLASLT_MATMUL_TILE_64x256 = 19,
176
+ CUBLASLT_MATMUL_TILE_128x128 = 20,
177
+ CUBLASLT_MATMUL_TILE_256x64 = 21,
178
+ CUBLASLT_MATMUL_TILE_64x512 = 22,
179
+ CUBLASLT_MATMUL_TILE_128x256 = 23,
180
+ CUBLASLT_MATMUL_TILE_256x128 = 24,
181
+ CUBLASLT_MATMUL_TILE_512x64 = 25,
182
+ CUBLASLT_MATMUL_TILE_64x96 = 26,
183
+ CUBLASLT_MATMUL_TILE_96x64 = 27,
184
+ CUBLASLT_MATMUL_TILE_96x128 = 28,
185
+ CUBLASLT_MATMUL_TILE_128x160 = 29,
186
+ CUBLASLT_MATMUL_TILE_160x128 = 30,
187
+ CUBLASLT_MATMUL_TILE_192x128 = 31,
188
+ CUBLASLT_MATMUL_TILE_128x192 = 32,
189
+ CUBLASLT_MATMUL_TILE_128x96 = 33,
190
+ CUBLASLT_MATMUL_TILE_32x256 = 34,
191
+ CUBLASLT_MATMUL_TILE_256x32 = 35,
192
+ CUBLASLT_MATMUL_TILE_END
193
+ } cublasLtMatmulTile_t;
194
+
195
+ /** Size and number of stages in which elements are read into shared memory
196
+ *
197
+ * General order of stages IDs is sorted by stage size first and by number of stages second.
198
+ */
199
+ typedef enum {
200
+ CUBLASLT_MATMUL_STAGES_UNDEFINED = 0,
201
+ CUBLASLT_MATMUL_STAGES_16x1 = 1,
202
+ CUBLASLT_MATMUL_STAGES_16x2 = 2,
203
+ CUBLASLT_MATMUL_STAGES_16x3 = 3,
204
+ CUBLASLT_MATMUL_STAGES_16x4 = 4,
205
+ CUBLASLT_MATMUL_STAGES_16x5 = 5,
206
+ CUBLASLT_MATMUL_STAGES_16x6 = 6,
207
+ CUBLASLT_MATMUL_STAGES_32x1 = 7,
208
+ CUBLASLT_MATMUL_STAGES_32x2 = 8,
209
+ CUBLASLT_MATMUL_STAGES_32x3 = 9,
210
+ CUBLASLT_MATMUL_STAGES_32x4 = 10,
211
+ CUBLASLT_MATMUL_STAGES_32x5 = 11,
212
+ CUBLASLT_MATMUL_STAGES_32x6 = 12,
213
+ CUBLASLT_MATMUL_STAGES_64x1 = 13,
214
+ CUBLASLT_MATMUL_STAGES_64x2 = 14,
215
+ CUBLASLT_MATMUL_STAGES_64x3 = 15,
216
+ CUBLASLT_MATMUL_STAGES_64x4 = 16,
217
+ CUBLASLT_MATMUL_STAGES_64x5 = 17,
218
+ CUBLASLT_MATMUL_STAGES_64x6 = 18,
219
+ CUBLASLT_MATMUL_STAGES_128x1 = 19,
220
+ CUBLASLT_MATMUL_STAGES_128x2 = 20,
221
+ CUBLASLT_MATMUL_STAGES_128x3 = 21,
222
+ CUBLASLT_MATMUL_STAGES_128x4 = 22,
223
+ CUBLASLT_MATMUL_STAGES_128x5 = 23,
224
+ CUBLASLT_MATMUL_STAGES_128x6 = 24,
225
+ CUBLASLT_MATMUL_STAGES_32x10 = 25,
226
+ CUBLASLT_MATMUL_STAGES_8x4 = 26,
227
+ CUBLASLT_MATMUL_STAGES_16x10 = 27,
228
+ CUBLASLT_MATMUL_STAGES_8x5 = 28,
229
+ CUBLASLT_MATMUL_STAGES_8x3 = 31,
230
+ CUBLASLT_MATMUL_STAGES_8xAUTO = 32,
231
+ CUBLASLT_MATMUL_STAGES_16xAUTO = 33,
232
+ CUBLASLT_MATMUL_STAGES_32xAUTO = 34,
233
+ CUBLASLT_MATMUL_STAGES_64xAUTO = 35,
234
+ CUBLASLT_MATMUL_STAGES_128xAUTO = 36,
235
+ CUBLASLT_MATMUL_STAGES_END
236
+ } cublasLtMatmulStages_t;
237
+
238
+ /** Thread Block Cluster size
239
+ *
240
+ * Typically dimensioned similar to cublasLtMatmulTile_t, with the third coordinate unused at this time.
241
+ */
242
+ typedef enum {
243
+ /** Let library pick cluster shape automatically */
244
+ CUBLASLT_CLUSTER_SHAPE_AUTO = 0,
245
+ CUBLASLT_CLUSTER_SHAPE_1x1x1 = 2,
246
+ CUBLASLT_CLUSTER_SHAPE_2x1x1 = 3,
247
+ CUBLASLT_CLUSTER_SHAPE_4x1x1 = 4,
248
+ CUBLASLT_CLUSTER_SHAPE_1x2x1 = 5,
249
+ CUBLASLT_CLUSTER_SHAPE_2x2x1 = 6,
250
+ CUBLASLT_CLUSTER_SHAPE_4x2x1 = 7,
251
+ CUBLASLT_CLUSTER_SHAPE_1x4x1 = 8,
252
+ CUBLASLT_CLUSTER_SHAPE_2x4x1 = 9,
253
+ CUBLASLT_CLUSTER_SHAPE_4x4x1 = 10,
254
+ CUBLASLT_CLUSTER_SHAPE_8x1x1 = 11,
255
+ CUBLASLT_CLUSTER_SHAPE_1x8x1 = 12,
256
+ CUBLASLT_CLUSTER_SHAPE_8x2x1 = 13,
257
+ CUBLASLT_CLUSTER_SHAPE_2x8x1 = 14,
258
+ CUBLASLT_CLUSTER_SHAPE_16x1x1 = 15,
259
+ CUBLASLT_CLUSTER_SHAPE_1x16x1 = 16,
260
+ CUBLASLT_CLUSTER_SHAPE_3x1x1 = 17,
261
+ CUBLASLT_CLUSTER_SHAPE_5x1x1 = 18,
262
+ CUBLASLT_CLUSTER_SHAPE_6x1x1 = 19,
263
+ CUBLASLT_CLUSTER_SHAPE_7x1x1 = 20,
264
+ CUBLASLT_CLUSTER_SHAPE_9x1x1 = 21,
265
+ CUBLASLT_CLUSTER_SHAPE_10x1x1 = 22,
266
+ CUBLASLT_CLUSTER_SHAPE_11x1x1 = 23,
267
+ CUBLASLT_CLUSTER_SHAPE_12x1x1 = 24,
268
+ CUBLASLT_CLUSTER_SHAPE_13x1x1 = 25,
269
+ CUBLASLT_CLUSTER_SHAPE_14x1x1 = 26,
270
+ CUBLASLT_CLUSTER_SHAPE_15x1x1 = 27,
271
+ CUBLASLT_CLUSTER_SHAPE_3x2x1 = 28,
272
+ CUBLASLT_CLUSTER_SHAPE_5x2x1 = 29,
273
+ CUBLASLT_CLUSTER_SHAPE_6x2x1 = 30,
274
+ CUBLASLT_CLUSTER_SHAPE_7x2x1 = 31,
275
+ CUBLASLT_CLUSTER_SHAPE_1x3x1 = 32,
276
+ CUBLASLT_CLUSTER_SHAPE_2x3x1 = 33,
277
+ CUBLASLT_CLUSTER_SHAPE_3x3x1 = 34,
278
+ CUBLASLT_CLUSTER_SHAPE_4x3x1 = 35,
279
+ CUBLASLT_CLUSTER_SHAPE_5x3x1 = 36,
280
+ CUBLASLT_CLUSTER_SHAPE_3x4x1 = 37,
281
+ CUBLASLT_CLUSTER_SHAPE_1x5x1 = 38,
282
+ CUBLASLT_CLUSTER_SHAPE_2x5x1 = 39,
283
+ CUBLASLT_CLUSTER_SHAPE_3x5x1 = 40,
284
+ CUBLASLT_CLUSTER_SHAPE_1x6x1 = 41,
285
+ CUBLASLT_CLUSTER_SHAPE_2x6x1 = 42,
286
+ CUBLASLT_CLUSTER_SHAPE_1x7x1 = 43,
287
+ CUBLASLT_CLUSTER_SHAPE_2x7x1 = 44,
288
+ CUBLASLT_CLUSTER_SHAPE_1x9x1 = 45,
289
+ CUBLASLT_CLUSTER_SHAPE_1x10x1 = 46,
290
+ CUBLASLT_CLUSTER_SHAPE_1x11x1 = 47,
291
+ CUBLASLT_CLUSTER_SHAPE_1x12x1 = 48,
292
+ CUBLASLT_CLUSTER_SHAPE_1x13x1 = 49,
293
+ CUBLASLT_CLUSTER_SHAPE_1x14x1 = 50,
294
+ CUBLASLT_CLUSTER_SHAPE_1x15x1 = 51,
295
+ CUBLASLT_CLUSTER_SHAPE_END
296
+ } cublasLtClusterShape_t;
297
+
298
+ /** Inner size of the kernel
299
+ *
300
+ * Represents various aspects of internal kernel design, that don't impact CUDA grid size but may have other more subtle
301
+ * effects.
302
+ *
303
+ */
304
+ typedef enum {
305
+ CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED = 0,
306
+ CUBLASLT_MATMUL_INNER_SHAPE_MMA884 = 1,
307
+ CUBLASLT_MATMUL_INNER_SHAPE_MMA1684 = 2,
308
+ CUBLASLT_MATMUL_INNER_SHAPE_MMA1688 = 3,
309
+ CUBLASLT_MATMUL_INNER_SHAPE_MMA16816 = 4,
310
+ CUBLASLT_MATMUL_INNER_SHAPE_END
311
+ } cublasLtMatmulInnerShape_t;
312
+
313
+ /** Pointer mode to use for alpha/beta */
314
+ typedef enum {
315
+ /** matches CUBLAS_POINTER_MODE_HOST, pointer targets a single value host memory */
316
+ CUBLASLT_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST,
317
+ /** matches CUBLAS_POINTER_MODE_DEVICE, pointer targets a single value device memory */
318
+ CUBLASLT_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE,
319
+ /** pointer targets an array in device memory */
320
+ CUBLASLT_POINTER_MODE_DEVICE_VECTOR = 2,
321
+ /** alpha pointer targets an array in device memory, beta is zero. Note:
322
+ CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is not supported, must be 0. */
323
+ CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO = 3,
324
+ /** alpha pointer targets an array in device memory, beta is a single value in host memory. */
325
+ CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST = 4,
326
+ } cublasLtPointerMode_t;
327
+
328
+ /** Mask to define pointer mode capability */
329
+ typedef enum {
330
+ /** see CUBLASLT_POINTER_MODE_HOST */
331
+ CUBLASLT_POINTER_MODE_MASK_HOST = 1,
332
+ /** see CUBLASLT_POINTER_MODE_DEVICE */
333
+ CUBLASLT_POINTER_MODE_MASK_DEVICE = 2,
334
+ /** see CUBLASLT_POINTER_MODE_DEVICE_VECTOR */
335
+ CUBLASLT_POINTER_MODE_MASK_DEVICE_VECTOR = 4,
336
+ /** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO */
337
+ CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_ZERO = 8,
338
+ /** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST */
339
+ CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_HOST = 16,
340
+ } cublasLtPointerModeMask_t;
341
+
342
+ /** Implementation details that may affect numerical behavior of algorithms. */
343
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA (0x01ull << 0)
344
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA (0x02ull << 0)
345
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_IMMA (0x04ull << 0)
346
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_DMMA (0x08ull << 0)
347
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK (0xfeull << 0)
348
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK (0xffull << 0)
349
+
350
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_16F (0x01ull << 8)
351
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32F (0x02ull << 8)
352
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_64F (0x04ull << 8)
353
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32I (0x08ull << 8)
354
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_TYPE_MASK (0xffull << 8)
355
+
356
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16F (0x01ull << 16)
357
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16BF (0x02ull << 16)
358
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_TF32 (0x04ull << 16)
359
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_32F (0x08ull << 16)
360
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_64F (0x10ull << 16)
361
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8I (0x20ull << 16)
362
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E4M3 (0x40ull << 16)
363
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E5M2 (0x80ull << 16)
364
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK (0xffull << 16)
365
+
366
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN (0x01ull << 32)
367
+ typedef uint64_t cublasLtNumericalImplFlags_t;
368
+
369
+ /** Execute matrix multiplication (D = alpha * op(A) * op(B) + beta * C).
370
+ *
371
+ * \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized
372
+ * \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g.
373
+ * when workspaceSizeInBytes is less than workspace required by configured
374
+ * algo
375
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured
376
+ * operation
377
+ * \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device
378
+ * \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device
379
+ * \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully
380
+ */
381
+ cublasStatus_t CUBLASWINAPI cublasLtMatmul(cublasLtHandle_t lightHandle,
382
+ cublasLtMatmulDesc_t computeDesc,
383
+ const void* alpha, /* host or device pointer */
384
+ const void* A,
385
+ cublasLtMatrixLayout_t Adesc,
386
+ const void* B,
387
+ cublasLtMatrixLayout_t Bdesc,
388
+ const void* beta, /* host or device pointer */
389
+ const void* C,
390
+ cublasLtMatrixLayout_t Cdesc,
391
+ void* D,
392
+ cublasLtMatrixLayout_t Ddesc,
393
+ const cublasLtMatmulAlgo_t* algo,
394
+ void* workspace,
395
+ size_t workspaceSizeInBytes,
396
+ cudaStream_t stream);
397
+
398
+ /** Matrix layout conversion helper (C = alpha * op(A) + beta * op(B))
399
+ *
400
+ * Can be used to change memory order of data or to scale and shift the values.
401
+ *
402
+ * \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized
403
+ * \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g.
404
+ * when A is not NULL, but Adesc is NULL
405
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured
406
+ * operation
407
+ * \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device
408
+ * \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device
409
+ * \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully
410
+ */
411
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(cublasLtHandle_t lightHandle,
412
+ cublasLtMatrixTransformDesc_t transformDesc,
413
+ const void* alpha, /* host or device pointer */
414
+ const void* A,
415
+ cublasLtMatrixLayout_t Adesc,
416
+ const void* beta, /* host or device pointer */
417
+ const void* B,
418
+ cublasLtMatrixLayout_t Bdesc,
419
+ void* C,
420
+ cublasLtMatrixLayout_t Cdesc,
421
+ cudaStream_t stream);
422
+
423
+ /* ---------------------------------------------------------------------------------------*/
424
+ /* Helper functions for cublasLtMatrixLayout_t */
425
+ /* ---------------------------------------------------------------------------------------*/
426
+
427
+ /** Enum for data ordering */
428
+ typedef enum {
429
+ /** Column-major
430
+ *
431
+ * Leading dimension is the stride (in elements) to the beginning of next column in memory.
432
+ */
433
+ CUBLASLT_ORDER_COL = 0,
434
+ /** Row major
435
+ *
436
+ * Leading dimension is the stride (in elements) to the beginning of next row in memory.
437
+ */
438
+ CUBLASLT_ORDER_ROW = 1,
439
+ /** Column-major ordered tiles of 32 columns.
440
+ *
441
+ * Leading dimension is the stride (in elements) to the beginning of next group of 32-columns. E.g. if matrix has 33
442
+ * columns and 2 rows, ld must be at least (32) * 2 = 64.
443
+ */
444
+ CUBLASLT_ORDER_COL32 = 2,
445
+ /** Column-major ordered tiles of composite tiles with total 32 columns and 8 rows, tile composed of interleaved
446
+ * inner tiles of 4 columns within 4 even or odd rows in an alternating pattern.
447
+ *
448
+ * Leading dimension is the stride (in elements) to the beginning of the first 32 column x 8 row tile for the next
449
+ * 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32 * 8) * 1 = 256.
450
+ */
451
+ CUBLASLT_ORDER_COL4_4R2_8C = 3,
452
+ /** Column-major ordered tiles of composite tiles with total 32 columns ands 32 rows.
453
+ * Element offset within the tile is calculated as (((row%8)/2*4+row/8)*2+row%2)*32+col.
454
+ *
455
+ * Leading dimension is the stride (in elements) to the beginning of the first 32 column x 32 row tile for the next
456
+ * 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32*32)*1 = 1024.
457
+ */
458
+ CUBLASLT_ORDER_COL32_2R_4R4 = 4,
459
+
460
+ } cublasLtOrder_t;
461
+
462
+ /** Attributes of memory layout */
463
+ typedef enum {
464
+ /** Data type, see cudaDataType.
465
+ *
466
+ * uint32_t
467
+ */
468
+ CUBLASLT_MATRIX_LAYOUT_TYPE = 0,
469
+
470
+ /** Memory order of the data, see cublasLtOrder_t.
471
+ *
472
+ * int32_t, default: CUBLASLT_ORDER_COL
473
+ */
474
+ CUBLASLT_MATRIX_LAYOUT_ORDER = 1,
475
+
476
+ /** Number of rows.
477
+ *
478
+ * Usually only values that can be expressed as int32_t are supported.
479
+ *
480
+ * uint64_t
481
+ */
482
+ CUBLASLT_MATRIX_LAYOUT_ROWS = 2,
483
+
484
+ /** Number of columns.
485
+ *
486
+ * Usually only values that can be expressed as int32_t are supported.
487
+ *
488
+ * uint64_t
489
+ */
490
+ CUBLASLT_MATRIX_LAYOUT_COLS = 3,
491
+
492
+ /** Matrix leading dimension.
493
+ *
494
+ * For CUBLASLT_ORDER_COL this is stride (in elements) of matrix column, for more details and documentation for
495
+ * other memory orders see documentation for cublasLtOrder_t values.
496
+ *
497
+ * Currently only non-negative values are supported, must be large enough so that matrix memory locations are not
498
+ * overlapping (e.g. greater or equal to CUBLASLT_MATRIX_LAYOUT_ROWS in case of CUBLASLT_ORDER_COL).
499
+ *
500
+ * int64_t;
501
+ */
502
+ CUBLASLT_MATRIX_LAYOUT_LD = 4,
503
+
504
+ /** Number of matmul operations to perform in the batch.
505
+ *
506
+ * See also CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT
507
+ *
508
+ * int32_t, default: 1
509
+ */
510
+ CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5,
511
+
512
+ /** Stride (in elements) to the next matrix for strided batch operation.
513
+ *
514
+ * When matrix type is planar-complex (CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET != 0), batch stride
515
+ * is interpreted by cublasLtMatmul() in number of real valued sub-elements. E.g. for data of type CUDA_C_16F,
516
+ * offset of 1024B is encoded as a stride of value 512 (since each element of the real and imaginary matrices
517
+ * is a 2B (16bit) floating point type).
518
+ *
519
+ * NOTE: A bug in cublasLtMatrixTransform() causes it to interpret the batch stride for a planar-complex matrix
520
+ * as if it was specified in number of complex elements. Therefore an offset of 1024B must be encoded as stride
521
+ * value 256 when calling cublasLtMatrixTransform() (each complex element is 4B with real and imaginary values 2B
522
+ * each). This behavior is expected to be corrected in the next major cuBLAS version.
523
+ *
524
+ * int64_t, default: 0
525
+ */
526
+ CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6,
527
+
528
+ /** Stride (in bytes) to the imaginary plane for planar complex layout.
529
+ *
530
+ * int64_t, default: 0 - 0 means that layout is regular (real and imaginary parts of complex numbers are interleaved
531
+ * in memory in each element)
532
+ */
533
+ CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET = 7,
534
+ } cublasLtMatrixLayoutAttribute_t;
535
+
536
+ /** Internal. Do not use directly.
537
+ */
538
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( //
539
+ cublasLtMatrixLayout_t matLayout,
540
+ size_t size,
541
+ cudaDataType type,
542
+ uint64_t rows,
543
+ uint64_t cols,
544
+ int64_t ld);
545
+
546
+ /** Initialize matrix layout descriptor in pre-allocated space.
547
+ *
548
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
549
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
550
+ */
551
+ static inline cublasStatus_t cublasLtMatrixLayoutInit(
552
+ cublasLtMatrixLayout_t matLayout, cudaDataType type, uint64_t rows, uint64_t cols, int64_t ld) {
553
+ return cublasLtMatrixLayoutInit_internal(matLayout, sizeof(*matLayout), type, rows, cols, ld);
554
+ }
555
+
556
+ /** Create new matrix layout descriptor.
557
+ *
558
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
559
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
560
+ */
561
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( //
562
+ cublasLtMatrixLayout_t* matLayout,
563
+ cudaDataType type,
564
+ uint64_t rows,
565
+ uint64_t cols,
566
+ int64_t ld);
567
+
568
+ /** Destroy matrix layout descriptor.
569
+ *
570
+ * \retval CUBLAS_STATUS_SUCCESS if operation was successful
571
+ */
572
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout);
573
+
574
+ /** Set matrix layout descriptor attribute.
575
+ *
576
+ * \param[in] matLayout The descriptor
577
+ * \param[in] attr The attribute
578
+ * \param[in] buf memory address containing the new value
579
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
580
+ *
581
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
582
+ * selected attribute
583
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
584
+ */
585
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( //
586
+ cublasLtMatrixLayout_t matLayout,
587
+ cublasLtMatrixLayoutAttribute_t attr,
588
+ const void* buf,
589
+ size_t sizeInBytes);
590
+
591
+ /** Get matrix layout descriptor attribute.
592
+ *
593
+ * \param[in] matLayout The descriptor
594
+ * \param[in] attr The attribute
595
+ * \param[out] buf memory address containing the new value
596
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
597
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
598
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
599
+ *
600
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
601
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
602
+ * selected attribute
603
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
604
+ */
605
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( //
606
+ cublasLtMatrixLayout_t matLayout,
607
+ cublasLtMatrixLayoutAttribute_t attr,
608
+ void* buf,
609
+ size_t sizeInBytes,
610
+ size_t* sizeWritten);
611
+
612
+ /* ---------------------------------------------------------------------------------------*/
613
+ /* Helper functions for cublasLtMatmulDesc_t */
614
+ /* ---------------------------------------------------------------------------------------*/
615
+
616
+ /** Matmul descriptor attributes to define details of the operation. */
617
+ typedef enum {
618
+ /** Compute type, see cudaDataType. Defines data type used for multiply and accumulate operations and the
619
+ * accumulator during matrix multiplication.
620
+ *
621
+ * int32_t
622
+ */
623
+ CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0,
624
+
625
+ /** Scale type, see cudaDataType. Defines data type of alpha and beta. Accumulator and value from matrix C are
626
+ * typically converted to scale type before final scaling. Value is then converted from scale type to type of matrix
627
+ * D before being stored in memory.
628
+ *
629
+ * int32_t, default: same as CUBLASLT_MATMUL_DESC_COMPUTE_TYPE
630
+ */
631
+ CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1,
632
+
633
+ /** Pointer mode of alpha and beta, see cublasLtPointerMode_t. When CUBLASLT_POINTER_MODE_DEVICE_VECTOR is in use,
634
+ * alpha/beta vector lenghts must match number of output matrix rows.
635
+ *
636
+ * int32_t, default: CUBLASLT_POINTER_MODE_HOST
637
+ */
638
+ CUBLASLT_MATMUL_DESC_POINTER_MODE = 2,
639
+
640
+ /** Transform of matrix A, see cublasOperation_t.
641
+ *
642
+ * int32_t, default: CUBLAS_OP_N
643
+ */
644
+ CUBLASLT_MATMUL_DESC_TRANSA = 3,
645
+
646
+ /** Transform of matrix B, see cublasOperation_t.
647
+ *
648
+ * int32_t, default: CUBLAS_OP_N
649
+ */
650
+ CUBLASLT_MATMUL_DESC_TRANSB = 4,
651
+
652
+ /** Transform of matrix C, see cublasOperation_t.
653
+ *
654
+ * Currently only CUBLAS_OP_N is supported.
655
+ *
656
+ * int32_t, default: CUBLAS_OP_N
657
+ */
658
+ CUBLASLT_MATMUL_DESC_TRANSC = 5,
659
+
660
+ /** Matrix fill mode, see cublasFillMode_t.
661
+ *
662
+ * int32_t, default: CUBLAS_FILL_MODE_FULL
663
+ */
664
+ CUBLASLT_MATMUL_DESC_FILL_MODE = 6,
665
+
666
+ /** Epilogue function, see cublasLtEpilogue_t.
667
+ *
668
+ * uint32_t, default: CUBLASLT_EPILOGUE_DEFAULT
669
+ */
670
+ CUBLASLT_MATMUL_DESC_EPILOGUE = 7,
671
+
672
+ /** Bias or bias gradient vector pointer in the device memory.
673
+ *
674
+ * Bias case. See CUBLASLT_EPILOGUE_BIAS.
675
+ * For bias data type see CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE.
676
+ *
677
+ * Bias vector length must match matrix D rows count.
678
+ *
679
+ * Bias gradient case. See CUBLASLT_EPILOGUE_DRELU_BGRAD and CUBLASLT_EPILOGUE_DGELU_BGRAD.
680
+ * Bias gradient vector elements are the same type as the output elements
681
+ * (Ctype) with the exception of IMMA kernels (see above).
682
+ *
683
+ * Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic()
684
+ * depend on its value to determine expected pointer alignment.
685
+ *
686
+ * Bias case: const void *, default: NULL
687
+ * Bias gradient case: void *, default: NULL
688
+ */
689
+ CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8,
690
+
691
+ /** Batch stride for bias or bias gradient vector.
692
+ *
693
+ * Used together with CUBLASLT_MATMUL_DESC_BIAS_POINTER when matrix D's CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1.
694
+ *
695
+ * int64_t, default: 0
696
+ */
697
+ CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE = 10,
698
+
699
+ /** Pointer for epilogue auxiliary buffer.
700
+ *
701
+ * - Output vector for ReLu bit-mask in forward pass when CUBLASLT_EPILOGUE_RELU_AUX
702
+ * or CUBLASLT_EPILOGUE_RELU_AUX_BIAS epilogue is used.
703
+ * - Input vector for ReLu bit-mask in backward pass when
704
+ * CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is used.
705
+ *
706
+ * - Output of GELU input matrix in forward pass when
707
+ * CUBLASLT_EPILOGUE_GELU_AUX_BIAS epilogue is used.
708
+ * - Input of GELU input matrix for backward pass when
709
+ * CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue is used.
710
+ *
711
+ * For aux data type see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE.
712
+ *
713
+ * Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic()
714
+ * depend on its value to determine expected pointer alignment.
715
+ *
716
+ * Requires setting CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD attribute.
717
+ *
718
+ * Forward pass: void *, default: NULL
719
+ * Backward pass: const void *, default: NULL
720
+ */
721
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER = 11,
722
+
723
+ /** Leading dimension for epilogue auxiliary buffer.
724
+ *
725
+ * - ReLu bit-mask matrix leading dimension in elements (i.e. bits)
726
+ * when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is
727
+ * used. Must be divisible by 128 and be no less than the number of rows in the output matrix.
728
+ *
729
+ * - GELU input matrix leading dimension in elements
730
+ * when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used.
731
+ * Must be divisible by 8 and be no less than the number of rows in the output matrix.
732
+ *
733
+ * int64_t, default: 0
734
+ */
735
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD = 12,
736
+
737
+ /** Batch stride for epilogue auxiliary buffer.
738
+ *
739
+ * - ReLu bit-mask matrix batch stride in elements (i.e. bits)
740
+ * when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is
741
+ * used. Must be divisible by 128.
742
+ *
743
+ * - GELU input matrix batch stride in elements
744
+ * when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used.
745
+ * Must be divisible by 8.
746
+ *
747
+ * int64_t, default: 0
748
+ */
749
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE = 13,
750
+
751
+ /** Batch stride for alpha vector.
752
+ *
753
+ * Used together with CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST when matrix D's
754
+ * CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1. If CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO is set then
755
+ * CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE must be set to 0 as this mode doesnt supported batched alpha vector.
756
+ *
757
+ * int64_t, default: 0
758
+ */
759
+ CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE = 14,
760
+
761
+ /** Number of SMs to target for parallel execution. Optimizes heuristics for execution on a different number of SMs
762
+ * when user expects a concurrent stream to be using some of the device resources.
763
+ *
764
+ * int32_t, default: 0 - use the number reported by the device.
765
+ */
766
+ CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 15,
767
+
768
+ /** Device pointer to the scale factor value that converts data in matrix A to the compute data type range.
769
+ *
770
+ * The scaling factor value must have the same type as the compute type.
771
+ *
772
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1.
773
+ *
774
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
775
+ * will return CUBLAS_INVALID_VALUE.
776
+ *
777
+ * const void *, default: NULL
778
+ */
779
+ CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 17,
780
+
781
+ /** Device pointer to the scale factor value to convert data in matrix B to compute data type range.
782
+ *
783
+ * The scaling factor value must have the same type as the compute type.
784
+ *
785
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1.
786
+ *
787
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
788
+ * will return CUBLAS_INVALID_VALUE.
789
+ *
790
+ * const void *, default: NULL
791
+ */
792
+ CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 18,
793
+
794
+ /** Device pointer to the scale factor value to convert data in matrix C to compute data type range.
795
+ *
796
+ * The scaling factor value must have the same type as the compute type.
797
+ *
798
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1.
799
+ *
800
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
801
+ * will return CUBLAS_INVALID_VALUE.
802
+ *
803
+ * const void *, default: NULL
804
+ */
805
+ CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 19,
806
+
807
+ /** Device pointer to the scale factor value to convert data in matrix D to compute data type range.
808
+ *
809
+ * The scaling factor value must have the same type as the compute type.
810
+ *
811
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1.
812
+ *
813
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
814
+ * will return CUBLAS_INVALID_VALUE.
815
+ *
816
+ * const void *, default: NULL
817
+ */
818
+ CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 20,
819
+
820
+ /** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the
821
+ * output matrix.
822
+ *
823
+ * The computed value has the same type as the compute type.
824
+ *
825
+ * If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix
826
+ * data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
827
+ *
828
+ * void *, default: NULL
829
+ */
830
+ CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 21,
831
+
832
+ /** Type of the data to be stored to the memory pointed to by CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
833
+ *
834
+ * If unset, the data type defaults to the type of elements of the output matrix with some exceptions, see details
835
+ * below.
836
+ *
837
+ * ReLu uses a bit-mask.
838
+ *
839
+ * GELU input matrix elements type is the same as the type of elements of
840
+ * the output matrix with some exceptions, see details below.
841
+ *
842
+ * For fp8 kernels with output type CUDA_R_8F_E4M3 the aux data type can be CUDA_R_8F_E4M3 or CUDA_R_16F with some
843
+ * restrictions. See https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulDescAttributes_t for more details.
844
+ *
845
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
846
+ * will return CUBLAS_INVALID_VALUE.
847
+ *
848
+ * int32_t based on cudaDataType, default: -1
849
+ */
850
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE = 22,
851
+
852
+ /** Device pointer to the scaling factor value to convert results from compute type data range to storage
853
+ * data range in the auxiliary matrix that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
854
+ *
855
+ * The scaling factor value must have the same type as the compute type.
856
+ *
857
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1. If set for an unsupported matrix data,
858
+ * scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
859
+ *
860
+ * void *, default: NULL
861
+ */
862
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER = 23,
863
+
864
+ /** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the
865
+ * buffer that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
866
+ *
867
+ * The computed value has the same type as the compute type.
868
+ *
869
+ * If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix
870
+ * data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
871
+ *
872
+ * void *, default: NULL
873
+ */
874
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER = 24,
875
+
876
+ /** Flag for managing fp8 fast accumulation mode.
877
+ * When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results
878
+ * will not periodically be promoted to a higher precision.
879
+ *
880
+ * int8_t, default: 0 - fast accumulation mode is disabled.
881
+ */
882
+ CUBLASLT_MATMUL_DESC_FAST_ACCUM = 25,
883
+
884
+ /** Type of bias or bias gradient vector in the device memory.
885
+ *
886
+ * Bias case: see CUBLASLT_EPILOGUE_BIAS.
887
+ *
888
+ * Bias vector elements are the same type as the elements of output matrix (Dtype) with the following exceptions:
889
+ * - IMMA kernels with computeType=CUDA_R_32I and Ctype=CUDA_R_8I where the bias vector elements
890
+ * are the same type as alpha, beta (CUBLASLT_MATMUL_DESC_SCALE_TYPE=CUDA_R_32F)
891
+ * - fp8 kernels with an output type of CUDA_R_32F, CUDA_R_8F_E4M3 or CUDA_R_8F_E5M2, See
892
+ * https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul for details.
893
+ *
894
+ * int32_t based on cudaDataType, default: -1
895
+ */
896
+ CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 26,
897
+
898
+ /** EXPERIMENTAL: Number of atomic synchronization chunks in the row dimension of the output matrix D.
899
+ *
900
+ * int32_t, default 0 (atomic synchronization disabled)
901
+ */
902
+ CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS = 27,
903
+
904
+ /** EXPERIMENTAL: Number of atomic synchronization chunks in the column dimension of the output matrix D.
905
+ *
906
+ * int32_t, default 0 (atomic synchronization disabled)
907
+ */
908
+ CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS = 28,
909
+
910
+ /** EXPERIMENTAL: Pointer to a device array of input atomic counters consumed by a matmul.
911
+ *
912
+ * int32_t *, default: NULL
913
+ * */
914
+ CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER = 29,
915
+
916
+ /** EXPERIMENTAL: Pointer to a device array of output atomic counters produced by a matmul.
917
+ *
918
+ * int32_t *, default: NULL
919
+ * */
920
+ CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER = 30,
921
+ } cublasLtMatmulDescAttributes_t;
922
+
923
+ /** Internal. Do not use directly.
924
+ */
925
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( //
926
+ cublasLtMatmulDesc_t matmulDesc,
927
+ size_t size,
928
+ cublasComputeType_t computeType,
929
+ cudaDataType_t scaleType);
930
+
931
+ /** Initialize matmul operation descriptor in pre-allocated space.
932
+ *
933
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
934
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was initialized successfully
935
+ */
936
+ static inline cublasStatus_t cublasLtMatmulDescInit( //
937
+ cublasLtMatmulDesc_t matmulDesc,
938
+ cublasComputeType_t computeType,
939
+ cudaDataType_t scaleType) {
940
+ return cublasLtMatmulDescInit_internal(matmulDesc, sizeof(*matmulDesc), computeType, scaleType);
941
+ }
942
+
943
+ /** Create new matmul operation descriptor.
944
+ *
945
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
946
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
947
+ */
948
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate(cublasLtMatmulDesc_t* matmulDesc,
949
+ cublasComputeType_t computeType,
950
+ cudaDataType_t scaleType);
951
+
952
+ /** Destroy matmul operation descriptor.
953
+ *
954
+ * \retval CUBLAS_STATUS_SUCCESS if operation was successful
955
+ */
956
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc);
957
+
958
+ /** Set matmul operation descriptor attribute.
959
+ *
960
+ * \param[in] matmulDesc The descriptor
961
+ * \param[in] attr The attribute
962
+ * \param[in] buf memory address containing the new value
963
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
964
+ *
965
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
966
+ * selected attribute
967
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
968
+ */
969
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( //
970
+ cublasLtMatmulDesc_t matmulDesc,
971
+ cublasLtMatmulDescAttributes_t attr,
972
+ const void* buf,
973
+ size_t sizeInBytes);
974
+
975
+ /** Get matmul operation descriptor attribute.
976
+ *
977
+ * \param[in] matmulDesc The descriptor
978
+ * \param[in] attr The attribute
979
+ * \param[out] buf memory address containing the new value
980
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
981
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
982
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
983
+ *
984
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
985
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
986
+ * selected attribute
987
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
988
+ */
989
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( //
990
+ cublasLtMatmulDesc_t matmulDesc,
991
+ cublasLtMatmulDescAttributes_t attr,
992
+ void* buf,
993
+ size_t sizeInBytes,
994
+ size_t* sizeWritten);
995
+
996
+ /* ---------------------------------------------------------------------------------------*/
997
+ /* Helper functions for cublasLtMatrixTransformDesc_t */
998
+ /* ---------------------------------------------------------------------------------------*/
999
+
1000
+ /** Matrix transform descriptor attributes to define details of the operation.
1001
+ */
1002
+ typedef enum {
1003
+ /** Scale type, see cudaDataType. Inputs are converted to scale type for scaling and summation and results are then
1004
+ * converted to output type to store in memory.
1005
+ *
1006
+ * int32_t
1007
+ */
1008
+ CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE,
1009
+
1010
+ /** Pointer mode of alpha and beta, see cublasLtPointerMode_t.
1011
+ *
1012
+ * int32_t, default: CUBLASLT_POINTER_MODE_HOST
1013
+ */
1014
+ CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE,
1015
+
1016
+ /** Transform of matrix A, see cublasOperation_t.
1017
+ *
1018
+ * int32_t, default: CUBLAS_OP_N
1019
+ */
1020
+ CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA,
1021
+
1022
+ /** Transform of matrix B, see cublasOperation_t.
1023
+ *
1024
+ * int32_t, default: CUBLAS_OP_N
1025
+ */
1026
+ CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSB,
1027
+ } cublasLtMatrixTransformDescAttributes_t;
1028
+
1029
+ /** Internal. Do not use directly.
1030
+ */
1031
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc,
1032
+ size_t size,
1033
+ cudaDataType scaleType);
1034
+
1035
+ /** Initialize matrix transform operation descriptor in pre-allocated space.
1036
+ *
1037
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
1038
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
1039
+ */
1040
+ static inline cublasStatus_t cublasLtMatrixTransformDescInit(cublasLtMatrixTransformDesc_t transformDesc,
1041
+ cudaDataType scaleType) {
1042
+ return cublasLtMatrixTransformDescInit_internal(transformDesc, sizeof(*transformDesc), scaleType);
1043
+ }
1044
+
1045
+ /** Create new matrix transform operation descriptor.
1046
+ *
1047
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
1048
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
1049
+ */
1050
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t* transformDesc,
1051
+ cudaDataType scaleType);
1052
+
1053
+ /** Destroy matrix transform operation descriptor.
1054
+ *
1055
+ * \retval CUBLAS_STATUS_SUCCESS if operation was successful
1056
+ */
1057
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc);
1058
+
1059
+ /** Set matrix transform operation descriptor attribute.
1060
+ *
1061
+ * \param[in] transformDesc The descriptor
1062
+ * \param[in] attr The attribute
1063
+ * \param[in] buf memory address containing the new value
1064
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1065
+ *
1066
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
1067
+ * selected attribute
1068
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
1069
+ */
1070
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( //
1071
+ cublasLtMatrixTransformDesc_t transformDesc,
1072
+ cublasLtMatrixTransformDescAttributes_t attr,
1073
+ const void* buf,
1074
+ size_t sizeInBytes);
1075
+
1076
+ /** Get matrix transform operation descriptor attribute.
1077
+ *
1078
+ * \param[in] transformDesc The descriptor
1079
+ * \param[in] attr The attribute
1080
+ * \param[out] buf memory address containing the new value
1081
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1082
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number
1083
+ * of bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
1084
+ *
1085
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
1086
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
1087
+ * selected attribute
1088
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
1089
+ */
1090
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( //
1091
+ cublasLtMatrixTransformDesc_t transformDesc,
1092
+ cublasLtMatrixTransformDescAttributes_t attr,
1093
+ void* buf,
1094
+ size_t sizeInBytes,
1095
+ size_t* sizeWritten);
1096
+
1097
+ /** Reduction scheme for portions of the dot-product calculated in parallel (a. k. a. "split - K").
1098
+ */
1099
+ typedef enum {
1100
+ /** No reduction scheme, dot-product shall be performed in one sequence.
1101
+ */
1102
+ CUBLASLT_REDUCTION_SCHEME_NONE = 0,
1103
+
1104
+ /** Reduction is performed "in place" - using the output buffer (and output data type) and counters (in workspace) to
1105
+ * guarantee the sequentiality.
1106
+ */
1107
+ CUBLASLT_REDUCTION_SCHEME_INPLACE = 1,
1108
+
1109
+ /** Intermediate results are stored in compute type in the workspace and reduced in a separate step.
1110
+ */
1111
+ CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE = 2,
1112
+
1113
+ /** Intermediate results are stored in output type in the workspace and reduced in a separate step.
1114
+ */
1115
+ CUBLASLT_REDUCTION_SCHEME_OUTPUT_TYPE = 4,
1116
+
1117
+ CUBLASLT_REDUCTION_SCHEME_MASK = 0x7,
1118
+ } cublasLtReductionScheme_t;
1119
+
1120
+ /** Postprocessing options for the epilogue
1121
+ */
1122
+ typedef enum {
1123
+ /** No special postprocessing, just scale and quantize results if necessary.
1124
+ */
1125
+ CUBLASLT_EPILOGUE_DEFAULT = 1,
1126
+
1127
+ /** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)).
1128
+ */
1129
+ CUBLASLT_EPILOGUE_RELU = 2,
1130
+
1131
+ /** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)).
1132
+ *
1133
+ * This epilogue mode produces an extra output, a ReLu bit-mask matrix,
1134
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1135
+ */
1136
+ CUBLASLT_EPILOGUE_RELU_AUX = (CUBLASLT_EPILOGUE_RELU | 128),
1137
+
1138
+ /** Bias, apply (broadcasted) Bias from bias vector. Bias vector length must match matrix D rows, it must be packed
1139
+ * (stride between vector elements is 1). Bias vector is broadcasted to all columns and added before applying final
1140
+ * postprocessing.
1141
+ */
1142
+ CUBLASLT_EPILOGUE_BIAS = 4,
1143
+
1144
+ /** ReLu and Bias, apply Bias and then ReLu transform
1145
+ */
1146
+ CUBLASLT_EPILOGUE_RELU_BIAS = (CUBLASLT_EPILOGUE_RELU | CUBLASLT_EPILOGUE_BIAS),
1147
+
1148
+ /** ReLu and Bias, apply Bias and then ReLu transform
1149
+ *
1150
+ * This epilogue mode produces an extra output, a ReLu bit-mask matrix,
1151
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1152
+ */
1153
+ CUBLASLT_EPILOGUE_RELU_AUX_BIAS = (CUBLASLT_EPILOGUE_RELU_AUX | CUBLASLT_EPILOGUE_BIAS),
1154
+
1155
+ /* ReLu gradient. Apply ReLu gradient to matmul output. Store ReLu gradient in the output matrix.
1156
+ *
1157
+ * This epilogue mode requires an extra input,
1158
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1159
+ */
1160
+ CUBLASLT_EPILOGUE_DRELU = 8 | 128,
1161
+
1162
+ /* ReLu and Bias gradients. Apply independently ReLu and Bias gradient to
1163
+ * matmul output. Store ReLu gradient in the output matrix, and Bias gradient
1164
+ * in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
1165
+ *
1166
+ * This epilogue mode requires an extra input,
1167
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1168
+ */
1169
+ CUBLASLT_EPILOGUE_DRELU_BGRAD = CUBLASLT_EPILOGUE_DRELU | 16,
1170
+
1171
+ /** GELU, apply GELU point-wise transform to the results (x:=GELU(x)).
1172
+ */
1173
+ CUBLASLT_EPILOGUE_GELU = 32,
1174
+
1175
+ /** GELU, apply GELU point-wise transform to the results (x:=GELU(x)).
1176
+ *
1177
+ * This epilogue mode outputs GELU input as a separate matrix (useful for training).
1178
+ * See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1179
+ */
1180
+ CUBLASLT_EPILOGUE_GELU_AUX = (CUBLASLT_EPILOGUE_GELU | 128),
1181
+
1182
+ /** GELU and Bias, apply Bias and then GELU transform
1183
+ */
1184
+ CUBLASLT_EPILOGUE_GELU_BIAS = (CUBLASLT_EPILOGUE_GELU | CUBLASLT_EPILOGUE_BIAS),
1185
+
1186
+ /** GELU and Bias, apply Bias and then GELU transform
1187
+ *
1188
+ * This epilogue mode outputs GELU input as a separate matrix (useful for training).
1189
+ * See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1190
+ */
1191
+ CUBLASLT_EPILOGUE_GELU_AUX_BIAS = (CUBLASLT_EPILOGUE_GELU_AUX | CUBLASLT_EPILOGUE_BIAS),
1192
+
1193
+ /* GELU gradient. Apply GELU gradient to matmul output. Store GELU gradient in the output matrix.
1194
+ *
1195
+ * This epilogue mode requires an extra input,
1196
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1197
+ */
1198
+ CUBLASLT_EPILOGUE_DGELU = 64 | 128,
1199
+
1200
+ /* GELU and Bias gradients. Apply independently GELU and Bias gradient to
1201
+ * matmul output. Store GELU gradient in the output matrix, and Bias gradient
1202
+ * in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
1203
+ *
1204
+ * This epilogue mode requires an extra input,
1205
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1206
+ */
1207
+ CUBLASLT_EPILOGUE_DGELU_BGRAD = CUBLASLT_EPILOGUE_DGELU | 16,
1208
+
1209
+ /** Bias gradient based on the input matrix A.
1210
+ *
1211
+ * The bias size corresponds to the number of rows of the matrix D.
1212
+ * The reduction happens over the GEMM's "k" dimension.
1213
+ *
1214
+ * Stores Bias gradient in the auxiliary output
1215
+ * (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
1216
+ */
1217
+ CUBLASLT_EPILOGUE_BGRADA = 256,
1218
+
1219
+ /** Bias gradient based on the input matrix B.
1220
+ *
1221
+ * The bias size corresponds to the number of columns of the matrix D.
1222
+ * The reduction happens over the GEMM's "k" dimension.
1223
+ *
1224
+ * Stores Bias gradient in the auxiliary output
1225
+ * (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
1226
+ */
1227
+ CUBLASLT_EPILOGUE_BGRADB = 512,
1228
+ } cublasLtEpilogue_t;
1229
+
1230
+ /** Matmul heuristic search mode
1231
+ */
1232
+ typedef enum {
1233
+ /** ask heuristics for best algo for given usecase
1234
+ */
1235
+ CUBLASLT_SEARCH_BEST_FIT = 0,
1236
+ /** only try to find best config for preconfigured algo id
1237
+ */
1238
+ CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID = 1,
1239
+ /** reserved for future use
1240
+ */
1241
+ CUBLASLT_SEARCH_RESERVED_02 = 2,
1242
+ /** reserved for future use
1243
+ */
1244
+ CUBLASLT_SEARCH_RESERVED_03 = 3,
1245
+ /** reserved for future use
1246
+ */
1247
+ CUBLASLT_SEARCH_RESERVED_04 = 4,
1248
+ /** reserved for future use
1249
+ */
1250
+ CUBLASLT_SEARCH_RESERVED_05 = 5,
1251
+ } cublasLtMatmulSearch_t;
1252
+
1253
+ /** Algo search preference to fine tune the heuristic function. */
1254
+ typedef enum {
1255
+ /** Search mode, see cublasLtMatmulSearch_t.
1256
+ *
1257
+ * uint32_t, default: CUBLASLT_SEARCH_BEST_FIT
1258
+ */
1259
+ CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0,
1260
+
1261
+ /** Maximum allowed workspace size in bytes.
1262
+ *
1263
+ * uint64_t, default: 0 - no workspace allowed
1264
+ */
1265
+ CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1,
1266
+
1267
+ /** Reduction scheme mask, see cublasLtReductionScheme_t. Filters heuristic result to only include algo configs that
1268
+ * use one of the required modes.
1269
+ *
1270
+ * E.g. mask value of 0x03 will allow only INPLACE and COMPUTE_TYPE reduction schemes.
1271
+ *
1272
+ * uint32_t, default: CUBLASLT_REDUCTION_SCHEME_MASK (allows all reduction schemes)
1273
+ */
1274
+ CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3,
1275
+
1276
+ /** Minimum buffer alignment for matrix A (in bytes).
1277
+ *
1278
+ * Selecting a smaller value will exclude algorithms that can not work with matrix A that is not as strictly aligned
1279
+ * as they need.
1280
+ *
1281
+ * uint32_t, default: 256
1282
+ */
1283
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5,
1284
+
1285
+ /** Minimum buffer alignment for matrix B (in bytes).
1286
+ *
1287
+ * Selecting a smaller value will exclude algorithms that can not work with matrix B that is not as strictly aligned
1288
+ * as they need.
1289
+ *
1290
+ * uint32_t, default: 256
1291
+ */
1292
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6,
1293
+
1294
+ /** Minimum buffer alignment for matrix C (in bytes).
1295
+ *
1296
+ * Selecting a smaller value will exclude algorithms that can not work with matrix C that is not as strictly aligned
1297
+ * as they need.
1298
+ *
1299
+ * uint32_t, default: 256
1300
+ */
1301
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7,
1302
+
1303
+ /** Minimum buffer alignment for matrix D (in bytes).
1304
+ *
1305
+ * Selecting a smaller value will exclude algorithms that can not work with matrix D that is not as strictly aligned
1306
+ * as they need.
1307
+ *
1308
+ * uint32_t, default: 256
1309
+ */
1310
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8,
1311
+
1312
+ /** Maximum wave count.
1313
+ *
1314
+ * See cublasLtMatmulHeuristicResult_t::wavesCount.
1315
+ *
1316
+ * Selecting a non-zero value will exclude algorithms that report device utilization higher than specified.
1317
+ *
1318
+ * float, default: 0.0f
1319
+ */
1320
+ CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9,
1321
+
1322
+ /** Numerical implementation details mask, see cublasLtNumericalImplFlags_t. Filters heuristic result to only include
1323
+ * algorithms that use the allowed implementations.
1324
+ *
1325
+ * uint64_t, default: uint64_t(-1) (allow everything)
1326
+ */
1327
+ CUBLASLT_MATMUL_PREF_IMPL_MASK = 12,
1328
+ } cublasLtMatmulPreferenceAttributes_t;
1329
+
1330
+ /** Internal. Do not use directly.
1331
+ */
1332
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size);
1333
+
1334
+ /** Initialize matmul heuristic search preference descriptor in pre-allocated space.
1335
+ *
1336
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
1337
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
1338
+ */
1339
+ static inline cublasStatus_t cublasLtMatmulPreferenceInit(cublasLtMatmulPreference_t pref) {
1340
+ return cublasLtMatmulPreferenceInit_internal(pref, sizeof(*pref));
1341
+ }
1342
+
1343
+ /** Create new matmul heuristic search preference descriptor.
1344
+ *
1345
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
1346
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
1347
+ */
1348
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t* pref);
1349
+
1350
+ /** Destroy matmul heuristic search preference descriptor.
1351
+ *
1352
+ * \retval CUBLAS_STATUS_SUCCESS if operation was successful
1353
+ */
1354
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref);
1355
+
1356
+ /** Set matmul heuristic search preference descriptor attribute.
1357
+ *
1358
+ * \param[in] pref The descriptor
1359
+ * \param[in] attr The attribute
1360
+ * \param[in] buf memory address containing the new value
1361
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1362
+ *
1363
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
1364
+ * selected attribute
1365
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
1366
+ */
1367
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( //
1368
+ cublasLtMatmulPreference_t pref,
1369
+ cublasLtMatmulPreferenceAttributes_t attr,
1370
+ const void* buf,
1371
+ size_t sizeInBytes);
1372
+
1373
+ /** Get matmul heuristic search preference descriptor attribute.
1374
+ *
1375
+ * \param[in] pref The descriptor
1376
+ * \param[in] attr The attribute
1377
+ * \param[out] buf memory address containing the new value
1378
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1379
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
1380
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
1381
+ *
1382
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
1383
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
1384
+ * selected attribute
1385
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
1386
+ */
1387
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( //
1388
+ cublasLtMatmulPreference_t pref,
1389
+ cublasLtMatmulPreferenceAttributes_t attr,
1390
+ void* buf,
1391
+ size_t sizeInBytes,
1392
+ size_t* sizeWritten);
1393
+
1394
+ /** Results structure used by cublasLtMatmulGetAlgo.
1395
+ *
1396
+ * Holds returned configured algo descriptor and its runtime properties.
1397
+ */
1398
+ typedef struct {
1399
+ /** Matmul algorithm descriptor.
1400
+ *
1401
+ * Must be initialized with cublasLtMatmulAlgoInit() if preferences' CUBLASLT_MATMUL_PERF_SEARCH_MODE is set to
1402
+ * CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID
1403
+ */
1404
+ cublasLtMatmulAlgo_t algo;
1405
+
1406
+ /** Actual size of workspace memory required.
1407
+ */
1408
+ size_t workspaceSize;
1409
+
1410
+ /** Result status, other fields are only valid if after call to cublasLtMatmulAlgoGetHeuristic() this member is set to
1411
+ * CUBLAS_STATUS_SUCCESS.
1412
+ */
1413
+ cublasStatus_t state;
1414
+
1415
+ /** Waves count - a device utilization metric.
1416
+ *
1417
+ * wavesCount value of 1.0f suggests that when kernel is launched it will fully occupy the GPU.
1418
+ */
1419
+ float wavesCount;
1420
+
1421
+ int reserved[4];
1422
+ } cublasLtMatmulHeuristicResult_t;
1423
+
1424
+ /** Query cublasLt heuristic for algorithm appropriate for given use case.
1425
+ *
1426
+ * \param[in] lightHandle Pointer to the allocated cuBLASLt handle for the cuBLASLt
1427
+ * context. See cublasLtHandle_t.
1428
+ * \param[in] operationDesc Handle to the matrix multiplication descriptor.
1429
+ * \param[in] Adesc Handle to the layout descriptors for matrix A.
1430
+ * \param[in] Bdesc Handle to the layout descriptors for matrix B.
1431
+ * \param[in] Cdesc Handle to the layout descriptors for matrix C.
1432
+ * \param[in] Ddesc Handle to the layout descriptors for matrix D.
1433
+ * \param[in] preference Pointer to the structure holding the heuristic search
1434
+ * preferences descriptor. See cublasLtMatrixLayout_t.
1435
+ * \param[in] requestedAlgoCount Size of heuristicResultsArray (in elements) and requested
1436
+ * maximum number of algorithms to return.
1437
+ * \param[in, out] heuristicResultsArray Output algorithms and associated runtime characteristics,
1438
+ * ordered in increasing estimated compute time.
1439
+ * \param[out] returnAlgoCount The number of heuristicResultsArray elements written.
1440
+ *
1441
+ * \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero
1442
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if no heuristic function available for current configuration
1443
+ * \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect
1444
+ * heuristicResultsArray[0 to (returnAlgoCount - 1)].state
1445
+ * for detail status of results
1446
+ */
1447
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic(cublasLtHandle_t lightHandle,
1448
+ cublasLtMatmulDesc_t operationDesc,
1449
+ cublasLtMatrixLayout_t Adesc,
1450
+ cublasLtMatrixLayout_t Bdesc,
1451
+ cublasLtMatrixLayout_t Cdesc,
1452
+ cublasLtMatrixLayout_t Ddesc,
1453
+ cublasLtMatmulPreference_t preference,
1454
+ int requestedAlgoCount,
1455
+ cublasLtMatmulHeuristicResult_t heuristicResultsArray[],
1456
+ int* returnAlgoCount);
1457
+
1458
+ /* ---------------------------------------------------------------------------------------*/
1459
+ /* Lower level API to be able to implement own Heuristic and Find routines */
1460
+ /* ---------------------------------------------------------------------------------------*/
1461
+
1462
+ /** Routine to get all algo IDs that can potentially run
1463
+ *
1464
+ * \param[in] int requestedAlgoCount requested number of algos (must be less or equal to size of algoIdsA
1465
+ * (in elements)) \param[out] algoIdsA array to write algoIds to \param[out] returnAlgoCount number of algoIds
1466
+ * actually written
1467
+ *
1468
+ * \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero
1469
+ * \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect returnAlgoCount to get actual number of IDs
1470
+ * available
1471
+ */
1472
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds(cublasLtHandle_t lightHandle,
1473
+ cublasComputeType_t computeType,
1474
+ cudaDataType_t scaleType,
1475
+ cudaDataType_t Atype,
1476
+ cudaDataType_t Btype,
1477
+ cudaDataType_t Ctype,
1478
+ cudaDataType_t Dtype,
1479
+ int requestedAlgoCount,
1480
+ int algoIdsArray[],
1481
+ int* returnAlgoCount);
1482
+
1483
+ /** Initialize algo structure
1484
+ *
1485
+ * \retval CUBLAS_STATUS_INVALID_VALUE if algo is NULL or algoId is outside of recognized range
1486
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if algoId is not supported for given combination of data types
1487
+ * \retval CUBLAS_STATUS_SUCCESS if the structure was successfully initialized
1488
+ */
1489
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit(cublasLtHandle_t lightHandle,
1490
+ cublasComputeType_t computeType,
1491
+ cudaDataType_t scaleType,
1492
+ cudaDataType_t Atype,
1493
+ cudaDataType_t Btype,
1494
+ cudaDataType_t Ctype,
1495
+ cudaDataType_t Dtype,
1496
+ int algoId,
1497
+ cublasLtMatmulAlgo_t* algo);
1498
+
1499
+ /** Check configured algo descriptor for correctness and support on current device.
1500
+ *
1501
+ * Result includes required workspace size and calculated wave count.
1502
+ *
1503
+ * CUBLAS_STATUS_SUCCESS doesn't fully guarantee algo will run (will fail if e.g. buffers are not correctly aligned);
1504
+ * but if cublasLtMatmulAlgoCheck fails, the algo will not run.
1505
+ *
1506
+ * \param[in] algo algo configuration to check
1507
+ * \param[out] result result structure to report algo runtime characteristics; algo field is never updated
1508
+ *
1509
+ * \retval CUBLAS_STATUS_INVALID_VALUE if matrix layout descriptors or operation descriptor don't match algo
1510
+ * descriptor
1511
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if algo configuration or data type combination is not currently supported on
1512
+ * given device
1513
+ * \retval CUBLAS_STATUS_ARCH_MISMATCH if algo configuration cannot be run using the selected device
1514
+ * \retval CUBLAS_STATUS_SUCCESS if check was successful
1515
+ */
1516
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( //
1517
+ cublasLtHandle_t lightHandle,
1518
+ cublasLtMatmulDesc_t operationDesc,
1519
+ cublasLtMatrixLayout_t Adesc,
1520
+ cublasLtMatrixLayout_t Bdesc,
1521
+ cublasLtMatrixLayout_t Cdesc,
1522
+ cublasLtMatrixLayout_t Ddesc,
1523
+ const cublasLtMatmulAlgo_t* algo, ///< may point to result->algo
1524
+ cublasLtMatmulHeuristicResult_t* result);
1525
+
1526
+ /** Capabilities Attributes that can be retrieved from an initialized Algo structure
1527
+ */
1528
+ typedef enum {
1529
+ /** support for split K, see CUBLASLT_ALGO_CONFIG_SPLITK_NUM
1530
+ *
1531
+ * int32_t, 0 means no support, supported otherwise
1532
+ */
1533
+ CUBLASLT_ALGO_CAP_SPLITK_SUPPORT = 0,
1534
+
1535
+ /** reduction scheme mask, see cublasLtReductionScheme_t; shows supported reduction schemes, if reduction scheme is
1536
+ * not masked out it is supported.
1537
+ *
1538
+ * e.g. int isReductionSchemeComputeTypeSupported ? (reductionSchemeMask & CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE) ==
1539
+ * CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE ? 1 : 0;
1540
+ *
1541
+ * uint32_t
1542
+ */
1543
+ CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK = 1,
1544
+
1545
+ /** support for cta swizzling, see CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
1546
+ *
1547
+ * uint32_t, 0 means no support, 1 means supported value of 1, other values are reserved
1548
+ */
1549
+ CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT = 2,
1550
+
1551
+ /** support strided batch
1552
+ *
1553
+ * int32_t, 0 means no support, supported otherwise
1554
+ */
1555
+ CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT = 3,
1556
+
1557
+ /** support results out of place (D != C in D = alpha.A.B + beta.C)
1558
+ *
1559
+ * int32_t, 0 means no support, supported otherwise
1560
+ */
1561
+ CUBLASLT_ALGO_CAP_OUT_OF_PLACE_RESULT_SUPPORT = 4,
1562
+
1563
+ /** syrk/herk support (on top of regular gemm)
1564
+ *
1565
+ * int32_t, 0 means no support, supported otherwise
1566
+ */
1567
+ CUBLASLT_ALGO_CAP_UPLO_SUPPORT = 5,
1568
+
1569
+ /** tile ids possible to use, see cublasLtMatmulTile_t; if no tile ids are supported use
1570
+ * CUBLASLT_MATMUL_TILE_UNDEFINED
1571
+ *
1572
+ * use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count
1573
+ *
1574
+ * array of uint32_t
1575
+ */
1576
+ CUBLASLT_ALGO_CAP_TILE_IDS = 6,
1577
+
1578
+ /** custom option range is from 0 to CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX (inclusive), see
1579
+ * CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
1580
+ *
1581
+ * int32_t
1582
+ */
1583
+ CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX = 7,
1584
+
1585
+ /** whether algorithm supports custom (not COL or ROW memory order), see cublasLtOrder_t
1586
+ *
1587
+ * int32_t 0 means only COL and ROW memory order is allowed, non-zero means that algo might have different
1588
+ * requirements;
1589
+ */
1590
+ CUBLASLT_ALGO_CAP_CUSTOM_MEMORY_ORDER = 10,
1591
+
1592
+ /** bitmask enumerating pointer modes algorithm supports
1593
+ *
1594
+ * uint32_t, see cublasLtPointerModeMask_t
1595
+ */
1596
+ CUBLASLT_ALGO_CAP_POINTER_MODE_MASK = 11,
1597
+
1598
+ /** bitmask enumerating kinds of postprocessing algorithm supports in the epilogue
1599
+ *
1600
+ * uint32_t, see cublasLtEpilogue_t
1601
+ */
1602
+ CUBLASLT_ALGO_CAP_EPILOGUE_MASK = 12,
1603
+
1604
+ /** stages ids possible to use, see cublasLtMatmulStages_t; if no stages ids are supported use
1605
+ * CUBLASLT_MATMUL_STAGES_UNDEFINED
1606
+ *
1607
+ * use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count
1608
+ *
1609
+ * array of uint32_t
1610
+ */
1611
+ CUBLASLT_ALGO_CAP_STAGES_IDS = 13,
1612
+
1613
+ /** support for nagative ld for all of the matrices
1614
+ *
1615
+ * int32_t 0 means no support, supported otherwise
1616
+ */
1617
+ CUBLASLT_ALGO_CAP_LD_NEGATIVE = 14,
1618
+
1619
+ /** details about algorithm's implementation that affect it's numerical behavior
1620
+ *
1621
+ * uint64_t, see cublasLtNumericalImplFlags_t
1622
+ */
1623
+ CUBLASLT_ALGO_CAP_NUMERICAL_IMPL_FLAGS = 15,
1624
+
1625
+ /** minimum alignment required for A matrix in bytes
1626
+ * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
1627
+ *
1628
+ * uint32_t
1629
+ */
1630
+ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_A_BYTES = 16,
1631
+
1632
+ /** minimum alignment required for B matrix in bytes
1633
+ * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
1634
+ *
1635
+ * uint32_t
1636
+ */
1637
+ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_B_BYTES = 17,
1638
+
1639
+ /** minimum alignment required for C matrix in bytes
1640
+ * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
1641
+ *
1642
+ * uint32_t
1643
+ */
1644
+ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_C_BYTES = 18,
1645
+
1646
+ /** minimum alignment required for D matrix in bytes
1647
+ * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
1648
+ *
1649
+ * uint32_t
1650
+ */
1651
+ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_D_BYTES = 19,
1652
+
1653
+ /** EXPERIMENTAL: support for synchronization via atomic counters
1654
+ *
1655
+ * int32_t
1656
+ */
1657
+ CUBLASLT_ALGO_CAP_ATOMIC_SYNC = 20,
1658
+ } cublasLtMatmulAlgoCapAttributes_t;
1659
+
1660
+ /** Get algo capability attribute.
1661
+ *
1662
+ * E.g. to get list of supported Tile IDs:
1663
+ * cublasLtMatmulTile_t tiles[CUBLASLT_MATMUL_TILE_END];
1664
+ * size_t num_tiles, size_written;
1665
+ * if (cublasLtMatmulAlgoCapGetAttribute(algo, CUBLASLT_ALGO_CAP_TILE_IDS, tiles, sizeof(tiles), size_written) ==
1666
+ * CUBLAS_STATUS_SUCCESS) { num_tiles = size_written / sizeof(tiles[0]);
1667
+ * }
1668
+ *
1669
+ * \param[in] algo The algo descriptor
1670
+ * \param[in] attr The attribute
1671
+ * \param[out] buf memory address containing the new value
1672
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1673
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
1674
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
1675
+ *
1676
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
1677
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
1678
+ * selected attribute
1679
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
1680
+ */
1681
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute(const cublasLtMatmulAlgo_t* algo,
1682
+ cublasLtMatmulAlgoCapAttributes_t attr,
1683
+ void* buf,
1684
+ size_t sizeInBytes,
1685
+ size_t* sizeWritten);
1686
+
1687
+ /** Algo Configuration Attributes that can be set according to the Algo capabilities
1688
+ */
1689
+ typedef enum {
1690
+ /** algorithm index, see cublasLtMatmulAlgoGetIds()
1691
+ *
1692
+ * readonly, set by cublasLtMatmulAlgoInit()
1693
+ * int32_t
1694
+ */
1695
+ CUBLASLT_ALGO_CONFIG_ID = 0,
1696
+ /** tile id, see cublasLtMatmulTile_t
1697
+ *
1698
+ * uint32_t, default: CUBLASLT_MATMUL_TILE_UNDEFINED
1699
+ */
1700
+ CUBLASLT_ALGO_CONFIG_TILE_ID = 1,
1701
+ /** Number of K splits. If the number of K splits is greater than one, SPLITK_NUM parts
1702
+ * of matrix multiplication will be computed in parallel. The results will be accumulated
1703
+ * according to CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
1704
+ *
1705
+ * int32_t, default: 1
1706
+ */
1707
+ CUBLASLT_ALGO_CONFIG_SPLITK_NUM = 2,
1708
+ /** reduction scheme, see cublasLtReductionScheme_t
1709
+ *
1710
+ * uint32_t, default: CUBLASLT_REDUCTION_SCHEME_NONE
1711
+ */
1712
+ CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME = 3,
1713
+ /** cta swizzling, change mapping from CUDA grid coordinates to parts of the matrices
1714
+ *
1715
+ * possible values: 0, 1, other values reserved
1716
+ *
1717
+ * uint32_t, default: 0
1718
+ */
1719
+ CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING = 4,
1720
+ /** custom option, each algorithm can support some custom options that don't fit description of the other config
1721
+ * attributes, see CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX to get accepted range for any specific case
1722
+ *
1723
+ * uint32_t, default: 0
1724
+ */
1725
+ CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION = 5,
1726
+ /** stages id, see cublasLtMatmulStages_t
1727
+ *
1728
+ * uint32_t, default: CUBLASLT_MATMUL_STAGES_UNDEFINED
1729
+ */
1730
+ CUBLASLT_ALGO_CONFIG_STAGES_ID = 6,
1731
+ /** inner shape id, see cublasLtMatmulInnerShape_t
1732
+ *
1733
+ * uint16_t, default: 0 (CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED)
1734
+ */
1735
+ CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID = 7,
1736
+ /** Thread Block Cluster shape id, see cublasLtClusterShape_t. Defines cluster size to use.
1737
+ *
1738
+ * uint16_t, default: 0 (CUBLASLT_CLUSTER_SHAPE_AUTO)
1739
+ */
1740
+ CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID = 8,
1741
+ } cublasLtMatmulAlgoConfigAttributes_t;
1742
+
1743
+ /** Set algo configuration attribute.
1744
+ *
1745
+ * \param[in] algo The algo descriptor
1746
+ * \param[in] attr The attribute
1747
+ * \param[in] buf memory address containing the new value
1748
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1749
+ *
1750
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
1751
+ * selected attribute
1752
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
1753
+ */
1754
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute(cublasLtMatmulAlgo_t* algo,
1755
+ cublasLtMatmulAlgoConfigAttributes_t attr,
1756
+ const void* buf,
1757
+ size_t sizeInBytes);
1758
+
1759
+ /** Get algo configuration attribute.
1760
+ *
1761
+ * \param[in] algo The algo descriptor
1762
+ * \param[in] attr The attribute
1763
+ * \param[out] buf memory address containing the new value
1764
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1765
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
1766
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
1767
+ *
1768
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
1769
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
1770
+ * selected attribute
1771
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
1772
+ */
1773
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute(const cublasLtMatmulAlgo_t* algo,
1774
+ cublasLtMatmulAlgoConfigAttributes_t attr,
1775
+ void* buf,
1776
+ size_t sizeInBytes,
1777
+ size_t* sizeWritten);
1778
+
1779
+ /** Experimental: Logger callback type.
1780
+ */
1781
+ typedef void (*cublasLtLoggerCallback_t)(int logLevel, const char* functionName, const char* message);
1782
+
1783
+ /** Experimental: Logger callback setter.
1784
+ *
1785
+ * \param[in] callback a user defined callback function to be called by the logger
1786
+ *
1787
+ * \retval CUBLAS_STATUS_SUCCESS if callback was set successfully
1788
+ */
1789
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetCallback(cublasLtLoggerCallback_t callback);
1790
+
1791
+ /** Experimental: Log file setter.
1792
+ *
1793
+ * \param[in] file an open file with write permissions
1794
+ *
1795
+ * \retval CUBLAS_STATUS_SUCCESS if log file was set successfully
1796
+ */
1797
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetFile(FILE* file);
1798
+
1799
+ /** Experimental: Open log file.
1800
+ *
1801
+ * \param[in] logFile log file path. if the log file does not exist, it will be created
1802
+ *
1803
+ * \retval CUBLAS_STATUS_SUCCESS if log file was created successfully
1804
+ */
1805
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerOpenFile(const char* logFile);
1806
+
1807
+ /** Experimental: Log level setter.
1808
+ *
1809
+ * \param[in] level log level, should be one of the following:
1810
+ * 0. Off
1811
+ * 1. Errors
1812
+ * 2. Performance Trace
1813
+ * 3. Performance Hints
1814
+ * 4. Heuristics Trace
1815
+ * 5. API Trace
1816
+ *
1817
+ * \retval CUBLAS_STATUS_INVALID_VALUE if log level is not one of the above levels
1818
+ *
1819
+ * \retval CUBLAS_STATUS_SUCCESS if log level was set successfully
1820
+ */
1821
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetLevel(int level);
1822
+
1823
+ /** Experimental: Log mask setter.
1824
+ *
1825
+ * \param[in] mask log mask, should be a combination of the following masks:
1826
+ * 0. Off
1827
+ * 1. Errors
1828
+ * 2. Performance Trace
1829
+ * 4. Performance Hints
1830
+ * 8. Heuristics Trace
1831
+ * 16. API Trace
1832
+ *
1833
+ * \retval CUBLAS_STATUS_SUCCESS if log mask was set successfully
1834
+ */
1835
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetMask(int mask);
1836
+
1837
+ /** Experimental: Disable logging for the entire session.
1838
+ *
1839
+ * \retval CUBLAS_STATUS_SUCCESS if disabled logging
1840
+ */
1841
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerForceDisable();
1842
+
1843
+ #if defined(__cplusplus)
1844
+ }
1845
+ #endif /* __cplusplus */
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasXt.h ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /* cublasXt : Host API, Out of Core and Multi-GPU BLAS Library
51
+
52
+ */
53
+
54
+ #if !defined(CUBLAS_XT_H_)
55
+ #define CUBLAS_XT_H_
56
+
57
+ #include "driver_types.h"
58
+ #include "cuComplex.h" /* import complex data type */
59
+
60
+ #include "cublas_v2.h"
61
+
62
+ #if defined(__cplusplus)
63
+ extern "C" {
64
+ #endif /* __cplusplus */
65
+
66
+ struct cublasXtContext;
67
+ typedef struct cublasXtContext* cublasXtHandle_t;
68
+
69
+ cublasStatus_t CUBLASWINAPI cublasXtCreate(cublasXtHandle_t* handle);
70
+ cublasStatus_t CUBLASWINAPI cublasXtDestroy(cublasXtHandle_t handle);
71
+ cublasStatus_t CUBLASWINAPI cublasXtGetNumBoards(int nbDevices, int deviceId[], int* nbBoards);
72
+ cublasStatus_t CUBLASWINAPI cublasXtMaxBoards(int* nbGpuBoards);
73
+ /* This routine selects the Gpus that the user want to use for CUBLAS-XT */
74
+ cublasStatus_t CUBLASWINAPI cublasXtDeviceSelect(cublasXtHandle_t handle, int nbDevices, int deviceId[]);
75
+
76
+ /* This routine allows to change the dimension of the tiles ( blockDim x blockDim ) */
77
+ cublasStatus_t CUBLASWINAPI cublasXtSetBlockDim(cublasXtHandle_t handle, int blockDim);
78
+ cublasStatus_t CUBLASWINAPI cublasXtGetBlockDim(cublasXtHandle_t handle, int* blockDim);
79
+
80
+ typedef enum { CUBLASXT_PINNING_DISABLED = 0, CUBLASXT_PINNING_ENABLED = 1 } cublasXtPinnedMemMode_t;
81
+ /* This routine allows to CUBLAS-XT to pin the Host memory if it find out that some of the matrix passed
82
+ are not pinned : Pinning/Unpinning the Host memory is still a costly operation
83
+ It is better if the user controls the memory on its own (by pinning/unpinning oly when necessary)
84
+ */
85
+ cublasStatus_t CUBLASWINAPI cublasXtGetPinningMemMode(cublasXtHandle_t handle, cublasXtPinnedMemMode_t* mode);
86
+ cublasStatus_t CUBLASWINAPI cublasXtSetPinningMemMode(cublasXtHandle_t handle, cublasXtPinnedMemMode_t mode);
87
+
88
+ /* This routines is to provide a CPU Blas routines, used for too small sizes or hybrid computation */
89
+ typedef enum {
90
+ CUBLASXT_FLOAT = 0,
91
+ CUBLASXT_DOUBLE = 1,
92
+ CUBLASXT_COMPLEX = 2,
93
+ CUBLASXT_DOUBLECOMPLEX = 3,
94
+ } cublasXtOpType_t;
95
+
96
+ typedef enum {
97
+ CUBLASXT_GEMM = 0,
98
+ CUBLASXT_SYRK = 1,
99
+ CUBLASXT_HERK = 2,
100
+ CUBLASXT_SYMM = 3,
101
+ CUBLASXT_HEMM = 4,
102
+ CUBLASXT_TRSM = 5,
103
+ CUBLASXT_SYR2K = 6,
104
+ CUBLASXT_HER2K = 7,
105
+
106
+ CUBLASXT_SPMM = 8,
107
+ CUBLASXT_SYRKX = 9,
108
+ CUBLASXT_HERKX = 10,
109
+ CUBLASXT_TRMM = 11,
110
+ CUBLASXT_ROUTINE_MAX = 12,
111
+ } cublasXtBlasOp_t;
112
+
113
+ /* Currently only 32-bit integer BLAS routines are supported */
114
+ cublasStatus_t CUBLASWINAPI cublasXtSetCpuRoutine(cublasXtHandle_t handle,
115
+ cublasXtBlasOp_t blasOp,
116
+ cublasXtOpType_t type,
117
+ void* blasFunctor);
118
+
119
+ /* Specified the percentage of work that should done by the CPU, default is 0 (no work) */
120
+ cublasStatus_t CUBLASWINAPI cublasXtSetCpuRatio(cublasXtHandle_t handle,
121
+ cublasXtBlasOp_t blasOp,
122
+ cublasXtOpType_t type,
123
+ float ratio);
124
+
125
+ /* GEMM */
126
+ cublasStatus_t CUBLASWINAPI cublasXtSgemm(cublasXtHandle_t handle,
127
+ cublasOperation_t transa,
128
+ cublasOperation_t transb,
129
+ size_t m,
130
+ size_t n,
131
+ size_t k,
132
+ const float* alpha,
133
+ const float* A,
134
+ size_t lda,
135
+ const float* B,
136
+ size_t ldb,
137
+ const float* beta,
138
+ float* C,
139
+ size_t ldc);
140
+
141
+ cublasStatus_t CUBLASWINAPI cublasXtDgemm(cublasXtHandle_t handle,
142
+ cublasOperation_t transa,
143
+ cublasOperation_t transb,
144
+ size_t m,
145
+ size_t n,
146
+ size_t k,
147
+ const double* alpha,
148
+ const double* A,
149
+ size_t lda,
150
+ const double* B,
151
+ size_t ldb,
152
+ const double* beta,
153
+ double* C,
154
+ size_t ldc);
155
+
156
+ cublasStatus_t CUBLASWINAPI cublasXtCgemm(cublasXtHandle_t handle,
157
+ cublasOperation_t transa,
158
+ cublasOperation_t transb,
159
+ size_t m,
160
+ size_t n,
161
+ size_t k,
162
+ const cuComplex* alpha,
163
+ const cuComplex* A,
164
+ size_t lda,
165
+ const cuComplex* B,
166
+ size_t ldb,
167
+ const cuComplex* beta,
168
+ cuComplex* C,
169
+ size_t ldc);
170
+
171
+ cublasStatus_t CUBLASWINAPI cublasXtZgemm(cublasXtHandle_t handle,
172
+ cublasOperation_t transa,
173
+ cublasOperation_t transb,
174
+ size_t m,
175
+ size_t n,
176
+ size_t k,
177
+ const cuDoubleComplex* alpha,
178
+ const cuDoubleComplex* A,
179
+ size_t lda,
180
+ const cuDoubleComplex* B,
181
+ size_t ldb,
182
+ const cuDoubleComplex* beta,
183
+ cuDoubleComplex* C,
184
+ size_t ldc);
185
+ /* ------------------------------------------------------- */
186
+ /* SYRK */
187
+ cublasStatus_t CUBLASWINAPI cublasXtSsyrk(cublasXtHandle_t handle,
188
+ cublasFillMode_t uplo,
189
+ cublasOperation_t trans,
190
+ size_t n,
191
+ size_t k,
192
+ const float* alpha,
193
+ const float* A,
194
+ size_t lda,
195
+ const float* beta,
196
+ float* C,
197
+ size_t ldc);
198
+
199
+ cublasStatus_t CUBLASWINAPI cublasXtDsyrk(cublasXtHandle_t handle,
200
+ cublasFillMode_t uplo,
201
+ cublasOperation_t trans,
202
+ size_t n,
203
+ size_t k,
204
+ const double* alpha,
205
+ const double* A,
206
+ size_t lda,
207
+ const double* beta,
208
+ double* C,
209
+ size_t ldc);
210
+
211
+ cublasStatus_t CUBLASWINAPI cublasXtCsyrk(cublasXtHandle_t handle,
212
+ cublasFillMode_t uplo,
213
+ cublasOperation_t trans,
214
+ size_t n,
215
+ size_t k,
216
+ const cuComplex* alpha,
217
+ const cuComplex* A,
218
+ size_t lda,
219
+ const cuComplex* beta,
220
+ cuComplex* C,
221
+ size_t ldc);
222
+
223
+ cublasStatus_t CUBLASWINAPI cublasXtZsyrk(cublasXtHandle_t handle,
224
+ cublasFillMode_t uplo,
225
+ cublasOperation_t trans,
226
+ size_t n,
227
+ size_t k,
228
+ const cuDoubleComplex* alpha,
229
+ const cuDoubleComplex* A,
230
+ size_t lda,
231
+ const cuDoubleComplex* beta,
232
+ cuDoubleComplex* C,
233
+ size_t ldc);
234
+ /* -------------------------------------------------------------------- */
235
+ /* HERK */
236
+ cublasStatus_t CUBLASWINAPI cublasXtCherk(cublasXtHandle_t handle,
237
+ cublasFillMode_t uplo,
238
+ cublasOperation_t trans,
239
+ size_t n,
240
+ size_t k,
241
+ const float* alpha,
242
+ const cuComplex* A,
243
+ size_t lda,
244
+ const float* beta,
245
+ cuComplex* C,
246
+ size_t ldc);
247
+
248
+ cublasStatus_t CUBLASWINAPI cublasXtZherk(cublasXtHandle_t handle,
249
+ cublasFillMode_t uplo,
250
+ cublasOperation_t trans,
251
+ size_t n,
252
+ size_t k,
253
+ const double* alpha,
254
+ const cuDoubleComplex* A,
255
+ size_t lda,
256
+ const double* beta,
257
+ cuDoubleComplex* C,
258
+ size_t ldc);
259
+ /* -------------------------------------------------------------------- */
260
+ /* SYR2K */
261
+ cublasStatus_t CUBLASWINAPI cublasXtSsyr2k(cublasXtHandle_t handle,
262
+ cublasFillMode_t uplo,
263
+ cublasOperation_t trans,
264
+ size_t n,
265
+ size_t k,
266
+ const float* alpha,
267
+ const float* A,
268
+ size_t lda,
269
+ const float* B,
270
+ size_t ldb,
271
+ const float* beta,
272
+ float* C,
273
+ size_t ldc);
274
+
275
+ cublasStatus_t CUBLASWINAPI cublasXtDsyr2k(cublasXtHandle_t handle,
276
+ cublasFillMode_t uplo,
277
+ cublasOperation_t trans,
278
+ size_t n,
279
+ size_t k,
280
+ const double* alpha,
281
+ const double* A,
282
+ size_t lda,
283
+ const double* B,
284
+ size_t ldb,
285
+ const double* beta,
286
+ double* C,
287
+ size_t ldc);
288
+
289
+ cublasStatus_t CUBLASWINAPI cublasXtCsyr2k(cublasXtHandle_t handle,
290
+ cublasFillMode_t uplo,
291
+ cublasOperation_t trans,
292
+ size_t n,
293
+ size_t k,
294
+ const cuComplex* alpha,
295
+ const cuComplex* A,
296
+ size_t lda,
297
+ const cuComplex* B,
298
+ size_t ldb,
299
+ const cuComplex* beta,
300
+ cuComplex* C,
301
+ size_t ldc);
302
+
303
+ cublasStatus_t CUBLASWINAPI cublasXtZsyr2k(cublasXtHandle_t handle,
304
+ cublasFillMode_t uplo,
305
+ cublasOperation_t trans,
306
+ size_t n,
307
+ size_t k,
308
+ const cuDoubleComplex* alpha,
309
+ const cuDoubleComplex* A,
310
+ size_t lda,
311
+ const cuDoubleComplex* B,
312
+ size_t ldb,
313
+ const cuDoubleComplex* beta,
314
+ cuDoubleComplex* C,
315
+ size_t ldc);
316
+ /* -------------------------------------------------------------------- */
317
+ /* HERKX : variant extension of HERK */
318
+ cublasStatus_t CUBLASWINAPI cublasXtCherkx(cublasXtHandle_t handle,
319
+ cublasFillMode_t uplo,
320
+ cublasOperation_t trans,
321
+ size_t n,
322
+ size_t k,
323
+ const cuComplex* alpha,
324
+ const cuComplex* A,
325
+ size_t lda,
326
+ const cuComplex* B,
327
+ size_t ldb,
328
+ const float* beta,
329
+ cuComplex* C,
330
+ size_t ldc);
331
+
332
+ cublasStatus_t CUBLASWINAPI cublasXtZherkx(cublasXtHandle_t handle,
333
+ cublasFillMode_t uplo,
334
+ cublasOperation_t trans,
335
+ size_t n,
336
+ size_t k,
337
+ const cuDoubleComplex* alpha,
338
+ const cuDoubleComplex* A,
339
+ size_t lda,
340
+ const cuDoubleComplex* B,
341
+ size_t ldb,
342
+ const double* beta,
343
+ cuDoubleComplex* C,
344
+ size_t ldc);
345
+
346
+ /* -------------------------------------------------------------------- */
347
+ /* TRSM */
348
+ cublasStatus_t CUBLASWINAPI cublasXtStrsm(cublasXtHandle_t handle,
349
+ cublasSideMode_t side,
350
+ cublasFillMode_t uplo,
351
+ cublasOperation_t trans,
352
+ cublasDiagType_t diag,
353
+ size_t m,
354
+ size_t n,
355
+ const float* alpha,
356
+ const float* A,
357
+ size_t lda,
358
+ float* B,
359
+ size_t ldb);
360
+
361
+ cublasStatus_t CUBLASWINAPI cublasXtDtrsm(cublasXtHandle_t handle,
362
+ cublasSideMode_t side,
363
+ cublasFillMode_t uplo,
364
+ cublasOperation_t trans,
365
+ cublasDiagType_t diag,
366
+ size_t m,
367
+ size_t n,
368
+ const double* alpha,
369
+ const double* A,
370
+ size_t lda,
371
+ double* B,
372
+ size_t ldb);
373
+
374
+ cublasStatus_t CUBLASWINAPI cublasXtCtrsm(cublasXtHandle_t handle,
375
+ cublasSideMode_t side,
376
+ cublasFillMode_t uplo,
377
+ cublasOperation_t trans,
378
+ cublasDiagType_t diag,
379
+ size_t m,
380
+ size_t n,
381
+ const cuComplex* alpha,
382
+ const cuComplex* A,
383
+ size_t lda,
384
+ cuComplex* B,
385
+ size_t ldb);
386
+
387
+ cublasStatus_t CUBLASWINAPI cublasXtZtrsm(cublasXtHandle_t handle,
388
+ cublasSideMode_t side,
389
+ cublasFillMode_t uplo,
390
+ cublasOperation_t trans,
391
+ cublasDiagType_t diag,
392
+ size_t m,
393
+ size_t n,
394
+ const cuDoubleComplex* alpha,
395
+ const cuDoubleComplex* A,
396
+ size_t lda,
397
+ cuDoubleComplex* B,
398
+ size_t ldb);
399
+ /* -------------------------------------------------------------------- */
400
+ /* SYMM : Symmetric Multiply Matrix*/
401
+ cublasStatus_t CUBLASWINAPI cublasXtSsymm(cublasXtHandle_t handle,
402
+ cublasSideMode_t side,
403
+ cublasFillMode_t uplo,
404
+ size_t m,
405
+ size_t n,
406
+ const float* alpha,
407
+ const float* A,
408
+ size_t lda,
409
+ const float* B,
410
+ size_t ldb,
411
+ const float* beta,
412
+ float* C,
413
+ size_t ldc);
414
+
415
+ cublasStatus_t CUBLASWINAPI cublasXtDsymm(cublasXtHandle_t handle,
416
+ cublasSideMode_t side,
417
+ cublasFillMode_t uplo,
418
+ size_t m,
419
+ size_t n,
420
+ const double* alpha,
421
+ const double* A,
422
+ size_t lda,
423
+ const double* B,
424
+ size_t ldb,
425
+ const double* beta,
426
+ double* C,
427
+ size_t ldc);
428
+
429
+ cublasStatus_t CUBLASWINAPI cublasXtCsymm(cublasXtHandle_t handle,
430
+ cublasSideMode_t side,
431
+ cublasFillMode_t uplo,
432
+ size_t m,
433
+ size_t n,
434
+ const cuComplex* alpha,
435
+ const cuComplex* A,
436
+ size_t lda,
437
+ const cuComplex* B,
438
+ size_t ldb,
439
+ const cuComplex* beta,
440
+ cuComplex* C,
441
+ size_t ldc);
442
+
443
+ cublasStatus_t CUBLASWINAPI cublasXtZsymm(cublasXtHandle_t handle,
444
+ cublasSideMode_t side,
445
+ cublasFillMode_t uplo,
446
+ size_t m,
447
+ size_t n,
448
+ const cuDoubleComplex* alpha,
449
+ const cuDoubleComplex* A,
450
+ size_t lda,
451
+ const cuDoubleComplex* B,
452
+ size_t ldb,
453
+ const cuDoubleComplex* beta,
454
+ cuDoubleComplex* C,
455
+ size_t ldc);
456
+ /* -------------------------------------------------------------------- */
457
+ /* HEMM : Hermitian Matrix Multiply */
458
+ cublasStatus_t CUBLASWINAPI cublasXtChemm(cublasXtHandle_t handle,
459
+ cublasSideMode_t side,
460
+ cublasFillMode_t uplo,
461
+ size_t m,
462
+ size_t n,
463
+ const cuComplex* alpha,
464
+ const cuComplex* A,
465
+ size_t lda,
466
+ const cuComplex* B,
467
+ size_t ldb,
468
+ const cuComplex* beta,
469
+ cuComplex* C,
470
+ size_t ldc);
471
+
472
+ cublasStatus_t CUBLASWINAPI cublasXtZhemm(cublasXtHandle_t handle,
473
+ cublasSideMode_t side,
474
+ cublasFillMode_t uplo,
475
+ size_t m,
476
+ size_t n,
477
+ const cuDoubleComplex* alpha,
478
+ const cuDoubleComplex* A,
479
+ size_t lda,
480
+ const cuDoubleComplex* B,
481
+ size_t ldb,
482
+ const cuDoubleComplex* beta,
483
+ cuDoubleComplex* C,
484
+ size_t ldc);
485
+
486
+ /* -------------------------------------------------------------------- */
487
+ /* SYRKX : variant extension of SYRK */
488
+ cublasStatus_t CUBLASWINAPI cublasXtSsyrkx(cublasXtHandle_t handle,
489
+ cublasFillMode_t uplo,
490
+ cublasOperation_t trans,
491
+ size_t n,
492
+ size_t k,
493
+ const float* alpha,
494
+ const float* A,
495
+ size_t lda,
496
+ const float* B,
497
+ size_t ldb,
498
+ const float* beta,
499
+ float* C,
500
+ size_t ldc);
501
+
502
+ cublasStatus_t CUBLASWINAPI cublasXtDsyrkx(cublasXtHandle_t handle,
503
+ cublasFillMode_t uplo,
504
+ cublasOperation_t trans,
505
+ size_t n,
506
+ size_t k,
507
+ const double* alpha,
508
+ const double* A,
509
+ size_t lda,
510
+ const double* B,
511
+ size_t ldb,
512
+ const double* beta,
513
+ double* C,
514
+ size_t ldc);
515
+
516
+ cublasStatus_t CUBLASWINAPI cublasXtCsyrkx(cublasXtHandle_t handle,
517
+ cublasFillMode_t uplo,
518
+ cublasOperation_t trans,
519
+ size_t n,
520
+ size_t k,
521
+ const cuComplex* alpha,
522
+ const cuComplex* A,
523
+ size_t lda,
524
+ const cuComplex* B,
525
+ size_t ldb,
526
+ const cuComplex* beta,
527
+ cuComplex* C,
528
+ size_t ldc);
529
+
530
+ cublasStatus_t CUBLASWINAPI cublasXtZsyrkx(cublasXtHandle_t handle,
531
+ cublasFillMode_t uplo,
532
+ cublasOperation_t trans,
533
+ size_t n,
534
+ size_t k,
535
+ const cuDoubleComplex* alpha,
536
+ const cuDoubleComplex* A,
537
+ size_t lda,
538
+ const cuDoubleComplex* B,
539
+ size_t ldb,
540
+ const cuDoubleComplex* beta,
541
+ cuDoubleComplex* C,
542
+ size_t ldc);
543
+ /* -------------------------------------------------------------------- */
544
+ /* HER2K : variant extension of HERK */
545
+ cublasStatus_t CUBLASWINAPI cublasXtCher2k(cublasXtHandle_t handle,
546
+ cublasFillMode_t uplo,
547
+ cublasOperation_t trans,
548
+ size_t n,
549
+ size_t k,
550
+ const cuComplex* alpha,
551
+ const cuComplex* A,
552
+ size_t lda,
553
+ const cuComplex* B,
554
+ size_t ldb,
555
+ const float* beta,
556
+ cuComplex* C,
557
+ size_t ldc);
558
+
559
+ cublasStatus_t CUBLASWINAPI cublasXtZher2k(cublasXtHandle_t handle,
560
+ cublasFillMode_t uplo,
561
+ cublasOperation_t trans,
562
+ size_t n,
563
+ size_t k,
564
+ const cuDoubleComplex* alpha,
565
+ const cuDoubleComplex* A,
566
+ size_t lda,
567
+ const cuDoubleComplex* B,
568
+ size_t ldb,
569
+ const double* beta,
570
+ cuDoubleComplex* C,
571
+ size_t ldc);
572
+
573
+ /* -------------------------------------------------------------------- */
574
+ /* SPMM : Symmetric Packed Multiply Matrix*/
575
+ cublasStatus_t CUBLASWINAPI cublasXtSspmm(cublasXtHandle_t handle,
576
+ cublasSideMode_t side,
577
+ cublasFillMode_t uplo,
578
+ size_t m,
579
+ size_t n,
580
+ const float* alpha,
581
+ const float* AP,
582
+ const float* B,
583
+ size_t ldb,
584
+ const float* beta,
585
+ float* C,
586
+ size_t ldc);
587
+
588
+ cublasStatus_t CUBLASWINAPI cublasXtDspmm(cublasXtHandle_t handle,
589
+ cublasSideMode_t side,
590
+ cublasFillMode_t uplo,
591
+ size_t m,
592
+ size_t n,
593
+ const double* alpha,
594
+ const double* AP,
595
+ const double* B,
596
+ size_t ldb,
597
+ const double* beta,
598
+ double* C,
599
+ size_t ldc);
600
+
601
+ cublasStatus_t CUBLASWINAPI cublasXtCspmm(cublasXtHandle_t handle,
602
+ cublasSideMode_t side,
603
+ cublasFillMode_t uplo,
604
+ size_t m,
605
+ size_t n,
606
+ const cuComplex* alpha,
607
+ const cuComplex* AP,
608
+ const cuComplex* B,
609
+ size_t ldb,
610
+ const cuComplex* beta,
611
+ cuComplex* C,
612
+ size_t ldc);
613
+
614
+ cublasStatus_t CUBLASWINAPI cublasXtZspmm(cublasXtHandle_t handle,
615
+ cublasSideMode_t side,
616
+ cublasFillMode_t uplo,
617
+ size_t m,
618
+ size_t n,
619
+ const cuDoubleComplex* alpha,
620
+ const cuDoubleComplex* AP,
621
+ const cuDoubleComplex* B,
622
+ size_t ldb,
623
+ const cuDoubleComplex* beta,
624
+ cuDoubleComplex* C,
625
+ size_t ldc);
626
+
627
+ /* -------------------------------------------------------------------- */
628
+ /* TRMM */
629
+ cublasStatus_t CUBLASWINAPI cublasXtStrmm(cublasXtHandle_t handle,
630
+ cublasSideMode_t side,
631
+ cublasFillMode_t uplo,
632
+ cublasOperation_t trans,
633
+ cublasDiagType_t diag,
634
+ size_t m,
635
+ size_t n,
636
+ const float* alpha,
637
+ const float* A,
638
+ size_t lda,
639
+ const float* B,
640
+ size_t ldb,
641
+ float* C,
642
+ size_t ldc);
643
+
644
+ cublasStatus_t CUBLASWINAPI cublasXtDtrmm(cublasXtHandle_t handle,
645
+ cublasSideMode_t side,
646
+ cublasFillMode_t uplo,
647
+ cublasOperation_t trans,
648
+ cublasDiagType_t diag,
649
+ size_t m,
650
+ size_t n,
651
+ const double* alpha,
652
+ const double* A,
653
+ size_t lda,
654
+ const double* B,
655
+ size_t ldb,
656
+ double* C,
657
+ size_t ldc);
658
+
659
+ cublasStatus_t CUBLASWINAPI cublasXtCtrmm(cublasXtHandle_t handle,
660
+ cublasSideMode_t side,
661
+ cublasFillMode_t uplo,
662
+ cublasOperation_t trans,
663
+ cublasDiagType_t diag,
664
+ size_t m,
665
+ size_t n,
666
+ const cuComplex* alpha,
667
+ const cuComplex* A,
668
+ size_t lda,
669
+ const cuComplex* B,
670
+ size_t ldb,
671
+ cuComplex* C,
672
+ size_t ldc);
673
+
674
+ cublasStatus_t CUBLASWINAPI cublasXtZtrmm(cublasXtHandle_t handle,
675
+ cublasSideMode_t side,
676
+ cublasFillMode_t uplo,
677
+ cublasOperation_t trans,
678
+ cublasDiagType_t diag,
679
+ size_t m,
680
+ size_t n,
681
+ const cuDoubleComplex* alpha,
682
+ const cuDoubleComplex* A,
683
+ size_t lda,
684
+ const cuDoubleComplex* B,
685
+ size_t ldb,
686
+ cuDoubleComplex* C,
687
+ size_t ldc);
688
+
689
+ #if defined(__cplusplus)
690
+ }
691
+ #endif /* __cplusplus */
692
+
693
+ #endif /* !defined(CUBLAS_XT_H_) */
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * This is the public header file for the new CUBLAS library API, it mapped the generic
52
+ * Cublas name functions to the actual _v2 implementations.
53
+ */
54
+
55
+ #if !defined(CUBLAS_V2_H_)
56
+ #define CUBLAS_V2_H_
57
+
58
+ #if defined(CUBLAS_H_)
59
+ #error "It is an error to include both cublas.h and cublas_v2.h"
60
+ #endif
61
+
62
+ #undef CUBLASAPI
63
+ #ifdef __CUDACC__
64
+ #define CUBLASAPI __host__ __device__
65
+ #else
66
+ #define CUBLASAPI
67
+ #endif
68
+
69
+ #include "cublas_api.h"
70
+
71
+ #define cublasCreate cublasCreate_v2
72
+ #define cublasDestroy cublasDestroy_v2
73
+ #define cublasGetVersion cublasGetVersion_v2
74
+ #define cublasSetWorkspace cublasSetWorkspace_v2
75
+ #define cublasSetStream cublasSetStream_v2
76
+ #define cublasGetStream cublasGetStream_v2
77
+ #define cublasGetPointerMode cublasGetPointerMode_v2
78
+ #define cublasSetPointerMode cublasSetPointerMode_v2
79
+
80
+ /* 32-bit integer */
81
+
82
+ /* Blas1 Routines */
83
+
84
+ #define cublasSnrm2 cublasSnrm2_v2
85
+ #define cublasDnrm2 cublasDnrm2_v2
86
+ #define cublasScnrm2 cublasScnrm2_v2
87
+ #define cublasDznrm2 cublasDznrm2_v2
88
+
89
+ #define cublasSdot cublasSdot_v2
90
+ #define cublasDdot cublasDdot_v2
91
+ #define cublasCdotu cublasCdotu_v2
92
+ #define cublasCdotc cublasCdotc_v2
93
+ #define cublasZdotu cublasZdotu_v2
94
+ #define cublasZdotc cublasZdotc_v2
95
+
96
+ #define cublasSscal cublasSscal_v2
97
+ #define cublasDscal cublasDscal_v2
98
+ #define cublasCscal cublasCscal_v2
99
+ #define cublasCsscal cublasCsscal_v2
100
+ #define cublasZscal cublasZscal_v2
101
+ #define cublasZdscal cublasZdscal_v2
102
+
103
+ #define cublasSaxpy cublasSaxpy_v2
104
+ #define cublasDaxpy cublasDaxpy_v2
105
+ #define cublasCaxpy cublasCaxpy_v2
106
+ #define cublasZaxpy cublasZaxpy_v2
107
+
108
+ #define cublasScopy cublasScopy_v2
109
+ #define cublasDcopy cublasDcopy_v2
110
+ #define cublasCcopy cublasCcopy_v2
111
+ #define cublasZcopy cublasZcopy_v2
112
+
113
+ #define cublasSswap cublasSswap_v2
114
+ #define cublasDswap cublasDswap_v2
115
+ #define cublasCswap cublasCswap_v2
116
+ #define cublasZswap cublasZswap_v2
117
+
118
+ #define cublasIsamax cublasIsamax_v2
119
+ #define cublasIdamax cublasIdamax_v2
120
+ #define cublasIcamax cublasIcamax_v2
121
+ #define cublasIzamax cublasIzamax_v2
122
+
123
+ #define cublasIsamin cublasIsamin_v2
124
+ #define cublasIdamin cublasIdamin_v2
125
+ #define cublasIcamin cublasIcamin_v2
126
+ #define cublasIzamin cublasIzamin_v2
127
+
128
+ #define cublasSasum cublasSasum_v2
129
+ #define cublasDasum cublasDasum_v2
130
+ #define cublasScasum cublasScasum_v2
131
+ #define cublasDzasum cublasDzasum_v2
132
+
133
+ #define cublasSrot cublasSrot_v2
134
+ #define cublasDrot cublasDrot_v2
135
+ #define cublasCrot cublasCrot_v2
136
+ #define cublasCsrot cublasCsrot_v2
137
+ #define cublasZrot cublasZrot_v2
138
+ #define cublasZdrot cublasZdrot_v2
139
+
140
+ #define cublasSrotg cublasSrotg_v2
141
+ #define cublasDrotg cublasDrotg_v2
142
+ #define cublasCrotg cublasCrotg_v2
143
+ #define cublasZrotg cublasZrotg_v2
144
+
145
+ #define cublasSrotm cublasSrotm_v2
146
+ #define cublasDrotm cublasDrotm_v2
147
+
148
+ #define cublasSrotmg cublasSrotmg_v2
149
+ #define cublasDrotmg cublasDrotmg_v2
150
+
151
+ /* Blas2 Routines */
152
+
153
+ #define cublasSgemv cublasSgemv_v2
154
+ #define cublasDgemv cublasDgemv_v2
155
+ #define cublasCgemv cublasCgemv_v2
156
+ #define cublasZgemv cublasZgemv_v2
157
+
158
+ #define cublasSgbmv cublasSgbmv_v2
159
+ #define cublasDgbmv cublasDgbmv_v2
160
+ #define cublasCgbmv cublasCgbmv_v2
161
+ #define cublasZgbmv cublasZgbmv_v2
162
+
163
+ #define cublasStrmv cublasStrmv_v2
164
+ #define cublasDtrmv cublasDtrmv_v2
165
+ #define cublasCtrmv cublasCtrmv_v2
166
+ #define cublasZtrmv cublasZtrmv_v2
167
+
168
+ #define cublasStbmv cublasStbmv_v2
169
+ #define cublasDtbmv cublasDtbmv_v2
170
+ #define cublasCtbmv cublasCtbmv_v2
171
+ #define cublasZtbmv cublasZtbmv_v2
172
+
173
+ #define cublasStpmv cublasStpmv_v2
174
+ #define cublasDtpmv cublasDtpmv_v2
175
+ #define cublasCtpmv cublasCtpmv_v2
176
+ #define cublasZtpmv cublasZtpmv_v2
177
+
178
+ #define cublasStrsv cublasStrsv_v2
179
+ #define cublasDtrsv cublasDtrsv_v2
180
+ #define cublasCtrsv cublasCtrsv_v2
181
+ #define cublasZtrsv cublasZtrsv_v2
182
+
183
+ #define cublasStpsv cublasStpsv_v2
184
+ #define cublasDtpsv cublasDtpsv_v2
185
+ #define cublasCtpsv cublasCtpsv_v2
186
+ #define cublasZtpsv cublasZtpsv_v2
187
+
188
+ #define cublasStbsv cublasStbsv_v2
189
+ #define cublasDtbsv cublasDtbsv_v2
190
+ #define cublasCtbsv cublasCtbsv_v2
191
+ #define cublasZtbsv cublasZtbsv_v2
192
+
193
+ #define cublasSsymv cublasSsymv_v2
194
+ #define cublasDsymv cublasDsymv_v2
195
+ #define cublasCsymv cublasCsymv_v2
196
+ #define cublasZsymv cublasZsymv_v2
197
+ #define cublasChemv cublasChemv_v2
198
+ #define cublasZhemv cublasZhemv_v2
199
+
200
+ #define cublasSsbmv cublasSsbmv_v2
201
+ #define cublasDsbmv cublasDsbmv_v2
202
+ #define cublasChbmv cublasChbmv_v2
203
+ #define cublasZhbmv cublasZhbmv_v2
204
+
205
+ #define cublasSspmv cublasSspmv_v2
206
+ #define cublasDspmv cublasDspmv_v2
207
+ #define cublasChpmv cublasChpmv_v2
208
+ #define cublasZhpmv cublasZhpmv_v2
209
+
210
+ #define cublasSger cublasSger_v2
211
+ #define cublasDger cublasDger_v2
212
+ #define cublasCgeru cublasCgeru_v2
213
+ #define cublasCgerc cublasCgerc_v2
214
+ #define cublasZgeru cublasZgeru_v2
215
+ #define cublasZgerc cublasZgerc_v2
216
+
217
+ #define cublasSsyr cublasSsyr_v2
218
+ #define cublasDsyr cublasDsyr_v2
219
+ #define cublasCsyr cublasCsyr_v2
220
+ #define cublasZsyr cublasZsyr_v2
221
+ #define cublasCher cublasCher_v2
222
+ #define cublasZher cublasZher_v2
223
+
224
+ #define cublasSspr cublasSspr_v2
225
+ #define cublasDspr cublasDspr_v2
226
+ #define cublasChpr cublasChpr_v2
227
+ #define cublasZhpr cublasZhpr_v2
228
+
229
+ #define cublasSsyr2 cublasSsyr2_v2
230
+ #define cublasDsyr2 cublasDsyr2_v2
231
+ #define cublasCsyr2 cublasCsyr2_v2
232
+ #define cublasZsyr2 cublasZsyr2_v2
233
+ #define cublasCher2 cublasCher2_v2
234
+ #define cublasZher2 cublasZher2_v2
235
+
236
+ #define cublasSspr2 cublasSspr2_v2
237
+ #define cublasDspr2 cublasDspr2_v2
238
+ #define cublasChpr2 cublasChpr2_v2
239
+ #define cublasZhpr2 cublasZhpr2_v2
240
+
241
+ /* Blas3 Routines */
242
+
243
+ #define cublasSgemm cublasSgemm_v2
244
+ #define cublasDgemm cublasDgemm_v2
245
+ #define cublasCgemm cublasCgemm_v2
246
+ #define cublasZgemm cublasZgemm_v2
247
+
248
+ #define cublasSsyrk cublasSsyrk_v2
249
+ #define cublasDsyrk cublasDsyrk_v2
250
+ #define cublasCsyrk cublasCsyrk_v2
251
+ #define cublasZsyrk cublasZsyrk_v2
252
+ #define cublasCherk cublasCherk_v2
253
+ #define cublasZherk cublasZherk_v2
254
+
255
+ #define cublasSsyr2k cublasSsyr2k_v2
256
+ #define cublasDsyr2k cublasDsyr2k_v2
257
+ #define cublasCsyr2k cublasCsyr2k_v2
258
+ #define cublasZsyr2k cublasZsyr2k_v2
259
+ #define cublasCher2k cublasCher2k_v2
260
+ #define cublasZher2k cublasZher2k_v2
261
+
262
+ #define cublasSsymm cublasSsymm_v2
263
+ #define cublasDsymm cublasDsymm_v2
264
+ #define cublasCsymm cublasCsymm_v2
265
+ #define cublasZsymm cublasZsymm_v2
266
+ #define cublasChemm cublasChemm_v2
267
+ #define cublasZhemm cublasZhemm_v2
268
+
269
+ #define cublasStrsm cublasStrsm_v2
270
+ #define cublasDtrsm cublasDtrsm_v2
271
+ #define cublasCtrsm cublasCtrsm_v2
272
+ #define cublasZtrsm cublasZtrsm_v2
273
+
274
+ #define cublasStrmm cublasStrmm_v2
275
+ #define cublasDtrmm cublasDtrmm_v2
276
+ #define cublasCtrmm cublasCtrmm_v2
277
+ #define cublasZtrmm cublasZtrmm_v2
278
+
279
+ /* 64-bit integer */
280
+
281
+ /* Blas1 Routines */
282
+
283
+ #define cublasSnrm2_64 cublasSnrm2_v2_64
284
+ #define cublasDnrm2_64 cublasDnrm2_v2_64
285
+ #define cublasScnrm2_64 cublasScnrm2_v2_64
286
+ #define cublasDznrm2_64 cublasDznrm2_v2_64
287
+
288
+ #define cublasSdot_64 cublasSdot_v2_64
289
+ #define cublasDdot_64 cublasDdot_v2_64
290
+ #define cublasCdotu_64 cublasCdotu_v2_64
291
+ #define cublasCdotc_64 cublasCdotc_v2_64
292
+ #define cublasZdotu_64 cublasZdotu_v2_64
293
+ #define cublasZdotc_64 cublasZdotc_v2_64
294
+
295
+ #define cublasSscal_64 cublasSscal_v2_64
296
+ #define cublasDscal_64 cublasDscal_v2_64
297
+ #define cublasCscal_64 cublasCscal_v2_64
298
+ #define cublasCsscal_64 cublasCsscal_v2_64
299
+ #define cublasZscal_64 cublasZscal_v2_64
300
+ #define cublasZdscal_64 cublasZdscal_v2_64
301
+
302
+ #define cublasSaxpy_64 cublasSaxpy_v2_64
303
+ #define cublasDaxpy_64 cublasDaxpy_v2_64
304
+ #define cublasCaxpy_64 cublasCaxpy_v2_64
305
+ #define cublasZaxpy_64 cublasZaxpy_v2_64
306
+
307
+ #define cublasScopy_64 cublasScopy_v2_64
308
+ #define cublasDcopy_64 cublasDcopy_v2_64
309
+ #define cublasCcopy_64 cublasCcopy_v2_64
310
+ #define cublasZcopy_64 cublasZcopy_v2_64
311
+
312
+ #define cublasSswap_64 cublasSswap_v2_64
313
+ #define cublasDswap_64 cublasDswap_v2_64
314
+ #define cublasCswap_64 cublasCswap_v2_64
315
+ #define cublasZswap_64 cublasZswap_v2_64
316
+
317
+ #define cublasIsamax_64 cublasIsamax_v2_64
318
+ #define cublasIdamax_64 cublasIdamax_v2_64
319
+ #define cublasIcamax_64 cublasIcamax_v2_64
320
+ #define cublasIzamax_64 cublasIzamax_v2_64
321
+
322
+ #define cublasIsamin_64 cublasIsamin_v2_64
323
+ #define cublasIdamin_64 cublasIdamin_v2_64
324
+ #define cublasIcamin_64 cublasIcamin_v2_64
325
+ #define cublasIzamin_64 cublasIzamin_v2_64
326
+
327
+ #define cublasSasum_64 cublasSasum_v2_64
328
+ #define cublasDasum_64 cublasDasum_v2_64
329
+ #define cublasScasum_64 cublasScasum_v2_64
330
+ #define cublasDzasum_64 cublasDzasum_v2_64
331
+
332
+ #define cublasSrot_64 cublasSrot_v2_64
333
+ #define cublasDrot_64 cublasDrot_v2_64
334
+ #define cublasCrot_64 cublasCrot_v2_64
335
+ #define cublasCsrot_64 cublasCsrot_v2_64
336
+ #define cublasZrot_64 cublasZrot_v2_64
337
+ #define cublasZdrot_64 cublasZdrot_v2_64
338
+
339
+ #define cublasSrotg_64 cublasSrotg_v2_64
340
+ #define cublasDrotg_64 cublasDrotg_v2_64
341
+ #define cublasCrotg_64 cublasCrotg_v2_64
342
+ #define cublasZrotg_64 cublasZrotg_v2_64
343
+
344
+ #define cublasSrotm_64 cublasSrotm_v2_64
345
+ #define cublasDrotm_64 cublasDrotm_v2_64
346
+
347
+ #define cublasSrotmg_64 cublasSrotmg_v2_64
348
+ #define cublasDrotmg_64 cublasDrotmg_v2_64
349
+
350
+ /* Blas2 Routines */
351
+
352
+ #define cublasSgemv_64 cublasSgemv_v2_64
353
+ #define cublasDgemv_64 cublasDgemv_v2_64
354
+ #define cublasCgemv_64 cublasCgemv_v2_64
355
+ #define cublasZgemv_64 cublasZgemv_v2_64
356
+
357
+ #define cublasSgbmv_64 cublasSgbmv_v2_64
358
+ #define cublasDgbmv_64 cublasDgbmv_v2_64
359
+ #define cublasCgbmv_64 cublasCgbmv_v2_64
360
+ #define cublasZgbmv_64 cublasZgbmv_v2_64
361
+
362
+ #define cublasStrmv_64 cublasStrmv_v2_64
363
+ #define cublasDtrmv_64 cublasDtrmv_v2_64
364
+ #define cublasCtrmv_64 cublasCtrmv_v2_64
365
+ #define cublasZtrmv_64 cublasZtrmv_v2_64
366
+
367
+ #define cublasStbmv_64 cublasStbmv_v2_64
368
+ #define cublasDtbmv_64 cublasDtbmv_v2_64
369
+ #define cublasCtbmv_64 cublasCtbmv_v2_64
370
+ #define cublasZtbmv_64 cublasZtbmv_v2_64
371
+
372
+ #define cublasStpmv_64 cublasStpmv_v2_64
373
+ #define cublasDtpmv_64 cublasDtpmv_v2_64
374
+ #define cublasCtpmv_64 cublasCtpmv_v2_64
375
+ #define cublasZtpmv_64 cublasZtpmv_v2_64
376
+
377
+ #define cublasStrsv_64 cublasStrsv_v2_64
378
+ #define cublasDtrsv_64 cublasDtrsv_v2_64
379
+ #define cublasCtrsv_64 cublasCtrsv_v2_64
380
+ #define cublasZtrsv_64 cublasZtrsv_v2_64
381
+
382
+ #define cublasStpsv_64 cublasStpsv_v2_64
383
+ #define cublasDtpsv_64 cublasDtpsv_v2_64
384
+ #define cublasCtpsv_64 cublasCtpsv_v2_64
385
+ #define cublasZtpsv_64 cublasZtpsv_v2_64
386
+
387
+ #define cublasStbsv_64 cublasStbsv_v2_64
388
+ #define cublasDtbsv_64 cublasDtbsv_v2_64
389
+ #define cublasCtbsv_64 cublasCtbsv_v2_64
390
+ #define cublasZtbsv_64 cublasZtbsv_v2_64
391
+
392
+ #define cublasSsymv_64 cublasSsymv_v2_64
393
+ #define cublasDsymv_64 cublasDsymv_v2_64
394
+ #define cublasCsymv_64 cublasCsymv_v2_64
395
+ #define cublasZsymv_64 cublasZsymv_v2_64
396
+ #define cublasChemv_64 cublasChemv_v2_64
397
+ #define cublasZhemv_64 cublasZhemv_v2_64
398
+
399
+ #define cublasSsbmv_64 cublasSsbmv_v2_64
400
+ #define cublasDsbmv_64 cublasDsbmv_v2_64
401
+ #define cublasChbmv_64 cublasChbmv_v2_64
402
+ #define cublasZhbmv_64 cublasZhbmv_v2_64
403
+
404
+ #define cublasSspmv_64 cublasSspmv_v2_64
405
+ #define cublasDspmv_64 cublasDspmv_v2_64
406
+ #define cublasChpmv_64 cublasChpmv_v2_64
407
+ #define cublasZhpmv_64 cublasZhpmv_v2_64
408
+
409
+ #define cublasSger_64 cublasSger_v2_64
410
+ #define cublasDger_64 cublasDger_v2_64
411
+ #define cublasCgeru_64 cublasCgeru_v2_64
412
+ #define cublasCgerc_64 cublasCgerc_v2_64
413
+ #define cublasZgeru_64 cublasZgeru_v2_64
414
+ #define cublasZgerc_64 cublasZgerc_v2_64
415
+
416
+ #define cublasSsyr_64 cublasSsyr_v2_64
417
+ #define cublasDsyr_64 cublasDsyr_v2_64
418
+ #define cublasCsyr_64 cublasCsyr_v2_64
419
+ #define cublasZsyr_64 cublasZsyr_v2_64
420
+ #define cublasCher_64 cublasCher_v2_64
421
+ #define cublasZher_64 cublasZher_v2_64
422
+
423
+ #define cublasSspr_64 cublasSspr_v2_64
424
+ #define cublasDspr_64 cublasDspr_v2_64
425
+ #define cublasChpr_64 cublasChpr_v2_64
426
+ #define cublasZhpr_64 cublasZhpr_v2_64
427
+
428
+ #define cublasSsyr2_64 cublasSsyr2_v2_64
429
+ #define cublasDsyr2_64 cublasDsyr2_v2_64
430
+ #define cublasCsyr2_64 cublasCsyr2_v2_64
431
+ #define cublasZsyr2_64 cublasZsyr2_v2_64
432
+ #define cublasCher2_64 cublasCher2_v2_64
433
+ #define cublasZher2_64 cublasZher2_v2_64
434
+
435
+ #define cublasSspr2_64 cublasSspr2_v2_64
436
+ #define cublasDspr2_64 cublasDspr2_v2_64
437
+ #define cublasChpr2_64 cublasChpr2_v2_64
438
+ #define cublasZhpr2_64 cublasZhpr2_v2_64
439
+
440
+ /* Blas3 Routines */
441
+
442
+ #define cublasSgemm_64 cublasSgemm_v2_64
443
+ #define cublasDgemm_64 cublasDgemm_v2_64
444
+ #define cublasCgemm_64 cublasCgemm_v2_64
445
+ #define cublasZgemm_64 cublasZgemm_v2_64
446
+
447
+ #define cublasSsyrk_64 cublasSsyrk_v2_64
448
+ #define cublasDsyrk_64 cublasDsyrk_v2_64
449
+ #define cublasCsyrk_64 cublasCsyrk_v2_64
450
+ #define cublasZsyrk_64 cublasZsyrk_v2_64
451
+ #define cublasCherk_64 cublasCherk_v2_64
452
+ #define cublasZherk_64 cublasZherk_v2_64
453
+
454
+ #define cublasSsyr2k_64 cublasSsyr2k_v2_64
455
+ #define cublasDsyr2k_64 cublasDsyr2k_v2_64
456
+ #define cublasCsyr2k_64 cublasCsyr2k_v2_64
457
+ #define cublasZsyr2k_64 cublasZsyr2k_v2_64
458
+ #define cublasCher2k_64 cublasCher2k_v2_64
459
+ #define cublasZher2k_64 cublasZher2k_v2_64
460
+
461
+ #define cublasSsymm_64 cublasSsymm_v2_64
462
+ #define cublasDsymm_64 cublasDsymm_v2_64
463
+ #define cublasCsymm_64 cublasCsymm_v2_64
464
+ #define cublasZsymm_64 cublasZsymm_v2_64
465
+ #define cublasChemm_64 cublasChemm_v2_64
466
+ #define cublasZhemm_64 cublasZhemm_v2_64
467
+
468
+ #define cublasStrsm_64 cublasStrsm_v2_64
469
+ #define cublasDtrsm_64 cublasDtrsm_v2_64
470
+ #define cublasCtrsm_64 cublasCtrsm_v2_64
471
+ #define cublasZtrsm_64 cublasZtrsm_v2_64
472
+
473
+ #define cublasStrmm_64 cublasStrmm_v2_64
474
+ #define cublasDtrmm_64 cublasDtrmm_v2_64
475
+ #define cublasCtrmm_64 cublasCtrmm_v2_64
476
+ #define cublasZtrmm_64 cublasZtrmm_v2_64
477
+
478
+ #endif /* !defined(CUBLAS_V2_H_) */
.venv/lib/python3.11/site-packages/nvidia/cublas/include/nvblas.h ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #if !defined(NVBLAS_H_)
51
+ #define NVBLAS_H_
52
+
53
+ #include "driver_types.h"
54
+ #include "cuComplex.h" /* import complex data type */
55
+
56
+ #if defined(__cplusplus)
57
+ extern "C" {
58
+ #endif
59
+
60
+ /* GEMM */
61
+ void sgemm_(const char* transa,
62
+ const char* transb,
63
+ const int* m,
64
+ const int* n,
65
+ const int* k,
66
+ const float* alpha,
67
+ const float* a,
68
+ const int* lda,
69
+ const float* b,
70
+ const int* ldb,
71
+ const float* beta,
72
+ float* c,
73
+ const int* ldc);
74
+
75
+ void dgemm_(const char* transa,
76
+ const char* transb,
77
+ const int* m,
78
+ const int* n,
79
+ const int* k,
80
+ const double* alpha,
81
+ const double* a,
82
+ const int* lda,
83
+ const double* b,
84
+ const int* ldb,
85
+ const double* beta,
86
+ double* c,
87
+ const int* ldc);
88
+
89
+ void cgemm_(const char* transa,
90
+ const char* transb,
91
+ const int* m,
92
+ const int* n,
93
+ const int* k,
94
+ const cuComplex* alpha,
95
+ const cuComplex* a,
96
+ const int* lda,
97
+ const cuComplex* b,
98
+ const int* ldb,
99
+ const cuComplex* beta,
100
+ cuComplex* c,
101
+ const int* ldc);
102
+
103
+ void zgemm_(const char* transa,
104
+ const char* transb,
105
+ const int* m,
106
+ const int* n,
107
+ const int* k,
108
+ const cuDoubleComplex* alpha,
109
+ const cuDoubleComplex* a,
110
+ const int* lda,
111
+ const cuDoubleComplex* b,
112
+ const int* ldb,
113
+ const cuDoubleComplex* beta,
114
+ cuDoubleComplex* c,
115
+ const int* ldc);
116
+
117
+ void sgemm(const char* transa,
118
+ const char* transb,
119
+ const int* m,
120
+ const int* n,
121
+ const int* k,
122
+ const float* alpha,
123
+ const float* a,
124
+ const int* lda,
125
+ const float* b,
126
+ const int* ldb,
127
+ const float* beta,
128
+ float* c,
129
+ const int* ldc);
130
+
131
+ void dgemm(const char* transa,
132
+ const char* transb,
133
+ const int* m,
134
+ const int* n,
135
+ const int* k,
136
+ const double* alpha,
137
+ const double* a,
138
+ const int* lda,
139
+ const double* b,
140
+ const int* ldb,
141
+ const double* beta,
142
+ double* c,
143
+ const int* ldc);
144
+
145
+ void cgemm(const char* transa,
146
+ const char* transb,
147
+ const int* m,
148
+ const int* n,
149
+ const int* k,
150
+ const cuComplex* alpha,
151
+ const cuComplex* a,
152
+ const int* lda,
153
+ const cuComplex* b,
154
+ const int* ldb,
155
+ const cuComplex* beta,
156
+ cuComplex* c,
157
+ const int* ldc);
158
+
159
+ void zgemm(const char* transa,
160
+ const char* transb,
161
+ const int* m,
162
+ const int* n,
163
+ const int* k,
164
+ const cuDoubleComplex* alpha,
165
+ const cuDoubleComplex* a,
166
+ const int* lda,
167
+ const cuDoubleComplex* b,
168
+ const int* ldb,
169
+ const cuDoubleComplex* beta,
170
+ cuDoubleComplex* c,
171
+ const int* ldc);
172
+
173
+ /* SYRK */
174
+ void ssyrk_(const char* uplo,
175
+ const char* trans,
176
+ const int* n,
177
+ const int* k,
178
+ const float* alpha,
179
+ const float* a,
180
+ const int* lda,
181
+ const float* beta,
182
+ float* c,
183
+ const int* ldc);
184
+
185
+ void dsyrk_(const char* uplo,
186
+ const char* trans,
187
+ const int* n,
188
+ const int* k,
189
+ const double* alpha,
190
+ const double* a,
191
+ const int* lda,
192
+ const double* beta,
193
+ double* c,
194
+ const int* ldc);
195
+
196
+ void csyrk_(const char* uplo,
197
+ const char* trans,
198
+ const int* n,
199
+ const int* k,
200
+ const cuComplex* alpha,
201
+ const cuComplex* a,
202
+ const int* lda,
203
+ const cuComplex* beta,
204
+ cuComplex* c,
205
+ const int* ldc);
206
+
207
+ void zsyrk_(const char* uplo,
208
+ const char* trans,
209
+ const int* n,
210
+ const int* k,
211
+ const cuDoubleComplex* alpha,
212
+ const cuDoubleComplex* a,
213
+ const int* lda,
214
+ const cuDoubleComplex* beta,
215
+ cuDoubleComplex* c,
216
+ const int* ldc);
217
+
218
+ void ssyrk(const char* uplo,
219
+ const char* trans,
220
+ const int* n,
221
+ const int* k,
222
+ const float* alpha,
223
+ const float* a,
224
+ const int* lda,
225
+ const float* beta,
226
+ float* c,
227
+ const int* ldc);
228
+
229
+ void dsyrk(const char* uplo,
230
+ const char* trans,
231
+ const int* n,
232
+ const int* k,
233
+ const double* alpha,
234
+ const double* a,
235
+ const int* lda,
236
+ const double* beta,
237
+ double* c,
238
+ const int* ldc);
239
+
240
+ void csyrk(const char* uplo,
241
+ const char* trans,
242
+ const int* n,
243
+ const int* k,
244
+ const cuComplex* alpha,
245
+ const cuComplex* a,
246
+ const int* lda,
247
+ const cuComplex* beta,
248
+ cuComplex* c,
249
+ const int* ldc);
250
+
251
+ void zsyrk(const char* uplo,
252
+ const char* trans,
253
+ const int* n,
254
+ const int* k,
255
+ const cuDoubleComplex* alpha,
256
+ const cuDoubleComplex* a,
257
+ const int* lda,
258
+ const cuDoubleComplex* beta,
259
+ cuDoubleComplex* c,
260
+ const int* ldc);
261
+
262
+ /* HERK */
263
+ void cherk_(const char* uplo,
264
+ const char* trans,
265
+ const int* n,
266
+ const int* k,
267
+ const float* alpha,
268
+ const cuComplex* a,
269
+ const int* lda,
270
+ const float* beta,
271
+ cuComplex* c,
272
+ const int* ldc);
273
+
274
+ void zherk_(const char* uplo,
275
+ const char* trans,
276
+ const int* n,
277
+ const int* k,
278
+ const double* alpha,
279
+ const cuDoubleComplex* a,
280
+ const int* lda,
281
+ const double* beta,
282
+ cuDoubleComplex* c,
283
+ const int* ldc);
284
+
285
+ void cherk(const char* uplo,
286
+ const char* trans,
287
+ const int* n,
288
+ const int* k,
289
+ const float* alpha,
290
+ const cuComplex* a,
291
+ const int* lda,
292
+ const float* beta,
293
+ cuComplex* c,
294
+ const int* ldc);
295
+
296
+ void zherk(const char* uplo,
297
+ const char* trans,
298
+ const int* n,
299
+ const int* k,
300
+ const double* alpha,
301
+ const cuDoubleComplex* a,
302
+ const int* lda,
303
+ const double* beta,
304
+ cuDoubleComplex* c,
305
+ const int* ldc);
306
+
307
+ /* TRSM */
308
+ void strsm_(const char* side,
309
+ const char* uplo,
310
+ const char* transa,
311
+ const char* diag,
312
+ const int* m,
313
+ const int* n,
314
+ const float* alpha,
315
+ const float* a,
316
+ const int* lda,
317
+ float* b,
318
+ const int* ldb);
319
+
320
+ void dtrsm_(const char* side,
321
+ const char* uplo,
322
+ const char* transa,
323
+ const char* diag,
324
+ const int* m,
325
+ const int* n,
326
+ const double* alpha,
327
+ const double* a,
328
+ const int* lda,
329
+ double* b,
330
+ const int* ldb);
331
+
332
+ void ctrsm_(const char* side,
333
+ const char* uplo,
334
+ const char* transa,
335
+ const char* diag,
336
+ const int* m,
337
+ const int* n,
338
+ const cuComplex* alpha,
339
+ const cuComplex* a,
340
+ const int* lda,
341
+ cuComplex* b,
342
+ const int* ldb);
343
+
344
+ void ztrsm_(const char* side,
345
+ const char* uplo,
346
+ const char* transa,
347
+ const char* diag,
348
+ const int* m,
349
+ const int* n,
350
+ const cuDoubleComplex* alpha,
351
+ const cuDoubleComplex* a,
352
+ const int* lda,
353
+ cuDoubleComplex* b,
354
+ const int* ldb);
355
+
356
+ void strsm(const char* side,
357
+ const char* uplo,
358
+ const char* transa,
359
+ const char* diag,
360
+ const int* m,
361
+ const int* n,
362
+ const float* alpha,
363
+ const float* a,
364
+ const int* lda,
365
+ float* b,
366
+ const int* ldb);
367
+
368
+ void dtrsm(const char* side,
369
+ const char* uplo,
370
+ const char* transa,
371
+ const char* diag,
372
+ const int* m,
373
+ const int* n,
374
+ const double* alpha,
375
+ const double* a,
376
+ const int* lda,
377
+ double* b,
378
+ const int* ldb);
379
+
380
+ void ctrsm(const char* side,
381
+ const char* uplo,
382
+ const char* transa,
383
+ const char* diag,
384
+ const int* m,
385
+ const int* n,
386
+ const cuComplex* alpha,
387
+ const cuComplex* a,
388
+ const int* lda,
389
+ cuComplex* b,
390
+ const int* ldb);
391
+
392
+ void ztrsm(const char* side,
393
+ const char* uplo,
394
+ const char* transa,
395
+ const char* diag,
396
+ const int* m,
397
+ const int* n,
398
+ const cuDoubleComplex* alpha,
399
+ const cuDoubleComplex* a,
400
+ const int* lda,
401
+ cuDoubleComplex* b,
402
+ const int* ldb);
403
+
404
+ /* SYMM */
405
+ void ssymm_(const char* side,
406
+ const char* uplo,
407
+ const int* m,
408
+ const int* n,
409
+ const float* alpha,
410
+ const float* a,
411
+ const int* lda,
412
+ const float* b,
413
+ const int* ldb,
414
+ const float* beta,
415
+ float* c,
416
+ const int* ldc);
417
+
418
+ void dsymm_(const char* side,
419
+ const char* uplo,
420
+ const int* m,
421
+ const int* n,
422
+ const double* alpha,
423
+ const double* a,
424
+ const int* lda,
425
+ const double* b,
426
+ const int* ldb,
427
+ const double* beta,
428
+ double* c,
429
+ const int* ldc);
430
+
431
+ void csymm_(const char* side,
432
+ const char* uplo,
433
+ const int* m,
434
+ const int* n,
435
+ const cuComplex* alpha,
436
+ const cuComplex* a,
437
+ const int* lda,
438
+ const cuComplex* b,
439
+ const int* ldb,
440
+ const cuComplex* beta,
441
+ cuComplex* c,
442
+ const int* ldc);
443
+
444
+ void zsymm_(const char* side,
445
+ const char* uplo,
446
+ const int* m,
447
+ const int* n,
448
+ const cuDoubleComplex* alpha,
449
+ const cuDoubleComplex* a,
450
+ const int* lda,
451
+ const cuDoubleComplex* b,
452
+ const int* ldb,
453
+ const cuDoubleComplex* beta,
454
+ cuDoubleComplex* c,
455
+ const int* ldc);
456
+
457
+ void ssymm(const char* side,
458
+ const char* uplo,
459
+ const int* m,
460
+ const int* n,
461
+ const float* alpha,
462
+ const float* a,
463
+ const int* lda,
464
+ const float* b,
465
+ const int* ldb,
466
+ const float* beta,
467
+ float* c,
468
+ const int* ldc);
469
+
470
+ void dsymm(const char* side,
471
+ const char* uplo,
472
+ const int* m,
473
+ const int* n,
474
+ const double* alpha,
475
+ const double* a,
476
+ const int* lda,
477
+ const double* b,
478
+ const int* ldb,
479
+ const double* beta,
480
+ double* c,
481
+ const int* ldc);
482
+
483
+ void csymm(const char* side,
484
+ const char* uplo,
485
+ const int* m,
486
+ const int* n,
487
+ const cuComplex* alpha,
488
+ const cuComplex* a,
489
+ const int* lda,
490
+ const cuComplex* b,
491
+ const int* ldb,
492
+ const cuComplex* beta,
493
+ cuComplex* c,
494
+ const int* ldc);
495
+
496
+ void zsymm(const char* side,
497
+ const char* uplo,
498
+ const int* m,
499
+ const int* n,
500
+ const cuDoubleComplex* alpha,
501
+ const cuDoubleComplex* a,
502
+ const int* lda,
503
+ const cuDoubleComplex* b,
504
+ const int* ldb,
505
+ const cuDoubleComplex* beta,
506
+ cuDoubleComplex* c,
507
+ const int* ldc);
508
+
509
+ /* HEMM */
510
+ void chemm_(const char* side,
511
+ const char* uplo,
512
+ const int* m,
513
+ const int* n,
514
+ const cuComplex* alpha,
515
+ const cuComplex* a,
516
+ const int* lda,
517
+ const cuComplex* b,
518
+ const int* ldb,
519
+ const cuComplex* beta,
520
+ cuComplex* c,
521
+ const int* ldc);
522
+
523
+ void zhemm_(const char* side,
524
+ const char* uplo,
525
+ const int* m,
526
+ const int* n,
527
+ const cuDoubleComplex* alpha,
528
+ const cuDoubleComplex* a,
529
+ const int* lda,
530
+ const cuDoubleComplex* b,
531
+ const int* ldb,
532
+ const cuDoubleComplex* beta,
533
+ cuDoubleComplex* c,
534
+ const int* ldc);
535
+
536
+ /* HEMM with no underscore*/
537
+ void chemm(const char* side,
538
+ const char* uplo,
539
+ const int* m,
540
+ const int* n,
541
+ const cuComplex* alpha,
542
+ const cuComplex* a,
543
+ const int* lda,
544
+ const cuComplex* b,
545
+ const int* ldb,
546
+ const cuComplex* beta,
547
+ cuComplex* c,
548
+ const int* ldc);
549
+
550
+ void zhemm(const char* side,
551
+ const char* uplo,
552
+ const int* m,
553
+ const int* n,
554
+ const cuDoubleComplex* alpha,
555
+ const cuDoubleComplex* a,
556
+ const int* lda,
557
+ const cuDoubleComplex* b,
558
+ const int* ldb,
559
+ const cuDoubleComplex* beta,
560
+ cuDoubleComplex* c,
561
+ const int* ldc);
562
+
563
+ /* SYR2K */
564
+ void ssyr2k_(const char* uplo,
565
+ const char* trans,
566
+ const int* n,
567
+ const int* k,
568
+ const float* alpha,
569
+ const float* a,
570
+ const int* lda,
571
+ const float* b,
572
+ const int* ldb,
573
+ const float* beta,
574
+ float* c,
575
+ const int* ldc);
576
+
577
+ void dsyr2k_(const char* uplo,
578
+ const char* trans,
579
+ const int* n,
580
+ const int* k,
581
+ const double* alpha,
582
+ const double* a,
583
+ const int* lda,
584
+ const double* b,
585
+ const int* ldb,
586
+ const double* beta,
587
+ double* c,
588
+ const int* ldc);
589
+
590
+ void csyr2k_(const char* uplo,
591
+ const char* trans,
592
+ const int* n,
593
+ const int* k,
594
+ const cuComplex* alpha,
595
+ const cuComplex* a,
596
+ const int* lda,
597
+ const cuComplex* b,
598
+ const int* ldb,
599
+ const cuComplex* beta,
600
+ cuComplex* c,
601
+ const int* ldc);
602
+
603
+ void zsyr2k_(const char* uplo,
604
+ const char* trans,
605
+ const int* n,
606
+ const int* k,
607
+ const cuDoubleComplex* alpha,
608
+ const cuDoubleComplex* a,
609
+ const int* lda,
610
+ const cuDoubleComplex* b,
611
+ const int* ldb,
612
+ const cuDoubleComplex* beta,
613
+ cuDoubleComplex* c,
614
+ const int* ldc);
615
+
616
+ /* SYR2K no_underscore*/
617
+ void ssyr2k(const char* uplo,
618
+ const char* trans,
619
+ const int* n,
620
+ const int* k,
621
+ const float* alpha,
622
+ const float* a,
623
+ const int* lda,
624
+ const float* b,
625
+ const int* ldb,
626
+ const float* beta,
627
+ float* c,
628
+ const int* ldc);
629
+
630
+ void dsyr2k(const char* uplo,
631
+ const char* trans,
632
+ const int* n,
633
+ const int* k,
634
+ const double* alpha,
635
+ const double* a,
636
+ const int* lda,
637
+ const double* b,
638
+ const int* ldb,
639
+ const double* beta,
640
+ double* c,
641
+ const int* ldc);
642
+
643
+ void csyr2k(const char* uplo,
644
+ const char* trans,
645
+ const int* n,
646
+ const int* k,
647
+ const cuComplex* alpha,
648
+ const cuComplex* a,
649
+ const int* lda,
650
+ const cuComplex* b,
651
+ const int* ldb,
652
+ const cuComplex* beta,
653
+ cuComplex* c,
654
+ const int* ldc);
655
+
656
+ void zsyr2k(const char* uplo,
657
+ const char* trans,
658
+ const int* n,
659
+ const int* k,
660
+ const cuDoubleComplex* alpha,
661
+ const cuDoubleComplex* a,
662
+ const int* lda,
663
+ const cuDoubleComplex* b,
664
+ const int* ldb,
665
+ const cuDoubleComplex* beta,
666
+ cuDoubleComplex* c,
667
+ const int* ldc);
668
+
669
+ /* HERK */
670
+ void cher2k_(const char* uplo,
671
+ const char* trans,
672
+ const int* n,
673
+ const int* k,
674
+ const cuComplex* alpha,
675
+ const cuComplex* a,
676
+ const int* lda,
677
+ const cuComplex* b,
678
+ const int* ldb,
679
+ const float* beta,
680
+ cuComplex* c,
681
+ const int* ldc);
682
+
683
+ void zher2k_(const char* uplo,
684
+ const char* trans,
685
+ const int* n,
686
+ const int* k,
687
+ const cuDoubleComplex* alpha,
688
+ const cuDoubleComplex* a,
689
+ const int* lda,
690
+ const cuDoubleComplex* b,
691
+ const int* ldb,
692
+ const double* beta,
693
+ cuDoubleComplex* c,
694
+ const int* ldc);
695
+
696
+ /* HER2K with no underscore */
697
+ void cher2k(const char* uplo,
698
+ const char* trans,
699
+ const int* n,
700
+ const int* k,
701
+ const cuComplex* alpha,
702
+ const cuComplex* a,
703
+ const int* lda,
704
+ const cuComplex* b,
705
+ const int* ldb,
706
+ const float* beta,
707
+ cuComplex* c,
708
+ const int* ldc);
709
+
710
+ void zher2k(const char* uplo,
711
+ const char* trans,
712
+ const int* n,
713
+ const int* k,
714
+ const cuDoubleComplex* alpha,
715
+ const cuDoubleComplex* a,
716
+ const int* lda,
717
+ const cuDoubleComplex* b,
718
+ const int* ldb,
719
+ const double* beta,
720
+ cuDoubleComplex* c,
721
+ const int* ldc);
722
+
723
+ /* TRMM */
724
+ void strmm_(const char* side,
725
+ const char* uplo,
726
+ const char* transa,
727
+ const char* diag,
728
+ const int* m,
729
+ const int* n,
730
+ const float* alpha,
731
+ const float* a,
732
+ const int* lda,
733
+ float* b,
734
+ const int* ldb);
735
+
736
+ void dtrmm_(const char* side,
737
+ const char* uplo,
738
+ const char* transa,
739
+ const char* diag,
740
+ const int* m,
741
+ const int* n,
742
+ const double* alpha,
743
+ const double* a,
744
+ const int* lda,
745
+ double* b,
746
+ const int* ldb);
747
+
748
+ void ctrmm_(const char* side,
749
+ const char* uplo,
750
+ const char* transa,
751
+ const char* diag,
752
+ const int* m,
753
+ const int* n,
754
+ const cuComplex* alpha,
755
+ const cuComplex* a,
756
+ const int* lda,
757
+ cuComplex* b,
758
+ const int* ldb);
759
+
760
+ void ztrmm_(const char* side,
761
+ const char* uplo,
762
+ const char* transa,
763
+ const char* diag,
764
+ const int* m,
765
+ const int* n,
766
+ const cuDoubleComplex* alpha,
767
+ const cuDoubleComplex* a,
768
+ const int* lda,
769
+ cuDoubleComplex* b,
770
+ const int* ldb);
771
+
772
+ void strmm(const char* side,
773
+ const char* uplo,
774
+ const char* transa,
775
+ const char* diag,
776
+ const int* m,
777
+ const int* n,
778
+ const float* alpha,
779
+ const float* a,
780
+ const int* lda,
781
+ float* b,
782
+ const int* ldb);
783
+
784
+ void dtrmm(const char* side,
785
+ const char* uplo,
786
+ const char* transa,
787
+ const char* diag,
788
+ const int* m,
789
+ const int* n,
790
+ const double* alpha,
791
+ const double* a,
792
+ const int* lda,
793
+ double* b,
794
+ const int* ldb);
795
+
796
+ void ctrmm(const char* side,
797
+ const char* uplo,
798
+ const char* transa,
799
+ const char* diag,
800
+ const int* m,
801
+ const int* n,
802
+ const cuComplex* alpha,
803
+ const cuComplex* a,
804
+ const int* lda,
805
+ cuComplex* b,
806
+ const int* ldb);
807
+
808
+ void ztrmm(const char* side,
809
+ const char* uplo,
810
+ const char* transa,
811
+ const char* diag,
812
+ const int* m,
813
+ const int* n,
814
+ const cuDoubleComplex* alpha,
815
+ const cuDoubleComplex* a,
816
+ const int* lda,
817
+ cuDoubleComplex* b,
818
+ const int* ldb);
819
+
820
+ #if defined(__cplusplus)
821
+ }
822
+ #endif /* __cplusplus */
823
+
824
+ #endif /* !defined(NVBLAS_H_) */
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (190 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c2a58dc54154208392301d0fe3d53a120e4c1ebeab9e80ce91fe9948baeadc9
3
+ size 757496
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (190 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (198 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/nvrtc.h ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // NVIDIA_COPYRIGHT_BEGIN
3
+ //
4
+ // Copyright (c) 2014-2023, NVIDIA CORPORATION. All rights reserved.
5
+ //
6
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
7
+ // and proprietary rights in and to this software, related documentation
8
+ // and any modifications thereto. Any use, reproduction, disclosure or
9
+ // distribution of this software and related documentation without an express
10
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
11
+ //
12
+ // NVIDIA_COPYRIGHT_END
13
+ //
14
+
15
+ #ifndef __NVRTC_H__
16
+ #define __NVRTC_H__
17
+
18
+ #ifdef __cplusplus
19
+ extern "C" {
20
+ #endif /* __cplusplus */
21
+
22
+ #include <stdlib.h>
23
+
24
+
25
+ /*************************************************************************//**
26
+ *
27
+ * \defgroup error Error Handling
28
+ *
29
+ * NVRTC defines the following enumeration type and function for API call
30
+ * error handling.
31
+ *
32
+ ****************************************************************************/
33
+
34
+
35
+ /**
36
+ * \ingroup error
37
+ * \brief The enumerated type nvrtcResult defines API call result codes.
38
+ * NVRTC API functions return nvrtcResult to indicate the call
39
+ * result.
40
+ */
41
+ typedef enum {
42
+ NVRTC_SUCCESS = 0,
43
+ NVRTC_ERROR_OUT_OF_MEMORY = 1,
44
+ NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2,
45
+ NVRTC_ERROR_INVALID_INPUT = 3,
46
+ NVRTC_ERROR_INVALID_PROGRAM = 4,
47
+ NVRTC_ERROR_INVALID_OPTION = 5,
48
+ NVRTC_ERROR_COMPILATION = 6,
49
+ NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7,
50
+ NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8,
51
+ NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9,
52
+ NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10,
53
+ NVRTC_ERROR_INTERNAL_ERROR = 11,
54
+ NVRTC_ERROR_TIME_FILE_WRITE_FAILED = 12
55
+ } nvrtcResult;
56
+
57
+
58
+ /**
59
+ * \ingroup error
60
+ * \brief nvrtcGetErrorString is a helper function that returns a string
61
+ * describing the given nvrtcResult code, e.g., NVRTC_SUCCESS to
62
+ * \c "NVRTC_SUCCESS".
63
+ * For unrecognized enumeration values, it returns
64
+ * \c "NVRTC_ERROR unknown".
65
+ *
66
+ * \param [in] result CUDA Runtime Compilation API result code.
67
+ * \return Message string for the given #nvrtcResult code.
68
+ */
69
+ const char *nvrtcGetErrorString(nvrtcResult result);
70
+
71
+
72
+ /*************************************************************************//**
73
+ *
74
+ * \defgroup query General Information Query
75
+ *
76
+ * NVRTC defines the following function for general information query.
77
+ *
78
+ ****************************************************************************/
79
+
80
+
81
+ /**
82
+ * \ingroup query
83
+ * \brief nvrtcVersion sets the output parameters \p major and \p minor
84
+ * with the CUDA Runtime Compilation version number.
85
+ *
86
+ * \param [out] major CUDA Runtime Compilation major version number.
87
+ * \param [out] minor CUDA Runtime Compilation minor version number.
88
+ * \return
89
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
90
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
91
+ *
92
+ */
93
+ nvrtcResult nvrtcVersion(int *major, int *minor);
94
+
95
+
96
+ /**
97
+ * \ingroup query
98
+ * \brief nvrtcGetNumSupportedArchs sets the output parameter \p numArchs
99
+ * with the number of architectures supported by NVRTC. This can
100
+ * then be used to pass an array to ::nvrtcGetSupportedArchs to
101
+ * get the supported architectures.
102
+ *
103
+ * \param [out] numArchs number of supported architectures.
104
+ * \return
105
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
106
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
107
+ *
108
+ * see ::nvrtcGetSupportedArchs
109
+ */
110
+ nvrtcResult nvrtcGetNumSupportedArchs(int* numArchs);
111
+
112
+
113
+ /**
114
+ * \ingroup query
115
+ * \brief nvrtcGetSupportedArchs populates the array passed via the output parameter
116
+ * \p supportedArchs with the architectures supported by NVRTC. The array is
117
+ * sorted in the ascending order. The size of the array to be passed can be
118
+ * determined using ::nvrtcGetNumSupportedArchs.
119
+ *
120
+ * \param [out] supportedArchs sorted array of supported architectures.
121
+ * \return
122
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
123
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
124
+ *
125
+ * see ::nvrtcGetNumSupportedArchs
126
+ */
127
+ nvrtcResult nvrtcGetSupportedArchs(int* supportedArchs);
128
+
129
+
130
+ /*************************************************************************//**
131
+ *
132
+ * \defgroup compilation Compilation
133
+ *
134
+ * NVRTC defines the following type and functions for actual compilation.
135
+ *
136
+ ****************************************************************************/
137
+
138
+
139
+ /**
140
+ * \ingroup compilation
141
+ * \brief nvrtcProgram is the unit of compilation, and an opaque handle for
142
+ * a program.
143
+ *
144
+ * To compile a CUDA program string, an instance of nvrtcProgram must be
145
+ * created first with ::nvrtcCreateProgram, then compiled with
146
+ * ::nvrtcCompileProgram.
147
+ */
148
+ typedef struct _nvrtcProgram *nvrtcProgram;
149
+
150
+
151
+ /**
152
+ * \ingroup compilation
153
+ * \brief nvrtcCreateProgram creates an instance of nvrtcProgram with the
154
+ * given input parameters, and sets the output parameter \p prog with
155
+ * it.
156
+ *
157
+ * \param [out] prog CUDA Runtime Compilation program.
158
+ * \param [in] src CUDA program source.
159
+ * \param [in] name CUDA program name.\n
160
+ * \p name can be \c NULL; \c "default_program" is
161
+ * used when \p name is \c NULL or "".
162
+ * \param [in] numHeaders Number of headers used.\n
163
+ * \p numHeaders must be greater than or equal to 0.
164
+ * \param [in] headers Sources of the headers.\n
165
+ * \p headers can be \c NULL when \p numHeaders is
166
+ * 0.
167
+ * \param [in] includeNames Name of each header by which they can be
168
+ * included in the CUDA program source.\n
169
+ * \p includeNames can be \c NULL when \p numHeaders
170
+ * is 0. These headers must be included with the exact
171
+ * names specified here.
172
+ * \return
173
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
174
+ * - \link #nvrtcResult NVRTC_ERROR_OUT_OF_MEMORY \endlink
175
+ * - \link #nvrtcResult NVRTC_ERROR_PROGRAM_CREATION_FAILURE \endlink
176
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
177
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
178
+ *
179
+ * \see ::nvrtcDestroyProgram
180
+ */
181
+ nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
182
+ const char *src,
183
+ const char *name,
184
+ int numHeaders,
185
+ const char * const *headers,
186
+ const char * const *includeNames);
187
+
188
+
189
+ /**
190
+ * \ingroup compilation
191
+ * \brief nvrtcDestroyProgram destroys the given program.
192
+ *
193
+ * \param [in] prog CUDA Runtime Compilation program.
194
+ * \return
195
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
196
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
197
+ *
198
+ * \see ::nvrtcCreateProgram
199
+ */
200
+ nvrtcResult nvrtcDestroyProgram(nvrtcProgram *prog);
201
+
202
+
203
+ /**
204
+ * \ingroup compilation
205
+ * \brief nvrtcCompileProgram compiles the given program.
206
+ *
207
+ * \param [in] prog CUDA Runtime Compilation program.
208
+ * \param [in] numOptions Number of compiler options passed.
209
+ * \param [in] options Compiler options in the form of C string array.\n
210
+ * \p options can be \c NULL when \p numOptions is 0.
211
+ *
212
+ * \return
213
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
214
+ * - \link #nvrtcResult NVRTC_ERROR_OUT_OF_MEMORY \endlink
215
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
216
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
217
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_OPTION \endlink
218
+ * - \link #nvrtcResult NVRTC_ERROR_COMPILATION \endlink
219
+ * - \link #nvrtcResult NVRTC_ERROR_BUILTIN_OPERATION_FAILURE \endlink
220
+ * - \link #nvrtcResult NVRTC_ERROR_TIME_FILE_WRITE_FAILED \endlink
221
+ *
222
+ * It supports compile options listed in \ref options.
223
+ */
224
+ nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
225
+ int numOptions, const char * const *options);
226
+
227
+
228
+ /**
229
+ * \ingroup compilation
230
+ * \brief nvrtcGetPTXSize sets the value of \p ptxSizeRet with the size of the PTX
231
+ * generated by the previous compilation of \p prog (including the
232
+ * trailing \c NULL).
233
+ *
234
+ * \param [in] prog CUDA Runtime Compilation program.
235
+ * \param [out] ptxSizeRet Size of the generated PTX (including the trailing
236
+ * \c NULL).
237
+ * \return
238
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
239
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
240
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
241
+ *
242
+ * \see ::nvrtcGetPTX
243
+ */
244
+ nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet);
245
+
246
+
247
+ /**
248
+ * \ingroup compilation
249
+ * \brief nvrtcGetPTX stores the PTX generated by the previous compilation
250
+ * of \p prog in the memory pointed by \p ptx.
251
+ *
252
+ * \param [in] prog CUDA Runtime Compilation program.
253
+ * \param [out] ptx Compiled result.
254
+ * \return
255
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
256
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
257
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
258
+ *
259
+ * \see ::nvrtcGetPTXSize
260
+ */
261
+ nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
262
+
263
+
264
+ /**
265
+ * \ingroup compilation
266
+ * \brief nvrtcGetCUBINSize sets the value of \p cubinSizeRet with the size of the cubin
267
+ * generated by the previous compilation of \p prog. The value of
268
+ * cubinSizeRet is set to 0 if the value specified to \c -arch is a
269
+ * virtual architecture instead of an actual architecture.
270
+ *
271
+ * \param [in] prog CUDA Runtime Compilation program.
272
+ * \param [out] cubinSizeRet Size of the generated cubin.
273
+ * \return
274
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
275
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
276
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
277
+ *
278
+ * \see ::nvrtcGetCUBIN
279
+ */
280
+ nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog, size_t *cubinSizeRet);
281
+
282
+
283
+ /**
284
+ * \ingroup compilation
285
+ * \brief nvrtcGetCUBIN stores the cubin generated by the previous compilation
286
+ * of \p prog in the memory pointed by \p cubin. No cubin is available
287
+ * if the value specified to \c -arch is a virtual architecture instead
288
+ * of an actual architecture.
289
+ *
290
+ * \param [in] prog CUDA Runtime Compilation program.
291
+ * \param [out] cubin Compiled and assembled result.
292
+ * \return
293
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
294
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
295
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
296
+ *
297
+ * \see ::nvrtcGetCUBINSize
298
+ */
299
+ nvrtcResult nvrtcGetCUBIN(nvrtcProgram prog, char *cubin);
300
+
301
+
302
+ #if defined(_WIN32)
303
+ # define __DEPRECATED__(msg) __declspec(deprecated(msg))
304
+ #elif (defined(__GNUC__) && (__GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 5 && !defined(__clang__))))
305
+ # define __DEPRECATED__(msg) __attribute__((deprecated))
306
+ #elif (defined(__GNUC__))
307
+ # define __DEPRECATED__(msg) __attribute__((deprecated(msg)))
308
+ #else
309
+ # define __DEPRECATED__(msg)
310
+ #endif
311
+
312
+ /**
313
+ * \ingroup compilation
314
+ * \brief
315
+ * DEPRECATION NOTICE: This function will be removed in a future release. Please use
316
+ * nvrtcGetLTOIRSize (and nvrtcGetLTOIR) instead.
317
+ */
318
+ __DEPRECATED__("This function will be removed in a future release. Please use nvrtcGetLTOIRSize instead")
319
+ nvrtcResult nvrtcGetNVVMSize(nvrtcProgram prog, size_t *nvvmSizeRet);
320
+
321
+ /**
322
+ * \ingroup compilation
323
+ * \brief
324
+ * DEPRECATION NOTICE: This function will be removed in a future release. Please use
325
+ * nvrtcGetLTOIR (and nvrtcGetLTOIRSize) instead.
326
+ */
327
+ __DEPRECATED__("This function will be removed in a future release. Please use nvrtcGetLTOIR instead")
328
+ nvrtcResult nvrtcGetNVVM(nvrtcProgram prog, char *nvvm);
329
+
330
+ #undef __DEPRECATED__
331
+
332
+ /**
333
+ * \ingroup compilation
334
+ * \brief nvrtcGetLTOIRSize sets the value of \p LTOIRSizeRet with the size of the LTO IR
335
+ * generated by the previous compilation of \p prog. The value of
336
+ * LTOIRSizeRet is set to 0 if the program was not compiled with
337
+ * \c -dlto.
338
+ *
339
+ * \param [in] prog CUDA Runtime Compilation program.
340
+ * \param [out] LTOIRSizeRet Size of the generated LTO IR.
341
+ * \return
342
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
343
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
344
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
345
+ *
346
+ * \see ::nvrtcGetLTOIR
347
+ */
348
+ nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *LTOIRSizeRet);
349
+
350
+
351
+ /**
352
+ * \ingroup compilation
353
+ * \brief nvrtcGetLTOIR stores the LTO IR generated by the previous compilation
354
+ * of \p prog in the memory pointed by \p LTOIR. No LTO IR is available
355
+ * if the program was compiled without \c -dlto.
356
+ *
357
+ * \param [in] prog CUDA Runtime Compilation program.
358
+ * \param [out] LTOIR Compiled result.
359
+ * \return
360
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
361
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
362
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
363
+ *
364
+ * \see ::nvrtcGetLTOIRSize
365
+ */
366
+ nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *LTOIR);
367
+
368
+
369
+ /**
370
+ * \ingroup compilation
371
+ * \brief nvrtcGetOptiXIRSize sets the value of \p optixirSizeRet with the size of the OptiX IR
372
+ * generated by the previous compilation of \p prog. The value of
373
+ * nvrtcGetOptiXIRSize is set to 0 if the program was compiled with
374
+ * options incompatible with OptiX IR generation.
375
+ *
376
+ * \param [in] prog CUDA Runtime Compilation program.
377
+ * \param [out] optixirSizeRet Size of the generated LTO IR.
378
+ * \return
379
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
380
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
381
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
382
+ *
383
+ * \see ::nvrtcGetOptiXIR
384
+ */
385
+ nvrtcResult nvrtcGetOptiXIRSize(nvrtcProgram prog, size_t *optixirSizeRet);
386
+
387
+
388
+ /**
389
+ * \ingroup compilation
390
+ * \brief nvrtcGetOptiXIR stores the OptiX IR generated by the previous compilation
391
+ * of \p prog in the memory pointed by \p optixir. No OptiX IR is available
392
+ * if the program was compiled with options incompatible with OptiX IR generation.
393
+ *
394
+ * \param [in] prog CUDA Runtime Compilation program.
395
+ * \param [out] Optix IR Compiled result.
396
+ * \return
397
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
398
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
399
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
400
+ *
401
+ * \see ::nvrtcGetOptiXIRSize
402
+ */
403
+ nvrtcResult nvrtcGetOptiXIR(nvrtcProgram prog, char *optixir);
404
+
405
+ /**
406
+ * \ingroup compilation
407
+ * \brief nvrtcGetProgramLogSize sets \p logSizeRet with the size of the
408
+ * log generated by the previous compilation of \p prog (including the
409
+ * trailing \c NULL).
410
+ *
411
+ * Note that compilation log may be generated with warnings and informative
412
+ * messages, even when the compilation of \p prog succeeds.
413
+ *
414
+ * \param [in] prog CUDA Runtime Compilation program.
415
+ * \param [out] logSizeRet Size of the compilation log
416
+ * (including the trailing \c NULL).
417
+ * \return
418
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
419
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
420
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
421
+ *
422
+ * \see ::nvrtcGetProgramLog
423
+ */
424
+ nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, size_t *logSizeRet);
425
+
426
+
427
+ /**
428
+ * \ingroup compilation
429
+ * \brief nvrtcGetProgramLog stores the log generated by the previous
430
+ * compilation of \p prog in the memory pointed by \p log.
431
+ *
432
+ * \param [in] prog CUDA Runtime Compilation program.
433
+ * \param [out] log Compilation log.
434
+ * \return
435
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
436
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
437
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
438
+ *
439
+ * \see ::nvrtcGetProgramLogSize
440
+ */
441
+ nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
442
+
443
+
444
+ /**
445
+ * \ingroup compilation
446
+ * \brief nvrtcAddNameExpression notes the given name expression
447
+ * denoting the address of a __global__ function
448
+ * or __device__/__constant__ variable.
449
+ *
450
+ * The identical name expression string must be provided on a subsequent
451
+ * call to nvrtcGetLoweredName to extract the lowered name.
452
+ * \param [in] prog CUDA Runtime Compilation program.
453
+ * \param [in] name_expression constant expression denoting the address of
454
+ * a __global__ function or __device__/__constant__ variable.
455
+ * \return
456
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
457
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
458
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
459
+ * - \link #nvrtcResult NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION \endlink
460
+ *
461
+ * \see ::nvrtcGetLoweredName
462
+ */
463
+ nvrtcResult nvrtcAddNameExpression(nvrtcProgram prog,
464
+ const char * const name_expression);
465
+
466
+ /**
467
+ * \ingroup compilation
468
+ * \brief nvrtcGetLoweredName extracts the lowered (mangled) name
469
+ * for a __global__ function or __device__/__constant__ variable,
470
+ * and updates *lowered_name to point to it. The memory containing
471
+ * the name is released when the NVRTC program is destroyed by
472
+ * nvrtcDestroyProgram.
473
+ * The identical name expression must have been previously
474
+ * provided to nvrtcAddNameExpression.
475
+ *
476
+ * \param [in] prog CUDA Runtime Compilation program.
477
+ * \param [in] name_expression constant expression denoting the address of
478
+ * a __global__ function or __device__/__constant__ variable.
479
+ * \param [out] lowered_name initialized by the function to point to a
480
+ * C string containing the lowered (mangled)
481
+ * name corresponding to the provided name expression.
482
+ * \return
483
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
484
+ * - \link #nvrtcResult NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION \endlink
485
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
486
+ * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
487
+ * - \link #nvrtcResult NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID \endlink
488
+ *
489
+ * \see ::nvrtcAddNameExpression
490
+ */
491
+ nvrtcResult nvrtcGetLoweredName(nvrtcProgram prog,
492
+ const char *const name_expression,
493
+ const char** lowered_name);
494
+
495
+
496
+ /**
497
+ * \defgroup options Supported Compile Options
498
+ *
499
+ * NVRTC supports the compile options below.
500
+ * Option names with two preceding dashs (\c --) are long option names and
501
+ * option names with one preceding dash (\c -) are short option names.
502
+ * Short option names can be used instead of long option names.
503
+ * When a compile option takes an argument, an assignment operator (\c =)
504
+ * is used to separate the compile option argument from the compile option
505
+ * name, e.g., \c "--gpu-architecture=compute_60".
506
+ * Alternatively, the compile option name and the argument can be specified in
507
+ * separate strings without an assignment operator, .e.g,
508
+ * \c "--gpu-architecture" \c "compute_60".
509
+ * Single-character short option names, such as \c -D, \c -U, and \c -I, do
510
+ * not require an assignment operator, and the compile option name and the
511
+ * argument can be present in the same string with or without spaces between
512
+ * them.
513
+ * For instance, \c "-D=<def>", \c "-D<def>", and \c "-D <def>" are all
514
+ * supported.
515
+ *
516
+ * The valid compiler options are:
517
+ *
518
+ * - Compilation targets
519
+ * - \c --gpu-architecture=\<arch\> (\c -arch)\n
520
+ * Specify the name of the class of GPU architectures for which the
521
+ * input must be compiled.\n
522
+ * - Valid <c>\<arch\></c>s:
523
+ * - \c compute_50
524
+ * - \c compute_52
525
+ * - \c compute_53
526
+ * - \c compute_60
527
+ * - \c compute_61
528
+ * - \c compute_62
529
+ * - \c compute_70
530
+ * - \c compute_72
531
+ * - \c compute_75
532
+ * - \c compute_80
533
+ * - \c compute_87
534
+ * - \c compute_89
535
+ * - \c compute_90
536
+ * - \c compute_90a
537
+ * - \c sm_50
538
+ * - \c sm_52
539
+ * - \c sm_53
540
+ * - \c sm_60
541
+ * - \c sm_61
542
+ * - \c sm_62
543
+ * - \c sm_70
544
+ * - \c sm_72
545
+ * - \c sm_75
546
+ * - \c sm_80
547
+ * - \c sm_87
548
+ * - \c sm_89
549
+ * - \c sm_90
550
+ * - \c sm_90a
551
+ * - Default: \c compute_52
552
+ * - Separate compilation / whole-program compilation
553
+ * - \c --device-c (\c -dc)\n
554
+ * Generate relocatable code that can be linked with other relocatable
555
+ * device code. It is equivalent to --relocatable-device-code=true.
556
+ * - \c --device-w (\c -dw)\n
557
+ * Generate non-relocatable code. It is equivalent to
558
+ * \c --relocatable-device-code=false.
559
+ * - \c --relocatable-device-code={true|false} (\c -rdc)\n
560
+ * Enable (disable) the generation of relocatable device code.
561
+ * - Default: \c false
562
+ * - \c --extensible-whole-program (\c -ewp)\n
563
+ * Do extensible whole program compilation of device code.
564
+ * - Default: \c false
565
+ * - Debugging support
566
+ * - \c --device-debug (\c -G)\n
567
+ * Generate debug information. If --dopt is not specified,
568
+ * then turns off all optimizations.
569
+ * - \c --generate-line-info (\c -lineinfo)\n
570
+ * Generate line-number information.
571
+ * - Code generation
572
+ * - \c --dopt on (\c -dopt)\n
573
+ * - \c --dopt=on \n
574
+ * Enable device code optimization. When specified along with '-G', enables
575
+ * limited debug information generation for optimized device code (currently,
576
+ * only line number information).
577
+ * When '-G' is not specified, '-dopt=on' is implicit.
578
+ * - \c --ptxas-options \<options\> (\c -Xptxas)\n
579
+ * - \c --ptxas-options=\<options\> \n
580
+ * Specify options directly to ptxas, the PTX optimizing assembler.
581
+ * - \c --maxrregcount=\<N\> (\c -maxrregcount)\n
582
+ * Specify the maximum amount of registers that GPU functions can use.
583
+ * Until a function-specific limit, a higher value will generally
584
+ * increase the performance of individual GPU threads that execute this
585
+ * function. However, because thread registers are allocated from a
586
+ * global register pool on each GPU, a higher value of this option will
587
+ * also reduce the maximum thread block size, thereby reducing the amount
588
+ * of thread parallelism. Hence, a good maxrregcount value is the result
589
+ * of a trade-off. If this option is not specified, then no maximum is
590
+ * assumed. Value less than the minimum registers required by ABI will
591
+ * be bumped up by the compiler to ABI minimum limit.
592
+ * - \c --ftz={true|false} (\c -ftz)\n
593
+ * When performing single-precision floating-point operations, flush
594
+ * denormal values to zero or preserve denormal values.
595
+ * \c --use_fast_math implies \c --ftz=true.
596
+ * - Default: \c false
597
+ * - \c --prec-sqrt={true|false} (\c -prec-sqrt)\n
598
+ * For single-precision floating-point square root, use IEEE
599
+ * round-to-nearest mode or use a faster approximation.
600
+ * \c --use_fast_math implies \c --prec-sqrt=false.
601
+ * - Default: \c true
602
+ * - \c --prec-div={true|false} (\c -prec-div)\n
603
+ * For single-precision floating-point division and reciprocals, use IEEE
604
+ * round-to-nearest mode or use a faster approximation.
605
+ * \c --use_fast_math implies \c --prec-div=false.
606
+ * - Default: \c true
607
+ * - \c --fmad={true|false} (\c -fmad)\n
608
+ * Enables (disables) the contraction of floating-point multiplies and
609
+ * adds/subtracts into floating-point multiply-add operations (FMAD,
610
+ * FFMA, or DFMA). \c --use_fast_math implies \c --fmad=true.
611
+ * - Default: \c true
612
+ * - \c --use_fast_math (\c -use_fast_math)\n
613
+ * Make use of fast math operations.
614
+ * \c --use_fast_math implies \c --ftz=true \c --prec-div=false
615
+ * \c --prec-sqrt=false \c --fmad=true.
616
+ * - \c --extra-device-vectorization (\c -extra-device-vectorization)\n
617
+ * Enables more aggressive device code vectorization in the NVVM optimizer.
618
+ * - \c --modify-stack-limit={true|false} (\c -modify-stack-limit)\n
619
+ * On Linux, during compilation, use \c setrlimit() to increase stack size
620
+ * to maximum allowed. The limit is reset to the previous value at the
621
+ * end of compilation.
622
+ * Note: \c setrlimit() changes the value for the entire process.
623
+ * - Default: \c true
624
+ * - \c --dlink-time-opt (\c -dlto)\n
625
+ * Generate intermediate code for later link-time optimization.
626
+ * It implies \c -rdc=true.
627
+ * Note: when this option is used the nvrtcGetLTOIR API should be used,
628
+ * as PTX or Cubin will not be generated.
629
+ * - \c --gen-opt-lto (\c -gen-opt-lto)\n
630
+ * Run the optimizer passes before generating the LTO IR.
631
+ * - \c --optix-ir (\c -optix-ir)\n
632
+ * Generate OptiX IR. The Optix IR is only intended for consumption by OptiX
633
+ * through appropriate APIs. This feature is not supported with
634
+ * link-time-optimization (\c -dlto)\n.
635
+ * Note: when this option is used the nvrtcGetOptiX API should be used,
636
+ * as PTX or Cubin will not be generated.
637
+ * - \c --jump-table-density=[0-101] (\c -jtd)\n
638
+ * Specify the case density percentage in switch statements, and use it as
639
+ * a minimal threshold to determine whether jump table(brx.idx instruction)
640
+ * will be used to implement a switch statement. Default value is 101. The
641
+ * percentage ranges from 0 to 101 inclusively.
642
+ * - Preprocessing
643
+ * - \c --define-macro=\<def\> (\c -D)\n
644
+ * \c \<def\> can be either \c \<name\> or \c \<name=definitions\>.
645
+ * - \c \<name\> \n
646
+ * Predefine \c \<name\> as a macro with definition \c 1.
647
+ * - \c \<name\>=\<definition\> \n
648
+ * The contents of \c \<definition\> are tokenized and preprocessed
649
+ * as if they appeared during translation phase three in a \c \#define
650
+ * directive. In particular, the definition will be truncated by
651
+ * embedded new line characters.
652
+ * - \c --undefine-macro=\<def\> (\c -U)\n
653
+ * Cancel any previous definition of \c \<def\>.
654
+ * - \c --include-path=\<dir\> (\c -I)\n
655
+ * Add the directory \c \<dir\> to the list of directories to be
656
+ * searched for headers. These paths are searched after the list of
657
+ * headers given to ::nvrtcCreateProgram.
658
+ * - \c --pre-include=\<header\> (\c -include)\n
659
+ * Preinclude \c \<header\> during preprocessing.
660
+ * - \c --no-source-include (\c -no-source-include)
661
+ * The preprocessor by default adds the directory of each input sources
662
+ * to the include path. This option disables this feature and only
663
+ * considers the path specified explicitly.
664
+ * - Language Dialect
665
+ * - \c --std={c++03|c++11|c++14|c++17|c++20}
666
+ * (\c -std={c++11|c++14|c++17|c++20})\n
667
+ * Set language dialect to C++03, C++11, C++14, C++17 or C++20
668
+ * - Default: \c c++17
669
+ * - \c --builtin-move-forward={true|false} (\c -builtin-move-forward)\n
670
+ * Provide builtin definitions of \c std::move and \c std::forward,
671
+ * when C++11 or later language dialect is selected.
672
+ * - Default: \c true
673
+ * - \c --builtin-initializer-list={true|false}
674
+ * (\c -builtin-initializer-list)\n
675
+ * Provide builtin definitions of \c std::initializer_list class and
676
+ * member functions when C++11 or later language dialect is selected.
677
+ * - Default: \c true
678
+ * - Misc.
679
+ * - \c --disable-warnings (\c -w)\n
680
+ * Inhibit all warning messages.
681
+ * - \c --restrict (\c -restrict)\n
682
+ * Programmer assertion that all kernel pointer parameters are restrict
683
+ * pointers.
684
+ * - \c --device-as-default-execution-space
685
+ * (\c -default-device)\n
686
+ * Treat entities with no execution space annotation as \c __device__
687
+ * entities.
688
+ * - \c --device-int128 (\c -device-int128)\n
689
+ * Allow the \c __int128 type in device code. Also causes the macro \c __CUDACC_RTC_INT128__
690
+ * to be defined.
691
+ * - \c --optimization-info=\<kind\> (\c -opt-info)\n
692
+ * Provide optimization reports for the specified kind of optimization.
693
+ * The following kind tags are supported:
694
+ * - \c inline : emit a remark when a function is inlined.
695
+ * - \c --display-error-number (\c -err-no)\n
696
+ * Display diagnostic number for warning messages. (Default)
697
+ * - \c --no-display-error-number (\c -no-err-no)\n
698
+ * Disables the display of a diagnostic number for warning messages.
699
+ * - \c --diag-error=<error-number>,... (\c -diag-error)\n
700
+ * Emit error for specified diagnostic message number(s). Message numbers can be separated by comma.
701
+ * - \c --diag-suppress=<error-number>,... (\c -diag-suppress)\n
702
+ * Suppress specified diagnostic message number(s). Message numbers can be separated by comma.
703
+ * - \c --diag-warn=<error-number>,... (\c -diag-warn)\n
704
+ * Emit warning for specified diagnostic message number(s). Message numbers can be separated by comma.
705
+ * - \c --brief-diagnostics={true|false} (\c -brief-diag)\n
706
+ * This option disables or enables showing source line and column info
707
+ * in a diagnostic.
708
+ * The --brief-diagnostics=true will not show the source line and column info.
709
+ * - Default: \c false
710
+ * - \c --time=<file-name> (\c -time)\n
711
+ * Generate a comma separated value table with the time taken by each compilation
712
+ * phase, and append it at the end of the file given as the option argument.
713
+ * If the file does not exist, the column headings are generated in the first row
714
+ * of the table. If the file name is '-', the timing data is written to the compilation log.
715
+ * - \c --split-compile=<number of threads> (\c -split-compile=<number of threads>)\n
716
+ * Perform compiler optimizations in parallel.
717
+ * Split compilation attempts to reduce compile time by enabling the compiler to run certain
718
+ * optimization passes concurrently. This option accepts a numerical value that specifies the
719
+ * maximum number of threads the compiler can use. One can also allow the compiler to use the maximum
720
+ * threads available on the system by setting --split-compile=0.
721
+ * Setting --split-compile=1 will cause this option to be ignored.
722
+ * - \c --fdevice-syntax-only (\c -fdevice-syntax-only)\n
723
+ * Ends device compilation after front-end syntax checking. This option does not generate valid
724
+ * device code.
725
+ * - \c --minimal (\c -minimal)\n
726
+ * Omit certain language features to reduce compile time for small programs.
727
+ * In particular, the following are omitted:
728
+ * - Texture and surface functions and associated types, e.g., \c cudaTextureObject_t.
729
+ * - CUDA Runtime Functions that are provided by the cudadevrt device code library,
730
+ * typically named with prefix "cuda", e.g., \c cudaMalloc.
731
+ * - Kernel launch from device code.
732
+ * - Types and macros associated with CUDA Runtime and Driver APIs,
733
+ * provided by cuda/tools/cudart/driver_types.h, typically named with prefix "cuda", e.g., \c cudaError_t.
734
+ *
735
+ */
736
+
737
+ #ifdef __cplusplus
738
+ }
739
+ #endif /* __cplusplus */
740
+
741
+
742
+ /* The utility function 'nvrtcGetTypeName' is not available by default. Define
743
+ the macro 'NVRTC_GET_TYPE_NAME' to a non-zero value to make it available.
744
+ */
745
+
746
+ #if NVRTC_GET_TYPE_NAME || __DOXYGEN_ONLY__
747
+
748
+ #if NVRTC_USE_CXXABI || __clang__ || __GNUC__ || __DOXYGEN_ONLY__
749
+ #include <cxxabi.h>
750
+ #include <cstdlib>
751
+
752
+ #elif defined(_WIN32)
753
+ #include <Windows.h>
754
+ #include <DbgHelp.h>
755
+ #endif /* NVRTC_USE_CXXABI || __clang__ || __GNUC__ */
756
+
757
+
758
+ #include <string>
759
+ #include <typeinfo>
760
+
761
+ template <typename T> struct __nvrtcGetTypeName_helper_t { };
762
+
763
+ /*************************************************************************//**
764
+ *
765
+ * \defgroup hosthelper Host Helper
766
+ *
767
+ * NVRTC defines the following functions for easier interaction with host code.
768
+ *
769
+ ****************************************************************************/
770
+
771
+ /**
772
+ * \ingroup hosthelper
773
+ * \brief nvrtcGetTypeName stores the source level name of a type in the given
774
+ * std::string location.
775
+ *
776
+ * This function is only provided when the macro NVRTC_GET_TYPE_NAME is
777
+ * defined with a non-zero value. It uses abi::__cxa_demangle or UnDecorateSymbolName
778
+ * function calls to extract the type name, when using gcc/clang or cl.exe compilers,
779
+ * respectively. If the name extraction fails, it will return NVRTC_INTERNAL_ERROR,
780
+ * otherwise *result is initialized with the extracted name.
781
+ *
782
+ * Windows-specific notes:
783
+ * - nvrtcGetTypeName() is not multi-thread safe because it calls UnDecorateSymbolName(),
784
+ * which is not multi-thread safe.
785
+ * - The returned string may contain Microsoft-specific keywords such as __ptr64 and __cdecl.
786
+ *
787
+ * \param [in] tinfo: reference to object of type std::type_info for a given type.
788
+ * \param [in] result: pointer to std::string in which to store the type name.
789
+ * \return
790
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
791
+ * - \link #nvrtcResult NVRTC_ERROR_INTERNAL_ERROR \endlink
792
+ *
793
+ */
794
+ inline nvrtcResult nvrtcGetTypeName(const std::type_info &tinfo, std::string *result)
795
+ {
796
+ #if USE_CXXABI || __clang__ || __GNUC__
797
+ const char *name = tinfo.name();
798
+ int status;
799
+ char *undecorated_name = abi::__cxa_demangle(name, 0, 0, &status);
800
+ if (status == 0) {
801
+ *result = undecorated_name;
802
+ free(undecorated_name);
803
+ return NVRTC_SUCCESS;
804
+ }
805
+ #elif defined(_WIN32)
806
+ const char *name = tinfo.raw_name();
807
+ if (!name || *name != '.') {
808
+ return NVRTC_ERROR_INTERNAL_ERROR;
809
+ }
810
+ char undecorated_name[4096];
811
+ //name+1 skips over the '.' prefix
812
+ if(UnDecorateSymbolName(name+1, undecorated_name,
813
+ sizeof(undecorated_name) / sizeof(*undecorated_name),
814
+ //note: doesn't seem to work correctly without UNDNAME_NO_ARGUMENTS.
815
+ UNDNAME_NO_ARGUMENTS | UNDNAME_NAME_ONLY ) ) {
816
+ *result = undecorated_name;
817
+ return NVRTC_SUCCESS;
818
+ }
819
+ #endif /* USE_CXXABI || __clang__ || __GNUC__ */
820
+
821
+ return NVRTC_ERROR_INTERNAL_ERROR;
822
+ }
823
+
824
+ /**
825
+ * \ingroup hosthelper
826
+ * \brief nvrtcGetTypeName stores the source level name of the template type argument
827
+ * T in the given std::string location.
828
+ *
829
+ * This function is only provided when the macro NVRTC_GET_TYPE_NAME is
830
+ * defined with a non-zero value. It uses abi::__cxa_demangle or UnDecorateSymbolName
831
+ * function calls to extract the type name, when using gcc/clang or cl.exe compilers,
832
+ * respectively. If the name extraction fails, it will return NVRTC_INTERNAL_ERROR,
833
+ * otherwise *result is initialized with the extracted name.
834
+ *
835
+ * Windows-specific notes:
836
+ * - nvrtcGetTypeName() is not multi-thread safe because it calls UnDecorateSymbolName(),
837
+ * which is not multi-thread safe.
838
+ * - The returned string may contain Microsoft-specific keywords such as __ptr64 and __cdecl.
839
+ *
840
+ * \param [in] result: pointer to std::string in which to store the type name.
841
+ * \return
842
+ * - \link #nvrtcResult NVRTC_SUCCESS \endlink
843
+ * - \link #nvrtcResult NVRTC_ERROR_INTERNAL_ERROR \endlink
844
+ *
845
+ */
846
+
847
+ template <typename T>
848
+ nvrtcResult nvrtcGetTypeName(std::string *result)
849
+ {
850
+ nvrtcResult res = nvrtcGetTypeName(typeid(__nvrtcGetTypeName_helper_t<T>),
851
+ result);
852
+ if (res != NVRTC_SUCCESS)
853
+ return res;
854
+
855
+ std::string repr = *result;
856
+ std::size_t idx = repr.find("__nvrtcGetTypeName_helper_t");
857
+ idx = (idx != std::string::npos) ? repr.find("<", idx) : idx;
858
+ std::size_t last_idx = repr.find_last_of('>');
859
+ if (idx == std::string::npos || last_idx == std::string::npos) {
860
+ return NVRTC_ERROR_INTERNAL_ERROR;
861
+ }
862
+ ++idx;
863
+ *result = repr.substr(idx, last_idx - idx);
864
+ return NVRTC_SUCCESS;
865
+ }
866
+
867
+ #endif /* NVRTC_GET_TYPE_NAME */
868
+
869
+ #endif /* __NVRTC_H__ */
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (194 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (192 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/async.h ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+ #ifndef _CG_ASYNC_H
50
+ #define _CG_ASYNC_H
51
+
52
+ #include "helpers.h"
53
+ #include "info.h"
54
+
55
+ #include <cuda_pipeline.h>
56
+
57
+ _CG_BEGIN_NAMESPACE
58
+
59
+ namespace details {
60
+ // Groups supported by memcpy_async
61
+ template <class TyGroup>
62
+ struct _async_copy_group_supported : public _CG_STL_NAMESPACE::false_type {};
63
+
64
+ template <unsigned int Sz, typename TyPar>
65
+ struct _async_copy_group_supported<cooperative_groups::thread_block_tile<Sz, TyPar>>
66
+ : public _CG_STL_NAMESPACE::true_type {};
67
+ template <>
68
+ struct _async_copy_group_supported<cooperative_groups::coalesced_group> : public _CG_STL_NAMESPACE::true_type {};
69
+ template <>
70
+ struct _async_copy_group_supported<cooperative_groups::thread_block> : public _CG_STL_NAMESPACE::true_type {};
71
+
72
+ template <class TyGroup>
73
+ using async_copy_group_supported = _async_copy_group_supported<details::remove_qual<TyGroup>>;
74
+
75
+ // Groups that require optimization
76
+ template <class TyGroup>
77
+ struct _async_copy_optimize_tile : public _CG_STL_NAMESPACE::false_type {};
78
+
79
+ template <typename TyPar>
80
+ struct _async_copy_optimize_tile<cooperative_groups::thread_block_tile<1, TyPar>>
81
+ : public _CG_STL_NAMESPACE::false_type {};
82
+
83
+ template <unsigned int Sz, typename TyPar>
84
+ struct _async_copy_optimize_tile<cooperative_groups::thread_block_tile<Sz, TyPar>>
85
+ : public _CG_STL_NAMESPACE::true_type {};
86
+
87
+ template <class TyGroup>
88
+ using async_copy_optimize_tile = _async_copy_optimize_tile<details::remove_qual<TyGroup>>;
89
+
90
+ // SFINAE helpers for tile optimizations
91
+ template <class TyGroup>
92
+ using enable_tile_optimization =
93
+ typename _CG_STL_NAMESPACE::enable_if<async_copy_optimize_tile<TyGroup>::value, void *>::type;
94
+
95
+ template <class TyGroup>
96
+ using disable_tile_optimization =
97
+ typename _CG_STL_NAMESPACE::enable_if<!async_copy_optimize_tile<TyGroup>::value, void *>::type;
98
+
99
+ // Segment for punning to aligned types
100
+ template <unsigned int N>
101
+ struct _Segment {
102
+ int _seg[N];
103
+ };
104
+
105
+ // Trivial layout guaranteed-aligned copy-async compatible segments
106
+ template <unsigned int N>
107
+ struct Segment;
108
+ template <>
109
+ struct __align__(4) Segment<1> : public _Segment<1>{};
110
+ template <>
111
+ struct __align__(8) Segment<2> : public _Segment<2>{};
112
+ template <>
113
+ struct __align__(16) Segment<4> : public _Segment<4>{};
114
+
115
+ // Interleaved element by element copies from source to dest
116
+ template <typename TyGroup, typename TyElem>
117
+ _CG_STATIC_QUALIFIER void inline_copy(TyGroup &group, TyElem *__restrict__ dst, const TyElem *__restrict__ src,
118
+ size_t count) {
119
+ const unsigned int rank = group.thread_rank();
120
+ const unsigned int stride = group.size();
121
+
122
+ for (size_t idx = rank; idx < count; idx += stride) {
123
+ dst[idx] = src[idx];
124
+ }
125
+ }
126
+
127
+ template <typename TyGroup, typename TyElem, enable_tile_optimization<TyGroup> = nullptr>
128
+ _CG_STATIC_QUALIFIER void accelerated_async_copy(TyGroup &group, TyElem *__restrict__ dst,
129
+ const TyElem *__restrict__ src, size_t count) {
130
+ static_assert(async_copy_group_supported<TyGroup>::value,
131
+ "Async copy is only supported for groups that represent private shared memory");
132
+
133
+ if (count == 0) {
134
+ return;
135
+ }
136
+
137
+ const bool dstIsNotShared = !__isShared(dst);
138
+ const bool srcIsNotGlobal = !__isGlobal(src);
139
+
140
+ if (dstIsNotShared || srcIsNotGlobal) {
141
+ inline_copy(group, dst, src, count);
142
+ return;
143
+ }
144
+
145
+ const unsigned int stride = group.size();
146
+ const unsigned int rank = group.thread_rank();
147
+ // Efficient copies require warps to operate on the same amount of work at each step.
148
+ // remainders are handled in a separate stage to prevent branching
149
+ const unsigned int subWarpMask = (stride - 1);
150
+ const unsigned int subwarpCopies = (subWarpMask & (unsigned int)count);
151
+ const unsigned int maxSubwarpRank = min(rank, subwarpCopies - 1);
152
+
153
+ const size_t warpCopies = (count & (~subWarpMask));
154
+
155
+ for (size_t idx = 0; idx < warpCopies; idx += stride) {
156
+ size_t _srcIdx = rank + idx;
157
+ size_t _dstIdx = rank + idx;
158
+ __pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem));
159
+ }
160
+
161
+ if (subwarpCopies) {
162
+ size_t _srcIdx = warpCopies + maxSubwarpRank;
163
+ size_t _dstIdx = warpCopies + maxSubwarpRank;
164
+ __pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem));
165
+ }
166
+ }
167
+
168
+ template <typename TyGroup, typename TyElem, disable_tile_optimization<TyGroup> = nullptr>
169
+ _CG_STATIC_QUALIFIER void accelerated_async_copy(TyGroup &group, TyElem *__restrict__ dst,
170
+ const TyElem *__restrict__ src, size_t count) {
171
+ static_assert(async_copy_group_supported<TyGroup>::value,
172
+ "Async copy is only supported for groups that represent private shared memory");
173
+
174
+ const bool dstIsNotShared = !__isShared(dst);
175
+ const bool srcIsNotGlobal = !__isGlobal(src);
176
+
177
+ if (dstIsNotShared || srcIsNotGlobal) {
178
+ inline_copy(group, dst, src, count);
179
+ return;
180
+ }
181
+
182
+ unsigned int stride = group.size();
183
+ unsigned int rank = group.thread_rank();
184
+
185
+ for (size_t idx = rank; idx < count; idx += stride) {
186
+ size_t _srcIdx = idx;
187
+ size_t _dstIdx = idx;
188
+ __pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem));
189
+ }
190
+ }
191
+
192
+ // Determine best possible alignment given an input and initial conditions
193
+ // Attempts to generate as little code as possible, most likely should only be used with 1 and 2 byte alignments
194
+ template <unsigned int MinAlignment, unsigned int MaxAlignment>
195
+ _CG_STATIC_QUALIFIER uint32_t find_best_alignment(void *__restrict__ dst, const void *__restrict__ src) {
196
+ // Narrowing conversion intentional
197
+ uint32_t base1 = (uint32_t) reinterpret_cast<uintptr_t>(src);
198
+ uint32_t base2 = (uint32_t) reinterpret_cast<uintptr_t>(dst);
199
+
200
+ uint32_t diff = ((base1) ^ (base2)) & (MaxAlignment - 1);
201
+
202
+ // range [MaxAlignment, alignof(elem)], step: x >> 1
203
+ // over range of possible alignments, choose best available out of range
204
+ uint32_t out = MaxAlignment;
205
+ #pragma unroll
206
+ for (uint32_t alignment = (MaxAlignment >> 1); alignment >= MinAlignment; alignment >>= 1) {
207
+ if (alignment & diff)
208
+ out = alignment;
209
+ }
210
+
211
+ return out;
212
+ }
213
+
214
+ // Determine best possible alignment given an input and initial conditions
215
+ // Attempts to generate as little code as possible, most likely should only be used with 1 and 2 byte alignments
216
+ template <typename TyType, typename TyGroup>
217
+ _CG_STATIC_QUALIFIER void copy_like(const TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src,
218
+ size_t count) {
219
+ const char *src = reinterpret_cast<const char *>(_src);
220
+ char *dst = reinterpret_cast<char *>(_dst);
221
+
222
+ constexpr uint32_t targetAlignment = (uint32_t)alignof(TyType);
223
+
224
+ uint32_t base = (uint32_t) reinterpret_cast<uintptr_t>(src);
225
+ uint32_t alignOffset = ((~base) + 1) & (targetAlignment - 1);
226
+
227
+ inline_copy(group, dst, src, alignOffset);
228
+ count -= alignOffset;
229
+ src += alignOffset;
230
+ dst += alignOffset;
231
+
232
+ // Copy using the best available alignment, async_copy expects n-datums, not bytes
233
+ size_t asyncCount = count / sizeof(TyType);
234
+ accelerated_async_copy(group, reinterpret_cast<TyType *>(dst), reinterpret_cast<const TyType *>(src), asyncCount);
235
+ asyncCount *= sizeof(TyType);
236
+
237
+ count -= asyncCount;
238
+ src += asyncCount;
239
+ dst += asyncCount;
240
+ inline_copy(group, dst, src, count);
241
+ }
242
+
243
+ // We must determine alignment and manually align src/dst ourselves
244
+ template <size_t AlignHint>
245
+ struct _memcpy_async_align_dispatch {
246
+ template <typename TyGroup>
247
+ _CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ dst, const void *__restrict__ src, size_t count) {
248
+ uint32_t alignment = find_best_alignment<AlignHint, 16>(dst, src);
249
+
250
+ // Avoid copying the extra bytes if desired copy count is smaller
251
+ alignment = count < alignment ? AlignHint : alignment;
252
+
253
+ switch (alignment) {
254
+ default:
255
+ case 1:
256
+ inline_copy(group, reinterpret_cast<char *>(dst), reinterpret_cast<const char *>(src), count);
257
+ break;
258
+ case 2:
259
+ inline_copy(group, reinterpret_cast<short *>(dst), reinterpret_cast<const short *>(src), count >> 1);
260
+ break;
261
+ case 4:
262
+ copy_like<Segment<1>>(group, dst, src, count);
263
+ break;
264
+ case 8:
265
+ copy_like<Segment<2>>(group, dst, src, count);
266
+ break;
267
+ case 16:
268
+ copy_like<Segment<4>>(group, dst, src, count);
269
+ break;
270
+ }
271
+ }
272
+ };
273
+
274
+ // Specialization for 4 byte alignments
275
+ template <>
276
+ struct _memcpy_async_align_dispatch<4> {
277
+ template <typename TyGroup>
278
+ _CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src,
279
+ size_t count) {
280
+ const Segment<1> *src = reinterpret_cast<const Segment<1> *>(_src);
281
+ Segment<1> *dst = reinterpret_cast<Segment<1> *>(_dst);
282
+
283
+ // Dispatch straight to aligned LDGSTS calls
284
+ accelerated_async_copy(group, dst, src, count / sizeof(*dst));
285
+ }
286
+ };
287
+
288
+ // Specialization for 8 byte alignments
289
+ template <>
290
+ struct _memcpy_async_align_dispatch<8> {
291
+ template <typename TyGroup>
292
+ _CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src,
293
+ size_t count) {
294
+ const Segment<2> *src = reinterpret_cast<const Segment<2> *>(_src);
295
+ Segment<2> *dst = reinterpret_cast<Segment<2> *>(_dst);
296
+
297
+ // Dispatch straight to aligned LDGSTS calls
298
+ accelerated_async_copy(group, dst, src, count / sizeof(*dst));
299
+ }
300
+ };
301
+
302
+ // Alignments over 16 are truncated to 16 and bypass alignment
303
+ // This is the highest performing memcpy available
304
+ template <>
305
+ struct _memcpy_async_align_dispatch<16> {
306
+ template <typename TyGroup>
307
+ _CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src,
308
+ size_t count) {
309
+ const Segment<4> *src = reinterpret_cast<const Segment<4> *>(_src);
310
+ Segment<4> *dst = reinterpret_cast<Segment<4> *>(_dst);
311
+
312
+ // Dispatch straight to aligned LDGSTS calls
313
+ accelerated_async_copy(group, dst, src, count / sizeof(*dst));
314
+ }
315
+ };
316
+
317
+ // byte-wide API
318
+ template <size_t Alignment, class TyGroup>
319
+ _CG_STATIC_QUALIFIER void _memcpy_async_dispatch_to_aligned_copy(const TyGroup &group, void *__restrict__ _dst,
320
+ const void *__restrict__ _src, size_t count) {
321
+ static_assert(!(Alignment & (Alignment - 1)), "Known static alignment dispatch must be a power of 2");
322
+ details::_memcpy_async_align_dispatch<Alignment>::copy(group, _dst, _src, count);
323
+ }
324
+
325
+ // Internal dispatch APIs
326
+ // These deduce the alignments and sizes necessary to invoke the underlying copy engine
327
+ template <typename Ty>
328
+ using is_void = _CG_STL_NAMESPACE::is_same<Ty, void>;
329
+
330
+ template <typename Ty>
331
+ using enable_if_not_void = typename _CG_STL_NAMESPACE::enable_if<!is_void<Ty>::value, void *>::type;
332
+
333
+ template <typename Ty>
334
+ using enable_if_void = typename _CG_STL_NAMESPACE::enable_if<is_void<Ty>::value, void *>::type;
335
+
336
+ template <typename Ty>
337
+ using enable_if_integral =
338
+ typename _CG_STL_NAMESPACE::enable_if<_CG_STL_NAMESPACE::is_integral<Ty>::value, void *>::type;
339
+
340
+ // byte-wide API using aligned_sized_t
341
+ template <class TyGroup, template <size_t> typename Alignment, size_t Hint>
342
+ _CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, void *__restrict__ _dst,
343
+ const void *__restrict__ _src, const Alignment<Hint> &count) {
344
+ constexpr size_t _align = (Hint > 16) ? 16 : Hint;
345
+
346
+ details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, _dst, _src, (size_t)count);
347
+ }
348
+
349
+ // byte-wide API using type for aligment
350
+ template <class TyGroup, typename TyElem, typename TySize, size_t Hint = alignof(TyElem),
351
+ enable_if_not_void<TyElem> = nullptr, enable_if_integral<TySize> = nullptr>
352
+ _CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, TyElem *__restrict__ _dst,
353
+ const TyElem *__restrict__ _src, const TySize& count) {
354
+ constexpr size_t _align = (Hint > 16) ? 16 : Hint;
355
+
356
+ details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, _dst, _src, count);
357
+ }
358
+
359
+ // byte-wide API with full alignment deduction required
360
+ template <class TyGroup, typename TyElem, typename TySize, enable_if_void<TyElem> = nullptr,
361
+ enable_if_integral<TySize> = nullptr>
362
+ _CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, TyElem *__restrict__ _dst,
363
+ const TyElem *__restrict__ _src, const TySize& count) {
364
+ details::_memcpy_async_dispatch_to_aligned_copy<1>(group, _dst, _src, count);
365
+ }
366
+
367
+ // 1d-datum API
368
+ template <class TyGroup, typename TyElem, size_t Hint = alignof(TyElem)>
369
+ _CG_STATIC_QUALIFIER void _memcpy_async_datum(const TyGroup &group, TyElem *__restrict__ dst, const size_t dstCount,
370
+ const TyElem *__restrict__ src, const size_t srcCount) {
371
+ constexpr unsigned int _align = Hint;
372
+ const size_t totalCount = min(dstCount, srcCount) * sizeof(TyElem);
373
+
374
+ details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, dst, src, totalCount);
375
+ }
376
+
377
+ // 1d-datum API using aligned_size_t
378
+ template <class TyGroup, typename TyElem, template <size_t> typename Alignment, size_t Hint>
379
+ _CG_STATIC_QUALIFIER void _memcpy_async_datum(const TyGroup &group, TyElem *__restrict__ dst, const Alignment<Hint> &dstCount,
380
+ const TyElem *__restrict__ src, const Alignment<Hint> &srcCount) {
381
+ constexpr unsigned int _align = Hint;
382
+ const size_t totalCount = min((size_t)dstCount, (size_t)srcCount) * sizeof(TyElem);
383
+
384
+ details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, dst, src, totalCount);
385
+ }
386
+
387
+ } // namespace details
388
+
389
+ /*
390
+ * Group submit batch of async-copy to cover contiguous 1D array
391
+ * and commit that batch to eventually wait for completion.
392
+ */
393
+ template <class TyGroup, typename TyElem, typename TySizeT>
394
+ _CG_STATIC_QUALIFIER void memcpy_async(const TyGroup &group, TyElem *__restrict__ _dst, const TyElem *__restrict__ _src,
395
+ const TySizeT &count) {
396
+ details::_memcpy_async_bytes(group, _dst, _src, count);
397
+ __pipeline_commit();
398
+ }
399
+
400
+ /*
401
+ * Group submit batch of async-copy to cover contiguous 1D array
402
+ * and commit that batch to eventually wait for completion.
403
+ * Object counts are in datum sized chunks, not bytes.
404
+ */
405
+ template <class TyGroup, class TyElem, typename DstLayout, typename SrcLayout>
406
+ _CG_STATIC_QUALIFIER void memcpy_async(const TyGroup &group, TyElem *__restrict__ dst, const DstLayout &dstLayout,
407
+ const TyElem *__restrict__ src, const SrcLayout &srcLayout) {
408
+ details::_memcpy_async_datum(group, dst, dstLayout, src, srcLayout);
409
+ __pipeline_commit();
410
+ }
411
+
412
+ /* Group wait for prior Nth stage of memcpy_async to complete. */
413
+ template <unsigned int Stage, class TyGroup>
414
+ _CG_STATIC_QUALIFIER void wait_prior(const TyGroup &group) {
415
+ __pipeline_wait_prior(Stage);
416
+ group.sync();
417
+ }
418
+
419
+ /* Group wait all previously submitted memcpy_async to complete. */
420
+ template <class TyGroup>
421
+ _CG_STATIC_QUALIFIER void wait(const TyGroup &group) {
422
+ __pipeline_wait_prior(0);
423
+ group.sync();
424
+ }
425
+
426
+ /***************** CG APIs including pipeline are deprecated *****************/
427
+
428
+ /* Group submit batch of async-copy to cover of contiguous 1D array
429
+ to a pipeline and commit the batch*/
430
+ template <class TyGroup, class TyElem>
431
+ _CG_DEPRECATED _CG_STATIC_QUALIFIER void memcpy_async(TyGroup &group, TyElem *dst, size_t dstCount, const TyElem *src, size_t srcCount,
432
+ nvcuda::experimental::pipeline &pipe) {
433
+ details::_memcpy_async_datum(group, dst, dstCount, src, srcCount);
434
+ pipe.commit();
435
+ }
436
+
437
+ /* Group wait for prior Nth stage of memcpy_async to complete. */
438
+ template <unsigned int Stage, class TyGroup>
439
+ _CG_DEPRECATED _CG_STATIC_QUALIFIER void wait_prior(TyGroup &group, nvcuda::experimental::pipeline &pipe) {
440
+ pipe.wait_prior<Stage>();
441
+ group.sync();
442
+ }
443
+
444
+ /* Group wait for stage-S of memcpy_async to complete. */
445
+ template <class TyGroup>
446
+ _CG_DEPRECATED _CG_STATIC_QUALIFIER void wait(TyGroup &group, nvcuda::experimental::pipeline &pipe, size_t stage) {
447
+ pipe.wait(stage);
448
+ group.sync();
449
+ }
450
+ _CG_END_NAMESPACE
451
+
452
+ #endif // _CG_ASYNC_H
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/coalesced_scan.h ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+ #ifndef _CG_COALESCED_SCAN_H_
50
+ #define _CG_COALESCED_SCAN_H_
51
+
52
+ #include "info.h"
53
+ #include "helpers.h"
54
+ #include "cooperative_groups.h"
55
+ #include "partitioning.h"
56
+ #include "functional.h"
57
+
58
+ _CG_BEGIN_NAMESPACE
59
+
60
+ namespace details {
61
+
62
+ template <typename TyGroup, typename TyVal, typename TyOp>
63
+ _CG_QUALIFIER auto inclusive_scan_contiguous(const TyGroup& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) {
64
+ auto out = val;
65
+ for (int mask = 1; mask < group.size(); mask <<= 1) {
66
+ auto tmp = group.shfl_up(out, mask);
67
+ if (mask <= group.thread_rank()) {
68
+ out = op(out, tmp);
69
+ }
70
+ }
71
+
72
+ return out;
73
+ }
74
+
75
+ template <typename TyGroup, typename TyVal, typename TyOp>
76
+ _CG_QUALIFIER auto inclusive_scan_non_contiguous(const TyGroup& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) {
77
+ const unsigned int groupSize = group.size();
78
+ auto out = val;
79
+
80
+ const unsigned int mask = details::_coalesced_group_data_access::get_mask(group);
81
+ unsigned int lanemask = details::lanemask32_lt() & mask;
82
+ unsigned int srcLane = details::laneid();
83
+
84
+ const unsigned int base = __ffs(mask)-1; /* lane with rank == 0 */
85
+ const unsigned int rank = __popc(lanemask);
86
+
87
+ for (unsigned int i = 1, j = 1; i < groupSize; i <<= 1) {
88
+ if (i <= rank) {
89
+ srcLane -= j;
90
+ j = i; /* maximum possible lane */
91
+
92
+ unsigned int begLane = base + rank - i; /* minimum possible lane */
93
+
94
+ /* Next source lane is in the range [ begLane .. srcLane ]
95
+ * If begLane < srcLane then do a binary search.
96
+ */
97
+ while (begLane < srcLane) {
98
+ const unsigned int halfLane = (begLane + srcLane) >> 1;
99
+ const unsigned int halfMask = lanemask >> halfLane;
100
+ const unsigned int d = __popc(halfMask);
101
+ if (d < i) {
102
+ srcLane = halfLane - 1; /* halfLane too large */
103
+ }
104
+ else if ((i < d) || !(halfMask & 0x01)) {
105
+ begLane = halfLane + 1; /* halfLane too small */
106
+ }
107
+ else {
108
+ begLane = srcLane = halfLane; /* happen to hit */
109
+ }
110
+ }
111
+ }
112
+
113
+ auto tmp = details::tile::shuffle_dispatch<TyVal>::shfl(out, mask, srcLane, 32);
114
+ if (i <= rank) {
115
+ out = op(out, tmp);
116
+ }
117
+ }
118
+ return out;
119
+ }
120
+
121
+ template <unsigned int TySize, typename ParentT, typename TyVal, typename TyOp>
122
+ _CG_QUALIFIER auto coalesced_inclusive_scan(const __single_warp_thread_block_tile<TySize, ParentT>& group,
123
+ TyVal&& val,
124
+ TyOp&& op) -> decltype(op(val, val)) {
125
+ return inclusive_scan_contiguous(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp>(op));
126
+ }
127
+
128
+ template <typename TyVal, typename TyOp>
129
+ _CG_QUALIFIER auto coalesced_inclusive_scan(const coalesced_group& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) {
130
+ if (group.size() == 32) {
131
+ return inclusive_scan_contiguous(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp>(op));
132
+ }
133
+ else {
134
+ return inclusive_scan_non_contiguous(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp>(op));
135
+ }
136
+ }
137
+
138
+ template <bool IntegralOptimized>
139
+ struct scan_choose_convertion;
140
+
141
+ template<>
142
+ struct scan_choose_convertion<true> {
143
+ template <typename TyGroup, typename TyRes, typename TyVal>
144
+ _CG_STATIC_QUALIFIER details::remove_qual<TyVal> convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val) {
145
+ return result - val;
146
+ }
147
+ };
148
+
149
+ template<>
150
+ struct scan_choose_convertion<false> {
151
+ template <typename TyGroup, typename TyRes, typename TyVal>
152
+ _CG_STATIC_QUALIFIER details::remove_qual<TyVal> convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val) {
153
+ auto ret = group.shfl_up(result, 1);
154
+ if (group.thread_rank() == 0) {
155
+ return {};
156
+ }
157
+ else {
158
+ return ret;
159
+ }
160
+ }
161
+ };
162
+
163
+ template <typename TyGroup, typename TyRes, typename TyVal, typename TyFn>
164
+ _CG_QUALIFIER auto convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
165
+ using conversion = scan_choose_convertion<_CG_STL_NAMESPACE::is_same<remove_qual<TyFn>, cooperative_groups::plus<remove_qual<TyVal>>>::value
166
+ && _CG_STL_NAMESPACE::is_integral<remove_qual<TyVal>>::value>;
167
+ return conversion::convert_inclusive_to_exclusive(group, result, _CG_STL_NAMESPACE::forward<TyVal>(val));
168
+ }
169
+
170
+ } // details
171
+
172
+ _CG_END_NAMESPACE
173
+
174
+ #endif // _CG_COALESCED_SCAN_H_
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/driver_abi.h ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+ #ifndef _CG_DRIVER_API_H
50
+ #define _CG_DRIVER_API_H
51
+
52
+ #include "info.h"
53
+
54
+ _CG_BEGIN_NAMESPACE
55
+
56
+ namespace details {
57
+ template <unsigned int RegId>
58
+ _CG_QUALIFIER unsigned int load_env_reg() {
59
+ // Abort by default
60
+ _CG_ABORT();
61
+ return 0;
62
+ }
63
+
64
+ template <unsigned int HiReg, unsigned int LoReg>
65
+ _CG_QUALIFIER unsigned long long load_env_reg64() {
66
+ unsigned long long registerLo = load_env_reg<LoReg>();
67
+ unsigned long long registerHi = load_env_reg<HiReg>();
68
+
69
+ return (registerHi << 32) | registerLo;
70
+ }
71
+
72
+ // inline PTX for accessing registers requires an immediate for the special reg
73
+ # define LOAD_ENVREG(NUMBER) \
74
+ template <> _CG_QUALIFIER unsigned int load_env_reg<NUMBER>() { \
75
+ unsigned int r; \
76
+ asm ("mov.u32 %0, %%envreg" #NUMBER ";" : "=r"(r)); \
77
+ return r; \
78
+ }
79
+
80
+ // Instantiate loaders for registers used
81
+ LOAD_ENVREG(0);
82
+ LOAD_ENVREG(1);
83
+ LOAD_ENVREG(2);
84
+ # undef LOAD_ENVREG
85
+
86
+ struct grid_workspace {
87
+ unsigned int wsSize;
88
+ unsigned int barrier;
89
+ };
90
+
91
+ _CG_QUALIFIER grid_workspace* get_grid_workspace() {
92
+ unsigned long long gridWsAbiAddress = load_env_reg64<1, 2>();
93
+ // Interpret the address from envreg 1 and 2 as the driver's grid workspace
94
+ return (reinterpret_cast<grid_workspace*>(gridWsAbiAddress));
95
+ }
96
+ }
97
+ _CG_END_NAMESPACE
98
+
99
+ #endif // _CG_DRIVER_API_H
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/info.h ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 1993-2021 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+
50
+
51
+ #ifndef _CG_INFO_H_
52
+ #define _CG_INFO_H_
53
+ /*
54
+ ** Define: _CG_VERSION
55
+ */
56
+ #define _CG_VERSION 1000
57
+
58
+ /*
59
+ ** Define: _CG_ABI_VERSION
60
+ */
61
+ #ifndef _CG_ABI_VERSION
62
+ # define _CG_ABI_VERSION 1
63
+ #endif
64
+
65
+ /*
66
+ ** Define: _CG_ABI_EXPERIMENTAL
67
+ ** Desc: If enabled, sets all features enabled (ABI-breaking or experimental)
68
+ */
69
+ #if defined(_CG_ABI_EXPERIMENTAL)
70
+ #endif
71
+
72
+ #define _CG_CONCAT_INNER(x, y) x ## y
73
+ #define _CG_CONCAT_OUTER(x, y) _CG_CONCAT_INNER(x, y)
74
+ #define _CG_NAMESPACE _CG_CONCAT_OUTER(__v, _CG_ABI_VERSION)
75
+
76
+ #define _CG_BEGIN_NAMESPACE \
77
+ namespace cooperative_groups { namespace _CG_NAMESPACE {
78
+ #define _CG_END_NAMESPACE \
79
+ }; using namespace _CG_NAMESPACE; };
80
+
81
+ #if (defined(__cplusplus) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MSC_VER >= 1900))
82
+ # define _CG_CPP11_FEATURES
83
+ #endif
84
+
85
+ #if !defined(_CG_QUALIFIER)
86
+ # define _CG_QUALIFIER __forceinline__ __device__
87
+ #endif
88
+ #if !defined(_CG_STATIC_QUALIFIER)
89
+ # define _CG_STATIC_QUALIFIER static __forceinline__ __device__
90
+ #endif
91
+ #if !defined(_CG_CONSTEXPR_QUALIFIER)
92
+ # if defined(_CG_CPP11_FEATURES)
93
+ # define _CG_CONSTEXPR_QUALIFIER constexpr __forceinline__ __device__
94
+ # else
95
+ # define _CG_CONSTEXPR_QUALIFIER _CG_QUALIFIER
96
+ # endif
97
+ #endif
98
+ #if !defined(_CG_STATIC_CONSTEXPR_QUALIFIER)
99
+ # if defined(_CG_CPP11_FEATURES)
100
+ # define _CG_STATIC_CONSTEXPR_QUALIFIER static constexpr __forceinline__ __device__
101
+ # else
102
+ # define _CG_STATIC_CONSTEXPR_QUALIFIER _CG_STATIC_QUALIFIER
103
+ # endif
104
+ #endif
105
+
106
+ #if defined(_MSC_VER)
107
+ # define _CG_DEPRECATED __declspec(deprecated)
108
+ #else
109
+ # define _CG_DEPRECATED __attribute__((deprecated))
110
+ #endif
111
+
112
+ #if (__CUDA_ARCH__ >= 600) || !defined(__CUDA_ARCH__)
113
+ # define _CG_HAS_GRID_GROUP
114
+ #endif
115
+ #if (__CUDA_ARCH__ >= 600) || !defined(__CUDA_ARCH__)
116
+ # define _CG_HAS_MULTI_GRID_GROUP
117
+ #endif
118
+ #if (__CUDA_ARCH__ >= 700) || !defined(__CUDA_ARCH__)
119
+ # define _CG_HAS_MATCH_COLLECTIVE
120
+ #endif
121
+
122
+ #if (__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__) && (defined(__NVCC__) || defined(__CUDACC_RTC__))
123
+ # define _CG_HAS_OP_REDUX
124
+ #endif
125
+
126
+ #if ((__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__)) && !defined(_CG_USER_PROVIDED_SHARED_MEMORY)
127
+ # define _CG_HAS_RESERVED_SHARED
128
+ #endif
129
+
130
+ #if ((__CUDA_ARCH__ >= 900) || !defined(__CUDA_ARCH__)) && \
131
+ (defined(__NVCC__) || defined(__CUDACC_RTC__) || defined(_CG_CLUSTER_INTRINSICS_AVAILABLE)) && \
132
+ defined(_CG_CPP11_FEATURES)
133
+ # define _CG_HAS_CLUSTER_GROUP
134
+ #endif
135
+
136
+ #if (__CUDA_ARCH__ >= 900) || !defined(__CUDA_ARCH__)
137
+ # define _CG_HAS_INSTR_ELECT
138
+ #endif
139
+
140
+ // Has __half and __half2
141
+ // Only usable if you include the cuda_fp16.h extension, and
142
+ // _before_ including cooperative_groups.h
143
+ #ifdef __CUDA_FP16_TYPES_EXIST__
144
+ # define _CG_HAS_FP16_COLLECTIVE
145
+ #endif
146
+
147
+ // Include libcu++ where supported.
148
+ #if defined(_CG_CPP11_FEATURES) && !defined(__QNX__) && !defined(__ibmxl__) && \
149
+ (defined(__NVCC__) || defined(__CUDACC_RTC__)) && \
150
+ (defined(__x86_64__) || defined(__aarch64__) || defined(__ppc64__)|| defined(_M_X64) || defined(_M_ARM64)) && \
151
+ (defined(_MSC_VER) || defined(__GNUC__) || defined(__clang__))
152
+ # define _CG_USE_CUDA_STL
153
+ #else
154
+ # define _CG_USE_OWN_TRAITS
155
+ #endif
156
+
157
+ #if defined(_CG_USE_CUDA_STL) && (!defined(__CUDA_ARCH__) || \
158
+ ((!defined(_MSC_VER) && __CUDA_ARCH__ >= 600) || (defined(_MSC_VER) && __CUDA_ARCH__ >= 700)))
159
+ # define _CG_HAS_STL_ATOMICS
160
+ #endif
161
+
162
+ #ifdef _CG_CPP11_FEATURES
163
+ // Use cuda::std:: for type_traits
164
+ # if defined(_CG_USE_CUDA_STL)
165
+ # define _CG_STL_NAMESPACE cuda::std
166
+ # include <cuda/std/type_traits>
167
+ // Use CG's implementation of type traits
168
+ # else
169
+ # define _CG_STL_NAMESPACE cooperative_groups::details::templates
170
+ # endif
171
+ #endif
172
+
173
+ #ifdef _CG_CPP11_FEATURES
174
+ # define _CG_STATIC_CONST_DECL static constexpr
175
+ # define _CG_CONST_DECL constexpr
176
+ #else
177
+ # define _CG_STATIC_CONST_DECL static const
178
+ # define _CG_CONST_DECL const
179
+ #endif
180
+
181
+ #if (defined(_MSC_VER) && !defined(_WIN64)) || defined(__arm__)
182
+ # define _CG_ASM_PTR_CONSTRAINT "r"
183
+ #else
184
+ # define _CG_ASM_PTR_CONSTRAINT "l"
185
+ #endif
186
+
187
+ /*
188
+ ** Define: CG_DEBUG
189
+ ** What: Enables various runtime safety checks
190
+ */
191
+ #if defined(__CUDACC_DEBUG__) && defined(CG_DEBUG) && !defined(NDEBUG)
192
+ # define _CG_DEBUG
193
+ #endif
194
+
195
+ #if defined(_CG_DEBUG)
196
+ # include <assert.h>
197
+ # define _CG_ASSERT(x) assert((x));
198
+ # define _CG_ABORT() assert(0);
199
+ #else
200
+ # define _CG_ASSERT(x)
201
+ # define _CG_ABORT() __trap();
202
+ #endif
203
+
204
+ _CG_BEGIN_NAMESPACE
205
+
206
+ namespace details {
207
+ _CG_STATIC_CONST_DECL unsigned int default_max_block_size = 1024;
208
+
209
+ #if defined(_CG_CPP11_FEATURES) && !defined(_CG_USE_CUDA_STL)
210
+ namespace templates {
211
+
212
+ /**
213
+ * Integral constants
214
+ **/
215
+ template <typename Ty, Ty Val>
216
+ struct integral_constant {
217
+ static constexpr Ty value = Val;
218
+ typedef Ty type;
219
+
220
+ _CG_QUALIFIER constexpr operator type() const noexcept { return value; }
221
+ _CG_QUALIFIER constexpr type operator()() const noexcept { return value; }
222
+ };
223
+
224
+ typedef integral_constant<bool, true> true_type;
225
+ typedef integral_constant<bool, false> false_type;
226
+
227
+ /**
228
+ * CV Qualifiers
229
+ **/
230
+ template <class Ty> struct is_lvalue_reference : public details::templates::false_type {};
231
+ template <class Ty> struct is_lvalue_reference<Ty&> : public details::templates::true_type {};
232
+
233
+ template <class Ty> struct remove_reference {typedef Ty type;};
234
+ template <class Ty> struct remove_reference<Ty&> {typedef Ty type;};
235
+ template <class Ty> struct remove_reference<Ty&&> {typedef Ty type;};
236
+
237
+ template <class Ty>
238
+ using remove_reference_t = typename details::templates::remove_reference<Ty>::type;
239
+
240
+ template <class Ty> struct remove_const {typedef Ty type;};
241
+ template <class Ty> struct remove_const<const Ty> {typedef Ty type;};
242
+
243
+ template <class Ty> struct remove_volatile {typedef Ty type;};
244
+ template <class Ty> struct remove_volatile<volatile Ty> {typedef Ty type;};
245
+
246
+ template <class Ty> struct remove_cv {typedef typename details::templates::remove_volatile<typename details::templates::remove_const<Ty>::type>::type type;};
247
+
248
+ template <class Ty>
249
+ using remove_cv_t = typename details::templates::remove_cv<Ty>::type;
250
+
251
+ template <class Ty>
252
+ _CG_QUALIFIER Ty&& forward(remove_reference_t<Ty> &t) noexcept {
253
+ return static_cast<Ty&&>(t);
254
+ }
255
+
256
+ template <class Ty>
257
+ _CG_QUALIFIER Ty&& forward(remove_reference_t<Ty> &&t) noexcept {
258
+ static_assert(!details::templates::is_lvalue_reference<Ty>::value, "Forwarding an rvalue as an lvalue is not allowed.");
259
+ return static_cast<Ty&&>(t);
260
+ }
261
+
262
+ /**
263
+ * is_integral
264
+ **/
265
+ template <class Ty> struct _is_integral : public details::templates::false_type {};
266
+ template <> struct _is_integral<bool> : public details::templates::true_type {};
267
+ template <> struct _is_integral<char> : public details::templates::true_type {};
268
+ template <> struct _is_integral<unsigned char> : public details::templates::true_type {};
269
+ template <> struct _is_integral<short> : public details::templates::true_type {};
270
+ template <> struct _is_integral<unsigned short> : public details::templates::true_type {};
271
+ template <> struct _is_integral<int> : public details::templates::true_type {};
272
+ template <> struct _is_integral<unsigned int> : public details::templates::true_type {};
273
+ template <> struct _is_integral<long> : public details::templates::true_type {};
274
+ template <> struct _is_integral<long long> : public details::templates::true_type {};
275
+ template <> struct _is_integral<unsigned long> : public details::templates::true_type {};
276
+ template <> struct _is_integral<unsigned long long> : public details::templates::true_type {};
277
+ //Vector type support?
278
+
279
+ template <typename Ty>
280
+ struct is_integral : public details::templates::_is_integral<typename details::templates::remove_cv<Ty>::type> {};
281
+
282
+ /**
283
+ * is_floating_point
284
+ **/
285
+ template <class Ty> struct _is_floating_point : public details::templates::false_type {};
286
+ template <> struct _is_floating_point<float> : public details::templates::true_type {};
287
+ template <> struct _is_floating_point<double> : public details::templates::true_type {};
288
+ template <> struct _is_floating_point<long double> : public details::templates::true_type {};
289
+ # ifdef __CUDA_FP16_TYPES_EXIST__
290
+ template <> struct _is_floating_point<__half> : public details::templates::true_type {};
291
+ template <> struct _is_floating_point<__half2> : public details::templates::true_type {};
292
+ # endif
293
+ //Vector type support?
294
+
295
+ template <typename Ty>
296
+ struct is_floating_point : public details::templates::_is_floating_point<typename details::templates::remove_cv<Ty>::type> {};
297
+
298
+ template <class T>
299
+ struct is_arithmetic : details::templates::integral_constant<
300
+ bool,
301
+ details::templates::is_integral<T>::value ||
302
+ details::templates::is_floating_point<T>::value> {};
303
+
304
+ template <typename Ty, bool = details::templates::is_arithmetic<Ty>::value>
305
+ struct _is_unsigned : details::templates::integral_constant<bool, Ty(0) < Ty(-1)> {};
306
+
307
+ template <typename Ty>
308
+ struct _is_unsigned<Ty,false> : details::templates::false_type {};
309
+
310
+ template <typename Ty>
311
+ struct is_unsigned : _is_unsigned<typename details::templates::remove_cv<Ty>::type> {};
312
+
313
+ template <typename Ty> struct _is_pointer : public details::templates::false_type {};
314
+ template <typename Ty> struct _is_pointer<Ty*> : public details::templates::true_type {};
315
+
316
+ template <typename Ty>
317
+ struct is_pointer : _is_pointer<typename details::templates::remove_cv<Ty>::type> {};
318
+
319
+ /**
320
+ * programmatic type traits
321
+ **/
322
+ template<bool B, class Ty = void>
323
+ struct enable_if {};
324
+
325
+ template<class Ty>
326
+ struct enable_if<true, Ty> { typedef Ty type; };
327
+
328
+ template<bool Cond, typename Ty = void>
329
+ using enable_if_t = typename details::templates::enable_if<Cond, Ty>::type;
330
+
331
+ template<class Ty1, class Ty2>
332
+ struct is_same : details::templates::false_type {};
333
+
334
+ template<class Ty>
335
+ struct is_same<Ty, Ty> : details::templates::true_type {};
336
+
337
+ } // templates
338
+ #endif // _CG_CPP11_FEATURES
339
+
340
+ } // details
341
+ _CG_END_NAMESPACE
342
+
343
+
344
+ #endif // _CG_INFO_H_
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/invoke.h ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #ifndef _CG_INVOKE_H
51
+ #define _CG_INVOKE_H
52
+
53
+ #include "info.h"
54
+ #include "helpers.h"
55
+
56
+ #if defined(_CG_CPP11_FEATURES)
57
+
58
+ _CG_BEGIN_NAMESPACE
59
+
60
+ namespace details {
61
+
62
+ template <typename Group>
63
+ struct _elect_group_supported : _CG_STL_NAMESPACE::false_type {};
64
+ #ifdef _CG_HAS_INSTR_ELECT
65
+ template<>
66
+ struct _elect_group_supported<coalesced_group> : _CG_STL_NAMESPACE::true_type {};
67
+ template<unsigned int Size, typename Parent>
68
+ struct _elect_group_supported<thread_block_tile<Size, Parent>> :
69
+ _CG_STL_NAMESPACE::integral_constant<bool, (Size <= 32)> {};
70
+ #endif
71
+
72
+ template <typename Group>
73
+ struct elect_group_supported : public _elect_group_supported<details::remove_qual<Group>> {};
74
+
75
+ template<typename Group>
76
+ _CG_STATIC_QUALIFIER bool elect_one(const Group& group, unsigned int mask, unsigned int& leader_lane) {
77
+ int is_leader = 0;
78
+ #ifdef _CG_HAS_INSTR_ELECT
79
+ asm("{\n\t"
80
+ " .reg .pred p;\n\t"
81
+ " elect.sync %0|p, %2;\n\t"
82
+ " @p mov.s32 %1, 1;\n\t"
83
+ "}"
84
+ : "+r"(leader_lane), "+r"(is_leader) : "r" (mask));
85
+ #endif
86
+ return is_leader;
87
+ }
88
+
89
+ template<bool UseElect>
90
+ struct invoke_one_impl {};
91
+
92
+ template<>
93
+ struct invoke_one_impl<true> {
94
+ template<typename Group, typename Fn, typename... Args>
95
+ _CG_STATIC_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) {
96
+ auto mask = details::_coalesced_group_data_access::get_mask(group);
97
+ unsigned int leader_lane = 0;
98
+
99
+ if (elect_one(group, mask, leader_lane)) {
100
+ _CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...);
101
+ }
102
+ }
103
+
104
+ template<typename Group, typename Fn, typename... Args>
105
+ _CG_STATIC_QUALIFIER auto invoke_one_broadcast(const Group& group, Fn&& fn, Args&&... args)
106
+ -> typename _CG_STL_NAMESPACE::remove_reference<
107
+ decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...))>::type {
108
+
109
+ using ResultType = decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...));
110
+ details::remove_qual<ResultType> result;
111
+ auto mask = details::_coalesced_group_data_access::get_mask(group);
112
+ unsigned int leader_lane = 0;
113
+
114
+ if (elect_one(group, mask, leader_lane)) {
115
+ result = _CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...);
116
+ }
117
+
118
+ // Need to use low level api instead of group.shfl, because elect_one returns lane id, not group rank.
119
+ return tile::shuffle_dispatch<ResultType>::shfl(result, mask, leader_lane, 32);
120
+ }
121
+ };
122
+
123
+ template<>
124
+ struct invoke_one_impl<false> {
125
+ template<typename Group, typename Fn, typename... Args>
126
+ _CG_STATIC_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) {
127
+ if (group.thread_rank() == 0) {
128
+ _CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...);
129
+ }
130
+ }
131
+
132
+ template<typename Group, typename Fn, typename... Args>
133
+ _CG_STATIC_QUALIFIER auto invoke_one_broadcast(const Group& group, Fn&& fn, Args&&... args)
134
+ -> typename _CG_STL_NAMESPACE::remove_reference<
135
+ decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...))>::type {
136
+
137
+ using ResultType = decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...));
138
+ details::remove_qual<ResultType> result;
139
+
140
+ if (group.thread_rank() == 0) {
141
+ result = _CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...);
142
+ }
143
+
144
+ return group.shfl(result, 0);
145
+ }
146
+ };
147
+
148
+
149
+ }; // namespace details
150
+
151
+ template<typename Group, typename Fn, typename... Args>
152
+ _CG_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) {
153
+ using impl = details::invoke_one_impl<details::elect_group_supported<Group>::value>;
154
+ impl::invoke_one(group, _CG_STL_NAMESPACE::forward<Fn>(fn), _CG_STL_NAMESPACE::forward<Args>(args)...);
155
+ }
156
+
157
+ template<typename Fn, typename... Args>
158
+ _CG_QUALIFIER auto invoke_one_broadcast(const coalesced_group& group, Fn&& fn, Args&&... args)
159
+ -> typename _CG_STL_NAMESPACE::remove_reference<
160
+ decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...))>::type {
161
+
162
+ using ResultType = decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...));
163
+ static_assert(!_CG_STL_NAMESPACE::is_same<ResultType, void>::value,
164
+ "For invocables returning void invoke_one should be used instead");
165
+ using impl = details::invoke_one_impl<details::elect_group_supported<coalesced_group>::value>;
166
+ return impl::invoke_one_broadcast(group,
167
+ _CG_STL_NAMESPACE::forward<Fn>(fn),
168
+ _CG_STL_NAMESPACE::forward<Args>(args)...);
169
+ }
170
+
171
+ template<unsigned int Size, typename Parent, typename Fn, typename... Args>
172
+ _CG_QUALIFIER auto invoke_one_broadcast(const thread_block_tile<Size, Parent>& group, Fn&& fn, Args&&... args)
173
+ -> typename _CG_STL_NAMESPACE::remove_reference<
174
+ decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...))>::type {
175
+
176
+ using ResultType = decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...));
177
+ static_assert(!_CG_STL_NAMESPACE::is_same<ResultType, void>::value,
178
+ "For invocables returning void invoke_one should be used instead");
179
+ using impl = details::invoke_one_impl<details::elect_group_supported<thread_block_tile<Size, Parent>>::value>;
180
+ return impl::invoke_one_broadcast(group,
181
+ _CG_STL_NAMESPACE::forward<Fn>(fn),
182
+ _CG_STL_NAMESPACE::forward<Args>(args)...);
183
+ }
184
+
185
+ _CG_END_NAMESPACE
186
+
187
+ #endif //_CG_CPP11_FEATURES
188
+
189
+ #endif // _CG_INVOKE_H
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/memory.h ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+ #ifndef _COOPERATIVE_GROUPS_MEMORY_H_
50
+ # define _COOPERATIVE_GROUPS_MEMORY_H_
51
+
52
+ #include "info.h"
53
+
54
+ _CG_BEGIN_NAMESPACE
55
+
56
+ #if defined(_CG_CPP11_FEATURES)
57
+ namespace details {
58
+ _CG_STATIC_CONST_DECL int scratch_num_reserved_bytes = 12;
59
+
60
+ #if defined(_CG_HAS_RESERVED_SHARED)
61
+ _CG_STATIC_QUALIFIER void* reserved_shared_ptr()
62
+ {
63
+ void *ptr;
64
+ asm ("{\n\t"
65
+ " .reg .u32 start;\n\t"
66
+ " .reg .u64 extended;\n\t"
67
+ " mov.u32 start, %%reserved_smem_offset_1;\n\t"
68
+ " cvt.u64.u32 extended, start;\n\t"
69
+ " cvta.shared.u64 %0, extended;\n\t"
70
+ "}"
71
+ : "=" _CG_ASM_PTR_CONSTRAINT(ptr));
72
+ return ptr;
73
+ }
74
+ #endif
75
+
76
+ struct multi_warp_scratch {
77
+ // One barrier per possible size of the group.
78
+ _CG_STATIC_CONST_DECL unsigned int memory_barriers_count = 5;
79
+ _CG_STATIC_CONST_DECL size_t sync_memory_size = memory_barriers_count * sizeof(barrier_t);
80
+
81
+ using communication_type = unsigned long long;
82
+ _CG_STATIC_CONST_DECL size_t communication_size = sizeof(communication_type);
83
+
84
+ // Layout of the scratch space:
85
+ barrier_t barriers[memory_barriers_count];
86
+ char reserved[scratch_num_reserved_bytes]; // Reserve 12 bytes for future use
87
+ communication_type communication_memory[default_max_block_size / 32];
88
+
89
+ _CG_STATIC_CONSTEXPR_QUALIFIER unsigned int scratch_size_needed(unsigned int max_block_size) {
90
+ // One slot of collectives memory per warp.
91
+ return scratch_num_reserved_bytes + sync_memory_size + max_block_size / 32 * communication_size;
92
+ }
93
+
94
+ _CG_QUALIFIER void init_barriers(unsigned int thread_rank) {
95
+ if (thread_rank < memory_barriers_count) {
96
+ barriers[thread_rank] = 0;
97
+ }
98
+ }
99
+ };
100
+
101
+ #if defined(_CG_HAS_RESERVED_SHARED)
102
+ // CG can expect at least 288 bytes available in reserved shared
103
+ static_assert(sizeof(multi_warp_scratch) <= 288, "multi-warp scratch size is too large");
104
+ #endif
105
+
106
+ // Make sure the structure can fit into the user provided memory
107
+ static_assert(sizeof(multi_warp_scratch) <= multi_warp_scratch::scratch_size_needed(default_max_block_size),
108
+ "multi-warp scratch size is too large");
109
+
110
+
111
+ _CG_QUALIFIER multi_warp_scratch* get_scratch_ptr(void* user_scratch) {
112
+ void *ptr;
113
+ #if defined(_CG_HAS_RESERVED_SHARED)
114
+ ptr = reserved_shared_ptr();
115
+ #else
116
+ ptr = user_scratch;
117
+ #endif
118
+ return static_cast<multi_warp_scratch*>(ptr);
119
+
120
+ }
121
+
122
+ }
123
+
124
+ template <unsigned int MaxBlockSize = details::default_max_block_size>
125
+ struct __align__(details::multi_warp_scratch::communication_size) block_tile_memory {
126
+ private:
127
+ #if !defined(_CG_HAS_RESERVED_SHARED)
128
+ char scratch[details::multi_warp_scratch::scratch_size_needed(MaxBlockSize)];
129
+ #endif
130
+ };
131
+ #endif
132
+
133
+ _CG_END_NAMESPACE
134
+
135
+ #endif /* !_COOPERATIVE_GROUPS_MEMORY_H_ */
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/partitioning.h ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #ifndef _CG_PARTITIONING_H
51
+ #define _CG_PARTITIONING_H
52
+
53
+ #include "info.h"
54
+ #include "helpers.h"
55
+
56
+ _CG_BEGIN_NAMESPACE
57
+
58
+ namespace details {
59
+
60
+ template <typename TyGroup>
61
+ _CG_STATIC_QUALIFIER coalesced_group _binary_partition(const TyGroup &tile, bool pred) {
62
+ const unsigned int fullMask = ~0u;
63
+
64
+ unsigned int thisMask = _coalesced_group_data_access::get_mask(tile);
65
+ unsigned int predMask = pred ? 0 : fullMask;
66
+ unsigned int setMask = __ballot_sync(thisMask, pred);
67
+
68
+ if (setMask == thisMask || setMask == 0) {
69
+ coalesced_group subTile = _coalesced_group_data_access::construct_from_mask<coalesced_group>(thisMask);
70
+ _coalesced_group_data_access::modify_meta_group(subTile, 0, 1);
71
+ return subTile;
72
+ }
73
+ else {
74
+ unsigned int subMask = thisMask & (setMask ^ predMask);
75
+ coalesced_group subTile = _coalesced_group_data_access::construct_from_mask<coalesced_group>(subMask);
76
+ _coalesced_group_data_access::modify_meta_group(subTile, pred, 2);
77
+ return subTile;
78
+ }
79
+ }
80
+
81
+ #if defined(_CG_HAS_MATCH_COLLECTIVE) && defined(_CG_CPP11_FEATURES)
82
+ template <typename TyPredicate>
83
+ struct _labeled_partition_dispatch {
84
+ template <typename TyGroup>
85
+ _CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, TyPredicate pred) {
86
+ unsigned int thisMask = _coalesced_group_data_access::get_mask(tile);
87
+ unsigned int thisBias = __ffs(thisMask) - 1; // Subtract 1 to index properly from [1-32]
88
+ unsigned int subMask = __match_any_sync(thisMask, pred);
89
+
90
+ coalesced_group subTile = _coalesced_group_data_access::construct_from_mask<coalesced_group>(subMask);
91
+
92
+ int leaderLaneId = subTile.shfl(details::laneid(), 0);
93
+
94
+ bool isLeader = !subTile.thread_rank();
95
+ unsigned int leaderMask = __ballot_sync(thisMask, isLeader);
96
+ unsigned int tileRank = __fns(leaderMask, leaderLaneId, 0) - thisBias;
97
+
98
+ _coalesced_group_data_access::modify_meta_group(subTile, tileRank, __popc(leaderMask));
99
+
100
+ return subTile;
101
+ }
102
+ };
103
+
104
+ template <>
105
+ struct _labeled_partition_dispatch<bool> {
106
+ template <typename TyGroup>
107
+ _CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, bool pred) {
108
+ return _binary_partition(tile, pred);
109
+ }
110
+ };
111
+
112
+ template <typename TyPredicate>
113
+ struct _labeled_partition_dispatch<TyPredicate*> {
114
+ template <typename TyGroup>
115
+ _CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, TyPredicate* pred) {
116
+ auto impl = _labeled_partition_dispatch<unsigned long long>();
117
+ return impl(tile, reinterpret_cast<unsigned long long>(pred));
118
+ }
119
+ };
120
+ #endif
121
+ }; // namespace details
122
+
123
+ _CG_STATIC_QUALIFIER coalesced_group binary_partition(const coalesced_group &tile, bool pred) {
124
+ return details::_binary_partition(tile, pred);
125
+ }
126
+
127
+ template <unsigned int Size, typename ParentT>
128
+ _CG_STATIC_QUALIFIER coalesced_group binary_partition(const thread_block_tile<Size, ParentT> &tile, bool pred) {
129
+ #ifdef _CG_CPP11_FEATURES
130
+ static_assert(Size <= 32, "Binary partition is available only for tiles of size smaller or equal to 32");
131
+ #endif
132
+ return details::_binary_partition(tile, pred);
133
+ }
134
+
135
+
136
+ #if defined(_CG_HAS_MATCH_COLLECTIVE) && defined(_CG_CPP11_FEATURES)
137
+ template <typename TyPredicate>
138
+ _CG_STATIC_QUALIFIER coalesced_group labeled_partition(const coalesced_group &tile, TyPredicate pred) {
139
+ static_assert(_CG_STL_NAMESPACE::is_integral<TyPredicate>::value ||
140
+ _CG_STL_NAMESPACE::is_pointer<TyPredicate>::value,
141
+ "labeled_partition predicate must be an integral or pointer type");
142
+ auto dispatch = details::_labeled_partition_dispatch<details::remove_qual<TyPredicate>>();
143
+ return dispatch(tile, pred);
144
+ }
145
+
146
+ template <typename TyPredicate, unsigned int Size, typename ParentT>
147
+ _CG_STATIC_QUALIFIER coalesced_group labeled_partition(const thread_block_tile<Size, ParentT> &tile, TyPredicate pred) {
148
+ static_assert(_CG_STL_NAMESPACE::is_integral<TyPredicate>::value ||
149
+ _CG_STL_NAMESPACE::is_pointer<TyPredicate>::value,
150
+ "labeled_partition predicate must be an integral or pointer type");
151
+ static_assert(Size <= 32, "Labeled partition is available only for tiles of size smaller or equal to 32");
152
+ auto dispatch = details::_labeled_partition_dispatch<details::remove_qual<TyPredicate>>();
153
+ return dispatch(tile, pred);
154
+ }
155
+ #endif
156
+
157
+ _CG_END_NAMESPACE
158
+
159
+ #endif // _CG_PARTITIONING_H
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/reduce.h ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+ #ifndef _CG_REDUCE_H_
50
+ #define _CG_REDUCE_H_
51
+
52
+ #include "info.h"
53
+ #include "helpers.h"
54
+ #include "coalesced_reduce.h"
55
+ #include "functional.h"
56
+ #include "cooperative_groups.h"
57
+
58
+ _CG_BEGIN_NAMESPACE
59
+
60
+ namespace details {
61
+
62
+ template <class Ty>
63
+ using _redux_is_add_supported = _CG_STL_NAMESPACE::integral_constant<
64
+ bool,
65
+ _CG_STL_NAMESPACE::is_integral<Ty>::value && (sizeof(Ty) <= 4)>;
66
+
67
+ template <class Ty>
68
+ using redux_is_add_supported = _redux_is_add_supported<Ty>;
69
+
70
+ // A specialization for 64 bit logical operations is possible
71
+ // but for now only accelerate 32 bit bitwise ops
72
+ template <class Ty>
73
+ using redux_is_logical_supported = redux_is_add_supported<Ty>;
74
+
75
+ // Base operator support case
76
+ template <class TyOp, class Ty> struct _redux_op_supported : public _CG_STL_NAMESPACE::false_type {};
77
+ #ifdef _CG_HAS_OP_REDUX
78
+ template <class Ty> struct _redux_op_supported<cooperative_groups::plus<Ty>, Ty> : public redux_is_add_supported<Ty> {};
79
+ template <class Ty> struct _redux_op_supported<cooperative_groups::less<Ty>, Ty> : public redux_is_add_supported<Ty> {};
80
+ template <class Ty> struct _redux_op_supported<cooperative_groups::greater<Ty>, Ty> : public redux_is_add_supported<Ty> {};
81
+ template <class Ty> struct _redux_op_supported<cooperative_groups::bit_and<Ty>, Ty> : public redux_is_logical_supported<Ty> {};
82
+ template <class Ty> struct _redux_op_supported<cooperative_groups::bit_or<Ty>, Ty> : public redux_is_logical_supported<Ty> {};
83
+ template <class Ty> struct _redux_op_supported<cooperative_groups::bit_xor<Ty>, Ty> : public redux_is_logical_supported<Ty> {};
84
+ #endif
85
+
86
+ template <class Ty, template <class> class TyOp>
87
+ using redux_op_supported = _redux_op_supported<
88
+ typename details::remove_qual<TyOp<Ty>>,
89
+ Ty>;
90
+
91
+ // Groups smaller than 16 actually have worse performance characteristics when used with redux
92
+ // tiles of size 16 and 32 perform the same or better and have better code generation profiles
93
+ template <class TyGroup> struct _redux_group_optimized : public _CG_STL_NAMESPACE::false_type {};
94
+
95
+ template <unsigned int Sz, typename TyPar>
96
+ struct _redux_group_optimized<cooperative_groups::thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::integral_constant<
97
+ bool,
98
+ (Sz >= 16)> {};
99
+ template <unsigned int Sz, typename TyPar>
100
+ struct _redux_group_optimized<internal_thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::integral_constant<
101
+ bool,
102
+ (Sz >= 16)> {};
103
+ template <>
104
+ struct _redux_group_optimized<cooperative_groups::coalesced_group> : public _CG_STL_NAMESPACE::true_type {};
105
+
106
+ template <typename TyGroup>
107
+ using redux_group_optimized = _redux_group_optimized<details::remove_qual<TyGroup>>;
108
+
109
+ template <template <class> class TyOp>
110
+ _CG_STATIC_QUALIFIER int pick_redux(int mask, int val);
111
+ template <template <class> class TyOp>
112
+ _CG_STATIC_QUALIFIER unsigned int pick_redux(int mask, unsigned int val);
113
+
114
+ #ifdef _CG_HAS_OP_REDUX
115
+ template <> _CG_QUALIFIER int pick_redux<cooperative_groups::plus>(int mask, int val) {
116
+ return __reduce_add_sync(mask, val);
117
+ }
118
+ template <> _CG_QUALIFIER int pick_redux<cooperative_groups::less>(int mask, int val) {
119
+ return __reduce_min_sync(mask, val);
120
+ }
121
+ template <> _CG_QUALIFIER int pick_redux<cooperative_groups::greater>(int mask, int val) {
122
+ return __reduce_max_sync(mask, val);
123
+ }
124
+ template <> _CG_QUALIFIER int pick_redux<cooperative_groups::bit_and>(int mask, int val) {
125
+ return __reduce_and_sync(mask, val);
126
+ }
127
+ template <> _CG_QUALIFIER int pick_redux<cooperative_groups::bit_xor>(int mask, int val) {
128
+ return __reduce_xor_sync(mask, val);
129
+ }
130
+ template <> _CG_QUALIFIER int pick_redux<cooperative_groups::bit_or>(int mask, int val) {
131
+ return __reduce_or_sync(mask, val);
132
+ }
133
+
134
+ template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::plus>(int mask, unsigned int val) {
135
+ return __reduce_add_sync(mask, val);
136
+ }
137
+ template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::less>(int mask, unsigned int val) {
138
+ return __reduce_min_sync(mask, val);
139
+ }
140
+ template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::greater>(int mask, unsigned int val) {
141
+ return __reduce_max_sync(mask, val);
142
+ }
143
+ template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::bit_and>(int mask, unsigned int val) {
144
+ return __reduce_and_sync(mask, val);
145
+ }
146
+ template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::bit_xor>(int mask, unsigned int val) {
147
+ return __reduce_xor_sync(mask, val);
148
+ }
149
+ template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::bit_or>(int mask, unsigned int val) {
150
+ return __reduce_or_sync(mask, val);
151
+ }
152
+ #endif
153
+
154
+
155
+ template <typename TyVal, bool = _CG_STL_NAMESPACE::is_unsigned<TyVal>::value>
156
+ struct _accelerated_op;
157
+
158
+ // Signed type redux intrinsic dispatch
159
+ template <typename TyVal>
160
+ struct _accelerated_op<TyVal, false> {
161
+ template <template <class> class TyOp>
162
+ _CG_STATIC_QUALIFIER TyVal redux(int mask, TyVal val) {
163
+ return static_cast<TyVal>(pick_redux<TyOp>(mask, static_cast<int>(val)));
164
+ }
165
+ };
166
+
167
+ // Unsigned type redux intrinsic dispatch
168
+ template <typename TyVal>
169
+ struct _accelerated_op<TyVal, true> {
170
+ template <template <class> class TyOp>
171
+ _CG_STATIC_QUALIFIER TyVal redux(int mask, TyVal val) {
172
+ return static_cast<TyVal>(pick_redux<TyOp>(mask, static_cast<unsigned int>(val)));
173
+ }
174
+ };
175
+
176
+ template <typename TyVal>
177
+ using accelerated_op = _accelerated_op<TyVal>;
178
+
179
+
180
+ template <typename TyVal, typename TyFnInput, typename TyGroup>
181
+ class _redux_dispatch {
182
+ template <class Ty, template <class> class TyOp>
183
+ using _redux_is_usable = _CG_STL_NAMESPACE::integral_constant<bool,
184
+ redux_op_supported<Ty, TyOp>::value &&
185
+ redux_group_optimized<TyGroup>::value>;
186
+
187
+ template <class Ty, template <class> class TyOp>
188
+ using redux_is_usable = typename _CG_STL_NAMESPACE::enable_if<_redux_is_usable<Ty, TyOp>::value, void>::type*;
189
+
190
+ template <class Ty, template <class> class TyOp>
191
+ using redux_is_not_usable = typename _CG_STL_NAMESPACE::enable_if<!_redux_is_usable<Ty, TyOp>::value, void>::type*;
192
+
193
+ public:
194
+ // Dispatch to redux if the combination of op and args are supported
195
+ template<
196
+ template <class> class TyOp,
197
+ redux_is_usable<TyFnInput, TyOp> = nullptr>
198
+ _CG_STATIC_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>&& op) -> decltype(op(val, val)) {
199
+ // Retrieve the mask for the group and dispatch to redux
200
+ return accelerated_op<TyFnInput>::template redux<TyOp>(_coalesced_group_data_access::get_mask(group), _CG_STL_NAMESPACE::forward<TyVal>(val));
201
+ }
202
+
203
+ template<
204
+ template <class> class TyOp,
205
+ redux_is_usable<TyFnInput, TyOp> = nullptr>
206
+ _CG_STATIC_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>& op) -> decltype(op(val, val)) {
207
+ // Retrieve the mask for the group and dispatch to redux
208
+ return accelerated_op<TyFnInput>::template redux<TyOp>(_coalesced_group_data_access::get_mask(group), _CG_STL_NAMESPACE::forward<TyVal>(val));
209
+ }
210
+
211
+ // Fallback shuffle sync reduction
212
+ template <
213
+ template <class> class TyOp,
214
+ redux_is_not_usable<TyFnInput, TyOp> = nullptr>
215
+ _CG_STATIC_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>&& op) -> decltype(op(val, val)) {
216
+ //Dispatch to fallback shuffle sync accelerated reduction
217
+ return coalesced_reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp<TyFnInput>>(op));
218
+ }
219
+
220
+ };
221
+
222
+ // Group support for reduce.
223
+ template <class TyGroup> struct _reduce_group_supported : public _CG_STL_NAMESPACE::false_type {};
224
+
225
+ template <unsigned int Sz, typename TyPar>
226
+ struct _reduce_group_supported<cooperative_groups::thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::true_type {};
227
+ template <unsigned int Sz, typename TyPar>
228
+ struct _reduce_group_supported<internal_thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::true_type {};
229
+ template <>
230
+ struct _reduce_group_supported<cooperative_groups::coalesced_group> : public _CG_STL_NAMESPACE::true_type {};
231
+
232
+ template <typename TyGroup>
233
+ using reduce_group_supported = _reduce_group_supported<details::remove_qual<TyGroup>>;
234
+
235
+ template <typename TyVal, typename TyFnInput, template <class> class TyOp, typename TyGroup>
236
+ _CG_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>&& op) -> decltype(op(val, val)) {
237
+ static_assert(details::is_op_type_same<TyFnInput, TyVal>::value, "Operator and argument types differ");
238
+
239
+ using dispatch = details::_redux_dispatch<TyVal, TyFnInput, TyGroup>;
240
+ return dispatch::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp<TyFnInput>>(op));
241
+ }
242
+
243
+ template <typename TyVal, typename TyFnInput, template <class> class TyOp, typename TyGroup>
244
+ _CG_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>& op) -> decltype(op(val, val)) {
245
+ static_assert(details::is_op_type_same<TyFnInput, TyVal>::value, "Operator and argument types differ");
246
+
247
+ using dispatch = details::_redux_dispatch<TyVal, TyFnInput, TyGroup>;
248
+ return dispatch::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp<TyFnInput>>(op));
249
+ }
250
+
251
+
252
+ template <typename TyVal, typename TyOp, typename TyGroup>
253
+ _CG_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) {
254
+ return details::coalesced_reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp>(op));
255
+ }
256
+
257
+ template <unsigned int GroupId>
258
+ struct tile_reduce_dispatch;
259
+
260
+ template <>
261
+ struct tile_reduce_dispatch<details::coalesced_group_id> {
262
+ template <typename TyGroup, typename TyVal, typename TyFn>
263
+ _CG_STATIC_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
264
+ return details::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
265
+ }
266
+ };
267
+
268
+ #if defined(_CG_CPP11_FEATURES)
269
+ template <>
270
+ struct tile_reduce_dispatch<details::multi_tile_group_id> {
271
+ template <unsigned int Size, typename ParentT, typename TyVal, typename TyFn>
272
+ _CG_STATIC_QUALIFIER auto reduce(const thread_block_tile<Size, ParentT>& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
273
+ using warpType = details::internal_thread_block_tile<32, __static_size_multi_warp_tile_base<Size>>;
274
+ using TyRet = details::remove_qual<TyVal>;
275
+ const unsigned int num_warps = Size / 32;
276
+
277
+ auto warp_lambda = [&] (const warpType& warp, TyRet* warp_scratch_location) {
278
+ *warp_scratch_location =
279
+ details::reduce(warp, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
280
+ };
281
+ auto inter_warp_lambda =
282
+ [&] (const details::internal_thread_block_tile<num_warps, warpType>& subwarp, TyRet* thread_scratch_location) {
283
+ *thread_scratch_location =
284
+ details::reduce(subwarp, *thread_scratch_location, _CG_STL_NAMESPACE::forward<TyFn>(op));
285
+ };
286
+ return details::multi_warp_collectives_helper<TyRet>(group, warp_lambda, inter_warp_lambda);
287
+ }
288
+ };
289
+
290
+ template <unsigned int GroupId>
291
+ struct tile_async_reduce_dispatch;
292
+
293
+ template <>
294
+ struct tile_async_reduce_dispatch<details::coalesced_group_id> {
295
+ template <typename GroupT, typename TyDst, typename TyVal, typename TyFn, typename TyResHandler>
296
+ _CG_STATIC_QUALIFIER void reduce(const GroupT& group, TyDst& dst, TyVal&& val, TyFn&& op, TyResHandler& res_handler) {
297
+ // Do regular, in group reduction
298
+ auto result = details::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
299
+
300
+ // One thread stores/updates the destination
301
+ if (group.thread_rank() == 0) {
302
+ res_handler(result);
303
+ }
304
+ }
305
+ };
306
+
307
+ template <>
308
+ struct tile_async_reduce_dispatch<details::multi_tile_group_id> {
309
+ template <unsigned int TySize, typename ParentT, typename TyDst, typename TyInputVal, typename TyFn, typename TyResHandler>
310
+ _CG_STATIC_QUALIFIER void reduce(const thread_block_tile<TySize, ParentT>& group, TyDst& dst, TyInputVal&& val, TyFn&& op, TyResHandler& res_handler) {
311
+ using TyVal = remove_qual<TyInputVal>;
312
+ const unsigned int num_warps = TySize / 32;
313
+ details::barrier_t* sync_location = multi_warp_sync_location_getter(group);
314
+ auto warp_scratch_location = multi_warp_scratch_location_getter<TyVal>(group, group.thread_rank() / 32);
315
+
316
+ // Do in warp reduce
317
+ auto warp = details::tiled_partition_internal<32, thread_block_tile<TySize, ParentT>>();
318
+ *warp_scratch_location = details::reduce(warp, _CG_STL_NAMESPACE::forward<TyInputVal>(val), op);
319
+
320
+ // Tile of size num_warps from the last warp to arrive does final reduction step
321
+ if (details::sync_warps_last_releases(sync_location, details::cta::thread_rank(), num_warps)) {
322
+ auto subwarp = details::tiled_partition_internal<num_warps, decltype(warp)>();
323
+ if (subwarp.meta_group_rank() == 0) {
324
+ auto thread_scratch_location = multi_warp_scratch_location_getter<TyVal>(group, subwarp.thread_rank());
325
+ auto thread_val = *thread_scratch_location;
326
+ // Release other warps, we read their contribution already.
327
+ subwarp.sync();
328
+ details::sync_warps_release(sync_location, subwarp.thread_rank() == 0, details::cta::thread_rank(), num_warps);
329
+ TyVal result = details::reduce(subwarp, thread_val, op);
330
+ // One thread stores the result or updates the atomic
331
+ if (subwarp.thread_rank() == 0) {
332
+ res_handler(result);
333
+ }
334
+ }
335
+ warp.sync();
336
+ }
337
+ }
338
+ };
339
+ #endif
340
+
341
+ template <typename TyGroup, typename TyInputVal, typename TyRetVal>
342
+ _CG_QUALIFIER void check_reduce_params() {
343
+ static_assert(details::is_op_type_same<TyInputVal, TyRetVal>::value, "Operator input and output types differ");
344
+ static_assert(details::reduce_group_supported<TyGroup>::value, "This group does not exclusively represent a tile");
345
+ };
346
+
347
+ template <typename TyGroup, typename TyDstVal, typename TyInputVal, typename TyRetVal>
348
+ _CG_QUALIFIER void check_async_reduce_params() {
349
+ check_reduce_params<TyGroup, TyInputVal, TyRetVal>();
350
+ static_assert(details::is_op_type_same<TyDstVal, TyInputVal>::value, "Destination and input types differ");
351
+ }
352
+ } // details
353
+
354
+ template <typename TyGroup, typename TyVal, typename TyFn>
355
+ _CG_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
356
+ details::check_reduce_params<TyGroup, details::remove_qual<TyVal>, decltype(op(val, val))>();
357
+
358
+ using dispatch = details::tile_reduce_dispatch<TyGroup::_group_id>;
359
+ return dispatch::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
360
+ }
361
+
362
+ #if defined(_CG_CPP11_FEATURES)
363
+
364
+ # if defined(_CG_HAS_STL_ATOMICS)
365
+ template<typename TyGroup, typename TyVal, cuda::thread_scope Sco, typename TyInputVal, typename TyFn>
366
+ void _CG_QUALIFIER reduce_update_async(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) {
367
+ details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
368
+ auto update_lambda = [&] (TyVal& result) {
369
+ details::atomic_update(dst, result, op);
370
+ };
371
+ using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
372
+ dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), update_lambda);
373
+ }
374
+
375
+ template<typename TyGroup, typename TyVal, cuda::thread_scope Sco, typename TyInputVal, typename TyFn>
376
+ void _CG_QUALIFIER reduce_update_async(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) {
377
+ details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
378
+ auto update_lambda = [&] (TyVal& result) {
379
+ details::atomic_update(dst, result, op);
380
+ };
381
+ using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
382
+ dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), update_lambda);
383
+ }
384
+
385
+ template<typename TyGroup, typename TyVal, cuda::thread_scope Sco, typename TyInputVal, typename TyFn>
386
+ void _CG_QUALIFIER reduce_store_async(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) {
387
+ details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
388
+ auto store_lambda = [&] (TyVal& result) {
389
+ details::atomic_store(dst, result);
390
+ };
391
+ using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
392
+ dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), store_lambda);
393
+ }
394
+
395
+ template<typename TyGroup, typename TyVal, cuda::thread_scope Sco, typename TyInputVal, typename TyFn>
396
+ void _CG_QUALIFIER reduce_store_async(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) {
397
+ details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
398
+ auto store_lambda = [&] (TyVal& result) {
399
+ details::atomic_store(dst, result);
400
+ };
401
+ using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
402
+ dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), store_lambda);
403
+ }
404
+ # endif
405
+
406
+ template<typename TyGroup, typename TyVal, typename TyInputVal, typename TyFn>
407
+ void _CG_QUALIFIER reduce_store_async(const TyGroup& group, TyVal* dst, TyInputVal&& val, TyFn&& op) {
408
+ details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
409
+ auto store_lambda = [&] (TyVal& result) {
410
+ *dst = result;
411
+ };
412
+ using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
413
+ dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), store_lambda);
414
+ }
415
+ #endif
416
+
417
+ _CG_END_NAMESPACE
418
+
419
+ #endif // _CG_REDUCE_H_
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/scan.h ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+ #ifndef _CG_SCAN_H_
50
+ #define _CG_SCAN_H_
51
+
52
+ #include "info.h"
53
+ #include "helpers.h"
54
+ #include "functional.h"
55
+ #include "coalesced_scan.h"
56
+
57
+ _CG_BEGIN_NAMESPACE
58
+
59
+ namespace details {
60
+
61
+ // Group support for scan.
62
+ template <class TyGroup> struct _scan_group_supported : public _CG_STL_NAMESPACE::false_type {};
63
+
64
+ template <unsigned int Sz, typename TyPar>
65
+ struct _scan_group_supported<cooperative_groups::thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::true_type {};
66
+ template <unsigned int Sz, typename TyPar>
67
+ struct _scan_group_supported<internal_thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::true_type {};
68
+ template <>
69
+ struct _scan_group_supported<cooperative_groups::coalesced_group> : public _CG_STL_NAMESPACE::true_type {};
70
+
71
+ template <typename TyGroup>
72
+ using scan_group_supported = _scan_group_supported<details::remove_qual<TyGroup>>;
73
+
74
+ template <bool IsIntegralPlus>
75
+ struct integral_optimized_scan;
76
+
77
+ enum class ScanType { exclusive, inclusive };
78
+
79
+ template <unsigned int GroupId, ScanType TyScan>
80
+ struct scan_dispatch;
81
+
82
+ template <ScanType TyScan>
83
+ struct scan_dispatch<details::coalesced_group_id, TyScan> {
84
+ template <typename TyGroup, typename TyVal, typename TyFn>
85
+ _CG_STATIC_QUALIFIER auto scan(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
86
+ auto scan_result = coalesced_inclusive_scan(group, val, op);
87
+ if (TyScan == ScanType::exclusive) {
88
+ scan_result = convert_inclusive_to_exclusive(group,
89
+ scan_result,
90
+ _CG_STL_NAMESPACE::forward<TyVal>(val),
91
+ _CG_STL_NAMESPACE::forward<TyFn>(op));
92
+ }
93
+ return scan_result;
94
+ }
95
+ };
96
+
97
+ #if defined(_CG_CPP11_FEATURES)
98
+ template <ScanType TyScan>
99
+ struct scan_dispatch<details::multi_tile_group_id, TyScan> {
100
+ template <unsigned int Size, typename ParentT, typename TyVal, typename TyFn>
101
+ _CG_STATIC_QUALIFIER auto scan(const thread_block_tile<Size, ParentT>& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
102
+ using warpType = details::internal_thread_block_tile<32, __static_size_multi_warp_tile_base<Size>>;
103
+ using TyRet = details::remove_qual<TyVal>;
104
+ const unsigned int num_warps = Size / 32;
105
+ // In warp scan result, calculated in warp_lambda
106
+ TyRet warp_scan;
107
+
108
+ // In warp scan, put sum in the warp_scratch_location
109
+ auto warp_lambda = [&] (const warpType& warp, TyRet* warp_scratch_location) {
110
+ warp_scan =
111
+ details::coalesced_inclusive_scan(warp, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
112
+ if (warp.thread_rank() + 1 == warp.size()) {
113
+ *warp_scratch_location = warp_scan;
114
+ }
115
+ if (TyScan == ScanType::exclusive) {
116
+ warp_scan = warp.shfl_up(warp_scan, 1);
117
+ }
118
+ };
119
+
120
+ // Tile of size num_warps performing the final scan part (exclusive scan of warp sums), other threads will add it
121
+ // to its in-warp scan result
122
+ auto inter_warp_lambda =
123
+ [&] (const details::internal_thread_block_tile<num_warps, warpType>& subwarp, TyRet* thread_scratch_location) {
124
+ auto thread_val = *thread_scratch_location;
125
+ auto result = coalesced_inclusive_scan(subwarp, thread_val, op);
126
+ *thread_scratch_location = convert_inclusive_to_exclusive(subwarp, result, thread_val, op);
127
+ };
128
+
129
+ TyRet previous_warps_sum = details::multi_warp_collectives_helper<TyRet>(group, warp_lambda, inter_warp_lambda);
130
+ if (TyScan == ScanType::exclusive && warpType::thread_rank() == 0) {
131
+ return previous_warps_sum;
132
+ }
133
+ if (warpType::meta_group_rank() == 0) {
134
+ return warp_scan;
135
+ }
136
+ else {
137
+ return op(warp_scan, previous_warps_sum);
138
+ }
139
+ }
140
+ };
141
+
142
+ #if defined(_CG_HAS_STL_ATOMICS)
143
+ template <unsigned int GroupId, ScanType TyScan>
144
+ struct scan_update_dispatch;
145
+
146
+ template <ScanType TyScan>
147
+ struct scan_update_dispatch<details::coalesced_group_id, TyScan> {
148
+ template <typename TyGroup, typename TyAtomic, typename TyVal, typename TyFn>
149
+ _CG_STATIC_QUALIFIER auto scan(const TyGroup& group, TyAtomic& dst, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
150
+ details::remove_qual<TyVal> old;
151
+
152
+ // Do regular in group scan
153
+ auto scan_result = details::coalesced_inclusive_scan(group, val, op);
154
+
155
+ // Last thread updates the atomic and distributes its old value to other threads
156
+ if (group.thread_rank() == group.size() - 1) {
157
+ old = atomic_update(dst, scan_result, _CG_STL_NAMESPACE::forward<TyFn>(op));
158
+ }
159
+ old = group.shfl(old, group.size() - 1);
160
+ if (TyScan == ScanType::exclusive) {
161
+ scan_result = convert_inclusive_to_exclusive(group, scan_result, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
162
+ }
163
+ scan_result = op(old, scan_result);
164
+ return scan_result;
165
+ }
166
+ };
167
+
168
+ template <ScanType TyScan>
169
+ struct scan_update_dispatch<details::multi_tile_group_id, TyScan> {
170
+ template <unsigned int Size, typename ParentT, typename TyAtomic, typename TyVal, typename TyFn>
171
+ _CG_STATIC_QUALIFIER auto scan(const thread_block_tile<Size, ParentT>& group, TyAtomic& dst, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
172
+ using warpType = details::internal_thread_block_tile<32, __static_size_multi_warp_tile_base<Size>>;
173
+ using TyRet = details::remove_qual<TyVal>;
174
+ const unsigned int num_warps = Size / 32;
175
+ // In warp scan result, calculated in warp_lambda
176
+ TyRet warp_scan;
177
+
178
+ // In warp scan, put sum in the warp_scratch_location
179
+ auto warp_lambda = [&] (const warpType& warp, TyRet* warp_scratch_location) {
180
+ warp_scan =
181
+ details::coalesced_inclusive_scan(warp, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
182
+ if (warp.thread_rank() + 1 == warp.size()) {
183
+ *warp_scratch_location = warp_scan;
184
+ }
185
+ if (TyScan == ScanType::exclusive) {
186
+ warp_scan = warp.shfl_up(warp_scan, 1);
187
+ }
188
+ };
189
+
190
+ // Tile of size num_warps performing the final scan part (exclusive scan of warp sums), other threads will add it
191
+ // to its in-warp scan result
192
+ auto inter_warp_lambda =
193
+ [&] (const details::internal_thread_block_tile<num_warps, warpType>& subwarp, TyRet* thread_scratch_location) {
194
+ auto thread_val = *thread_scratch_location;
195
+ auto scan_result = details::coalesced_inclusive_scan(subwarp, thread_val, op);
196
+ TyRet offset;
197
+ // Single thread does the atomic update with sum of all contributions and reads the old value.
198
+ if (subwarp.thread_rank() == subwarp.size() - 1) {
199
+ offset = details::atomic_update(dst, scan_result, op);
200
+ }
201
+ offset = subwarp.shfl(offset, subwarp.size() - 1);
202
+ scan_result = convert_inclusive_to_exclusive(subwarp, scan_result, thread_val, op);
203
+ // Add offset read from the atomic to the scanned warp sum.
204
+ // Skipping first thread, since it got defautly constructed value from the conversion,
205
+ // it should just return the offset received from the thread that did the atomic update.
206
+ if (subwarp.thread_rank() != 0) {
207
+ offset = op(scan_result, offset);
208
+ }
209
+ *thread_scratch_location = offset;
210
+ };
211
+
212
+ TyRet previous_warps_sum = details::multi_warp_collectives_helper<TyRet>(group, warp_lambda, inter_warp_lambda);
213
+ if (TyScan == ScanType::exclusive && warpType::thread_rank() == 0) {
214
+ return previous_warps_sum;
215
+ }
216
+ return op(warp_scan, previous_warps_sum);
217
+ }
218
+ };
219
+ #endif
220
+ #endif
221
+
222
+ template <typename TyGroup, typename TyInputVal, typename TyRetVal>
223
+ _CG_QUALIFIER void check_scan_params() {
224
+ static_assert(details::is_op_type_same<TyInputVal, TyRetVal>::value, "Operator input and output types differ");
225
+ static_assert(details::scan_group_supported<TyGroup>::value, "This group does not exclusively represent a tile");
226
+ }
227
+
228
+ #if defined(_CG_HAS_STL_ATOMICS)
229
+ template <typename TyGroup, typename TyDstVal, typename TyInputVal, typename TyRetVal>
230
+ _CG_QUALIFIER void check_scan_update_params() {
231
+ check_scan_params<TyGroup, TyInputVal, TyRetVal>();
232
+ static_assert(details::is_op_type_same<TyDstVal, TyInputVal>::value, "Destination and input types differ");
233
+ }
234
+ #endif
235
+
236
+ } // details
237
+
238
+ template <typename TyGroup, typename TyVal, typename TyFn>
239
+ _CG_QUALIFIER auto inclusive_scan(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
240
+ details::check_scan_params<TyGroup, TyVal, decltype(op(val, val))>();
241
+
242
+ using dispatch = details::scan_dispatch<TyGroup::_group_id, details::ScanType::inclusive>;
243
+ return dispatch::scan(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
244
+ }
245
+
246
+ template <typename TyGroup, typename TyVal>
247
+ _CG_QUALIFIER details::remove_qual<TyVal> inclusive_scan(const TyGroup& group, TyVal&& val) {
248
+ return inclusive_scan(group, _CG_STL_NAMESPACE::forward<TyVal>(val), cooperative_groups::plus<details::remove_qual<TyVal>>());
249
+ }
250
+
251
+ template <typename TyGroup, typename TyVal, typename TyFn>
252
+ _CG_QUALIFIER auto exclusive_scan(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
253
+ details::check_scan_params<TyGroup, TyVal, decltype(op(val, val))>();
254
+
255
+ using dispatch = details::scan_dispatch<TyGroup::_group_id, details::ScanType::exclusive>;
256
+ return dispatch::scan(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
257
+ }
258
+
259
+ template <typename TyGroup, typename TyVal>
260
+ _CG_QUALIFIER details::remove_qual<TyVal> exclusive_scan(const TyGroup& group, TyVal&& val) {
261
+ return exclusive_scan(group, _CG_STL_NAMESPACE::forward<TyVal>(val), cooperative_groups::plus<details::remove_qual<TyVal>>());
262
+ }
263
+
264
+ #if defined(_CG_HAS_STL_ATOMICS)
265
+ template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco, typename TyFn>
266
+ _CG_QUALIFIER auto inclusive_scan_update(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) -> decltype(op(val, val)) {
267
+ details::check_scan_update_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
268
+
269
+ using dispatch = details::scan_update_dispatch<TyGroup::_group_id, details::ScanType::inclusive>;
270
+ return dispatch::scan(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
271
+ }
272
+
273
+ template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco>
274
+ _CG_QUALIFIER TyVal inclusive_scan_update(const TyGroup& group, cuda::atomic<TyVal, Sco> & dst, TyInputVal&& val) {
275
+ return inclusive_scan_update(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), cooperative_groups::plus<TyVal>());
276
+ }
277
+
278
+ template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco, typename TyFn>
279
+ _CG_QUALIFIER auto exclusive_scan_update(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) -> decltype(op(val, val)) {
280
+ details::check_scan_update_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
281
+
282
+ using dispatch = details::scan_update_dispatch<TyGroup::_group_id, details::ScanType::exclusive>;
283
+ return dispatch::scan(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
284
+ }
285
+
286
+ template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco>
287
+ _CG_QUALIFIER TyVal exclusive_scan_update(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val) {
288
+ return exclusive_scan_update(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), cooperative_groups::plus<TyVal>());
289
+ }
290
+
291
+ template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco, typename TyFn>
292
+ _CG_QUALIFIER auto inclusive_scan_update(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) -> decltype(op(val, val)) {
293
+ details::check_scan_update_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
294
+
295
+ using dispatch = details::scan_update_dispatch<TyGroup::_group_id, details::ScanType::inclusive>;
296
+ return dispatch::scan(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
297
+ }
298
+
299
+ template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco>
300
+ _CG_QUALIFIER TyVal inclusive_scan_update(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco> & dst, TyInputVal&& val) {
301
+ return inclusive_scan_update(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), cooperative_groups::plus<TyVal>());
302
+ }
303
+
304
+ template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco, typename TyFn>
305
+ _CG_QUALIFIER auto exclusive_scan_update(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) -> decltype(op(val, val)) {
306
+ details::check_scan_update_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
307
+
308
+ using dispatch = details::scan_update_dispatch<TyGroup::_group_id, details::ScanType::exclusive>;
309
+ return dispatch::scan(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
310
+ }
311
+
312
+ template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco>
313
+ _CG_QUALIFIER TyVal exclusive_scan_update(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val) {
314
+ return exclusive_scan_update(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), cooperative_groups::plus<TyVal>());
315
+ }
316
+ #endif
317
+
318
+ _CG_END_NAMESPACE
319
+
320
+ #endif // _CG_SCAN_H_
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/sync.h ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+ #ifndef _CG_GRID_H
50
+ #define _CG_GRID_H
51
+
52
+ #include "info.h"
53
+
54
+ _CG_BEGIN_NAMESPACE
55
+
56
+ namespace details
57
+ {
58
+
59
+ typedef unsigned int barrier_t;
60
+
61
+ _CG_STATIC_QUALIFIER bool bar_has_flipped(unsigned int old_arrive, unsigned int current_arrive) {
62
+ return (((old_arrive ^ current_arrive) & 0x80000000) != 0);
63
+ }
64
+
65
+ _CG_STATIC_QUALIFIER bool is_cta_master() {
66
+ return (threadIdx.x + threadIdx.y + threadIdx.z == 0);
67
+ }
68
+
69
+ _CG_STATIC_QUALIFIER unsigned int sync_grids_arrive(volatile barrier_t *arrived) {
70
+ unsigned int oldArrive = 0;
71
+
72
+ __barrier_sync(0);
73
+
74
+ if (is_cta_master()) {
75
+ unsigned int expected = gridDim.x * gridDim.y * gridDim.z;
76
+ bool gpu_master = (blockIdx.x + blockIdx.y + blockIdx.z == 0);
77
+ unsigned int nb = 1;
78
+
79
+ if (gpu_master) {
80
+ nb = 0x80000000 - (expected - 1);
81
+ }
82
+
83
+ #if __CUDA_ARCH__ < 700
84
+ // Fence; barrier update; volatile polling; fence
85
+ __threadfence();
86
+
87
+ oldArrive = atomicAdd((unsigned int*)arrived, nb);
88
+ #else
89
+ // Barrier update with release; polling with acquire
90
+ asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), "r"(nb) : "memory");
91
+ #endif
92
+ }
93
+
94
+ return oldArrive;
95
+ }
96
+
97
+
98
+ _CG_STATIC_QUALIFIER void sync_grids_wait(unsigned int oldArrive, volatile barrier_t *arrived) {
99
+ if (is_cta_master()) {
100
+ #if __CUDA_ARCH__ < 700
101
+ while (!bar_has_flipped(oldArrive, *arrived));
102
+
103
+ __threadfence();
104
+
105
+ #else
106
+ unsigned int current_arrive;
107
+ do {
108
+ asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived) : "memory");
109
+ } while (!bar_has_flipped(oldArrive, current_arrive));
110
+ #endif
111
+ }
112
+
113
+ __barrier_sync(0);
114
+ }
115
+
116
+ /* - Multi warp groups synchronization routines - */
117
+
118
+ // Need both acquire and release for the last warp, since it won't be able to acquire with red.and
119
+ _CG_STATIC_QUALIFIER unsigned int atom_or_acq_rel_cta(unsigned int *addr, unsigned int val) {
120
+ unsigned int old;
121
+ #if __CUDA_ARCH__ < 700
122
+ __threadfence_block();
123
+ old = atomicOr(addr, val);
124
+ #else
125
+ asm volatile("atom.or.acq_rel.cta.b32 %0,[%1],%2;" : "=r"(old) : _CG_ASM_PTR_CONSTRAINT(addr), "r"(val) : "memory");
126
+ #endif
127
+ return old;
128
+ }
129
+
130
+ // Special case where barrier is arrived, but not waited on
131
+ _CG_STATIC_QUALIFIER void red_or_release_cta(unsigned int *addr, unsigned int val) {
132
+ #if __CUDA_ARCH__ < 700
133
+ __threadfence_block();
134
+ atomicOr(addr, val);
135
+ #else
136
+ asm volatile("red.or.release.cta.b32 [%0],%1;" :: _CG_ASM_PTR_CONSTRAINT(addr), "r"(val) : "memory");
137
+ #endif
138
+ }
139
+
140
+ // Usually called by last arriving warp to released other warps, can be relaxed, since or was already acq_rel
141
+ _CG_STATIC_QUALIFIER void red_and_relaxed_cta(unsigned int *addr, unsigned int val) {
142
+ #if __CUDA_ARCH__ < 700
143
+ atomicAnd(addr, val);
144
+ #else
145
+ asm volatile("red.and.relaxed.cta.b32 [%0],%1;" :: _CG_ASM_PTR_CONSTRAINT(addr), "r"(val) : "memory");
146
+ #endif
147
+ }
148
+
149
+ // Special case of release, where last warp was doing extra work before releasing others, need to be release
150
+ // to ensure that extra work is visible
151
+ _CG_STATIC_QUALIFIER void red_and_release_cta(unsigned int *addr, unsigned int val) {
152
+ #if __CUDA_ARCH__ < 700
153
+ __threadfence_block();
154
+ atomicAnd(addr, val);
155
+ #else
156
+ asm volatile("red.and.release.cta.b32 [%0],%1;" :: _CG_ASM_PTR_CONSTRAINT(addr), "r"(val) : "memory");
157
+ #endif
158
+ }
159
+
160
+ // Read the barrier, acquire to ensure all memory operations following the sync are correctly performed after it is released
161
+ _CG_STATIC_QUALIFIER unsigned int ld_acquire_cta(unsigned int *addr) {
162
+ unsigned int val;
163
+ #if __CUDA_ARCH__ < 700
164
+ val = *((volatile unsigned int*) addr);
165
+ __threadfence_block();
166
+ #else
167
+ asm volatile("ld.acquire.cta.u32 %0,[%1];" : "=r"(val) : _CG_ASM_PTR_CONSTRAINT(addr) : "memory");
168
+ #endif
169
+ return val;
170
+ }
171
+
172
+ // Get synchronization bit mask of my thread_block_tile of size num_warps. Thread ranks 0..31 have the first bit assigned to them,
173
+ // thread ranks 32..63 second etc
174
+ // Bit masks are unique for each group, groups of the same size will have the same number of bits set, but on different positions
175
+ _CG_STATIC_QUALIFIER unsigned int get_group_mask(unsigned int thread_rank, unsigned int num_warps) {
176
+ return num_warps == 32 ? ~0 : ((1 << num_warps) - 1) << (num_warps * (thread_rank / (num_warps * 32)));
177
+ }
178
+
179
+ _CG_STATIC_QUALIFIER void barrier_wait(barrier_t *arrived, unsigned int warp_bit) {
180
+ while(ld_acquire_cta(arrived) & warp_bit);
181
+ }
182
+
183
+ // Default blocking sync.
184
+ _CG_STATIC_QUALIFIER void sync_warps(barrier_t *arrived, unsigned int thread_rank, unsigned int num_warps) {
185
+ unsigned int warp_id = thread_rank / 32;
186
+ bool warp_master = (thread_rank % 32 == 0);
187
+ unsigned int warp_bit = 1 << warp_id;
188
+ unsigned int group_mask = get_group_mask(thread_rank, num_warps);
189
+
190
+ __syncwarp(0xFFFFFFFF);
191
+
192
+ if (warp_master) {
193
+ unsigned int old = atom_or_acq_rel_cta(arrived, warp_bit);
194
+ if (((old | warp_bit) & group_mask) == group_mask) {
195
+ red_and_relaxed_cta(arrived, ~group_mask);
196
+ }
197
+ else {
198
+ barrier_wait(arrived, warp_bit);
199
+ }
200
+ }
201
+
202
+ __syncwarp(0xFFFFFFFF);
203
+ }
204
+
205
+ // Blocking sync, except the last arriving warp, that releases other warps, returns to do other stuff first.
206
+ // Warp returning true from this function needs to call sync_warps_release.
207
+ _CG_STATIC_QUALIFIER bool sync_warps_last_releases(barrier_t *arrived, unsigned int thread_rank, unsigned int num_warps) {
208
+ unsigned int warp_id = thread_rank / 32;
209
+ bool warp_master = (thread_rank % 32 == 0);
210
+ unsigned int warp_bit = 1 << warp_id;
211
+ unsigned int group_mask = get_group_mask(thread_rank, num_warps);
212
+
213
+ __syncwarp(0xFFFFFFFF);
214
+
215
+ unsigned int old = 0;
216
+ if (warp_master) {
217
+ old = atom_or_acq_rel_cta(arrived, warp_bit);
218
+ }
219
+ old = __shfl_sync(0xFFFFFFFF, old, 0);
220
+ if (((old | warp_bit) & group_mask) == group_mask) {
221
+ return true;
222
+ }
223
+ barrier_wait(arrived, warp_bit);
224
+
225
+ return false;
226
+ }
227
+
228
+ // Release my group from the barrier.
229
+ _CG_STATIC_QUALIFIER void sync_warps_release(barrier_t *arrived, bool is_master, unsigned int thread_rank, unsigned int num_warps) {
230
+ unsigned int group_mask = get_group_mask(thread_rank, num_warps);
231
+ if (is_master) {
232
+ red_and_release_cta(arrived, ~group_mask);
233
+ }
234
+ }
235
+
236
+ // Arrive at my group barrier, but don't block or release the barrier, even if every one arrives.
237
+ // sync_warps_release needs to be called by some warp after this one to reset the barrier.
238
+ _CG_STATIC_QUALIFIER void sync_warps_arrive(barrier_t *arrived, unsigned int thread_rank, unsigned int num_warps) {
239
+ unsigned int warp_id = thread_rank / 32;
240
+ bool warp_master = (thread_rank % 32 == 0);
241
+ unsigned int warp_bit = 1 << warp_id;
242
+ unsigned int group_mask = get_group_mask(thread_rank, num_warps);
243
+
244
+ __syncwarp(0xFFFFFFFF);
245
+
246
+ if (warp_master) {
247
+ red_or_release_cta(arrived, warp_bit);
248
+ }
249
+ }
250
+
251
+ // Wait for my warp to be released from the barrier. Warp must have arrived first.
252
+ _CG_STATIC_QUALIFIER void sync_warps_wait(barrier_t *arrived, unsigned int thread_rank) {
253
+ unsigned int warp_id = thread_rank / 32;
254
+ unsigned int warp_bit = 1 << warp_id;
255
+
256
+ barrier_wait(arrived, warp_bit);
257
+ }
258
+
259
+ // Wait for specific warp to arrive at the barrier
260
+ _CG_QUALIFIER void sync_warps_wait_for_specific_warp(barrier_t *arrived, unsigned int wait_warp_id) {
261
+ unsigned int wait_mask = 1 << wait_warp_id;
262
+ while((ld_acquire_cta(arrived) & wait_mask) != wait_mask);
263
+ }
264
+
265
+ // Initialize the bit corresponding to my warp in the barrier
266
+ _CG_QUALIFIER void sync_warps_reset(barrier_t *arrived, unsigned int thread_rank) {
267
+ unsigned int warp_id = thread_rank / 32;
268
+ unsigned int warp_bit = 1 << warp_id;
269
+
270
+ __syncwarp(0xFFFFFFFF);
271
+
272
+ if (thread_rank % 32 == 0) {
273
+ red_and_release_cta(arrived, ~warp_bit);
274
+ }
275
+ // No need to sync after the atomic, there will be a sync of the group that is being partitioned right after this.
276
+ }
277
+
278
+ } // details
279
+
280
+ _CG_END_NAMESPACE
281
+
282
+ #endif // _CG_GRID_H
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8774224f5b11a73b15d074a3fcce7327322c5c4cfdfd924d6a826779eec968fe
3
+ size 707904
.venv/lib/python3.11/site-packages/nvidia/cudnn/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn.h ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /* cudnn : Neural Networks Library */
51
+
52
+ #if !defined(CUDNN_H_)
53
+ #define CUDNN_H_
54
+ #if defined(__cplusplus)
55
+ extern "C" {
56
+ #endif
57
+
58
+ #include <cuda_runtime_api.h>
59
+ #include "cudnn_version.h"
60
+ #include "cudnn_graph.h"
61
+ #include "cudnn_ops.h"
62
+ #include "cudnn_adv.h"
63
+ #include "cudnn_cnn.h"
64
+
65
+ #if defined(__cplusplus)
66
+ }
67
+ #endif
68
+ #endif /* CUDNN_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv_v9.h ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /* cudnn_adv : cuDNN's advanced and experimental features.
51
+
52
+ */
53
+
54
+ #if !defined(CUDNN_ADV_H_)
55
+ #define CUDNN_ADV_H_
56
+
57
+ #include <stdint.h>
58
+
59
+ #include "cudnn_version.h"
60
+ #include "cudnn_ops.h"
61
+
62
+ /* These version numbers are autogenerated, do not edit manually. */
63
+ #define CUDNN_ADV_MAJOR 9
64
+ #define CUDNN_ADV_MINOR 1
65
+ #define CUDNN_ADV_PATCH 0
66
+
67
+ #if (CUDNN_ADV_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_MINOR != CUDNN_MINOR) || (CUDNN_ADV_PATCH != CUDNN_PATCHLEVEL)
68
+ #error Version mismatch in cuDNN ADV INFER!!!
69
+ #endif
70
+
71
+ #if defined(__cplusplus)
72
+ extern "C" {
73
+ #endif
74
+
75
+ /* BASIC RNN API */
76
+
77
+ typedef enum {
78
+ CUDNN_RNN_ALGO_STANDARD = 0,
79
+ CUDNN_RNN_ALGO_PERSIST_STATIC = 1,
80
+ CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2,
81
+ CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H = 3,
82
+ CUDNN_RNN_ALGO_COUNT = 4,
83
+ } cudnnRNNAlgo_t;
84
+
85
+ typedef enum {
86
+ CUDNN_FWD_MODE_INFERENCE = 0,
87
+ CUDNN_FWD_MODE_TRAINING = 1,
88
+ } cudnnForwardMode_t;
89
+
90
+ typedef enum {
91
+ CUDNN_RNN_RELU = 0, /* basic RNN cell type with ReLu activation */
92
+ CUDNN_RNN_TANH = 1, /* basic RNN cell type with tanh activation */
93
+ CUDNN_LSTM = 2, /* LSTM with optional recurrent projection and clipping */
94
+ CUDNN_GRU = 3, /* Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1); */
95
+ } cudnnRNNMode_t;
96
+
97
+ typedef enum {
98
+ CUDNN_RNN_NO_BIAS = 0, /* rnn cell formulas do not use biases */
99
+ CUDNN_RNN_SINGLE_INP_BIAS = 1, /* rnn cell formulas use one input bias in input GEMM */
100
+ CUDNN_RNN_DOUBLE_BIAS = 2, /* default, rnn cell formulas use two bias vectors */
101
+ CUDNN_RNN_SINGLE_REC_BIAS = 3 /* rnn cell formulas use one recurrent bias in recurrent GEMM */
102
+ } cudnnRNNBiasMode_t;
103
+
104
+ typedef enum {
105
+ CUDNN_UNIDIRECTIONAL = 0, /* single direction network */
106
+ CUDNN_BIDIRECTIONAL = 1, /* output concatination at each layer */
107
+ } cudnnDirectionMode_t;
108
+
109
+ typedef enum {
110
+ CUDNN_LINEAR_INPUT = 0, /* adjustable weight matrix in first layer input GEMM */
111
+ CUDNN_SKIP_INPUT = 1, /* fixed identity matrix in the first layer input GEMM */
112
+ } cudnnRNNInputMode_t;
113
+
114
+ typedef enum {
115
+ CUDNN_RNN_CLIP_NONE = 0, /* disables LSTM cell clipping */
116
+ CUDNN_RNN_CLIP_MINMAX = 1, /* enables LSTM cell clipping */
117
+ } cudnnRNNClipMode_t;
118
+
119
+ typedef enum {
120
+ CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED = 0, /* padded, outer stride from one time-step to the next */
121
+ CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED = 1, /* sequence length sorted and packed as in basic RNN api */
122
+ CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED = 2, /* padded, outer stride from one batch to the next */
123
+ } cudnnRNNDataLayout_t;
124
+
125
+ /* For auxFlags in cudnnSetRNNDescriptor_v8() */
126
+ #define CUDNN_RNN_PADDED_IO_DISABLED 0
127
+ #define CUDNN_RNN_PADDED_IO_ENABLED (1U << 0)
128
+
129
+ struct cudnnRNNStruct;
130
+ typedef struct cudnnRNNStruct *cudnnRNNDescriptor_t;
131
+
132
+ struct cudnnRNNDataStruct;
133
+ typedef struct cudnnRNNDataStruct *cudnnRNNDataDescriptor_t;
134
+
135
+ cudnnStatus_t CUDNNWINAPI
136
+ cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc);
137
+
138
+ cudnnStatus_t CUDNNWINAPI
139
+ cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
140
+
141
+ /*
142
+ * mathPrec in cudnnSetRNNDescriptor_v8() specifies compute precision.
143
+ * Compute precision is further modified by mathType that sets the
144
+ * preferred option for using NVIDIA Tensor Cores. dataType specify
145
+ * input/output data type and weight/bias type.
146
+ */
147
+
148
+ cudnnStatus_t CUDNNWINAPI
149
+ cudnnSetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
150
+ cudnnRNNAlgo_t algo,
151
+ cudnnRNNMode_t cellMode,
152
+ cudnnRNNBiasMode_t biasMode,
153
+ cudnnDirectionMode_t dirMode,
154
+ cudnnRNNInputMode_t inputMode,
155
+ cudnnDataType_t dataType,
156
+ cudnnDataType_t mathPrec,
157
+ cudnnMathType_t mathType,
158
+ int32_t inputSize,
159
+ int32_t hiddenSize,
160
+ int32_t projSize,
161
+ int32_t numLayers,
162
+ cudnnDropoutDescriptor_t dropoutDesc,
163
+ uint32_t auxFlags);
164
+
165
+ cudnnStatus_t CUDNNWINAPI
166
+ cudnnGetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
167
+ cudnnRNNAlgo_t *algo,
168
+ cudnnRNNMode_t *cellMode,
169
+ cudnnRNNBiasMode_t *biasMode,
170
+ cudnnDirectionMode_t *dirMode,
171
+ cudnnRNNInputMode_t *inputMode,
172
+ cudnnDataType_t *dataType,
173
+ cudnnDataType_t *mathPrec,
174
+ cudnnMathType_t *mathType,
175
+ int32_t *inputSize,
176
+ int32_t *hiddenSize,
177
+ int32_t *projSize,
178
+ int32_t *numLayers,
179
+ cudnnDropoutDescriptor_t *dropoutDesc,
180
+ uint32_t *auxFlags);
181
+
182
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
183
+ cudnnRNNSetClip_v8(cudnnRNNDescriptor_t rnnDesc,
184
+ cudnnRNNClipMode_t clipMode,
185
+ cudnnNanPropagation_t clipNanOpt,
186
+ double lclip,
187
+ double rclip);
188
+
189
+ cudnnStatus_t CUDNNWINAPI
190
+ cudnnRNNSetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t clipMode, double lclip, double rclip);
191
+
192
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
193
+ cudnnRNNGetClip_v8(cudnnRNNDescriptor_t rnnDesc,
194
+ cudnnRNNClipMode_t *clipMode,
195
+ cudnnNanPropagation_t *clipNanOpt,
196
+ double *lclip,
197
+ double *rclip);
198
+
199
+ cudnnStatus_t CUDNNWINAPI
200
+ cudnnRNNGetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t *clipMode, double *lclip, double *rclip);
201
+
202
+ cudnnStatus_t CUDNNWINAPI
203
+ cudnnBuildRNNDynamic(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int miniBatch);
204
+
205
+ cudnnStatus_t CUDNNWINAPI
206
+ cudnnGetRNNTempSpaceSizes(cudnnHandle_t handle,
207
+ cudnnRNNDescriptor_t rnnDesc,
208
+ cudnnForwardMode_t fwdMode,
209
+ cudnnRNNDataDescriptor_t xDesc,
210
+ size_t *workSpaceSize,
211
+ size_t *reserveSpaceSize);
212
+
213
+ cudnnStatus_t CUDNNWINAPI
214
+ cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, size_t *weightSpaceSize);
215
+
216
+ cudnnStatus_t CUDNNWINAPI
217
+ cudnnGetRNNWeightParams(cudnnHandle_t handle,
218
+ cudnnRNNDescriptor_t rnnDesc,
219
+ int32_t pseudoLayer,
220
+ size_t weightSpaceSize,
221
+ const void *weightSpace,
222
+ int32_t linLayerID,
223
+ cudnnTensorDescriptor_t mDesc,
224
+ void **mAddr,
225
+ cudnnTensorDescriptor_t bDesc,
226
+ void **bAddr);
227
+
228
+ cudnnStatus_t CUDNNWINAPI
229
+ cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc);
230
+
231
+ cudnnStatus_t CUDNNWINAPI
232
+ cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc);
233
+
234
+ cudnnStatus_t CUDNNWINAPI
235
+ cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
236
+ cudnnDataType_t dataType,
237
+ cudnnRNNDataLayout_t layout,
238
+ int maxSeqLength,
239
+ int batchSize,
240
+ int vectorSize,
241
+ const int seqLengthArray[], /* length of each sequence in the batch */
242
+ void *paddingFill); /* symbol for filling padding position in output */
243
+
244
+ cudnnStatus_t CUDNNWINAPI
245
+ cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
246
+ cudnnDataType_t *dataType,
247
+ cudnnRNNDataLayout_t *layout,
248
+ int *maxSeqLength,
249
+ int *batchSize,
250
+ int *vectorSize,
251
+ int arrayLengthRequested,
252
+ int seqLengthArray[],
253
+ void *paddingFill);
254
+
255
+ cudnnStatus_t CUDNNWINAPI
256
+ cudnnRNNForward(cudnnHandle_t handle,
257
+ cudnnRNNDescriptor_t rnnDesc,
258
+ cudnnForwardMode_t fwdMode,
259
+ const int32_t devSeqLengths[],
260
+ cudnnRNNDataDescriptor_t xDesc,
261
+ const void *x,
262
+ cudnnRNNDataDescriptor_t yDesc,
263
+ void *y,
264
+ cudnnTensorDescriptor_t hDesc,
265
+ const void *hx,
266
+ void *hy,
267
+ cudnnTensorDescriptor_t cDesc,
268
+ const void *cx,
269
+ void *cy,
270
+ size_t weightSpaceSize,
271
+ const void *weightSpace,
272
+ size_t workSpaceSize,
273
+ void *workSpace,
274
+ size_t reserveSpaceSize,
275
+ void *reserveSpace);
276
+
277
+ /* Sequence data descriptor */
278
+
279
+ typedef enum {
280
+ CUDNN_SEQDATA_TIME_DIM = 0, /* index in time */
281
+ CUDNN_SEQDATA_BATCH_DIM = 1, /* index in batch */
282
+ CUDNN_SEQDATA_BEAM_DIM = 2, /* index in beam */
283
+ CUDNN_SEQDATA_VECT_DIM = 3 /* index in vector */
284
+ } cudnnSeqDataAxis_t;
285
+
286
+ struct cudnnSeqDataStruct;
287
+ typedef struct cudnnSeqDataStruct *cudnnSeqDataDescriptor_t CUDNN_DEPRECATED;
288
+
289
+ #define CUDNN_SEQDATA_DIM_COUNT 4 /* dimension count */
290
+
291
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
292
+ cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc);
293
+
294
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
295
+ cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc);
296
+
297
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
298
+ cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc,
299
+ cudnnDataType_t dataType,
300
+ int nbDims,
301
+ const int dimA[],
302
+ const cudnnSeqDataAxis_t axes[],
303
+ size_t seqLengthArraySize,
304
+ const int seqLengthArray[],
305
+ void *paddingFill);
306
+
307
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
308
+ cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc,
309
+ cudnnDataType_t *dataType,
310
+ int *nbDims,
311
+ int nbDimsRequested,
312
+ int dimA[],
313
+ cudnnSeqDataAxis_t axes[],
314
+ size_t *seqLengthArraySize,
315
+ size_t seqLengthSizeRequested,
316
+ int seqLengthArray[],
317
+ void *paddingFill);
318
+
319
+ /* Multihead Attention */
320
+
321
+ /*
322
+ * Multi-head attention options passed via 'attnMode' in cudnnSetAttnDescriptor().
323
+ * Use the bitwise OR operator to combine several settings listed below. Additional
324
+ * minor options can be added here w/o changing or introducing new API functions.
325
+ */
326
+ #define CUDNN_ATTN_QUERYMAP_ALL_TO_ONE 0 /* multiple Q-s map to a single (K,V) set when beam size > 1 */
327
+ #define CUDNN_ATTN_QUERYMAP_ONE_TO_ONE (1U << 0) /* multiple Q-s map to multiple (K,V) sets when beam size > 1 */
328
+ #define CUDNN_ATTN_DISABLE_PROJ_BIASES 0 /* no biases in attention input and output projections */
329
+ #define CUDNN_ATTN_ENABLE_PROJ_BIASES (1U << 1) /* use biases in attention input and output projections */
330
+
331
+ struct cudnnAttnStruct;
332
+ typedef struct cudnnAttnStruct *cudnnAttnDescriptor_t CUDNN_DEPRECATED;
333
+
334
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
335
+ cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc);
336
+
337
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
338
+ cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc);
339
+
340
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
341
+ cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
342
+ unsigned attnMode,
343
+ int nHeads,
344
+ double smScaler,
345
+ cudnnDataType_t dataType,
346
+ cudnnDataType_t computePrec,
347
+ cudnnMathType_t mathType,
348
+ cudnnDropoutDescriptor_t attnDropoutDesc,
349
+ cudnnDropoutDescriptor_t postDropoutDesc,
350
+ int qSize,
351
+ int kSize,
352
+ int vSize,
353
+ int qProjSize,
354
+ int kProjSize,
355
+ int vProjSize,
356
+ int oProjSize,
357
+ int qoMaxSeqLength,
358
+ int kvMaxSeqLength,
359
+ int maxBatchSize,
360
+ int maxBeamSize);
361
+
362
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
363
+ cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
364
+ unsigned *attnMode,
365
+ int *nHeads,
366
+ double *smScaler,
367
+ cudnnDataType_t *dataType,
368
+ cudnnDataType_t *computePrec,
369
+ cudnnMathType_t *mathType,
370
+ cudnnDropoutDescriptor_t *attnDropoutDesc,
371
+ cudnnDropoutDescriptor_t *postDropoutDesc,
372
+ int *qSize,
373
+ int *kSize,
374
+ int *vSize,
375
+ int *qProjSize,
376
+ int *kProjSize,
377
+ int *vProjSize,
378
+ int *oProjSize,
379
+ int *qoMaxSeqLength,
380
+ int *kvMaxSeqLength,
381
+ int *maxBatchSize,
382
+ int *maxBeamSize);
383
+
384
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
385
+ cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle,
386
+ const cudnnAttnDescriptor_t attnDesc,
387
+ size_t *weightSizeInBytes,
388
+ size_t *workSpaceSizeInBytes,
389
+ size_t *reserveSpaceSizeInBytes);
390
+
391
+ typedef enum {
392
+ CUDNN_MH_ATTN_Q_WEIGHTS = 0, /* input projection weights for 'queries' */
393
+ CUDNN_MH_ATTN_K_WEIGHTS = 1, /* input projection weights for 'keys' */
394
+ CUDNN_MH_ATTN_V_WEIGHTS = 2, /* input projection weights for 'values' */
395
+ CUDNN_MH_ATTN_O_WEIGHTS = 3, /* output projection weights */
396
+ CUDNN_MH_ATTN_Q_BIASES = 4, /* input projection bias tensor for 'queries' */
397
+ CUDNN_MH_ATTN_K_BIASES = 5, /* input projection bias for 'keys' */
398
+ CUDNN_MH_ATTN_V_BIASES = 6, /* input projection bias for 'values' */
399
+ CUDNN_MH_ATTN_O_BIASES = 7, /* output projection biases */
400
+ } cudnnMultiHeadAttnWeightKind_t;
401
+
402
+ #define CUDNN_ATTN_WKIND_COUNT 8 /* Number of attention weight/bias tensors */
403
+
404
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
405
+ cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle,
406
+ const cudnnAttnDescriptor_t attnDesc,
407
+ cudnnMultiHeadAttnWeightKind_t wKind,
408
+ size_t weightSizeInBytes,
409
+ const void *weights,
410
+ cudnnTensorDescriptor_t wDesc,
411
+ void **wAddr);
412
+
413
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
414
+ cudnnMultiHeadAttnForward(cudnnHandle_t handle,
415
+ const cudnnAttnDescriptor_t attnDesc,
416
+ int currIdx,
417
+ const int loWinIdx[],
418
+ const int hiWinIdx[],
419
+ const int devSeqLengthsQO[],
420
+ const int devSeqLengthsKV[],
421
+ const cudnnSeqDataDescriptor_t qDesc,
422
+ const void *queries,
423
+ const void *residuals,
424
+ const cudnnSeqDataDescriptor_t kDesc,
425
+ const void *keys,
426
+ const cudnnSeqDataDescriptor_t vDesc,
427
+ const void *values,
428
+ const cudnnSeqDataDescriptor_t oDesc,
429
+ void *out,
430
+ size_t weightSizeInBytes,
431
+ const void *weights,
432
+ size_t workSpaceSizeInBytes,
433
+ void *workSpace,
434
+ size_t reserveSpaceSizeInBytes,
435
+ void *reserveSpace);
436
+
437
+ /*
438
+ * \brief Cross-library version checker.
439
+ * This function is implemented differently in each sub-library. Each sublib
440
+ * checks whether its own version matches that of its dependencies.
441
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
442
+ * CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
443
+ */
444
+ cudnnStatus_t CUDNNWINAPI
445
+ cudnnAdvVersionCheck(void);
446
+
447
+ typedef enum {
448
+ CUDNN_WGRAD_MODE_ADD = 0, /* add partial gradients to wgrad output buffers */
449
+ CUDNN_WGRAD_MODE_SET = 1, /* write partial gradients to wgrad output buffers */
450
+ } cudnnWgradMode_t;
451
+
452
+ cudnnStatus_t CUDNNWINAPI
453
+ cudnnRNNBackwardData_v8(cudnnHandle_t handle,
454
+ cudnnRNNDescriptor_t rnnDesc,
455
+ const int32_t devSeqLengths[],
456
+ cudnnRNNDataDescriptor_t yDesc,
457
+ const void *y,
458
+ const void *dy,
459
+ cudnnRNNDataDescriptor_t xDesc,
460
+ void *dx,
461
+ cudnnTensorDescriptor_t hDesc,
462
+ const void *hx,
463
+ const void *dhy,
464
+ void *dhx,
465
+ cudnnTensorDescriptor_t cDesc,
466
+ const void *cx,
467
+ const void *dcy,
468
+ void *dcx,
469
+ size_t weightSpaceSize,
470
+ const void *weightSpace,
471
+ size_t workSpaceSize,
472
+ void *workSpace,
473
+ size_t reserveSpaceSize,
474
+ void *reserveSpace);
475
+
476
+ cudnnStatus_t CUDNNWINAPI
477
+ cudnnRNNBackwardWeights_v8(cudnnHandle_t handle,
478
+ cudnnRNNDescriptor_t rnnDesc,
479
+ cudnnWgradMode_t addGrad,
480
+ const int32_t devSeqLengths[],
481
+ cudnnRNNDataDescriptor_t xDesc,
482
+ const void *x,
483
+ cudnnTensorDescriptor_t hDesc,
484
+ const void *hx,
485
+ cudnnRNNDataDescriptor_t yDesc,
486
+ const void *y,
487
+ size_t weightSpaceSize,
488
+ void *dweightSpace,
489
+ size_t workSpaceSize,
490
+ void *workSpace,
491
+ size_t reserveSpaceSize,
492
+ void *reserveSpace);
493
+
494
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
495
+ cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle,
496
+ const cudnnAttnDescriptor_t attnDesc,
497
+ const int loWinIdx[],
498
+ const int hiWinIdx[],
499
+ const int devSeqLengthsDQDO[],
500
+ const int devSeqLengthsDKDV[],
501
+ const cudnnSeqDataDescriptor_t doDesc,
502
+ const void *dout,
503
+ const cudnnSeqDataDescriptor_t dqDesc,
504
+ void *dqueries,
505
+ const void *queries,
506
+ const cudnnSeqDataDescriptor_t dkDesc,
507
+ void *dkeys,
508
+ const void *keys,
509
+ const cudnnSeqDataDescriptor_t dvDesc,
510
+ void *dvalues,
511
+ const void *values,
512
+ size_t weightSizeInBytes,
513
+ const void *weights,
514
+ size_t workSpaceSizeInBytes,
515
+ void *workSpace,
516
+ size_t reserveSpaceSizeInBytes,
517
+ void *reserveSpace);
518
+
519
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
520
+ cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle,
521
+ const cudnnAttnDescriptor_t attnDesc,
522
+ cudnnWgradMode_t addGrad,
523
+ const cudnnSeqDataDescriptor_t qDesc,
524
+ const void *queries,
525
+ const cudnnSeqDataDescriptor_t kDesc,
526
+ const void *keys,
527
+ const cudnnSeqDataDescriptor_t vDesc,
528
+ const void *values,
529
+ const cudnnSeqDataDescriptor_t doDesc,
530
+ const void *dout,
531
+ size_t weightSizeInBytes,
532
+ const void *weights,
533
+ void *dweights,
534
+ size_t workSpaceSizeInBytes,
535
+ void *workSpace,
536
+ size_t reserveSpaceSizeInBytes,
537
+ void *reserveSpace);
538
+
539
+ /*
540
+ * CTC (Connectionist Temporal Classification) loss descriptor create/destory/set/get functions
541
+ */
542
+ /* Input normalization mode for loss function */
543
+ typedef enum {
544
+ CUDNN_LOSS_NORMALIZATION_NONE = 0,
545
+ CUDNN_LOSS_NORMALIZATION_SOFTMAX = 1,
546
+ } cudnnLossNormalizationMode_t;
547
+
548
+ cudnnStatus_t CUDNNWINAPI
549
+ cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc);
550
+
551
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
552
+ cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType);
553
+
554
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
555
+ cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
556
+ cudnnDataType_t compType,
557
+ cudnnLossNormalizationMode_t normMode,
558
+ cudnnNanPropagation_t gradMode);
559
+
560
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
561
+ cudnnSetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
562
+ cudnnDataType_t compType,
563
+ cudnnLossNormalizationMode_t normMode,
564
+ cudnnNanPropagation_t gradMode,
565
+ int maxLabelLength);
566
+
567
+ cudnnStatus_t CUDNNWINAPI
568
+ cudnnSetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
569
+ cudnnDataType_t compType,
570
+ cudnnLossNormalizationMode_t normMode,
571
+ cudnnCTCGradMode_t ctcGradMode,
572
+ int maxLabelLength);
573
+
574
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
575
+ cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType);
576
+
577
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
578
+ cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
579
+ cudnnDataType_t *compType,
580
+ cudnnLossNormalizationMode_t *normMode,
581
+ cudnnNanPropagation_t *gradMode);
582
+
583
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
584
+ cudnnGetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
585
+ cudnnDataType_t *compType,
586
+ cudnnLossNormalizationMode_t *normMode,
587
+ cudnnNanPropagation_t *gradMode,
588
+ int *maxLabelLength);
589
+
590
+ cudnnStatus_t CUDNNWINAPI
591
+ cudnnGetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
592
+ cudnnDataType_t *compType,
593
+ cudnnLossNormalizationMode_t *normMode,
594
+ cudnnCTCGradMode_t *ctcGradMode,
595
+ int *maxLabelLength);
596
+
597
+ cudnnStatus_t CUDNNWINAPI
598
+ cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc);
599
+
600
+ /* return the ctc costs and gradients, given the probabilities and labels */
601
+ cudnnStatus_t CUDNNWINAPI
602
+ cudnnCTCLoss(
603
+ cudnnHandle_t handle,
604
+ const cudnnTensorDescriptor_t
605
+ probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
606
+ mini batch size, A is the alphabet size) */
607
+ const void *probs, /* probabilities after softmax, in GPU memory */
608
+ const int hostLabels[], /* labels, in CPU memory */
609
+ const int hostLabelLengths[], /* the length of each label, in CPU memory */
610
+ const int hostInputLengths[], /* the lengths of timing steps in each batch, in CPU memory */
611
+ void *costs, /* the returned costs of CTC, in GPU memory */
612
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
613
+ void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
614
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
615
+ cudnnCTCLossDescriptor_t ctcLossDesc,
616
+ void *workspace, /* pointer to the workspace, in GPU memory */
617
+ size_t workSpaceSizeInBytes); /* size of the workspace */
618
+
619
+ /* return the ctc costs and gradients, given the probabilities and labels */
620
+ cudnnStatus_t CUDNNWINAPI
621
+ cudnnCTCLoss_v8(
622
+ cudnnHandle_t handle,
623
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
624
+ cudnnCTCLossDescriptor_t ctcLossDesc,
625
+ const cudnnTensorDescriptor_t
626
+ probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
627
+ mini batch size, A is the alphabet size) */
628
+ const void *probs, /* probabilities after softmax, in GPU memory */
629
+ const int labels[], /* labels, in GPU memory */
630
+ const int labelLengths[], /* the length of each label, in GPU memory */
631
+ const int inputLengths[], /* the lengths of timing steps in each batch, in GPU memory */
632
+ void *costs, /* the returned costs of CTC, in GPU memory */
633
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
634
+ void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
635
+ size_t workSpaceSizeInBytes, /* size of the workspace */
636
+ void *workspace); /* pointer to the workspace, in GPU memory */
637
+
638
+ /* return the workspace size needed for ctc */
639
+ cudnnStatus_t CUDNNWINAPI
640
+ cudnnGetCTCLossWorkspaceSize(
641
+ cudnnHandle_t handle,
642
+ const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
643
+ timing steps, N is the mini batch size, A is the alphabet size) */
644
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
645
+ dimensions are T,N,A. To compute costs
646
+ only, set it to NULL */
647
+ const int *labels, /* labels, in CPU memory */
648
+ const int *labelLengths, /* the length of each label, in CPU memory */
649
+ const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */
650
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
651
+ cudnnCTCLossDescriptor_t ctcLossDesc,
652
+ size_t *sizeInBytes); /* pointer to the returned workspace size */
653
+
654
+ /* return the workspace size needed for ctc */
655
+ cudnnStatus_t CUDNNWINAPI
656
+ cudnnGetCTCLossWorkspaceSize_v8(
657
+ cudnnHandle_t handle,
658
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
659
+ cudnnCTCLossDescriptor_t ctcLossDesc,
660
+ const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
661
+ timing steps, N is the mini batch size, A is the alphabet size) */
662
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
663
+ dimensions are T,N,A. To compute costs
664
+ only, set it to NULL */
665
+ size_t *sizeInBytes); /* pointer to the returned workspace size */
666
+
667
+ #if defined(__cplusplus)
668
+ }
669
+ #endif
670
+
671
+ #endif /* CUDNN_ADV_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend_v9.h ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #ifndef _CUDNN_BACKEND_H_
51
+ #define _CUDNN_BACKEND_H_
52
+
53
+ /*
54
+ * The content of this header has been moved into cudnn_graph.h.
55
+ * This header is kept for the backward compatibility purpose.
56
+ */
57
+
58
+ #include "cudnn_graph.h"
59
+
60
+ #endif /* _CUDNN_BACKEND_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_graph.h ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_graph : cuDNN's basic definitions operations.
52
+ */
53
+
54
+ #if !defined(CUDNN_GRAPH_H_)
55
+ #define CUDNN_GRAPH_H_
56
+
57
+ #include <cuda_runtime_api.h>
58
+ #include <library_types.h>
59
+
60
+ #include <stdint.h>
61
+
62
+ #include "cudnn_version.h"
63
+
64
+ /* These version numbers are autogenerated, do not edit manually. */
65
+ #define CUDNN_GRAPH_MAJOR 9
66
+ #define CUDNN_GRAPH_MINOR 1
67
+ #define CUDNN_GRAPH_PATCH 0
68
+
69
+ #if (CUDNN_GRAPH_MAJOR != CUDNN_MAJOR) || (CUDNN_GRAPH_MINOR != CUDNN_MINOR) || (CUDNN_GRAPH_PATCH != CUDNN_PATCHLEVEL)
70
+ #error Version mismatch in cuDNN GRAPH!!!
71
+ #endif
72
+
73
+ #ifndef CUDNNWINAPI
74
+ #ifdef _WIN32
75
+ #define CUDNNWINAPI __stdcall
76
+ #else
77
+ #define CUDNNWINAPI
78
+ #endif
79
+ #endif
80
+
81
+ /* Warnings for deprecated API-s are enabled using the CUDNN_WARN_DEPRECATED macro */
82
+ #if defined(CUDNN_WARN_DEPRECATED) && (defined(__GNUC__) || defined(__clang__))
83
+ /* GCC, Intel C/C++, Cray C/C++, CLANG, IBM XL C/C++ little endian */
84
+ #define CUDNN_DEPRECATED __attribute__((deprecated))
85
+ #define CUDNN_DEPRECATED_ENUM __attribute__((deprecated))
86
+ #elif defined(CUDNN_WARN_DEPRECATED) && defined(_MSC_VER)
87
+ /* Microsoft Visual C++ */
88
+ #define CUDNN_DEPRECATED __declspec(deprecated)
89
+ #define CUDNN_DEPRECATED_ENUM __declspec(deprecated)
90
+ #elif defined(CUDNN_WARN_DEPRECATED) && (__cplusplus >= 201402L)
91
+ /* C++14 compilers */
92
+ #define CUDNN_DEPRECATED [[deprecated]]
93
+ #define CUDNN_DEPRECATED_ENUM [[deprecated]]
94
+ #else
95
+ /* No support for the deprecated attribute */
96
+ #define CUDNN_DEPRECATED
97
+ #define CUDNN_DEPRECATED_ENUM
98
+ #endif
99
+
100
+ #if defined(__cplusplus)
101
+ extern "C" {
102
+ #endif
103
+
104
+ struct cudnnContext;
105
+ typedef struct cudnnContext *cudnnHandle_t;
106
+
107
+ size_t CUDNNWINAPI
108
+ cudnnGetVersion(void);
109
+
110
+ size_t CUDNNWINAPI
111
+ cudnnGetMaxDeviceVersion(void);
112
+
113
+ /* Returns CUDA Runtime version statically linked against cudnn */
114
+ size_t CUDNNWINAPI
115
+ cudnnGetCudartVersion(void);
116
+
117
+ /*
118
+ * CUDNN return codes
119
+ */
120
+ typedef enum {
121
+ CUDNN_STATUS_SUCCESS = 0,
122
+
123
+ /* Uncategorized errors */
124
+ CUDNN_STATUS_NOT_INITIALIZED = 1001,
125
+ CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH = 1002,
126
+ CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH = 1003,
127
+ CUDNN_STATUS_DEPRECATED = 1004,
128
+ CUDNN_STATUS_LICENSE_ERROR = 1005,
129
+ CUDNN_STATUS_RUNTIME_IN_PROGRESS = 1006,
130
+ CUDNN_STATUS_RUNTIME_FP_OVERFLOW = 1007,
131
+
132
+ CUDNN_STATUS_BAD_PARAM = 2000,
133
+ CUDNN_STATUS_BAD_PARAM_NULL_POINTER = 2002,
134
+ CUDNN_STATUS_BAD_PARAM_MISALIGNED_POINTER = 2003,
135
+ CUDNN_STATUS_BAD_PARAM_NOT_FINALIZED = 2004,
136
+ CUDNN_STATUS_BAD_PARAM_OUT_OF_BOUND = 2005,
137
+ CUDNN_STATUS_BAD_PARAM_SIZE_INSUFFICIENT = 2006,
138
+ CUDNN_STATUS_BAD_PARAM_STREAM_MISMATCH = 2007,
139
+ CUDNN_STATUS_BAD_PARAM_SHAPE_MISMATCH = 2008,
140
+ CUDNN_STATUS_BAD_PARAM_DUPLICATED_ENTRIES = 2009,
141
+ CUDNN_STATUS_BAD_PARAM_ATTRIBUTE_TYPE = 2010,
142
+
143
+ CUDNN_STATUS_NOT_SUPPORTED = 3000,
144
+ CUDNN_STATUS_NOT_SUPPORTED_GRAPH_PATTERN = 3001,
145
+ CUDNN_STATUS_NOT_SUPPORTED_SHAPE = 3002,
146
+ CUDNN_STATUS_NOT_SUPPORTED_DATA_TYPE = 3003,
147
+ CUDNN_STATUS_NOT_SUPPORTED_LAYOUT = 3004,
148
+ CUDNN_STATUS_NOT_SUPPORTED_INCOMPATIBLE_CUDA_DRIVER = 3005,
149
+ CUDNN_STATUS_NOT_SUPPORTED_INCOMPATIBLE_CUDART = 3006,
150
+ CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH = 3007,
151
+ CUDNN_STATUS_NOT_SUPPORTED_RUNTIME_PREREQUISITE_MISSING = 3008,
152
+ CUDNN_STATUS_NOT_SUPPORTED_SUBLIBRARY_UNAVAILABLE = 3009,
153
+ CUDNN_STATUS_NOT_SUPPORTED_SHARED_MEMORY_INSUFFICIENT = 3010,
154
+ CUDNN_STATUS_NOT_SUPPORTED_PADDING = 3011,
155
+ CUDNN_STATUS_NOT_SUPPORTED_BAD_LAUNCH_PARAM = 3012,
156
+
157
+ CUDNN_STATUS_INTERNAL_ERROR = 4000,
158
+ CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED = 4001,
159
+ CUDNN_STATUS_INTERNAL_ERROR_UNEXPECTED_VALUE = 4002,
160
+ CUDNN_STATUS_INTERNAL_ERROR_HOST_ALLOCATION_FAILED = 4003,
161
+ CUDNN_STATUS_INTERNAL_ERROR_DEVICE_ALLOCATION_FAILED = 4004,
162
+ CUDNN_STATUS_INTERNAL_ERROR_BAD_LAUNCH_PARAM = 4005,
163
+ CUDNN_STATUS_INTERNAL_ERROR_TEXTURE_CREATION_FAILED = 4006,
164
+
165
+ CUDNN_STATUS_EXECUTION_FAILED = 5000,
166
+ CUDNN_STATUS_EXECUTION_FAILED_CUDA_DRIVER = 5001,
167
+ CUDNN_STATUS_EXECUTION_FAILED_CUBLAS = 5002,
168
+ CUDNN_STATUS_EXECUTION_FAILED_CUDART = 5003,
169
+ CUDNN_STATUS_EXECUTION_FAILED_CURAND = 5004,
170
+
171
+ CUDNN_STATUS_ALLOC_FAILED CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_INTERNAL_ERROR_HOST_ALLOCATION_FAILED,
172
+ CUDNN_STATUS_INVALID_VALUE CUDNN_DEPRECATED_ENUM = 2001 /* please transition to CUDNN_STATUS_BAD_PARAM instead */,
173
+ CUDNN_STATUS_ARCH_MISMATCH CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH,
174
+ CUDNN_STATUS_MAPPING_ERROR CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_INTERNAL_ERROR_TEXTURE_CREATION_FAILED,
175
+ CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING CUDNN_DEPRECATED_ENUM =
176
+ CUDNN_STATUS_NOT_SUPPORTED_RUNTIME_PREREQUISITE_MISSING,
177
+ CUDNN_STATUS_VERSION_MISMATCH CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH,
178
+ } cudnnStatus_t;
179
+
180
+ #define CUDNN_STATUS_FULL_ERROR_CODE(category, specific_err) ((cudnnStatus_t)(0 + (category) + (specific_err)))
181
+ #define CUDNN_STATUS_CATEGORY(full_error_code) ((full_error_code) / 1000 * 1000)
182
+ #define CUDNN_STATUS_SPECIFIC_ERROR(full_error_code) ((full_error_code) % 1000)
183
+
184
+ /* human-readable error messages */
185
+ const char *CUDNNWINAPI
186
+ cudnnGetErrorString(cudnnStatus_t status);
187
+
188
+ void CUDNNWINAPI
189
+ cudnnGetLastErrorString(char *message, size_t max_size);
190
+
191
+ /* Forward definition in this version only */
192
+ typedef struct cudnnRuntimeTag_t cudnnRuntimeTag_t CUDNN_DEPRECATED;
193
+
194
+ typedef enum {
195
+ CUDNN_ERRQUERY_RAWCODE = 0,
196
+ CUDNN_ERRQUERY_NONBLOCKING = 1,
197
+ CUDNN_ERRQUERY_BLOCKING = 2,
198
+ } cudnnErrQueryMode_t;
199
+
200
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
201
+ cudnnQueryRuntimeError(cudnnHandle_t handle, cudnnStatus_t *rstatus, cudnnErrQueryMode_t mode, cudnnRuntimeTag_t *tag);
202
+
203
+ cudnnStatus_t CUDNNWINAPI
204
+ cudnnGetProperty(libraryPropertyType type, int *value);
205
+
206
+ cudnnStatus_t CUDNNWINAPI
207
+ cudnnCreate(cudnnHandle_t *handle);
208
+ cudnnStatus_t CUDNNWINAPI
209
+ cudnnDestroy(cudnnHandle_t handle);
210
+ cudnnStatus_t CUDNNWINAPI
211
+ cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
212
+ cudnnStatus_t CUDNNWINAPI
213
+ cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId);
214
+ /*
215
+ * CUDNN data type
216
+ */
217
+ typedef enum {
218
+ CUDNN_DATA_FLOAT = 0,
219
+ CUDNN_DATA_DOUBLE = 1,
220
+ CUDNN_DATA_HALF = 2,
221
+ CUDNN_DATA_INT8 = 3,
222
+ CUDNN_DATA_INT32 = 4,
223
+ CUDNN_DATA_INT8x4 CUDNN_DEPRECATED_ENUM = 5,
224
+ CUDNN_DATA_UINT8 = 6,
225
+ CUDNN_DATA_UINT8x4 CUDNN_DEPRECATED_ENUM = 7,
226
+ CUDNN_DATA_INT8x32 CUDNN_DEPRECATED_ENUM = 8,
227
+ CUDNN_DATA_BFLOAT16 = 9,
228
+ CUDNN_DATA_INT64 = 10,
229
+ CUDNN_DATA_BOOLEAN = 11,
230
+ CUDNN_DATA_FP8_E4M3 = 12,
231
+ CUDNN_DATA_FP8_E5M2 = 13,
232
+ CUDNN_DATA_FAST_FLOAT_FOR_FP8 = 14,
233
+ } cudnnDataType_t;
234
+
235
+ /*
236
+ * CUDNN math type
237
+ */
238
+ typedef enum {
239
+ CUDNN_DEFAULT_MATH = 0,
240
+ CUDNN_TENSOR_OP_MATH = 1,
241
+ CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION = 2,
242
+ CUDNN_FMA_MATH = 3,
243
+ } cudnnMathType_t;
244
+
245
+ /*
246
+ * CUDNN propagate Nan
247
+ */
248
+ typedef enum {
249
+ CUDNN_NOT_PROPAGATE_NAN CUDNN_DEPRECATED_ENUM = 0,
250
+ CUDNN_PROPAGATE_NAN CUDNN_DEPRECATED_ENUM = 1,
251
+ } cudnnNanPropagation_t;
252
+
253
+ /*
254
+ * Behavior for OOB samples. OOB samples are samples where L+R > T is encountered during the gradient calculation. If
255
+ * gradMode is set to CUDNN_CTC_SKIP_OOB_GRADIENTS, then the CTC loss function does not write to the gradient buffer for
256
+ * that sample. Instead, the current values, even not finite, are retained. If gradMode is set to
257
+ * CUDNN_CTC_ZERO_OOB_GRADIENTS, then the gradient for that sample is set to zero. This guarantees a finite gradient.
258
+ */
259
+ typedef enum {
260
+ CUDNN_CTC_ZERO_OOB_GRADIENTS = 0,
261
+ CUDNN_CTC_SKIP_OOB_GRADIENTS = 1,
262
+ } cudnnCTCGradMode_t;
263
+
264
+ typedef enum {
265
+ CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */
266
+ CUDNN_TENSOR_NHWC = 1, /* feature maps interleaved ( cStride = 1 )*/
267
+ CUDNN_TENSOR_NCHW_VECT_C = 2, /* each image point is vector of element of C, vector length in data type */
268
+ } cudnnTensorFormat_t;
269
+
270
+ /*
271
+ * CUDNN ReduceTensor op type
272
+ */
273
+ typedef enum {
274
+ CUDNN_REDUCE_TENSOR_ADD = 0,
275
+ CUDNN_REDUCE_TENSOR_MUL = 1,
276
+ CUDNN_REDUCE_TENSOR_MIN = 2,
277
+ CUDNN_REDUCE_TENSOR_MAX = 3,
278
+ CUDNN_REDUCE_TENSOR_AMAX = 4,
279
+ CUDNN_REDUCE_TENSOR_AVG = 5,
280
+ CUDNN_REDUCE_TENSOR_NORM1 = 6,
281
+ CUDNN_REDUCE_TENSOR_NORM2 = 7,
282
+ CUDNN_REDUCE_TENSOR_MUL_NO_ZEROS = 8,
283
+ } cudnnReduceTensorOp_t;
284
+
285
+ /*
286
+ * activation mode
287
+ */
288
+ typedef enum {
289
+ CUDNN_ACTIVATION_SIGMOID = 0,
290
+ CUDNN_ACTIVATION_RELU = 1,
291
+ CUDNN_ACTIVATION_TANH = 2,
292
+ CUDNN_ACTIVATION_CLIPPED_RELU = 3,
293
+ CUDNN_ACTIVATION_ELU = 4,
294
+ CUDNN_ACTIVATION_IDENTITY = 5,
295
+ CUDNN_ACTIVATION_SWISH = 6
296
+ } cudnnActivationMode_t CUDNN_DEPRECATED;
297
+
298
+ typedef enum {
299
+ CUDNN_SEV_FATAL = 0,
300
+ CUDNN_SEV_ERROR = 1,
301
+ CUDNN_SEV_WARNING = 2,
302
+ CUDNN_SEV_INFO = 3,
303
+ } cudnnSeverity_t;
304
+
305
+ /* Message masks to be used with cudnnSetCallback() */
306
+ #define CUDNN_SEV_ERROR_EN (1U << CUDNN_SEV_ERROR)
307
+ #define CUDNN_SEV_WARNING_EN (1U << CUDNN_SEV_WARNING)
308
+ #define CUDNN_SEV_INFO_EN (1U << CUDNN_SEV_INFO)
309
+
310
+ /* struct containing useful informaiton for each API call */
311
+ typedef struct cudnnDebugStruct {
312
+ unsigned cudnn_version;
313
+ cudnnStatus_t cudnnStatus;
314
+ unsigned time_sec; /* epoch time in seconds */
315
+ unsigned time_usec; /* microseconds part of epoch time */
316
+ unsigned time_delta; /* time since start in seconds */
317
+ cudnnHandle_t handle; /* cudnn handle */
318
+ cudaStream_t stream; /* cuda stream ID */
319
+ unsigned long long pid; /* process ID */
320
+ unsigned long long tid; /* thread ID */
321
+ int cudaDeviceId; /* CUDA device ID */
322
+ int reserved[15]; /* reserved for future use */
323
+ } cudnnDebug_t;
324
+
325
+ typedef void (*cudnnCallback_t)(cudnnSeverity_t sev, void *udata, const cudnnDebug_t *dbg, const char *msg);
326
+
327
+ cudnnStatus_t CUDNNWINAPI
328
+ cudnnSetCallback(unsigned mask, void *udata, cudnnCallback_t fptr);
329
+
330
+ cudnnStatus_t CUDNNWINAPI
331
+ cudnnGetCallback(unsigned *mask, void **udata, cudnnCallback_t *fptr);
332
+
333
+ /*
334
+ * \brief Cross-library version checker.
335
+ * This function is implemented differently in each sub-library. Each sublib
336
+ * checks whether its own version matches that of its dependencies.
337
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
338
+ * CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
339
+ */
340
+ cudnnStatus_t CUDNNWINAPI
341
+ cudnnGraphVersionCheck(void);
342
+
343
+ /* Maximum supported number of tensor dimensions */
344
+ #define CUDNN_DIM_MAX 8
345
+
346
+ /*
347
+ * convolution mode
348
+ */
349
+ typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t;
350
+
351
+ /*
352
+ * CUDNN Reorder
353
+ */
354
+ typedef enum {
355
+ CUDNN_DEFAULT_REORDER = 0,
356
+ CUDNN_NO_REORDER = 1,
357
+ } cudnnReorderType_t CUDNN_DEPRECATED;
358
+
359
+ typedef void *cudnnBackendDescriptor_t;
360
+
361
+ typedef struct cudnnFractionStruct {
362
+ int64_t numerator;
363
+ int64_t denominator;
364
+ } cudnnFraction_t;
365
+
366
+ typedef enum {
367
+ CUDNN_POINTWISE_ADD = 0,
368
+ CUDNN_POINTWISE_ADD_SQUARE = 5,
369
+ CUDNN_POINTWISE_DIV = 6,
370
+ CUDNN_POINTWISE_MAX = 3,
371
+ CUDNN_POINTWISE_MIN = 2,
372
+ CUDNN_POINTWISE_MOD = 7,
373
+ CUDNN_POINTWISE_MUL = 1,
374
+ CUDNN_POINTWISE_POW = 8,
375
+ CUDNN_POINTWISE_SUB = 9,
376
+
377
+ CUDNN_POINTWISE_ABS = 10,
378
+ CUDNN_POINTWISE_CEIL = 11,
379
+ CUDNN_POINTWISE_COS = 12,
380
+ CUDNN_POINTWISE_EXP = 13,
381
+ CUDNN_POINTWISE_FLOOR = 14,
382
+ CUDNN_POINTWISE_LOG = 15,
383
+ CUDNN_POINTWISE_NEG = 16,
384
+ CUDNN_POINTWISE_RSQRT = 17,
385
+ CUDNN_POINTWISE_SIN = 18,
386
+ CUDNN_POINTWISE_SQRT = 4,
387
+ CUDNN_POINTWISE_TAN = 19,
388
+ CUDNN_POINTWISE_ERF = 20,
389
+ CUDNN_POINTWISE_IDENTITY = 21,
390
+ CUDNN_POINTWISE_RECIPROCAL = 22,
391
+ CUDNN_POINTWISE_ATAN2 = 23,
392
+
393
+ CUDNN_POINTWISE_RELU_FWD = 100,
394
+ CUDNN_POINTWISE_TANH_FWD = 101,
395
+ CUDNN_POINTWISE_SIGMOID_FWD = 102,
396
+ CUDNN_POINTWISE_ELU_FWD = 103,
397
+ CUDNN_POINTWISE_GELU_FWD = 104,
398
+ CUDNN_POINTWISE_SOFTPLUS_FWD = 105,
399
+ CUDNN_POINTWISE_SWISH_FWD = 106,
400
+ CUDNN_POINTWISE_GELU_APPROX_TANH_FWD = 107,
401
+
402
+ CUDNN_POINTWISE_RELU_BWD = 200,
403
+ CUDNN_POINTWISE_TANH_BWD = 201,
404
+ CUDNN_POINTWISE_SIGMOID_BWD = 202,
405
+ CUDNN_POINTWISE_ELU_BWD = 203,
406
+ CUDNN_POINTWISE_GELU_BWD = 204,
407
+ CUDNN_POINTWISE_SOFTPLUS_BWD = 205,
408
+ CUDNN_POINTWISE_SWISH_BWD = 206,
409
+ CUDNN_POINTWISE_GELU_APPROX_TANH_BWD = 207,
410
+
411
+ CUDNN_POINTWISE_CMP_EQ = 300,
412
+ CUDNN_POINTWISE_CMP_NEQ = 301,
413
+ CUDNN_POINTWISE_CMP_GT = 302,
414
+ CUDNN_POINTWISE_CMP_GE = 303,
415
+ CUDNN_POINTWISE_CMP_LT = 304,
416
+ CUDNN_POINTWISE_CMP_LE = 305,
417
+
418
+ CUDNN_POINTWISE_LOGICAL_AND = 400,
419
+ CUDNN_POINTWISE_LOGICAL_OR = 401,
420
+ CUDNN_POINTWISE_LOGICAL_NOT = 402,
421
+
422
+ CUDNN_POINTWISE_GEN_INDEX = 501,
423
+
424
+ CUDNN_POINTWISE_BINARY_SELECT = 601,
425
+ } cudnnPointwiseMode_t;
426
+
427
+ typedef enum {
428
+ CUDNN_RESAMPLE_NEAREST = 0,
429
+ CUDNN_RESAMPLE_BILINEAR = 1,
430
+ CUDNN_RESAMPLE_AVGPOOL = 2,
431
+ CUDNN_RESAMPLE_AVGPOOL_INCLUDE_PADDING = 2,
432
+ CUDNN_RESAMPLE_AVGPOOL_EXCLUDE_PADDING = 4,
433
+ CUDNN_RESAMPLE_MAXPOOL = 3,
434
+ } cudnnResampleMode_t;
435
+
436
+ typedef enum {
437
+ CUDNN_SIGNAL_SET = 0,
438
+ CUDNN_SIGNAL_WAIT = 1,
439
+ } cudnnSignalMode_t;
440
+
441
+ typedef enum {
442
+ CUDNN_GENSTATS_SUM_SQSUM = 0,
443
+ } cudnnGenStatsMode_t;
444
+
445
+ typedef enum {
446
+ CUDNN_BN_FINALIZE_STATISTICS_TRAINING = 0,
447
+ CUDNN_BN_FINALIZE_STATISTICS_INFERENCE = 1,
448
+ } cudnnBnFinalizeStatsMode_t;
449
+
450
+ typedef enum {
451
+ CUDNN_RNG_DISTRIBUTION_BERNOULLI,
452
+ CUDNN_RNG_DISTRIBUTION_UNIFORM,
453
+ CUDNN_RNG_DISTRIBUTION_NORMAL,
454
+ } cudnnRngDistribution_t;
455
+
456
+ typedef enum {
457
+ CUDNN_ATTR_POINTWISE_MODE = 0,
458
+ CUDNN_ATTR_POINTWISE_MATH_PREC = 1,
459
+ CUDNN_ATTR_POINTWISE_NAN_PROPAGATION CUDNN_DEPRECATED_ENUM = 2,
460
+ CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3,
461
+ CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4,
462
+ CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE = 5,
463
+ CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6,
464
+ CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7,
465
+ CUDNN_ATTR_POINTWISE_SWISH_BETA = 8,
466
+ CUDNN_ATTR_POINTWISE_AXIS = 9,
467
+
468
+ CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100,
469
+ CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101,
470
+ CUDNN_ATTR_CONVOLUTION_DILATIONS = 102,
471
+ CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES = 103,
472
+ CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104,
473
+ CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105,
474
+ CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106,
475
+
476
+ CUDNN_ATTR_ENGINEHEUR_MODE = 200,
477
+ CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH = 201,
478
+ CUDNN_ATTR_ENGINEHEUR_RESULTS = 202,
479
+ CUDNN_ATTR_ENGINEHEUR_SM_COUNT_TARGET = 203,
480
+
481
+ CUDNN_ATTR_ENGINECFG_ENGINE = 300,
482
+ CUDNN_ATTR_ENGINECFG_INTERMEDIATE_INFO = 301,
483
+ CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302,
484
+
485
+ CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400,
486
+ CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401,
487
+ CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402,
488
+ CUDNN_ATTR_EXECUTION_PLAN_COMPUTED_INTERMEDIATE_UIDS = 403,
489
+ CUDNN_ATTR_EXECUTION_PLAN_RUN_ONLY_INTERMEDIATE_UIDS = 404,
490
+ CUDNN_ATTR_EXECUTION_PLAN_JSON_REPRESENTATION = 405,
491
+
492
+ CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500,
493
+ CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501,
494
+ CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502,
495
+ CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_ATTRIBUTES = 503,
496
+
497
+ CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600,
498
+ CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE = 601,
499
+
500
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700,
501
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701,
502
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702,
503
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703,
504
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704,
505
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705,
506
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706,
507
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707,
508
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708,
509
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709,
510
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710,
511
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711,
512
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712,
513
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713,
514
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC = 714,
515
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715,
516
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716,
517
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717,
518
+
519
+ CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR = 750,
520
+ CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751,
521
+ CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752,
522
+ CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753,
523
+ CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754,
524
+ CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755,
525
+ CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756,
526
+ CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757,
527
+ CUDNN_ATTR_OPERATION_POINTWISE_TDESC = 758,
528
+
529
+ CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770,
530
+ CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC = 771,
531
+ CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772,
532
+ CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773,
533
+ CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC = 774,
534
+
535
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780,
536
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781,
537
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782,
538
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783,
539
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784,
540
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785,
541
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786,
542
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787,
543
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC = 788,
544
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789,
545
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790,
546
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791,
547
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792,
548
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793,
549
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794,
550
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795,
551
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796,
552
+
553
+ CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800,
554
+ CUDNN_ATTR_OPERATIONGRAPH_OPS = 801,
555
+ CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT = 802,
556
+
557
+ CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900,
558
+ CUDNN_ATTR_TENSOR_DATA_TYPE = 901,
559
+ CUDNN_ATTR_TENSOR_DIMENSIONS = 902,
560
+ CUDNN_ATTR_TENSOR_STRIDES = 903,
561
+ CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904,
562
+ CUDNN_ATTR_TENSOR_VECTORIZED_DIMENSION = 905,
563
+ CUDNN_ATTR_TENSOR_UNIQUE_ID = 906,
564
+ CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907,
565
+ CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908,
566
+ CUDNN_ATTR_TENSOR_REORDERING_MODE = 909,
567
+ CUDNN_ATTR_TENSOR_RAGGED_OFFSET_DESC = 913,
568
+
569
+ CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000,
570
+ CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS = 1001,
571
+ CUDNN_ATTR_VARIANT_PACK_INTERMEDIATES = 1002,
572
+ CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003,
573
+
574
+ CUDNN_ATTR_LAYOUT_INFO_TENSOR_UID = 1100,
575
+ CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101,
576
+
577
+ CUDNN_ATTR_KNOB_INFO_TYPE = 1200,
578
+ CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE = 1201,
579
+ CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE = 1202,
580
+ CUDNN_ATTR_KNOB_INFO_STRIDE = 1203,
581
+
582
+ CUDNN_ATTR_ENGINE_OPERATION_GRAPH = 1300,
583
+ CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301,
584
+ CUDNN_ATTR_ENGINE_KNOB_INFO = 1302,
585
+ CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303,
586
+ CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304,
587
+ CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305,
588
+ CUDNN_ATTR_ENGINE_SM_COUNT_TARGET = 1306,
589
+
590
+ CUDNN_ATTR_MATMUL_COMP_TYPE = 1500,
591
+ CUDNN_ATTR_MATMUL_PADDING_VALUE = 1503,
592
+
593
+ CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520,
594
+ CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521,
595
+ CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522,
596
+ CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523,
597
+ CUDNN_ATTR_OPERATION_MATMUL_IRREGULARLY_STRIDED_BATCH_COUNT CUDNN_DEPRECATED_ENUM = 1524,
598
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_M_OVERRIDE_DESC = 1525,
599
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_N_OVERRIDE_DESC = 1526,
600
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_K_OVERRIDE_DESC = 1527,
601
+
602
+ CUDNN_ATTR_REDUCTION_OPERATOR = 1600,
603
+ CUDNN_ATTR_REDUCTION_COMP_TYPE = 1601,
604
+
605
+ CUDNN_ATTR_OPERATION_REDUCTION_XDESC = 1610,
606
+ CUDNN_ATTR_OPERATION_REDUCTION_YDESC = 1611,
607
+ CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612,
608
+
609
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620,
610
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621,
611
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622,
612
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623,
613
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624,
614
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625,
615
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626,
616
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627,
617
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC = 1628,
618
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629,
619
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630,
620
+
621
+ CUDNN_ATTR_RESAMPLE_MODE = 1700,
622
+ CUDNN_ATTR_RESAMPLE_COMP_TYPE = 1701,
623
+ CUDNN_ATTR_RESAMPLE_SPATIAL_DIMS = 1702,
624
+ CUDNN_ATTR_RESAMPLE_POST_PADDINGS = 1703,
625
+ CUDNN_ATTR_RESAMPLE_PRE_PADDINGS = 1704,
626
+ CUDNN_ATTR_RESAMPLE_STRIDES = 1705,
627
+ CUDNN_ATTR_RESAMPLE_WINDOW_DIMS = 1706,
628
+ CUDNN_ATTR_RESAMPLE_NAN_PROPAGATION = 1707,
629
+ CUDNN_ATTR_RESAMPLE_PADDING_MODE = 1708,
630
+
631
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_XDESC = 1710,
632
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_YDESC = 1711,
633
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_IDXDESC = 1712,
634
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_ALPHA CUDNN_DEPRECATED_ENUM = 1713,
635
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_BETA CUDNN_DEPRECATED_ENUM = 1714,
636
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_DESC = 1716,
637
+
638
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DXDESC = 1720,
639
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DYDESC = 1721,
640
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_IDXDESC = 1722,
641
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_ALPHA CUDNN_DEPRECATED_ENUM = 1723,
642
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_BETA CUDNN_DEPRECATED_ENUM = 1724,
643
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DESC = 1725,
644
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_XDESC = 1726,
645
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_YDESC = 1727,
646
+
647
+ CUDNN_ATTR_OPERATION_CONCAT_AXIS = 1800,
648
+ CUDNN_ATTR_OPERATION_CONCAT_INPUT_DESCS = 1801,
649
+ CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX = 1802,
650
+ CUDNN_ATTR_OPERATION_CONCAT_OUTPUT_DESC = 1803,
651
+
652
+ CUDNN_ATTR_OPERATION_SIGNAL_MODE = 1900,
653
+ CUDNN_ATTR_OPERATION_SIGNAL_FLAGDESC = 1901,
654
+ CUDNN_ATTR_OPERATION_SIGNAL_VALUE = 1902,
655
+ CUDNN_ATTR_OPERATION_SIGNAL_XDESC = 1903,
656
+ CUDNN_ATTR_OPERATION_SIGNAL_YDESC = 1904,
657
+
658
+ CUDNN_ATTR_OPERATION_NORM_FWD_MODE = 2000,
659
+ CUDNN_ATTR_OPERATION_NORM_FWD_PHASE = 2001,
660
+ CUDNN_ATTR_OPERATION_NORM_FWD_XDESC = 2002,
661
+ CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC = 2003,
662
+ CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC = 2004,
663
+ CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC = 2005,
664
+ CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC = 2006,
665
+ CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC = 2007,
666
+ CUDNN_ATTR_OPERATION_NORM_FWD_EXP_AVG_FACTOR_DESC = 2008,
667
+ CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_MEAN_DESC = 2009,
668
+ CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_VAR_DESC = 2010,
669
+ CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_MEAN_DESC = 2011,
670
+ CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_VAR_DESC = 2012,
671
+ CUDNN_ATTR_OPERATION_NORM_FWD_YDESC = 2013,
672
+ CUDNN_ATTR_OPERATION_NORM_FWD_PEER_STAT_DESCS = 2014,
673
+
674
+ CUDNN_ATTR_OPERATION_NORM_BWD_MODE = 2100,
675
+ CUDNN_ATTR_OPERATION_NORM_BWD_XDESC = 2101,
676
+ CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC = 2102,
677
+ CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC = 2103,
678
+ CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC = 2104,
679
+ CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC = 2105,
680
+ CUDNN_ATTR_OPERATION_NORM_BWD_EPSILON_DESC = 2106,
681
+ CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC = 2107,
682
+ CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC = 2108,
683
+ CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC = 2109,
684
+ CUDNN_ATTR_OPERATION_NORM_BWD_PEER_STAT_DESCS = 2110,
685
+
686
+ CUDNN_ATTR_OPERATION_RESHAPE_XDESC = 2200,
687
+ CUDNN_ATTR_OPERATION_RESHAPE_YDESC = 2201,
688
+
689
+ CUDNN_ATTR_RNG_DISTRIBUTION = 2300,
690
+ CUDNN_ATTR_RNG_NORMAL_DIST_MEAN = 2301,
691
+ CUDNN_ATTR_RNG_NORMAL_DIST_STANDARD_DEVIATION = 2302,
692
+ CUDNN_ATTR_RNG_UNIFORM_DIST_MAXIMUM = 2303,
693
+ CUDNN_ATTR_RNG_UNIFORM_DIST_MINIMUM = 2304,
694
+ CUDNN_ATTR_RNG_BERNOULLI_DIST_PROBABILITY = 2305,
695
+
696
+ CUDNN_ATTR_OPERATION_RNG_YDESC = 2310,
697
+ CUDNN_ATTR_OPERATION_RNG_SEED = 2311,
698
+ CUDNN_ATTR_OPERATION_RNG_DESC = 2312,
699
+ CUDNN_ATTR_OPERATION_RNG_OFFSET_DESC = 2313,
700
+ } cudnnBackendAttributeName_t;
701
+
702
+ typedef enum {
703
+ CUDNN_TYPE_HANDLE = 0,
704
+ CUDNN_TYPE_DATA_TYPE,
705
+ CUDNN_TYPE_BOOLEAN,
706
+ CUDNN_TYPE_INT64,
707
+ CUDNN_TYPE_FLOAT,
708
+ CUDNN_TYPE_DOUBLE,
709
+ CUDNN_TYPE_VOID_PTR,
710
+ CUDNN_TYPE_CONVOLUTION_MODE,
711
+ CUDNN_TYPE_HEUR_MODE,
712
+ CUDNN_TYPE_KNOB_TYPE,
713
+ CUDNN_TYPE_NAN_PROPOGATION CUDNN_DEPRECATED_ENUM,
714
+ CUDNN_TYPE_NUMERICAL_NOTE,
715
+ CUDNN_TYPE_LAYOUT_TYPE,
716
+ CUDNN_TYPE_ATTRIB_NAME,
717
+ CUDNN_TYPE_POINTWISE_MODE,
718
+ CUDNN_TYPE_BACKEND_DESCRIPTOR,
719
+ CUDNN_TYPE_GENSTATS_MODE,
720
+ CUDNN_TYPE_BN_FINALIZE_STATS_MODE,
721
+ CUDNN_TYPE_REDUCTION_OPERATOR_TYPE,
722
+ CUDNN_TYPE_BEHAVIOR_NOTE,
723
+ CUDNN_TYPE_TENSOR_REORDERING_MODE,
724
+ CUDNN_TYPE_RESAMPLE_MODE,
725
+ CUDNN_TYPE_PADDING_MODE,
726
+ CUDNN_TYPE_INT32,
727
+ CUDNN_TYPE_CHAR,
728
+ CUDNN_TYPE_SIGNAL_MODE,
729
+ CUDNN_TYPE_FRACTION,
730
+ CUDNN_TYPE_NORM_MODE,
731
+ CUDNN_TYPE_NORM_FWD_PHASE,
732
+ CUDNN_TYPE_RNG_DISTRIBUTION
733
+ } cudnnBackendAttributeType_t;
734
+
735
+ typedef enum {
736
+ CUDNN_BACKEND_POINTWISE_DESCRIPTOR = 0,
737
+ CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR,
738
+ CUDNN_BACKEND_ENGINE_DESCRIPTOR,
739
+ CUDNN_BACKEND_ENGINECFG_DESCRIPTOR,
740
+ CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR,
741
+ CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR,
742
+ CUDNN_BACKEND_INTERMEDIATE_INFO_DESCRIPTOR,
743
+ CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR,
744
+ CUDNN_BACKEND_KNOB_INFO_DESCRIPTOR,
745
+ CUDNN_BACKEND_LAYOUT_INFO_DESCRIPTOR,
746
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
747
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
748
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
749
+ CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR,
750
+ CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR,
751
+ CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR,
752
+ CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR,
753
+ CUDNN_BACKEND_TENSOR_DESCRIPTOR,
754
+ CUDNN_BACKEND_MATMUL_DESCRIPTOR,
755
+ CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR,
756
+ CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR,
757
+ CUDNN_BACKEND_REDUCTION_DESCRIPTOR,
758
+ CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR,
759
+ CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR,
760
+ CUDNN_BACKEND_RESAMPLE_DESCRIPTOR,
761
+ CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR,
762
+ CUDNN_BACKEND_OPERATION_RESAMPLE_BWD_DESCRIPTOR,
763
+ CUDNN_BACKEND_OPERATION_CONCAT_DESCRIPTOR,
764
+ CUDNN_BACKEND_OPERATION_SIGNAL_DESCRIPTOR,
765
+ CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR,
766
+ CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR,
767
+ CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR,
768
+ CUDNN_BACKEND_RNG_DESCRIPTOR,
769
+ CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR,
770
+ } cudnnBackendDescriptorType_t;
771
+
772
+ typedef enum {
773
+ CUDNN_NUMERICAL_NOTE_TENSOR_CORE = 0,
774
+ CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS,
775
+ CUDNN_NUMERICAL_NOTE_REDUCED_PRECISION_REDUCTION,
776
+ CUDNN_NUMERICAL_NOTE_FFT,
777
+ CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC,
778
+ CUDNN_NUMERICAL_NOTE_WINOGRAD,
779
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_4x4,
780
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_6x6,
781
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_13x13,
782
+ CUDNN_NUMERICAL_NOTE_STRICT_NAN_PROP,
783
+ CUDNN_NUMERICAL_NOTE_TYPE_COUNT,
784
+ } cudnnBackendNumericalNote_t;
785
+
786
+ typedef enum {
787
+ CUDNN_BEHAVIOR_NOTE_RUNTIME_COMPILATION = 0,
788
+ CUDNN_BEHAVIOR_NOTE_REQUIRES_FILTER_INT8x32_REORDER = 1,
789
+ CUDNN_BEHAVIOR_NOTE_REQUIRES_BIAS_INT8x32_REORDER = 2,
790
+ CUDNN_BEHAVIOR_NOTE_TYPE_COUNT,
791
+ } cudnnBackendBehaviorNote_t;
792
+
793
+ typedef enum {
794
+ CUDNN_KNOB_TYPE_SPLIT_K CUDNN_DEPRECATED_ENUM = 0,
795
+ CUDNN_KNOB_TYPE_SWIZZLE = 1,
796
+ CUDNN_KNOB_TYPE_TILE_SIZE = 2,
797
+ CUDNN_KNOB_TYPE_USE_TEX CUDNN_DEPRECATED_ENUM = 3,
798
+ CUDNN_KNOB_TYPE_EDGE = 4,
799
+ CUDNN_KNOB_TYPE_KBLOCK CUDNN_DEPRECATED_ENUM = 5,
800
+ CUDNN_KNOB_TYPE_LDGA CUDNN_DEPRECATED_ENUM = 6,
801
+ CUDNN_KNOB_TYPE_LDGB CUDNN_DEPRECATED_ENUM = 7,
802
+ CUDNN_KNOB_TYPE_CHUNK_K CUDNN_DEPRECATED_ENUM = 8,
803
+ CUDNN_KNOB_TYPE_SPLIT_H CUDNN_DEPRECATED_ENUM = 9,
804
+ CUDNN_KNOB_TYPE_WINO_TILE CUDNN_DEPRECATED_ENUM = 10,
805
+ CUDNN_KNOB_TYPE_MULTIPLY = 11,
806
+ CUDNN_KNOB_TYPE_SPLIT_K_BUF = 12,
807
+ CUDNN_KNOB_TYPE_TILEK = 13,
808
+ CUDNN_KNOB_TYPE_STAGES = 14,
809
+ CUDNN_KNOB_TYPE_REDUCTION_MODE = 15,
810
+ CUDNN_KNOB_TYPE_CTA_SPLIT_K_MODE CUDNN_DEPRECATED_ENUM = 16,
811
+ CUDNN_KNOB_TYPE_SPLIT_K_SLC = 17,
812
+ CUDNN_KNOB_TYPE_IDX_MODE CUDNN_DEPRECATED_ENUM = 18,
813
+ CUDNN_KNOB_TYPE_SLICED CUDNN_DEPRECATED_ENUM = 19,
814
+ CUDNN_KNOB_TYPE_SPLIT_RS CUDNN_DEPRECATED_ENUM = 20,
815
+ CUDNN_KNOB_TYPE_SINGLEBUFFER CUDNN_DEPRECATED_ENUM = 21,
816
+ CUDNN_KNOB_TYPE_LDGC CUDNN_DEPRECATED_ENUM = 22,
817
+ CUDNN_KNOB_TYPE_SPECFILT = 23,
818
+ CUDNN_KNOB_TYPE_KERNEL_CFG = 24,
819
+ CUDNN_KNOB_TYPE_WORKSPACE = 25,
820
+ CUDNN_KNOB_TYPE_TILE_CGA CUDNN_DEPRECATED_ENUM = 26,
821
+ CUDNN_KNOB_TYPE_TILE_CGA_M = 27,
822
+ CUDNN_KNOB_TYPE_TILE_CGA_N = 28,
823
+ CUDNN_KNOB_TYPE_BLOCK_SIZE = 29,
824
+ CUDNN_KNOB_TYPE_OCCUPANCY = 30,
825
+ CUDNN_KNOB_TYPE_ARRAY_SIZE_PER_THREAD = 31,
826
+ CUDNN_KNOB_TYPE_NUM_C_PER_BLOCK CUDNN_DEPRECATED_ENUM = 32,
827
+ CUDNN_KNOB_TYPE_SPLIT_COLS = 33,
828
+ CUDNN_KNOB_TYPE_TILE_ROWS = 34,
829
+ CUDNN_KNOB_TYPE_TILE_COLS = 35,
830
+ CUDNN_KNOB_TYPE_LOAD_SIZE = 36,
831
+ CUDNN_KNOB_TYPE_COUNTS,
832
+ } cudnnBackendKnobType_t;
833
+
834
+ typedef enum {
835
+ CUDNN_LAYOUT_TYPE_PREFERRED_NCHW = 0,
836
+ CUDNN_LAYOUT_TYPE_PREFERRED_NHWC = 1,
837
+ CUDNN_LAYOUT_TYPE_PREFERRED_PAD4CK = 2,
838
+ CUDNN_LAYOUT_TYPE_PREFERRED_PAD8CK = 3,
839
+ CUDNN_LAYOUT_TYPE_COUNT = 4,
840
+ } cudnnBackendLayoutType_t;
841
+
842
+ typedef enum {
843
+ CUDNN_HEUR_MODE_INSTANT = 0,
844
+ CUDNN_HEUR_MODE_B = 1,
845
+ CUDNN_HEUR_MODE_FALLBACK = 2,
846
+ CUDNN_HEUR_MODE_A = 3,
847
+ CUDNN_HEUR_MODES_COUNT = 4,
848
+ } cudnnBackendHeurMode_t;
849
+
850
+ typedef enum {
851
+ CUDNN_TENSOR_REORDERING_NONE = 0,
852
+ CUDNN_TENSOR_REORDERING_INT8x32 = 1,
853
+ CUDNN_TENSOR_REORDERING_F16x16 = 2,
854
+ } cudnnBackendTensorReordering_t;
855
+
856
+ typedef enum {
857
+ CUDNN_ZERO_PAD = 0,
858
+ CUDNN_NEG_INF_PAD = 1,
859
+ CUDNN_EDGE_VAL_PAD = 2,
860
+ } cudnnPaddingMode_t;
861
+
862
+ typedef enum {
863
+ CUDNN_LAYER_NORM = 0,
864
+ CUDNN_INSTANCE_NORM = 1,
865
+ CUDNN_BATCH_NORM = 2,
866
+ CUDNN_GROUP_NORM = 3,
867
+ CUDNN_RMS_NORM = 4,
868
+ } cudnnBackendNormMode_t;
869
+
870
+ typedef enum {
871
+ CUDNN_NORM_FWD_INFERENCE = 0,
872
+ CUDNN_NORM_FWD_TRAINING = 1,
873
+ } cudnnBackendNormFwdPhase_t;
874
+
875
+ cudnnStatus_t CUDNNWINAPI
876
+ cudnnBackendCreateDescriptor(cudnnBackendDescriptorType_t descriptorType, cudnnBackendDescriptor_t *descriptor);
877
+
878
+ cudnnStatus_t CUDNNWINAPI
879
+ cudnnBackendDestroyDescriptor(cudnnBackendDescriptor_t descriptor);
880
+
881
+ cudnnStatus_t CUDNNWINAPI
882
+ cudnnBackendInitialize(cudnnBackendDescriptor_t descriptor);
883
+
884
+ cudnnStatus_t CUDNNWINAPI
885
+ cudnnBackendFinalize(cudnnBackendDescriptor_t descriptor);
886
+
887
+ cudnnStatus_t CUDNNWINAPI
888
+ cudnnBackendSetAttribute(cudnnBackendDescriptor_t descriptor,
889
+ cudnnBackendAttributeName_t attributeName,
890
+ cudnnBackendAttributeType_t attributeType,
891
+ int64_t elementCount,
892
+ const void *arrayOfElements);
893
+
894
+ cudnnStatus_t CUDNNWINAPI
895
+ cudnnBackendGetAttribute(cudnnBackendDescriptor_t const descriptor,
896
+ cudnnBackendAttributeName_t attributeName,
897
+ cudnnBackendAttributeType_t attributeType,
898
+ int64_t requestedElementCount,
899
+ int64_t *elementCount,
900
+ void *arrayOfElements);
901
+
902
+ cudnnStatus_t CUDNNWINAPI
903
+ cudnnBackendExecute(cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBackendDescriptor_t variantPack);
904
+
905
+ #if defined(__cplusplus)
906
+ }
907
+ #endif
908
+
909
+ #endif /* CUDNN_GRAPH_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops.h ADDED
@@ -0,0 +1,1316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_ops : cuDNN's basic definitions and basic operations.
52
+ */
53
+
54
+ #if !defined(CUDNN_OPS_H_)
55
+ #define CUDNN_OPS_H_
56
+
57
+ #include <stdint.h>
58
+
59
+ #include "cudnn_version.h"
60
+ #include "cudnn_graph.h"
61
+
62
+ /* These version numbers are autogenerated, do not edit manually. */
63
+ #define CUDNN_OPS_MAJOR 9
64
+ #define CUDNN_OPS_MINOR 1
65
+ #define CUDNN_OPS_PATCH 0
66
+
67
+ #if (CUDNN_OPS_MAJOR != CUDNN_MAJOR) || (CUDNN_OPS_MINOR != CUDNN_MINOR) || (CUDNN_OPS_PATCH != CUDNN_PATCHLEVEL)
68
+ #error Version mismatch in cuDNN OPS INFER!!!
69
+ #endif
70
+
71
+ #if defined(__cplusplus)
72
+ extern "C" {
73
+ #endif
74
+
75
+ /* Data structures to represent Image/Filter and the Neural Network Layer */
76
+ typedef struct cudnnTensorStruct *cudnnTensorDescriptor_t;
77
+ typedef struct cudnnPoolingStruct *cudnnPoolingDescriptor_t CUDNN_DEPRECATED;
78
+ typedef struct cudnnFilterStruct *cudnnFilterDescriptor_t CUDNN_DEPRECATED;
79
+ typedef struct cudnnLRNStruct *cudnnLRNDescriptor_t;
80
+ typedef struct cudnnActivationStruct *cudnnActivationDescriptor_t CUDNN_DEPRECATED;
81
+ typedef struct cudnnSpatialTransformerStruct *cudnnSpatialTransformerDescriptor_t;
82
+ typedef struct cudnnOpTensorStruct *cudnnOpTensorDescriptor_t CUDNN_DEPRECATED;
83
+ typedef struct cudnnReduceTensorStruct *cudnnReduceTensorDescriptor_t CUDNN_DEPRECATED;
84
+ typedef struct cudnnCTCLossStruct *cudnnCTCLossDescriptor_t;
85
+ typedef struct cudnnTensorTransformStruct *cudnnTensorTransformDescriptor_t CUDNN_DEPRECATED;
86
+ /*
87
+ * CUDNN Determinism
88
+ */
89
+ typedef enum {
90
+ CUDNN_NON_DETERMINISTIC = 0,
91
+ CUDNN_DETERMINISTIC = 1,
92
+ } cudnnDeterminism_t;
93
+
94
+ /* Create an instance of a generic Tensor descriptor */
95
+ cudnnStatus_t CUDNNWINAPI
96
+ cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc);
97
+
98
+ cudnnStatus_t CUDNNWINAPI
99
+ cudnnSetTensor4dDescriptor(cudnnTensorDescriptor_t tensorDesc,
100
+ cudnnTensorFormat_t format,
101
+ cudnnDataType_t dataType, /* image data type */
102
+ int n, /* number of inputs (batch size) */
103
+ int c, /* number of input feature maps */
104
+ int h, /* height of input section */
105
+ int w); /* width of input section */
106
+
107
+ cudnnStatus_t CUDNNWINAPI
108
+ cudnnSetTensor4dDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
109
+ cudnnDataType_t dataType, /* image data type */
110
+ int n, /* number of inputs (batch size) */
111
+ int c, /* number of input feature maps */
112
+ int h, /* height of input section */
113
+ int w, /* width of input section */
114
+ int nStride,
115
+ int cStride,
116
+ int hStride,
117
+ int wStride);
118
+
119
+ cudnnStatus_t CUDNNWINAPI
120
+ cudnnGetTensor4dDescriptor(const cudnnTensorDescriptor_t tensorDesc,
121
+ cudnnDataType_t *dataType, /* image data type */
122
+ int *n, /* number of inputs (batch size) */
123
+ int *c, /* number of input feature maps */
124
+ int *h, /* height of input section */
125
+ int *w, /* width of input section */
126
+ int *nStride,
127
+ int *cStride,
128
+ int *hStride,
129
+ int *wStride);
130
+
131
+ cudnnStatus_t CUDNNWINAPI
132
+ cudnnSetTensorNdDescriptor(cudnnTensorDescriptor_t tensorDesc,
133
+ cudnnDataType_t dataType,
134
+ int nbDims,
135
+ const int dimA[],
136
+ const int strideA[]);
137
+
138
+ cudnnStatus_t CUDNNWINAPI
139
+ cudnnSetTensorNdDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
140
+ cudnnTensorFormat_t format,
141
+ cudnnDataType_t dataType,
142
+ int nbDims,
143
+ const int dimA[]);
144
+
145
+ cudnnStatus_t CUDNNWINAPI
146
+ cudnnGetTensorNdDescriptor(const cudnnTensorDescriptor_t tensorDesc,
147
+ int nbDimsRequested,
148
+ cudnnDataType_t *dataType,
149
+ int *nbDims,
150
+ int dimA[],
151
+ int strideA[]);
152
+
153
+ cudnnStatus_t CUDNNWINAPI
154
+ cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size);
155
+
156
+ /* PixelOffset( n, c, h, w ) = n *input_stride + c * feature_stride + h * h_stride + w * w_stride
157
+
158
+ 1)Example of all images in row major order one batch of features after the other (with an optional padding on row)
159
+ input_stride : c x h x h_stride
160
+ feature_stride : h x h_stride
161
+ h_stride : >= w ( h_stride = w if no padding)
162
+ w_stride : 1
163
+
164
+
165
+ 2)Example of all images in row major with features maps interleaved
166
+ input_stride : c x h x h_stride
167
+ feature_stride : 1
168
+ h_stride : w x c
169
+ w_stride : c
170
+
171
+ 3)Example of all images in column major order one batch of features after the other (with optional padding on column)
172
+ input_stride : c x w x w_stride
173
+ feature_stride : w x w_stride
174
+ h_stride : 1
175
+ w_stride : >= h
176
+
177
+ */
178
+
179
+ /* Destroy an instance of Tensor4d descriptor */
180
+ cudnnStatus_t CUDNNWINAPI
181
+ cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc);
182
+
183
+ /* Fold/unfold transforms */
184
+ typedef enum {
185
+ CUDNN_TRANSFORM_FOLD = 0U,
186
+ CUDNN_TRANSFORM_UNFOLD = 1U,
187
+ } cudnnFoldingDirection_t;
188
+
189
+ /** Create a destination descriptor for cudnnTransformTensor */
190
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
191
+ cudnnInitTransformDest(const cudnnTensorTransformDescriptor_t transformDesc,
192
+ const cudnnTensorDescriptor_t srcDesc,
193
+ cudnnTensorDescriptor_t destDesc,
194
+ size_t *destSizeInBytes);
195
+
196
+ /** Create an empty tensor transform descriptor */
197
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
198
+ cudnnCreateTensorTransformDescriptor(cudnnTensorTransformDescriptor_t *transformDesc);
199
+
200
+ /** Initialize a previously created tensor transform descriptor. */
201
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
202
+ cudnnSetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
203
+ const uint32_t nbDims,
204
+ const cudnnTensorFormat_t destFormat,
205
+ const int32_t padBeforeA[],
206
+ const int32_t padAfterA[],
207
+ const uint32_t foldA[],
208
+ const cudnnFoldingDirection_t direction);
209
+
210
+ /**
211
+ * Retrieves the values stored in a previously initialized tensor transform
212
+ * descriptor.
213
+ */
214
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
215
+ cudnnGetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
216
+ uint32_t nbDimsRequested,
217
+ cudnnTensorFormat_t *destFormat,
218
+ int32_t padBeforeA[],
219
+ int32_t padAfterA[],
220
+ uint32_t foldA[],
221
+ cudnnFoldingDirection_t *direction);
222
+
223
+ /**
224
+ * Destroys a previously created tensor transform descriptor.
225
+ */
226
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
227
+ cudnnDestroyTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc);
228
+
229
+ /* Tensor layout conversion helper (y = alpha * x + beta * y) */
230
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
231
+ cudnnTransformTensor(cudnnHandle_t handle,
232
+ const void *alpha,
233
+ const cudnnTensorDescriptor_t xDesc,
234
+ const void *x,
235
+ const void *beta,
236
+ const cudnnTensorDescriptor_t yDesc,
237
+ void *y);
238
+
239
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
240
+ cudnnTransformTensorEx(cudnnHandle_t handle,
241
+ const cudnnTensorTransformDescriptor_t transDesc,
242
+ const void *alpha,
243
+ const cudnnTensorDescriptor_t srcDesc,
244
+ const void *srcData,
245
+ const void *beta,
246
+ const cudnnTensorDescriptor_t destDesc,
247
+ void *destData);
248
+
249
+ /* Tensor Bias addition : C = alpha * A + beta * C */
250
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
251
+ cudnnAddTensor(cudnnHandle_t handle,
252
+ const void *alpha,
253
+ const cudnnTensorDescriptor_t aDesc,
254
+ const void *A,
255
+ const void *beta,
256
+ const cudnnTensorDescriptor_t cDesc,
257
+ void *C);
258
+
259
+ /*
260
+ * CUDNN OpTensor op type
261
+ */
262
+ typedef enum {
263
+ CUDNN_OP_TENSOR_ADD = 0,
264
+ CUDNN_OP_TENSOR_MUL = 1,
265
+ CUDNN_OP_TENSOR_MIN = 2,
266
+ CUDNN_OP_TENSOR_MAX = 3,
267
+ CUDNN_OP_TENSOR_SQRT = 4,
268
+ CUDNN_OP_TENSOR_NOT = 5,
269
+ } cudnnOpTensorOp_t;
270
+
271
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
272
+ cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc);
273
+
274
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
275
+ cudnnSetOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc,
276
+ cudnnOpTensorOp_t opTensorOp,
277
+ cudnnDataType_t opTensorCompType,
278
+ cudnnNanPropagation_t opTensorNanOpt);
279
+
280
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
281
+ cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc,
282
+ cudnnOpTensorOp_t *opTensorOp,
283
+ cudnnDataType_t *opTensorCompType,
284
+ cudnnNanPropagation_t *opTensorNanOpt);
285
+
286
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
287
+ cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc);
288
+
289
+ /* Tensor operation : C = op( alpha1 * A, alpha2 * B ) + beta * C */
290
+ /* B tensor is ignored for CUDNN_OP_TENSOR_SQRT, CUDNN_OP_TENSOR_NOT. */
291
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
292
+ cudnnOpTensor(cudnnHandle_t handle,
293
+ const cudnnOpTensorDescriptor_t opTensorDesc,
294
+ const void *alpha1,
295
+ const cudnnTensorDescriptor_t aDesc,
296
+ const void *A,
297
+ const void *alpha2,
298
+ const cudnnTensorDescriptor_t bDesc,
299
+ const void *B,
300
+ const void *beta,
301
+ const cudnnTensorDescriptor_t cDesc,
302
+ void *C);
303
+
304
+ /*
305
+ * CUDNN ReduceTensor indices type
306
+ */
307
+ typedef enum {
308
+ CUDNN_REDUCE_TENSOR_NO_INDICES = 0,
309
+ CUDNN_REDUCE_TENSOR_FLATTENED_INDICES = 1,
310
+ } cudnnReduceTensorIndices_t CUDNN_DEPRECATED;
311
+
312
+ /*
313
+ * CUDNN tensor indices type size (all unsigned)
314
+ * Currently not supported, default is 32 bit unsigned.
315
+ */
316
+ typedef enum {
317
+ CUDNN_32BIT_INDICES = 0,
318
+ CUDNN_64BIT_INDICES = 1,
319
+ CUDNN_16BIT_INDICES = 2,
320
+ CUDNN_8BIT_INDICES = 3,
321
+ } cudnnIndicesType_t CUDNN_DEPRECATED;
322
+
323
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
324
+ cudnnCreateReduceTensorDescriptor(cudnnReduceTensorDescriptor_t *reduceTensorDesc);
325
+
326
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
327
+ cudnnSetReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc,
328
+ cudnnReduceTensorOp_t reduceTensorOp,
329
+ cudnnDataType_t reduceTensorCompType,
330
+ cudnnNanPropagation_t reduceTensorNanOpt,
331
+ cudnnReduceTensorIndices_t reduceTensorIndices,
332
+ cudnnIndicesType_t reduceTensorIndicesType);
333
+
334
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
335
+ cudnnGetReduceTensorDescriptor(const cudnnReduceTensorDescriptor_t reduceTensorDesc,
336
+ cudnnReduceTensorOp_t *reduceTensorOp,
337
+ cudnnDataType_t *reduceTensorCompType,
338
+ cudnnNanPropagation_t *reduceTensorNanOpt,
339
+ cudnnReduceTensorIndices_t *reduceTensorIndices,
340
+ cudnnIndicesType_t *reduceTensorIndicesType);
341
+
342
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
343
+ cudnnDestroyReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc);
344
+
345
+ /* Helper function to return the minimum size of the index space to be passed to the reduction given the input and
346
+ * output tensors */
347
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
348
+ cudnnGetReductionIndicesSize(cudnnHandle_t handle,
349
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
350
+ const cudnnTensorDescriptor_t aDesc,
351
+ const cudnnTensorDescriptor_t cDesc,
352
+ size_t *sizeInBytes);
353
+
354
+ /* Helper function to return the minimum size of the workspace to be passed to the reduction given the input and output
355
+ * tensors */
356
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
357
+ cudnnGetReductionWorkspaceSize(cudnnHandle_t handle,
358
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
359
+ const cudnnTensorDescriptor_t aDesc,
360
+ const cudnnTensorDescriptor_t cDesc,
361
+ size_t *sizeInBytes);
362
+
363
+ /* Tensor operation : C = reduce op( alpha * A ) + beta * C */
364
+ /* The NaN propagation enum applies to only the min and max reduce ops; the other reduce ops propagate NaN as usual. */
365
+ /* The indices space is ignored for reduce ops other than min or max. */
366
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
367
+ cudnnReduceTensor(cudnnHandle_t handle,
368
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
369
+ void *indices,
370
+ size_t indicesSizeInBytes,
371
+ void *workspace,
372
+ size_t workspaceSizeInBytes,
373
+ const void *alpha,
374
+ const cudnnTensorDescriptor_t aDesc,
375
+ const void *A,
376
+ const void *beta,
377
+ const cudnnTensorDescriptor_t cDesc,
378
+ void *C);
379
+
380
+ /* Set all values of a tensor to a given value : y[i] = value[0] */
381
+ cudnnStatus_t CUDNNWINAPI
382
+ cudnnSetTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *valuePtr);
383
+
384
+ /* Scale all values of a tensor by a given factor : y[i] = alpha * y[i] */
385
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
386
+ cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *alpha);
387
+
388
+ /* Create an instance of FilterStruct */
389
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
390
+ cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc);
391
+
392
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
393
+ cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc,
394
+ cudnnDataType_t dataType, /* image data type */
395
+ cudnnTensorFormat_t format,
396
+ int k, /* number of output feature maps */
397
+ int c, /* number of input feature maps */
398
+ int h, /* height of each input filter */
399
+ int w); /* width of each input filter */
400
+
401
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
402
+ cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc,
403
+ cudnnDataType_t *dataType, /* image data type */
404
+ cudnnTensorFormat_t *format,
405
+ int *k, /* number of output feature maps */
406
+ int *c, /* number of input feature maps */
407
+ int *h, /* height of each input filter */
408
+ int *w); /* width of each input filter */
409
+
410
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
411
+ cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc,
412
+ cudnnDataType_t dataType, /* image data type */
413
+ cudnnTensorFormat_t format,
414
+ int nbDims,
415
+ const int filterDimA[]);
416
+
417
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
418
+ cudnnGetFilterNdDescriptor(const cudnnFilterDescriptor_t filterDesc,
419
+ int nbDimsRequested,
420
+ cudnnDataType_t *dataType, /* image data type */
421
+ cudnnTensorFormat_t *format,
422
+ int *nbDims,
423
+ int filterDimA[]);
424
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
425
+ cudnnGetFilterSizeInBytes(const cudnnFilterDescriptor_t filterDesc, size_t *size);
426
+
427
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
428
+ cudnnTransformFilter(cudnnHandle_t handle,
429
+ const cudnnTensorTransformDescriptor_t transDesc,
430
+ const void *alpha,
431
+ const cudnnFilterDescriptor_t srcDesc,
432
+ const void *srcData,
433
+ const void *beta,
434
+ const cudnnFilterDescriptor_t destDesc,
435
+ void *destData);
436
+
437
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
438
+ cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc);
439
+
440
+ /*
441
+ * softmax algorithm
442
+ */
443
+ typedef enum {
444
+ CUDNN_SOFTMAX_FAST = 0, /* straightforward implementation */
445
+ CUDNN_SOFTMAX_ACCURATE = 1, /* subtract max from every point to avoid overflow */
446
+ CUDNN_SOFTMAX_LOG = 2
447
+ } cudnnSoftmaxAlgorithm_t;
448
+
449
+ typedef enum {
450
+ CUDNN_SOFTMAX_MODE_INSTANCE = 0, /* compute the softmax over all C, H, W for each N */
451
+ CUDNN_SOFTMAX_MODE_CHANNEL = 1 /* compute the softmax over all C for each H, W, N */
452
+ } cudnnSoftmaxMode_t;
453
+
454
+ /* Softmax functions: All of the form "output = alpha * Op(inputs) + beta * output" */
455
+
456
+ /* Function to perform forward softmax */
457
+ cudnnStatus_t CUDNNWINAPI
458
+ cudnnSoftmaxForward(cudnnHandle_t handle,
459
+ cudnnSoftmaxAlgorithm_t algo,
460
+ cudnnSoftmaxMode_t mode,
461
+ const void *alpha,
462
+ const cudnnTensorDescriptor_t xDesc,
463
+ const void *x,
464
+ const void *beta,
465
+ const cudnnTensorDescriptor_t yDesc,
466
+ void *y);
467
+
468
+ /*
469
+ * pooling mode
470
+ */
471
+ typedef enum {
472
+ CUDNN_POOLING_MAX = 0,
473
+ CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values */
474
+ CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values */
475
+ CUDNN_POOLING_MAX_DETERMINISTIC = 3
476
+ } cudnnPoolingMode_t CUDNN_DEPRECATED;
477
+
478
+ /* Create an instance of pooling descriptor */
479
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
480
+ cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
481
+
482
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
483
+ cudnnSetPooling2dDescriptor(cudnnPoolingDescriptor_t poolingDesc,
484
+ cudnnPoolingMode_t mode,
485
+ cudnnNanPropagation_t maxpoolingNanOpt,
486
+ int windowHeight,
487
+ int windowWidth,
488
+ int verticalPadding,
489
+ int horizontalPadding,
490
+ int verticalStride,
491
+ int horizontalStride);
492
+
493
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
494
+ cudnnGetPooling2dDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
495
+ cudnnPoolingMode_t *mode,
496
+ cudnnNanPropagation_t *maxpoolingNanOpt,
497
+ int *windowHeight,
498
+ int *windowWidth,
499
+ int *verticalPadding,
500
+ int *horizontalPadding,
501
+ int *verticalStride,
502
+ int *horizontalStride);
503
+
504
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
505
+ cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc,
506
+ const cudnnPoolingMode_t mode,
507
+ const cudnnNanPropagation_t maxpoolingNanOpt,
508
+ int nbDims,
509
+ const int windowDimA[],
510
+ const int paddingA[],
511
+ const int strideA[]);
512
+
513
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
514
+ cudnnGetPoolingNdDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
515
+ int nbDimsRequested,
516
+ cudnnPoolingMode_t *mode,
517
+ cudnnNanPropagation_t *maxpoolingNanOpt,
518
+ int *nbDims,
519
+ int windowDimA[],
520
+ int paddingA[],
521
+ int strideA[]);
522
+
523
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
524
+ cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
525
+ const cudnnTensorDescriptor_t inputTensorDesc,
526
+ int nbDims,
527
+ int outputTensorDimA[]);
528
+
529
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
530
+ cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
531
+ const cudnnTensorDescriptor_t inputTensorDesc,
532
+ int *n,
533
+ int *c,
534
+ int *h,
535
+ int *w);
536
+
537
+ /* Destroy an instance of pooling descriptor */
538
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
539
+ cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc);
540
+
541
+ /* Pooling functions: All of the form "output = alpha * Op(inputs) + beta * output" */
542
+
543
+ /* Function to perform forward pooling */
544
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
545
+ cudnnPoolingForward(cudnnHandle_t handle,
546
+ const cudnnPoolingDescriptor_t poolingDesc,
547
+ const void *alpha,
548
+ const cudnnTensorDescriptor_t xDesc,
549
+ const void *x,
550
+ const void *beta,
551
+ const cudnnTensorDescriptor_t yDesc,
552
+ void *y);
553
+
554
+ /* Activation functions: All of the form "output = alpha * Op(inputs) + beta * output" */
555
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
556
+ cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc);
557
+
558
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
559
+ cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc,
560
+ cudnnActivationMode_t mode,
561
+ cudnnNanPropagation_t reluNanOpt,
562
+ double coef); /* ceiling for clipped RELU, alpha for ELU */
563
+
564
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
565
+ cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc,
566
+ cudnnActivationMode_t *mode,
567
+ cudnnNanPropagation_t *reluNanOpt,
568
+ double *coef); /* ceiling for clipped RELU, alpha for ELU */
569
+
570
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
571
+ cudnnSetActivationDescriptorSwishBeta(cudnnActivationDescriptor_t activationDesc, double swish_beta);
572
+
573
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
574
+ cudnnGetActivationDescriptorSwishBeta(cudnnActivationDescriptor_t activationDesc, double *swish_beta);
575
+
576
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
577
+ cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc);
578
+
579
+ /* Function to perform forward activation */
580
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
581
+ cudnnActivationForward(cudnnHandle_t handle,
582
+ cudnnActivationDescriptor_t activationDesc,
583
+ const void *alpha,
584
+ const cudnnTensorDescriptor_t xDesc,
585
+ const void *x,
586
+ const void *beta,
587
+ const cudnnTensorDescriptor_t yDesc,
588
+ void *y);
589
+
590
+ /*
591
+ * Create an instance of LRN (Local Response Normalization) descriptor
592
+ * Uses lrnN=5, lrnAlpha=1e-4, lrnBeta=0.75, lrnK=2.0 as defaults from Krizhevsky'12 ImageNet paper
593
+ */
594
+ cudnnStatus_t CUDNNWINAPI
595
+ cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc);
596
+
597
+ #define CUDNN_LRN_MIN_N 1 /* minimum allowed lrnN */
598
+ #define CUDNN_LRN_MAX_N 16 /* maximum allowed lrnN */
599
+ #define CUDNN_LRN_MIN_K 1e-5 /* minimum allowed lrnK */
600
+ #define CUDNN_LRN_MIN_BETA 0.01 /* minimum allowed lrnBeta */
601
+
602
+ /* LRN layer mode */
603
+ typedef enum {
604
+ CUDNN_LRN_CROSS_CHANNEL_DIM1 = 0, /* Normalize across tensor's dimA[1] dimension */
605
+ } cudnnLRNMode_t;
606
+
607
+ /*
608
+ * Uses a window [center-lookBehind, center+lookAhead], where
609
+ * lookBehind = floor( (lrnN-1)/2 ), lookAhead = lrnN-lookBehind-1.
610
+ * Values of double parameters cast to tensor data type.
611
+ */
612
+ cudnnStatus_t CUDNNWINAPI
613
+ cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned lrnN, double lrnAlpha, double lrnBeta, double lrnK);
614
+ /*
615
+ * Retrieve the settings currently stored in an LRN layer descriptor
616
+ * Any of the provided pointers can be NULL (no corresponding value will be returned)
617
+ */
618
+ cudnnStatus_t CUDNNWINAPI
619
+ cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrnAlpha, double *lrnBeta, double *lrnK);
620
+
621
+ /* Destroy an instance of LRN descriptor */
622
+ cudnnStatus_t CUDNNWINAPI
623
+ cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc);
624
+
625
+ /* LRN functions: output = alpha * normalize(x) + beta * old_y */
626
+
627
+ /* LRN cross-channel forward computation. Double parameters cast to tensor data type */
628
+ cudnnStatus_t CUDNNWINAPI
629
+ cudnnLRNCrossChannelForward(cudnnHandle_t handle,
630
+ cudnnLRNDescriptor_t normDesc,
631
+ cudnnLRNMode_t lrnMode,
632
+ const void *alpha,
633
+ const cudnnTensorDescriptor_t xDesc,
634
+ const void *x,
635
+ const void *beta,
636
+ const cudnnTensorDescriptor_t yDesc,
637
+ void *y);
638
+
639
+ typedef enum {
640
+ CUDNN_DIVNORM_PRECOMPUTED_MEANS = 0,
641
+ } cudnnDivNormMode_t;
642
+
643
+ /* LCN/divisive normalization functions: y = alpha * normalize(x) + beta * y */
644
+ cudnnStatus_t CUDNNWINAPI
645
+ cudnnDivisiveNormalizationForward(cudnnHandle_t handle,
646
+ cudnnLRNDescriptor_t normDesc,
647
+ cudnnDivNormMode_t mode,
648
+ const void *alpha,
649
+ const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */
650
+ const void *x,
651
+ const void *means, /* if NULL, means are assumed to be zero */
652
+ void *temp,
653
+ void *temp2,
654
+ const void *beta,
655
+ const cudnnTensorDescriptor_t yDesc,
656
+ void *y);
657
+
658
+ typedef enum {
659
+ /* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
660
+ CUDNN_BATCHNORM_PER_ACTIVATION = 0,
661
+
662
+ /* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
663
+ CUDNN_BATCHNORM_SPATIAL = 1,
664
+
665
+ /*
666
+ * bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors).
667
+ * May be faster than CUDNN_BATCHNORM_SPATIAL but imposes some limits on the range of values
668
+ */
669
+ CUDNN_BATCHNORM_SPATIAL_PERSISTENT = 2,
670
+ } cudnnBatchNormMode_t CUDNN_DEPRECATED;
671
+
672
+ #define CUDNN_BN_MIN_EPSILON 0.0 /* Minimum epsilon allowed to be used in the Batch Normalization formula */
673
+
674
+ /*
675
+ * Derives a tensor descriptor from layer data descriptor for BatchNormalization
676
+ * scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
677
+ * bnScaleBiasMeanVarDesc and bnScaleBiasDiffDesc in Batch Normalization forward and backward functions.
678
+ */
679
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
680
+ cudnnDeriveBNTensorDescriptor(cudnnTensorDescriptor_t derivedBnDesc,
681
+ const cudnnTensorDescriptor_t xDesc,
682
+ cudnnBatchNormMode_t mode);
683
+
684
+ typedef enum {
685
+ CUDNN_BATCHNORM_OPS_BN = 0, /* do batch normalization only */
686
+ CUDNN_BATCHNORM_OPS_BN_ACTIVATION = 1, /* do batchNorm, then activation */
687
+ CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION = 2, /* do batchNorm, then elemWiseAdd, then activation */
688
+ } cudnnBatchNormOps_t CUDNN_DEPRECATED;
689
+
690
+ /*
691
+ * Performs Batch Normalization during Inference:
692
+ * y[i] = bnScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + bnBias[k]
693
+ * with bnScale, bnBias, runningMean, runningInvVariance tensors indexed
694
+ * according to spatial or per-activation mode. Refer to cudnnBatchNormalizationForwardTraining
695
+ * above for notes on function arguments.
696
+ */
697
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
698
+ cudnnBatchNormalizationForwardInference(cudnnHandle_t handle,
699
+ cudnnBatchNormMode_t mode,
700
+ const void *alpha, /* alpha[0] = result blend factor */
701
+ const void *beta, /* beta[0] = dest layer blend factor */
702
+ const cudnnTensorDescriptor_t xDesc,
703
+ const void *x, /* NxCxHxW */
704
+ const cudnnTensorDescriptor_t yDesc,
705
+ void *y, /* NxCxHxW */
706
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
707
+ const void *bnScale,
708
+ const void *bnBias,
709
+ const void *estimatedMean,
710
+ const void *estimatedVariance,
711
+ double epsilon);
712
+
713
+ typedef enum {
714
+ /* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
715
+ CUDNN_NORM_PER_ACTIVATION = 0,
716
+
717
+ /* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
718
+ CUDNN_NORM_PER_CHANNEL = 1,
719
+ } cudnnNormMode_t CUDNN_DEPRECATED;
720
+
721
+ typedef enum { CUDNN_NORM_ALGO_STANDARD = 0, CUDNN_NORM_ALGO_PERSIST = 1 } cudnnNormAlgo_t CUDNN_DEPRECATED;
722
+
723
+ /*
724
+ * Derives a tensor descriptor from layer data descriptor for Normalization
725
+ * scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
726
+ * normScaleBiasMeanVarDesc and normScaleBiasDiffDesc in Normalization forward and backward functions.
727
+ */
728
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
729
+ cudnnDeriveNormTensorDescriptor(cudnnTensorDescriptor_t derivedNormScaleBiasDesc,
730
+ cudnnTensorDescriptor_t derivedNormMeanVarDesc,
731
+ const cudnnTensorDescriptor_t xDesc,
732
+ cudnnNormMode_t mode,
733
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
734
+
735
+ typedef enum {
736
+ CUDNN_NORM_OPS_NORM = 0, /* do normalization only */
737
+ CUDNN_NORM_OPS_NORM_ACTIVATION = 1, /* do Norm, then activation */
738
+ CUDNN_NORM_OPS_NORM_ADD_ACTIVATION = 2, /* do Norm, then elemWiseAdd, then activation */
739
+ } cudnnNormOps_t CUDNN_DEPRECATED;
740
+
741
+ /*
742
+ * Performs Normalization during Inference:
743
+ * y[i] = normScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + normBias[k]
744
+ * with normScale, normBias, runningMean, runningInvVariance tensors indexed
745
+ * according to per-channel or per-activation mode. Refer to cudnnNormalizationForwardTraining
746
+ * above for notes on function arguments.
747
+ */
748
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
749
+ cudnnNormalizationForwardInference(cudnnHandle_t handle,
750
+ cudnnNormMode_t mode,
751
+ cudnnNormOps_t normOps,
752
+ cudnnNormAlgo_t algo,
753
+ const void *alpha, /* alpha[0] = result blend factor */
754
+ const void *beta, /* beta[0] = dest layer blend factor */
755
+ const cudnnTensorDescriptor_t xDesc,
756
+ const void *x, /* NxCxHxW */
757
+ const cudnnTensorDescriptor_t normScaleBiasDesc,
758
+ const void *normScale,
759
+ const void *normBias,
760
+ const cudnnTensorDescriptor_t normMeanVarDesc,
761
+ const void *estimatedMean,
762
+ const void *estimatedVariance,
763
+ const cudnnTensorDescriptor_t zDesc,
764
+ const void *z,
765
+ cudnnActivationDescriptor_t activationDesc,
766
+ const cudnnTensorDescriptor_t yDesc,
767
+ void *y, /* NxCxHxW */
768
+ double epsilon,
769
+ int groupCnt); /* Place hold for future work*/
770
+
771
+ /* APIs for spatial transformer network*/
772
+ typedef enum {
773
+ CUDNN_SAMPLER_BILINEAR = 0,
774
+ } cudnnSamplerType_t;
775
+
776
+ cudnnStatus_t CUDNNWINAPI
777
+ cudnnCreateSpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t *stDesc);
778
+
779
+ cudnnStatus_t CUDNNWINAPI
780
+ cudnnSetSpatialTransformerNdDescriptor(cudnnSpatialTransformerDescriptor_t stDesc,
781
+ cudnnSamplerType_t samplerType,
782
+ cudnnDataType_t dataType,
783
+ const int nbDims,
784
+ const int dimA[]);
785
+
786
+ cudnnStatus_t CUDNNWINAPI
787
+ cudnnDestroySpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t stDesc);
788
+
789
+ cudnnStatus_t CUDNNWINAPI
790
+ cudnnSpatialTfGridGeneratorForward(cudnnHandle_t handle,
791
+ const cudnnSpatialTransformerDescriptor_t stDesc,
792
+ const void *theta,
793
+ void *grid);
794
+
795
+ cudnnStatus_t CUDNNWINAPI
796
+ cudnnSpatialTfSamplerForward(cudnnHandle_t handle,
797
+ cudnnSpatialTransformerDescriptor_t stDesc,
798
+ const void *alpha,
799
+ const cudnnTensorDescriptor_t xDesc,
800
+ const void *x,
801
+ const void *grid,
802
+ const void *beta,
803
+ cudnnTensorDescriptor_t yDesc,
804
+ void *y);
805
+
806
+ typedef struct cudnnDropoutStruct *cudnnDropoutDescriptor_t;
807
+
808
+ cudnnStatus_t CUDNNWINAPI
809
+ cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc);
810
+
811
+ cudnnStatus_t CUDNNWINAPI
812
+ cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc);
813
+
814
+ /*helper function to determine size of the states to be passed to cudnnSetDropoutDescriptor */
815
+ cudnnStatus_t CUDNNWINAPI
816
+ cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t *sizeInBytes);
817
+
818
+ /*helper function to determine size of the reserve space to be passed to dropout forward/backward calls */
819
+ cudnnStatus_t CUDNNWINAPI
820
+ cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes);
821
+
822
+ cudnnStatus_t CUDNNWINAPI
823
+ cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
824
+ cudnnHandle_t handle,
825
+ float dropout,
826
+ void *states,
827
+ size_t stateSizeInBytes,
828
+ unsigned long long seed);
829
+
830
+ /* Restores the dropout descriptor to a previously saved-off state */
831
+ cudnnStatus_t CUDNNWINAPI
832
+ cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
833
+ cudnnHandle_t handle,
834
+ float dropout,
835
+ void *states,
836
+ size_t stateSizeInBytes,
837
+ unsigned long long seed);
838
+
839
+ cudnnStatus_t CUDNNWINAPI
840
+ cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
841
+ cudnnHandle_t handle,
842
+ float *dropout,
843
+ void **states,
844
+ unsigned long long *seed);
845
+
846
+ cudnnStatus_t CUDNNWINAPI
847
+ cudnnDropoutForward(cudnnHandle_t handle,
848
+ const cudnnDropoutDescriptor_t dropoutDesc,
849
+ const cudnnTensorDescriptor_t xdesc,
850
+ const void *x,
851
+ const cudnnTensorDescriptor_t ydesc,
852
+ void *y,
853
+ void *reserveSpace,
854
+ size_t reserveSpaceSizeInBytes);
855
+
856
+ /* TODO: move these enums out to the appropriate submodule */
857
+ typedef enum {
858
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0,
859
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1,
860
+ CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2,
861
+ CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3,
862
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4,
863
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5,
864
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6,
865
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7,
866
+ CUDNN_CONVOLUTION_FWD_ALGO_COUNT = 8
867
+ } cudnnConvolutionFwdAlgo_t;
868
+
869
+ typedef enum {
870
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, /* non-deterministic */
871
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
872
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
873
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3, /* non-deterministic */
874
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4, /* not implemented */
875
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5,
876
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING = 6,
877
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT = 7
878
+ } cudnnConvolutionBwdFilterAlgo_t;
879
+
880
+ typedef enum {
881
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, /* non-deterministic */
882
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
883
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
884
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
885
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4,
886
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5,
887
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT = 6
888
+ } cudnnConvolutionBwdDataAlgo_t;
889
+
890
+ typedef enum { CUDNN_CTC_LOSS_ALGO_DETERMINISTIC = 0, CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC = 1 } cudnnCTCLossAlgo_t;
891
+
892
+ /*
893
+ * \brief Cross-library version checker.
894
+ * This function is implemented differently in each sub-library. Each sublib
895
+ * checks whether its own version matches that of its dependencies.
896
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
897
+ * CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
898
+ */
899
+ cudnnStatus_t CUDNNWINAPI
900
+ cudnnOpsVersionCheck(void);
901
+
902
+ /* Function to perform backward softmax */
903
+ cudnnStatus_t CUDNNWINAPI
904
+ cudnnSoftmaxBackward(cudnnHandle_t handle,
905
+ cudnnSoftmaxAlgorithm_t algo,
906
+ cudnnSoftmaxMode_t mode,
907
+ const void *alpha,
908
+ const cudnnTensorDescriptor_t yDesc,
909
+ const void *y,
910
+ const cudnnTensorDescriptor_t dyDesc,
911
+ const void *dy,
912
+ const void *beta,
913
+ const cudnnTensorDescriptor_t dxDesc,
914
+ void *dx);
915
+
916
+ /* Function to perform backward pooling */
917
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
918
+ cudnnPoolingBackward(cudnnHandle_t handle,
919
+ const cudnnPoolingDescriptor_t poolingDesc,
920
+ const void *alpha,
921
+ const cudnnTensorDescriptor_t yDesc,
922
+ const void *y,
923
+ const cudnnTensorDescriptor_t dyDesc,
924
+ const void *dy,
925
+ const cudnnTensorDescriptor_t xDesc,
926
+ const void *x,
927
+ const void *beta,
928
+ const cudnnTensorDescriptor_t dxDesc,
929
+ void *dx);
930
+
931
+ /* Function to perform backward activation */
932
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
933
+ cudnnActivationBackward(cudnnHandle_t handle,
934
+ cudnnActivationDescriptor_t activationDesc,
935
+ const void *alpha,
936
+ const cudnnTensorDescriptor_t yDesc,
937
+ const void *y,
938
+ const cudnnTensorDescriptor_t dyDesc,
939
+ const void *dy,
940
+ const cudnnTensorDescriptor_t xDesc,
941
+ const void *x,
942
+ const void *beta,
943
+ const cudnnTensorDescriptor_t dxDesc,
944
+ void *dx);
945
+
946
+ /* LRN cross-channel backward computation. Double parameters cast to tensor data type */
947
+ cudnnStatus_t CUDNNWINAPI
948
+ cudnnLRNCrossChannelBackward(cudnnHandle_t handle,
949
+ cudnnLRNDescriptor_t normDesc,
950
+ cudnnLRNMode_t lrnMode,
951
+ const void *alpha,
952
+ const cudnnTensorDescriptor_t yDesc,
953
+ const void *y,
954
+ const cudnnTensorDescriptor_t dyDesc,
955
+ const void *dy,
956
+ const cudnnTensorDescriptor_t xDesc,
957
+ const void *x,
958
+ const void *beta,
959
+ const cudnnTensorDescriptor_t dxDesc,
960
+ void *dx);
961
+
962
+ cudnnStatus_t CUDNNWINAPI
963
+ cudnnDivisiveNormalizationBackward(cudnnHandle_t handle,
964
+ cudnnLRNDescriptor_t normDesc,
965
+ cudnnDivNormMode_t mode,
966
+ const void *alpha,
967
+ const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */
968
+ const void *x,
969
+ const void *means, /* if NULL, means are assumed to be zero */
970
+ const void *dy,
971
+ void *temp,
972
+ void *temp2,
973
+ const void *beta,
974
+ const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */
975
+ void *dx, /* output x differential */
976
+ void *dMeans); /* output means differential, can be NULL */
977
+
978
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
979
+ cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(cudnnHandle_t handle,
980
+ cudnnBatchNormMode_t mode,
981
+ cudnnBatchNormOps_t bnOps,
982
+ const cudnnTensorDescriptor_t xDesc,
983
+ const cudnnTensorDescriptor_t zDesc,
984
+ const cudnnTensorDescriptor_t yDesc,
985
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
986
+ const cudnnActivationDescriptor_t activationDesc,
987
+ size_t *sizeInBytes);
988
+
989
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
990
+ cudnnGetBatchNormalizationBackwardExWorkspaceSize(cudnnHandle_t handle,
991
+ cudnnBatchNormMode_t mode,
992
+ cudnnBatchNormOps_t bnOps,
993
+ const cudnnTensorDescriptor_t xDesc,
994
+ const cudnnTensorDescriptor_t yDesc,
995
+ const cudnnTensorDescriptor_t dyDesc,
996
+ const cudnnTensorDescriptor_t dzDesc,
997
+ const cudnnTensorDescriptor_t dxDesc,
998
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
999
+ const cudnnActivationDescriptor_t activationDesc,
1000
+ size_t *sizeInBytes);
1001
+
1002
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1003
+ cudnnGetBatchNormalizationTrainingExReserveSpaceSize(cudnnHandle_t handle,
1004
+ cudnnBatchNormMode_t mode,
1005
+ cudnnBatchNormOps_t bnOps,
1006
+ const cudnnActivationDescriptor_t activationDesc,
1007
+ const cudnnTensorDescriptor_t xDesc,
1008
+ size_t *sizeInBytes);
1009
+
1010
+ /* Computes y = BN(x). Also accumulates moving averages of mean and inverse variances */
1011
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1012
+ cudnnBatchNormalizationForwardTraining(
1013
+ cudnnHandle_t handle,
1014
+ cudnnBatchNormMode_t mode,
1015
+
1016
+ const void *alpha, /* alpha[0] = result blend factor */
1017
+ const void *beta, /* beta[0] = dest layer blend factor */
1018
+
1019
+ const cudnnTensorDescriptor_t xDesc,
1020
+ const void *x, /* NxCxHxW */
1021
+ const cudnnTensorDescriptor_t yDesc,
1022
+ void *y, /* NxCxHxW */
1023
+
1024
+ /* Shared desc for the next 6 tensors in the argument list.
1025
+ Data type to be set as follows:
1026
+ type = (typeOf(x) == double) ? double : float
1027
+ Dimensions for this descriptor depend on normalization mode
1028
+ - Spatial Normalization : tensors are expected to have dims 1xCx1x1
1029
+ (normalization is performed across NxHxW)
1030
+ - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW
1031
+ (normalization is performed across N) */
1032
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
1033
+
1034
+ /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */
1035
+ const void *bnScale,
1036
+ const void *bnBias,
1037
+
1038
+ /* MUST use factor=1 in the very first call of a complete training cycle.
1039
+ Use a factor=1/(1+n) at N-th call to the function to get
1040
+ Cumulative Moving Average (CMA) behavior
1041
+ CMA[n] = (x[1]+...+x[n])/n
1042
+ Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) =
1043
+ ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) =
1044
+ CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */
1045
+ double exponentialAverageFactor,
1046
+
1047
+ /* Used in Training phase only.
1048
+ runningMean = newMean*factor + runningMean*(1-factor) */
1049
+ void *resultRunningMean,
1050
+ /* Output in training mode, input in inference. Is the moving average
1051
+ of variance[x] (factor is applied in the same way as for runningMean) */
1052
+ void *resultRunningVariance,
1053
+
1054
+ /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
1055
+ double epsilon,
1056
+
1057
+ /* Optionally save intermediate results from the forward pass here
1058
+ - can be reused to speed up backward pass. NULL if unused */
1059
+ void *resultSaveMean,
1060
+ void *resultSaveInvVariance);
1061
+
1062
+ /* Computes y = relu(BN(x) + z). Also accumulates moving averages of mean and inverse variances */
1063
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1064
+ cudnnBatchNormalizationForwardTrainingEx(
1065
+ cudnnHandle_t handle,
1066
+ cudnnBatchNormMode_t mode,
1067
+ cudnnBatchNormOps_t bnOps,
1068
+
1069
+ const void *alpha, /* alpha[0] = result blend factor */
1070
+ const void *beta, /* beta[0] = dest layer blend factor */
1071
+
1072
+ const cudnnTensorDescriptor_t xDesc,
1073
+ const void *xData,
1074
+ const cudnnTensorDescriptor_t zDesc,
1075
+ const void *zData,
1076
+ const cudnnTensorDescriptor_t yDesc,
1077
+ void *yData,
1078
+
1079
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
1080
+ const void *bnScale,
1081
+ const void *bnBias,
1082
+
1083
+ double exponentialAverageFactor,
1084
+ void *resultRunningMean,
1085
+ void *resultRunningVariance,
1086
+
1087
+ /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
1088
+ double epsilon,
1089
+
1090
+ /* Optionally save intermediate results from the forward pass here
1091
+ - can be reused to speed up backward pass. NULL if unused */
1092
+ void *resultSaveMean,
1093
+ void *resultSaveInvVariance,
1094
+
1095
+ cudnnActivationDescriptor_t activationDesc,
1096
+ void *workspace,
1097
+ size_t workSpaceSizeInBytes,
1098
+ void *reserveSpace,
1099
+ size_t reserveSpaceSizeInBytes);
1100
+
1101
+ /* Performs backward pass of Batch Normalization layer. Returns x gradient,
1102
+ * bnScale gradient and bnBias gradient */
1103
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1104
+ cudnnBatchNormalizationBackward(cudnnHandle_t handle,
1105
+ cudnnBatchNormMode_t mode,
1106
+ const void *alphaDataDiff,
1107
+ const void *betaDataDiff,
1108
+ const void *alphaParamDiff,
1109
+ const void *betaParamDiff,
1110
+ const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */
1111
+ const void *x,
1112
+ const cudnnTensorDescriptor_t dyDesc,
1113
+ const void *dy,
1114
+ const cudnnTensorDescriptor_t dxDesc,
1115
+ void *dx,
1116
+ /* Shared tensor desc for the 4 tensors below */
1117
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
1118
+ const void *bnScale, /* bnBias doesn't affect backpropagation */
1119
+ /* scale and bias diff are not backpropagated below this layer */
1120
+ void *dBnScaleResult,
1121
+ void *dBnBiasResult,
1122
+ /* Same epsilon as forward pass */
1123
+ double epsilon,
1124
+
1125
+ /* Optionally cached intermediate results from
1126
+ forward pass */
1127
+ const void *savedMean,
1128
+ const void *savedInvVariance);
1129
+
1130
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1131
+ cudnnBatchNormalizationBackwardEx(cudnnHandle_t handle,
1132
+ cudnnBatchNormMode_t mode,
1133
+ cudnnBatchNormOps_t bnOps,
1134
+
1135
+ const void *alphaDataDiff,
1136
+ const void *betaDataDiff,
1137
+ const void *alphaParamDiff,
1138
+ const void *betaParamDiff,
1139
+ const cudnnTensorDescriptor_t xDesc,
1140
+ const void *xData,
1141
+ const cudnnTensorDescriptor_t yDesc,
1142
+ const void *yData,
1143
+ const cudnnTensorDescriptor_t dyDesc,
1144
+ const void *dyData,
1145
+ const cudnnTensorDescriptor_t dzDesc,
1146
+ void *dzData,
1147
+ const cudnnTensorDescriptor_t dxDesc,
1148
+ void *dxData,
1149
+
1150
+ /* Shared tensor desc for the 4 tensors below */
1151
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
1152
+ const void *bnScaleData,
1153
+ const void *bnBiasData, /* needed if there is activation */
1154
+ void *dBnScaleData,
1155
+ void *dBnBiasData,
1156
+ double epsilon, /* Same epsilon as forward pass */
1157
+
1158
+ /* Optionally cached intermediate results from
1159
+ forward pass */
1160
+ const void *savedMean,
1161
+ const void *savedInvVariance,
1162
+ cudnnActivationDescriptor_t activationDesc,
1163
+ void *workSpace,
1164
+ size_t workSpaceSizeInBytes,
1165
+ void *reserveSpace,
1166
+ size_t reserveSpaceSizeInBytes);
1167
+
1168
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1169
+ cudnnGetNormalizationForwardTrainingWorkspaceSize(cudnnHandle_t handle,
1170
+ cudnnNormMode_t mode,
1171
+ cudnnNormOps_t normOps,
1172
+ cudnnNormAlgo_t algo,
1173
+ const cudnnTensorDescriptor_t xDesc,
1174
+ const cudnnTensorDescriptor_t zDesc,
1175
+ const cudnnTensorDescriptor_t yDesc,
1176
+ const cudnnTensorDescriptor_t normScaleBiasDesc,
1177
+ const cudnnActivationDescriptor_t activationDesc,
1178
+ const cudnnTensorDescriptor_t normMeanVarDesc,
1179
+ size_t *sizeInBytes,
1180
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1181
+
1182
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1183
+ cudnnGetNormalizationBackwardWorkspaceSize(cudnnHandle_t handle,
1184
+ cudnnNormMode_t mode,
1185
+ cudnnNormOps_t normOps,
1186
+ cudnnNormAlgo_t algo,
1187
+ const cudnnTensorDescriptor_t xDesc,
1188
+ const cudnnTensorDescriptor_t yDesc,
1189
+ const cudnnTensorDescriptor_t dyDesc,
1190
+ const cudnnTensorDescriptor_t dzDesc,
1191
+ const cudnnTensorDescriptor_t dxDesc,
1192
+ const cudnnTensorDescriptor_t dNormScaleBiasDesc,
1193
+ const cudnnActivationDescriptor_t activationDesc,
1194
+ const cudnnTensorDescriptor_t normMeanVarDesc,
1195
+ size_t *sizeInBytes,
1196
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1197
+
1198
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1199
+ cudnnGetNormalizationTrainingReserveSpaceSize(cudnnHandle_t handle,
1200
+ cudnnNormMode_t mode,
1201
+ cudnnNormOps_t normOps,
1202
+ cudnnNormAlgo_t algo,
1203
+ const cudnnActivationDescriptor_t activationDesc,
1204
+ const cudnnTensorDescriptor_t xDesc,
1205
+ size_t *sizeInBytes,
1206
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1207
+
1208
+ /* Computes y = relu(Norm(x) + z). Also accumulates moving averages of mean and inverse variances */
1209
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1210
+ cudnnNormalizationForwardTraining(cudnnHandle_t handle,
1211
+ cudnnNormMode_t mode,
1212
+ cudnnNormOps_t normOps,
1213
+ cudnnNormAlgo_t algo,
1214
+ const void *alpha, /* alpha[0] = result blend factor */
1215
+ const void *beta, /* beta[0] = dest layer blend factor */
1216
+ const cudnnTensorDescriptor_t xDesc,
1217
+ const void *xData,
1218
+ const cudnnTensorDescriptor_t normScaleBiasDesc,
1219
+ const void *normScale,
1220
+ const void *normBias,
1221
+ double exponentialAverageFactor,
1222
+ const cudnnTensorDescriptor_t normMeanVarDesc,
1223
+ void *resultRunningMean,
1224
+ void *resultRunningVariance,
1225
+ /* Has to be >= 0. Should be the same in forward and backward functions. */
1226
+ double epsilon,
1227
+ /* Optionally save intermediate results from the forward pass here
1228
+ - can be reused to speed up backward pass. NULL if unused */
1229
+ void *resultSaveMean,
1230
+ void *resultSaveInvVariance,
1231
+ cudnnActivationDescriptor_t activationDesc,
1232
+ const cudnnTensorDescriptor_t zDesc,
1233
+ const void *zData,
1234
+ const cudnnTensorDescriptor_t yDesc,
1235
+ void *yData,
1236
+ void *workspace,
1237
+ size_t workSpaceSizeInBytes,
1238
+ void *reserveSpace,
1239
+ size_t reserveSpaceSizeInBytes,
1240
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1241
+
1242
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1243
+ cudnnNormalizationBackward(cudnnHandle_t handle,
1244
+ cudnnNormMode_t mode,
1245
+ cudnnNormOps_t normOps,
1246
+ cudnnNormAlgo_t algo,
1247
+ const void *alphaDataDiff,
1248
+ const void *betaDataDiff,
1249
+ const void *alphaParamDiff,
1250
+ const void *betaParamDiff,
1251
+ const cudnnTensorDescriptor_t xDesc,
1252
+ const void *xData,
1253
+ const cudnnTensorDescriptor_t yDesc,
1254
+ const void *yData,
1255
+ const cudnnTensorDescriptor_t dyDesc,
1256
+ const void *dyData,
1257
+ const cudnnTensorDescriptor_t dzDesc,
1258
+ void *dzData,
1259
+ const cudnnTensorDescriptor_t dxDesc,
1260
+ void *dxData,
1261
+ /* Shared tensor desc for the 4 tensors below */
1262
+ const cudnnTensorDescriptor_t dNormScaleBiasDesc,
1263
+ const void *normScaleData,
1264
+ const void *normBiasData, /* needed if there is activation */
1265
+ void *dNormScaleData,
1266
+ void *dNormBiasData,
1267
+ double epsilon, /* Same epsilon as forward pass */
1268
+ const cudnnTensorDescriptor_t normMeanVarDesc,
1269
+ /* Optionally cached intermediate results from
1270
+ forward pass */
1271
+ const void *savedMean,
1272
+ const void *savedInvVariance,
1273
+ cudnnActivationDescriptor_t activationDesc,
1274
+ void *workSpace,
1275
+ size_t workSpaceSizeInBytes,
1276
+ void *reserveSpace,
1277
+ size_t reserveSpaceSizeInBytes,
1278
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1279
+
1280
+ cudnnStatus_t CUDNNWINAPI
1281
+ cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle,
1282
+ const cudnnSpatialTransformerDescriptor_t stDesc,
1283
+ const void *dgrid,
1284
+ void *dtheta);
1285
+
1286
+ cudnnStatus_t CUDNNWINAPI
1287
+ cudnnSpatialTfSamplerBackward(cudnnHandle_t handle,
1288
+ cudnnSpatialTransformerDescriptor_t stDesc,
1289
+ const void *alpha,
1290
+ const cudnnTensorDescriptor_t xDesc,
1291
+ const void *x,
1292
+ const void *beta,
1293
+ const cudnnTensorDescriptor_t dxDesc,
1294
+ void *dx,
1295
+ const void *alphaDgrid,
1296
+ const cudnnTensorDescriptor_t dyDesc,
1297
+ const void *dy,
1298
+ const void *grid,
1299
+ const void *betaDgrid,
1300
+ void *dgrid);
1301
+
1302
+ cudnnStatus_t CUDNNWINAPI
1303
+ cudnnDropoutBackward(cudnnHandle_t handle,
1304
+ const cudnnDropoutDescriptor_t dropoutDesc,
1305
+ const cudnnTensorDescriptor_t dydesc,
1306
+ const void *dy,
1307
+ const cudnnTensorDescriptor_t dxdesc,
1308
+ void *dx,
1309
+ void *reserveSpace,
1310
+ size_t reserveSpaceSizeInBytes);
1311
+
1312
+ #if defined(__cplusplus)
1313
+ }
1314
+ #endif
1315
+
1316
+ #endif /* CUDNN_OPS_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version_v9.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /**
51
+ * \file: The master cuDNN version file.
52
+ */
53
+
54
+ #ifndef CUDNN_VERSION_H_
55
+ #define CUDNN_VERSION_H_
56
+
57
+ #define CUDNN_MAJOR 9
58
+ #define CUDNN_MINOR 1
59
+ #define CUDNN_PATCHLEVEL 0
60
+
61
+ #define CUDNN_VERSION (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
62
+
63
+ /* cannot use constexpr here since this is a C-only file */
64
+ /* Below is the max SM version this cuDNN library is aware of and supports natively */
65
+
66
+ #define CUDNN_MAX_SM_MAJOR_NUMBER 9
67
+ #define CUDNN_MAX_SM_MINOR_NUMBER 0
68
+ #define CUDNN_MAX_DEVICE_VERSION (CUDNN_MAX_SM_MAJOR_NUMBER * 100 + CUDNN_MAX_SM_MINOR_NUMBER * 10)
69
+
70
+ #endif /* CUDNN_VERSION_H */
.venv/lib/python3.11/site-packages/nvidia/cusolver/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cusolver/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (188 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverDn.h ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverMg.h ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2019 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #if !defined(CUSOLVERMG_H_)
51
+ #define CUSOLVERMG_H_
52
+
53
+ #include <stdint.h>
54
+ #include "cusolverDn.h"
55
+
56
+ #if defined(__cplusplus)
57
+ extern "C" {
58
+ #endif /* __cplusplus */
59
+
60
+ struct cusolverMgContext;
61
+ typedef struct cusolverMgContext *cusolverMgHandle_t;
62
+
63
+ /**
64
+ * \beief This enum decides how 1D device Ids (or process ranks) get mapped to
65
+ * a 2D grid.
66
+ */
67
+ typedef enum {
68
+
69
+ CUDALIBMG_GRID_MAPPING_ROW_MAJOR = 1,
70
+ CUDALIBMG_GRID_MAPPING_COL_MAJOR = 0
71
+
72
+ } cusolverMgGridMapping_t;
73
+
74
+ /** \brief Opaque structure of the distributed grid */
75
+ typedef void *cudaLibMgGrid_t;
76
+ /** \brief Opaque structure of the distributed matrix descriptor */
77
+ typedef void *cudaLibMgMatrixDesc_t;
78
+
79
+ cusolverStatus_t CUSOLVERAPI cusolverMgCreate(cusolverMgHandle_t *handle);
80
+
81
+ cusolverStatus_t CUSOLVERAPI cusolverMgDestroy(cusolverMgHandle_t handle);
82
+
83
+ cusolverStatus_t CUSOLVERAPI cusolverMgDeviceSelect(
84
+ cusolverMgHandle_t handle,
85
+ int nbDevices,
86
+ int deviceId[]);
87
+
88
+ /**
89
+ * \brief Allocates resources related to the shared memory device grid.
90
+ * \param[out] grid the opaque data strcuture that holds the grid
91
+ * \param[in] numRowDevices number of devices in the row
92
+ * \param[in] numColDevices number of devices in the column
93
+ * \param[in] deviceId This array of size height * width stores the
94
+ * device-ids of the 2D grid; each entry must correspond to a valid
95
+ * gpu or to -1 (denoting CPU). \param[in] mapping whether the 2D grid is in
96
+ * row/column major \returns the status code
97
+ */
98
+ cusolverStatus_t CUSOLVERAPI cusolverMgCreateDeviceGrid(
99
+ cudaLibMgGrid_t * grid,
100
+ int32_t numRowDevices,
101
+ int32_t numColDevices,
102
+ const int32_t deviceId[],
103
+ cusolverMgGridMapping_t mapping);
104
+
105
+ /**
106
+ * \brief Releases the allocated resources related to the distributed grid.
107
+ * \param[in] grid the opaque data strcuture that holds the distributed grid
108
+ * \returns the status code
109
+ */
110
+ cusolverStatus_t CUSOLVERAPI cusolverMgDestroyGrid(cudaLibMgGrid_t grid);
111
+
112
+ /**
113
+ * \brief Allocates resources related to the distributed matrix descriptor.
114
+ * \param[out] desc the opaque data strcuture that holds the descriptor
115
+ * \param[in] numRows number of total rows
116
+ * \param[in] numCols number of total columns
117
+ * \param[in] rowBlockSize row block size
118
+ * \param[in] colBlockSize column block size
119
+ * \param[in] dataType the data type of each element in cudaDataType
120
+ * \param[in] grid the opaque data structure of the distributed grid
121
+ * \returns the status code
122
+ */
123
+ cusolverStatus_t CUSOLVERAPI cusolverMgCreateMatrixDesc(
124
+ cudaLibMgMatrixDesc_t *desc,
125
+ int64_t numRows,
126
+ int64_t numCols,
127
+ int64_t rowBlockSize,
128
+ int64_t colBlockSize,
129
+ cudaDataType dataType,
130
+ const cudaLibMgGrid_t grid);
131
+
132
+ /**
133
+ * \brief Releases the allocated resources related to the distributed matrix
134
+ * descriptor. \param[in] desc the opaque data strcuture that holds the
135
+ * descriptor \returns the status code
136
+ */
137
+ cusolverStatus_t CUSOLVERAPI
138
+ cusolverMgDestroyMatrixDesc(cudaLibMgMatrixDesc_t desc);
139
+
140
+ cusolverStatus_t CUSOLVERAPI cusolverMgSyevd_bufferSize(
141
+ cusolverMgHandle_t handle,
142
+ cusolverEigMode_t jobz,
143
+ cublasFillMode_t uplo,
144
+ int N,
145
+ void * array_d_A[],
146
+ int IA,
147
+ int JA,
148
+ cudaLibMgMatrixDesc_t descrA,
149
+ void * W,
150
+ cudaDataType dataTypeW,
151
+ cudaDataType computeType,
152
+ int64_t * lwork);
153
+
154
+ cusolverStatus_t CUSOLVERAPI cusolverMgSyevd(
155
+ cusolverMgHandle_t handle,
156
+ cusolverEigMode_t jobz,
157
+ cublasFillMode_t uplo,
158
+ int N,
159
+ void * array_d_A[],
160
+ int IA,
161
+ int JA,
162
+ cudaLibMgMatrixDesc_t descrA,
163
+ void * W,
164
+ cudaDataType dataTypeW,
165
+ cudaDataType computeType,
166
+ void * array_d_work[],
167
+ int64_t lwork,
168
+ int * info);
169
+
170
+ cusolverStatus_t CUSOLVERAPI cusolverMgGetrf_bufferSize(
171
+ cusolverMgHandle_t handle,
172
+ int M,
173
+ int N,
174
+ void * array_d_A[],
175
+ int IA,
176
+ int JA,
177
+ cudaLibMgMatrixDesc_t descrA,
178
+ int * array_d_IPIV[],
179
+ cudaDataType computeType,
180
+ int64_t * lwork);
181
+
182
+ cusolverStatus_t CUSOLVERAPI cusolverMgGetrf(
183
+ cusolverMgHandle_t handle,
184
+ int M,
185
+ int N,
186
+ void * array_d_A[],
187
+ int IA,
188
+ int JA,
189
+ cudaLibMgMatrixDesc_t descrA,
190
+ int * array_d_IPIV[],
191
+ cudaDataType computeType,
192
+ void * array_d_work[],
193
+ int64_t lwork,
194
+ int * info);
195
+
196
+ cusolverStatus_t CUSOLVERAPI cusolverMgGetrs_bufferSize(
197
+ cusolverMgHandle_t handle,
198
+ cublasOperation_t TRANS,
199
+ int N,
200
+ int NRHS,
201
+ void * array_d_A[],
202
+ int IA,
203
+ int JA,
204
+ cudaLibMgMatrixDesc_t descrA,
205
+ int * array_d_IPIV[],
206
+ void * array_d_B[],
207
+ int IB,
208
+ int JB,
209
+ cudaLibMgMatrixDesc_t descrB,
210
+ cudaDataType computeType,
211
+ int64_t * lwork);
212
+
213
+ cusolverStatus_t CUSOLVERAPI cusolverMgGetrs(
214
+ cusolverMgHandle_t handle,
215
+ cublasOperation_t TRANS,
216
+ int N,
217
+ int NRHS,
218
+ void * array_d_A[],
219
+ int IA,
220
+ int JA,
221
+ cudaLibMgMatrixDesc_t descrA,
222
+ int * array_d_IPIV[],
223
+ void * array_d_B[],
224
+ int IB,
225
+ int JB,
226
+ cudaLibMgMatrixDesc_t descrB,
227
+ cudaDataType computeType,
228
+ void * array_d_work[],
229
+ int64_t lwork,
230
+ int * info);
231
+
232
+ cusolverStatus_t CUSOLVERAPI cusolverMgPotrf_bufferSize(
233
+ cusolverMgHandle_t handle,
234
+ cublasFillMode_t uplo,
235
+ int N,
236
+ void * array_d_A[],
237
+ int IA,
238
+ int JA,
239
+ cudaLibMgMatrixDesc_t descrA,
240
+ cudaDataType computeType,
241
+ int64_t * lwork);
242
+
243
+ cusolverStatus_t CUSOLVERAPI cusolverMgPotrf(
244
+ cusolverMgHandle_t handle,
245
+ cublasFillMode_t uplo,
246
+ int N,
247
+ void * array_d_A[],
248
+ int IA,
249
+ int JA,
250
+ cudaLibMgMatrixDesc_t descrA,
251
+ cudaDataType computeType,
252
+ void * array_d_work[],
253
+ int64_t lwork,
254
+ int * h_info);
255
+
256
+ cusolverStatus_t CUSOLVERAPI cusolverMgPotrs_bufferSize(
257
+ cusolverMgHandle_t handle,
258
+ cublasFillMode_t uplo,
259
+ int n,
260
+ int nrhs,
261
+ void * array_d_A[],
262
+ int IA,
263
+ int JA,
264
+ cudaLibMgMatrixDesc_t descrA,
265
+ void * array_d_B[],
266
+ int IB,
267
+ int JB,
268
+ cudaLibMgMatrixDesc_t descrB,
269
+ cudaDataType computeType,
270
+ int64_t * lwork);
271
+
272
+ cusolverStatus_t CUSOLVERAPI cusolverMgPotrs(
273
+ cusolverMgHandle_t handle,
274
+ cublasFillMode_t uplo,
275
+ int n,
276
+ int nrhs,
277
+ void * array_d_A[],
278
+ int IA,
279
+ int JA,
280
+ cudaLibMgMatrixDesc_t descrA,
281
+ void * array_d_B[],
282
+ int IB,
283
+ int JB,
284
+ cudaLibMgMatrixDesc_t descrB,
285
+ cudaDataType computeType,
286
+ void * array_d_work[],
287
+ int64_t lwork,
288
+ int * h_info);
289
+
290
+ cusolverStatus_t CUSOLVERAPI cusolverMgPotri_bufferSize(
291
+ cusolverMgHandle_t handle,
292
+ cublasFillMode_t uplo,
293
+ int N,
294
+ void * array_d_A[],
295
+ int IA,
296
+ int JA,
297
+ cudaLibMgMatrixDesc_t descrA,
298
+ cudaDataType computeType,
299
+ int64_t * lwork);
300
+
301
+ cusolverStatus_t CUSOLVERAPI cusolverMgPotri(
302
+ cusolverMgHandle_t handle,
303
+ cublasFillMode_t uplo,
304
+ int N,
305
+ void * array_d_A[],
306
+ int IA,
307
+ int JA,
308
+ cudaLibMgMatrixDesc_t descrA,
309
+ cudaDataType computeType,
310
+ void * array_d_work[],
311
+ int64_t lwork,
312
+ int * h_info);
313
+
314
+ #if defined(__cplusplus)
315
+ }
316
+ #endif /* __cplusplus */
317
+
318
+ #endif // CUSOLVERMG_H_
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverRf.h ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2014 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #if !defined(CUSOLVERRF_H_)
51
+ #define CUSOLVERRF_H_
52
+
53
+ #include "driver_types.h"
54
+ #include "cuComplex.h"
55
+ #include "cusolver_common.h"
56
+
57
+ #if defined(__cplusplus)
58
+ extern "C" {
59
+ #endif /* __cplusplus */
60
+
61
+ /* CUSOLVERRF mode */
62
+ typedef enum {
63
+ CUSOLVERRF_RESET_VALUES_FAST_MODE_OFF = 0, // default
64
+ CUSOLVERRF_RESET_VALUES_FAST_MODE_ON = 1
65
+ } cusolverRfResetValuesFastMode_t;
66
+
67
+ /* CUSOLVERRF matrix format */
68
+ typedef enum {
69
+ CUSOLVERRF_MATRIX_FORMAT_CSR = 0, // default
70
+ CUSOLVERRF_MATRIX_FORMAT_CSC = 1
71
+ } cusolverRfMatrixFormat_t;
72
+
73
+ /* CUSOLVERRF unit diagonal */
74
+ typedef enum {
75
+ CUSOLVERRF_UNIT_DIAGONAL_STORED_L = 0, // default
76
+ CUSOLVERRF_UNIT_DIAGONAL_STORED_U = 1,
77
+ CUSOLVERRF_UNIT_DIAGONAL_ASSUMED_L = 2,
78
+ CUSOLVERRF_UNIT_DIAGONAL_ASSUMED_U = 3
79
+ } cusolverRfUnitDiagonal_t;
80
+
81
+ /* CUSOLVERRF factorization algorithm */
82
+ typedef enum {
83
+ CUSOLVERRF_FACTORIZATION_ALG0 = 0, // default
84
+ CUSOLVERRF_FACTORIZATION_ALG1 = 1,
85
+ CUSOLVERRF_FACTORIZATION_ALG2 = 2,
86
+ } cusolverRfFactorization_t;
87
+
88
+ /* CUSOLVERRF triangular solve algorithm */
89
+ typedef enum {
90
+ CUSOLVERRF_TRIANGULAR_SOLVE_ALG1 = 1, // default
91
+ CUSOLVERRF_TRIANGULAR_SOLVE_ALG2 = 2,
92
+ CUSOLVERRF_TRIANGULAR_SOLVE_ALG3 = 3
93
+ } cusolverRfTriangularSolve_t;
94
+
95
+ /* CUSOLVERRF numeric boost report */
96
+ typedef enum {
97
+ CUSOLVERRF_NUMERIC_BOOST_NOT_USED = 0, // default
98
+ CUSOLVERRF_NUMERIC_BOOST_USED = 1
99
+ } cusolverRfNumericBoostReport_t;
100
+
101
+ /* Opaque structure holding CUSOLVERRF library common */
102
+ struct cusolverRfCommon;
103
+ typedef struct cusolverRfCommon* cusolverRfHandle_t;
104
+
105
+ /* CUSOLVERRF create (allocate memory) and destroy (free memory) in the handle
106
+ */
107
+ cusolverStatus_t CUSOLVERAPI cusolverRfCreate(cusolverRfHandle_t* handle);
108
+ cusolverStatus_t CUSOLVERAPI cusolverRfDestroy(cusolverRfHandle_t handle);
109
+
110
+ /* CUSOLVERRF set and get input format */
111
+ cusolverStatus_t CUSOLVERAPI cusolverRfGetMatrixFormat(
112
+ cusolverRfHandle_t handle,
113
+ cusolverRfMatrixFormat_t* format,
114
+ cusolverRfUnitDiagonal_t* diag);
115
+
116
+ cusolverStatus_t CUSOLVERAPI cusolverRfSetMatrixFormat(
117
+ cusolverRfHandle_t handle,
118
+ cusolverRfMatrixFormat_t format,
119
+ cusolverRfUnitDiagonal_t diag);
120
+
121
+ /* CUSOLVERRF set and get numeric properties */
122
+ cusolverStatus_t CUSOLVERAPI cusolverRfSetNumericProperties(
123
+ cusolverRfHandle_t handle,
124
+ double zero,
125
+ double boost);
126
+
127
+ cusolverStatus_t CUSOLVERAPI cusolverRfGetNumericProperties(
128
+ cusolverRfHandle_t handle,
129
+ double* zero,
130
+ double* boost);
131
+
132
+ cusolverStatus_t CUSOLVERAPI cusolverRfGetNumericBoostReport(
133
+ cusolverRfHandle_t handle,
134
+ cusolverRfNumericBoostReport_t* report);
135
+
136
+ /* CUSOLVERRF choose the triangular solve algorithm */
137
+ cusolverStatus_t CUSOLVERAPI cusolverRfSetAlgs(
138
+ cusolverRfHandle_t handle,
139
+ cusolverRfFactorization_t factAlg,
140
+ cusolverRfTriangularSolve_t solveAlg);
141
+
142
+ cusolverStatus_t CUSOLVERAPI cusolverRfGetAlgs(
143
+ cusolverRfHandle_t handle,
144
+ cusolverRfFactorization_t* factAlg,
145
+ cusolverRfTriangularSolve_t* solveAlg);
146
+
147
+ /* CUSOLVERRF set and get fast mode */
148
+ cusolverStatus_t CUSOLVERAPI cusolverRfGetResetValuesFastMode(
149
+ cusolverRfHandle_t handle,
150
+ cusolverRfResetValuesFastMode_t* fastMode);
151
+
152
+ cusolverStatus_t CUSOLVERAPI cusolverRfSetResetValuesFastMode(
153
+ cusolverRfHandle_t handle,
154
+ cusolverRfResetValuesFastMode_t fastMode);
155
+
156
+ /*** Non-Batched Routines ***/
157
+ /* CUSOLVERRF setup of internal structures from host or device memory */
158
+ cusolverStatus_t CUSOLVERAPI
159
+ cusolverRfSetupHost(/* Input (in the host memory) */
160
+ int n,
161
+ int nnzA,
162
+ int* h_csrRowPtrA,
163
+ int* h_csrColIndA,
164
+ double* h_csrValA,
165
+ int nnzL,
166
+ int* h_csrRowPtrL,
167
+ int* h_csrColIndL,
168
+ double* h_csrValL,
169
+ int nnzU,
170
+ int* h_csrRowPtrU,
171
+ int* h_csrColIndU,
172
+ double* h_csrValU,
173
+ int* h_P,
174
+ int* h_Q,
175
+ /* Output */
176
+ cusolverRfHandle_t handle);
177
+
178
+ cusolverStatus_t CUSOLVERAPI
179
+ cusolverRfSetupDevice(/* Input (in the device memory) */
180
+ int n,
181
+ int nnzA,
182
+ int* csrRowPtrA,
183
+ int* csrColIndA,
184
+ double* csrValA,
185
+ int nnzL,
186
+ int* csrRowPtrL,
187
+ int* csrColIndL,
188
+ double* csrValL,
189
+ int nnzU,
190
+ int* csrRowPtrU,
191
+ int* csrColIndU,
192
+ double* csrValU,
193
+ int* P,
194
+ int* Q,
195
+ /* Output */
196
+ cusolverRfHandle_t handle);
197
+
198
+ /* CUSOLVERRF update the matrix values (assuming the reordering, pivoting
199
+ and consequently the sparsity pattern of L and U did not change),
200
+ and zero out the remaining values. */
201
+ cusolverStatus_t CUSOLVERAPI
202
+ cusolverRfResetValues(/* Input (in the device memory) */
203
+ int n,
204
+ int nnzA,
205
+ int* csrRowPtrA,
206
+ int* csrColIndA,
207
+ double* csrValA,
208
+ int* P,
209
+ int* Q,
210
+ /* Output */
211
+ cusolverRfHandle_t handle);
212
+
213
+ /* CUSOLVERRF analysis (for parallelism) */
214
+ cusolverStatus_t CUSOLVERAPI cusolverRfAnalyze(cusolverRfHandle_t handle);
215
+
216
+ /* CUSOLVERRF re-factorization (for parallelism) */
217
+ cusolverStatus_t CUSOLVERAPI cusolverRfRefactor(cusolverRfHandle_t handle);
218
+
219
+ /* CUSOLVERRF extraction: Get L & U packed into a single matrix M */
220
+ cusolverStatus_t CUSOLVERAPI
221
+ cusolverRfAccessBundledFactorsDevice(/* Input */
222
+ cusolverRfHandle_t handle,
223
+ /* Output (in the host memory) */
224
+ int* nnzM,
225
+ /* Output (in the device memory) */
226
+ int** Mp,
227
+ int** Mi,
228
+ double** Mx);
229
+
230
+ cusolverStatus_t CUSOLVERAPI
231
+ cusolverRfExtractBundledFactorsHost(/* Input */
232
+ cusolverRfHandle_t handle,
233
+ /* Output (in the host memory) */
234
+ int* h_nnzM,
235
+ int** h_Mp,
236
+ int** h_Mi,
237
+ double** h_Mx);
238
+
239
+ /* CUSOLVERRF extraction: Get L & U individually */
240
+ cusolverStatus_t CUSOLVERAPI
241
+ cusolverRfExtractSplitFactorsHost(/* Input */
242
+ cusolverRfHandle_t handle,
243
+ /* Output (in the host memory) */
244
+ int* h_nnzL,
245
+ int** h_csrRowPtrL,
246
+ int** h_csrColIndL,
247
+ double** h_csrValL,
248
+ int* h_nnzU,
249
+ int** h_csrRowPtrU,
250
+ int** h_csrColIndU,
251
+ double** h_csrValU);
252
+
253
+ /* CUSOLVERRF (forward and backward triangular) solves */
254
+ cusolverStatus_t CUSOLVERAPI
255
+ cusolverRfSolve(/* Input (in the device memory) */
256
+ cusolverRfHandle_t handle,
257
+ int* P,
258
+ int* Q,
259
+ int nrhs, // only nrhs=1 is supported
260
+ double* Temp, // of size ldt*nrhs (ldt>=n)
261
+ int ldt,
262
+ /* Input/Output (in the device memory) */
263
+ double* XF,
264
+ /* Input */
265
+ int ldxf);
266
+
267
+ /*** Batched Routines ***/
268
+ /* CUSOLVERRF-batch setup of internal structures from host */
269
+ cusolverStatus_t CUSOLVERAPI
270
+ cusolverRfBatchSetupHost(/* Input (in the host memory)*/
271
+ int batchSize,
272
+ int n,
273
+ int nnzA,
274
+ int* h_csrRowPtrA,
275
+ int* h_csrColIndA,
276
+ double* h_csrValA_array[],
277
+ int nnzL,
278
+ int* h_csrRowPtrL,
279
+ int* h_csrColIndL,
280
+ double* h_csrValL,
281
+ int nnzU,
282
+ int* h_csrRowPtrU,
283
+ int* h_csrColIndU,
284
+ double* h_csrValU,
285
+ int* h_P,
286
+ int* h_Q,
287
+ /* Output (in the device memory) */
288
+ cusolverRfHandle_t handle);
289
+
290
+ /* CUSOLVERRF-batch update the matrix values (assuming the reordering,
291
+ pivoting and consequently the sparsity pattern of L and U did not change),
292
+ and zero out the remaining values. */
293
+ cusolverStatus_t CUSOLVERAPI
294
+ cusolverRfBatchResetValues(/* Input (in the device memory) */
295
+ int batchSize,
296
+ int n,
297
+ int nnzA,
298
+ int* csrRowPtrA,
299
+ int* csrColIndA,
300
+ double* csrValA_array[],
301
+ int* P,
302
+ int* Q,
303
+ /* Output */
304
+ cusolverRfHandle_t handle);
305
+
306
+ /* CUSOLVERRF-batch analysis (for parallelism) */
307
+ cusolverStatus_t CUSOLVERAPI
308
+ cusolverRfBatchAnalyze(cusolverRfHandle_t handle);
309
+
310
+ /* CUSOLVERRF-batch re-factorization (for parallelism) */
311
+ cusolverStatus_t CUSOLVERAPI
312
+ cusolverRfBatchRefactor(cusolverRfHandle_t handle);
313
+
314
+ /* CUSOLVERRF-batch (forward and backward triangular) solves */
315
+ cusolverStatus_t CUSOLVERAPI
316
+ cusolverRfBatchSolve(/* Input (in the device memory) */
317
+ cusolverRfHandle_t handle,
318
+ int* P,
319
+ int* Q,
320
+ int nrhs, // only nrhs=1 is supported
321
+ double* Temp, // of size 2*batchSize*(n*nrhs)
322
+ int ldt, // only ldt=n is supported
323
+ /* Input/Output (in the device memory) */
324
+ double* XF_array[],
325
+ /* Input */
326
+ int ldxf);
327
+
328
+ /* CUSOLVERRF-batch obtain the position of zero pivot */
329
+ cusolverStatus_t CUSOLVERAPI
330
+ cusolverRfBatchZeroPivot(/* Input */
331
+ cusolverRfHandle_t handle,
332
+ /* Output (in the host memory) */
333
+ int* position);
334
+
335
+ #if defined(__cplusplus)
336
+ }
337
+ #endif /* __cplusplus */
338
+
339
+ #endif /* CUSOLVERRF_H_ */