Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .venv/lib/python3.11/site-packages/nvidia/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/include/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas.h +891 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h +1845 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasXt.h +693 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h +478 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/include/nvblas.h +824 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/lib/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 +3 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/nvrtc.h +869 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/async.h +452 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/coalesced_scan.h +174 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/driver_abi.h +99 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/info.h +344 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/invoke.h +189 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/memory.h +135 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/partitioning.h +159 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/reduce.h +419 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/scan.h +320 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/sync.h +282 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 +3 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn.h +68 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv_v9.h +671 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend_v9.h +60 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_graph.h +909 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops.h +1316 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version_v9.h +70 -0
- .venv/lib/python3.11/site-packages/nvidia/cusolver/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cusolver/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cusolver/include/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cusolver/include/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverDn.h +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverMg.h +318 -0
- .venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverRf.h +339 -0
.gitattributes
CHANGED
|
@@ -120,3 +120,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 120 |
.venv/lib/python3.11/site-packages/click/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 121 |
.venv/lib/python3.11/site-packages/pyasn1/type/__pycache__/univ.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 122 |
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 120 |
.venv/lib/python3.11/site-packages/click/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 121 |
.venv/lib/python3.11/site-packages/pyasn1/type/__pycache__/univ.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 122 |
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text
|
| 123 |
+
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
|
| 124 |
+
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/nvidia/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (179 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (186 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cublas/include/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas.h
ADDED
|
@@ -0,0 +1,891 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* This is the public header file for the CUBLAS library, defining the API
|
| 52 |
+
*
|
| 53 |
+
* CUBLAS is an implementation of BLAS (Basic Linear Algebra Subroutines)
|
| 54 |
+
* on top of the CUDA runtime.
|
| 55 |
+
*/
|
| 56 |
+
|
| 57 |
+
#if !defined(CUBLAS_H_)
|
| 58 |
+
#define CUBLAS_H_
|
| 59 |
+
|
| 60 |
+
#if defined(CUBLAS_V2_H_)
|
| 61 |
+
#error "It is an error to include both cublas.h and cublas_v2.h"
|
| 62 |
+
#endif
|
| 63 |
+
|
| 64 |
+
#include <cuda_runtime.h>
|
| 65 |
+
|
| 66 |
+
#ifndef CUBLASWINAPI
|
| 67 |
+
#ifdef _WIN32
|
| 68 |
+
#define CUBLASWINAPI __stdcall
|
| 69 |
+
#else
|
| 70 |
+
#define CUBLASWINAPI
|
| 71 |
+
#endif
|
| 72 |
+
#endif
|
| 73 |
+
|
| 74 |
+
#undef CUBLASAPI
|
| 75 |
+
#ifdef __CUDACC__
|
| 76 |
+
#define CUBLASAPI __host__
|
| 77 |
+
#else
|
| 78 |
+
#define CUBLASAPI
|
| 79 |
+
#endif
|
| 80 |
+
|
| 81 |
+
#include "cublas_api.h"
|
| 82 |
+
|
| 83 |
+
#if defined(__cplusplus)
|
| 84 |
+
extern "C" {
|
| 85 |
+
#endif
|
| 86 |
+
|
| 87 |
+
/* CUBLAS data types */
|
| 88 |
+
#define cublasStatus cublasStatus_t
|
| 89 |
+
|
| 90 |
+
cublasStatus CUBLASWINAPI cublasInit(void);
|
| 91 |
+
cublasStatus CUBLASWINAPI cublasShutdown(void);
|
| 92 |
+
cublasStatus CUBLASWINAPI cublasGetError(void);
|
| 93 |
+
|
| 94 |
+
cublasStatus CUBLASWINAPI cublasGetVersion(int* version);
|
| 95 |
+
cublasStatus CUBLASWINAPI cublasAlloc(int n, int elemSize, void** devicePtr);
|
| 96 |
+
|
| 97 |
+
cublasStatus CUBLASWINAPI cublasFree(void* devicePtr);
|
| 98 |
+
|
| 99 |
+
cublasStatus CUBLASWINAPI cublasSetKernelStream(cudaStream_t stream);
|
| 100 |
+
|
| 101 |
+
/* ---------------- CUBLAS BLAS1 functions ---------------- */
|
| 102 |
+
/* NRM2 */
|
| 103 |
+
float CUBLASWINAPI cublasSnrm2(int n, const float* x, int incx);
|
| 104 |
+
double CUBLASWINAPI cublasDnrm2(int n, const double* x, int incx);
|
| 105 |
+
float CUBLASWINAPI cublasScnrm2(int n, const cuComplex* x, int incx);
|
| 106 |
+
double CUBLASWINAPI cublasDznrm2(int n, const cuDoubleComplex* x, int incx);
|
| 107 |
+
/*------------------------------------------------------------------------*/
|
| 108 |
+
/* DOT */
|
| 109 |
+
float CUBLASWINAPI cublasSdot(int n, const float* x, int incx, const float* y, int incy);
|
| 110 |
+
double CUBLASWINAPI cublasDdot(int n, const double* x, int incx, const double* y, int incy);
|
| 111 |
+
cuComplex CUBLASWINAPI cublasCdotu(int n, const cuComplex* x, int incx, const cuComplex* y, int incy);
|
| 112 |
+
cuComplex CUBLASWINAPI cublasCdotc(int n, const cuComplex* x, int incx, const cuComplex* y, int incy);
|
| 113 |
+
cuDoubleComplex CUBLASWINAPI cublasZdotu(int n, const cuDoubleComplex* x, int incx, const cuDoubleComplex* y, int incy);
|
| 114 |
+
cuDoubleComplex CUBLASWINAPI cublasZdotc(int n, const cuDoubleComplex* x, int incx, const cuDoubleComplex* y, int incy);
|
| 115 |
+
/*------------------------------------------------------------------------*/
|
| 116 |
+
/* SCAL */
|
| 117 |
+
void CUBLASWINAPI cublasSscal(int n, float alpha, float* x, int incx);
|
| 118 |
+
void CUBLASWINAPI cublasDscal(int n, double alpha, double* x, int incx);
|
| 119 |
+
void CUBLASWINAPI cublasCscal(int n, cuComplex alpha, cuComplex* x, int incx);
|
| 120 |
+
void CUBLASWINAPI cublasZscal(int n, cuDoubleComplex alpha, cuDoubleComplex* x, int incx);
|
| 121 |
+
|
| 122 |
+
void CUBLASWINAPI cublasCsscal(int n, float alpha, cuComplex* x, int incx);
|
| 123 |
+
void CUBLASWINAPI cublasZdscal(int n, double alpha, cuDoubleComplex* x, int incx);
|
| 124 |
+
/*------------------------------------------------------------------------*/
|
| 125 |
+
/* AXPY */
|
| 126 |
+
void CUBLASWINAPI cublasSaxpy(int n, float alpha, const float* x, int incx, float* y, int incy);
|
| 127 |
+
void CUBLASWINAPI cublasDaxpy(int n, double alpha, const double* x, int incx, double* y, int incy);
|
| 128 |
+
void CUBLASWINAPI cublasCaxpy(int n, cuComplex alpha, const cuComplex* x, int incx, cuComplex* y, int incy);
|
| 129 |
+
void CUBLASWINAPI
|
| 130 |
+
cublasZaxpy(int n, cuDoubleComplex alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy);
|
| 131 |
+
/*------------------------------------------------------------------------*/
|
| 132 |
+
/* COPY */
|
| 133 |
+
void CUBLASWINAPI cublasScopy(int n, const float* x, int incx, float* y, int incy);
|
| 134 |
+
void CUBLASWINAPI cublasDcopy(int n, const double* x, int incx, double* y, int incy);
|
| 135 |
+
void CUBLASWINAPI cublasCcopy(int n, const cuComplex* x, int incx, cuComplex* y, int incy);
|
| 136 |
+
void CUBLASWINAPI cublasZcopy(int n, const cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy);
|
| 137 |
+
/*------------------------------------------------------------------------*/
|
| 138 |
+
/* SWAP */
|
| 139 |
+
void CUBLASWINAPI cublasSswap(int n, float* x, int incx, float* y, int incy);
|
| 140 |
+
void CUBLASWINAPI cublasDswap(int n, double* x, int incx, double* y, int incy);
|
| 141 |
+
void CUBLASWINAPI cublasCswap(int n, cuComplex* x, int incx, cuComplex* y, int incy);
|
| 142 |
+
void CUBLASWINAPI cublasZswap(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy);
|
| 143 |
+
/*------------------------------------------------------------------------*/
|
| 144 |
+
/* AMAX */
|
| 145 |
+
int CUBLASWINAPI cublasIsamax(int n, const float* x, int incx);
|
| 146 |
+
int CUBLASWINAPI cublasIdamax(int n, const double* x, int incx);
|
| 147 |
+
int CUBLASWINAPI cublasIcamax(int n, const cuComplex* x, int incx);
|
| 148 |
+
int CUBLASWINAPI cublasIzamax(int n, const cuDoubleComplex* x, int incx);
|
| 149 |
+
/*------------------------------------------------------------------------*/
|
| 150 |
+
/* AMIN */
|
| 151 |
+
int CUBLASWINAPI cublasIsamin(int n, const float* x, int incx);
|
| 152 |
+
int CUBLASWINAPI cublasIdamin(int n, const double* x, int incx);
|
| 153 |
+
|
| 154 |
+
int CUBLASWINAPI cublasIcamin(int n, const cuComplex* x, int incx);
|
| 155 |
+
int CUBLASWINAPI cublasIzamin(int n, const cuDoubleComplex* x, int incx);
|
| 156 |
+
/*------------------------------------------------------------------------*/
|
| 157 |
+
/* ASUM */
|
| 158 |
+
float CUBLASWINAPI cublasSasum(int n, const float* x, int incx);
|
| 159 |
+
double CUBLASWINAPI cublasDasum(int n, const double* x, int incx);
|
| 160 |
+
float CUBLASWINAPI cublasScasum(int n, const cuComplex* x, int incx);
|
| 161 |
+
double CUBLASWINAPI cublasDzasum(int n, const cuDoubleComplex* x, int incx);
|
| 162 |
+
/*------------------------------------------------------------------------*/
|
| 163 |
+
/* ROT */
|
| 164 |
+
void CUBLASWINAPI cublasSrot(int n, float* x, int incx, float* y, int incy, float sc, float ss);
|
| 165 |
+
void CUBLASWINAPI cublasDrot(int n, double* x, int incx, double* y, int incy, double sc, double ss);
|
| 166 |
+
void CUBLASWINAPI cublasCrot(int n, cuComplex* x, int incx, cuComplex* y, int incy, float c, cuComplex s);
|
| 167 |
+
void CUBLASWINAPI
|
| 168 |
+
cublasZrot(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, double sc, cuDoubleComplex cs);
|
| 169 |
+
void CUBLASWINAPI cublasCsrot(int n, cuComplex* x, int incx, cuComplex* y, int incy, float c, float s);
|
| 170 |
+
void CUBLASWINAPI cublasZdrot(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, double c, double s);
|
| 171 |
+
/*------------------------------------------------------------------------*/
|
| 172 |
+
/* ROTG */
|
| 173 |
+
void CUBLASWINAPI cublasSrotg(float* sa, float* sb, float* sc, float* ss);
|
| 174 |
+
void CUBLASWINAPI cublasDrotg(double* sa, double* sb, double* sc, double* ss);
|
| 175 |
+
void CUBLASWINAPI cublasCrotg(cuComplex* ca, cuComplex cb, float* sc, cuComplex* cs);
|
| 176 |
+
void CUBLASWINAPI cublasZrotg(cuDoubleComplex* ca, cuDoubleComplex cb, double* sc, cuDoubleComplex* cs);
|
| 177 |
+
/*------------------------------------------------------------------------*/
|
| 178 |
+
/* ROTM */
|
| 179 |
+
void CUBLASWINAPI cublasSrotm(int n, float* x, int incx, float* y, int incy, const float* sparam);
|
| 180 |
+
void CUBLASWINAPI cublasDrotm(int n, double* x, int incx, double* y, int incy, const double* sparam);
|
| 181 |
+
/*------------------------------------------------------------------------*/
|
| 182 |
+
/* ROTMG */
|
| 183 |
+
void CUBLASWINAPI cublasSrotmg(float* sd1, float* sd2, float* sx1, const float* sy1, float* sparam);
|
| 184 |
+
void CUBLASWINAPI cublasDrotmg(double* sd1, double* sd2, double* sx1, const double* sy1, double* sparam);
|
| 185 |
+
|
| 186 |
+
/* --------------- CUBLAS BLAS2 functions ---------------- */
|
| 187 |
+
/* GEMV */
|
| 188 |
+
void CUBLASWINAPI cublasSgemv(char trans,
|
| 189 |
+
int m,
|
| 190 |
+
int n,
|
| 191 |
+
float alpha,
|
| 192 |
+
const float* A,
|
| 193 |
+
int lda,
|
| 194 |
+
const float* x,
|
| 195 |
+
int incx,
|
| 196 |
+
float beta,
|
| 197 |
+
float* y,
|
| 198 |
+
int incy);
|
| 199 |
+
void CUBLASWINAPI cublasDgemv(char trans,
|
| 200 |
+
int m,
|
| 201 |
+
int n,
|
| 202 |
+
double alpha,
|
| 203 |
+
const double* A,
|
| 204 |
+
int lda,
|
| 205 |
+
const double* x,
|
| 206 |
+
int incx,
|
| 207 |
+
double beta,
|
| 208 |
+
double* y,
|
| 209 |
+
int incy);
|
| 210 |
+
void CUBLASWINAPI cublasCgemv(char trans,
|
| 211 |
+
int m,
|
| 212 |
+
int n,
|
| 213 |
+
cuComplex alpha,
|
| 214 |
+
const cuComplex* A,
|
| 215 |
+
int lda,
|
| 216 |
+
const cuComplex* x,
|
| 217 |
+
int incx,
|
| 218 |
+
cuComplex beta,
|
| 219 |
+
cuComplex* y,
|
| 220 |
+
int incy);
|
| 221 |
+
void CUBLASWINAPI cublasZgemv(char trans,
|
| 222 |
+
int m,
|
| 223 |
+
int n,
|
| 224 |
+
cuDoubleComplex alpha,
|
| 225 |
+
const cuDoubleComplex* A,
|
| 226 |
+
int lda,
|
| 227 |
+
const cuDoubleComplex* x,
|
| 228 |
+
int incx,
|
| 229 |
+
cuDoubleComplex beta,
|
| 230 |
+
cuDoubleComplex* y,
|
| 231 |
+
int incy);
|
| 232 |
+
/*------------------------------------------------------------------------*/
|
| 233 |
+
/* GBMV */
|
| 234 |
+
void CUBLASWINAPI cublasSgbmv(char trans,
|
| 235 |
+
int m,
|
| 236 |
+
int n,
|
| 237 |
+
int kl,
|
| 238 |
+
int ku,
|
| 239 |
+
float alpha,
|
| 240 |
+
const float* A,
|
| 241 |
+
int lda,
|
| 242 |
+
const float* x,
|
| 243 |
+
int incx,
|
| 244 |
+
float beta,
|
| 245 |
+
float* y,
|
| 246 |
+
int incy);
|
| 247 |
+
void CUBLASWINAPI cublasDgbmv(char trans,
|
| 248 |
+
int m,
|
| 249 |
+
int n,
|
| 250 |
+
int kl,
|
| 251 |
+
int ku,
|
| 252 |
+
double alpha,
|
| 253 |
+
const double* A,
|
| 254 |
+
int lda,
|
| 255 |
+
const double* x,
|
| 256 |
+
int incx,
|
| 257 |
+
double beta,
|
| 258 |
+
double* y,
|
| 259 |
+
int incy);
|
| 260 |
+
void CUBLASWINAPI cublasCgbmv(char trans,
|
| 261 |
+
int m,
|
| 262 |
+
int n,
|
| 263 |
+
int kl,
|
| 264 |
+
int ku,
|
| 265 |
+
cuComplex alpha,
|
| 266 |
+
const cuComplex* A,
|
| 267 |
+
int lda,
|
| 268 |
+
const cuComplex* x,
|
| 269 |
+
int incx,
|
| 270 |
+
cuComplex beta,
|
| 271 |
+
cuComplex* y,
|
| 272 |
+
int incy);
|
| 273 |
+
void CUBLASWINAPI cublasZgbmv(char trans,
|
| 274 |
+
int m,
|
| 275 |
+
int n,
|
| 276 |
+
int kl,
|
| 277 |
+
int ku,
|
| 278 |
+
cuDoubleComplex alpha,
|
| 279 |
+
const cuDoubleComplex* A,
|
| 280 |
+
int lda,
|
| 281 |
+
const cuDoubleComplex* x,
|
| 282 |
+
int incx,
|
| 283 |
+
cuDoubleComplex beta,
|
| 284 |
+
cuDoubleComplex* y,
|
| 285 |
+
int incy);
|
| 286 |
+
/*------------------------------------------------------------------------*/
|
| 287 |
+
/* TRMV */
|
| 288 |
+
void CUBLASWINAPI cublasStrmv(char uplo, char trans, char diag, int n, const float* A, int lda, float* x, int incx);
|
| 289 |
+
void CUBLASWINAPI cublasDtrmv(char uplo, char trans, char diag, int n, const double* A, int lda, double* x, int incx);
|
| 290 |
+
void CUBLASWINAPI
|
| 291 |
+
cublasCtrmv(char uplo, char trans, char diag, int n, const cuComplex* A, int lda, cuComplex* x, int incx);
|
| 292 |
+
void CUBLASWINAPI
|
| 293 |
+
cublasZtrmv(char uplo, char trans, char diag, int n, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx);
|
| 294 |
+
/*------------------------------------------------------------------------*/
|
| 295 |
+
/* TBMV */
|
| 296 |
+
void CUBLASWINAPI
|
| 297 |
+
cublasStbmv(char uplo, char trans, char diag, int n, int k, const float* A, int lda, float* x, int incx);
|
| 298 |
+
void CUBLASWINAPI
|
| 299 |
+
cublasDtbmv(char uplo, char trans, char diag, int n, int k, const double* A, int lda, double* x, int incx);
|
| 300 |
+
void CUBLASWINAPI
|
| 301 |
+
cublasCtbmv(char uplo, char trans, char diag, int n, int k, const cuComplex* A, int lda, cuComplex* x, int incx);
|
| 302 |
+
void CUBLASWINAPI cublasZtbmv(
|
| 303 |
+
char uplo, char trans, char diag, int n, int k, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx);
|
| 304 |
+
/*------------------------------------------------------------------------*/
|
| 305 |
+
/* TPMV */
|
| 306 |
+
void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, const float* AP, float* x, int incx);
|
| 307 |
+
|
| 308 |
+
void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, const double* AP, double* x, int incx);
|
| 309 |
+
|
| 310 |
+
void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, const cuComplex* AP, cuComplex* x, int incx);
|
| 311 |
+
|
| 312 |
+
void CUBLASWINAPI
|
| 313 |
+
cublasZtpmv(char uplo, char trans, char diag, int n, const cuDoubleComplex* AP, cuDoubleComplex* x, int incx);
|
| 314 |
+
/*------------------------------------------------------------------------*/
|
| 315 |
+
/* TRSV */
|
| 316 |
+
void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, const float* A, int lda, float* x, int incx);
|
| 317 |
+
|
| 318 |
+
void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, const double* A, int lda, double* x, int incx);
|
| 319 |
+
|
| 320 |
+
void CUBLASWINAPI
|
| 321 |
+
cublasCtrsv(char uplo, char trans, char diag, int n, const cuComplex* A, int lda, cuComplex* x, int incx);
|
| 322 |
+
|
| 323 |
+
void CUBLASWINAPI
|
| 324 |
+
cublasZtrsv(char uplo, char trans, char diag, int n, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx);
|
| 325 |
+
/*------------------------------------------------------------------------*/
|
| 326 |
+
/* TPSV */
|
| 327 |
+
void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, const float* AP, float* x, int incx);
|
| 328 |
+
|
| 329 |
+
void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, const double* AP, double* x, int incx);
|
| 330 |
+
|
| 331 |
+
void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, const cuComplex* AP, cuComplex* x, int incx);
|
| 332 |
+
|
| 333 |
+
void CUBLASWINAPI
|
| 334 |
+
cublasZtpsv(char uplo, char trans, char diag, int n, const cuDoubleComplex* AP, cuDoubleComplex* x, int incx);
|
| 335 |
+
/*------------------------------------------------------------------------*/
|
| 336 |
+
/* TBSV */
|
| 337 |
+
void CUBLASWINAPI
|
| 338 |
+
cublasStbsv(char uplo, char trans, char diag, int n, int k, const float* A, int lda, float* x, int incx);
|
| 339 |
+
|
| 340 |
+
void CUBLASWINAPI
|
| 341 |
+
cublasDtbsv(char uplo, char trans, char diag, int n, int k, const double* A, int lda, double* x, int incx);
|
| 342 |
+
void CUBLASWINAPI
|
| 343 |
+
cublasCtbsv(char uplo, char trans, char diag, int n, int k, const cuComplex* A, int lda, cuComplex* x, int incx);
|
| 344 |
+
|
| 345 |
+
void CUBLASWINAPI cublasZtbsv(
|
| 346 |
+
char uplo, char trans, char diag, int n, int k, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx);
|
| 347 |
+
/*------------------------------------------------------------------------*/
|
| 348 |
+
/* SYMV/HEMV */
|
| 349 |
+
void CUBLASWINAPI cublasSsymv(
|
| 350 |
+
char uplo, int n, float alpha, const float* A, int lda, const float* x, int incx, float beta, float* y, int incy);
|
| 351 |
+
void CUBLASWINAPI cublasDsymv(char uplo,
|
| 352 |
+
int n,
|
| 353 |
+
double alpha,
|
| 354 |
+
const double* A,
|
| 355 |
+
int lda,
|
| 356 |
+
const double* x,
|
| 357 |
+
int incx,
|
| 358 |
+
double beta,
|
| 359 |
+
double* y,
|
| 360 |
+
int incy);
|
| 361 |
+
void CUBLASWINAPI cublasChemv(char uplo,
|
| 362 |
+
int n,
|
| 363 |
+
cuComplex alpha,
|
| 364 |
+
const cuComplex* A,
|
| 365 |
+
int lda,
|
| 366 |
+
const cuComplex* x,
|
| 367 |
+
int incx,
|
| 368 |
+
cuComplex beta,
|
| 369 |
+
cuComplex* y,
|
| 370 |
+
int incy);
|
| 371 |
+
void CUBLASWINAPI cublasZhemv(char uplo,
|
| 372 |
+
int n,
|
| 373 |
+
cuDoubleComplex alpha,
|
| 374 |
+
const cuDoubleComplex* A,
|
| 375 |
+
int lda,
|
| 376 |
+
const cuDoubleComplex* x,
|
| 377 |
+
int incx,
|
| 378 |
+
cuDoubleComplex beta,
|
| 379 |
+
cuDoubleComplex* y,
|
| 380 |
+
int incy);
|
| 381 |
+
/*------------------------------------------------------------------------*/
|
| 382 |
+
/* SBMV/HBMV */
|
| 383 |
+
void CUBLASWINAPI cublasSsbmv(char uplo,
|
| 384 |
+
int n,
|
| 385 |
+
int k,
|
| 386 |
+
float alpha,
|
| 387 |
+
const float* A,
|
| 388 |
+
int lda,
|
| 389 |
+
const float* x,
|
| 390 |
+
int incx,
|
| 391 |
+
float beta,
|
| 392 |
+
float* y,
|
| 393 |
+
int incy);
|
| 394 |
+
void CUBLASWINAPI cublasDsbmv(char uplo,
|
| 395 |
+
int n,
|
| 396 |
+
int k,
|
| 397 |
+
double alpha,
|
| 398 |
+
const double* A,
|
| 399 |
+
int lda,
|
| 400 |
+
const double* x,
|
| 401 |
+
int incx,
|
| 402 |
+
double beta,
|
| 403 |
+
double* y,
|
| 404 |
+
int incy);
|
| 405 |
+
void CUBLASWINAPI cublasChbmv(char uplo,
|
| 406 |
+
int n,
|
| 407 |
+
int k,
|
| 408 |
+
cuComplex alpha,
|
| 409 |
+
const cuComplex* A,
|
| 410 |
+
int lda,
|
| 411 |
+
const cuComplex* x,
|
| 412 |
+
int incx,
|
| 413 |
+
cuComplex beta,
|
| 414 |
+
cuComplex* y,
|
| 415 |
+
int incy);
|
| 416 |
+
void CUBLASWINAPI cublasZhbmv(char uplo,
|
| 417 |
+
int n,
|
| 418 |
+
int k,
|
| 419 |
+
cuDoubleComplex alpha,
|
| 420 |
+
const cuDoubleComplex* A,
|
| 421 |
+
int lda,
|
| 422 |
+
const cuDoubleComplex* x,
|
| 423 |
+
int incx,
|
| 424 |
+
cuDoubleComplex beta,
|
| 425 |
+
cuDoubleComplex* y,
|
| 426 |
+
int incy);
|
| 427 |
+
/*------------------------------------------------------------------------*/
|
| 428 |
+
/* SPMV/HPMV */
|
| 429 |
+
void CUBLASWINAPI
|
| 430 |
+
cublasSspmv(char uplo, int n, float alpha, const float* AP, const float* x, int incx, float beta, float* y, int incy);
|
| 431 |
+
void CUBLASWINAPI cublasDspmv(
|
| 432 |
+
char uplo, int n, double alpha, const double* AP, const double* x, int incx, double beta, double* y, int incy);
|
| 433 |
+
void CUBLASWINAPI cublasChpmv(char uplo,
|
| 434 |
+
int n,
|
| 435 |
+
cuComplex alpha,
|
| 436 |
+
const cuComplex* AP,
|
| 437 |
+
const cuComplex* x,
|
| 438 |
+
int incx,
|
| 439 |
+
cuComplex beta,
|
| 440 |
+
cuComplex* y,
|
| 441 |
+
int incy);
|
| 442 |
+
void CUBLASWINAPI cublasZhpmv(char uplo,
|
| 443 |
+
int n,
|
| 444 |
+
cuDoubleComplex alpha,
|
| 445 |
+
const cuDoubleComplex* AP,
|
| 446 |
+
const cuDoubleComplex* x,
|
| 447 |
+
int incx,
|
| 448 |
+
cuDoubleComplex beta,
|
| 449 |
+
cuDoubleComplex* y,
|
| 450 |
+
int incy);
|
| 451 |
+
|
| 452 |
+
/*------------------------------------------------------------------------*/
|
| 453 |
+
/* GER */
|
| 454 |
+
void CUBLASWINAPI
|
| 455 |
+
cublasSger(int m, int n, float alpha, const float* x, int incx, const float* y, int incy, float* A, int lda);
|
| 456 |
+
void CUBLASWINAPI
|
| 457 |
+
cublasDger(int m, int n, double alpha, const double* x, int incx, const double* y, int incy, double* A, int lda);
|
| 458 |
+
|
| 459 |
+
void CUBLASWINAPI cublasCgeru(
|
| 460 |
+
int m, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* A, int lda);
|
| 461 |
+
void CUBLASWINAPI cublasCgerc(
|
| 462 |
+
int m, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* A, int lda);
|
| 463 |
+
void CUBLASWINAPI cublasZgeru(int m,
|
| 464 |
+
int n,
|
| 465 |
+
cuDoubleComplex alpha,
|
| 466 |
+
const cuDoubleComplex* x,
|
| 467 |
+
int incx,
|
| 468 |
+
const cuDoubleComplex* y,
|
| 469 |
+
int incy,
|
| 470 |
+
cuDoubleComplex* A,
|
| 471 |
+
int lda);
|
| 472 |
+
void CUBLASWINAPI cublasZgerc(int m,
|
| 473 |
+
int n,
|
| 474 |
+
cuDoubleComplex alpha,
|
| 475 |
+
const cuDoubleComplex* x,
|
| 476 |
+
int incx,
|
| 477 |
+
const cuDoubleComplex* y,
|
| 478 |
+
int incy,
|
| 479 |
+
cuDoubleComplex* A,
|
| 480 |
+
int lda);
|
| 481 |
+
/*------------------------------------------------------------------------*/
|
| 482 |
+
/* SYR/HER */
|
| 483 |
+
void CUBLASWINAPI cublasSsyr(char uplo, int n, float alpha, const float* x, int incx, float* A, int lda);
|
| 484 |
+
void CUBLASWINAPI cublasDsyr(char uplo, int n, double alpha, const double* x, int incx, double* A, int lda);
|
| 485 |
+
|
| 486 |
+
void CUBLASWINAPI cublasCher(char uplo, int n, float alpha, const cuComplex* x, int incx, cuComplex* A, int lda);
|
| 487 |
+
void CUBLASWINAPI
|
| 488 |
+
cublasZher(char uplo, int n, double alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* A, int lda);
|
| 489 |
+
|
| 490 |
+
/*------------------------------------------------------------------------*/
|
| 491 |
+
/* SPR/HPR */
|
| 492 |
+
void CUBLASWINAPI cublasSspr(char uplo, int n, float alpha, const float* x, int incx, float* AP);
|
| 493 |
+
void CUBLASWINAPI cublasDspr(char uplo, int n, double alpha, const double* x, int incx, double* AP);
|
| 494 |
+
void CUBLASWINAPI cublasChpr(char uplo, int n, float alpha, const cuComplex* x, int incx, cuComplex* AP);
|
| 495 |
+
void CUBLASWINAPI cublasZhpr(char uplo, int n, double alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* AP);
|
| 496 |
+
/*------------------------------------------------------------------------*/
|
| 497 |
+
/* SYR2/HER2 */
|
| 498 |
+
void CUBLASWINAPI
|
| 499 |
+
cublasSsyr2(char uplo, int n, float alpha, const float* x, int incx, const float* y, int incy, float* A, int lda);
|
| 500 |
+
void CUBLASWINAPI
|
| 501 |
+
cublasDsyr2(char uplo, int n, double alpha, const double* x, int incx, const double* y, int incy, double* A, int lda);
|
| 502 |
+
void CUBLASWINAPI cublasCher2(char uplo,
|
| 503 |
+
int n,
|
| 504 |
+
cuComplex alpha,
|
| 505 |
+
const cuComplex* x,
|
| 506 |
+
int incx,
|
| 507 |
+
const cuComplex* y,
|
| 508 |
+
int incy,
|
| 509 |
+
cuComplex* A,
|
| 510 |
+
int lda);
|
| 511 |
+
void CUBLASWINAPI cublasZher2(char uplo,
|
| 512 |
+
int n,
|
| 513 |
+
cuDoubleComplex alpha,
|
| 514 |
+
const cuDoubleComplex* x,
|
| 515 |
+
int incx,
|
| 516 |
+
const cuDoubleComplex* y,
|
| 517 |
+
int incy,
|
| 518 |
+
cuDoubleComplex* A,
|
| 519 |
+
int lda);
|
| 520 |
+
|
| 521 |
+
/*------------------------------------------------------------------------*/
|
| 522 |
+
/* SPR2/HPR2 */
|
| 523 |
+
void CUBLASWINAPI
|
| 524 |
+
cublasSspr2(char uplo, int n, float alpha, const float* x, int incx, const float* y, int incy, float* AP);
|
| 525 |
+
void CUBLASWINAPI
|
| 526 |
+
cublasDspr2(char uplo, int n, double alpha, const double* x, int incx, const double* y, int incy, double* AP);
|
| 527 |
+
void CUBLASWINAPI cublasChpr2(
|
| 528 |
+
char uplo, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* AP);
|
| 529 |
+
void CUBLASWINAPI cublasZhpr2(char uplo,
|
| 530 |
+
int n,
|
| 531 |
+
cuDoubleComplex alpha,
|
| 532 |
+
const cuDoubleComplex* x,
|
| 533 |
+
int incx,
|
| 534 |
+
const cuDoubleComplex* y,
|
| 535 |
+
int incy,
|
| 536 |
+
cuDoubleComplex* AP);
|
| 537 |
+
/* ------------------------BLAS3 Functions ------------------------------- */
|
| 538 |
+
/* GEMM */
|
| 539 |
+
void CUBLASWINAPI cublasSgemm(char transa,
|
| 540 |
+
char transb,
|
| 541 |
+
int m,
|
| 542 |
+
int n,
|
| 543 |
+
int k,
|
| 544 |
+
float alpha,
|
| 545 |
+
const float* A,
|
| 546 |
+
int lda,
|
| 547 |
+
const float* B,
|
| 548 |
+
int ldb,
|
| 549 |
+
float beta,
|
| 550 |
+
float* C,
|
| 551 |
+
int ldc);
|
| 552 |
+
void CUBLASWINAPI cublasDgemm(char transa,
|
| 553 |
+
char transb,
|
| 554 |
+
int m,
|
| 555 |
+
int n,
|
| 556 |
+
int k,
|
| 557 |
+
double alpha,
|
| 558 |
+
const double* A,
|
| 559 |
+
int lda,
|
| 560 |
+
const double* B,
|
| 561 |
+
int ldb,
|
| 562 |
+
double beta,
|
| 563 |
+
double* C,
|
| 564 |
+
int ldc);
|
| 565 |
+
void CUBLASWINAPI cublasCgemm(char transa,
|
| 566 |
+
char transb,
|
| 567 |
+
int m,
|
| 568 |
+
int n,
|
| 569 |
+
int k,
|
| 570 |
+
cuComplex alpha,
|
| 571 |
+
const cuComplex* A,
|
| 572 |
+
int lda,
|
| 573 |
+
const cuComplex* B,
|
| 574 |
+
int ldb,
|
| 575 |
+
cuComplex beta,
|
| 576 |
+
cuComplex* C,
|
| 577 |
+
int ldc);
|
| 578 |
+
void CUBLASWINAPI cublasZgemm(char transa,
|
| 579 |
+
char transb,
|
| 580 |
+
int m,
|
| 581 |
+
int n,
|
| 582 |
+
int k,
|
| 583 |
+
cuDoubleComplex alpha,
|
| 584 |
+
const cuDoubleComplex* A,
|
| 585 |
+
int lda,
|
| 586 |
+
const cuDoubleComplex* B,
|
| 587 |
+
int ldb,
|
| 588 |
+
cuDoubleComplex beta,
|
| 589 |
+
cuDoubleComplex* C,
|
| 590 |
+
int ldc);
|
| 591 |
+
/* -------------------------------------------------------*/
|
| 592 |
+
/* SYRK */
|
| 593 |
+
void CUBLASWINAPI
|
| 594 |
+
cublasSsyrk(char uplo, char trans, int n, int k, float alpha, const float* A, int lda, float beta, float* C, int ldc);
|
| 595 |
+
void CUBLASWINAPI cublasDsyrk(
|
| 596 |
+
char uplo, char trans, int n, int k, double alpha, const double* A, int lda, double beta, double* C, int ldc);
|
| 597 |
+
|
| 598 |
+
void CUBLASWINAPI cublasCsyrk(char uplo,
|
| 599 |
+
char trans,
|
| 600 |
+
int n,
|
| 601 |
+
int k,
|
| 602 |
+
cuComplex alpha,
|
| 603 |
+
const cuComplex* A,
|
| 604 |
+
int lda,
|
| 605 |
+
cuComplex beta,
|
| 606 |
+
cuComplex* C,
|
| 607 |
+
int ldc);
|
| 608 |
+
void CUBLASWINAPI cublasZsyrk(char uplo,
|
| 609 |
+
char trans,
|
| 610 |
+
int n,
|
| 611 |
+
int k,
|
| 612 |
+
cuDoubleComplex alpha,
|
| 613 |
+
const cuDoubleComplex* A,
|
| 614 |
+
int lda,
|
| 615 |
+
cuDoubleComplex beta,
|
| 616 |
+
cuDoubleComplex* C,
|
| 617 |
+
int ldc);
|
| 618 |
+
/* ------------------------------------------------------- */
|
| 619 |
+
/* HERK */
|
| 620 |
+
void CUBLASWINAPI cublasCherk(
|
| 621 |
+
char uplo, char trans, int n, int k, float alpha, const cuComplex* A, int lda, float beta, cuComplex* C, int ldc);
|
| 622 |
+
void CUBLASWINAPI cublasZherk(char uplo,
|
| 623 |
+
char trans,
|
| 624 |
+
int n,
|
| 625 |
+
int k,
|
| 626 |
+
double alpha,
|
| 627 |
+
const cuDoubleComplex* A,
|
| 628 |
+
int lda,
|
| 629 |
+
double beta,
|
| 630 |
+
cuDoubleComplex* C,
|
| 631 |
+
int ldc);
|
| 632 |
+
/* ------------------------------------------------------- */
|
| 633 |
+
/* SYR2K */
|
| 634 |
+
void CUBLASWINAPI cublasSsyr2k(char uplo,
|
| 635 |
+
char trans,
|
| 636 |
+
int n,
|
| 637 |
+
int k,
|
| 638 |
+
float alpha,
|
| 639 |
+
const float* A,
|
| 640 |
+
int lda,
|
| 641 |
+
const float* B,
|
| 642 |
+
int ldb,
|
| 643 |
+
float beta,
|
| 644 |
+
float* C,
|
| 645 |
+
int ldc);
|
| 646 |
+
|
| 647 |
+
void CUBLASWINAPI cublasDsyr2k(char uplo,
|
| 648 |
+
char trans,
|
| 649 |
+
int n,
|
| 650 |
+
int k,
|
| 651 |
+
double alpha,
|
| 652 |
+
const double* A,
|
| 653 |
+
int lda,
|
| 654 |
+
const double* B,
|
| 655 |
+
int ldb,
|
| 656 |
+
double beta,
|
| 657 |
+
double* C,
|
| 658 |
+
int ldc);
|
| 659 |
+
void CUBLASWINAPI cublasCsyr2k(char uplo,
|
| 660 |
+
char trans,
|
| 661 |
+
int n,
|
| 662 |
+
int k,
|
| 663 |
+
cuComplex alpha,
|
| 664 |
+
const cuComplex* A,
|
| 665 |
+
int lda,
|
| 666 |
+
const cuComplex* B,
|
| 667 |
+
int ldb,
|
| 668 |
+
cuComplex beta,
|
| 669 |
+
cuComplex* C,
|
| 670 |
+
int ldc);
|
| 671 |
+
|
| 672 |
+
void CUBLASWINAPI cublasZsyr2k(char uplo,
|
| 673 |
+
char trans,
|
| 674 |
+
int n,
|
| 675 |
+
int k,
|
| 676 |
+
cuDoubleComplex alpha,
|
| 677 |
+
const cuDoubleComplex* A,
|
| 678 |
+
int lda,
|
| 679 |
+
const cuDoubleComplex* B,
|
| 680 |
+
int ldb,
|
| 681 |
+
cuDoubleComplex beta,
|
| 682 |
+
cuDoubleComplex* C,
|
| 683 |
+
int ldc);
|
| 684 |
+
/* ------------------------------------------------------- */
|
| 685 |
+
/* HER2K */
|
| 686 |
+
void CUBLASWINAPI cublasCher2k(char uplo,
|
| 687 |
+
char trans,
|
| 688 |
+
int n,
|
| 689 |
+
int k,
|
| 690 |
+
cuComplex alpha,
|
| 691 |
+
const cuComplex* A,
|
| 692 |
+
int lda,
|
| 693 |
+
const cuComplex* B,
|
| 694 |
+
int ldb,
|
| 695 |
+
float beta,
|
| 696 |
+
cuComplex* C,
|
| 697 |
+
int ldc);
|
| 698 |
+
|
| 699 |
+
void CUBLASWINAPI cublasZher2k(char uplo,
|
| 700 |
+
char trans,
|
| 701 |
+
int n,
|
| 702 |
+
int k,
|
| 703 |
+
cuDoubleComplex alpha,
|
| 704 |
+
const cuDoubleComplex* A,
|
| 705 |
+
int lda,
|
| 706 |
+
const cuDoubleComplex* B,
|
| 707 |
+
int ldb,
|
| 708 |
+
double beta,
|
| 709 |
+
cuDoubleComplex* C,
|
| 710 |
+
int ldc);
|
| 711 |
+
|
| 712 |
+
/*------------------------------------------------------------------------*/
|
| 713 |
+
/* SYMM*/
|
| 714 |
+
void CUBLASWINAPI cublasSsymm(char side,
|
| 715 |
+
char uplo,
|
| 716 |
+
int m,
|
| 717 |
+
int n,
|
| 718 |
+
float alpha,
|
| 719 |
+
const float* A,
|
| 720 |
+
int lda,
|
| 721 |
+
const float* B,
|
| 722 |
+
int ldb,
|
| 723 |
+
float beta,
|
| 724 |
+
float* C,
|
| 725 |
+
int ldc);
|
| 726 |
+
void CUBLASWINAPI cublasDsymm(char side,
|
| 727 |
+
char uplo,
|
| 728 |
+
int m,
|
| 729 |
+
int n,
|
| 730 |
+
double alpha,
|
| 731 |
+
const double* A,
|
| 732 |
+
int lda,
|
| 733 |
+
const double* B,
|
| 734 |
+
int ldb,
|
| 735 |
+
double beta,
|
| 736 |
+
double* C,
|
| 737 |
+
int ldc);
|
| 738 |
+
|
| 739 |
+
void CUBLASWINAPI cublasCsymm(char side,
|
| 740 |
+
char uplo,
|
| 741 |
+
int m,
|
| 742 |
+
int n,
|
| 743 |
+
cuComplex alpha,
|
| 744 |
+
const cuComplex* A,
|
| 745 |
+
int lda,
|
| 746 |
+
const cuComplex* B,
|
| 747 |
+
int ldb,
|
| 748 |
+
cuComplex beta,
|
| 749 |
+
cuComplex* C,
|
| 750 |
+
int ldc);
|
| 751 |
+
|
| 752 |
+
void CUBLASWINAPI cublasZsymm(char side,
|
| 753 |
+
char uplo,
|
| 754 |
+
int m,
|
| 755 |
+
int n,
|
| 756 |
+
cuDoubleComplex alpha,
|
| 757 |
+
const cuDoubleComplex* A,
|
| 758 |
+
int lda,
|
| 759 |
+
const cuDoubleComplex* B,
|
| 760 |
+
int ldb,
|
| 761 |
+
cuDoubleComplex beta,
|
| 762 |
+
cuDoubleComplex* C,
|
| 763 |
+
int ldc);
|
| 764 |
+
/*------------------------------------------------------------------------*/
|
| 765 |
+
/* HEMM*/
|
| 766 |
+
void CUBLASWINAPI cublasChemm(char side,
|
| 767 |
+
char uplo,
|
| 768 |
+
int m,
|
| 769 |
+
int n,
|
| 770 |
+
cuComplex alpha,
|
| 771 |
+
const cuComplex* A,
|
| 772 |
+
int lda,
|
| 773 |
+
const cuComplex* B,
|
| 774 |
+
int ldb,
|
| 775 |
+
cuComplex beta,
|
| 776 |
+
cuComplex* C,
|
| 777 |
+
int ldc);
|
| 778 |
+
void CUBLASWINAPI cublasZhemm(char side,
|
| 779 |
+
char uplo,
|
| 780 |
+
int m,
|
| 781 |
+
int n,
|
| 782 |
+
cuDoubleComplex alpha,
|
| 783 |
+
const cuDoubleComplex* A,
|
| 784 |
+
int lda,
|
| 785 |
+
const cuDoubleComplex* B,
|
| 786 |
+
int ldb,
|
| 787 |
+
cuDoubleComplex beta,
|
| 788 |
+
cuDoubleComplex* C,
|
| 789 |
+
int ldc);
|
| 790 |
+
|
| 791 |
+
/*------------------------------------------------------------------------*/
|
| 792 |
+
/* TRSM*/
|
| 793 |
+
void CUBLASWINAPI cublasStrsm(char side,
|
| 794 |
+
char uplo,
|
| 795 |
+
char transa,
|
| 796 |
+
char diag,
|
| 797 |
+
int m,
|
| 798 |
+
int n,
|
| 799 |
+
float alpha,
|
| 800 |
+
const float* A,
|
| 801 |
+
int lda,
|
| 802 |
+
float* B,
|
| 803 |
+
int ldb);
|
| 804 |
+
|
| 805 |
+
void CUBLASWINAPI cublasDtrsm(char side,
|
| 806 |
+
char uplo,
|
| 807 |
+
char transa,
|
| 808 |
+
char diag,
|
| 809 |
+
int m,
|
| 810 |
+
int n,
|
| 811 |
+
double alpha,
|
| 812 |
+
const double* A,
|
| 813 |
+
int lda,
|
| 814 |
+
double* B,
|
| 815 |
+
int ldb);
|
| 816 |
+
|
| 817 |
+
void CUBLASWINAPI cublasCtrsm(char side,
|
| 818 |
+
char uplo,
|
| 819 |
+
char transa,
|
| 820 |
+
char diag,
|
| 821 |
+
int m,
|
| 822 |
+
int n,
|
| 823 |
+
cuComplex alpha,
|
| 824 |
+
const cuComplex* A,
|
| 825 |
+
int lda,
|
| 826 |
+
cuComplex* B,
|
| 827 |
+
int ldb);
|
| 828 |
+
|
| 829 |
+
void CUBLASWINAPI cublasZtrsm(char side,
|
| 830 |
+
char uplo,
|
| 831 |
+
char transa,
|
| 832 |
+
char diag,
|
| 833 |
+
int m,
|
| 834 |
+
int n,
|
| 835 |
+
cuDoubleComplex alpha,
|
| 836 |
+
const cuDoubleComplex* A,
|
| 837 |
+
int lda,
|
| 838 |
+
cuDoubleComplex* B,
|
| 839 |
+
int ldb);
|
| 840 |
+
/*------------------------------------------------------------------------*/
|
| 841 |
+
/* TRMM*/
|
| 842 |
+
void CUBLASWINAPI cublasStrmm(char side,
|
| 843 |
+
char uplo,
|
| 844 |
+
char transa,
|
| 845 |
+
char diag,
|
| 846 |
+
int m,
|
| 847 |
+
int n,
|
| 848 |
+
float alpha,
|
| 849 |
+
const float* A,
|
| 850 |
+
int lda,
|
| 851 |
+
float* B,
|
| 852 |
+
int ldb);
|
| 853 |
+
void CUBLASWINAPI cublasDtrmm(char side,
|
| 854 |
+
char uplo,
|
| 855 |
+
char transa,
|
| 856 |
+
char diag,
|
| 857 |
+
int m,
|
| 858 |
+
int n,
|
| 859 |
+
double alpha,
|
| 860 |
+
const double* A,
|
| 861 |
+
int lda,
|
| 862 |
+
double* B,
|
| 863 |
+
int ldb);
|
| 864 |
+
void CUBLASWINAPI cublasCtrmm(char side,
|
| 865 |
+
char uplo,
|
| 866 |
+
char transa,
|
| 867 |
+
char diag,
|
| 868 |
+
int m,
|
| 869 |
+
int n,
|
| 870 |
+
cuComplex alpha,
|
| 871 |
+
const cuComplex* A,
|
| 872 |
+
int lda,
|
| 873 |
+
cuComplex* B,
|
| 874 |
+
int ldb);
|
| 875 |
+
void CUBLASWINAPI cublasZtrmm(char side,
|
| 876 |
+
char uplo,
|
| 877 |
+
char transa,
|
| 878 |
+
char diag,
|
| 879 |
+
int m,
|
| 880 |
+
int n,
|
| 881 |
+
cuDoubleComplex alpha,
|
| 882 |
+
const cuDoubleComplex* A,
|
| 883 |
+
int lda,
|
| 884 |
+
cuDoubleComplex* B,
|
| 885 |
+
int ldb);
|
| 886 |
+
|
| 887 |
+
#if defined(__cplusplus)
|
| 888 |
+
}
|
| 889 |
+
#endif /* __cplusplus */
|
| 890 |
+
|
| 891 |
+
#endif /* !defined(CUBLAS_H_) */
|
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h
ADDED
|
@@ -0,0 +1,1845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
#pragma once
|
| 50 |
+
|
| 51 |
+
#ifndef CUBLASAPI
|
| 52 |
+
#ifdef __CUDACC__
|
| 53 |
+
#define CUBLASAPI __host__ __device__
|
| 54 |
+
#else
|
| 55 |
+
#define CUBLASAPI
|
| 56 |
+
#endif
|
| 57 |
+
#endif
|
| 58 |
+
|
| 59 |
+
#include <cublas_api.h>
|
| 60 |
+
|
| 61 |
+
#include <stdint.h>
|
| 62 |
+
#include <stddef.h>
|
| 63 |
+
#include <stdio.h>
|
| 64 |
+
|
| 65 |
+
#if defined(__cplusplus)
|
| 66 |
+
extern "C" {
|
| 67 |
+
#endif /* __cplusplus */
|
| 68 |
+
|
| 69 |
+
/** Opaque structure holding CUBLASLT context
|
| 70 |
+
*/
|
| 71 |
+
typedef struct cublasLtContext* cublasLtHandle_t;
|
| 72 |
+
|
| 73 |
+
cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t* lightHandle);
|
| 74 |
+
|
| 75 |
+
cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle);
|
| 76 |
+
|
| 77 |
+
const char* CUBLASWINAPI cublasLtGetStatusName(cublasStatus_t status);
|
| 78 |
+
|
| 79 |
+
const char* CUBLASWINAPI cublasLtGetStatusString(cublasStatus_t status);
|
| 80 |
+
|
| 81 |
+
size_t CUBLASWINAPI cublasLtGetVersion(void);
|
| 82 |
+
|
| 83 |
+
size_t CUBLASWINAPI cublasLtGetCudartVersion(void);
|
| 84 |
+
|
| 85 |
+
cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type, int* value);
|
| 86 |
+
|
| 87 |
+
cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheGetCapacity(size_t* capacity);
|
| 88 |
+
cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheSetCapacity(size_t capacity);
|
| 89 |
+
|
| 90 |
+
/** Restricts usage of CPU instructions (ISA) specified by the flags in the mask.
|
| 91 |
+
*
|
| 92 |
+
* Flags can be combined with bitwise OR(|) operator. Supported flags:
|
| 93 |
+
* - 0x1 -- x86-64 AVX512 ISA
|
| 94 |
+
*
|
| 95 |
+
* Default mask: 0 (any applicable ISA is allowed).
|
| 96 |
+
*
|
| 97 |
+
* The function returns the previous value of the mask.
|
| 98 |
+
* The function takes precedence over the environment variable CUBLASLT_DISABLE_CPU_INSTRUCTIONS_MASK.
|
| 99 |
+
*/
|
| 100 |
+
unsigned CUBLASWINAPI cublasLtDisableCpuInstructionsSetMask(unsigned mask);
|
| 101 |
+
|
| 102 |
+
/** Semi-opaque descriptor for matrix memory layout
|
| 103 |
+
*/
|
| 104 |
+
typedef struct {
|
| 105 |
+
uint64_t data[8];
|
| 106 |
+
} cublasLtMatrixLayoutOpaque_t;
|
| 107 |
+
|
| 108 |
+
/** Opaque descriptor for matrix memory layout
|
| 109 |
+
*/
|
| 110 |
+
typedef cublasLtMatrixLayoutOpaque_t* cublasLtMatrixLayout_t;
|
| 111 |
+
|
| 112 |
+
/** Semi-opaque algorithm descriptor (to avoid complicated alloc/free schemes)
|
| 113 |
+
*
|
| 114 |
+
* This structure can be trivially serialized and later restored for use with the same version of cuBLAS library to save
|
| 115 |
+
* on selecting the right configuration again.
|
| 116 |
+
*/
|
| 117 |
+
typedef struct {
|
| 118 |
+
uint64_t data[8];
|
| 119 |
+
} cublasLtMatmulAlgo_t;
|
| 120 |
+
|
| 121 |
+
/** Semi-opaque descriptor for cublasLtMatmul() operation details
|
| 122 |
+
*/
|
| 123 |
+
typedef struct {
|
| 124 |
+
uint64_t data[32];
|
| 125 |
+
} cublasLtMatmulDescOpaque_t;
|
| 126 |
+
|
| 127 |
+
/** Opaque descriptor for cublasLtMatmul() operation details
|
| 128 |
+
*/
|
| 129 |
+
typedef cublasLtMatmulDescOpaque_t* cublasLtMatmulDesc_t;
|
| 130 |
+
|
| 131 |
+
/** Semi-opaque descriptor for cublasLtMatrixTransform() operation details
|
| 132 |
+
*/
|
| 133 |
+
typedef struct {
|
| 134 |
+
uint64_t data[8];
|
| 135 |
+
} cublasLtMatrixTransformDescOpaque_t;
|
| 136 |
+
|
| 137 |
+
/** Opaque descriptor for cublasLtMatrixTransform() operation details
|
| 138 |
+
*/
|
| 139 |
+
typedef cublasLtMatrixTransformDescOpaque_t* cublasLtMatrixTransformDesc_t;
|
| 140 |
+
|
| 141 |
+
/** Semi-opaque descriptor for cublasLtMatmulPreference() operation details
|
| 142 |
+
*/
|
| 143 |
+
typedef struct {
|
| 144 |
+
uint64_t data[8];
|
| 145 |
+
} cublasLtMatmulPreferenceOpaque_t;
|
| 146 |
+
|
| 147 |
+
/** Opaque descriptor for cublasLtMatmulAlgoGetHeuristic() configuration
|
| 148 |
+
*/
|
| 149 |
+
typedef cublasLtMatmulPreferenceOpaque_t* cublasLtMatmulPreference_t;
|
| 150 |
+
|
| 151 |
+
/** Tile size (in C/D matrix Rows x Cols)
|
| 152 |
+
*
|
| 153 |
+
* General order of tile IDs is sorted by size first and by first dimension second.
|
| 154 |
+
*/
|
| 155 |
+
typedef enum {
|
| 156 |
+
CUBLASLT_MATMUL_TILE_UNDEFINED = 0,
|
| 157 |
+
CUBLASLT_MATMUL_TILE_8x8 = 1,
|
| 158 |
+
CUBLASLT_MATMUL_TILE_8x16 = 2,
|
| 159 |
+
CUBLASLT_MATMUL_TILE_16x8 = 3,
|
| 160 |
+
CUBLASLT_MATMUL_TILE_8x32 = 4,
|
| 161 |
+
CUBLASLT_MATMUL_TILE_16x16 = 5,
|
| 162 |
+
CUBLASLT_MATMUL_TILE_32x8 = 6,
|
| 163 |
+
CUBLASLT_MATMUL_TILE_8x64 = 7,
|
| 164 |
+
CUBLASLT_MATMUL_TILE_16x32 = 8,
|
| 165 |
+
CUBLASLT_MATMUL_TILE_32x16 = 9,
|
| 166 |
+
CUBLASLT_MATMUL_TILE_64x8 = 10,
|
| 167 |
+
CUBLASLT_MATMUL_TILE_32x32 = 11,
|
| 168 |
+
CUBLASLT_MATMUL_TILE_32x64 = 12,
|
| 169 |
+
CUBLASLT_MATMUL_TILE_64x32 = 13,
|
| 170 |
+
CUBLASLT_MATMUL_TILE_32x128 = 14,
|
| 171 |
+
CUBLASLT_MATMUL_TILE_64x64 = 15,
|
| 172 |
+
CUBLASLT_MATMUL_TILE_128x32 = 16,
|
| 173 |
+
CUBLASLT_MATMUL_TILE_64x128 = 17,
|
| 174 |
+
CUBLASLT_MATMUL_TILE_128x64 = 18,
|
| 175 |
+
CUBLASLT_MATMUL_TILE_64x256 = 19,
|
| 176 |
+
CUBLASLT_MATMUL_TILE_128x128 = 20,
|
| 177 |
+
CUBLASLT_MATMUL_TILE_256x64 = 21,
|
| 178 |
+
CUBLASLT_MATMUL_TILE_64x512 = 22,
|
| 179 |
+
CUBLASLT_MATMUL_TILE_128x256 = 23,
|
| 180 |
+
CUBLASLT_MATMUL_TILE_256x128 = 24,
|
| 181 |
+
CUBLASLT_MATMUL_TILE_512x64 = 25,
|
| 182 |
+
CUBLASLT_MATMUL_TILE_64x96 = 26,
|
| 183 |
+
CUBLASLT_MATMUL_TILE_96x64 = 27,
|
| 184 |
+
CUBLASLT_MATMUL_TILE_96x128 = 28,
|
| 185 |
+
CUBLASLT_MATMUL_TILE_128x160 = 29,
|
| 186 |
+
CUBLASLT_MATMUL_TILE_160x128 = 30,
|
| 187 |
+
CUBLASLT_MATMUL_TILE_192x128 = 31,
|
| 188 |
+
CUBLASLT_MATMUL_TILE_128x192 = 32,
|
| 189 |
+
CUBLASLT_MATMUL_TILE_128x96 = 33,
|
| 190 |
+
CUBLASLT_MATMUL_TILE_32x256 = 34,
|
| 191 |
+
CUBLASLT_MATMUL_TILE_256x32 = 35,
|
| 192 |
+
CUBLASLT_MATMUL_TILE_END
|
| 193 |
+
} cublasLtMatmulTile_t;
|
| 194 |
+
|
| 195 |
+
/** Size and number of stages in which elements are read into shared memory
|
| 196 |
+
*
|
| 197 |
+
* General order of stages IDs is sorted by stage size first and by number of stages second.
|
| 198 |
+
*/
|
| 199 |
+
typedef enum {
|
| 200 |
+
CUBLASLT_MATMUL_STAGES_UNDEFINED = 0,
|
| 201 |
+
CUBLASLT_MATMUL_STAGES_16x1 = 1,
|
| 202 |
+
CUBLASLT_MATMUL_STAGES_16x2 = 2,
|
| 203 |
+
CUBLASLT_MATMUL_STAGES_16x3 = 3,
|
| 204 |
+
CUBLASLT_MATMUL_STAGES_16x4 = 4,
|
| 205 |
+
CUBLASLT_MATMUL_STAGES_16x5 = 5,
|
| 206 |
+
CUBLASLT_MATMUL_STAGES_16x6 = 6,
|
| 207 |
+
CUBLASLT_MATMUL_STAGES_32x1 = 7,
|
| 208 |
+
CUBLASLT_MATMUL_STAGES_32x2 = 8,
|
| 209 |
+
CUBLASLT_MATMUL_STAGES_32x3 = 9,
|
| 210 |
+
CUBLASLT_MATMUL_STAGES_32x4 = 10,
|
| 211 |
+
CUBLASLT_MATMUL_STAGES_32x5 = 11,
|
| 212 |
+
CUBLASLT_MATMUL_STAGES_32x6 = 12,
|
| 213 |
+
CUBLASLT_MATMUL_STAGES_64x1 = 13,
|
| 214 |
+
CUBLASLT_MATMUL_STAGES_64x2 = 14,
|
| 215 |
+
CUBLASLT_MATMUL_STAGES_64x3 = 15,
|
| 216 |
+
CUBLASLT_MATMUL_STAGES_64x4 = 16,
|
| 217 |
+
CUBLASLT_MATMUL_STAGES_64x5 = 17,
|
| 218 |
+
CUBLASLT_MATMUL_STAGES_64x6 = 18,
|
| 219 |
+
CUBLASLT_MATMUL_STAGES_128x1 = 19,
|
| 220 |
+
CUBLASLT_MATMUL_STAGES_128x2 = 20,
|
| 221 |
+
CUBLASLT_MATMUL_STAGES_128x3 = 21,
|
| 222 |
+
CUBLASLT_MATMUL_STAGES_128x4 = 22,
|
| 223 |
+
CUBLASLT_MATMUL_STAGES_128x5 = 23,
|
| 224 |
+
CUBLASLT_MATMUL_STAGES_128x6 = 24,
|
| 225 |
+
CUBLASLT_MATMUL_STAGES_32x10 = 25,
|
| 226 |
+
CUBLASLT_MATMUL_STAGES_8x4 = 26,
|
| 227 |
+
CUBLASLT_MATMUL_STAGES_16x10 = 27,
|
| 228 |
+
CUBLASLT_MATMUL_STAGES_8x5 = 28,
|
| 229 |
+
CUBLASLT_MATMUL_STAGES_8x3 = 31,
|
| 230 |
+
CUBLASLT_MATMUL_STAGES_8xAUTO = 32,
|
| 231 |
+
CUBLASLT_MATMUL_STAGES_16xAUTO = 33,
|
| 232 |
+
CUBLASLT_MATMUL_STAGES_32xAUTO = 34,
|
| 233 |
+
CUBLASLT_MATMUL_STAGES_64xAUTO = 35,
|
| 234 |
+
CUBLASLT_MATMUL_STAGES_128xAUTO = 36,
|
| 235 |
+
CUBLASLT_MATMUL_STAGES_END
|
| 236 |
+
} cublasLtMatmulStages_t;
|
| 237 |
+
|
| 238 |
+
/** Thread Block Cluster size
|
| 239 |
+
*
|
| 240 |
+
* Typically dimensioned similar to cublasLtMatmulTile_t, with the third coordinate unused at this time.
|
| 241 |
+
*/
|
| 242 |
+
typedef enum {
|
| 243 |
+
/** Let library pick cluster shape automatically */
|
| 244 |
+
CUBLASLT_CLUSTER_SHAPE_AUTO = 0,
|
| 245 |
+
CUBLASLT_CLUSTER_SHAPE_1x1x1 = 2,
|
| 246 |
+
CUBLASLT_CLUSTER_SHAPE_2x1x1 = 3,
|
| 247 |
+
CUBLASLT_CLUSTER_SHAPE_4x1x1 = 4,
|
| 248 |
+
CUBLASLT_CLUSTER_SHAPE_1x2x1 = 5,
|
| 249 |
+
CUBLASLT_CLUSTER_SHAPE_2x2x1 = 6,
|
| 250 |
+
CUBLASLT_CLUSTER_SHAPE_4x2x1 = 7,
|
| 251 |
+
CUBLASLT_CLUSTER_SHAPE_1x4x1 = 8,
|
| 252 |
+
CUBLASLT_CLUSTER_SHAPE_2x4x1 = 9,
|
| 253 |
+
CUBLASLT_CLUSTER_SHAPE_4x4x1 = 10,
|
| 254 |
+
CUBLASLT_CLUSTER_SHAPE_8x1x1 = 11,
|
| 255 |
+
CUBLASLT_CLUSTER_SHAPE_1x8x1 = 12,
|
| 256 |
+
CUBLASLT_CLUSTER_SHAPE_8x2x1 = 13,
|
| 257 |
+
CUBLASLT_CLUSTER_SHAPE_2x8x1 = 14,
|
| 258 |
+
CUBLASLT_CLUSTER_SHAPE_16x1x1 = 15,
|
| 259 |
+
CUBLASLT_CLUSTER_SHAPE_1x16x1 = 16,
|
| 260 |
+
CUBLASLT_CLUSTER_SHAPE_3x1x1 = 17,
|
| 261 |
+
CUBLASLT_CLUSTER_SHAPE_5x1x1 = 18,
|
| 262 |
+
CUBLASLT_CLUSTER_SHAPE_6x1x1 = 19,
|
| 263 |
+
CUBLASLT_CLUSTER_SHAPE_7x1x1 = 20,
|
| 264 |
+
CUBLASLT_CLUSTER_SHAPE_9x1x1 = 21,
|
| 265 |
+
CUBLASLT_CLUSTER_SHAPE_10x1x1 = 22,
|
| 266 |
+
CUBLASLT_CLUSTER_SHAPE_11x1x1 = 23,
|
| 267 |
+
CUBLASLT_CLUSTER_SHAPE_12x1x1 = 24,
|
| 268 |
+
CUBLASLT_CLUSTER_SHAPE_13x1x1 = 25,
|
| 269 |
+
CUBLASLT_CLUSTER_SHAPE_14x1x1 = 26,
|
| 270 |
+
CUBLASLT_CLUSTER_SHAPE_15x1x1 = 27,
|
| 271 |
+
CUBLASLT_CLUSTER_SHAPE_3x2x1 = 28,
|
| 272 |
+
CUBLASLT_CLUSTER_SHAPE_5x2x1 = 29,
|
| 273 |
+
CUBLASLT_CLUSTER_SHAPE_6x2x1 = 30,
|
| 274 |
+
CUBLASLT_CLUSTER_SHAPE_7x2x1 = 31,
|
| 275 |
+
CUBLASLT_CLUSTER_SHAPE_1x3x1 = 32,
|
| 276 |
+
CUBLASLT_CLUSTER_SHAPE_2x3x1 = 33,
|
| 277 |
+
CUBLASLT_CLUSTER_SHAPE_3x3x1 = 34,
|
| 278 |
+
CUBLASLT_CLUSTER_SHAPE_4x3x1 = 35,
|
| 279 |
+
CUBLASLT_CLUSTER_SHAPE_5x3x1 = 36,
|
| 280 |
+
CUBLASLT_CLUSTER_SHAPE_3x4x1 = 37,
|
| 281 |
+
CUBLASLT_CLUSTER_SHAPE_1x5x1 = 38,
|
| 282 |
+
CUBLASLT_CLUSTER_SHAPE_2x5x1 = 39,
|
| 283 |
+
CUBLASLT_CLUSTER_SHAPE_3x5x1 = 40,
|
| 284 |
+
CUBLASLT_CLUSTER_SHAPE_1x6x1 = 41,
|
| 285 |
+
CUBLASLT_CLUSTER_SHAPE_2x6x1 = 42,
|
| 286 |
+
CUBLASLT_CLUSTER_SHAPE_1x7x1 = 43,
|
| 287 |
+
CUBLASLT_CLUSTER_SHAPE_2x7x1 = 44,
|
| 288 |
+
CUBLASLT_CLUSTER_SHAPE_1x9x1 = 45,
|
| 289 |
+
CUBLASLT_CLUSTER_SHAPE_1x10x1 = 46,
|
| 290 |
+
CUBLASLT_CLUSTER_SHAPE_1x11x1 = 47,
|
| 291 |
+
CUBLASLT_CLUSTER_SHAPE_1x12x1 = 48,
|
| 292 |
+
CUBLASLT_CLUSTER_SHAPE_1x13x1 = 49,
|
| 293 |
+
CUBLASLT_CLUSTER_SHAPE_1x14x1 = 50,
|
| 294 |
+
CUBLASLT_CLUSTER_SHAPE_1x15x1 = 51,
|
| 295 |
+
CUBLASLT_CLUSTER_SHAPE_END
|
| 296 |
+
} cublasLtClusterShape_t;
|
| 297 |
+
|
| 298 |
+
/** Inner size of the kernel
|
| 299 |
+
*
|
| 300 |
+
* Represents various aspects of internal kernel design, that don't impact CUDA grid size but may have other more subtle
|
| 301 |
+
* effects.
|
| 302 |
+
*
|
| 303 |
+
*/
|
| 304 |
+
typedef enum {
|
| 305 |
+
CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED = 0,
|
| 306 |
+
CUBLASLT_MATMUL_INNER_SHAPE_MMA884 = 1,
|
| 307 |
+
CUBLASLT_MATMUL_INNER_SHAPE_MMA1684 = 2,
|
| 308 |
+
CUBLASLT_MATMUL_INNER_SHAPE_MMA1688 = 3,
|
| 309 |
+
CUBLASLT_MATMUL_INNER_SHAPE_MMA16816 = 4,
|
| 310 |
+
CUBLASLT_MATMUL_INNER_SHAPE_END
|
| 311 |
+
} cublasLtMatmulInnerShape_t;
|
| 312 |
+
|
| 313 |
+
/** Pointer mode to use for alpha/beta */
|
| 314 |
+
typedef enum {
|
| 315 |
+
/** matches CUBLAS_POINTER_MODE_HOST, pointer targets a single value host memory */
|
| 316 |
+
CUBLASLT_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST,
|
| 317 |
+
/** matches CUBLAS_POINTER_MODE_DEVICE, pointer targets a single value device memory */
|
| 318 |
+
CUBLASLT_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE,
|
| 319 |
+
/** pointer targets an array in device memory */
|
| 320 |
+
CUBLASLT_POINTER_MODE_DEVICE_VECTOR = 2,
|
| 321 |
+
/** alpha pointer targets an array in device memory, beta is zero. Note:
|
| 322 |
+
CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is not supported, must be 0. */
|
| 323 |
+
CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO = 3,
|
| 324 |
+
/** alpha pointer targets an array in device memory, beta is a single value in host memory. */
|
| 325 |
+
CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST = 4,
|
| 326 |
+
} cublasLtPointerMode_t;
|
| 327 |
+
|
| 328 |
+
/** Mask to define pointer mode capability */
|
| 329 |
+
typedef enum {
|
| 330 |
+
/** see CUBLASLT_POINTER_MODE_HOST */
|
| 331 |
+
CUBLASLT_POINTER_MODE_MASK_HOST = 1,
|
| 332 |
+
/** see CUBLASLT_POINTER_MODE_DEVICE */
|
| 333 |
+
CUBLASLT_POINTER_MODE_MASK_DEVICE = 2,
|
| 334 |
+
/** see CUBLASLT_POINTER_MODE_DEVICE_VECTOR */
|
| 335 |
+
CUBLASLT_POINTER_MODE_MASK_DEVICE_VECTOR = 4,
|
| 336 |
+
/** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO */
|
| 337 |
+
CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_ZERO = 8,
|
| 338 |
+
/** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST */
|
| 339 |
+
CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_HOST = 16,
|
| 340 |
+
} cublasLtPointerModeMask_t;
|
| 341 |
+
|
| 342 |
+
/** Implementation details that may affect numerical behavior of algorithms. */
|
| 343 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA (0x01ull << 0)
|
| 344 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA (0x02ull << 0)
|
| 345 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_IMMA (0x04ull << 0)
|
| 346 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_DMMA (0x08ull << 0)
|
| 347 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK (0xfeull << 0)
|
| 348 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK (0xffull << 0)
|
| 349 |
+
|
| 350 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_16F (0x01ull << 8)
|
| 351 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32F (0x02ull << 8)
|
| 352 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_64F (0x04ull << 8)
|
| 353 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32I (0x08ull << 8)
|
| 354 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_TYPE_MASK (0xffull << 8)
|
| 355 |
+
|
| 356 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16F (0x01ull << 16)
|
| 357 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16BF (0x02ull << 16)
|
| 358 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_TF32 (0x04ull << 16)
|
| 359 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_32F (0x08ull << 16)
|
| 360 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_64F (0x10ull << 16)
|
| 361 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8I (0x20ull << 16)
|
| 362 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E4M3 (0x40ull << 16)
|
| 363 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E5M2 (0x80ull << 16)
|
| 364 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK (0xffull << 16)
|
| 365 |
+
|
| 366 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN (0x01ull << 32)
|
| 367 |
+
typedef uint64_t cublasLtNumericalImplFlags_t;
|
| 368 |
+
|
| 369 |
+
/** Execute matrix multiplication (D = alpha * op(A) * op(B) + beta * C).
|
| 370 |
+
*
|
| 371 |
+
* \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized
|
| 372 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g.
|
| 373 |
+
* when workspaceSizeInBytes is less than workspace required by configured
|
| 374 |
+
* algo
|
| 375 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured
|
| 376 |
+
* operation
|
| 377 |
+
* \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device
|
| 378 |
+
* \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device
|
| 379 |
+
* \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully
|
| 380 |
+
*/
|
| 381 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmul(cublasLtHandle_t lightHandle,
|
| 382 |
+
cublasLtMatmulDesc_t computeDesc,
|
| 383 |
+
const void* alpha, /* host or device pointer */
|
| 384 |
+
const void* A,
|
| 385 |
+
cublasLtMatrixLayout_t Adesc,
|
| 386 |
+
const void* B,
|
| 387 |
+
cublasLtMatrixLayout_t Bdesc,
|
| 388 |
+
const void* beta, /* host or device pointer */
|
| 389 |
+
const void* C,
|
| 390 |
+
cublasLtMatrixLayout_t Cdesc,
|
| 391 |
+
void* D,
|
| 392 |
+
cublasLtMatrixLayout_t Ddesc,
|
| 393 |
+
const cublasLtMatmulAlgo_t* algo,
|
| 394 |
+
void* workspace,
|
| 395 |
+
size_t workspaceSizeInBytes,
|
| 396 |
+
cudaStream_t stream);
|
| 397 |
+
|
| 398 |
+
/** Matrix layout conversion helper (C = alpha * op(A) + beta * op(B))
|
| 399 |
+
*
|
| 400 |
+
* Can be used to change memory order of data or to scale and shift the values.
|
| 401 |
+
*
|
| 402 |
+
* \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized
|
| 403 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g.
|
| 404 |
+
* when A is not NULL, but Adesc is NULL
|
| 405 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured
|
| 406 |
+
* operation
|
| 407 |
+
* \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device
|
| 408 |
+
* \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device
|
| 409 |
+
* \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully
|
| 410 |
+
*/
|
| 411 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(cublasLtHandle_t lightHandle,
|
| 412 |
+
cublasLtMatrixTransformDesc_t transformDesc,
|
| 413 |
+
const void* alpha, /* host or device pointer */
|
| 414 |
+
const void* A,
|
| 415 |
+
cublasLtMatrixLayout_t Adesc,
|
| 416 |
+
const void* beta, /* host or device pointer */
|
| 417 |
+
const void* B,
|
| 418 |
+
cublasLtMatrixLayout_t Bdesc,
|
| 419 |
+
void* C,
|
| 420 |
+
cublasLtMatrixLayout_t Cdesc,
|
| 421 |
+
cudaStream_t stream);
|
| 422 |
+
|
| 423 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 424 |
+
/* Helper functions for cublasLtMatrixLayout_t */
|
| 425 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 426 |
+
|
| 427 |
+
/** Enum for data ordering */
|
| 428 |
+
typedef enum {
|
| 429 |
+
/** Column-major
|
| 430 |
+
*
|
| 431 |
+
* Leading dimension is the stride (in elements) to the beginning of next column in memory.
|
| 432 |
+
*/
|
| 433 |
+
CUBLASLT_ORDER_COL = 0,
|
| 434 |
+
/** Row major
|
| 435 |
+
*
|
| 436 |
+
* Leading dimension is the stride (in elements) to the beginning of next row in memory.
|
| 437 |
+
*/
|
| 438 |
+
CUBLASLT_ORDER_ROW = 1,
|
| 439 |
+
/** Column-major ordered tiles of 32 columns.
|
| 440 |
+
*
|
| 441 |
+
* Leading dimension is the stride (in elements) to the beginning of next group of 32-columns. E.g. if matrix has 33
|
| 442 |
+
* columns and 2 rows, ld must be at least (32) * 2 = 64.
|
| 443 |
+
*/
|
| 444 |
+
CUBLASLT_ORDER_COL32 = 2,
|
| 445 |
+
/** Column-major ordered tiles of composite tiles with total 32 columns and 8 rows, tile composed of interleaved
|
| 446 |
+
* inner tiles of 4 columns within 4 even or odd rows in an alternating pattern.
|
| 447 |
+
*
|
| 448 |
+
* Leading dimension is the stride (in elements) to the beginning of the first 32 column x 8 row tile for the next
|
| 449 |
+
* 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32 * 8) * 1 = 256.
|
| 450 |
+
*/
|
| 451 |
+
CUBLASLT_ORDER_COL4_4R2_8C = 3,
|
| 452 |
+
/** Column-major ordered tiles of composite tiles with total 32 columns ands 32 rows.
|
| 453 |
+
* Element offset within the tile is calculated as (((row%8)/2*4+row/8)*2+row%2)*32+col.
|
| 454 |
+
*
|
| 455 |
+
* Leading dimension is the stride (in elements) to the beginning of the first 32 column x 32 row tile for the next
|
| 456 |
+
* 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32*32)*1 = 1024.
|
| 457 |
+
*/
|
| 458 |
+
CUBLASLT_ORDER_COL32_2R_4R4 = 4,
|
| 459 |
+
|
| 460 |
+
} cublasLtOrder_t;
|
| 461 |
+
|
| 462 |
+
/** Attributes of memory layout */
|
| 463 |
+
typedef enum {
|
| 464 |
+
/** Data type, see cudaDataType.
|
| 465 |
+
*
|
| 466 |
+
* uint32_t
|
| 467 |
+
*/
|
| 468 |
+
CUBLASLT_MATRIX_LAYOUT_TYPE = 0,
|
| 469 |
+
|
| 470 |
+
/** Memory order of the data, see cublasLtOrder_t.
|
| 471 |
+
*
|
| 472 |
+
* int32_t, default: CUBLASLT_ORDER_COL
|
| 473 |
+
*/
|
| 474 |
+
CUBLASLT_MATRIX_LAYOUT_ORDER = 1,
|
| 475 |
+
|
| 476 |
+
/** Number of rows.
|
| 477 |
+
*
|
| 478 |
+
* Usually only values that can be expressed as int32_t are supported.
|
| 479 |
+
*
|
| 480 |
+
* uint64_t
|
| 481 |
+
*/
|
| 482 |
+
CUBLASLT_MATRIX_LAYOUT_ROWS = 2,
|
| 483 |
+
|
| 484 |
+
/** Number of columns.
|
| 485 |
+
*
|
| 486 |
+
* Usually only values that can be expressed as int32_t are supported.
|
| 487 |
+
*
|
| 488 |
+
* uint64_t
|
| 489 |
+
*/
|
| 490 |
+
CUBLASLT_MATRIX_LAYOUT_COLS = 3,
|
| 491 |
+
|
| 492 |
+
/** Matrix leading dimension.
|
| 493 |
+
*
|
| 494 |
+
* For CUBLASLT_ORDER_COL this is stride (in elements) of matrix column, for more details and documentation for
|
| 495 |
+
* other memory orders see documentation for cublasLtOrder_t values.
|
| 496 |
+
*
|
| 497 |
+
* Currently only non-negative values are supported, must be large enough so that matrix memory locations are not
|
| 498 |
+
* overlapping (e.g. greater or equal to CUBLASLT_MATRIX_LAYOUT_ROWS in case of CUBLASLT_ORDER_COL).
|
| 499 |
+
*
|
| 500 |
+
* int64_t;
|
| 501 |
+
*/
|
| 502 |
+
CUBLASLT_MATRIX_LAYOUT_LD = 4,
|
| 503 |
+
|
| 504 |
+
/** Number of matmul operations to perform in the batch.
|
| 505 |
+
*
|
| 506 |
+
* See also CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT
|
| 507 |
+
*
|
| 508 |
+
* int32_t, default: 1
|
| 509 |
+
*/
|
| 510 |
+
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5,
|
| 511 |
+
|
| 512 |
+
/** Stride (in elements) to the next matrix for strided batch operation.
|
| 513 |
+
*
|
| 514 |
+
* When matrix type is planar-complex (CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET != 0), batch stride
|
| 515 |
+
* is interpreted by cublasLtMatmul() in number of real valued sub-elements. E.g. for data of type CUDA_C_16F,
|
| 516 |
+
* offset of 1024B is encoded as a stride of value 512 (since each element of the real and imaginary matrices
|
| 517 |
+
* is a 2B (16bit) floating point type).
|
| 518 |
+
*
|
| 519 |
+
* NOTE: A bug in cublasLtMatrixTransform() causes it to interpret the batch stride for a planar-complex matrix
|
| 520 |
+
* as if it was specified in number of complex elements. Therefore an offset of 1024B must be encoded as stride
|
| 521 |
+
* value 256 when calling cublasLtMatrixTransform() (each complex element is 4B with real and imaginary values 2B
|
| 522 |
+
* each). This behavior is expected to be corrected in the next major cuBLAS version.
|
| 523 |
+
*
|
| 524 |
+
* int64_t, default: 0
|
| 525 |
+
*/
|
| 526 |
+
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6,
|
| 527 |
+
|
| 528 |
+
/** Stride (in bytes) to the imaginary plane for planar complex layout.
|
| 529 |
+
*
|
| 530 |
+
* int64_t, default: 0 - 0 means that layout is regular (real and imaginary parts of complex numbers are interleaved
|
| 531 |
+
* in memory in each element)
|
| 532 |
+
*/
|
| 533 |
+
CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET = 7,
|
| 534 |
+
} cublasLtMatrixLayoutAttribute_t;
|
| 535 |
+
|
| 536 |
+
/** Internal. Do not use directly.
|
| 537 |
+
*/
|
| 538 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( //
|
| 539 |
+
cublasLtMatrixLayout_t matLayout,
|
| 540 |
+
size_t size,
|
| 541 |
+
cudaDataType type,
|
| 542 |
+
uint64_t rows,
|
| 543 |
+
uint64_t cols,
|
| 544 |
+
int64_t ld);
|
| 545 |
+
|
| 546 |
+
/** Initialize matrix layout descriptor in pre-allocated space.
|
| 547 |
+
*
|
| 548 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
|
| 549 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 550 |
+
*/
|
| 551 |
+
static inline cublasStatus_t cublasLtMatrixLayoutInit(
|
| 552 |
+
cublasLtMatrixLayout_t matLayout, cudaDataType type, uint64_t rows, uint64_t cols, int64_t ld) {
|
| 553 |
+
return cublasLtMatrixLayoutInit_internal(matLayout, sizeof(*matLayout), type, rows, cols, ld);
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
/** Create new matrix layout descriptor.
|
| 557 |
+
*
|
| 558 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
|
| 559 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 560 |
+
*/
|
| 561 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( //
|
| 562 |
+
cublasLtMatrixLayout_t* matLayout,
|
| 563 |
+
cudaDataType type,
|
| 564 |
+
uint64_t rows,
|
| 565 |
+
uint64_t cols,
|
| 566 |
+
int64_t ld);
|
| 567 |
+
|
| 568 |
+
/** Destroy matrix layout descriptor.
|
| 569 |
+
*
|
| 570 |
+
* \retval CUBLAS_STATUS_SUCCESS if operation was successful
|
| 571 |
+
*/
|
| 572 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout);
|
| 573 |
+
|
| 574 |
+
/** Set matrix layout descriptor attribute.
|
| 575 |
+
*
|
| 576 |
+
* \param[in] matLayout The descriptor
|
| 577 |
+
* \param[in] attr The attribute
|
| 578 |
+
* \param[in] buf memory address containing the new value
|
| 579 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 580 |
+
*
|
| 581 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 582 |
+
* selected attribute
|
| 583 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 584 |
+
*/
|
| 585 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( //
|
| 586 |
+
cublasLtMatrixLayout_t matLayout,
|
| 587 |
+
cublasLtMatrixLayoutAttribute_t attr,
|
| 588 |
+
const void* buf,
|
| 589 |
+
size_t sizeInBytes);
|
| 590 |
+
|
| 591 |
+
/** Get matrix layout descriptor attribute.
|
| 592 |
+
*
|
| 593 |
+
* \param[in] matLayout The descriptor
|
| 594 |
+
* \param[in] attr The attribute
|
| 595 |
+
* \param[out] buf memory address containing the new value
|
| 596 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 597 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 598 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 599 |
+
*
|
| 600 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 601 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 602 |
+
* selected attribute
|
| 603 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 604 |
+
*/
|
| 605 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( //
|
| 606 |
+
cublasLtMatrixLayout_t matLayout,
|
| 607 |
+
cublasLtMatrixLayoutAttribute_t attr,
|
| 608 |
+
void* buf,
|
| 609 |
+
size_t sizeInBytes,
|
| 610 |
+
size_t* sizeWritten);
|
| 611 |
+
|
| 612 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 613 |
+
/* Helper functions for cublasLtMatmulDesc_t */
|
| 614 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 615 |
+
|
| 616 |
+
/** Matmul descriptor attributes to define details of the operation. */
|
| 617 |
+
typedef enum {
|
| 618 |
+
/** Compute type, see cudaDataType. Defines data type used for multiply and accumulate operations and the
|
| 619 |
+
* accumulator during matrix multiplication.
|
| 620 |
+
*
|
| 621 |
+
* int32_t
|
| 622 |
+
*/
|
| 623 |
+
CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0,
|
| 624 |
+
|
| 625 |
+
/** Scale type, see cudaDataType. Defines data type of alpha and beta. Accumulator and value from matrix C are
|
| 626 |
+
* typically converted to scale type before final scaling. Value is then converted from scale type to type of matrix
|
| 627 |
+
* D before being stored in memory.
|
| 628 |
+
*
|
| 629 |
+
* int32_t, default: same as CUBLASLT_MATMUL_DESC_COMPUTE_TYPE
|
| 630 |
+
*/
|
| 631 |
+
CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1,
|
| 632 |
+
|
| 633 |
+
/** Pointer mode of alpha and beta, see cublasLtPointerMode_t. When CUBLASLT_POINTER_MODE_DEVICE_VECTOR is in use,
|
| 634 |
+
* alpha/beta vector lenghts must match number of output matrix rows.
|
| 635 |
+
*
|
| 636 |
+
* int32_t, default: CUBLASLT_POINTER_MODE_HOST
|
| 637 |
+
*/
|
| 638 |
+
CUBLASLT_MATMUL_DESC_POINTER_MODE = 2,
|
| 639 |
+
|
| 640 |
+
/** Transform of matrix A, see cublasOperation_t.
|
| 641 |
+
*
|
| 642 |
+
* int32_t, default: CUBLAS_OP_N
|
| 643 |
+
*/
|
| 644 |
+
CUBLASLT_MATMUL_DESC_TRANSA = 3,
|
| 645 |
+
|
| 646 |
+
/** Transform of matrix B, see cublasOperation_t.
|
| 647 |
+
*
|
| 648 |
+
* int32_t, default: CUBLAS_OP_N
|
| 649 |
+
*/
|
| 650 |
+
CUBLASLT_MATMUL_DESC_TRANSB = 4,
|
| 651 |
+
|
| 652 |
+
/** Transform of matrix C, see cublasOperation_t.
|
| 653 |
+
*
|
| 654 |
+
* Currently only CUBLAS_OP_N is supported.
|
| 655 |
+
*
|
| 656 |
+
* int32_t, default: CUBLAS_OP_N
|
| 657 |
+
*/
|
| 658 |
+
CUBLASLT_MATMUL_DESC_TRANSC = 5,
|
| 659 |
+
|
| 660 |
+
/** Matrix fill mode, see cublasFillMode_t.
|
| 661 |
+
*
|
| 662 |
+
* int32_t, default: CUBLAS_FILL_MODE_FULL
|
| 663 |
+
*/
|
| 664 |
+
CUBLASLT_MATMUL_DESC_FILL_MODE = 6,
|
| 665 |
+
|
| 666 |
+
/** Epilogue function, see cublasLtEpilogue_t.
|
| 667 |
+
*
|
| 668 |
+
* uint32_t, default: CUBLASLT_EPILOGUE_DEFAULT
|
| 669 |
+
*/
|
| 670 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE = 7,
|
| 671 |
+
|
| 672 |
+
/** Bias or bias gradient vector pointer in the device memory.
|
| 673 |
+
*
|
| 674 |
+
* Bias case. See CUBLASLT_EPILOGUE_BIAS.
|
| 675 |
+
* For bias data type see CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE.
|
| 676 |
+
*
|
| 677 |
+
* Bias vector length must match matrix D rows count.
|
| 678 |
+
*
|
| 679 |
+
* Bias gradient case. See CUBLASLT_EPILOGUE_DRELU_BGRAD and CUBLASLT_EPILOGUE_DGELU_BGRAD.
|
| 680 |
+
* Bias gradient vector elements are the same type as the output elements
|
| 681 |
+
* (Ctype) with the exception of IMMA kernels (see above).
|
| 682 |
+
*
|
| 683 |
+
* Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic()
|
| 684 |
+
* depend on its value to determine expected pointer alignment.
|
| 685 |
+
*
|
| 686 |
+
* Bias case: const void *, default: NULL
|
| 687 |
+
* Bias gradient case: void *, default: NULL
|
| 688 |
+
*/
|
| 689 |
+
CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8,
|
| 690 |
+
|
| 691 |
+
/** Batch stride for bias or bias gradient vector.
|
| 692 |
+
*
|
| 693 |
+
* Used together with CUBLASLT_MATMUL_DESC_BIAS_POINTER when matrix D's CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1.
|
| 694 |
+
*
|
| 695 |
+
* int64_t, default: 0
|
| 696 |
+
*/
|
| 697 |
+
CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE = 10,
|
| 698 |
+
|
| 699 |
+
/** Pointer for epilogue auxiliary buffer.
|
| 700 |
+
*
|
| 701 |
+
* - Output vector for ReLu bit-mask in forward pass when CUBLASLT_EPILOGUE_RELU_AUX
|
| 702 |
+
* or CUBLASLT_EPILOGUE_RELU_AUX_BIAS epilogue is used.
|
| 703 |
+
* - Input vector for ReLu bit-mask in backward pass when
|
| 704 |
+
* CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is used.
|
| 705 |
+
*
|
| 706 |
+
* - Output of GELU input matrix in forward pass when
|
| 707 |
+
* CUBLASLT_EPILOGUE_GELU_AUX_BIAS epilogue is used.
|
| 708 |
+
* - Input of GELU input matrix for backward pass when
|
| 709 |
+
* CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue is used.
|
| 710 |
+
*
|
| 711 |
+
* For aux data type see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE.
|
| 712 |
+
*
|
| 713 |
+
* Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic()
|
| 714 |
+
* depend on its value to determine expected pointer alignment.
|
| 715 |
+
*
|
| 716 |
+
* Requires setting CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD attribute.
|
| 717 |
+
*
|
| 718 |
+
* Forward pass: void *, default: NULL
|
| 719 |
+
* Backward pass: const void *, default: NULL
|
| 720 |
+
*/
|
| 721 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER = 11,
|
| 722 |
+
|
| 723 |
+
/** Leading dimension for epilogue auxiliary buffer.
|
| 724 |
+
*
|
| 725 |
+
* - ReLu bit-mask matrix leading dimension in elements (i.e. bits)
|
| 726 |
+
* when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is
|
| 727 |
+
* used. Must be divisible by 128 and be no less than the number of rows in the output matrix.
|
| 728 |
+
*
|
| 729 |
+
* - GELU input matrix leading dimension in elements
|
| 730 |
+
* when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used.
|
| 731 |
+
* Must be divisible by 8 and be no less than the number of rows in the output matrix.
|
| 732 |
+
*
|
| 733 |
+
* int64_t, default: 0
|
| 734 |
+
*/
|
| 735 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD = 12,
|
| 736 |
+
|
| 737 |
+
/** Batch stride for epilogue auxiliary buffer.
|
| 738 |
+
*
|
| 739 |
+
* - ReLu bit-mask matrix batch stride in elements (i.e. bits)
|
| 740 |
+
* when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is
|
| 741 |
+
* used. Must be divisible by 128.
|
| 742 |
+
*
|
| 743 |
+
* - GELU input matrix batch stride in elements
|
| 744 |
+
* when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used.
|
| 745 |
+
* Must be divisible by 8.
|
| 746 |
+
*
|
| 747 |
+
* int64_t, default: 0
|
| 748 |
+
*/
|
| 749 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE = 13,
|
| 750 |
+
|
| 751 |
+
/** Batch stride for alpha vector.
|
| 752 |
+
*
|
| 753 |
+
* Used together with CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST when matrix D's
|
| 754 |
+
* CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1. If CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO is set then
|
| 755 |
+
* CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE must be set to 0 as this mode doesnt supported batched alpha vector.
|
| 756 |
+
*
|
| 757 |
+
* int64_t, default: 0
|
| 758 |
+
*/
|
| 759 |
+
CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE = 14,
|
| 760 |
+
|
| 761 |
+
/** Number of SMs to target for parallel execution. Optimizes heuristics for execution on a different number of SMs
|
| 762 |
+
* when user expects a concurrent stream to be using some of the device resources.
|
| 763 |
+
*
|
| 764 |
+
* int32_t, default: 0 - use the number reported by the device.
|
| 765 |
+
*/
|
| 766 |
+
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 15,
|
| 767 |
+
|
| 768 |
+
/** Device pointer to the scale factor value that converts data in matrix A to the compute data type range.
|
| 769 |
+
*
|
| 770 |
+
* The scaling factor value must have the same type as the compute type.
|
| 771 |
+
*
|
| 772 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1.
|
| 773 |
+
*
|
| 774 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 775 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 776 |
+
*
|
| 777 |
+
* const void *, default: NULL
|
| 778 |
+
*/
|
| 779 |
+
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 17,
|
| 780 |
+
|
| 781 |
+
/** Device pointer to the scale factor value to convert data in matrix B to compute data type range.
|
| 782 |
+
*
|
| 783 |
+
* The scaling factor value must have the same type as the compute type.
|
| 784 |
+
*
|
| 785 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1.
|
| 786 |
+
*
|
| 787 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 788 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 789 |
+
*
|
| 790 |
+
* const void *, default: NULL
|
| 791 |
+
*/
|
| 792 |
+
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 18,
|
| 793 |
+
|
| 794 |
+
/** Device pointer to the scale factor value to convert data in matrix C to compute data type range.
|
| 795 |
+
*
|
| 796 |
+
* The scaling factor value must have the same type as the compute type.
|
| 797 |
+
*
|
| 798 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1.
|
| 799 |
+
*
|
| 800 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 801 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 802 |
+
*
|
| 803 |
+
* const void *, default: NULL
|
| 804 |
+
*/
|
| 805 |
+
CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 19,
|
| 806 |
+
|
| 807 |
+
/** Device pointer to the scale factor value to convert data in matrix D to compute data type range.
|
| 808 |
+
*
|
| 809 |
+
* The scaling factor value must have the same type as the compute type.
|
| 810 |
+
*
|
| 811 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1.
|
| 812 |
+
*
|
| 813 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 814 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 815 |
+
*
|
| 816 |
+
* const void *, default: NULL
|
| 817 |
+
*/
|
| 818 |
+
CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 20,
|
| 819 |
+
|
| 820 |
+
/** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the
|
| 821 |
+
* output matrix.
|
| 822 |
+
*
|
| 823 |
+
* The computed value has the same type as the compute type.
|
| 824 |
+
*
|
| 825 |
+
* If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix
|
| 826 |
+
* data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
|
| 827 |
+
*
|
| 828 |
+
* void *, default: NULL
|
| 829 |
+
*/
|
| 830 |
+
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 21,
|
| 831 |
+
|
| 832 |
+
/** Type of the data to be stored to the memory pointed to by CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 833 |
+
*
|
| 834 |
+
* If unset, the data type defaults to the type of elements of the output matrix with some exceptions, see details
|
| 835 |
+
* below.
|
| 836 |
+
*
|
| 837 |
+
* ReLu uses a bit-mask.
|
| 838 |
+
*
|
| 839 |
+
* GELU input matrix elements type is the same as the type of elements of
|
| 840 |
+
* the output matrix with some exceptions, see details below.
|
| 841 |
+
*
|
| 842 |
+
* For fp8 kernels with output type CUDA_R_8F_E4M3 the aux data type can be CUDA_R_8F_E4M3 or CUDA_R_16F with some
|
| 843 |
+
* restrictions. See https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulDescAttributes_t for more details.
|
| 844 |
+
*
|
| 845 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 846 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 847 |
+
*
|
| 848 |
+
* int32_t based on cudaDataType, default: -1
|
| 849 |
+
*/
|
| 850 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE = 22,
|
| 851 |
+
|
| 852 |
+
/** Device pointer to the scaling factor value to convert results from compute type data range to storage
|
| 853 |
+
* data range in the auxiliary matrix that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 854 |
+
*
|
| 855 |
+
* The scaling factor value must have the same type as the compute type.
|
| 856 |
+
*
|
| 857 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1. If set for an unsupported matrix data,
|
| 858 |
+
* scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
|
| 859 |
+
*
|
| 860 |
+
* void *, default: NULL
|
| 861 |
+
*/
|
| 862 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER = 23,
|
| 863 |
+
|
| 864 |
+
/** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the
|
| 865 |
+
* buffer that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 866 |
+
*
|
| 867 |
+
* The computed value has the same type as the compute type.
|
| 868 |
+
*
|
| 869 |
+
* If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix
|
| 870 |
+
* data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
|
| 871 |
+
*
|
| 872 |
+
* void *, default: NULL
|
| 873 |
+
*/
|
| 874 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER = 24,
|
| 875 |
+
|
| 876 |
+
/** Flag for managing fp8 fast accumulation mode.
|
| 877 |
+
* When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results
|
| 878 |
+
* will not periodically be promoted to a higher precision.
|
| 879 |
+
*
|
| 880 |
+
* int8_t, default: 0 - fast accumulation mode is disabled.
|
| 881 |
+
*/
|
| 882 |
+
CUBLASLT_MATMUL_DESC_FAST_ACCUM = 25,
|
| 883 |
+
|
| 884 |
+
/** Type of bias or bias gradient vector in the device memory.
|
| 885 |
+
*
|
| 886 |
+
* Bias case: see CUBLASLT_EPILOGUE_BIAS.
|
| 887 |
+
*
|
| 888 |
+
* Bias vector elements are the same type as the elements of output matrix (Dtype) with the following exceptions:
|
| 889 |
+
* - IMMA kernels with computeType=CUDA_R_32I and Ctype=CUDA_R_8I where the bias vector elements
|
| 890 |
+
* are the same type as alpha, beta (CUBLASLT_MATMUL_DESC_SCALE_TYPE=CUDA_R_32F)
|
| 891 |
+
* - fp8 kernels with an output type of CUDA_R_32F, CUDA_R_8F_E4M3 or CUDA_R_8F_E5M2, See
|
| 892 |
+
* https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul for details.
|
| 893 |
+
*
|
| 894 |
+
* int32_t based on cudaDataType, default: -1
|
| 895 |
+
*/
|
| 896 |
+
CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 26,
|
| 897 |
+
|
| 898 |
+
/** EXPERIMENTAL: Number of atomic synchronization chunks in the row dimension of the output matrix D.
|
| 899 |
+
*
|
| 900 |
+
* int32_t, default 0 (atomic synchronization disabled)
|
| 901 |
+
*/
|
| 902 |
+
CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS = 27,
|
| 903 |
+
|
| 904 |
+
/** EXPERIMENTAL: Number of atomic synchronization chunks in the column dimension of the output matrix D.
|
| 905 |
+
*
|
| 906 |
+
* int32_t, default 0 (atomic synchronization disabled)
|
| 907 |
+
*/
|
| 908 |
+
CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS = 28,
|
| 909 |
+
|
| 910 |
+
/** EXPERIMENTAL: Pointer to a device array of input atomic counters consumed by a matmul.
|
| 911 |
+
*
|
| 912 |
+
* int32_t *, default: NULL
|
| 913 |
+
* */
|
| 914 |
+
CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER = 29,
|
| 915 |
+
|
| 916 |
+
/** EXPERIMENTAL: Pointer to a device array of output atomic counters produced by a matmul.
|
| 917 |
+
*
|
| 918 |
+
* int32_t *, default: NULL
|
| 919 |
+
* */
|
| 920 |
+
CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER = 30,
|
| 921 |
+
} cublasLtMatmulDescAttributes_t;
|
| 922 |
+
|
| 923 |
+
/** Internal. Do not use directly.
|
| 924 |
+
*/
|
| 925 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( //
|
| 926 |
+
cublasLtMatmulDesc_t matmulDesc,
|
| 927 |
+
size_t size,
|
| 928 |
+
cublasComputeType_t computeType,
|
| 929 |
+
cudaDataType_t scaleType);
|
| 930 |
+
|
| 931 |
+
/** Initialize matmul operation descriptor in pre-allocated space.
|
| 932 |
+
*
|
| 933 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
|
| 934 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was initialized successfully
|
| 935 |
+
*/
|
| 936 |
+
static inline cublasStatus_t cublasLtMatmulDescInit( //
|
| 937 |
+
cublasLtMatmulDesc_t matmulDesc,
|
| 938 |
+
cublasComputeType_t computeType,
|
| 939 |
+
cudaDataType_t scaleType) {
|
| 940 |
+
return cublasLtMatmulDescInit_internal(matmulDesc, sizeof(*matmulDesc), computeType, scaleType);
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
/** Create new matmul operation descriptor.
|
| 944 |
+
*
|
| 945 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
|
| 946 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 947 |
+
*/
|
| 948 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate(cublasLtMatmulDesc_t* matmulDesc,
|
| 949 |
+
cublasComputeType_t computeType,
|
| 950 |
+
cudaDataType_t scaleType);
|
| 951 |
+
|
| 952 |
+
/** Destroy matmul operation descriptor.
|
| 953 |
+
*
|
| 954 |
+
* \retval CUBLAS_STATUS_SUCCESS if operation was successful
|
| 955 |
+
*/
|
| 956 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc);
|
| 957 |
+
|
| 958 |
+
/** Set matmul operation descriptor attribute.
|
| 959 |
+
*
|
| 960 |
+
* \param[in] matmulDesc The descriptor
|
| 961 |
+
* \param[in] attr The attribute
|
| 962 |
+
* \param[in] buf memory address containing the new value
|
| 963 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 964 |
+
*
|
| 965 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 966 |
+
* selected attribute
|
| 967 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 968 |
+
*/
|
| 969 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( //
|
| 970 |
+
cublasLtMatmulDesc_t matmulDesc,
|
| 971 |
+
cublasLtMatmulDescAttributes_t attr,
|
| 972 |
+
const void* buf,
|
| 973 |
+
size_t sizeInBytes);
|
| 974 |
+
|
| 975 |
+
/** Get matmul operation descriptor attribute.
|
| 976 |
+
*
|
| 977 |
+
* \param[in] matmulDesc The descriptor
|
| 978 |
+
* \param[in] attr The attribute
|
| 979 |
+
* \param[out] buf memory address containing the new value
|
| 980 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 981 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 982 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 983 |
+
*
|
| 984 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 985 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 986 |
+
* selected attribute
|
| 987 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 988 |
+
*/
|
| 989 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( //
|
| 990 |
+
cublasLtMatmulDesc_t matmulDesc,
|
| 991 |
+
cublasLtMatmulDescAttributes_t attr,
|
| 992 |
+
void* buf,
|
| 993 |
+
size_t sizeInBytes,
|
| 994 |
+
size_t* sizeWritten);
|
| 995 |
+
|
| 996 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 997 |
+
/* Helper functions for cublasLtMatrixTransformDesc_t */
|
| 998 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 999 |
+
|
| 1000 |
+
/** Matrix transform descriptor attributes to define details of the operation.
|
| 1001 |
+
*/
|
| 1002 |
+
typedef enum {
|
| 1003 |
+
/** Scale type, see cudaDataType. Inputs are converted to scale type for scaling and summation and results are then
|
| 1004 |
+
* converted to output type to store in memory.
|
| 1005 |
+
*
|
| 1006 |
+
* int32_t
|
| 1007 |
+
*/
|
| 1008 |
+
CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE,
|
| 1009 |
+
|
| 1010 |
+
/** Pointer mode of alpha and beta, see cublasLtPointerMode_t.
|
| 1011 |
+
*
|
| 1012 |
+
* int32_t, default: CUBLASLT_POINTER_MODE_HOST
|
| 1013 |
+
*/
|
| 1014 |
+
CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE,
|
| 1015 |
+
|
| 1016 |
+
/** Transform of matrix A, see cublasOperation_t.
|
| 1017 |
+
*
|
| 1018 |
+
* int32_t, default: CUBLAS_OP_N
|
| 1019 |
+
*/
|
| 1020 |
+
CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA,
|
| 1021 |
+
|
| 1022 |
+
/** Transform of matrix B, see cublasOperation_t.
|
| 1023 |
+
*
|
| 1024 |
+
* int32_t, default: CUBLAS_OP_N
|
| 1025 |
+
*/
|
| 1026 |
+
CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSB,
|
| 1027 |
+
} cublasLtMatrixTransformDescAttributes_t;
|
| 1028 |
+
|
| 1029 |
+
/** Internal. Do not use directly.
|
| 1030 |
+
*/
|
| 1031 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc,
|
| 1032 |
+
size_t size,
|
| 1033 |
+
cudaDataType scaleType);
|
| 1034 |
+
|
| 1035 |
+
/** Initialize matrix transform operation descriptor in pre-allocated space.
|
| 1036 |
+
*
|
| 1037 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
|
| 1038 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 1039 |
+
*/
|
| 1040 |
+
static inline cublasStatus_t cublasLtMatrixTransformDescInit(cublasLtMatrixTransformDesc_t transformDesc,
|
| 1041 |
+
cudaDataType scaleType) {
|
| 1042 |
+
return cublasLtMatrixTransformDescInit_internal(transformDesc, sizeof(*transformDesc), scaleType);
|
| 1043 |
+
}
|
| 1044 |
+
|
| 1045 |
+
/** Create new matrix transform operation descriptor.
|
| 1046 |
+
*
|
| 1047 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
|
| 1048 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 1049 |
+
*/
|
| 1050 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t* transformDesc,
|
| 1051 |
+
cudaDataType scaleType);
|
| 1052 |
+
|
| 1053 |
+
/** Destroy matrix transform operation descriptor.
|
| 1054 |
+
*
|
| 1055 |
+
* \retval CUBLAS_STATUS_SUCCESS if operation was successful
|
| 1056 |
+
*/
|
| 1057 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc);
|
| 1058 |
+
|
| 1059 |
+
/** Set matrix transform operation descriptor attribute.
|
| 1060 |
+
*
|
| 1061 |
+
* \param[in] transformDesc The descriptor
|
| 1062 |
+
* \param[in] attr The attribute
|
| 1063 |
+
* \param[in] buf memory address containing the new value
|
| 1064 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1065 |
+
*
|
| 1066 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1067 |
+
* selected attribute
|
| 1068 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 1069 |
+
*/
|
| 1070 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( //
|
| 1071 |
+
cublasLtMatrixTransformDesc_t transformDesc,
|
| 1072 |
+
cublasLtMatrixTransformDescAttributes_t attr,
|
| 1073 |
+
const void* buf,
|
| 1074 |
+
size_t sizeInBytes);
|
| 1075 |
+
|
| 1076 |
+
/** Get matrix transform operation descriptor attribute.
|
| 1077 |
+
*
|
| 1078 |
+
* \param[in] transformDesc The descriptor
|
| 1079 |
+
* \param[in] attr The attribute
|
| 1080 |
+
* \param[out] buf memory address containing the new value
|
| 1081 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1082 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number
|
| 1083 |
+
* of bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 1084 |
+
*
|
| 1085 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 1086 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1087 |
+
* selected attribute
|
| 1088 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 1089 |
+
*/
|
| 1090 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( //
|
| 1091 |
+
cublasLtMatrixTransformDesc_t transformDesc,
|
| 1092 |
+
cublasLtMatrixTransformDescAttributes_t attr,
|
| 1093 |
+
void* buf,
|
| 1094 |
+
size_t sizeInBytes,
|
| 1095 |
+
size_t* sizeWritten);
|
| 1096 |
+
|
| 1097 |
+
/** Reduction scheme for portions of the dot-product calculated in parallel (a. k. a. "split - K").
|
| 1098 |
+
*/
|
| 1099 |
+
typedef enum {
|
| 1100 |
+
/** No reduction scheme, dot-product shall be performed in one sequence.
|
| 1101 |
+
*/
|
| 1102 |
+
CUBLASLT_REDUCTION_SCHEME_NONE = 0,
|
| 1103 |
+
|
| 1104 |
+
/** Reduction is performed "in place" - using the output buffer (and output data type) and counters (in workspace) to
|
| 1105 |
+
* guarantee the sequentiality.
|
| 1106 |
+
*/
|
| 1107 |
+
CUBLASLT_REDUCTION_SCHEME_INPLACE = 1,
|
| 1108 |
+
|
| 1109 |
+
/** Intermediate results are stored in compute type in the workspace and reduced in a separate step.
|
| 1110 |
+
*/
|
| 1111 |
+
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE = 2,
|
| 1112 |
+
|
| 1113 |
+
/** Intermediate results are stored in output type in the workspace and reduced in a separate step.
|
| 1114 |
+
*/
|
| 1115 |
+
CUBLASLT_REDUCTION_SCHEME_OUTPUT_TYPE = 4,
|
| 1116 |
+
|
| 1117 |
+
CUBLASLT_REDUCTION_SCHEME_MASK = 0x7,
|
| 1118 |
+
} cublasLtReductionScheme_t;
|
| 1119 |
+
|
| 1120 |
+
/** Postprocessing options for the epilogue
|
| 1121 |
+
*/
|
| 1122 |
+
typedef enum {
|
| 1123 |
+
/** No special postprocessing, just scale and quantize results if necessary.
|
| 1124 |
+
*/
|
| 1125 |
+
CUBLASLT_EPILOGUE_DEFAULT = 1,
|
| 1126 |
+
|
| 1127 |
+
/** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)).
|
| 1128 |
+
*/
|
| 1129 |
+
CUBLASLT_EPILOGUE_RELU = 2,
|
| 1130 |
+
|
| 1131 |
+
/** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)).
|
| 1132 |
+
*
|
| 1133 |
+
* This epilogue mode produces an extra output, a ReLu bit-mask matrix,
|
| 1134 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1135 |
+
*/
|
| 1136 |
+
CUBLASLT_EPILOGUE_RELU_AUX = (CUBLASLT_EPILOGUE_RELU | 128),
|
| 1137 |
+
|
| 1138 |
+
/** Bias, apply (broadcasted) Bias from bias vector. Bias vector length must match matrix D rows, it must be packed
|
| 1139 |
+
* (stride between vector elements is 1). Bias vector is broadcasted to all columns and added before applying final
|
| 1140 |
+
* postprocessing.
|
| 1141 |
+
*/
|
| 1142 |
+
CUBLASLT_EPILOGUE_BIAS = 4,
|
| 1143 |
+
|
| 1144 |
+
/** ReLu and Bias, apply Bias and then ReLu transform
|
| 1145 |
+
*/
|
| 1146 |
+
CUBLASLT_EPILOGUE_RELU_BIAS = (CUBLASLT_EPILOGUE_RELU | CUBLASLT_EPILOGUE_BIAS),
|
| 1147 |
+
|
| 1148 |
+
/** ReLu and Bias, apply Bias and then ReLu transform
|
| 1149 |
+
*
|
| 1150 |
+
* This epilogue mode produces an extra output, a ReLu bit-mask matrix,
|
| 1151 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1152 |
+
*/
|
| 1153 |
+
CUBLASLT_EPILOGUE_RELU_AUX_BIAS = (CUBLASLT_EPILOGUE_RELU_AUX | CUBLASLT_EPILOGUE_BIAS),
|
| 1154 |
+
|
| 1155 |
+
/* ReLu gradient. Apply ReLu gradient to matmul output. Store ReLu gradient in the output matrix.
|
| 1156 |
+
*
|
| 1157 |
+
* This epilogue mode requires an extra input,
|
| 1158 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1159 |
+
*/
|
| 1160 |
+
CUBLASLT_EPILOGUE_DRELU = 8 | 128,
|
| 1161 |
+
|
| 1162 |
+
/* ReLu and Bias gradients. Apply independently ReLu and Bias gradient to
|
| 1163 |
+
* matmul output. Store ReLu gradient in the output matrix, and Bias gradient
|
| 1164 |
+
* in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
|
| 1165 |
+
*
|
| 1166 |
+
* This epilogue mode requires an extra input,
|
| 1167 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1168 |
+
*/
|
| 1169 |
+
CUBLASLT_EPILOGUE_DRELU_BGRAD = CUBLASLT_EPILOGUE_DRELU | 16,
|
| 1170 |
+
|
| 1171 |
+
/** GELU, apply GELU point-wise transform to the results (x:=GELU(x)).
|
| 1172 |
+
*/
|
| 1173 |
+
CUBLASLT_EPILOGUE_GELU = 32,
|
| 1174 |
+
|
| 1175 |
+
/** GELU, apply GELU point-wise transform to the results (x:=GELU(x)).
|
| 1176 |
+
*
|
| 1177 |
+
* This epilogue mode outputs GELU input as a separate matrix (useful for training).
|
| 1178 |
+
* See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1179 |
+
*/
|
| 1180 |
+
CUBLASLT_EPILOGUE_GELU_AUX = (CUBLASLT_EPILOGUE_GELU | 128),
|
| 1181 |
+
|
| 1182 |
+
/** GELU and Bias, apply Bias and then GELU transform
|
| 1183 |
+
*/
|
| 1184 |
+
CUBLASLT_EPILOGUE_GELU_BIAS = (CUBLASLT_EPILOGUE_GELU | CUBLASLT_EPILOGUE_BIAS),
|
| 1185 |
+
|
| 1186 |
+
/** GELU and Bias, apply Bias and then GELU transform
|
| 1187 |
+
*
|
| 1188 |
+
* This epilogue mode outputs GELU input as a separate matrix (useful for training).
|
| 1189 |
+
* See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1190 |
+
*/
|
| 1191 |
+
CUBLASLT_EPILOGUE_GELU_AUX_BIAS = (CUBLASLT_EPILOGUE_GELU_AUX | CUBLASLT_EPILOGUE_BIAS),
|
| 1192 |
+
|
| 1193 |
+
/* GELU gradient. Apply GELU gradient to matmul output. Store GELU gradient in the output matrix.
|
| 1194 |
+
*
|
| 1195 |
+
* This epilogue mode requires an extra input,
|
| 1196 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1197 |
+
*/
|
| 1198 |
+
CUBLASLT_EPILOGUE_DGELU = 64 | 128,
|
| 1199 |
+
|
| 1200 |
+
/* GELU and Bias gradients. Apply independently GELU and Bias gradient to
|
| 1201 |
+
* matmul output. Store GELU gradient in the output matrix, and Bias gradient
|
| 1202 |
+
* in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
|
| 1203 |
+
*
|
| 1204 |
+
* This epilogue mode requires an extra input,
|
| 1205 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1206 |
+
*/
|
| 1207 |
+
CUBLASLT_EPILOGUE_DGELU_BGRAD = CUBLASLT_EPILOGUE_DGELU | 16,
|
| 1208 |
+
|
| 1209 |
+
/** Bias gradient based on the input matrix A.
|
| 1210 |
+
*
|
| 1211 |
+
* The bias size corresponds to the number of rows of the matrix D.
|
| 1212 |
+
* The reduction happens over the GEMM's "k" dimension.
|
| 1213 |
+
*
|
| 1214 |
+
* Stores Bias gradient in the auxiliary output
|
| 1215 |
+
* (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
|
| 1216 |
+
*/
|
| 1217 |
+
CUBLASLT_EPILOGUE_BGRADA = 256,
|
| 1218 |
+
|
| 1219 |
+
/** Bias gradient based on the input matrix B.
|
| 1220 |
+
*
|
| 1221 |
+
* The bias size corresponds to the number of columns of the matrix D.
|
| 1222 |
+
* The reduction happens over the GEMM's "k" dimension.
|
| 1223 |
+
*
|
| 1224 |
+
* Stores Bias gradient in the auxiliary output
|
| 1225 |
+
* (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
|
| 1226 |
+
*/
|
| 1227 |
+
CUBLASLT_EPILOGUE_BGRADB = 512,
|
| 1228 |
+
} cublasLtEpilogue_t;
|
| 1229 |
+
|
| 1230 |
+
/** Matmul heuristic search mode
|
| 1231 |
+
*/
|
| 1232 |
+
typedef enum {
|
| 1233 |
+
/** ask heuristics for best algo for given usecase
|
| 1234 |
+
*/
|
| 1235 |
+
CUBLASLT_SEARCH_BEST_FIT = 0,
|
| 1236 |
+
/** only try to find best config for preconfigured algo id
|
| 1237 |
+
*/
|
| 1238 |
+
CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID = 1,
|
| 1239 |
+
/** reserved for future use
|
| 1240 |
+
*/
|
| 1241 |
+
CUBLASLT_SEARCH_RESERVED_02 = 2,
|
| 1242 |
+
/** reserved for future use
|
| 1243 |
+
*/
|
| 1244 |
+
CUBLASLT_SEARCH_RESERVED_03 = 3,
|
| 1245 |
+
/** reserved for future use
|
| 1246 |
+
*/
|
| 1247 |
+
CUBLASLT_SEARCH_RESERVED_04 = 4,
|
| 1248 |
+
/** reserved for future use
|
| 1249 |
+
*/
|
| 1250 |
+
CUBLASLT_SEARCH_RESERVED_05 = 5,
|
| 1251 |
+
} cublasLtMatmulSearch_t;
|
| 1252 |
+
|
| 1253 |
+
/** Algo search preference to fine tune the heuristic function. */
|
| 1254 |
+
typedef enum {
|
| 1255 |
+
/** Search mode, see cublasLtMatmulSearch_t.
|
| 1256 |
+
*
|
| 1257 |
+
* uint32_t, default: CUBLASLT_SEARCH_BEST_FIT
|
| 1258 |
+
*/
|
| 1259 |
+
CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0,
|
| 1260 |
+
|
| 1261 |
+
/** Maximum allowed workspace size in bytes.
|
| 1262 |
+
*
|
| 1263 |
+
* uint64_t, default: 0 - no workspace allowed
|
| 1264 |
+
*/
|
| 1265 |
+
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1,
|
| 1266 |
+
|
| 1267 |
+
/** Reduction scheme mask, see cublasLtReductionScheme_t. Filters heuristic result to only include algo configs that
|
| 1268 |
+
* use one of the required modes.
|
| 1269 |
+
*
|
| 1270 |
+
* E.g. mask value of 0x03 will allow only INPLACE and COMPUTE_TYPE reduction schemes.
|
| 1271 |
+
*
|
| 1272 |
+
* uint32_t, default: CUBLASLT_REDUCTION_SCHEME_MASK (allows all reduction schemes)
|
| 1273 |
+
*/
|
| 1274 |
+
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3,
|
| 1275 |
+
|
| 1276 |
+
/** Minimum buffer alignment for matrix A (in bytes).
|
| 1277 |
+
*
|
| 1278 |
+
* Selecting a smaller value will exclude algorithms that can not work with matrix A that is not as strictly aligned
|
| 1279 |
+
* as they need.
|
| 1280 |
+
*
|
| 1281 |
+
* uint32_t, default: 256
|
| 1282 |
+
*/
|
| 1283 |
+
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5,
|
| 1284 |
+
|
| 1285 |
+
/** Minimum buffer alignment for matrix B (in bytes).
|
| 1286 |
+
*
|
| 1287 |
+
* Selecting a smaller value will exclude algorithms that can not work with matrix B that is not as strictly aligned
|
| 1288 |
+
* as they need.
|
| 1289 |
+
*
|
| 1290 |
+
* uint32_t, default: 256
|
| 1291 |
+
*/
|
| 1292 |
+
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6,
|
| 1293 |
+
|
| 1294 |
+
/** Minimum buffer alignment for matrix C (in bytes).
|
| 1295 |
+
*
|
| 1296 |
+
* Selecting a smaller value will exclude algorithms that can not work with matrix C that is not as strictly aligned
|
| 1297 |
+
* as they need.
|
| 1298 |
+
*
|
| 1299 |
+
* uint32_t, default: 256
|
| 1300 |
+
*/
|
| 1301 |
+
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7,
|
| 1302 |
+
|
| 1303 |
+
/** Minimum buffer alignment for matrix D (in bytes).
|
| 1304 |
+
*
|
| 1305 |
+
* Selecting a smaller value will exclude algorithms that can not work with matrix D that is not as strictly aligned
|
| 1306 |
+
* as they need.
|
| 1307 |
+
*
|
| 1308 |
+
* uint32_t, default: 256
|
| 1309 |
+
*/
|
| 1310 |
+
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8,
|
| 1311 |
+
|
| 1312 |
+
/** Maximum wave count.
|
| 1313 |
+
*
|
| 1314 |
+
* See cublasLtMatmulHeuristicResult_t::wavesCount.
|
| 1315 |
+
*
|
| 1316 |
+
* Selecting a non-zero value will exclude algorithms that report device utilization higher than specified.
|
| 1317 |
+
*
|
| 1318 |
+
* float, default: 0.0f
|
| 1319 |
+
*/
|
| 1320 |
+
CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9,
|
| 1321 |
+
|
| 1322 |
+
/** Numerical implementation details mask, see cublasLtNumericalImplFlags_t. Filters heuristic result to only include
|
| 1323 |
+
* algorithms that use the allowed implementations.
|
| 1324 |
+
*
|
| 1325 |
+
* uint64_t, default: uint64_t(-1) (allow everything)
|
| 1326 |
+
*/
|
| 1327 |
+
CUBLASLT_MATMUL_PREF_IMPL_MASK = 12,
|
| 1328 |
+
} cublasLtMatmulPreferenceAttributes_t;
|
| 1329 |
+
|
| 1330 |
+
/** Internal. Do not use directly.
|
| 1331 |
+
*/
|
| 1332 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size);
|
| 1333 |
+
|
| 1334 |
+
/** Initialize matmul heuristic search preference descriptor in pre-allocated space.
|
| 1335 |
+
*
|
| 1336 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
|
| 1337 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 1338 |
+
*/
|
| 1339 |
+
static inline cublasStatus_t cublasLtMatmulPreferenceInit(cublasLtMatmulPreference_t pref) {
|
| 1340 |
+
return cublasLtMatmulPreferenceInit_internal(pref, sizeof(*pref));
|
| 1341 |
+
}
|
| 1342 |
+
|
| 1343 |
+
/** Create new matmul heuristic search preference descriptor.
|
| 1344 |
+
*
|
| 1345 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
|
| 1346 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 1347 |
+
*/
|
| 1348 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t* pref);
|
| 1349 |
+
|
| 1350 |
+
/** Destroy matmul heuristic search preference descriptor.
|
| 1351 |
+
*
|
| 1352 |
+
* \retval CUBLAS_STATUS_SUCCESS if operation was successful
|
| 1353 |
+
*/
|
| 1354 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref);
|
| 1355 |
+
|
| 1356 |
+
/** Set matmul heuristic search preference descriptor attribute.
|
| 1357 |
+
*
|
| 1358 |
+
* \param[in] pref The descriptor
|
| 1359 |
+
* \param[in] attr The attribute
|
| 1360 |
+
* \param[in] buf memory address containing the new value
|
| 1361 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1362 |
+
*
|
| 1363 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1364 |
+
* selected attribute
|
| 1365 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 1366 |
+
*/
|
| 1367 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( //
|
| 1368 |
+
cublasLtMatmulPreference_t pref,
|
| 1369 |
+
cublasLtMatmulPreferenceAttributes_t attr,
|
| 1370 |
+
const void* buf,
|
| 1371 |
+
size_t sizeInBytes);
|
| 1372 |
+
|
| 1373 |
+
/** Get matmul heuristic search preference descriptor attribute.
|
| 1374 |
+
*
|
| 1375 |
+
* \param[in] pref The descriptor
|
| 1376 |
+
* \param[in] attr The attribute
|
| 1377 |
+
* \param[out] buf memory address containing the new value
|
| 1378 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1379 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 1380 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 1381 |
+
*
|
| 1382 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 1383 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1384 |
+
* selected attribute
|
| 1385 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 1386 |
+
*/
|
| 1387 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( //
|
| 1388 |
+
cublasLtMatmulPreference_t pref,
|
| 1389 |
+
cublasLtMatmulPreferenceAttributes_t attr,
|
| 1390 |
+
void* buf,
|
| 1391 |
+
size_t sizeInBytes,
|
| 1392 |
+
size_t* sizeWritten);
|
| 1393 |
+
|
| 1394 |
+
/** Results structure used by cublasLtMatmulGetAlgo.
|
| 1395 |
+
*
|
| 1396 |
+
* Holds returned configured algo descriptor and its runtime properties.
|
| 1397 |
+
*/
|
| 1398 |
+
typedef struct {
|
| 1399 |
+
/** Matmul algorithm descriptor.
|
| 1400 |
+
*
|
| 1401 |
+
* Must be initialized with cublasLtMatmulAlgoInit() if preferences' CUBLASLT_MATMUL_PERF_SEARCH_MODE is set to
|
| 1402 |
+
* CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID
|
| 1403 |
+
*/
|
| 1404 |
+
cublasLtMatmulAlgo_t algo;
|
| 1405 |
+
|
| 1406 |
+
/** Actual size of workspace memory required.
|
| 1407 |
+
*/
|
| 1408 |
+
size_t workspaceSize;
|
| 1409 |
+
|
| 1410 |
+
/** Result status, other fields are only valid if after call to cublasLtMatmulAlgoGetHeuristic() this member is set to
|
| 1411 |
+
* CUBLAS_STATUS_SUCCESS.
|
| 1412 |
+
*/
|
| 1413 |
+
cublasStatus_t state;
|
| 1414 |
+
|
| 1415 |
+
/** Waves count - a device utilization metric.
|
| 1416 |
+
*
|
| 1417 |
+
* wavesCount value of 1.0f suggests that when kernel is launched it will fully occupy the GPU.
|
| 1418 |
+
*/
|
| 1419 |
+
float wavesCount;
|
| 1420 |
+
|
| 1421 |
+
int reserved[4];
|
| 1422 |
+
} cublasLtMatmulHeuristicResult_t;
|
| 1423 |
+
|
| 1424 |
+
/** Query cublasLt heuristic for algorithm appropriate for given use case.
|
| 1425 |
+
*
|
| 1426 |
+
* \param[in] lightHandle Pointer to the allocated cuBLASLt handle for the cuBLASLt
|
| 1427 |
+
* context. See cublasLtHandle_t.
|
| 1428 |
+
* \param[in] operationDesc Handle to the matrix multiplication descriptor.
|
| 1429 |
+
* \param[in] Adesc Handle to the layout descriptors for matrix A.
|
| 1430 |
+
* \param[in] Bdesc Handle to the layout descriptors for matrix B.
|
| 1431 |
+
* \param[in] Cdesc Handle to the layout descriptors for matrix C.
|
| 1432 |
+
* \param[in] Ddesc Handle to the layout descriptors for matrix D.
|
| 1433 |
+
* \param[in] preference Pointer to the structure holding the heuristic search
|
| 1434 |
+
* preferences descriptor. See cublasLtMatrixLayout_t.
|
| 1435 |
+
* \param[in] requestedAlgoCount Size of heuristicResultsArray (in elements) and requested
|
| 1436 |
+
* maximum number of algorithms to return.
|
| 1437 |
+
* \param[in, out] heuristicResultsArray Output algorithms and associated runtime characteristics,
|
| 1438 |
+
* ordered in increasing estimated compute time.
|
| 1439 |
+
* \param[out] returnAlgoCount The number of heuristicResultsArray elements written.
|
| 1440 |
+
*
|
| 1441 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero
|
| 1442 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if no heuristic function available for current configuration
|
| 1443 |
+
* \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect
|
| 1444 |
+
* heuristicResultsArray[0 to (returnAlgoCount - 1)].state
|
| 1445 |
+
* for detail status of results
|
| 1446 |
+
*/
|
| 1447 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic(cublasLtHandle_t lightHandle,
|
| 1448 |
+
cublasLtMatmulDesc_t operationDesc,
|
| 1449 |
+
cublasLtMatrixLayout_t Adesc,
|
| 1450 |
+
cublasLtMatrixLayout_t Bdesc,
|
| 1451 |
+
cublasLtMatrixLayout_t Cdesc,
|
| 1452 |
+
cublasLtMatrixLayout_t Ddesc,
|
| 1453 |
+
cublasLtMatmulPreference_t preference,
|
| 1454 |
+
int requestedAlgoCount,
|
| 1455 |
+
cublasLtMatmulHeuristicResult_t heuristicResultsArray[],
|
| 1456 |
+
int* returnAlgoCount);
|
| 1457 |
+
|
| 1458 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 1459 |
+
/* Lower level API to be able to implement own Heuristic and Find routines */
|
| 1460 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 1461 |
+
|
| 1462 |
+
/** Routine to get all algo IDs that can potentially run
|
| 1463 |
+
*
|
| 1464 |
+
* \param[in] int requestedAlgoCount requested number of algos (must be less or equal to size of algoIdsA
|
| 1465 |
+
* (in elements)) \param[out] algoIdsA array to write algoIds to \param[out] returnAlgoCount number of algoIds
|
| 1466 |
+
* actually written
|
| 1467 |
+
*
|
| 1468 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero
|
| 1469 |
+
* \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect returnAlgoCount to get actual number of IDs
|
| 1470 |
+
* available
|
| 1471 |
+
*/
|
| 1472 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds(cublasLtHandle_t lightHandle,
|
| 1473 |
+
cublasComputeType_t computeType,
|
| 1474 |
+
cudaDataType_t scaleType,
|
| 1475 |
+
cudaDataType_t Atype,
|
| 1476 |
+
cudaDataType_t Btype,
|
| 1477 |
+
cudaDataType_t Ctype,
|
| 1478 |
+
cudaDataType_t Dtype,
|
| 1479 |
+
int requestedAlgoCount,
|
| 1480 |
+
int algoIdsArray[],
|
| 1481 |
+
int* returnAlgoCount);
|
| 1482 |
+
|
| 1483 |
+
/** Initialize algo structure
|
| 1484 |
+
*
|
| 1485 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if algo is NULL or algoId is outside of recognized range
|
| 1486 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if algoId is not supported for given combination of data types
|
| 1487 |
+
* \retval CUBLAS_STATUS_SUCCESS if the structure was successfully initialized
|
| 1488 |
+
*/
|
| 1489 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit(cublasLtHandle_t lightHandle,
|
| 1490 |
+
cublasComputeType_t computeType,
|
| 1491 |
+
cudaDataType_t scaleType,
|
| 1492 |
+
cudaDataType_t Atype,
|
| 1493 |
+
cudaDataType_t Btype,
|
| 1494 |
+
cudaDataType_t Ctype,
|
| 1495 |
+
cudaDataType_t Dtype,
|
| 1496 |
+
int algoId,
|
| 1497 |
+
cublasLtMatmulAlgo_t* algo);
|
| 1498 |
+
|
| 1499 |
+
/** Check configured algo descriptor for correctness and support on current device.
|
| 1500 |
+
*
|
| 1501 |
+
* Result includes required workspace size and calculated wave count.
|
| 1502 |
+
*
|
| 1503 |
+
* CUBLAS_STATUS_SUCCESS doesn't fully guarantee algo will run (will fail if e.g. buffers are not correctly aligned);
|
| 1504 |
+
* but if cublasLtMatmulAlgoCheck fails, the algo will not run.
|
| 1505 |
+
*
|
| 1506 |
+
* \param[in] algo algo configuration to check
|
| 1507 |
+
* \param[out] result result structure to report algo runtime characteristics; algo field is never updated
|
| 1508 |
+
*
|
| 1509 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if matrix layout descriptors or operation descriptor don't match algo
|
| 1510 |
+
* descriptor
|
| 1511 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if algo configuration or data type combination is not currently supported on
|
| 1512 |
+
* given device
|
| 1513 |
+
* \retval CUBLAS_STATUS_ARCH_MISMATCH if algo configuration cannot be run using the selected device
|
| 1514 |
+
* \retval CUBLAS_STATUS_SUCCESS if check was successful
|
| 1515 |
+
*/
|
| 1516 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( //
|
| 1517 |
+
cublasLtHandle_t lightHandle,
|
| 1518 |
+
cublasLtMatmulDesc_t operationDesc,
|
| 1519 |
+
cublasLtMatrixLayout_t Adesc,
|
| 1520 |
+
cublasLtMatrixLayout_t Bdesc,
|
| 1521 |
+
cublasLtMatrixLayout_t Cdesc,
|
| 1522 |
+
cublasLtMatrixLayout_t Ddesc,
|
| 1523 |
+
const cublasLtMatmulAlgo_t* algo, ///< may point to result->algo
|
| 1524 |
+
cublasLtMatmulHeuristicResult_t* result);
|
| 1525 |
+
|
| 1526 |
+
/** Capabilities Attributes that can be retrieved from an initialized Algo structure
|
| 1527 |
+
*/
|
| 1528 |
+
typedef enum {
|
| 1529 |
+
/** support for split K, see CUBLASLT_ALGO_CONFIG_SPLITK_NUM
|
| 1530 |
+
*
|
| 1531 |
+
* int32_t, 0 means no support, supported otherwise
|
| 1532 |
+
*/
|
| 1533 |
+
CUBLASLT_ALGO_CAP_SPLITK_SUPPORT = 0,
|
| 1534 |
+
|
| 1535 |
+
/** reduction scheme mask, see cublasLtReductionScheme_t; shows supported reduction schemes, if reduction scheme is
|
| 1536 |
+
* not masked out it is supported.
|
| 1537 |
+
*
|
| 1538 |
+
* e.g. int isReductionSchemeComputeTypeSupported ? (reductionSchemeMask & CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE) ==
|
| 1539 |
+
* CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE ? 1 : 0;
|
| 1540 |
+
*
|
| 1541 |
+
* uint32_t
|
| 1542 |
+
*/
|
| 1543 |
+
CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK = 1,
|
| 1544 |
+
|
| 1545 |
+
/** support for cta swizzling, see CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
|
| 1546 |
+
*
|
| 1547 |
+
* uint32_t, 0 means no support, 1 means supported value of 1, other values are reserved
|
| 1548 |
+
*/
|
| 1549 |
+
CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT = 2,
|
| 1550 |
+
|
| 1551 |
+
/** support strided batch
|
| 1552 |
+
*
|
| 1553 |
+
* int32_t, 0 means no support, supported otherwise
|
| 1554 |
+
*/
|
| 1555 |
+
CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT = 3,
|
| 1556 |
+
|
| 1557 |
+
/** support results out of place (D != C in D = alpha.A.B + beta.C)
|
| 1558 |
+
*
|
| 1559 |
+
* int32_t, 0 means no support, supported otherwise
|
| 1560 |
+
*/
|
| 1561 |
+
CUBLASLT_ALGO_CAP_OUT_OF_PLACE_RESULT_SUPPORT = 4,
|
| 1562 |
+
|
| 1563 |
+
/** syrk/herk support (on top of regular gemm)
|
| 1564 |
+
*
|
| 1565 |
+
* int32_t, 0 means no support, supported otherwise
|
| 1566 |
+
*/
|
| 1567 |
+
CUBLASLT_ALGO_CAP_UPLO_SUPPORT = 5,
|
| 1568 |
+
|
| 1569 |
+
/** tile ids possible to use, see cublasLtMatmulTile_t; if no tile ids are supported use
|
| 1570 |
+
* CUBLASLT_MATMUL_TILE_UNDEFINED
|
| 1571 |
+
*
|
| 1572 |
+
* use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count
|
| 1573 |
+
*
|
| 1574 |
+
* array of uint32_t
|
| 1575 |
+
*/
|
| 1576 |
+
CUBLASLT_ALGO_CAP_TILE_IDS = 6,
|
| 1577 |
+
|
| 1578 |
+
/** custom option range is from 0 to CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX (inclusive), see
|
| 1579 |
+
* CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
|
| 1580 |
+
*
|
| 1581 |
+
* int32_t
|
| 1582 |
+
*/
|
| 1583 |
+
CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX = 7,
|
| 1584 |
+
|
| 1585 |
+
/** whether algorithm supports custom (not COL or ROW memory order), see cublasLtOrder_t
|
| 1586 |
+
*
|
| 1587 |
+
* int32_t 0 means only COL and ROW memory order is allowed, non-zero means that algo might have different
|
| 1588 |
+
* requirements;
|
| 1589 |
+
*/
|
| 1590 |
+
CUBLASLT_ALGO_CAP_CUSTOM_MEMORY_ORDER = 10,
|
| 1591 |
+
|
| 1592 |
+
/** bitmask enumerating pointer modes algorithm supports
|
| 1593 |
+
*
|
| 1594 |
+
* uint32_t, see cublasLtPointerModeMask_t
|
| 1595 |
+
*/
|
| 1596 |
+
CUBLASLT_ALGO_CAP_POINTER_MODE_MASK = 11,
|
| 1597 |
+
|
| 1598 |
+
/** bitmask enumerating kinds of postprocessing algorithm supports in the epilogue
|
| 1599 |
+
*
|
| 1600 |
+
* uint32_t, see cublasLtEpilogue_t
|
| 1601 |
+
*/
|
| 1602 |
+
CUBLASLT_ALGO_CAP_EPILOGUE_MASK = 12,
|
| 1603 |
+
|
| 1604 |
+
/** stages ids possible to use, see cublasLtMatmulStages_t; if no stages ids are supported use
|
| 1605 |
+
* CUBLASLT_MATMUL_STAGES_UNDEFINED
|
| 1606 |
+
*
|
| 1607 |
+
* use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count
|
| 1608 |
+
*
|
| 1609 |
+
* array of uint32_t
|
| 1610 |
+
*/
|
| 1611 |
+
CUBLASLT_ALGO_CAP_STAGES_IDS = 13,
|
| 1612 |
+
|
| 1613 |
+
/** support for nagative ld for all of the matrices
|
| 1614 |
+
*
|
| 1615 |
+
* int32_t 0 means no support, supported otherwise
|
| 1616 |
+
*/
|
| 1617 |
+
CUBLASLT_ALGO_CAP_LD_NEGATIVE = 14,
|
| 1618 |
+
|
| 1619 |
+
/** details about algorithm's implementation that affect it's numerical behavior
|
| 1620 |
+
*
|
| 1621 |
+
* uint64_t, see cublasLtNumericalImplFlags_t
|
| 1622 |
+
*/
|
| 1623 |
+
CUBLASLT_ALGO_CAP_NUMERICAL_IMPL_FLAGS = 15,
|
| 1624 |
+
|
| 1625 |
+
/** minimum alignment required for A matrix in bytes
|
| 1626 |
+
* (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
|
| 1627 |
+
*
|
| 1628 |
+
* uint32_t
|
| 1629 |
+
*/
|
| 1630 |
+
CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_A_BYTES = 16,
|
| 1631 |
+
|
| 1632 |
+
/** minimum alignment required for B matrix in bytes
|
| 1633 |
+
* (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
|
| 1634 |
+
*
|
| 1635 |
+
* uint32_t
|
| 1636 |
+
*/
|
| 1637 |
+
CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_B_BYTES = 17,
|
| 1638 |
+
|
| 1639 |
+
/** minimum alignment required for C matrix in bytes
|
| 1640 |
+
* (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
|
| 1641 |
+
*
|
| 1642 |
+
* uint32_t
|
| 1643 |
+
*/
|
| 1644 |
+
CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_C_BYTES = 18,
|
| 1645 |
+
|
| 1646 |
+
/** minimum alignment required for D matrix in bytes
|
| 1647 |
+
* (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
|
| 1648 |
+
*
|
| 1649 |
+
* uint32_t
|
| 1650 |
+
*/
|
| 1651 |
+
CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_D_BYTES = 19,
|
| 1652 |
+
|
| 1653 |
+
/** EXPERIMENTAL: support for synchronization via atomic counters
|
| 1654 |
+
*
|
| 1655 |
+
* int32_t
|
| 1656 |
+
*/
|
| 1657 |
+
CUBLASLT_ALGO_CAP_ATOMIC_SYNC = 20,
|
| 1658 |
+
} cublasLtMatmulAlgoCapAttributes_t;
|
| 1659 |
+
|
| 1660 |
+
/** Get algo capability attribute.
|
| 1661 |
+
*
|
| 1662 |
+
* E.g. to get list of supported Tile IDs:
|
| 1663 |
+
* cublasLtMatmulTile_t tiles[CUBLASLT_MATMUL_TILE_END];
|
| 1664 |
+
* size_t num_tiles, size_written;
|
| 1665 |
+
* if (cublasLtMatmulAlgoCapGetAttribute(algo, CUBLASLT_ALGO_CAP_TILE_IDS, tiles, sizeof(tiles), size_written) ==
|
| 1666 |
+
* CUBLAS_STATUS_SUCCESS) { num_tiles = size_written / sizeof(tiles[0]);
|
| 1667 |
+
* }
|
| 1668 |
+
*
|
| 1669 |
+
* \param[in] algo The algo descriptor
|
| 1670 |
+
* \param[in] attr The attribute
|
| 1671 |
+
* \param[out] buf memory address containing the new value
|
| 1672 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1673 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 1674 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 1675 |
+
*
|
| 1676 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 1677 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1678 |
+
* selected attribute
|
| 1679 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 1680 |
+
*/
|
| 1681 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute(const cublasLtMatmulAlgo_t* algo,
|
| 1682 |
+
cublasLtMatmulAlgoCapAttributes_t attr,
|
| 1683 |
+
void* buf,
|
| 1684 |
+
size_t sizeInBytes,
|
| 1685 |
+
size_t* sizeWritten);
|
| 1686 |
+
|
| 1687 |
+
/** Algo Configuration Attributes that can be set according to the Algo capabilities
|
| 1688 |
+
*/
|
| 1689 |
+
typedef enum {
|
| 1690 |
+
/** algorithm index, see cublasLtMatmulAlgoGetIds()
|
| 1691 |
+
*
|
| 1692 |
+
* readonly, set by cublasLtMatmulAlgoInit()
|
| 1693 |
+
* int32_t
|
| 1694 |
+
*/
|
| 1695 |
+
CUBLASLT_ALGO_CONFIG_ID = 0,
|
| 1696 |
+
/** tile id, see cublasLtMatmulTile_t
|
| 1697 |
+
*
|
| 1698 |
+
* uint32_t, default: CUBLASLT_MATMUL_TILE_UNDEFINED
|
| 1699 |
+
*/
|
| 1700 |
+
CUBLASLT_ALGO_CONFIG_TILE_ID = 1,
|
| 1701 |
+
/** Number of K splits. If the number of K splits is greater than one, SPLITK_NUM parts
|
| 1702 |
+
* of matrix multiplication will be computed in parallel. The results will be accumulated
|
| 1703 |
+
* according to CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
|
| 1704 |
+
*
|
| 1705 |
+
* int32_t, default: 1
|
| 1706 |
+
*/
|
| 1707 |
+
CUBLASLT_ALGO_CONFIG_SPLITK_NUM = 2,
|
| 1708 |
+
/** reduction scheme, see cublasLtReductionScheme_t
|
| 1709 |
+
*
|
| 1710 |
+
* uint32_t, default: CUBLASLT_REDUCTION_SCHEME_NONE
|
| 1711 |
+
*/
|
| 1712 |
+
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME = 3,
|
| 1713 |
+
/** cta swizzling, change mapping from CUDA grid coordinates to parts of the matrices
|
| 1714 |
+
*
|
| 1715 |
+
* possible values: 0, 1, other values reserved
|
| 1716 |
+
*
|
| 1717 |
+
* uint32_t, default: 0
|
| 1718 |
+
*/
|
| 1719 |
+
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING = 4,
|
| 1720 |
+
/** custom option, each algorithm can support some custom options that don't fit description of the other config
|
| 1721 |
+
* attributes, see CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX to get accepted range for any specific case
|
| 1722 |
+
*
|
| 1723 |
+
* uint32_t, default: 0
|
| 1724 |
+
*/
|
| 1725 |
+
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION = 5,
|
| 1726 |
+
/** stages id, see cublasLtMatmulStages_t
|
| 1727 |
+
*
|
| 1728 |
+
* uint32_t, default: CUBLASLT_MATMUL_STAGES_UNDEFINED
|
| 1729 |
+
*/
|
| 1730 |
+
CUBLASLT_ALGO_CONFIG_STAGES_ID = 6,
|
| 1731 |
+
/** inner shape id, see cublasLtMatmulInnerShape_t
|
| 1732 |
+
*
|
| 1733 |
+
* uint16_t, default: 0 (CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED)
|
| 1734 |
+
*/
|
| 1735 |
+
CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID = 7,
|
| 1736 |
+
/** Thread Block Cluster shape id, see cublasLtClusterShape_t. Defines cluster size to use.
|
| 1737 |
+
*
|
| 1738 |
+
* uint16_t, default: 0 (CUBLASLT_CLUSTER_SHAPE_AUTO)
|
| 1739 |
+
*/
|
| 1740 |
+
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID = 8,
|
| 1741 |
+
} cublasLtMatmulAlgoConfigAttributes_t;
|
| 1742 |
+
|
| 1743 |
+
/** Set algo configuration attribute.
|
| 1744 |
+
*
|
| 1745 |
+
* \param[in] algo The algo descriptor
|
| 1746 |
+
* \param[in] attr The attribute
|
| 1747 |
+
* \param[in] buf memory address containing the new value
|
| 1748 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1749 |
+
*
|
| 1750 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1751 |
+
* selected attribute
|
| 1752 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 1753 |
+
*/
|
| 1754 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute(cublasLtMatmulAlgo_t* algo,
|
| 1755 |
+
cublasLtMatmulAlgoConfigAttributes_t attr,
|
| 1756 |
+
const void* buf,
|
| 1757 |
+
size_t sizeInBytes);
|
| 1758 |
+
|
| 1759 |
+
/** Get algo configuration attribute.
|
| 1760 |
+
*
|
| 1761 |
+
* \param[in] algo The algo descriptor
|
| 1762 |
+
* \param[in] attr The attribute
|
| 1763 |
+
* \param[out] buf memory address containing the new value
|
| 1764 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1765 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 1766 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 1767 |
+
*
|
| 1768 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 1769 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1770 |
+
* selected attribute
|
| 1771 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 1772 |
+
*/
|
| 1773 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute(const cublasLtMatmulAlgo_t* algo,
|
| 1774 |
+
cublasLtMatmulAlgoConfigAttributes_t attr,
|
| 1775 |
+
void* buf,
|
| 1776 |
+
size_t sizeInBytes,
|
| 1777 |
+
size_t* sizeWritten);
|
| 1778 |
+
|
| 1779 |
+
/** Experimental: Logger callback type.
|
| 1780 |
+
*/
|
| 1781 |
+
typedef void (*cublasLtLoggerCallback_t)(int logLevel, const char* functionName, const char* message);
|
| 1782 |
+
|
| 1783 |
+
/** Experimental: Logger callback setter.
|
| 1784 |
+
*
|
| 1785 |
+
* \param[in] callback a user defined callback function to be called by the logger
|
| 1786 |
+
*
|
| 1787 |
+
* \retval CUBLAS_STATUS_SUCCESS if callback was set successfully
|
| 1788 |
+
*/
|
| 1789 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerSetCallback(cublasLtLoggerCallback_t callback);
|
| 1790 |
+
|
| 1791 |
+
/** Experimental: Log file setter.
|
| 1792 |
+
*
|
| 1793 |
+
* \param[in] file an open file with write permissions
|
| 1794 |
+
*
|
| 1795 |
+
* \retval CUBLAS_STATUS_SUCCESS if log file was set successfully
|
| 1796 |
+
*/
|
| 1797 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerSetFile(FILE* file);
|
| 1798 |
+
|
| 1799 |
+
/** Experimental: Open log file.
|
| 1800 |
+
*
|
| 1801 |
+
* \param[in] logFile log file path. if the log file does not exist, it will be created
|
| 1802 |
+
*
|
| 1803 |
+
* \retval CUBLAS_STATUS_SUCCESS if log file was created successfully
|
| 1804 |
+
*/
|
| 1805 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerOpenFile(const char* logFile);
|
| 1806 |
+
|
| 1807 |
+
/** Experimental: Log level setter.
|
| 1808 |
+
*
|
| 1809 |
+
* \param[in] level log level, should be one of the following:
|
| 1810 |
+
* 0. Off
|
| 1811 |
+
* 1. Errors
|
| 1812 |
+
* 2. Performance Trace
|
| 1813 |
+
* 3. Performance Hints
|
| 1814 |
+
* 4. Heuristics Trace
|
| 1815 |
+
* 5. API Trace
|
| 1816 |
+
*
|
| 1817 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if log level is not one of the above levels
|
| 1818 |
+
*
|
| 1819 |
+
* \retval CUBLAS_STATUS_SUCCESS if log level was set successfully
|
| 1820 |
+
*/
|
| 1821 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerSetLevel(int level);
|
| 1822 |
+
|
| 1823 |
+
/** Experimental: Log mask setter.
|
| 1824 |
+
*
|
| 1825 |
+
* \param[in] mask log mask, should be a combination of the following masks:
|
| 1826 |
+
* 0. Off
|
| 1827 |
+
* 1. Errors
|
| 1828 |
+
* 2. Performance Trace
|
| 1829 |
+
* 4. Performance Hints
|
| 1830 |
+
* 8. Heuristics Trace
|
| 1831 |
+
* 16. API Trace
|
| 1832 |
+
*
|
| 1833 |
+
* \retval CUBLAS_STATUS_SUCCESS if log mask was set successfully
|
| 1834 |
+
*/
|
| 1835 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerSetMask(int mask);
|
| 1836 |
+
|
| 1837 |
+
/** Experimental: Disable logging for the entire session.
|
| 1838 |
+
*
|
| 1839 |
+
* \retval CUBLAS_STATUS_SUCCESS if disabled logging
|
| 1840 |
+
*/
|
| 1841 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerForceDisable();
|
| 1842 |
+
|
| 1843 |
+
#if defined(__cplusplus)
|
| 1844 |
+
}
|
| 1845 |
+
#endif /* __cplusplus */
|
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasXt.h
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/* cublasXt : Host API, Out of Core and Multi-GPU BLAS Library
|
| 51 |
+
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUBLAS_XT_H_)
|
| 55 |
+
#define CUBLAS_XT_H_
|
| 56 |
+
|
| 57 |
+
#include "driver_types.h"
|
| 58 |
+
#include "cuComplex.h" /* import complex data type */
|
| 59 |
+
|
| 60 |
+
#include "cublas_v2.h"
|
| 61 |
+
|
| 62 |
+
#if defined(__cplusplus)
|
| 63 |
+
extern "C" {
|
| 64 |
+
#endif /* __cplusplus */
|
| 65 |
+
|
| 66 |
+
struct cublasXtContext;
|
| 67 |
+
typedef struct cublasXtContext* cublasXtHandle_t;
|
| 68 |
+
|
| 69 |
+
cublasStatus_t CUBLASWINAPI cublasXtCreate(cublasXtHandle_t* handle);
|
| 70 |
+
cublasStatus_t CUBLASWINAPI cublasXtDestroy(cublasXtHandle_t handle);
|
| 71 |
+
cublasStatus_t CUBLASWINAPI cublasXtGetNumBoards(int nbDevices, int deviceId[], int* nbBoards);
|
| 72 |
+
cublasStatus_t CUBLASWINAPI cublasXtMaxBoards(int* nbGpuBoards);
|
| 73 |
+
/* This routine selects the Gpus that the user want to use for CUBLAS-XT */
|
| 74 |
+
cublasStatus_t CUBLASWINAPI cublasXtDeviceSelect(cublasXtHandle_t handle, int nbDevices, int deviceId[]);
|
| 75 |
+
|
| 76 |
+
/* This routine allows to change the dimension of the tiles ( blockDim x blockDim ) */
|
| 77 |
+
cublasStatus_t CUBLASWINAPI cublasXtSetBlockDim(cublasXtHandle_t handle, int blockDim);
|
| 78 |
+
cublasStatus_t CUBLASWINAPI cublasXtGetBlockDim(cublasXtHandle_t handle, int* blockDim);
|
| 79 |
+
|
| 80 |
+
typedef enum { CUBLASXT_PINNING_DISABLED = 0, CUBLASXT_PINNING_ENABLED = 1 } cublasXtPinnedMemMode_t;
|
| 81 |
+
/* This routine allows to CUBLAS-XT to pin the Host memory if it find out that some of the matrix passed
|
| 82 |
+
are not pinned : Pinning/Unpinning the Host memory is still a costly operation
|
| 83 |
+
It is better if the user controls the memory on its own (by pinning/unpinning oly when necessary)
|
| 84 |
+
*/
|
| 85 |
+
cublasStatus_t CUBLASWINAPI cublasXtGetPinningMemMode(cublasXtHandle_t handle, cublasXtPinnedMemMode_t* mode);
|
| 86 |
+
cublasStatus_t CUBLASWINAPI cublasXtSetPinningMemMode(cublasXtHandle_t handle, cublasXtPinnedMemMode_t mode);
|
| 87 |
+
|
| 88 |
+
/* This routines is to provide a CPU Blas routines, used for too small sizes or hybrid computation */
|
| 89 |
+
typedef enum {
|
| 90 |
+
CUBLASXT_FLOAT = 0,
|
| 91 |
+
CUBLASXT_DOUBLE = 1,
|
| 92 |
+
CUBLASXT_COMPLEX = 2,
|
| 93 |
+
CUBLASXT_DOUBLECOMPLEX = 3,
|
| 94 |
+
} cublasXtOpType_t;
|
| 95 |
+
|
| 96 |
+
typedef enum {
|
| 97 |
+
CUBLASXT_GEMM = 0,
|
| 98 |
+
CUBLASXT_SYRK = 1,
|
| 99 |
+
CUBLASXT_HERK = 2,
|
| 100 |
+
CUBLASXT_SYMM = 3,
|
| 101 |
+
CUBLASXT_HEMM = 4,
|
| 102 |
+
CUBLASXT_TRSM = 5,
|
| 103 |
+
CUBLASXT_SYR2K = 6,
|
| 104 |
+
CUBLASXT_HER2K = 7,
|
| 105 |
+
|
| 106 |
+
CUBLASXT_SPMM = 8,
|
| 107 |
+
CUBLASXT_SYRKX = 9,
|
| 108 |
+
CUBLASXT_HERKX = 10,
|
| 109 |
+
CUBLASXT_TRMM = 11,
|
| 110 |
+
CUBLASXT_ROUTINE_MAX = 12,
|
| 111 |
+
} cublasXtBlasOp_t;
|
| 112 |
+
|
| 113 |
+
/* Currently only 32-bit integer BLAS routines are supported */
|
| 114 |
+
cublasStatus_t CUBLASWINAPI cublasXtSetCpuRoutine(cublasXtHandle_t handle,
|
| 115 |
+
cublasXtBlasOp_t blasOp,
|
| 116 |
+
cublasXtOpType_t type,
|
| 117 |
+
void* blasFunctor);
|
| 118 |
+
|
| 119 |
+
/* Specified the percentage of work that should done by the CPU, default is 0 (no work) */
|
| 120 |
+
cublasStatus_t CUBLASWINAPI cublasXtSetCpuRatio(cublasXtHandle_t handle,
|
| 121 |
+
cublasXtBlasOp_t blasOp,
|
| 122 |
+
cublasXtOpType_t type,
|
| 123 |
+
float ratio);
|
| 124 |
+
|
| 125 |
+
/* GEMM */
|
| 126 |
+
cublasStatus_t CUBLASWINAPI cublasXtSgemm(cublasXtHandle_t handle,
|
| 127 |
+
cublasOperation_t transa,
|
| 128 |
+
cublasOperation_t transb,
|
| 129 |
+
size_t m,
|
| 130 |
+
size_t n,
|
| 131 |
+
size_t k,
|
| 132 |
+
const float* alpha,
|
| 133 |
+
const float* A,
|
| 134 |
+
size_t lda,
|
| 135 |
+
const float* B,
|
| 136 |
+
size_t ldb,
|
| 137 |
+
const float* beta,
|
| 138 |
+
float* C,
|
| 139 |
+
size_t ldc);
|
| 140 |
+
|
| 141 |
+
cublasStatus_t CUBLASWINAPI cublasXtDgemm(cublasXtHandle_t handle,
|
| 142 |
+
cublasOperation_t transa,
|
| 143 |
+
cublasOperation_t transb,
|
| 144 |
+
size_t m,
|
| 145 |
+
size_t n,
|
| 146 |
+
size_t k,
|
| 147 |
+
const double* alpha,
|
| 148 |
+
const double* A,
|
| 149 |
+
size_t lda,
|
| 150 |
+
const double* B,
|
| 151 |
+
size_t ldb,
|
| 152 |
+
const double* beta,
|
| 153 |
+
double* C,
|
| 154 |
+
size_t ldc);
|
| 155 |
+
|
| 156 |
+
cublasStatus_t CUBLASWINAPI cublasXtCgemm(cublasXtHandle_t handle,
|
| 157 |
+
cublasOperation_t transa,
|
| 158 |
+
cublasOperation_t transb,
|
| 159 |
+
size_t m,
|
| 160 |
+
size_t n,
|
| 161 |
+
size_t k,
|
| 162 |
+
const cuComplex* alpha,
|
| 163 |
+
const cuComplex* A,
|
| 164 |
+
size_t lda,
|
| 165 |
+
const cuComplex* B,
|
| 166 |
+
size_t ldb,
|
| 167 |
+
const cuComplex* beta,
|
| 168 |
+
cuComplex* C,
|
| 169 |
+
size_t ldc);
|
| 170 |
+
|
| 171 |
+
cublasStatus_t CUBLASWINAPI cublasXtZgemm(cublasXtHandle_t handle,
|
| 172 |
+
cublasOperation_t transa,
|
| 173 |
+
cublasOperation_t transb,
|
| 174 |
+
size_t m,
|
| 175 |
+
size_t n,
|
| 176 |
+
size_t k,
|
| 177 |
+
const cuDoubleComplex* alpha,
|
| 178 |
+
const cuDoubleComplex* A,
|
| 179 |
+
size_t lda,
|
| 180 |
+
const cuDoubleComplex* B,
|
| 181 |
+
size_t ldb,
|
| 182 |
+
const cuDoubleComplex* beta,
|
| 183 |
+
cuDoubleComplex* C,
|
| 184 |
+
size_t ldc);
|
| 185 |
+
/* ------------------------------------------------------- */
|
| 186 |
+
/* SYRK */
|
| 187 |
+
cublasStatus_t CUBLASWINAPI cublasXtSsyrk(cublasXtHandle_t handle,
|
| 188 |
+
cublasFillMode_t uplo,
|
| 189 |
+
cublasOperation_t trans,
|
| 190 |
+
size_t n,
|
| 191 |
+
size_t k,
|
| 192 |
+
const float* alpha,
|
| 193 |
+
const float* A,
|
| 194 |
+
size_t lda,
|
| 195 |
+
const float* beta,
|
| 196 |
+
float* C,
|
| 197 |
+
size_t ldc);
|
| 198 |
+
|
| 199 |
+
cublasStatus_t CUBLASWINAPI cublasXtDsyrk(cublasXtHandle_t handle,
|
| 200 |
+
cublasFillMode_t uplo,
|
| 201 |
+
cublasOperation_t trans,
|
| 202 |
+
size_t n,
|
| 203 |
+
size_t k,
|
| 204 |
+
const double* alpha,
|
| 205 |
+
const double* A,
|
| 206 |
+
size_t lda,
|
| 207 |
+
const double* beta,
|
| 208 |
+
double* C,
|
| 209 |
+
size_t ldc);
|
| 210 |
+
|
| 211 |
+
cublasStatus_t CUBLASWINAPI cublasXtCsyrk(cublasXtHandle_t handle,
|
| 212 |
+
cublasFillMode_t uplo,
|
| 213 |
+
cublasOperation_t trans,
|
| 214 |
+
size_t n,
|
| 215 |
+
size_t k,
|
| 216 |
+
const cuComplex* alpha,
|
| 217 |
+
const cuComplex* A,
|
| 218 |
+
size_t lda,
|
| 219 |
+
const cuComplex* beta,
|
| 220 |
+
cuComplex* C,
|
| 221 |
+
size_t ldc);
|
| 222 |
+
|
| 223 |
+
cublasStatus_t CUBLASWINAPI cublasXtZsyrk(cublasXtHandle_t handle,
|
| 224 |
+
cublasFillMode_t uplo,
|
| 225 |
+
cublasOperation_t trans,
|
| 226 |
+
size_t n,
|
| 227 |
+
size_t k,
|
| 228 |
+
const cuDoubleComplex* alpha,
|
| 229 |
+
const cuDoubleComplex* A,
|
| 230 |
+
size_t lda,
|
| 231 |
+
const cuDoubleComplex* beta,
|
| 232 |
+
cuDoubleComplex* C,
|
| 233 |
+
size_t ldc);
|
| 234 |
+
/* -------------------------------------------------------------------- */
|
| 235 |
+
/* HERK */
|
| 236 |
+
cublasStatus_t CUBLASWINAPI cublasXtCherk(cublasXtHandle_t handle,
|
| 237 |
+
cublasFillMode_t uplo,
|
| 238 |
+
cublasOperation_t trans,
|
| 239 |
+
size_t n,
|
| 240 |
+
size_t k,
|
| 241 |
+
const float* alpha,
|
| 242 |
+
const cuComplex* A,
|
| 243 |
+
size_t lda,
|
| 244 |
+
const float* beta,
|
| 245 |
+
cuComplex* C,
|
| 246 |
+
size_t ldc);
|
| 247 |
+
|
| 248 |
+
cublasStatus_t CUBLASWINAPI cublasXtZherk(cublasXtHandle_t handle,
|
| 249 |
+
cublasFillMode_t uplo,
|
| 250 |
+
cublasOperation_t trans,
|
| 251 |
+
size_t n,
|
| 252 |
+
size_t k,
|
| 253 |
+
const double* alpha,
|
| 254 |
+
const cuDoubleComplex* A,
|
| 255 |
+
size_t lda,
|
| 256 |
+
const double* beta,
|
| 257 |
+
cuDoubleComplex* C,
|
| 258 |
+
size_t ldc);
|
| 259 |
+
/* -------------------------------------------------------------------- */
|
| 260 |
+
/* SYR2K */
|
| 261 |
+
cublasStatus_t CUBLASWINAPI cublasXtSsyr2k(cublasXtHandle_t handle,
|
| 262 |
+
cublasFillMode_t uplo,
|
| 263 |
+
cublasOperation_t trans,
|
| 264 |
+
size_t n,
|
| 265 |
+
size_t k,
|
| 266 |
+
const float* alpha,
|
| 267 |
+
const float* A,
|
| 268 |
+
size_t lda,
|
| 269 |
+
const float* B,
|
| 270 |
+
size_t ldb,
|
| 271 |
+
const float* beta,
|
| 272 |
+
float* C,
|
| 273 |
+
size_t ldc);
|
| 274 |
+
|
| 275 |
+
cublasStatus_t CUBLASWINAPI cublasXtDsyr2k(cublasXtHandle_t handle,
|
| 276 |
+
cublasFillMode_t uplo,
|
| 277 |
+
cublasOperation_t trans,
|
| 278 |
+
size_t n,
|
| 279 |
+
size_t k,
|
| 280 |
+
const double* alpha,
|
| 281 |
+
const double* A,
|
| 282 |
+
size_t lda,
|
| 283 |
+
const double* B,
|
| 284 |
+
size_t ldb,
|
| 285 |
+
const double* beta,
|
| 286 |
+
double* C,
|
| 287 |
+
size_t ldc);
|
| 288 |
+
|
| 289 |
+
cublasStatus_t CUBLASWINAPI cublasXtCsyr2k(cublasXtHandle_t handle,
|
| 290 |
+
cublasFillMode_t uplo,
|
| 291 |
+
cublasOperation_t trans,
|
| 292 |
+
size_t n,
|
| 293 |
+
size_t k,
|
| 294 |
+
const cuComplex* alpha,
|
| 295 |
+
const cuComplex* A,
|
| 296 |
+
size_t lda,
|
| 297 |
+
const cuComplex* B,
|
| 298 |
+
size_t ldb,
|
| 299 |
+
const cuComplex* beta,
|
| 300 |
+
cuComplex* C,
|
| 301 |
+
size_t ldc);
|
| 302 |
+
|
| 303 |
+
cublasStatus_t CUBLASWINAPI cublasXtZsyr2k(cublasXtHandle_t handle,
|
| 304 |
+
cublasFillMode_t uplo,
|
| 305 |
+
cublasOperation_t trans,
|
| 306 |
+
size_t n,
|
| 307 |
+
size_t k,
|
| 308 |
+
const cuDoubleComplex* alpha,
|
| 309 |
+
const cuDoubleComplex* A,
|
| 310 |
+
size_t lda,
|
| 311 |
+
const cuDoubleComplex* B,
|
| 312 |
+
size_t ldb,
|
| 313 |
+
const cuDoubleComplex* beta,
|
| 314 |
+
cuDoubleComplex* C,
|
| 315 |
+
size_t ldc);
|
| 316 |
+
/* -------------------------------------------------------------------- */
|
| 317 |
+
/* HERKX : variant extension of HERK */
|
| 318 |
+
cublasStatus_t CUBLASWINAPI cublasXtCherkx(cublasXtHandle_t handle,
|
| 319 |
+
cublasFillMode_t uplo,
|
| 320 |
+
cublasOperation_t trans,
|
| 321 |
+
size_t n,
|
| 322 |
+
size_t k,
|
| 323 |
+
const cuComplex* alpha,
|
| 324 |
+
const cuComplex* A,
|
| 325 |
+
size_t lda,
|
| 326 |
+
const cuComplex* B,
|
| 327 |
+
size_t ldb,
|
| 328 |
+
const float* beta,
|
| 329 |
+
cuComplex* C,
|
| 330 |
+
size_t ldc);
|
| 331 |
+
|
| 332 |
+
cublasStatus_t CUBLASWINAPI cublasXtZherkx(cublasXtHandle_t handle,
|
| 333 |
+
cublasFillMode_t uplo,
|
| 334 |
+
cublasOperation_t trans,
|
| 335 |
+
size_t n,
|
| 336 |
+
size_t k,
|
| 337 |
+
const cuDoubleComplex* alpha,
|
| 338 |
+
const cuDoubleComplex* A,
|
| 339 |
+
size_t lda,
|
| 340 |
+
const cuDoubleComplex* B,
|
| 341 |
+
size_t ldb,
|
| 342 |
+
const double* beta,
|
| 343 |
+
cuDoubleComplex* C,
|
| 344 |
+
size_t ldc);
|
| 345 |
+
|
| 346 |
+
/* -------------------------------------------------------------------- */
|
| 347 |
+
/* TRSM */
|
| 348 |
+
cublasStatus_t CUBLASWINAPI cublasXtStrsm(cublasXtHandle_t handle,
|
| 349 |
+
cublasSideMode_t side,
|
| 350 |
+
cublasFillMode_t uplo,
|
| 351 |
+
cublasOperation_t trans,
|
| 352 |
+
cublasDiagType_t diag,
|
| 353 |
+
size_t m,
|
| 354 |
+
size_t n,
|
| 355 |
+
const float* alpha,
|
| 356 |
+
const float* A,
|
| 357 |
+
size_t lda,
|
| 358 |
+
float* B,
|
| 359 |
+
size_t ldb);
|
| 360 |
+
|
| 361 |
+
cublasStatus_t CUBLASWINAPI cublasXtDtrsm(cublasXtHandle_t handle,
|
| 362 |
+
cublasSideMode_t side,
|
| 363 |
+
cublasFillMode_t uplo,
|
| 364 |
+
cublasOperation_t trans,
|
| 365 |
+
cublasDiagType_t diag,
|
| 366 |
+
size_t m,
|
| 367 |
+
size_t n,
|
| 368 |
+
const double* alpha,
|
| 369 |
+
const double* A,
|
| 370 |
+
size_t lda,
|
| 371 |
+
double* B,
|
| 372 |
+
size_t ldb);
|
| 373 |
+
|
| 374 |
+
cublasStatus_t CUBLASWINAPI cublasXtCtrsm(cublasXtHandle_t handle,
|
| 375 |
+
cublasSideMode_t side,
|
| 376 |
+
cublasFillMode_t uplo,
|
| 377 |
+
cublasOperation_t trans,
|
| 378 |
+
cublasDiagType_t diag,
|
| 379 |
+
size_t m,
|
| 380 |
+
size_t n,
|
| 381 |
+
const cuComplex* alpha,
|
| 382 |
+
const cuComplex* A,
|
| 383 |
+
size_t lda,
|
| 384 |
+
cuComplex* B,
|
| 385 |
+
size_t ldb);
|
| 386 |
+
|
| 387 |
+
cublasStatus_t CUBLASWINAPI cublasXtZtrsm(cublasXtHandle_t handle,
|
| 388 |
+
cublasSideMode_t side,
|
| 389 |
+
cublasFillMode_t uplo,
|
| 390 |
+
cublasOperation_t trans,
|
| 391 |
+
cublasDiagType_t diag,
|
| 392 |
+
size_t m,
|
| 393 |
+
size_t n,
|
| 394 |
+
const cuDoubleComplex* alpha,
|
| 395 |
+
const cuDoubleComplex* A,
|
| 396 |
+
size_t lda,
|
| 397 |
+
cuDoubleComplex* B,
|
| 398 |
+
size_t ldb);
|
| 399 |
+
/* -------------------------------------------------------------------- */
|
| 400 |
+
/* SYMM : Symmetric Multiply Matrix*/
|
| 401 |
+
cublasStatus_t CUBLASWINAPI cublasXtSsymm(cublasXtHandle_t handle,
|
| 402 |
+
cublasSideMode_t side,
|
| 403 |
+
cublasFillMode_t uplo,
|
| 404 |
+
size_t m,
|
| 405 |
+
size_t n,
|
| 406 |
+
const float* alpha,
|
| 407 |
+
const float* A,
|
| 408 |
+
size_t lda,
|
| 409 |
+
const float* B,
|
| 410 |
+
size_t ldb,
|
| 411 |
+
const float* beta,
|
| 412 |
+
float* C,
|
| 413 |
+
size_t ldc);
|
| 414 |
+
|
| 415 |
+
cublasStatus_t CUBLASWINAPI cublasXtDsymm(cublasXtHandle_t handle,
|
| 416 |
+
cublasSideMode_t side,
|
| 417 |
+
cublasFillMode_t uplo,
|
| 418 |
+
size_t m,
|
| 419 |
+
size_t n,
|
| 420 |
+
const double* alpha,
|
| 421 |
+
const double* A,
|
| 422 |
+
size_t lda,
|
| 423 |
+
const double* B,
|
| 424 |
+
size_t ldb,
|
| 425 |
+
const double* beta,
|
| 426 |
+
double* C,
|
| 427 |
+
size_t ldc);
|
| 428 |
+
|
| 429 |
+
cublasStatus_t CUBLASWINAPI cublasXtCsymm(cublasXtHandle_t handle,
|
| 430 |
+
cublasSideMode_t side,
|
| 431 |
+
cublasFillMode_t uplo,
|
| 432 |
+
size_t m,
|
| 433 |
+
size_t n,
|
| 434 |
+
const cuComplex* alpha,
|
| 435 |
+
const cuComplex* A,
|
| 436 |
+
size_t lda,
|
| 437 |
+
const cuComplex* B,
|
| 438 |
+
size_t ldb,
|
| 439 |
+
const cuComplex* beta,
|
| 440 |
+
cuComplex* C,
|
| 441 |
+
size_t ldc);
|
| 442 |
+
|
| 443 |
+
cublasStatus_t CUBLASWINAPI cublasXtZsymm(cublasXtHandle_t handle,
|
| 444 |
+
cublasSideMode_t side,
|
| 445 |
+
cublasFillMode_t uplo,
|
| 446 |
+
size_t m,
|
| 447 |
+
size_t n,
|
| 448 |
+
const cuDoubleComplex* alpha,
|
| 449 |
+
const cuDoubleComplex* A,
|
| 450 |
+
size_t lda,
|
| 451 |
+
const cuDoubleComplex* B,
|
| 452 |
+
size_t ldb,
|
| 453 |
+
const cuDoubleComplex* beta,
|
| 454 |
+
cuDoubleComplex* C,
|
| 455 |
+
size_t ldc);
|
| 456 |
+
/* -------------------------------------------------------------------- */
|
| 457 |
+
/* HEMM : Hermitian Matrix Multiply */
|
| 458 |
+
cublasStatus_t CUBLASWINAPI cublasXtChemm(cublasXtHandle_t handle,
|
| 459 |
+
cublasSideMode_t side,
|
| 460 |
+
cublasFillMode_t uplo,
|
| 461 |
+
size_t m,
|
| 462 |
+
size_t n,
|
| 463 |
+
const cuComplex* alpha,
|
| 464 |
+
const cuComplex* A,
|
| 465 |
+
size_t lda,
|
| 466 |
+
const cuComplex* B,
|
| 467 |
+
size_t ldb,
|
| 468 |
+
const cuComplex* beta,
|
| 469 |
+
cuComplex* C,
|
| 470 |
+
size_t ldc);
|
| 471 |
+
|
| 472 |
+
cublasStatus_t CUBLASWINAPI cublasXtZhemm(cublasXtHandle_t handle,
|
| 473 |
+
cublasSideMode_t side,
|
| 474 |
+
cublasFillMode_t uplo,
|
| 475 |
+
size_t m,
|
| 476 |
+
size_t n,
|
| 477 |
+
const cuDoubleComplex* alpha,
|
| 478 |
+
const cuDoubleComplex* A,
|
| 479 |
+
size_t lda,
|
| 480 |
+
const cuDoubleComplex* B,
|
| 481 |
+
size_t ldb,
|
| 482 |
+
const cuDoubleComplex* beta,
|
| 483 |
+
cuDoubleComplex* C,
|
| 484 |
+
size_t ldc);
|
| 485 |
+
|
| 486 |
+
/* -------------------------------------------------------------------- */
|
| 487 |
+
/* SYRKX : variant extension of SYRK */
|
| 488 |
+
cublasStatus_t CUBLASWINAPI cublasXtSsyrkx(cublasXtHandle_t handle,
|
| 489 |
+
cublasFillMode_t uplo,
|
| 490 |
+
cublasOperation_t trans,
|
| 491 |
+
size_t n,
|
| 492 |
+
size_t k,
|
| 493 |
+
const float* alpha,
|
| 494 |
+
const float* A,
|
| 495 |
+
size_t lda,
|
| 496 |
+
const float* B,
|
| 497 |
+
size_t ldb,
|
| 498 |
+
const float* beta,
|
| 499 |
+
float* C,
|
| 500 |
+
size_t ldc);
|
| 501 |
+
|
| 502 |
+
cublasStatus_t CUBLASWINAPI cublasXtDsyrkx(cublasXtHandle_t handle,
|
| 503 |
+
cublasFillMode_t uplo,
|
| 504 |
+
cublasOperation_t trans,
|
| 505 |
+
size_t n,
|
| 506 |
+
size_t k,
|
| 507 |
+
const double* alpha,
|
| 508 |
+
const double* A,
|
| 509 |
+
size_t lda,
|
| 510 |
+
const double* B,
|
| 511 |
+
size_t ldb,
|
| 512 |
+
const double* beta,
|
| 513 |
+
double* C,
|
| 514 |
+
size_t ldc);
|
| 515 |
+
|
| 516 |
+
cublasStatus_t CUBLASWINAPI cublasXtCsyrkx(cublasXtHandle_t handle,
|
| 517 |
+
cublasFillMode_t uplo,
|
| 518 |
+
cublasOperation_t trans,
|
| 519 |
+
size_t n,
|
| 520 |
+
size_t k,
|
| 521 |
+
const cuComplex* alpha,
|
| 522 |
+
const cuComplex* A,
|
| 523 |
+
size_t lda,
|
| 524 |
+
const cuComplex* B,
|
| 525 |
+
size_t ldb,
|
| 526 |
+
const cuComplex* beta,
|
| 527 |
+
cuComplex* C,
|
| 528 |
+
size_t ldc);
|
| 529 |
+
|
| 530 |
+
cublasStatus_t CUBLASWINAPI cublasXtZsyrkx(cublasXtHandle_t handle,
|
| 531 |
+
cublasFillMode_t uplo,
|
| 532 |
+
cublasOperation_t trans,
|
| 533 |
+
size_t n,
|
| 534 |
+
size_t k,
|
| 535 |
+
const cuDoubleComplex* alpha,
|
| 536 |
+
const cuDoubleComplex* A,
|
| 537 |
+
size_t lda,
|
| 538 |
+
const cuDoubleComplex* B,
|
| 539 |
+
size_t ldb,
|
| 540 |
+
const cuDoubleComplex* beta,
|
| 541 |
+
cuDoubleComplex* C,
|
| 542 |
+
size_t ldc);
|
| 543 |
+
/* -------------------------------------------------------------------- */
|
| 544 |
+
/* HER2K : variant extension of HERK */
|
| 545 |
+
cublasStatus_t CUBLASWINAPI cublasXtCher2k(cublasXtHandle_t handle,
|
| 546 |
+
cublasFillMode_t uplo,
|
| 547 |
+
cublasOperation_t trans,
|
| 548 |
+
size_t n,
|
| 549 |
+
size_t k,
|
| 550 |
+
const cuComplex* alpha,
|
| 551 |
+
const cuComplex* A,
|
| 552 |
+
size_t lda,
|
| 553 |
+
const cuComplex* B,
|
| 554 |
+
size_t ldb,
|
| 555 |
+
const float* beta,
|
| 556 |
+
cuComplex* C,
|
| 557 |
+
size_t ldc);
|
| 558 |
+
|
| 559 |
+
cublasStatus_t CUBLASWINAPI cublasXtZher2k(cublasXtHandle_t handle,
|
| 560 |
+
cublasFillMode_t uplo,
|
| 561 |
+
cublasOperation_t trans,
|
| 562 |
+
size_t n,
|
| 563 |
+
size_t k,
|
| 564 |
+
const cuDoubleComplex* alpha,
|
| 565 |
+
const cuDoubleComplex* A,
|
| 566 |
+
size_t lda,
|
| 567 |
+
const cuDoubleComplex* B,
|
| 568 |
+
size_t ldb,
|
| 569 |
+
const double* beta,
|
| 570 |
+
cuDoubleComplex* C,
|
| 571 |
+
size_t ldc);
|
| 572 |
+
|
| 573 |
+
/* -------------------------------------------------------------------- */
|
| 574 |
+
/* SPMM : Symmetric Packed Multiply Matrix*/
|
| 575 |
+
cublasStatus_t CUBLASWINAPI cublasXtSspmm(cublasXtHandle_t handle,
|
| 576 |
+
cublasSideMode_t side,
|
| 577 |
+
cublasFillMode_t uplo,
|
| 578 |
+
size_t m,
|
| 579 |
+
size_t n,
|
| 580 |
+
const float* alpha,
|
| 581 |
+
const float* AP,
|
| 582 |
+
const float* B,
|
| 583 |
+
size_t ldb,
|
| 584 |
+
const float* beta,
|
| 585 |
+
float* C,
|
| 586 |
+
size_t ldc);
|
| 587 |
+
|
| 588 |
+
cublasStatus_t CUBLASWINAPI cublasXtDspmm(cublasXtHandle_t handle,
|
| 589 |
+
cublasSideMode_t side,
|
| 590 |
+
cublasFillMode_t uplo,
|
| 591 |
+
size_t m,
|
| 592 |
+
size_t n,
|
| 593 |
+
const double* alpha,
|
| 594 |
+
const double* AP,
|
| 595 |
+
const double* B,
|
| 596 |
+
size_t ldb,
|
| 597 |
+
const double* beta,
|
| 598 |
+
double* C,
|
| 599 |
+
size_t ldc);
|
| 600 |
+
|
| 601 |
+
cublasStatus_t CUBLASWINAPI cublasXtCspmm(cublasXtHandle_t handle,
|
| 602 |
+
cublasSideMode_t side,
|
| 603 |
+
cublasFillMode_t uplo,
|
| 604 |
+
size_t m,
|
| 605 |
+
size_t n,
|
| 606 |
+
const cuComplex* alpha,
|
| 607 |
+
const cuComplex* AP,
|
| 608 |
+
const cuComplex* B,
|
| 609 |
+
size_t ldb,
|
| 610 |
+
const cuComplex* beta,
|
| 611 |
+
cuComplex* C,
|
| 612 |
+
size_t ldc);
|
| 613 |
+
|
| 614 |
+
cublasStatus_t CUBLASWINAPI cublasXtZspmm(cublasXtHandle_t handle,
|
| 615 |
+
cublasSideMode_t side,
|
| 616 |
+
cublasFillMode_t uplo,
|
| 617 |
+
size_t m,
|
| 618 |
+
size_t n,
|
| 619 |
+
const cuDoubleComplex* alpha,
|
| 620 |
+
const cuDoubleComplex* AP,
|
| 621 |
+
const cuDoubleComplex* B,
|
| 622 |
+
size_t ldb,
|
| 623 |
+
const cuDoubleComplex* beta,
|
| 624 |
+
cuDoubleComplex* C,
|
| 625 |
+
size_t ldc);
|
| 626 |
+
|
| 627 |
+
/* -------------------------------------------------------------------- */
|
| 628 |
+
/* TRMM */
|
| 629 |
+
cublasStatus_t CUBLASWINAPI cublasXtStrmm(cublasXtHandle_t handle,
|
| 630 |
+
cublasSideMode_t side,
|
| 631 |
+
cublasFillMode_t uplo,
|
| 632 |
+
cublasOperation_t trans,
|
| 633 |
+
cublasDiagType_t diag,
|
| 634 |
+
size_t m,
|
| 635 |
+
size_t n,
|
| 636 |
+
const float* alpha,
|
| 637 |
+
const float* A,
|
| 638 |
+
size_t lda,
|
| 639 |
+
const float* B,
|
| 640 |
+
size_t ldb,
|
| 641 |
+
float* C,
|
| 642 |
+
size_t ldc);
|
| 643 |
+
|
| 644 |
+
cublasStatus_t CUBLASWINAPI cublasXtDtrmm(cublasXtHandle_t handle,
|
| 645 |
+
cublasSideMode_t side,
|
| 646 |
+
cublasFillMode_t uplo,
|
| 647 |
+
cublasOperation_t trans,
|
| 648 |
+
cublasDiagType_t diag,
|
| 649 |
+
size_t m,
|
| 650 |
+
size_t n,
|
| 651 |
+
const double* alpha,
|
| 652 |
+
const double* A,
|
| 653 |
+
size_t lda,
|
| 654 |
+
const double* B,
|
| 655 |
+
size_t ldb,
|
| 656 |
+
double* C,
|
| 657 |
+
size_t ldc);
|
| 658 |
+
|
| 659 |
+
cublasStatus_t CUBLASWINAPI cublasXtCtrmm(cublasXtHandle_t handle,
|
| 660 |
+
cublasSideMode_t side,
|
| 661 |
+
cublasFillMode_t uplo,
|
| 662 |
+
cublasOperation_t trans,
|
| 663 |
+
cublasDiagType_t diag,
|
| 664 |
+
size_t m,
|
| 665 |
+
size_t n,
|
| 666 |
+
const cuComplex* alpha,
|
| 667 |
+
const cuComplex* A,
|
| 668 |
+
size_t lda,
|
| 669 |
+
const cuComplex* B,
|
| 670 |
+
size_t ldb,
|
| 671 |
+
cuComplex* C,
|
| 672 |
+
size_t ldc);
|
| 673 |
+
|
| 674 |
+
cublasStatus_t CUBLASWINAPI cublasXtZtrmm(cublasXtHandle_t handle,
|
| 675 |
+
cublasSideMode_t side,
|
| 676 |
+
cublasFillMode_t uplo,
|
| 677 |
+
cublasOperation_t trans,
|
| 678 |
+
cublasDiagType_t diag,
|
| 679 |
+
size_t m,
|
| 680 |
+
size_t n,
|
| 681 |
+
const cuDoubleComplex* alpha,
|
| 682 |
+
const cuDoubleComplex* A,
|
| 683 |
+
size_t lda,
|
| 684 |
+
const cuDoubleComplex* B,
|
| 685 |
+
size_t ldb,
|
| 686 |
+
cuDoubleComplex* C,
|
| 687 |
+
size_t ldc);
|
| 688 |
+
|
| 689 |
+
#if defined(__cplusplus)
|
| 690 |
+
}
|
| 691 |
+
#endif /* __cplusplus */
|
| 692 |
+
|
| 693 |
+
#endif /* !defined(CUBLAS_XT_H_) */
|
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* This is the public header file for the new CUBLAS library API, it mapped the generic
|
| 52 |
+
* Cublas name functions to the actual _v2 implementations.
|
| 53 |
+
*/
|
| 54 |
+
|
| 55 |
+
#if !defined(CUBLAS_V2_H_)
|
| 56 |
+
#define CUBLAS_V2_H_
|
| 57 |
+
|
| 58 |
+
#if defined(CUBLAS_H_)
|
| 59 |
+
#error "It is an error to include both cublas.h and cublas_v2.h"
|
| 60 |
+
#endif
|
| 61 |
+
|
| 62 |
+
#undef CUBLASAPI
|
| 63 |
+
#ifdef __CUDACC__
|
| 64 |
+
#define CUBLASAPI __host__ __device__
|
| 65 |
+
#else
|
| 66 |
+
#define CUBLASAPI
|
| 67 |
+
#endif
|
| 68 |
+
|
| 69 |
+
#include "cublas_api.h"
|
| 70 |
+
|
| 71 |
+
#define cublasCreate cublasCreate_v2
|
| 72 |
+
#define cublasDestroy cublasDestroy_v2
|
| 73 |
+
#define cublasGetVersion cublasGetVersion_v2
|
| 74 |
+
#define cublasSetWorkspace cublasSetWorkspace_v2
|
| 75 |
+
#define cublasSetStream cublasSetStream_v2
|
| 76 |
+
#define cublasGetStream cublasGetStream_v2
|
| 77 |
+
#define cublasGetPointerMode cublasGetPointerMode_v2
|
| 78 |
+
#define cublasSetPointerMode cublasSetPointerMode_v2
|
| 79 |
+
|
| 80 |
+
/* 32-bit integer */
|
| 81 |
+
|
| 82 |
+
/* Blas1 Routines */
|
| 83 |
+
|
| 84 |
+
#define cublasSnrm2 cublasSnrm2_v2
|
| 85 |
+
#define cublasDnrm2 cublasDnrm2_v2
|
| 86 |
+
#define cublasScnrm2 cublasScnrm2_v2
|
| 87 |
+
#define cublasDznrm2 cublasDznrm2_v2
|
| 88 |
+
|
| 89 |
+
#define cublasSdot cublasSdot_v2
|
| 90 |
+
#define cublasDdot cublasDdot_v2
|
| 91 |
+
#define cublasCdotu cublasCdotu_v2
|
| 92 |
+
#define cublasCdotc cublasCdotc_v2
|
| 93 |
+
#define cublasZdotu cublasZdotu_v2
|
| 94 |
+
#define cublasZdotc cublasZdotc_v2
|
| 95 |
+
|
| 96 |
+
#define cublasSscal cublasSscal_v2
|
| 97 |
+
#define cublasDscal cublasDscal_v2
|
| 98 |
+
#define cublasCscal cublasCscal_v2
|
| 99 |
+
#define cublasCsscal cublasCsscal_v2
|
| 100 |
+
#define cublasZscal cublasZscal_v2
|
| 101 |
+
#define cublasZdscal cublasZdscal_v2
|
| 102 |
+
|
| 103 |
+
#define cublasSaxpy cublasSaxpy_v2
|
| 104 |
+
#define cublasDaxpy cublasDaxpy_v2
|
| 105 |
+
#define cublasCaxpy cublasCaxpy_v2
|
| 106 |
+
#define cublasZaxpy cublasZaxpy_v2
|
| 107 |
+
|
| 108 |
+
#define cublasScopy cublasScopy_v2
|
| 109 |
+
#define cublasDcopy cublasDcopy_v2
|
| 110 |
+
#define cublasCcopy cublasCcopy_v2
|
| 111 |
+
#define cublasZcopy cublasZcopy_v2
|
| 112 |
+
|
| 113 |
+
#define cublasSswap cublasSswap_v2
|
| 114 |
+
#define cublasDswap cublasDswap_v2
|
| 115 |
+
#define cublasCswap cublasCswap_v2
|
| 116 |
+
#define cublasZswap cublasZswap_v2
|
| 117 |
+
|
| 118 |
+
#define cublasIsamax cublasIsamax_v2
|
| 119 |
+
#define cublasIdamax cublasIdamax_v2
|
| 120 |
+
#define cublasIcamax cublasIcamax_v2
|
| 121 |
+
#define cublasIzamax cublasIzamax_v2
|
| 122 |
+
|
| 123 |
+
#define cublasIsamin cublasIsamin_v2
|
| 124 |
+
#define cublasIdamin cublasIdamin_v2
|
| 125 |
+
#define cublasIcamin cublasIcamin_v2
|
| 126 |
+
#define cublasIzamin cublasIzamin_v2
|
| 127 |
+
|
| 128 |
+
#define cublasSasum cublasSasum_v2
|
| 129 |
+
#define cublasDasum cublasDasum_v2
|
| 130 |
+
#define cublasScasum cublasScasum_v2
|
| 131 |
+
#define cublasDzasum cublasDzasum_v2
|
| 132 |
+
|
| 133 |
+
#define cublasSrot cublasSrot_v2
|
| 134 |
+
#define cublasDrot cublasDrot_v2
|
| 135 |
+
#define cublasCrot cublasCrot_v2
|
| 136 |
+
#define cublasCsrot cublasCsrot_v2
|
| 137 |
+
#define cublasZrot cublasZrot_v2
|
| 138 |
+
#define cublasZdrot cublasZdrot_v2
|
| 139 |
+
|
| 140 |
+
#define cublasSrotg cublasSrotg_v2
|
| 141 |
+
#define cublasDrotg cublasDrotg_v2
|
| 142 |
+
#define cublasCrotg cublasCrotg_v2
|
| 143 |
+
#define cublasZrotg cublasZrotg_v2
|
| 144 |
+
|
| 145 |
+
#define cublasSrotm cublasSrotm_v2
|
| 146 |
+
#define cublasDrotm cublasDrotm_v2
|
| 147 |
+
|
| 148 |
+
#define cublasSrotmg cublasSrotmg_v2
|
| 149 |
+
#define cublasDrotmg cublasDrotmg_v2
|
| 150 |
+
|
| 151 |
+
/* Blas2 Routines */
|
| 152 |
+
|
| 153 |
+
#define cublasSgemv cublasSgemv_v2
|
| 154 |
+
#define cublasDgemv cublasDgemv_v2
|
| 155 |
+
#define cublasCgemv cublasCgemv_v2
|
| 156 |
+
#define cublasZgemv cublasZgemv_v2
|
| 157 |
+
|
| 158 |
+
#define cublasSgbmv cublasSgbmv_v2
|
| 159 |
+
#define cublasDgbmv cublasDgbmv_v2
|
| 160 |
+
#define cublasCgbmv cublasCgbmv_v2
|
| 161 |
+
#define cublasZgbmv cublasZgbmv_v2
|
| 162 |
+
|
| 163 |
+
#define cublasStrmv cublasStrmv_v2
|
| 164 |
+
#define cublasDtrmv cublasDtrmv_v2
|
| 165 |
+
#define cublasCtrmv cublasCtrmv_v2
|
| 166 |
+
#define cublasZtrmv cublasZtrmv_v2
|
| 167 |
+
|
| 168 |
+
#define cublasStbmv cublasStbmv_v2
|
| 169 |
+
#define cublasDtbmv cublasDtbmv_v2
|
| 170 |
+
#define cublasCtbmv cublasCtbmv_v2
|
| 171 |
+
#define cublasZtbmv cublasZtbmv_v2
|
| 172 |
+
|
| 173 |
+
#define cublasStpmv cublasStpmv_v2
|
| 174 |
+
#define cublasDtpmv cublasDtpmv_v2
|
| 175 |
+
#define cublasCtpmv cublasCtpmv_v2
|
| 176 |
+
#define cublasZtpmv cublasZtpmv_v2
|
| 177 |
+
|
| 178 |
+
#define cublasStrsv cublasStrsv_v2
|
| 179 |
+
#define cublasDtrsv cublasDtrsv_v2
|
| 180 |
+
#define cublasCtrsv cublasCtrsv_v2
|
| 181 |
+
#define cublasZtrsv cublasZtrsv_v2
|
| 182 |
+
|
| 183 |
+
#define cublasStpsv cublasStpsv_v2
|
| 184 |
+
#define cublasDtpsv cublasDtpsv_v2
|
| 185 |
+
#define cublasCtpsv cublasCtpsv_v2
|
| 186 |
+
#define cublasZtpsv cublasZtpsv_v2
|
| 187 |
+
|
| 188 |
+
#define cublasStbsv cublasStbsv_v2
|
| 189 |
+
#define cublasDtbsv cublasDtbsv_v2
|
| 190 |
+
#define cublasCtbsv cublasCtbsv_v2
|
| 191 |
+
#define cublasZtbsv cublasZtbsv_v2
|
| 192 |
+
|
| 193 |
+
#define cublasSsymv cublasSsymv_v2
|
| 194 |
+
#define cublasDsymv cublasDsymv_v2
|
| 195 |
+
#define cublasCsymv cublasCsymv_v2
|
| 196 |
+
#define cublasZsymv cublasZsymv_v2
|
| 197 |
+
#define cublasChemv cublasChemv_v2
|
| 198 |
+
#define cublasZhemv cublasZhemv_v2
|
| 199 |
+
|
| 200 |
+
#define cublasSsbmv cublasSsbmv_v2
|
| 201 |
+
#define cublasDsbmv cublasDsbmv_v2
|
| 202 |
+
#define cublasChbmv cublasChbmv_v2
|
| 203 |
+
#define cublasZhbmv cublasZhbmv_v2
|
| 204 |
+
|
| 205 |
+
#define cublasSspmv cublasSspmv_v2
|
| 206 |
+
#define cublasDspmv cublasDspmv_v2
|
| 207 |
+
#define cublasChpmv cublasChpmv_v2
|
| 208 |
+
#define cublasZhpmv cublasZhpmv_v2
|
| 209 |
+
|
| 210 |
+
#define cublasSger cublasSger_v2
|
| 211 |
+
#define cublasDger cublasDger_v2
|
| 212 |
+
#define cublasCgeru cublasCgeru_v2
|
| 213 |
+
#define cublasCgerc cublasCgerc_v2
|
| 214 |
+
#define cublasZgeru cublasZgeru_v2
|
| 215 |
+
#define cublasZgerc cublasZgerc_v2
|
| 216 |
+
|
| 217 |
+
#define cublasSsyr cublasSsyr_v2
|
| 218 |
+
#define cublasDsyr cublasDsyr_v2
|
| 219 |
+
#define cublasCsyr cublasCsyr_v2
|
| 220 |
+
#define cublasZsyr cublasZsyr_v2
|
| 221 |
+
#define cublasCher cublasCher_v2
|
| 222 |
+
#define cublasZher cublasZher_v2
|
| 223 |
+
|
| 224 |
+
#define cublasSspr cublasSspr_v2
|
| 225 |
+
#define cublasDspr cublasDspr_v2
|
| 226 |
+
#define cublasChpr cublasChpr_v2
|
| 227 |
+
#define cublasZhpr cublasZhpr_v2
|
| 228 |
+
|
| 229 |
+
#define cublasSsyr2 cublasSsyr2_v2
|
| 230 |
+
#define cublasDsyr2 cublasDsyr2_v2
|
| 231 |
+
#define cublasCsyr2 cublasCsyr2_v2
|
| 232 |
+
#define cublasZsyr2 cublasZsyr2_v2
|
| 233 |
+
#define cublasCher2 cublasCher2_v2
|
| 234 |
+
#define cublasZher2 cublasZher2_v2
|
| 235 |
+
|
| 236 |
+
#define cublasSspr2 cublasSspr2_v2
|
| 237 |
+
#define cublasDspr2 cublasDspr2_v2
|
| 238 |
+
#define cublasChpr2 cublasChpr2_v2
|
| 239 |
+
#define cublasZhpr2 cublasZhpr2_v2
|
| 240 |
+
|
| 241 |
+
/* Blas3 Routines */
|
| 242 |
+
|
| 243 |
+
#define cublasSgemm cublasSgemm_v2
|
| 244 |
+
#define cublasDgemm cublasDgemm_v2
|
| 245 |
+
#define cublasCgemm cublasCgemm_v2
|
| 246 |
+
#define cublasZgemm cublasZgemm_v2
|
| 247 |
+
|
| 248 |
+
#define cublasSsyrk cublasSsyrk_v2
|
| 249 |
+
#define cublasDsyrk cublasDsyrk_v2
|
| 250 |
+
#define cublasCsyrk cublasCsyrk_v2
|
| 251 |
+
#define cublasZsyrk cublasZsyrk_v2
|
| 252 |
+
#define cublasCherk cublasCherk_v2
|
| 253 |
+
#define cublasZherk cublasZherk_v2
|
| 254 |
+
|
| 255 |
+
#define cublasSsyr2k cublasSsyr2k_v2
|
| 256 |
+
#define cublasDsyr2k cublasDsyr2k_v2
|
| 257 |
+
#define cublasCsyr2k cublasCsyr2k_v2
|
| 258 |
+
#define cublasZsyr2k cublasZsyr2k_v2
|
| 259 |
+
#define cublasCher2k cublasCher2k_v2
|
| 260 |
+
#define cublasZher2k cublasZher2k_v2
|
| 261 |
+
|
| 262 |
+
#define cublasSsymm cublasSsymm_v2
|
| 263 |
+
#define cublasDsymm cublasDsymm_v2
|
| 264 |
+
#define cublasCsymm cublasCsymm_v2
|
| 265 |
+
#define cublasZsymm cublasZsymm_v2
|
| 266 |
+
#define cublasChemm cublasChemm_v2
|
| 267 |
+
#define cublasZhemm cublasZhemm_v2
|
| 268 |
+
|
| 269 |
+
#define cublasStrsm cublasStrsm_v2
|
| 270 |
+
#define cublasDtrsm cublasDtrsm_v2
|
| 271 |
+
#define cublasCtrsm cublasCtrsm_v2
|
| 272 |
+
#define cublasZtrsm cublasZtrsm_v2
|
| 273 |
+
|
| 274 |
+
#define cublasStrmm cublasStrmm_v2
|
| 275 |
+
#define cublasDtrmm cublasDtrmm_v2
|
| 276 |
+
#define cublasCtrmm cublasCtrmm_v2
|
| 277 |
+
#define cublasZtrmm cublasZtrmm_v2
|
| 278 |
+
|
| 279 |
+
/* 64-bit integer */
|
| 280 |
+
|
| 281 |
+
/* Blas1 Routines */
|
| 282 |
+
|
| 283 |
+
#define cublasSnrm2_64 cublasSnrm2_v2_64
|
| 284 |
+
#define cublasDnrm2_64 cublasDnrm2_v2_64
|
| 285 |
+
#define cublasScnrm2_64 cublasScnrm2_v2_64
|
| 286 |
+
#define cublasDznrm2_64 cublasDznrm2_v2_64
|
| 287 |
+
|
| 288 |
+
#define cublasSdot_64 cublasSdot_v2_64
|
| 289 |
+
#define cublasDdot_64 cublasDdot_v2_64
|
| 290 |
+
#define cublasCdotu_64 cublasCdotu_v2_64
|
| 291 |
+
#define cublasCdotc_64 cublasCdotc_v2_64
|
| 292 |
+
#define cublasZdotu_64 cublasZdotu_v2_64
|
| 293 |
+
#define cublasZdotc_64 cublasZdotc_v2_64
|
| 294 |
+
|
| 295 |
+
#define cublasSscal_64 cublasSscal_v2_64
|
| 296 |
+
#define cublasDscal_64 cublasDscal_v2_64
|
| 297 |
+
#define cublasCscal_64 cublasCscal_v2_64
|
| 298 |
+
#define cublasCsscal_64 cublasCsscal_v2_64
|
| 299 |
+
#define cublasZscal_64 cublasZscal_v2_64
|
| 300 |
+
#define cublasZdscal_64 cublasZdscal_v2_64
|
| 301 |
+
|
| 302 |
+
#define cublasSaxpy_64 cublasSaxpy_v2_64
|
| 303 |
+
#define cublasDaxpy_64 cublasDaxpy_v2_64
|
| 304 |
+
#define cublasCaxpy_64 cublasCaxpy_v2_64
|
| 305 |
+
#define cublasZaxpy_64 cublasZaxpy_v2_64
|
| 306 |
+
|
| 307 |
+
#define cublasScopy_64 cublasScopy_v2_64
|
| 308 |
+
#define cublasDcopy_64 cublasDcopy_v2_64
|
| 309 |
+
#define cublasCcopy_64 cublasCcopy_v2_64
|
| 310 |
+
#define cublasZcopy_64 cublasZcopy_v2_64
|
| 311 |
+
|
| 312 |
+
#define cublasSswap_64 cublasSswap_v2_64
|
| 313 |
+
#define cublasDswap_64 cublasDswap_v2_64
|
| 314 |
+
#define cublasCswap_64 cublasCswap_v2_64
|
| 315 |
+
#define cublasZswap_64 cublasZswap_v2_64
|
| 316 |
+
|
| 317 |
+
#define cublasIsamax_64 cublasIsamax_v2_64
|
| 318 |
+
#define cublasIdamax_64 cublasIdamax_v2_64
|
| 319 |
+
#define cublasIcamax_64 cublasIcamax_v2_64
|
| 320 |
+
#define cublasIzamax_64 cublasIzamax_v2_64
|
| 321 |
+
|
| 322 |
+
#define cublasIsamin_64 cublasIsamin_v2_64
|
| 323 |
+
#define cublasIdamin_64 cublasIdamin_v2_64
|
| 324 |
+
#define cublasIcamin_64 cublasIcamin_v2_64
|
| 325 |
+
#define cublasIzamin_64 cublasIzamin_v2_64
|
| 326 |
+
|
| 327 |
+
#define cublasSasum_64 cublasSasum_v2_64
|
| 328 |
+
#define cublasDasum_64 cublasDasum_v2_64
|
| 329 |
+
#define cublasScasum_64 cublasScasum_v2_64
|
| 330 |
+
#define cublasDzasum_64 cublasDzasum_v2_64
|
| 331 |
+
|
| 332 |
+
#define cublasSrot_64 cublasSrot_v2_64
|
| 333 |
+
#define cublasDrot_64 cublasDrot_v2_64
|
| 334 |
+
#define cublasCrot_64 cublasCrot_v2_64
|
| 335 |
+
#define cublasCsrot_64 cublasCsrot_v2_64
|
| 336 |
+
#define cublasZrot_64 cublasZrot_v2_64
|
| 337 |
+
#define cublasZdrot_64 cublasZdrot_v2_64
|
| 338 |
+
|
| 339 |
+
#define cublasSrotg_64 cublasSrotg_v2_64
|
| 340 |
+
#define cublasDrotg_64 cublasDrotg_v2_64
|
| 341 |
+
#define cublasCrotg_64 cublasCrotg_v2_64
|
| 342 |
+
#define cublasZrotg_64 cublasZrotg_v2_64
|
| 343 |
+
|
| 344 |
+
#define cublasSrotm_64 cublasSrotm_v2_64
|
| 345 |
+
#define cublasDrotm_64 cublasDrotm_v2_64
|
| 346 |
+
|
| 347 |
+
#define cublasSrotmg_64 cublasSrotmg_v2_64
|
| 348 |
+
#define cublasDrotmg_64 cublasDrotmg_v2_64
|
| 349 |
+
|
| 350 |
+
/* Blas2 Routines */
|
| 351 |
+
|
| 352 |
+
#define cublasSgemv_64 cublasSgemv_v2_64
|
| 353 |
+
#define cublasDgemv_64 cublasDgemv_v2_64
|
| 354 |
+
#define cublasCgemv_64 cublasCgemv_v2_64
|
| 355 |
+
#define cublasZgemv_64 cublasZgemv_v2_64
|
| 356 |
+
|
| 357 |
+
#define cublasSgbmv_64 cublasSgbmv_v2_64
|
| 358 |
+
#define cublasDgbmv_64 cublasDgbmv_v2_64
|
| 359 |
+
#define cublasCgbmv_64 cublasCgbmv_v2_64
|
| 360 |
+
#define cublasZgbmv_64 cublasZgbmv_v2_64
|
| 361 |
+
|
| 362 |
+
#define cublasStrmv_64 cublasStrmv_v2_64
|
| 363 |
+
#define cublasDtrmv_64 cublasDtrmv_v2_64
|
| 364 |
+
#define cublasCtrmv_64 cublasCtrmv_v2_64
|
| 365 |
+
#define cublasZtrmv_64 cublasZtrmv_v2_64
|
| 366 |
+
|
| 367 |
+
#define cublasStbmv_64 cublasStbmv_v2_64
|
| 368 |
+
#define cublasDtbmv_64 cublasDtbmv_v2_64
|
| 369 |
+
#define cublasCtbmv_64 cublasCtbmv_v2_64
|
| 370 |
+
#define cublasZtbmv_64 cublasZtbmv_v2_64
|
| 371 |
+
|
| 372 |
+
#define cublasStpmv_64 cublasStpmv_v2_64
|
| 373 |
+
#define cublasDtpmv_64 cublasDtpmv_v2_64
|
| 374 |
+
#define cublasCtpmv_64 cublasCtpmv_v2_64
|
| 375 |
+
#define cublasZtpmv_64 cublasZtpmv_v2_64
|
| 376 |
+
|
| 377 |
+
#define cublasStrsv_64 cublasStrsv_v2_64
|
| 378 |
+
#define cublasDtrsv_64 cublasDtrsv_v2_64
|
| 379 |
+
#define cublasCtrsv_64 cublasCtrsv_v2_64
|
| 380 |
+
#define cublasZtrsv_64 cublasZtrsv_v2_64
|
| 381 |
+
|
| 382 |
+
#define cublasStpsv_64 cublasStpsv_v2_64
|
| 383 |
+
#define cublasDtpsv_64 cublasDtpsv_v2_64
|
| 384 |
+
#define cublasCtpsv_64 cublasCtpsv_v2_64
|
| 385 |
+
#define cublasZtpsv_64 cublasZtpsv_v2_64
|
| 386 |
+
|
| 387 |
+
#define cublasStbsv_64 cublasStbsv_v2_64
|
| 388 |
+
#define cublasDtbsv_64 cublasDtbsv_v2_64
|
| 389 |
+
#define cublasCtbsv_64 cublasCtbsv_v2_64
|
| 390 |
+
#define cublasZtbsv_64 cublasZtbsv_v2_64
|
| 391 |
+
|
| 392 |
+
#define cublasSsymv_64 cublasSsymv_v2_64
|
| 393 |
+
#define cublasDsymv_64 cublasDsymv_v2_64
|
| 394 |
+
#define cublasCsymv_64 cublasCsymv_v2_64
|
| 395 |
+
#define cublasZsymv_64 cublasZsymv_v2_64
|
| 396 |
+
#define cublasChemv_64 cublasChemv_v2_64
|
| 397 |
+
#define cublasZhemv_64 cublasZhemv_v2_64
|
| 398 |
+
|
| 399 |
+
#define cublasSsbmv_64 cublasSsbmv_v2_64
|
| 400 |
+
#define cublasDsbmv_64 cublasDsbmv_v2_64
|
| 401 |
+
#define cublasChbmv_64 cublasChbmv_v2_64
|
| 402 |
+
#define cublasZhbmv_64 cublasZhbmv_v2_64
|
| 403 |
+
|
| 404 |
+
#define cublasSspmv_64 cublasSspmv_v2_64
|
| 405 |
+
#define cublasDspmv_64 cublasDspmv_v2_64
|
| 406 |
+
#define cublasChpmv_64 cublasChpmv_v2_64
|
| 407 |
+
#define cublasZhpmv_64 cublasZhpmv_v2_64
|
| 408 |
+
|
| 409 |
+
#define cublasSger_64 cublasSger_v2_64
|
| 410 |
+
#define cublasDger_64 cublasDger_v2_64
|
| 411 |
+
#define cublasCgeru_64 cublasCgeru_v2_64
|
| 412 |
+
#define cublasCgerc_64 cublasCgerc_v2_64
|
| 413 |
+
#define cublasZgeru_64 cublasZgeru_v2_64
|
| 414 |
+
#define cublasZgerc_64 cublasZgerc_v2_64
|
| 415 |
+
|
| 416 |
+
#define cublasSsyr_64 cublasSsyr_v2_64
|
| 417 |
+
#define cublasDsyr_64 cublasDsyr_v2_64
|
| 418 |
+
#define cublasCsyr_64 cublasCsyr_v2_64
|
| 419 |
+
#define cublasZsyr_64 cublasZsyr_v2_64
|
| 420 |
+
#define cublasCher_64 cublasCher_v2_64
|
| 421 |
+
#define cublasZher_64 cublasZher_v2_64
|
| 422 |
+
|
| 423 |
+
#define cublasSspr_64 cublasSspr_v2_64
|
| 424 |
+
#define cublasDspr_64 cublasDspr_v2_64
|
| 425 |
+
#define cublasChpr_64 cublasChpr_v2_64
|
| 426 |
+
#define cublasZhpr_64 cublasZhpr_v2_64
|
| 427 |
+
|
| 428 |
+
#define cublasSsyr2_64 cublasSsyr2_v2_64
|
| 429 |
+
#define cublasDsyr2_64 cublasDsyr2_v2_64
|
| 430 |
+
#define cublasCsyr2_64 cublasCsyr2_v2_64
|
| 431 |
+
#define cublasZsyr2_64 cublasZsyr2_v2_64
|
| 432 |
+
#define cublasCher2_64 cublasCher2_v2_64
|
| 433 |
+
#define cublasZher2_64 cublasZher2_v2_64
|
| 434 |
+
|
| 435 |
+
#define cublasSspr2_64 cublasSspr2_v2_64
|
| 436 |
+
#define cublasDspr2_64 cublasDspr2_v2_64
|
| 437 |
+
#define cublasChpr2_64 cublasChpr2_v2_64
|
| 438 |
+
#define cublasZhpr2_64 cublasZhpr2_v2_64
|
| 439 |
+
|
| 440 |
+
/* Blas3 Routines */
|
| 441 |
+
|
| 442 |
+
#define cublasSgemm_64 cublasSgemm_v2_64
|
| 443 |
+
#define cublasDgemm_64 cublasDgemm_v2_64
|
| 444 |
+
#define cublasCgemm_64 cublasCgemm_v2_64
|
| 445 |
+
#define cublasZgemm_64 cublasZgemm_v2_64
|
| 446 |
+
|
| 447 |
+
#define cublasSsyrk_64 cublasSsyrk_v2_64
|
| 448 |
+
#define cublasDsyrk_64 cublasDsyrk_v2_64
|
| 449 |
+
#define cublasCsyrk_64 cublasCsyrk_v2_64
|
| 450 |
+
#define cublasZsyrk_64 cublasZsyrk_v2_64
|
| 451 |
+
#define cublasCherk_64 cublasCherk_v2_64
|
| 452 |
+
#define cublasZherk_64 cublasZherk_v2_64
|
| 453 |
+
|
| 454 |
+
#define cublasSsyr2k_64 cublasSsyr2k_v2_64
|
| 455 |
+
#define cublasDsyr2k_64 cublasDsyr2k_v2_64
|
| 456 |
+
#define cublasCsyr2k_64 cublasCsyr2k_v2_64
|
| 457 |
+
#define cublasZsyr2k_64 cublasZsyr2k_v2_64
|
| 458 |
+
#define cublasCher2k_64 cublasCher2k_v2_64
|
| 459 |
+
#define cublasZher2k_64 cublasZher2k_v2_64
|
| 460 |
+
|
| 461 |
+
#define cublasSsymm_64 cublasSsymm_v2_64
|
| 462 |
+
#define cublasDsymm_64 cublasDsymm_v2_64
|
| 463 |
+
#define cublasCsymm_64 cublasCsymm_v2_64
|
| 464 |
+
#define cublasZsymm_64 cublasZsymm_v2_64
|
| 465 |
+
#define cublasChemm_64 cublasChemm_v2_64
|
| 466 |
+
#define cublasZhemm_64 cublasZhemm_v2_64
|
| 467 |
+
|
| 468 |
+
#define cublasStrsm_64 cublasStrsm_v2_64
|
| 469 |
+
#define cublasDtrsm_64 cublasDtrsm_v2_64
|
| 470 |
+
#define cublasCtrsm_64 cublasCtrsm_v2_64
|
| 471 |
+
#define cublasZtrsm_64 cublasZtrsm_v2_64
|
| 472 |
+
|
| 473 |
+
#define cublasStrmm_64 cublasStrmm_v2_64
|
| 474 |
+
#define cublasDtrmm_64 cublasDtrmm_v2_64
|
| 475 |
+
#define cublasCtrmm_64 cublasCtrmm_v2_64
|
| 476 |
+
#define cublasZtrmm_64 cublasZtrmm_v2_64
|
| 477 |
+
|
| 478 |
+
#endif /* !defined(CUBLAS_V2_H_) */
|
.venv/lib/python3.11/site-packages/nvidia/cublas/include/nvblas.h
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#if !defined(NVBLAS_H_)
|
| 51 |
+
#define NVBLAS_H_
|
| 52 |
+
|
| 53 |
+
#include "driver_types.h"
|
| 54 |
+
#include "cuComplex.h" /* import complex data type */
|
| 55 |
+
|
| 56 |
+
#if defined(__cplusplus)
|
| 57 |
+
extern "C" {
|
| 58 |
+
#endif
|
| 59 |
+
|
| 60 |
+
/* GEMM */
|
| 61 |
+
void sgemm_(const char* transa,
|
| 62 |
+
const char* transb,
|
| 63 |
+
const int* m,
|
| 64 |
+
const int* n,
|
| 65 |
+
const int* k,
|
| 66 |
+
const float* alpha,
|
| 67 |
+
const float* a,
|
| 68 |
+
const int* lda,
|
| 69 |
+
const float* b,
|
| 70 |
+
const int* ldb,
|
| 71 |
+
const float* beta,
|
| 72 |
+
float* c,
|
| 73 |
+
const int* ldc);
|
| 74 |
+
|
| 75 |
+
void dgemm_(const char* transa,
|
| 76 |
+
const char* transb,
|
| 77 |
+
const int* m,
|
| 78 |
+
const int* n,
|
| 79 |
+
const int* k,
|
| 80 |
+
const double* alpha,
|
| 81 |
+
const double* a,
|
| 82 |
+
const int* lda,
|
| 83 |
+
const double* b,
|
| 84 |
+
const int* ldb,
|
| 85 |
+
const double* beta,
|
| 86 |
+
double* c,
|
| 87 |
+
const int* ldc);
|
| 88 |
+
|
| 89 |
+
void cgemm_(const char* transa,
|
| 90 |
+
const char* transb,
|
| 91 |
+
const int* m,
|
| 92 |
+
const int* n,
|
| 93 |
+
const int* k,
|
| 94 |
+
const cuComplex* alpha,
|
| 95 |
+
const cuComplex* a,
|
| 96 |
+
const int* lda,
|
| 97 |
+
const cuComplex* b,
|
| 98 |
+
const int* ldb,
|
| 99 |
+
const cuComplex* beta,
|
| 100 |
+
cuComplex* c,
|
| 101 |
+
const int* ldc);
|
| 102 |
+
|
| 103 |
+
void zgemm_(const char* transa,
|
| 104 |
+
const char* transb,
|
| 105 |
+
const int* m,
|
| 106 |
+
const int* n,
|
| 107 |
+
const int* k,
|
| 108 |
+
const cuDoubleComplex* alpha,
|
| 109 |
+
const cuDoubleComplex* a,
|
| 110 |
+
const int* lda,
|
| 111 |
+
const cuDoubleComplex* b,
|
| 112 |
+
const int* ldb,
|
| 113 |
+
const cuDoubleComplex* beta,
|
| 114 |
+
cuDoubleComplex* c,
|
| 115 |
+
const int* ldc);
|
| 116 |
+
|
| 117 |
+
void sgemm(const char* transa,
|
| 118 |
+
const char* transb,
|
| 119 |
+
const int* m,
|
| 120 |
+
const int* n,
|
| 121 |
+
const int* k,
|
| 122 |
+
const float* alpha,
|
| 123 |
+
const float* a,
|
| 124 |
+
const int* lda,
|
| 125 |
+
const float* b,
|
| 126 |
+
const int* ldb,
|
| 127 |
+
const float* beta,
|
| 128 |
+
float* c,
|
| 129 |
+
const int* ldc);
|
| 130 |
+
|
| 131 |
+
void dgemm(const char* transa,
|
| 132 |
+
const char* transb,
|
| 133 |
+
const int* m,
|
| 134 |
+
const int* n,
|
| 135 |
+
const int* k,
|
| 136 |
+
const double* alpha,
|
| 137 |
+
const double* a,
|
| 138 |
+
const int* lda,
|
| 139 |
+
const double* b,
|
| 140 |
+
const int* ldb,
|
| 141 |
+
const double* beta,
|
| 142 |
+
double* c,
|
| 143 |
+
const int* ldc);
|
| 144 |
+
|
| 145 |
+
void cgemm(const char* transa,
|
| 146 |
+
const char* transb,
|
| 147 |
+
const int* m,
|
| 148 |
+
const int* n,
|
| 149 |
+
const int* k,
|
| 150 |
+
const cuComplex* alpha,
|
| 151 |
+
const cuComplex* a,
|
| 152 |
+
const int* lda,
|
| 153 |
+
const cuComplex* b,
|
| 154 |
+
const int* ldb,
|
| 155 |
+
const cuComplex* beta,
|
| 156 |
+
cuComplex* c,
|
| 157 |
+
const int* ldc);
|
| 158 |
+
|
| 159 |
+
void zgemm(const char* transa,
|
| 160 |
+
const char* transb,
|
| 161 |
+
const int* m,
|
| 162 |
+
const int* n,
|
| 163 |
+
const int* k,
|
| 164 |
+
const cuDoubleComplex* alpha,
|
| 165 |
+
const cuDoubleComplex* a,
|
| 166 |
+
const int* lda,
|
| 167 |
+
const cuDoubleComplex* b,
|
| 168 |
+
const int* ldb,
|
| 169 |
+
const cuDoubleComplex* beta,
|
| 170 |
+
cuDoubleComplex* c,
|
| 171 |
+
const int* ldc);
|
| 172 |
+
|
| 173 |
+
/* SYRK */
|
| 174 |
+
void ssyrk_(const char* uplo,
|
| 175 |
+
const char* trans,
|
| 176 |
+
const int* n,
|
| 177 |
+
const int* k,
|
| 178 |
+
const float* alpha,
|
| 179 |
+
const float* a,
|
| 180 |
+
const int* lda,
|
| 181 |
+
const float* beta,
|
| 182 |
+
float* c,
|
| 183 |
+
const int* ldc);
|
| 184 |
+
|
| 185 |
+
void dsyrk_(const char* uplo,
|
| 186 |
+
const char* trans,
|
| 187 |
+
const int* n,
|
| 188 |
+
const int* k,
|
| 189 |
+
const double* alpha,
|
| 190 |
+
const double* a,
|
| 191 |
+
const int* lda,
|
| 192 |
+
const double* beta,
|
| 193 |
+
double* c,
|
| 194 |
+
const int* ldc);
|
| 195 |
+
|
| 196 |
+
void csyrk_(const char* uplo,
|
| 197 |
+
const char* trans,
|
| 198 |
+
const int* n,
|
| 199 |
+
const int* k,
|
| 200 |
+
const cuComplex* alpha,
|
| 201 |
+
const cuComplex* a,
|
| 202 |
+
const int* lda,
|
| 203 |
+
const cuComplex* beta,
|
| 204 |
+
cuComplex* c,
|
| 205 |
+
const int* ldc);
|
| 206 |
+
|
| 207 |
+
void zsyrk_(const char* uplo,
|
| 208 |
+
const char* trans,
|
| 209 |
+
const int* n,
|
| 210 |
+
const int* k,
|
| 211 |
+
const cuDoubleComplex* alpha,
|
| 212 |
+
const cuDoubleComplex* a,
|
| 213 |
+
const int* lda,
|
| 214 |
+
const cuDoubleComplex* beta,
|
| 215 |
+
cuDoubleComplex* c,
|
| 216 |
+
const int* ldc);
|
| 217 |
+
|
| 218 |
+
void ssyrk(const char* uplo,
|
| 219 |
+
const char* trans,
|
| 220 |
+
const int* n,
|
| 221 |
+
const int* k,
|
| 222 |
+
const float* alpha,
|
| 223 |
+
const float* a,
|
| 224 |
+
const int* lda,
|
| 225 |
+
const float* beta,
|
| 226 |
+
float* c,
|
| 227 |
+
const int* ldc);
|
| 228 |
+
|
| 229 |
+
void dsyrk(const char* uplo,
|
| 230 |
+
const char* trans,
|
| 231 |
+
const int* n,
|
| 232 |
+
const int* k,
|
| 233 |
+
const double* alpha,
|
| 234 |
+
const double* a,
|
| 235 |
+
const int* lda,
|
| 236 |
+
const double* beta,
|
| 237 |
+
double* c,
|
| 238 |
+
const int* ldc);
|
| 239 |
+
|
| 240 |
+
void csyrk(const char* uplo,
|
| 241 |
+
const char* trans,
|
| 242 |
+
const int* n,
|
| 243 |
+
const int* k,
|
| 244 |
+
const cuComplex* alpha,
|
| 245 |
+
const cuComplex* a,
|
| 246 |
+
const int* lda,
|
| 247 |
+
const cuComplex* beta,
|
| 248 |
+
cuComplex* c,
|
| 249 |
+
const int* ldc);
|
| 250 |
+
|
| 251 |
+
void zsyrk(const char* uplo,
|
| 252 |
+
const char* trans,
|
| 253 |
+
const int* n,
|
| 254 |
+
const int* k,
|
| 255 |
+
const cuDoubleComplex* alpha,
|
| 256 |
+
const cuDoubleComplex* a,
|
| 257 |
+
const int* lda,
|
| 258 |
+
const cuDoubleComplex* beta,
|
| 259 |
+
cuDoubleComplex* c,
|
| 260 |
+
const int* ldc);
|
| 261 |
+
|
| 262 |
+
/* HERK */
|
| 263 |
+
void cherk_(const char* uplo,
|
| 264 |
+
const char* trans,
|
| 265 |
+
const int* n,
|
| 266 |
+
const int* k,
|
| 267 |
+
const float* alpha,
|
| 268 |
+
const cuComplex* a,
|
| 269 |
+
const int* lda,
|
| 270 |
+
const float* beta,
|
| 271 |
+
cuComplex* c,
|
| 272 |
+
const int* ldc);
|
| 273 |
+
|
| 274 |
+
void zherk_(const char* uplo,
|
| 275 |
+
const char* trans,
|
| 276 |
+
const int* n,
|
| 277 |
+
const int* k,
|
| 278 |
+
const double* alpha,
|
| 279 |
+
const cuDoubleComplex* a,
|
| 280 |
+
const int* lda,
|
| 281 |
+
const double* beta,
|
| 282 |
+
cuDoubleComplex* c,
|
| 283 |
+
const int* ldc);
|
| 284 |
+
|
| 285 |
+
void cherk(const char* uplo,
|
| 286 |
+
const char* trans,
|
| 287 |
+
const int* n,
|
| 288 |
+
const int* k,
|
| 289 |
+
const float* alpha,
|
| 290 |
+
const cuComplex* a,
|
| 291 |
+
const int* lda,
|
| 292 |
+
const float* beta,
|
| 293 |
+
cuComplex* c,
|
| 294 |
+
const int* ldc);
|
| 295 |
+
|
| 296 |
+
void zherk(const char* uplo,
|
| 297 |
+
const char* trans,
|
| 298 |
+
const int* n,
|
| 299 |
+
const int* k,
|
| 300 |
+
const double* alpha,
|
| 301 |
+
const cuDoubleComplex* a,
|
| 302 |
+
const int* lda,
|
| 303 |
+
const double* beta,
|
| 304 |
+
cuDoubleComplex* c,
|
| 305 |
+
const int* ldc);
|
| 306 |
+
|
| 307 |
+
/* TRSM */
|
| 308 |
+
void strsm_(const char* side,
|
| 309 |
+
const char* uplo,
|
| 310 |
+
const char* transa,
|
| 311 |
+
const char* diag,
|
| 312 |
+
const int* m,
|
| 313 |
+
const int* n,
|
| 314 |
+
const float* alpha,
|
| 315 |
+
const float* a,
|
| 316 |
+
const int* lda,
|
| 317 |
+
float* b,
|
| 318 |
+
const int* ldb);
|
| 319 |
+
|
| 320 |
+
void dtrsm_(const char* side,
|
| 321 |
+
const char* uplo,
|
| 322 |
+
const char* transa,
|
| 323 |
+
const char* diag,
|
| 324 |
+
const int* m,
|
| 325 |
+
const int* n,
|
| 326 |
+
const double* alpha,
|
| 327 |
+
const double* a,
|
| 328 |
+
const int* lda,
|
| 329 |
+
double* b,
|
| 330 |
+
const int* ldb);
|
| 331 |
+
|
| 332 |
+
void ctrsm_(const char* side,
|
| 333 |
+
const char* uplo,
|
| 334 |
+
const char* transa,
|
| 335 |
+
const char* diag,
|
| 336 |
+
const int* m,
|
| 337 |
+
const int* n,
|
| 338 |
+
const cuComplex* alpha,
|
| 339 |
+
const cuComplex* a,
|
| 340 |
+
const int* lda,
|
| 341 |
+
cuComplex* b,
|
| 342 |
+
const int* ldb);
|
| 343 |
+
|
| 344 |
+
void ztrsm_(const char* side,
|
| 345 |
+
const char* uplo,
|
| 346 |
+
const char* transa,
|
| 347 |
+
const char* diag,
|
| 348 |
+
const int* m,
|
| 349 |
+
const int* n,
|
| 350 |
+
const cuDoubleComplex* alpha,
|
| 351 |
+
const cuDoubleComplex* a,
|
| 352 |
+
const int* lda,
|
| 353 |
+
cuDoubleComplex* b,
|
| 354 |
+
const int* ldb);
|
| 355 |
+
|
| 356 |
+
void strsm(const char* side,
|
| 357 |
+
const char* uplo,
|
| 358 |
+
const char* transa,
|
| 359 |
+
const char* diag,
|
| 360 |
+
const int* m,
|
| 361 |
+
const int* n,
|
| 362 |
+
const float* alpha,
|
| 363 |
+
const float* a,
|
| 364 |
+
const int* lda,
|
| 365 |
+
float* b,
|
| 366 |
+
const int* ldb);
|
| 367 |
+
|
| 368 |
+
void dtrsm(const char* side,
|
| 369 |
+
const char* uplo,
|
| 370 |
+
const char* transa,
|
| 371 |
+
const char* diag,
|
| 372 |
+
const int* m,
|
| 373 |
+
const int* n,
|
| 374 |
+
const double* alpha,
|
| 375 |
+
const double* a,
|
| 376 |
+
const int* lda,
|
| 377 |
+
double* b,
|
| 378 |
+
const int* ldb);
|
| 379 |
+
|
| 380 |
+
void ctrsm(const char* side,
|
| 381 |
+
const char* uplo,
|
| 382 |
+
const char* transa,
|
| 383 |
+
const char* diag,
|
| 384 |
+
const int* m,
|
| 385 |
+
const int* n,
|
| 386 |
+
const cuComplex* alpha,
|
| 387 |
+
const cuComplex* a,
|
| 388 |
+
const int* lda,
|
| 389 |
+
cuComplex* b,
|
| 390 |
+
const int* ldb);
|
| 391 |
+
|
| 392 |
+
void ztrsm(const char* side,
|
| 393 |
+
const char* uplo,
|
| 394 |
+
const char* transa,
|
| 395 |
+
const char* diag,
|
| 396 |
+
const int* m,
|
| 397 |
+
const int* n,
|
| 398 |
+
const cuDoubleComplex* alpha,
|
| 399 |
+
const cuDoubleComplex* a,
|
| 400 |
+
const int* lda,
|
| 401 |
+
cuDoubleComplex* b,
|
| 402 |
+
const int* ldb);
|
| 403 |
+
|
| 404 |
+
/* SYMM */
|
| 405 |
+
void ssymm_(const char* side,
|
| 406 |
+
const char* uplo,
|
| 407 |
+
const int* m,
|
| 408 |
+
const int* n,
|
| 409 |
+
const float* alpha,
|
| 410 |
+
const float* a,
|
| 411 |
+
const int* lda,
|
| 412 |
+
const float* b,
|
| 413 |
+
const int* ldb,
|
| 414 |
+
const float* beta,
|
| 415 |
+
float* c,
|
| 416 |
+
const int* ldc);
|
| 417 |
+
|
| 418 |
+
void dsymm_(const char* side,
|
| 419 |
+
const char* uplo,
|
| 420 |
+
const int* m,
|
| 421 |
+
const int* n,
|
| 422 |
+
const double* alpha,
|
| 423 |
+
const double* a,
|
| 424 |
+
const int* lda,
|
| 425 |
+
const double* b,
|
| 426 |
+
const int* ldb,
|
| 427 |
+
const double* beta,
|
| 428 |
+
double* c,
|
| 429 |
+
const int* ldc);
|
| 430 |
+
|
| 431 |
+
void csymm_(const char* side,
|
| 432 |
+
const char* uplo,
|
| 433 |
+
const int* m,
|
| 434 |
+
const int* n,
|
| 435 |
+
const cuComplex* alpha,
|
| 436 |
+
const cuComplex* a,
|
| 437 |
+
const int* lda,
|
| 438 |
+
const cuComplex* b,
|
| 439 |
+
const int* ldb,
|
| 440 |
+
const cuComplex* beta,
|
| 441 |
+
cuComplex* c,
|
| 442 |
+
const int* ldc);
|
| 443 |
+
|
| 444 |
+
void zsymm_(const char* side,
|
| 445 |
+
const char* uplo,
|
| 446 |
+
const int* m,
|
| 447 |
+
const int* n,
|
| 448 |
+
const cuDoubleComplex* alpha,
|
| 449 |
+
const cuDoubleComplex* a,
|
| 450 |
+
const int* lda,
|
| 451 |
+
const cuDoubleComplex* b,
|
| 452 |
+
const int* ldb,
|
| 453 |
+
const cuDoubleComplex* beta,
|
| 454 |
+
cuDoubleComplex* c,
|
| 455 |
+
const int* ldc);
|
| 456 |
+
|
| 457 |
+
void ssymm(const char* side,
|
| 458 |
+
const char* uplo,
|
| 459 |
+
const int* m,
|
| 460 |
+
const int* n,
|
| 461 |
+
const float* alpha,
|
| 462 |
+
const float* a,
|
| 463 |
+
const int* lda,
|
| 464 |
+
const float* b,
|
| 465 |
+
const int* ldb,
|
| 466 |
+
const float* beta,
|
| 467 |
+
float* c,
|
| 468 |
+
const int* ldc);
|
| 469 |
+
|
| 470 |
+
void dsymm(const char* side,
|
| 471 |
+
const char* uplo,
|
| 472 |
+
const int* m,
|
| 473 |
+
const int* n,
|
| 474 |
+
const double* alpha,
|
| 475 |
+
const double* a,
|
| 476 |
+
const int* lda,
|
| 477 |
+
const double* b,
|
| 478 |
+
const int* ldb,
|
| 479 |
+
const double* beta,
|
| 480 |
+
double* c,
|
| 481 |
+
const int* ldc);
|
| 482 |
+
|
| 483 |
+
void csymm(const char* side,
|
| 484 |
+
const char* uplo,
|
| 485 |
+
const int* m,
|
| 486 |
+
const int* n,
|
| 487 |
+
const cuComplex* alpha,
|
| 488 |
+
const cuComplex* a,
|
| 489 |
+
const int* lda,
|
| 490 |
+
const cuComplex* b,
|
| 491 |
+
const int* ldb,
|
| 492 |
+
const cuComplex* beta,
|
| 493 |
+
cuComplex* c,
|
| 494 |
+
const int* ldc);
|
| 495 |
+
|
| 496 |
+
void zsymm(const char* side,
|
| 497 |
+
const char* uplo,
|
| 498 |
+
const int* m,
|
| 499 |
+
const int* n,
|
| 500 |
+
const cuDoubleComplex* alpha,
|
| 501 |
+
const cuDoubleComplex* a,
|
| 502 |
+
const int* lda,
|
| 503 |
+
const cuDoubleComplex* b,
|
| 504 |
+
const int* ldb,
|
| 505 |
+
const cuDoubleComplex* beta,
|
| 506 |
+
cuDoubleComplex* c,
|
| 507 |
+
const int* ldc);
|
| 508 |
+
|
| 509 |
+
/* HEMM */
|
| 510 |
+
void chemm_(const char* side,
|
| 511 |
+
const char* uplo,
|
| 512 |
+
const int* m,
|
| 513 |
+
const int* n,
|
| 514 |
+
const cuComplex* alpha,
|
| 515 |
+
const cuComplex* a,
|
| 516 |
+
const int* lda,
|
| 517 |
+
const cuComplex* b,
|
| 518 |
+
const int* ldb,
|
| 519 |
+
const cuComplex* beta,
|
| 520 |
+
cuComplex* c,
|
| 521 |
+
const int* ldc);
|
| 522 |
+
|
| 523 |
+
void zhemm_(const char* side,
|
| 524 |
+
const char* uplo,
|
| 525 |
+
const int* m,
|
| 526 |
+
const int* n,
|
| 527 |
+
const cuDoubleComplex* alpha,
|
| 528 |
+
const cuDoubleComplex* a,
|
| 529 |
+
const int* lda,
|
| 530 |
+
const cuDoubleComplex* b,
|
| 531 |
+
const int* ldb,
|
| 532 |
+
const cuDoubleComplex* beta,
|
| 533 |
+
cuDoubleComplex* c,
|
| 534 |
+
const int* ldc);
|
| 535 |
+
|
| 536 |
+
/* HEMM with no underscore*/
|
| 537 |
+
void chemm(const char* side,
|
| 538 |
+
const char* uplo,
|
| 539 |
+
const int* m,
|
| 540 |
+
const int* n,
|
| 541 |
+
const cuComplex* alpha,
|
| 542 |
+
const cuComplex* a,
|
| 543 |
+
const int* lda,
|
| 544 |
+
const cuComplex* b,
|
| 545 |
+
const int* ldb,
|
| 546 |
+
const cuComplex* beta,
|
| 547 |
+
cuComplex* c,
|
| 548 |
+
const int* ldc);
|
| 549 |
+
|
| 550 |
+
void zhemm(const char* side,
|
| 551 |
+
const char* uplo,
|
| 552 |
+
const int* m,
|
| 553 |
+
const int* n,
|
| 554 |
+
const cuDoubleComplex* alpha,
|
| 555 |
+
const cuDoubleComplex* a,
|
| 556 |
+
const int* lda,
|
| 557 |
+
const cuDoubleComplex* b,
|
| 558 |
+
const int* ldb,
|
| 559 |
+
const cuDoubleComplex* beta,
|
| 560 |
+
cuDoubleComplex* c,
|
| 561 |
+
const int* ldc);
|
| 562 |
+
|
| 563 |
+
/* SYR2K */
|
| 564 |
+
void ssyr2k_(const char* uplo,
|
| 565 |
+
const char* trans,
|
| 566 |
+
const int* n,
|
| 567 |
+
const int* k,
|
| 568 |
+
const float* alpha,
|
| 569 |
+
const float* a,
|
| 570 |
+
const int* lda,
|
| 571 |
+
const float* b,
|
| 572 |
+
const int* ldb,
|
| 573 |
+
const float* beta,
|
| 574 |
+
float* c,
|
| 575 |
+
const int* ldc);
|
| 576 |
+
|
| 577 |
+
void dsyr2k_(const char* uplo,
|
| 578 |
+
const char* trans,
|
| 579 |
+
const int* n,
|
| 580 |
+
const int* k,
|
| 581 |
+
const double* alpha,
|
| 582 |
+
const double* a,
|
| 583 |
+
const int* lda,
|
| 584 |
+
const double* b,
|
| 585 |
+
const int* ldb,
|
| 586 |
+
const double* beta,
|
| 587 |
+
double* c,
|
| 588 |
+
const int* ldc);
|
| 589 |
+
|
| 590 |
+
void csyr2k_(const char* uplo,
|
| 591 |
+
const char* trans,
|
| 592 |
+
const int* n,
|
| 593 |
+
const int* k,
|
| 594 |
+
const cuComplex* alpha,
|
| 595 |
+
const cuComplex* a,
|
| 596 |
+
const int* lda,
|
| 597 |
+
const cuComplex* b,
|
| 598 |
+
const int* ldb,
|
| 599 |
+
const cuComplex* beta,
|
| 600 |
+
cuComplex* c,
|
| 601 |
+
const int* ldc);
|
| 602 |
+
|
| 603 |
+
void zsyr2k_(const char* uplo,
|
| 604 |
+
const char* trans,
|
| 605 |
+
const int* n,
|
| 606 |
+
const int* k,
|
| 607 |
+
const cuDoubleComplex* alpha,
|
| 608 |
+
const cuDoubleComplex* a,
|
| 609 |
+
const int* lda,
|
| 610 |
+
const cuDoubleComplex* b,
|
| 611 |
+
const int* ldb,
|
| 612 |
+
const cuDoubleComplex* beta,
|
| 613 |
+
cuDoubleComplex* c,
|
| 614 |
+
const int* ldc);
|
| 615 |
+
|
| 616 |
+
/* SYR2K no_underscore*/
|
| 617 |
+
void ssyr2k(const char* uplo,
|
| 618 |
+
const char* trans,
|
| 619 |
+
const int* n,
|
| 620 |
+
const int* k,
|
| 621 |
+
const float* alpha,
|
| 622 |
+
const float* a,
|
| 623 |
+
const int* lda,
|
| 624 |
+
const float* b,
|
| 625 |
+
const int* ldb,
|
| 626 |
+
const float* beta,
|
| 627 |
+
float* c,
|
| 628 |
+
const int* ldc);
|
| 629 |
+
|
| 630 |
+
void dsyr2k(const char* uplo,
|
| 631 |
+
const char* trans,
|
| 632 |
+
const int* n,
|
| 633 |
+
const int* k,
|
| 634 |
+
const double* alpha,
|
| 635 |
+
const double* a,
|
| 636 |
+
const int* lda,
|
| 637 |
+
const double* b,
|
| 638 |
+
const int* ldb,
|
| 639 |
+
const double* beta,
|
| 640 |
+
double* c,
|
| 641 |
+
const int* ldc);
|
| 642 |
+
|
| 643 |
+
void csyr2k(const char* uplo,
|
| 644 |
+
const char* trans,
|
| 645 |
+
const int* n,
|
| 646 |
+
const int* k,
|
| 647 |
+
const cuComplex* alpha,
|
| 648 |
+
const cuComplex* a,
|
| 649 |
+
const int* lda,
|
| 650 |
+
const cuComplex* b,
|
| 651 |
+
const int* ldb,
|
| 652 |
+
const cuComplex* beta,
|
| 653 |
+
cuComplex* c,
|
| 654 |
+
const int* ldc);
|
| 655 |
+
|
| 656 |
+
void zsyr2k(const char* uplo,
|
| 657 |
+
const char* trans,
|
| 658 |
+
const int* n,
|
| 659 |
+
const int* k,
|
| 660 |
+
const cuDoubleComplex* alpha,
|
| 661 |
+
const cuDoubleComplex* a,
|
| 662 |
+
const int* lda,
|
| 663 |
+
const cuDoubleComplex* b,
|
| 664 |
+
const int* ldb,
|
| 665 |
+
const cuDoubleComplex* beta,
|
| 666 |
+
cuDoubleComplex* c,
|
| 667 |
+
const int* ldc);
|
| 668 |
+
|
| 669 |
+
/* HERK */
|
| 670 |
+
void cher2k_(const char* uplo,
|
| 671 |
+
const char* trans,
|
| 672 |
+
const int* n,
|
| 673 |
+
const int* k,
|
| 674 |
+
const cuComplex* alpha,
|
| 675 |
+
const cuComplex* a,
|
| 676 |
+
const int* lda,
|
| 677 |
+
const cuComplex* b,
|
| 678 |
+
const int* ldb,
|
| 679 |
+
const float* beta,
|
| 680 |
+
cuComplex* c,
|
| 681 |
+
const int* ldc);
|
| 682 |
+
|
| 683 |
+
void zher2k_(const char* uplo,
|
| 684 |
+
const char* trans,
|
| 685 |
+
const int* n,
|
| 686 |
+
const int* k,
|
| 687 |
+
const cuDoubleComplex* alpha,
|
| 688 |
+
const cuDoubleComplex* a,
|
| 689 |
+
const int* lda,
|
| 690 |
+
const cuDoubleComplex* b,
|
| 691 |
+
const int* ldb,
|
| 692 |
+
const double* beta,
|
| 693 |
+
cuDoubleComplex* c,
|
| 694 |
+
const int* ldc);
|
| 695 |
+
|
| 696 |
+
/* HER2K with no underscore */
|
| 697 |
+
void cher2k(const char* uplo,
|
| 698 |
+
const char* trans,
|
| 699 |
+
const int* n,
|
| 700 |
+
const int* k,
|
| 701 |
+
const cuComplex* alpha,
|
| 702 |
+
const cuComplex* a,
|
| 703 |
+
const int* lda,
|
| 704 |
+
const cuComplex* b,
|
| 705 |
+
const int* ldb,
|
| 706 |
+
const float* beta,
|
| 707 |
+
cuComplex* c,
|
| 708 |
+
const int* ldc);
|
| 709 |
+
|
| 710 |
+
void zher2k(const char* uplo,
|
| 711 |
+
const char* trans,
|
| 712 |
+
const int* n,
|
| 713 |
+
const int* k,
|
| 714 |
+
const cuDoubleComplex* alpha,
|
| 715 |
+
const cuDoubleComplex* a,
|
| 716 |
+
const int* lda,
|
| 717 |
+
const cuDoubleComplex* b,
|
| 718 |
+
const int* ldb,
|
| 719 |
+
const double* beta,
|
| 720 |
+
cuDoubleComplex* c,
|
| 721 |
+
const int* ldc);
|
| 722 |
+
|
| 723 |
+
/* TRMM */
|
| 724 |
+
void strmm_(const char* side,
|
| 725 |
+
const char* uplo,
|
| 726 |
+
const char* transa,
|
| 727 |
+
const char* diag,
|
| 728 |
+
const int* m,
|
| 729 |
+
const int* n,
|
| 730 |
+
const float* alpha,
|
| 731 |
+
const float* a,
|
| 732 |
+
const int* lda,
|
| 733 |
+
float* b,
|
| 734 |
+
const int* ldb);
|
| 735 |
+
|
| 736 |
+
void dtrmm_(const char* side,
|
| 737 |
+
const char* uplo,
|
| 738 |
+
const char* transa,
|
| 739 |
+
const char* diag,
|
| 740 |
+
const int* m,
|
| 741 |
+
const int* n,
|
| 742 |
+
const double* alpha,
|
| 743 |
+
const double* a,
|
| 744 |
+
const int* lda,
|
| 745 |
+
double* b,
|
| 746 |
+
const int* ldb);
|
| 747 |
+
|
| 748 |
+
void ctrmm_(const char* side,
|
| 749 |
+
const char* uplo,
|
| 750 |
+
const char* transa,
|
| 751 |
+
const char* diag,
|
| 752 |
+
const int* m,
|
| 753 |
+
const int* n,
|
| 754 |
+
const cuComplex* alpha,
|
| 755 |
+
const cuComplex* a,
|
| 756 |
+
const int* lda,
|
| 757 |
+
cuComplex* b,
|
| 758 |
+
const int* ldb);
|
| 759 |
+
|
| 760 |
+
void ztrmm_(const char* side,
|
| 761 |
+
const char* uplo,
|
| 762 |
+
const char* transa,
|
| 763 |
+
const char* diag,
|
| 764 |
+
const int* m,
|
| 765 |
+
const int* n,
|
| 766 |
+
const cuDoubleComplex* alpha,
|
| 767 |
+
const cuDoubleComplex* a,
|
| 768 |
+
const int* lda,
|
| 769 |
+
cuDoubleComplex* b,
|
| 770 |
+
const int* ldb);
|
| 771 |
+
|
| 772 |
+
void strmm(const char* side,
|
| 773 |
+
const char* uplo,
|
| 774 |
+
const char* transa,
|
| 775 |
+
const char* diag,
|
| 776 |
+
const int* m,
|
| 777 |
+
const int* n,
|
| 778 |
+
const float* alpha,
|
| 779 |
+
const float* a,
|
| 780 |
+
const int* lda,
|
| 781 |
+
float* b,
|
| 782 |
+
const int* ldb);
|
| 783 |
+
|
| 784 |
+
void dtrmm(const char* side,
|
| 785 |
+
const char* uplo,
|
| 786 |
+
const char* transa,
|
| 787 |
+
const char* diag,
|
| 788 |
+
const int* m,
|
| 789 |
+
const int* n,
|
| 790 |
+
const double* alpha,
|
| 791 |
+
const double* a,
|
| 792 |
+
const int* lda,
|
| 793 |
+
double* b,
|
| 794 |
+
const int* ldb);
|
| 795 |
+
|
| 796 |
+
void ctrmm(const char* side,
|
| 797 |
+
const char* uplo,
|
| 798 |
+
const char* transa,
|
| 799 |
+
const char* diag,
|
| 800 |
+
const int* m,
|
| 801 |
+
const int* n,
|
| 802 |
+
const cuComplex* alpha,
|
| 803 |
+
const cuComplex* a,
|
| 804 |
+
const int* lda,
|
| 805 |
+
cuComplex* b,
|
| 806 |
+
const int* ldb);
|
| 807 |
+
|
| 808 |
+
void ztrmm(const char* side,
|
| 809 |
+
const char* uplo,
|
| 810 |
+
const char* transa,
|
| 811 |
+
const char* diag,
|
| 812 |
+
const int* m,
|
| 813 |
+
const int* n,
|
| 814 |
+
const cuDoubleComplex* alpha,
|
| 815 |
+
const cuDoubleComplex* a,
|
| 816 |
+
const int* lda,
|
| 817 |
+
cuDoubleComplex* b,
|
| 818 |
+
const int* ldb);
|
| 819 |
+
|
| 820 |
+
#if defined(__cplusplus)
|
| 821 |
+
}
|
| 822 |
+
#endif /* __cplusplus */
|
| 823 |
+
|
| 824 |
+
#endif /* !defined(NVBLAS_H_) */
|
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (190 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c2a58dc54154208392301d0fe3d53a120e4c1ebeab9e80ce91fe9948baeadc9
|
| 3 |
+
size 757496
|
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (190 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/nvrtc.h
ADDED
|
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// NVIDIA_COPYRIGHT_BEGIN
|
| 3 |
+
//
|
| 4 |
+
// Copyright (c) 2014-2023, NVIDIA CORPORATION. All rights reserved.
|
| 5 |
+
//
|
| 6 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 7 |
+
// and proprietary rights in and to this software, related documentation
|
| 8 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
| 9 |
+
// distribution of this software and related documentation without an express
|
| 10 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 11 |
+
//
|
| 12 |
+
// NVIDIA_COPYRIGHT_END
|
| 13 |
+
//
|
| 14 |
+
|
| 15 |
+
#ifndef __NVRTC_H__
|
| 16 |
+
#define __NVRTC_H__
|
| 17 |
+
|
| 18 |
+
#ifdef __cplusplus
|
| 19 |
+
extern "C" {
|
| 20 |
+
#endif /* __cplusplus */
|
| 21 |
+
|
| 22 |
+
#include <stdlib.h>
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
/*************************************************************************//**
|
| 26 |
+
*
|
| 27 |
+
* \defgroup error Error Handling
|
| 28 |
+
*
|
| 29 |
+
* NVRTC defines the following enumeration type and function for API call
|
| 30 |
+
* error handling.
|
| 31 |
+
*
|
| 32 |
+
****************************************************************************/
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* \ingroup error
|
| 37 |
+
* \brief The enumerated type nvrtcResult defines API call result codes.
|
| 38 |
+
* NVRTC API functions return nvrtcResult to indicate the call
|
| 39 |
+
* result.
|
| 40 |
+
*/
|
| 41 |
+
typedef enum {
|
| 42 |
+
NVRTC_SUCCESS = 0,
|
| 43 |
+
NVRTC_ERROR_OUT_OF_MEMORY = 1,
|
| 44 |
+
NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2,
|
| 45 |
+
NVRTC_ERROR_INVALID_INPUT = 3,
|
| 46 |
+
NVRTC_ERROR_INVALID_PROGRAM = 4,
|
| 47 |
+
NVRTC_ERROR_INVALID_OPTION = 5,
|
| 48 |
+
NVRTC_ERROR_COMPILATION = 6,
|
| 49 |
+
NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7,
|
| 50 |
+
NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8,
|
| 51 |
+
NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9,
|
| 52 |
+
NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10,
|
| 53 |
+
NVRTC_ERROR_INTERNAL_ERROR = 11,
|
| 54 |
+
NVRTC_ERROR_TIME_FILE_WRITE_FAILED = 12
|
| 55 |
+
} nvrtcResult;
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
/**
|
| 59 |
+
* \ingroup error
|
| 60 |
+
* \brief nvrtcGetErrorString is a helper function that returns a string
|
| 61 |
+
* describing the given nvrtcResult code, e.g., NVRTC_SUCCESS to
|
| 62 |
+
* \c "NVRTC_SUCCESS".
|
| 63 |
+
* For unrecognized enumeration values, it returns
|
| 64 |
+
* \c "NVRTC_ERROR unknown".
|
| 65 |
+
*
|
| 66 |
+
* \param [in] result CUDA Runtime Compilation API result code.
|
| 67 |
+
* \return Message string for the given #nvrtcResult code.
|
| 68 |
+
*/
|
| 69 |
+
const char *nvrtcGetErrorString(nvrtcResult result);
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
/*************************************************************************//**
|
| 73 |
+
*
|
| 74 |
+
* \defgroup query General Information Query
|
| 75 |
+
*
|
| 76 |
+
* NVRTC defines the following function for general information query.
|
| 77 |
+
*
|
| 78 |
+
****************************************************************************/
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
/**
|
| 82 |
+
* \ingroup query
|
| 83 |
+
* \brief nvrtcVersion sets the output parameters \p major and \p minor
|
| 84 |
+
* with the CUDA Runtime Compilation version number.
|
| 85 |
+
*
|
| 86 |
+
* \param [out] major CUDA Runtime Compilation major version number.
|
| 87 |
+
* \param [out] minor CUDA Runtime Compilation minor version number.
|
| 88 |
+
* \return
|
| 89 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 90 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 91 |
+
*
|
| 92 |
+
*/
|
| 93 |
+
nvrtcResult nvrtcVersion(int *major, int *minor);
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
/**
|
| 97 |
+
* \ingroup query
|
| 98 |
+
* \brief nvrtcGetNumSupportedArchs sets the output parameter \p numArchs
|
| 99 |
+
* with the number of architectures supported by NVRTC. This can
|
| 100 |
+
* then be used to pass an array to ::nvrtcGetSupportedArchs to
|
| 101 |
+
* get the supported architectures.
|
| 102 |
+
*
|
| 103 |
+
* \param [out] numArchs number of supported architectures.
|
| 104 |
+
* \return
|
| 105 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 106 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 107 |
+
*
|
| 108 |
+
* see ::nvrtcGetSupportedArchs
|
| 109 |
+
*/
|
| 110 |
+
nvrtcResult nvrtcGetNumSupportedArchs(int* numArchs);
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
/**
|
| 114 |
+
* \ingroup query
|
| 115 |
+
* \brief nvrtcGetSupportedArchs populates the array passed via the output parameter
|
| 116 |
+
* \p supportedArchs with the architectures supported by NVRTC. The array is
|
| 117 |
+
* sorted in the ascending order. The size of the array to be passed can be
|
| 118 |
+
* determined using ::nvrtcGetNumSupportedArchs.
|
| 119 |
+
*
|
| 120 |
+
* \param [out] supportedArchs sorted array of supported architectures.
|
| 121 |
+
* \return
|
| 122 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 123 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 124 |
+
*
|
| 125 |
+
* see ::nvrtcGetNumSupportedArchs
|
| 126 |
+
*/
|
| 127 |
+
nvrtcResult nvrtcGetSupportedArchs(int* supportedArchs);
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
/*************************************************************************//**
|
| 131 |
+
*
|
| 132 |
+
* \defgroup compilation Compilation
|
| 133 |
+
*
|
| 134 |
+
* NVRTC defines the following type and functions for actual compilation.
|
| 135 |
+
*
|
| 136 |
+
****************************************************************************/
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
/**
|
| 140 |
+
* \ingroup compilation
|
| 141 |
+
* \brief nvrtcProgram is the unit of compilation, and an opaque handle for
|
| 142 |
+
* a program.
|
| 143 |
+
*
|
| 144 |
+
* To compile a CUDA program string, an instance of nvrtcProgram must be
|
| 145 |
+
* created first with ::nvrtcCreateProgram, then compiled with
|
| 146 |
+
* ::nvrtcCompileProgram.
|
| 147 |
+
*/
|
| 148 |
+
typedef struct _nvrtcProgram *nvrtcProgram;
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
/**
|
| 152 |
+
* \ingroup compilation
|
| 153 |
+
* \brief nvrtcCreateProgram creates an instance of nvrtcProgram with the
|
| 154 |
+
* given input parameters, and sets the output parameter \p prog with
|
| 155 |
+
* it.
|
| 156 |
+
*
|
| 157 |
+
* \param [out] prog CUDA Runtime Compilation program.
|
| 158 |
+
* \param [in] src CUDA program source.
|
| 159 |
+
* \param [in] name CUDA program name.\n
|
| 160 |
+
* \p name can be \c NULL; \c "default_program" is
|
| 161 |
+
* used when \p name is \c NULL or "".
|
| 162 |
+
* \param [in] numHeaders Number of headers used.\n
|
| 163 |
+
* \p numHeaders must be greater than or equal to 0.
|
| 164 |
+
* \param [in] headers Sources of the headers.\n
|
| 165 |
+
* \p headers can be \c NULL when \p numHeaders is
|
| 166 |
+
* 0.
|
| 167 |
+
* \param [in] includeNames Name of each header by which they can be
|
| 168 |
+
* included in the CUDA program source.\n
|
| 169 |
+
* \p includeNames can be \c NULL when \p numHeaders
|
| 170 |
+
* is 0. These headers must be included with the exact
|
| 171 |
+
* names specified here.
|
| 172 |
+
* \return
|
| 173 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 174 |
+
* - \link #nvrtcResult NVRTC_ERROR_OUT_OF_MEMORY \endlink
|
| 175 |
+
* - \link #nvrtcResult NVRTC_ERROR_PROGRAM_CREATION_FAILURE \endlink
|
| 176 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 177 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 178 |
+
*
|
| 179 |
+
* \see ::nvrtcDestroyProgram
|
| 180 |
+
*/
|
| 181 |
+
nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
|
| 182 |
+
const char *src,
|
| 183 |
+
const char *name,
|
| 184 |
+
int numHeaders,
|
| 185 |
+
const char * const *headers,
|
| 186 |
+
const char * const *includeNames);
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
/**
|
| 190 |
+
* \ingroup compilation
|
| 191 |
+
* \brief nvrtcDestroyProgram destroys the given program.
|
| 192 |
+
*
|
| 193 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 194 |
+
* \return
|
| 195 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 196 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 197 |
+
*
|
| 198 |
+
* \see ::nvrtcCreateProgram
|
| 199 |
+
*/
|
| 200 |
+
nvrtcResult nvrtcDestroyProgram(nvrtcProgram *prog);
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
/**
|
| 204 |
+
* \ingroup compilation
|
| 205 |
+
* \brief nvrtcCompileProgram compiles the given program.
|
| 206 |
+
*
|
| 207 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 208 |
+
* \param [in] numOptions Number of compiler options passed.
|
| 209 |
+
* \param [in] options Compiler options in the form of C string array.\n
|
| 210 |
+
* \p options can be \c NULL when \p numOptions is 0.
|
| 211 |
+
*
|
| 212 |
+
* \return
|
| 213 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 214 |
+
* - \link #nvrtcResult NVRTC_ERROR_OUT_OF_MEMORY \endlink
|
| 215 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 216 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 217 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_OPTION \endlink
|
| 218 |
+
* - \link #nvrtcResult NVRTC_ERROR_COMPILATION \endlink
|
| 219 |
+
* - \link #nvrtcResult NVRTC_ERROR_BUILTIN_OPERATION_FAILURE \endlink
|
| 220 |
+
* - \link #nvrtcResult NVRTC_ERROR_TIME_FILE_WRITE_FAILED \endlink
|
| 221 |
+
*
|
| 222 |
+
* It supports compile options listed in \ref options.
|
| 223 |
+
*/
|
| 224 |
+
nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
|
| 225 |
+
int numOptions, const char * const *options);
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
/**
|
| 229 |
+
* \ingroup compilation
|
| 230 |
+
* \brief nvrtcGetPTXSize sets the value of \p ptxSizeRet with the size of the PTX
|
| 231 |
+
* generated by the previous compilation of \p prog (including the
|
| 232 |
+
* trailing \c NULL).
|
| 233 |
+
*
|
| 234 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 235 |
+
* \param [out] ptxSizeRet Size of the generated PTX (including the trailing
|
| 236 |
+
* \c NULL).
|
| 237 |
+
* \return
|
| 238 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 239 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 240 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 241 |
+
*
|
| 242 |
+
* \see ::nvrtcGetPTX
|
| 243 |
+
*/
|
| 244 |
+
nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet);
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
/**
|
| 248 |
+
* \ingroup compilation
|
| 249 |
+
* \brief nvrtcGetPTX stores the PTX generated by the previous compilation
|
| 250 |
+
* of \p prog in the memory pointed by \p ptx.
|
| 251 |
+
*
|
| 252 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 253 |
+
* \param [out] ptx Compiled result.
|
| 254 |
+
* \return
|
| 255 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 256 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 257 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 258 |
+
*
|
| 259 |
+
* \see ::nvrtcGetPTXSize
|
| 260 |
+
*/
|
| 261 |
+
nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
/**
|
| 265 |
+
* \ingroup compilation
|
| 266 |
+
* \brief nvrtcGetCUBINSize sets the value of \p cubinSizeRet with the size of the cubin
|
| 267 |
+
* generated by the previous compilation of \p prog. The value of
|
| 268 |
+
* cubinSizeRet is set to 0 if the value specified to \c -arch is a
|
| 269 |
+
* virtual architecture instead of an actual architecture.
|
| 270 |
+
*
|
| 271 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 272 |
+
* \param [out] cubinSizeRet Size of the generated cubin.
|
| 273 |
+
* \return
|
| 274 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 275 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 276 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 277 |
+
*
|
| 278 |
+
* \see ::nvrtcGetCUBIN
|
| 279 |
+
*/
|
| 280 |
+
nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog, size_t *cubinSizeRet);
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
/**
|
| 284 |
+
* \ingroup compilation
|
| 285 |
+
* \brief nvrtcGetCUBIN stores the cubin generated by the previous compilation
|
| 286 |
+
* of \p prog in the memory pointed by \p cubin. No cubin is available
|
| 287 |
+
* if the value specified to \c -arch is a virtual architecture instead
|
| 288 |
+
* of an actual architecture.
|
| 289 |
+
*
|
| 290 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 291 |
+
* \param [out] cubin Compiled and assembled result.
|
| 292 |
+
* \return
|
| 293 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 294 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 295 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 296 |
+
*
|
| 297 |
+
* \see ::nvrtcGetCUBINSize
|
| 298 |
+
*/
|
| 299 |
+
nvrtcResult nvrtcGetCUBIN(nvrtcProgram prog, char *cubin);
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
#if defined(_WIN32)
|
| 303 |
+
# define __DEPRECATED__(msg) __declspec(deprecated(msg))
|
| 304 |
+
#elif (defined(__GNUC__) && (__GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 5 && !defined(__clang__))))
|
| 305 |
+
# define __DEPRECATED__(msg) __attribute__((deprecated))
|
| 306 |
+
#elif (defined(__GNUC__))
|
| 307 |
+
# define __DEPRECATED__(msg) __attribute__((deprecated(msg)))
|
| 308 |
+
#else
|
| 309 |
+
# define __DEPRECATED__(msg)
|
| 310 |
+
#endif
|
| 311 |
+
|
| 312 |
+
/**
|
| 313 |
+
* \ingroup compilation
|
| 314 |
+
* \brief
|
| 315 |
+
* DEPRECATION NOTICE: This function will be removed in a future release. Please use
|
| 316 |
+
* nvrtcGetLTOIRSize (and nvrtcGetLTOIR) instead.
|
| 317 |
+
*/
|
| 318 |
+
__DEPRECATED__("This function will be removed in a future release. Please use nvrtcGetLTOIRSize instead")
|
| 319 |
+
nvrtcResult nvrtcGetNVVMSize(nvrtcProgram prog, size_t *nvvmSizeRet);
|
| 320 |
+
|
| 321 |
+
/**
|
| 322 |
+
* \ingroup compilation
|
| 323 |
+
* \brief
|
| 324 |
+
* DEPRECATION NOTICE: This function will be removed in a future release. Please use
|
| 325 |
+
* nvrtcGetLTOIR (and nvrtcGetLTOIRSize) instead.
|
| 326 |
+
*/
|
| 327 |
+
__DEPRECATED__("This function will be removed in a future release. Please use nvrtcGetLTOIR instead")
|
| 328 |
+
nvrtcResult nvrtcGetNVVM(nvrtcProgram prog, char *nvvm);
|
| 329 |
+
|
| 330 |
+
#undef __DEPRECATED__
|
| 331 |
+
|
| 332 |
+
/**
|
| 333 |
+
* \ingroup compilation
|
| 334 |
+
* \brief nvrtcGetLTOIRSize sets the value of \p LTOIRSizeRet with the size of the LTO IR
|
| 335 |
+
* generated by the previous compilation of \p prog. The value of
|
| 336 |
+
* LTOIRSizeRet is set to 0 if the program was not compiled with
|
| 337 |
+
* \c -dlto.
|
| 338 |
+
*
|
| 339 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 340 |
+
* \param [out] LTOIRSizeRet Size of the generated LTO IR.
|
| 341 |
+
* \return
|
| 342 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 343 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 344 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 345 |
+
*
|
| 346 |
+
* \see ::nvrtcGetLTOIR
|
| 347 |
+
*/
|
| 348 |
+
nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *LTOIRSizeRet);
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
/**
|
| 352 |
+
* \ingroup compilation
|
| 353 |
+
* \brief nvrtcGetLTOIR stores the LTO IR generated by the previous compilation
|
| 354 |
+
* of \p prog in the memory pointed by \p LTOIR. No LTO IR is available
|
| 355 |
+
* if the program was compiled without \c -dlto.
|
| 356 |
+
*
|
| 357 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 358 |
+
* \param [out] LTOIR Compiled result.
|
| 359 |
+
* \return
|
| 360 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 361 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 362 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 363 |
+
*
|
| 364 |
+
* \see ::nvrtcGetLTOIRSize
|
| 365 |
+
*/
|
| 366 |
+
nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *LTOIR);
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
/**
|
| 370 |
+
* \ingroup compilation
|
| 371 |
+
* \brief nvrtcGetOptiXIRSize sets the value of \p optixirSizeRet with the size of the OptiX IR
|
| 372 |
+
* generated by the previous compilation of \p prog. The value of
|
| 373 |
+
* nvrtcGetOptiXIRSize is set to 0 if the program was compiled with
|
| 374 |
+
* options incompatible with OptiX IR generation.
|
| 375 |
+
*
|
| 376 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 377 |
+
* \param [out] optixirSizeRet Size of the generated LTO IR.
|
| 378 |
+
* \return
|
| 379 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 380 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 381 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 382 |
+
*
|
| 383 |
+
* \see ::nvrtcGetOptiXIR
|
| 384 |
+
*/
|
| 385 |
+
nvrtcResult nvrtcGetOptiXIRSize(nvrtcProgram prog, size_t *optixirSizeRet);
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
/**
|
| 389 |
+
* \ingroup compilation
|
| 390 |
+
* \brief nvrtcGetOptiXIR stores the OptiX IR generated by the previous compilation
|
| 391 |
+
* of \p prog in the memory pointed by \p optixir. No OptiX IR is available
|
| 392 |
+
* if the program was compiled with options incompatible with OptiX IR generation.
|
| 393 |
+
*
|
| 394 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 395 |
+
* \param [out] Optix IR Compiled result.
|
| 396 |
+
* \return
|
| 397 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 398 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 399 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 400 |
+
*
|
| 401 |
+
* \see ::nvrtcGetOptiXIRSize
|
| 402 |
+
*/
|
| 403 |
+
nvrtcResult nvrtcGetOptiXIR(nvrtcProgram prog, char *optixir);
|
| 404 |
+
|
| 405 |
+
/**
|
| 406 |
+
* \ingroup compilation
|
| 407 |
+
* \brief nvrtcGetProgramLogSize sets \p logSizeRet with the size of the
|
| 408 |
+
* log generated by the previous compilation of \p prog (including the
|
| 409 |
+
* trailing \c NULL).
|
| 410 |
+
*
|
| 411 |
+
* Note that compilation log may be generated with warnings and informative
|
| 412 |
+
* messages, even when the compilation of \p prog succeeds.
|
| 413 |
+
*
|
| 414 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 415 |
+
* \param [out] logSizeRet Size of the compilation log
|
| 416 |
+
* (including the trailing \c NULL).
|
| 417 |
+
* \return
|
| 418 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 419 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 420 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 421 |
+
*
|
| 422 |
+
* \see ::nvrtcGetProgramLog
|
| 423 |
+
*/
|
| 424 |
+
nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, size_t *logSizeRet);
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
/**
|
| 428 |
+
* \ingroup compilation
|
| 429 |
+
* \brief nvrtcGetProgramLog stores the log generated by the previous
|
| 430 |
+
* compilation of \p prog in the memory pointed by \p log.
|
| 431 |
+
*
|
| 432 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 433 |
+
* \param [out] log Compilation log.
|
| 434 |
+
* \return
|
| 435 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 436 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 437 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 438 |
+
*
|
| 439 |
+
* \see ::nvrtcGetProgramLogSize
|
| 440 |
+
*/
|
| 441 |
+
nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
/**
|
| 445 |
+
* \ingroup compilation
|
| 446 |
+
* \brief nvrtcAddNameExpression notes the given name expression
|
| 447 |
+
* denoting the address of a __global__ function
|
| 448 |
+
* or __device__/__constant__ variable.
|
| 449 |
+
*
|
| 450 |
+
* The identical name expression string must be provided on a subsequent
|
| 451 |
+
* call to nvrtcGetLoweredName to extract the lowered name.
|
| 452 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 453 |
+
* \param [in] name_expression constant expression denoting the address of
|
| 454 |
+
* a __global__ function or __device__/__constant__ variable.
|
| 455 |
+
* \return
|
| 456 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 457 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 458 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 459 |
+
* - \link #nvrtcResult NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION \endlink
|
| 460 |
+
*
|
| 461 |
+
* \see ::nvrtcGetLoweredName
|
| 462 |
+
*/
|
| 463 |
+
nvrtcResult nvrtcAddNameExpression(nvrtcProgram prog,
|
| 464 |
+
const char * const name_expression);
|
| 465 |
+
|
| 466 |
+
/**
|
| 467 |
+
* \ingroup compilation
|
| 468 |
+
* \brief nvrtcGetLoweredName extracts the lowered (mangled) name
|
| 469 |
+
* for a __global__ function or __device__/__constant__ variable,
|
| 470 |
+
* and updates *lowered_name to point to it. The memory containing
|
| 471 |
+
* the name is released when the NVRTC program is destroyed by
|
| 472 |
+
* nvrtcDestroyProgram.
|
| 473 |
+
* The identical name expression must have been previously
|
| 474 |
+
* provided to nvrtcAddNameExpression.
|
| 475 |
+
*
|
| 476 |
+
* \param [in] prog CUDA Runtime Compilation program.
|
| 477 |
+
* \param [in] name_expression constant expression denoting the address of
|
| 478 |
+
* a __global__ function or __device__/__constant__ variable.
|
| 479 |
+
* \param [out] lowered_name initialized by the function to point to a
|
| 480 |
+
* C string containing the lowered (mangled)
|
| 481 |
+
* name corresponding to the provided name expression.
|
| 482 |
+
* \return
|
| 483 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 484 |
+
* - \link #nvrtcResult NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION \endlink
|
| 485 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink
|
| 486 |
+
* - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink
|
| 487 |
+
* - \link #nvrtcResult NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID \endlink
|
| 488 |
+
*
|
| 489 |
+
* \see ::nvrtcAddNameExpression
|
| 490 |
+
*/
|
| 491 |
+
nvrtcResult nvrtcGetLoweredName(nvrtcProgram prog,
|
| 492 |
+
const char *const name_expression,
|
| 493 |
+
const char** lowered_name);
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
/**
|
| 497 |
+
* \defgroup options Supported Compile Options
|
| 498 |
+
*
|
| 499 |
+
* NVRTC supports the compile options below.
|
| 500 |
+
* Option names with two preceding dashs (\c --) are long option names and
|
| 501 |
+
* option names with one preceding dash (\c -) are short option names.
|
| 502 |
+
* Short option names can be used instead of long option names.
|
| 503 |
+
* When a compile option takes an argument, an assignment operator (\c =)
|
| 504 |
+
* is used to separate the compile option argument from the compile option
|
| 505 |
+
* name, e.g., \c "--gpu-architecture=compute_60".
|
| 506 |
+
* Alternatively, the compile option name and the argument can be specified in
|
| 507 |
+
* separate strings without an assignment operator, .e.g,
|
| 508 |
+
* \c "--gpu-architecture" \c "compute_60".
|
| 509 |
+
* Single-character short option names, such as \c -D, \c -U, and \c -I, do
|
| 510 |
+
* not require an assignment operator, and the compile option name and the
|
| 511 |
+
* argument can be present in the same string with or without spaces between
|
| 512 |
+
* them.
|
| 513 |
+
* For instance, \c "-D=<def>", \c "-D<def>", and \c "-D <def>" are all
|
| 514 |
+
* supported.
|
| 515 |
+
*
|
| 516 |
+
* The valid compiler options are:
|
| 517 |
+
*
|
| 518 |
+
* - Compilation targets
|
| 519 |
+
* - \c --gpu-architecture=\<arch\> (\c -arch)\n
|
| 520 |
+
* Specify the name of the class of GPU architectures for which the
|
| 521 |
+
* input must be compiled.\n
|
| 522 |
+
* - Valid <c>\<arch\></c>s:
|
| 523 |
+
* - \c compute_50
|
| 524 |
+
* - \c compute_52
|
| 525 |
+
* - \c compute_53
|
| 526 |
+
* - \c compute_60
|
| 527 |
+
* - \c compute_61
|
| 528 |
+
* - \c compute_62
|
| 529 |
+
* - \c compute_70
|
| 530 |
+
* - \c compute_72
|
| 531 |
+
* - \c compute_75
|
| 532 |
+
* - \c compute_80
|
| 533 |
+
* - \c compute_87
|
| 534 |
+
* - \c compute_89
|
| 535 |
+
* - \c compute_90
|
| 536 |
+
* - \c compute_90a
|
| 537 |
+
* - \c sm_50
|
| 538 |
+
* - \c sm_52
|
| 539 |
+
* - \c sm_53
|
| 540 |
+
* - \c sm_60
|
| 541 |
+
* - \c sm_61
|
| 542 |
+
* - \c sm_62
|
| 543 |
+
* - \c sm_70
|
| 544 |
+
* - \c sm_72
|
| 545 |
+
* - \c sm_75
|
| 546 |
+
* - \c sm_80
|
| 547 |
+
* - \c sm_87
|
| 548 |
+
* - \c sm_89
|
| 549 |
+
* - \c sm_90
|
| 550 |
+
* - \c sm_90a
|
| 551 |
+
* - Default: \c compute_52
|
| 552 |
+
* - Separate compilation / whole-program compilation
|
| 553 |
+
* - \c --device-c (\c -dc)\n
|
| 554 |
+
* Generate relocatable code that can be linked with other relocatable
|
| 555 |
+
* device code. It is equivalent to --relocatable-device-code=true.
|
| 556 |
+
* - \c --device-w (\c -dw)\n
|
| 557 |
+
* Generate non-relocatable code. It is equivalent to
|
| 558 |
+
* \c --relocatable-device-code=false.
|
| 559 |
+
* - \c --relocatable-device-code={true|false} (\c -rdc)\n
|
| 560 |
+
* Enable (disable) the generation of relocatable device code.
|
| 561 |
+
* - Default: \c false
|
| 562 |
+
* - \c --extensible-whole-program (\c -ewp)\n
|
| 563 |
+
* Do extensible whole program compilation of device code.
|
| 564 |
+
* - Default: \c false
|
| 565 |
+
* - Debugging support
|
| 566 |
+
* - \c --device-debug (\c -G)\n
|
| 567 |
+
* Generate debug information. If --dopt is not specified,
|
| 568 |
+
* then turns off all optimizations.
|
| 569 |
+
* - \c --generate-line-info (\c -lineinfo)\n
|
| 570 |
+
* Generate line-number information.
|
| 571 |
+
* - Code generation
|
| 572 |
+
* - \c --dopt on (\c -dopt)\n
|
| 573 |
+
* - \c --dopt=on \n
|
| 574 |
+
* Enable device code optimization. When specified along with '-G', enables
|
| 575 |
+
* limited debug information generation for optimized device code (currently,
|
| 576 |
+
* only line number information).
|
| 577 |
+
* When '-G' is not specified, '-dopt=on' is implicit.
|
| 578 |
+
* - \c --ptxas-options \<options\> (\c -Xptxas)\n
|
| 579 |
+
* - \c --ptxas-options=\<options\> \n
|
| 580 |
+
* Specify options directly to ptxas, the PTX optimizing assembler.
|
| 581 |
+
* - \c --maxrregcount=\<N\> (\c -maxrregcount)\n
|
| 582 |
+
* Specify the maximum amount of registers that GPU functions can use.
|
| 583 |
+
* Until a function-specific limit, a higher value will generally
|
| 584 |
+
* increase the performance of individual GPU threads that execute this
|
| 585 |
+
* function. However, because thread registers are allocated from a
|
| 586 |
+
* global register pool on each GPU, a higher value of this option will
|
| 587 |
+
* also reduce the maximum thread block size, thereby reducing the amount
|
| 588 |
+
* of thread parallelism. Hence, a good maxrregcount value is the result
|
| 589 |
+
* of a trade-off. If this option is not specified, then no maximum is
|
| 590 |
+
* assumed. Value less than the minimum registers required by ABI will
|
| 591 |
+
* be bumped up by the compiler to ABI minimum limit.
|
| 592 |
+
* - \c --ftz={true|false} (\c -ftz)\n
|
| 593 |
+
* When performing single-precision floating-point operations, flush
|
| 594 |
+
* denormal values to zero or preserve denormal values.
|
| 595 |
+
* \c --use_fast_math implies \c --ftz=true.
|
| 596 |
+
* - Default: \c false
|
| 597 |
+
* - \c --prec-sqrt={true|false} (\c -prec-sqrt)\n
|
| 598 |
+
* For single-precision floating-point square root, use IEEE
|
| 599 |
+
* round-to-nearest mode or use a faster approximation.
|
| 600 |
+
* \c --use_fast_math implies \c --prec-sqrt=false.
|
| 601 |
+
* - Default: \c true
|
| 602 |
+
* - \c --prec-div={true|false} (\c -prec-div)\n
|
| 603 |
+
* For single-precision floating-point division and reciprocals, use IEEE
|
| 604 |
+
* round-to-nearest mode or use a faster approximation.
|
| 605 |
+
* \c --use_fast_math implies \c --prec-div=false.
|
| 606 |
+
* - Default: \c true
|
| 607 |
+
* - \c --fmad={true|false} (\c -fmad)\n
|
| 608 |
+
* Enables (disables) the contraction of floating-point multiplies and
|
| 609 |
+
* adds/subtracts into floating-point multiply-add operations (FMAD,
|
| 610 |
+
* FFMA, or DFMA). \c --use_fast_math implies \c --fmad=true.
|
| 611 |
+
* - Default: \c true
|
| 612 |
+
* - \c --use_fast_math (\c -use_fast_math)\n
|
| 613 |
+
* Make use of fast math operations.
|
| 614 |
+
* \c --use_fast_math implies \c --ftz=true \c --prec-div=false
|
| 615 |
+
* \c --prec-sqrt=false \c --fmad=true.
|
| 616 |
+
* - \c --extra-device-vectorization (\c -extra-device-vectorization)\n
|
| 617 |
+
* Enables more aggressive device code vectorization in the NVVM optimizer.
|
| 618 |
+
* - \c --modify-stack-limit={true|false} (\c -modify-stack-limit)\n
|
| 619 |
+
* On Linux, during compilation, use \c setrlimit() to increase stack size
|
| 620 |
+
* to maximum allowed. The limit is reset to the previous value at the
|
| 621 |
+
* end of compilation.
|
| 622 |
+
* Note: \c setrlimit() changes the value for the entire process.
|
| 623 |
+
* - Default: \c true
|
| 624 |
+
* - \c --dlink-time-opt (\c -dlto)\n
|
| 625 |
+
* Generate intermediate code for later link-time optimization.
|
| 626 |
+
* It implies \c -rdc=true.
|
| 627 |
+
* Note: when this option is used the nvrtcGetLTOIR API should be used,
|
| 628 |
+
* as PTX or Cubin will not be generated.
|
| 629 |
+
* - \c --gen-opt-lto (\c -gen-opt-lto)\n
|
| 630 |
+
* Run the optimizer passes before generating the LTO IR.
|
| 631 |
+
* - \c --optix-ir (\c -optix-ir)\n
|
| 632 |
+
* Generate OptiX IR. The Optix IR is only intended for consumption by OptiX
|
| 633 |
+
* through appropriate APIs. This feature is not supported with
|
| 634 |
+
* link-time-optimization (\c -dlto)\n.
|
| 635 |
+
* Note: when this option is used the nvrtcGetOptiX API should be used,
|
| 636 |
+
* as PTX or Cubin will not be generated.
|
| 637 |
+
* - \c --jump-table-density=[0-101] (\c -jtd)\n
|
| 638 |
+
* Specify the case density percentage in switch statements, and use it as
|
| 639 |
+
* a minimal threshold to determine whether jump table(brx.idx instruction)
|
| 640 |
+
* will be used to implement a switch statement. Default value is 101. The
|
| 641 |
+
* percentage ranges from 0 to 101 inclusively.
|
| 642 |
+
* - Preprocessing
|
| 643 |
+
* - \c --define-macro=\<def\> (\c -D)\n
|
| 644 |
+
* \c \<def\> can be either \c \<name\> or \c \<name=definitions\>.
|
| 645 |
+
* - \c \<name\> \n
|
| 646 |
+
* Predefine \c \<name\> as a macro with definition \c 1.
|
| 647 |
+
* - \c \<name\>=\<definition\> \n
|
| 648 |
+
* The contents of \c \<definition\> are tokenized and preprocessed
|
| 649 |
+
* as if they appeared during translation phase three in a \c \#define
|
| 650 |
+
* directive. In particular, the definition will be truncated by
|
| 651 |
+
* embedded new line characters.
|
| 652 |
+
* - \c --undefine-macro=\<def\> (\c -U)\n
|
| 653 |
+
* Cancel any previous definition of \c \<def\>.
|
| 654 |
+
* - \c --include-path=\<dir\> (\c -I)\n
|
| 655 |
+
* Add the directory \c \<dir\> to the list of directories to be
|
| 656 |
+
* searched for headers. These paths are searched after the list of
|
| 657 |
+
* headers given to ::nvrtcCreateProgram.
|
| 658 |
+
* - \c --pre-include=\<header\> (\c -include)\n
|
| 659 |
+
* Preinclude \c \<header\> during preprocessing.
|
| 660 |
+
* - \c --no-source-include (\c -no-source-include)
|
| 661 |
+
* The preprocessor by default adds the directory of each input sources
|
| 662 |
+
* to the include path. This option disables this feature and only
|
| 663 |
+
* considers the path specified explicitly.
|
| 664 |
+
* - Language Dialect
|
| 665 |
+
* - \c --std={c++03|c++11|c++14|c++17|c++20}
|
| 666 |
+
* (\c -std={c++11|c++14|c++17|c++20})\n
|
| 667 |
+
* Set language dialect to C++03, C++11, C++14, C++17 or C++20
|
| 668 |
+
* - Default: \c c++17
|
| 669 |
+
* - \c --builtin-move-forward={true|false} (\c -builtin-move-forward)\n
|
| 670 |
+
* Provide builtin definitions of \c std::move and \c std::forward,
|
| 671 |
+
* when C++11 or later language dialect is selected.
|
| 672 |
+
* - Default: \c true
|
| 673 |
+
* - \c --builtin-initializer-list={true|false}
|
| 674 |
+
* (\c -builtin-initializer-list)\n
|
| 675 |
+
* Provide builtin definitions of \c std::initializer_list class and
|
| 676 |
+
* member functions when C++11 or later language dialect is selected.
|
| 677 |
+
* - Default: \c true
|
| 678 |
+
* - Misc.
|
| 679 |
+
* - \c --disable-warnings (\c -w)\n
|
| 680 |
+
* Inhibit all warning messages.
|
| 681 |
+
* - \c --restrict (\c -restrict)\n
|
| 682 |
+
* Programmer assertion that all kernel pointer parameters are restrict
|
| 683 |
+
* pointers.
|
| 684 |
+
* - \c --device-as-default-execution-space
|
| 685 |
+
* (\c -default-device)\n
|
| 686 |
+
* Treat entities with no execution space annotation as \c __device__
|
| 687 |
+
* entities.
|
| 688 |
+
* - \c --device-int128 (\c -device-int128)\n
|
| 689 |
+
* Allow the \c __int128 type in device code. Also causes the macro \c __CUDACC_RTC_INT128__
|
| 690 |
+
* to be defined.
|
| 691 |
+
* - \c --optimization-info=\<kind\> (\c -opt-info)\n
|
| 692 |
+
* Provide optimization reports for the specified kind of optimization.
|
| 693 |
+
* The following kind tags are supported:
|
| 694 |
+
* - \c inline : emit a remark when a function is inlined.
|
| 695 |
+
* - \c --display-error-number (\c -err-no)\n
|
| 696 |
+
* Display diagnostic number for warning messages. (Default)
|
| 697 |
+
* - \c --no-display-error-number (\c -no-err-no)\n
|
| 698 |
+
* Disables the display of a diagnostic number for warning messages.
|
| 699 |
+
* - \c --diag-error=<error-number>,... (\c -diag-error)\n
|
| 700 |
+
* Emit error for specified diagnostic message number(s). Message numbers can be separated by comma.
|
| 701 |
+
* - \c --diag-suppress=<error-number>,... (\c -diag-suppress)\n
|
| 702 |
+
* Suppress specified diagnostic message number(s). Message numbers can be separated by comma.
|
| 703 |
+
* - \c --diag-warn=<error-number>,... (\c -diag-warn)\n
|
| 704 |
+
* Emit warning for specified diagnostic message number(s). Message numbers can be separated by comma.
|
| 705 |
+
* - \c --brief-diagnostics={true|false} (\c -brief-diag)\n
|
| 706 |
+
* This option disables or enables showing source line and column info
|
| 707 |
+
* in a diagnostic.
|
| 708 |
+
* The --brief-diagnostics=true will not show the source line and column info.
|
| 709 |
+
* - Default: \c false
|
| 710 |
+
* - \c --time=<file-name> (\c -time)\n
|
| 711 |
+
* Generate a comma separated value table with the time taken by each compilation
|
| 712 |
+
* phase, and append it at the end of the file given as the option argument.
|
| 713 |
+
* If the file does not exist, the column headings are generated in the first row
|
| 714 |
+
* of the table. If the file name is '-', the timing data is written to the compilation log.
|
| 715 |
+
* - \c --split-compile=<number of threads> (\c -split-compile=<number of threads>)\n
|
| 716 |
+
* Perform compiler optimizations in parallel.
|
| 717 |
+
* Split compilation attempts to reduce compile time by enabling the compiler to run certain
|
| 718 |
+
* optimization passes concurrently. This option accepts a numerical value that specifies the
|
| 719 |
+
* maximum number of threads the compiler can use. One can also allow the compiler to use the maximum
|
| 720 |
+
* threads available on the system by setting --split-compile=0.
|
| 721 |
+
* Setting --split-compile=1 will cause this option to be ignored.
|
| 722 |
+
* - \c --fdevice-syntax-only (\c -fdevice-syntax-only)\n
|
| 723 |
+
* Ends device compilation after front-end syntax checking. This option does not generate valid
|
| 724 |
+
* device code.
|
| 725 |
+
* - \c --minimal (\c -minimal)\n
|
| 726 |
+
* Omit certain language features to reduce compile time for small programs.
|
| 727 |
+
* In particular, the following are omitted:
|
| 728 |
+
* - Texture and surface functions and associated types, e.g., \c cudaTextureObject_t.
|
| 729 |
+
* - CUDA Runtime Functions that are provided by the cudadevrt device code library,
|
| 730 |
+
* typically named with prefix "cuda", e.g., \c cudaMalloc.
|
| 731 |
+
* - Kernel launch from device code.
|
| 732 |
+
* - Types and macros associated with CUDA Runtime and Driver APIs,
|
| 733 |
+
* provided by cuda/tools/cudart/driver_types.h, typically named with prefix "cuda", e.g., \c cudaError_t.
|
| 734 |
+
*
|
| 735 |
+
*/
|
| 736 |
+
|
| 737 |
+
#ifdef __cplusplus
|
| 738 |
+
}
|
| 739 |
+
#endif /* __cplusplus */
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
/* The utility function 'nvrtcGetTypeName' is not available by default. Define
|
| 743 |
+
the macro 'NVRTC_GET_TYPE_NAME' to a non-zero value to make it available.
|
| 744 |
+
*/
|
| 745 |
+
|
| 746 |
+
#if NVRTC_GET_TYPE_NAME || __DOXYGEN_ONLY__
|
| 747 |
+
|
| 748 |
+
#if NVRTC_USE_CXXABI || __clang__ || __GNUC__ || __DOXYGEN_ONLY__
|
| 749 |
+
#include <cxxabi.h>
|
| 750 |
+
#include <cstdlib>
|
| 751 |
+
|
| 752 |
+
#elif defined(_WIN32)
|
| 753 |
+
#include <Windows.h>
|
| 754 |
+
#include <DbgHelp.h>
|
| 755 |
+
#endif /* NVRTC_USE_CXXABI || __clang__ || __GNUC__ */
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
#include <string>
|
| 759 |
+
#include <typeinfo>
|
| 760 |
+
|
| 761 |
+
template <typename T> struct __nvrtcGetTypeName_helper_t { };
|
| 762 |
+
|
| 763 |
+
/*************************************************************************//**
|
| 764 |
+
*
|
| 765 |
+
* \defgroup hosthelper Host Helper
|
| 766 |
+
*
|
| 767 |
+
* NVRTC defines the following functions for easier interaction with host code.
|
| 768 |
+
*
|
| 769 |
+
****************************************************************************/
|
| 770 |
+
|
| 771 |
+
/**
|
| 772 |
+
* \ingroup hosthelper
|
| 773 |
+
* \brief nvrtcGetTypeName stores the source level name of a type in the given
|
| 774 |
+
* std::string location.
|
| 775 |
+
*
|
| 776 |
+
* This function is only provided when the macro NVRTC_GET_TYPE_NAME is
|
| 777 |
+
* defined with a non-zero value. It uses abi::__cxa_demangle or UnDecorateSymbolName
|
| 778 |
+
* function calls to extract the type name, when using gcc/clang or cl.exe compilers,
|
| 779 |
+
* respectively. If the name extraction fails, it will return NVRTC_INTERNAL_ERROR,
|
| 780 |
+
* otherwise *result is initialized with the extracted name.
|
| 781 |
+
*
|
| 782 |
+
* Windows-specific notes:
|
| 783 |
+
* - nvrtcGetTypeName() is not multi-thread safe because it calls UnDecorateSymbolName(),
|
| 784 |
+
* which is not multi-thread safe.
|
| 785 |
+
* - The returned string may contain Microsoft-specific keywords such as __ptr64 and __cdecl.
|
| 786 |
+
*
|
| 787 |
+
* \param [in] tinfo: reference to object of type std::type_info for a given type.
|
| 788 |
+
* \param [in] result: pointer to std::string in which to store the type name.
|
| 789 |
+
* \return
|
| 790 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 791 |
+
* - \link #nvrtcResult NVRTC_ERROR_INTERNAL_ERROR \endlink
|
| 792 |
+
*
|
| 793 |
+
*/
|
| 794 |
+
inline nvrtcResult nvrtcGetTypeName(const std::type_info &tinfo, std::string *result)
|
| 795 |
+
{
|
| 796 |
+
#if USE_CXXABI || __clang__ || __GNUC__
|
| 797 |
+
const char *name = tinfo.name();
|
| 798 |
+
int status;
|
| 799 |
+
char *undecorated_name = abi::__cxa_demangle(name, 0, 0, &status);
|
| 800 |
+
if (status == 0) {
|
| 801 |
+
*result = undecorated_name;
|
| 802 |
+
free(undecorated_name);
|
| 803 |
+
return NVRTC_SUCCESS;
|
| 804 |
+
}
|
| 805 |
+
#elif defined(_WIN32)
|
| 806 |
+
const char *name = tinfo.raw_name();
|
| 807 |
+
if (!name || *name != '.') {
|
| 808 |
+
return NVRTC_ERROR_INTERNAL_ERROR;
|
| 809 |
+
}
|
| 810 |
+
char undecorated_name[4096];
|
| 811 |
+
//name+1 skips over the '.' prefix
|
| 812 |
+
if(UnDecorateSymbolName(name+1, undecorated_name,
|
| 813 |
+
sizeof(undecorated_name) / sizeof(*undecorated_name),
|
| 814 |
+
//note: doesn't seem to work correctly without UNDNAME_NO_ARGUMENTS.
|
| 815 |
+
UNDNAME_NO_ARGUMENTS | UNDNAME_NAME_ONLY ) ) {
|
| 816 |
+
*result = undecorated_name;
|
| 817 |
+
return NVRTC_SUCCESS;
|
| 818 |
+
}
|
| 819 |
+
#endif /* USE_CXXABI || __clang__ || __GNUC__ */
|
| 820 |
+
|
| 821 |
+
return NVRTC_ERROR_INTERNAL_ERROR;
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
/**
|
| 825 |
+
* \ingroup hosthelper
|
| 826 |
+
* \brief nvrtcGetTypeName stores the source level name of the template type argument
|
| 827 |
+
* T in the given std::string location.
|
| 828 |
+
*
|
| 829 |
+
* This function is only provided when the macro NVRTC_GET_TYPE_NAME is
|
| 830 |
+
* defined with a non-zero value. It uses abi::__cxa_demangle or UnDecorateSymbolName
|
| 831 |
+
* function calls to extract the type name, when using gcc/clang or cl.exe compilers,
|
| 832 |
+
* respectively. If the name extraction fails, it will return NVRTC_INTERNAL_ERROR,
|
| 833 |
+
* otherwise *result is initialized with the extracted name.
|
| 834 |
+
*
|
| 835 |
+
* Windows-specific notes:
|
| 836 |
+
* - nvrtcGetTypeName() is not multi-thread safe because it calls UnDecorateSymbolName(),
|
| 837 |
+
* which is not multi-thread safe.
|
| 838 |
+
* - The returned string may contain Microsoft-specific keywords such as __ptr64 and __cdecl.
|
| 839 |
+
*
|
| 840 |
+
* \param [in] result: pointer to std::string in which to store the type name.
|
| 841 |
+
* \return
|
| 842 |
+
* - \link #nvrtcResult NVRTC_SUCCESS \endlink
|
| 843 |
+
* - \link #nvrtcResult NVRTC_ERROR_INTERNAL_ERROR \endlink
|
| 844 |
+
*
|
| 845 |
+
*/
|
| 846 |
+
|
| 847 |
+
template <typename T>
|
| 848 |
+
nvrtcResult nvrtcGetTypeName(std::string *result)
|
| 849 |
+
{
|
| 850 |
+
nvrtcResult res = nvrtcGetTypeName(typeid(__nvrtcGetTypeName_helper_t<T>),
|
| 851 |
+
result);
|
| 852 |
+
if (res != NVRTC_SUCCESS)
|
| 853 |
+
return res;
|
| 854 |
+
|
| 855 |
+
std::string repr = *result;
|
| 856 |
+
std::size_t idx = repr.find("__nvrtcGetTypeName_helper_t");
|
| 857 |
+
idx = (idx != std::string::npos) ? repr.find("<", idx) : idx;
|
| 858 |
+
std::size_t last_idx = repr.find_last_of('>');
|
| 859 |
+
if (idx == std::string::npos || last_idx == std::string::npos) {
|
| 860 |
+
return NVRTC_ERROR_INTERNAL_ERROR;
|
| 861 |
+
}
|
| 862 |
+
++idx;
|
| 863 |
+
*result = repr.substr(idx, last_idx - idx);
|
| 864 |
+
return NVRTC_SUCCESS;
|
| 865 |
+
}
|
| 866 |
+
|
| 867 |
+
#endif /* NVRTC_GET_TYPE_NAME */
|
| 868 |
+
|
| 869 |
+
#endif /* __NVRTC_H__ */
|
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/async.h
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
#ifndef _CG_ASYNC_H
|
| 50 |
+
#define _CG_ASYNC_H
|
| 51 |
+
|
| 52 |
+
#include "helpers.h"
|
| 53 |
+
#include "info.h"
|
| 54 |
+
|
| 55 |
+
#include <cuda_pipeline.h>
|
| 56 |
+
|
| 57 |
+
_CG_BEGIN_NAMESPACE
|
| 58 |
+
|
| 59 |
+
namespace details {
|
| 60 |
+
// Groups supported by memcpy_async
|
| 61 |
+
template <class TyGroup>
|
| 62 |
+
struct _async_copy_group_supported : public _CG_STL_NAMESPACE::false_type {};
|
| 63 |
+
|
| 64 |
+
template <unsigned int Sz, typename TyPar>
|
| 65 |
+
struct _async_copy_group_supported<cooperative_groups::thread_block_tile<Sz, TyPar>>
|
| 66 |
+
: public _CG_STL_NAMESPACE::true_type {};
|
| 67 |
+
template <>
|
| 68 |
+
struct _async_copy_group_supported<cooperative_groups::coalesced_group> : public _CG_STL_NAMESPACE::true_type {};
|
| 69 |
+
template <>
|
| 70 |
+
struct _async_copy_group_supported<cooperative_groups::thread_block> : public _CG_STL_NAMESPACE::true_type {};
|
| 71 |
+
|
| 72 |
+
template <class TyGroup>
|
| 73 |
+
using async_copy_group_supported = _async_copy_group_supported<details::remove_qual<TyGroup>>;
|
| 74 |
+
|
| 75 |
+
// Groups that require optimization
|
| 76 |
+
template <class TyGroup>
|
| 77 |
+
struct _async_copy_optimize_tile : public _CG_STL_NAMESPACE::false_type {};
|
| 78 |
+
|
| 79 |
+
template <typename TyPar>
|
| 80 |
+
struct _async_copy_optimize_tile<cooperative_groups::thread_block_tile<1, TyPar>>
|
| 81 |
+
: public _CG_STL_NAMESPACE::false_type {};
|
| 82 |
+
|
| 83 |
+
template <unsigned int Sz, typename TyPar>
|
| 84 |
+
struct _async_copy_optimize_tile<cooperative_groups::thread_block_tile<Sz, TyPar>>
|
| 85 |
+
: public _CG_STL_NAMESPACE::true_type {};
|
| 86 |
+
|
| 87 |
+
template <class TyGroup>
|
| 88 |
+
using async_copy_optimize_tile = _async_copy_optimize_tile<details::remove_qual<TyGroup>>;
|
| 89 |
+
|
| 90 |
+
// SFINAE helpers for tile optimizations
|
| 91 |
+
template <class TyGroup>
|
| 92 |
+
using enable_tile_optimization =
|
| 93 |
+
typename _CG_STL_NAMESPACE::enable_if<async_copy_optimize_tile<TyGroup>::value, void *>::type;
|
| 94 |
+
|
| 95 |
+
template <class TyGroup>
|
| 96 |
+
using disable_tile_optimization =
|
| 97 |
+
typename _CG_STL_NAMESPACE::enable_if<!async_copy_optimize_tile<TyGroup>::value, void *>::type;
|
| 98 |
+
|
| 99 |
+
// Segment for punning to aligned types
|
| 100 |
+
template <unsigned int N>
|
| 101 |
+
struct _Segment {
|
| 102 |
+
int _seg[N];
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
// Trivial layout guaranteed-aligned copy-async compatible segments
|
| 106 |
+
template <unsigned int N>
|
| 107 |
+
struct Segment;
|
| 108 |
+
template <>
|
| 109 |
+
struct __align__(4) Segment<1> : public _Segment<1>{};
|
| 110 |
+
template <>
|
| 111 |
+
struct __align__(8) Segment<2> : public _Segment<2>{};
|
| 112 |
+
template <>
|
| 113 |
+
struct __align__(16) Segment<4> : public _Segment<4>{};
|
| 114 |
+
|
| 115 |
+
// Interleaved element by element copies from source to dest
|
| 116 |
+
template <typename TyGroup, typename TyElem>
|
| 117 |
+
_CG_STATIC_QUALIFIER void inline_copy(TyGroup &group, TyElem *__restrict__ dst, const TyElem *__restrict__ src,
|
| 118 |
+
size_t count) {
|
| 119 |
+
const unsigned int rank = group.thread_rank();
|
| 120 |
+
const unsigned int stride = group.size();
|
| 121 |
+
|
| 122 |
+
for (size_t idx = rank; idx < count; idx += stride) {
|
| 123 |
+
dst[idx] = src[idx];
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template <typename TyGroup, typename TyElem, enable_tile_optimization<TyGroup> = nullptr>
|
| 128 |
+
_CG_STATIC_QUALIFIER void accelerated_async_copy(TyGroup &group, TyElem *__restrict__ dst,
|
| 129 |
+
const TyElem *__restrict__ src, size_t count) {
|
| 130 |
+
static_assert(async_copy_group_supported<TyGroup>::value,
|
| 131 |
+
"Async copy is only supported for groups that represent private shared memory");
|
| 132 |
+
|
| 133 |
+
if (count == 0) {
|
| 134 |
+
return;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
const bool dstIsNotShared = !__isShared(dst);
|
| 138 |
+
const bool srcIsNotGlobal = !__isGlobal(src);
|
| 139 |
+
|
| 140 |
+
if (dstIsNotShared || srcIsNotGlobal) {
|
| 141 |
+
inline_copy(group, dst, src, count);
|
| 142 |
+
return;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
const unsigned int stride = group.size();
|
| 146 |
+
const unsigned int rank = group.thread_rank();
|
| 147 |
+
// Efficient copies require warps to operate on the same amount of work at each step.
|
| 148 |
+
// remainders are handled in a separate stage to prevent branching
|
| 149 |
+
const unsigned int subWarpMask = (stride - 1);
|
| 150 |
+
const unsigned int subwarpCopies = (subWarpMask & (unsigned int)count);
|
| 151 |
+
const unsigned int maxSubwarpRank = min(rank, subwarpCopies - 1);
|
| 152 |
+
|
| 153 |
+
const size_t warpCopies = (count & (~subWarpMask));
|
| 154 |
+
|
| 155 |
+
for (size_t idx = 0; idx < warpCopies; idx += stride) {
|
| 156 |
+
size_t _srcIdx = rank + idx;
|
| 157 |
+
size_t _dstIdx = rank + idx;
|
| 158 |
+
__pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem));
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
if (subwarpCopies) {
|
| 162 |
+
size_t _srcIdx = warpCopies + maxSubwarpRank;
|
| 163 |
+
size_t _dstIdx = warpCopies + maxSubwarpRank;
|
| 164 |
+
__pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem));
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
template <typename TyGroup, typename TyElem, disable_tile_optimization<TyGroup> = nullptr>
|
| 169 |
+
_CG_STATIC_QUALIFIER void accelerated_async_copy(TyGroup &group, TyElem *__restrict__ dst,
|
| 170 |
+
const TyElem *__restrict__ src, size_t count) {
|
| 171 |
+
static_assert(async_copy_group_supported<TyGroup>::value,
|
| 172 |
+
"Async copy is only supported for groups that represent private shared memory");
|
| 173 |
+
|
| 174 |
+
const bool dstIsNotShared = !__isShared(dst);
|
| 175 |
+
const bool srcIsNotGlobal = !__isGlobal(src);
|
| 176 |
+
|
| 177 |
+
if (dstIsNotShared || srcIsNotGlobal) {
|
| 178 |
+
inline_copy(group, dst, src, count);
|
| 179 |
+
return;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
unsigned int stride = group.size();
|
| 183 |
+
unsigned int rank = group.thread_rank();
|
| 184 |
+
|
| 185 |
+
for (size_t idx = rank; idx < count; idx += stride) {
|
| 186 |
+
size_t _srcIdx = idx;
|
| 187 |
+
size_t _dstIdx = idx;
|
| 188 |
+
__pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem));
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// Determine best possible alignment given an input and initial conditions
|
| 193 |
+
// Attempts to generate as little code as possible, most likely should only be used with 1 and 2 byte alignments
|
| 194 |
+
template <unsigned int MinAlignment, unsigned int MaxAlignment>
|
| 195 |
+
_CG_STATIC_QUALIFIER uint32_t find_best_alignment(void *__restrict__ dst, const void *__restrict__ src) {
|
| 196 |
+
// Narrowing conversion intentional
|
| 197 |
+
uint32_t base1 = (uint32_t) reinterpret_cast<uintptr_t>(src);
|
| 198 |
+
uint32_t base2 = (uint32_t) reinterpret_cast<uintptr_t>(dst);
|
| 199 |
+
|
| 200 |
+
uint32_t diff = ((base1) ^ (base2)) & (MaxAlignment - 1);
|
| 201 |
+
|
| 202 |
+
// range [MaxAlignment, alignof(elem)], step: x >> 1
|
| 203 |
+
// over range of possible alignments, choose best available out of range
|
| 204 |
+
uint32_t out = MaxAlignment;
|
| 205 |
+
#pragma unroll
|
| 206 |
+
for (uint32_t alignment = (MaxAlignment >> 1); alignment >= MinAlignment; alignment >>= 1) {
|
| 207 |
+
if (alignment & diff)
|
| 208 |
+
out = alignment;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
return out;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// Determine best possible alignment given an input and initial conditions
|
| 215 |
+
// Attempts to generate as little code as possible, most likely should only be used with 1 and 2 byte alignments
|
| 216 |
+
template <typename TyType, typename TyGroup>
|
| 217 |
+
_CG_STATIC_QUALIFIER void copy_like(const TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src,
|
| 218 |
+
size_t count) {
|
| 219 |
+
const char *src = reinterpret_cast<const char *>(_src);
|
| 220 |
+
char *dst = reinterpret_cast<char *>(_dst);
|
| 221 |
+
|
| 222 |
+
constexpr uint32_t targetAlignment = (uint32_t)alignof(TyType);
|
| 223 |
+
|
| 224 |
+
uint32_t base = (uint32_t) reinterpret_cast<uintptr_t>(src);
|
| 225 |
+
uint32_t alignOffset = ((~base) + 1) & (targetAlignment - 1);
|
| 226 |
+
|
| 227 |
+
inline_copy(group, dst, src, alignOffset);
|
| 228 |
+
count -= alignOffset;
|
| 229 |
+
src += alignOffset;
|
| 230 |
+
dst += alignOffset;
|
| 231 |
+
|
| 232 |
+
// Copy using the best available alignment, async_copy expects n-datums, not bytes
|
| 233 |
+
size_t asyncCount = count / sizeof(TyType);
|
| 234 |
+
accelerated_async_copy(group, reinterpret_cast<TyType *>(dst), reinterpret_cast<const TyType *>(src), asyncCount);
|
| 235 |
+
asyncCount *= sizeof(TyType);
|
| 236 |
+
|
| 237 |
+
count -= asyncCount;
|
| 238 |
+
src += asyncCount;
|
| 239 |
+
dst += asyncCount;
|
| 240 |
+
inline_copy(group, dst, src, count);
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
// We must determine alignment and manually align src/dst ourselves
|
| 244 |
+
template <size_t AlignHint>
|
| 245 |
+
struct _memcpy_async_align_dispatch {
|
| 246 |
+
template <typename TyGroup>
|
| 247 |
+
_CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ dst, const void *__restrict__ src, size_t count) {
|
| 248 |
+
uint32_t alignment = find_best_alignment<AlignHint, 16>(dst, src);
|
| 249 |
+
|
| 250 |
+
// Avoid copying the extra bytes if desired copy count is smaller
|
| 251 |
+
alignment = count < alignment ? AlignHint : alignment;
|
| 252 |
+
|
| 253 |
+
switch (alignment) {
|
| 254 |
+
default:
|
| 255 |
+
case 1:
|
| 256 |
+
inline_copy(group, reinterpret_cast<char *>(dst), reinterpret_cast<const char *>(src), count);
|
| 257 |
+
break;
|
| 258 |
+
case 2:
|
| 259 |
+
inline_copy(group, reinterpret_cast<short *>(dst), reinterpret_cast<const short *>(src), count >> 1);
|
| 260 |
+
break;
|
| 261 |
+
case 4:
|
| 262 |
+
copy_like<Segment<1>>(group, dst, src, count);
|
| 263 |
+
break;
|
| 264 |
+
case 8:
|
| 265 |
+
copy_like<Segment<2>>(group, dst, src, count);
|
| 266 |
+
break;
|
| 267 |
+
case 16:
|
| 268 |
+
copy_like<Segment<4>>(group, dst, src, count);
|
| 269 |
+
break;
|
| 270 |
+
}
|
| 271 |
+
}
|
| 272 |
+
};
|
| 273 |
+
|
| 274 |
+
// Specialization for 4 byte alignments
|
| 275 |
+
template <>
|
| 276 |
+
struct _memcpy_async_align_dispatch<4> {
|
| 277 |
+
template <typename TyGroup>
|
| 278 |
+
_CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src,
|
| 279 |
+
size_t count) {
|
| 280 |
+
const Segment<1> *src = reinterpret_cast<const Segment<1> *>(_src);
|
| 281 |
+
Segment<1> *dst = reinterpret_cast<Segment<1> *>(_dst);
|
| 282 |
+
|
| 283 |
+
// Dispatch straight to aligned LDGSTS calls
|
| 284 |
+
accelerated_async_copy(group, dst, src, count / sizeof(*dst));
|
| 285 |
+
}
|
| 286 |
+
};
|
| 287 |
+
|
| 288 |
+
// Specialization for 8 byte alignments
|
| 289 |
+
template <>
|
| 290 |
+
struct _memcpy_async_align_dispatch<8> {
|
| 291 |
+
template <typename TyGroup>
|
| 292 |
+
_CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src,
|
| 293 |
+
size_t count) {
|
| 294 |
+
const Segment<2> *src = reinterpret_cast<const Segment<2> *>(_src);
|
| 295 |
+
Segment<2> *dst = reinterpret_cast<Segment<2> *>(_dst);
|
| 296 |
+
|
| 297 |
+
// Dispatch straight to aligned LDGSTS calls
|
| 298 |
+
accelerated_async_copy(group, dst, src, count / sizeof(*dst));
|
| 299 |
+
}
|
| 300 |
+
};
|
| 301 |
+
|
| 302 |
+
// Alignments over 16 are truncated to 16 and bypass alignment
|
| 303 |
+
// This is the highest performing memcpy available
|
| 304 |
+
template <>
|
| 305 |
+
struct _memcpy_async_align_dispatch<16> {
|
| 306 |
+
template <typename TyGroup>
|
| 307 |
+
_CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src,
|
| 308 |
+
size_t count) {
|
| 309 |
+
const Segment<4> *src = reinterpret_cast<const Segment<4> *>(_src);
|
| 310 |
+
Segment<4> *dst = reinterpret_cast<Segment<4> *>(_dst);
|
| 311 |
+
|
| 312 |
+
// Dispatch straight to aligned LDGSTS calls
|
| 313 |
+
accelerated_async_copy(group, dst, src, count / sizeof(*dst));
|
| 314 |
+
}
|
| 315 |
+
};
|
| 316 |
+
|
| 317 |
+
// byte-wide API
|
| 318 |
+
template <size_t Alignment, class TyGroup>
|
| 319 |
+
_CG_STATIC_QUALIFIER void _memcpy_async_dispatch_to_aligned_copy(const TyGroup &group, void *__restrict__ _dst,
|
| 320 |
+
const void *__restrict__ _src, size_t count) {
|
| 321 |
+
static_assert(!(Alignment & (Alignment - 1)), "Known static alignment dispatch must be a power of 2");
|
| 322 |
+
details::_memcpy_async_align_dispatch<Alignment>::copy(group, _dst, _src, count);
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
// Internal dispatch APIs
|
| 326 |
+
// These deduce the alignments and sizes necessary to invoke the underlying copy engine
|
| 327 |
+
template <typename Ty>
|
| 328 |
+
using is_void = _CG_STL_NAMESPACE::is_same<Ty, void>;
|
| 329 |
+
|
| 330 |
+
template <typename Ty>
|
| 331 |
+
using enable_if_not_void = typename _CG_STL_NAMESPACE::enable_if<!is_void<Ty>::value, void *>::type;
|
| 332 |
+
|
| 333 |
+
template <typename Ty>
|
| 334 |
+
using enable_if_void = typename _CG_STL_NAMESPACE::enable_if<is_void<Ty>::value, void *>::type;
|
| 335 |
+
|
| 336 |
+
template <typename Ty>
|
| 337 |
+
using enable_if_integral =
|
| 338 |
+
typename _CG_STL_NAMESPACE::enable_if<_CG_STL_NAMESPACE::is_integral<Ty>::value, void *>::type;
|
| 339 |
+
|
| 340 |
+
// byte-wide API using aligned_sized_t
|
| 341 |
+
template <class TyGroup, template <size_t> typename Alignment, size_t Hint>
|
| 342 |
+
_CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, void *__restrict__ _dst,
|
| 343 |
+
const void *__restrict__ _src, const Alignment<Hint> &count) {
|
| 344 |
+
constexpr size_t _align = (Hint > 16) ? 16 : Hint;
|
| 345 |
+
|
| 346 |
+
details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, _dst, _src, (size_t)count);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
// byte-wide API using type for aligment
|
| 350 |
+
template <class TyGroup, typename TyElem, typename TySize, size_t Hint = alignof(TyElem),
|
| 351 |
+
enable_if_not_void<TyElem> = nullptr, enable_if_integral<TySize> = nullptr>
|
| 352 |
+
_CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, TyElem *__restrict__ _dst,
|
| 353 |
+
const TyElem *__restrict__ _src, const TySize& count) {
|
| 354 |
+
constexpr size_t _align = (Hint > 16) ? 16 : Hint;
|
| 355 |
+
|
| 356 |
+
details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, _dst, _src, count);
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
// byte-wide API with full alignment deduction required
|
| 360 |
+
template <class TyGroup, typename TyElem, typename TySize, enable_if_void<TyElem> = nullptr,
|
| 361 |
+
enable_if_integral<TySize> = nullptr>
|
| 362 |
+
_CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, TyElem *__restrict__ _dst,
|
| 363 |
+
const TyElem *__restrict__ _src, const TySize& count) {
|
| 364 |
+
details::_memcpy_async_dispatch_to_aligned_copy<1>(group, _dst, _src, count);
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
// 1d-datum API
|
| 368 |
+
template <class TyGroup, typename TyElem, size_t Hint = alignof(TyElem)>
|
| 369 |
+
_CG_STATIC_QUALIFIER void _memcpy_async_datum(const TyGroup &group, TyElem *__restrict__ dst, const size_t dstCount,
|
| 370 |
+
const TyElem *__restrict__ src, const size_t srcCount) {
|
| 371 |
+
constexpr unsigned int _align = Hint;
|
| 372 |
+
const size_t totalCount = min(dstCount, srcCount) * sizeof(TyElem);
|
| 373 |
+
|
| 374 |
+
details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, dst, src, totalCount);
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
// 1d-datum API using aligned_size_t
|
| 378 |
+
template <class TyGroup, typename TyElem, template <size_t> typename Alignment, size_t Hint>
|
| 379 |
+
_CG_STATIC_QUALIFIER void _memcpy_async_datum(const TyGroup &group, TyElem *__restrict__ dst, const Alignment<Hint> &dstCount,
|
| 380 |
+
const TyElem *__restrict__ src, const Alignment<Hint> &srcCount) {
|
| 381 |
+
constexpr unsigned int _align = Hint;
|
| 382 |
+
const size_t totalCount = min((size_t)dstCount, (size_t)srcCount) * sizeof(TyElem);
|
| 383 |
+
|
| 384 |
+
details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, dst, src, totalCount);
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
} // namespace details
|
| 388 |
+
|
| 389 |
+
/*
|
| 390 |
+
* Group submit batch of async-copy to cover contiguous 1D array
|
| 391 |
+
* and commit that batch to eventually wait for completion.
|
| 392 |
+
*/
|
| 393 |
+
template <class TyGroup, typename TyElem, typename TySizeT>
|
| 394 |
+
_CG_STATIC_QUALIFIER void memcpy_async(const TyGroup &group, TyElem *__restrict__ _dst, const TyElem *__restrict__ _src,
|
| 395 |
+
const TySizeT &count) {
|
| 396 |
+
details::_memcpy_async_bytes(group, _dst, _src, count);
|
| 397 |
+
__pipeline_commit();
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
/*
|
| 401 |
+
* Group submit batch of async-copy to cover contiguous 1D array
|
| 402 |
+
* and commit that batch to eventually wait for completion.
|
| 403 |
+
* Object counts are in datum sized chunks, not bytes.
|
| 404 |
+
*/
|
| 405 |
+
template <class TyGroup, class TyElem, typename DstLayout, typename SrcLayout>
|
| 406 |
+
_CG_STATIC_QUALIFIER void memcpy_async(const TyGroup &group, TyElem *__restrict__ dst, const DstLayout &dstLayout,
|
| 407 |
+
const TyElem *__restrict__ src, const SrcLayout &srcLayout) {
|
| 408 |
+
details::_memcpy_async_datum(group, dst, dstLayout, src, srcLayout);
|
| 409 |
+
__pipeline_commit();
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
/* Group wait for prior Nth stage of memcpy_async to complete. */
|
| 413 |
+
template <unsigned int Stage, class TyGroup>
|
| 414 |
+
_CG_STATIC_QUALIFIER void wait_prior(const TyGroup &group) {
|
| 415 |
+
__pipeline_wait_prior(Stage);
|
| 416 |
+
group.sync();
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
/* Group wait all previously submitted memcpy_async to complete. */
|
| 420 |
+
template <class TyGroup>
|
| 421 |
+
_CG_STATIC_QUALIFIER void wait(const TyGroup &group) {
|
| 422 |
+
__pipeline_wait_prior(0);
|
| 423 |
+
group.sync();
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
/***************** CG APIs including pipeline are deprecated *****************/
|
| 427 |
+
|
| 428 |
+
/* Group submit batch of async-copy to cover of contiguous 1D array
|
| 429 |
+
to a pipeline and commit the batch*/
|
| 430 |
+
template <class TyGroup, class TyElem>
|
| 431 |
+
_CG_DEPRECATED _CG_STATIC_QUALIFIER void memcpy_async(TyGroup &group, TyElem *dst, size_t dstCount, const TyElem *src, size_t srcCount,
|
| 432 |
+
nvcuda::experimental::pipeline &pipe) {
|
| 433 |
+
details::_memcpy_async_datum(group, dst, dstCount, src, srcCount);
|
| 434 |
+
pipe.commit();
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
/* Group wait for prior Nth stage of memcpy_async to complete. */
|
| 438 |
+
template <unsigned int Stage, class TyGroup>
|
| 439 |
+
_CG_DEPRECATED _CG_STATIC_QUALIFIER void wait_prior(TyGroup &group, nvcuda::experimental::pipeline &pipe) {
|
| 440 |
+
pipe.wait_prior<Stage>();
|
| 441 |
+
group.sync();
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
/* Group wait for stage-S of memcpy_async to complete. */
|
| 445 |
+
template <class TyGroup>
|
| 446 |
+
_CG_DEPRECATED _CG_STATIC_QUALIFIER void wait(TyGroup &group, nvcuda::experimental::pipeline &pipe, size_t stage) {
|
| 447 |
+
pipe.wait(stage);
|
| 448 |
+
group.sync();
|
| 449 |
+
}
|
| 450 |
+
_CG_END_NAMESPACE
|
| 451 |
+
|
| 452 |
+
#endif // _CG_ASYNC_H
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/coalesced_scan.h
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
#ifndef _CG_COALESCED_SCAN_H_
|
| 50 |
+
#define _CG_COALESCED_SCAN_H_
|
| 51 |
+
|
| 52 |
+
#include "info.h"
|
| 53 |
+
#include "helpers.h"
|
| 54 |
+
#include "cooperative_groups.h"
|
| 55 |
+
#include "partitioning.h"
|
| 56 |
+
#include "functional.h"
|
| 57 |
+
|
| 58 |
+
_CG_BEGIN_NAMESPACE
|
| 59 |
+
|
| 60 |
+
namespace details {
|
| 61 |
+
|
| 62 |
+
template <typename TyGroup, typename TyVal, typename TyOp>
|
| 63 |
+
_CG_QUALIFIER auto inclusive_scan_contiguous(const TyGroup& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) {
|
| 64 |
+
auto out = val;
|
| 65 |
+
for (int mask = 1; mask < group.size(); mask <<= 1) {
|
| 66 |
+
auto tmp = group.shfl_up(out, mask);
|
| 67 |
+
if (mask <= group.thread_rank()) {
|
| 68 |
+
out = op(out, tmp);
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
return out;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
template <typename TyGroup, typename TyVal, typename TyOp>
|
| 76 |
+
_CG_QUALIFIER auto inclusive_scan_non_contiguous(const TyGroup& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) {
|
| 77 |
+
const unsigned int groupSize = group.size();
|
| 78 |
+
auto out = val;
|
| 79 |
+
|
| 80 |
+
const unsigned int mask = details::_coalesced_group_data_access::get_mask(group);
|
| 81 |
+
unsigned int lanemask = details::lanemask32_lt() & mask;
|
| 82 |
+
unsigned int srcLane = details::laneid();
|
| 83 |
+
|
| 84 |
+
const unsigned int base = __ffs(mask)-1; /* lane with rank == 0 */
|
| 85 |
+
const unsigned int rank = __popc(lanemask);
|
| 86 |
+
|
| 87 |
+
for (unsigned int i = 1, j = 1; i < groupSize; i <<= 1) {
|
| 88 |
+
if (i <= rank) {
|
| 89 |
+
srcLane -= j;
|
| 90 |
+
j = i; /* maximum possible lane */
|
| 91 |
+
|
| 92 |
+
unsigned int begLane = base + rank - i; /* minimum possible lane */
|
| 93 |
+
|
| 94 |
+
/* Next source lane is in the range [ begLane .. srcLane ]
|
| 95 |
+
* If begLane < srcLane then do a binary search.
|
| 96 |
+
*/
|
| 97 |
+
while (begLane < srcLane) {
|
| 98 |
+
const unsigned int halfLane = (begLane + srcLane) >> 1;
|
| 99 |
+
const unsigned int halfMask = lanemask >> halfLane;
|
| 100 |
+
const unsigned int d = __popc(halfMask);
|
| 101 |
+
if (d < i) {
|
| 102 |
+
srcLane = halfLane - 1; /* halfLane too large */
|
| 103 |
+
}
|
| 104 |
+
else if ((i < d) || !(halfMask & 0x01)) {
|
| 105 |
+
begLane = halfLane + 1; /* halfLane too small */
|
| 106 |
+
}
|
| 107 |
+
else {
|
| 108 |
+
begLane = srcLane = halfLane; /* happen to hit */
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
auto tmp = details::tile::shuffle_dispatch<TyVal>::shfl(out, mask, srcLane, 32);
|
| 114 |
+
if (i <= rank) {
|
| 115 |
+
out = op(out, tmp);
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
return out;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <unsigned int TySize, typename ParentT, typename TyVal, typename TyOp>
|
| 122 |
+
_CG_QUALIFIER auto coalesced_inclusive_scan(const __single_warp_thread_block_tile<TySize, ParentT>& group,
|
| 123 |
+
TyVal&& val,
|
| 124 |
+
TyOp&& op) -> decltype(op(val, val)) {
|
| 125 |
+
return inclusive_scan_contiguous(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp>(op));
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
template <typename TyVal, typename TyOp>
|
| 129 |
+
_CG_QUALIFIER auto coalesced_inclusive_scan(const coalesced_group& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) {
|
| 130 |
+
if (group.size() == 32) {
|
| 131 |
+
return inclusive_scan_contiguous(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp>(op));
|
| 132 |
+
}
|
| 133 |
+
else {
|
| 134 |
+
return inclusive_scan_non_contiguous(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp>(op));
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
template <bool IntegralOptimized>
|
| 139 |
+
struct scan_choose_convertion;
|
| 140 |
+
|
| 141 |
+
template<>
|
| 142 |
+
struct scan_choose_convertion<true> {
|
| 143 |
+
template <typename TyGroup, typename TyRes, typename TyVal>
|
| 144 |
+
_CG_STATIC_QUALIFIER details::remove_qual<TyVal> convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val) {
|
| 145 |
+
return result - val;
|
| 146 |
+
}
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
template<>
|
| 150 |
+
struct scan_choose_convertion<false> {
|
| 151 |
+
template <typename TyGroup, typename TyRes, typename TyVal>
|
| 152 |
+
_CG_STATIC_QUALIFIER details::remove_qual<TyVal> convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val) {
|
| 153 |
+
auto ret = group.shfl_up(result, 1);
|
| 154 |
+
if (group.thread_rank() == 0) {
|
| 155 |
+
return {};
|
| 156 |
+
}
|
| 157 |
+
else {
|
| 158 |
+
return ret;
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
};
|
| 162 |
+
|
| 163 |
+
template <typename TyGroup, typename TyRes, typename TyVal, typename TyFn>
|
| 164 |
+
_CG_QUALIFIER auto convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 165 |
+
using conversion = scan_choose_convertion<_CG_STL_NAMESPACE::is_same<remove_qual<TyFn>, cooperative_groups::plus<remove_qual<TyVal>>>::value
|
| 166 |
+
&& _CG_STL_NAMESPACE::is_integral<remove_qual<TyVal>>::value>;
|
| 167 |
+
return conversion::convert_inclusive_to_exclusive(group, result, _CG_STL_NAMESPACE::forward<TyVal>(val));
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
} // details
|
| 171 |
+
|
| 172 |
+
_CG_END_NAMESPACE
|
| 173 |
+
|
| 174 |
+
#endif // _CG_COALESCED_SCAN_H_
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/driver_abi.h
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
#ifndef _CG_DRIVER_API_H
|
| 50 |
+
#define _CG_DRIVER_API_H
|
| 51 |
+
|
| 52 |
+
#include "info.h"
|
| 53 |
+
|
| 54 |
+
_CG_BEGIN_NAMESPACE
|
| 55 |
+
|
| 56 |
+
namespace details {
|
| 57 |
+
template <unsigned int RegId>
|
| 58 |
+
_CG_QUALIFIER unsigned int load_env_reg() {
|
| 59 |
+
// Abort by default
|
| 60 |
+
_CG_ABORT();
|
| 61 |
+
return 0;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
template <unsigned int HiReg, unsigned int LoReg>
|
| 65 |
+
_CG_QUALIFIER unsigned long long load_env_reg64() {
|
| 66 |
+
unsigned long long registerLo = load_env_reg<LoReg>();
|
| 67 |
+
unsigned long long registerHi = load_env_reg<HiReg>();
|
| 68 |
+
|
| 69 |
+
return (registerHi << 32) | registerLo;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// inline PTX for accessing registers requires an immediate for the special reg
|
| 73 |
+
# define LOAD_ENVREG(NUMBER) \
|
| 74 |
+
template <> _CG_QUALIFIER unsigned int load_env_reg<NUMBER>() { \
|
| 75 |
+
unsigned int r; \
|
| 76 |
+
asm ("mov.u32 %0, %%envreg" #NUMBER ";" : "=r"(r)); \
|
| 77 |
+
return r; \
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// Instantiate loaders for registers used
|
| 81 |
+
LOAD_ENVREG(0);
|
| 82 |
+
LOAD_ENVREG(1);
|
| 83 |
+
LOAD_ENVREG(2);
|
| 84 |
+
# undef LOAD_ENVREG
|
| 85 |
+
|
| 86 |
+
struct grid_workspace {
|
| 87 |
+
unsigned int wsSize;
|
| 88 |
+
unsigned int barrier;
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
_CG_QUALIFIER grid_workspace* get_grid_workspace() {
|
| 92 |
+
unsigned long long gridWsAbiAddress = load_env_reg64<1, 2>();
|
| 93 |
+
// Interpret the address from envreg 1 and 2 as the driver's grid workspace
|
| 94 |
+
return (reinterpret_cast<grid_workspace*>(gridWsAbiAddress));
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
_CG_END_NAMESPACE
|
| 98 |
+
|
| 99 |
+
#endif // _CG_DRIVER_API_H
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/info.h
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 1993-2021 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
#ifndef _CG_INFO_H_
|
| 52 |
+
#define _CG_INFO_H_
|
| 53 |
+
/*
|
| 54 |
+
** Define: _CG_VERSION
|
| 55 |
+
*/
|
| 56 |
+
#define _CG_VERSION 1000
|
| 57 |
+
|
| 58 |
+
/*
|
| 59 |
+
** Define: _CG_ABI_VERSION
|
| 60 |
+
*/
|
| 61 |
+
#ifndef _CG_ABI_VERSION
|
| 62 |
+
# define _CG_ABI_VERSION 1
|
| 63 |
+
#endif
|
| 64 |
+
|
| 65 |
+
/*
|
| 66 |
+
** Define: _CG_ABI_EXPERIMENTAL
|
| 67 |
+
** Desc: If enabled, sets all features enabled (ABI-breaking or experimental)
|
| 68 |
+
*/
|
| 69 |
+
#if defined(_CG_ABI_EXPERIMENTAL)
|
| 70 |
+
#endif
|
| 71 |
+
|
| 72 |
+
#define _CG_CONCAT_INNER(x, y) x ## y
|
| 73 |
+
#define _CG_CONCAT_OUTER(x, y) _CG_CONCAT_INNER(x, y)
|
| 74 |
+
#define _CG_NAMESPACE _CG_CONCAT_OUTER(__v, _CG_ABI_VERSION)
|
| 75 |
+
|
| 76 |
+
#define _CG_BEGIN_NAMESPACE \
|
| 77 |
+
namespace cooperative_groups { namespace _CG_NAMESPACE {
|
| 78 |
+
#define _CG_END_NAMESPACE \
|
| 79 |
+
}; using namespace _CG_NAMESPACE; };
|
| 80 |
+
|
| 81 |
+
#if (defined(__cplusplus) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MSC_VER >= 1900))
|
| 82 |
+
# define _CG_CPP11_FEATURES
|
| 83 |
+
#endif
|
| 84 |
+
|
| 85 |
+
#if !defined(_CG_QUALIFIER)
|
| 86 |
+
# define _CG_QUALIFIER __forceinline__ __device__
|
| 87 |
+
#endif
|
| 88 |
+
#if !defined(_CG_STATIC_QUALIFIER)
|
| 89 |
+
# define _CG_STATIC_QUALIFIER static __forceinline__ __device__
|
| 90 |
+
#endif
|
| 91 |
+
#if !defined(_CG_CONSTEXPR_QUALIFIER)
|
| 92 |
+
# if defined(_CG_CPP11_FEATURES)
|
| 93 |
+
# define _CG_CONSTEXPR_QUALIFIER constexpr __forceinline__ __device__
|
| 94 |
+
# else
|
| 95 |
+
# define _CG_CONSTEXPR_QUALIFIER _CG_QUALIFIER
|
| 96 |
+
# endif
|
| 97 |
+
#endif
|
| 98 |
+
#if !defined(_CG_STATIC_CONSTEXPR_QUALIFIER)
|
| 99 |
+
# if defined(_CG_CPP11_FEATURES)
|
| 100 |
+
# define _CG_STATIC_CONSTEXPR_QUALIFIER static constexpr __forceinline__ __device__
|
| 101 |
+
# else
|
| 102 |
+
# define _CG_STATIC_CONSTEXPR_QUALIFIER _CG_STATIC_QUALIFIER
|
| 103 |
+
# endif
|
| 104 |
+
#endif
|
| 105 |
+
|
| 106 |
+
#if defined(_MSC_VER)
|
| 107 |
+
# define _CG_DEPRECATED __declspec(deprecated)
|
| 108 |
+
#else
|
| 109 |
+
# define _CG_DEPRECATED __attribute__((deprecated))
|
| 110 |
+
#endif
|
| 111 |
+
|
| 112 |
+
#if (__CUDA_ARCH__ >= 600) || !defined(__CUDA_ARCH__)
|
| 113 |
+
# define _CG_HAS_GRID_GROUP
|
| 114 |
+
#endif
|
| 115 |
+
#if (__CUDA_ARCH__ >= 600) || !defined(__CUDA_ARCH__)
|
| 116 |
+
# define _CG_HAS_MULTI_GRID_GROUP
|
| 117 |
+
#endif
|
| 118 |
+
#if (__CUDA_ARCH__ >= 700) || !defined(__CUDA_ARCH__)
|
| 119 |
+
# define _CG_HAS_MATCH_COLLECTIVE
|
| 120 |
+
#endif
|
| 121 |
+
|
| 122 |
+
#if (__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__) && (defined(__NVCC__) || defined(__CUDACC_RTC__))
|
| 123 |
+
# define _CG_HAS_OP_REDUX
|
| 124 |
+
#endif
|
| 125 |
+
|
| 126 |
+
#if ((__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__)) && !defined(_CG_USER_PROVIDED_SHARED_MEMORY)
|
| 127 |
+
# define _CG_HAS_RESERVED_SHARED
|
| 128 |
+
#endif
|
| 129 |
+
|
| 130 |
+
#if ((__CUDA_ARCH__ >= 900) || !defined(__CUDA_ARCH__)) && \
|
| 131 |
+
(defined(__NVCC__) || defined(__CUDACC_RTC__) || defined(_CG_CLUSTER_INTRINSICS_AVAILABLE)) && \
|
| 132 |
+
defined(_CG_CPP11_FEATURES)
|
| 133 |
+
# define _CG_HAS_CLUSTER_GROUP
|
| 134 |
+
#endif
|
| 135 |
+
|
| 136 |
+
#if (__CUDA_ARCH__ >= 900) || !defined(__CUDA_ARCH__)
|
| 137 |
+
# define _CG_HAS_INSTR_ELECT
|
| 138 |
+
#endif
|
| 139 |
+
|
| 140 |
+
// Has __half and __half2
|
| 141 |
+
// Only usable if you include the cuda_fp16.h extension, and
|
| 142 |
+
// _before_ including cooperative_groups.h
|
| 143 |
+
#ifdef __CUDA_FP16_TYPES_EXIST__
|
| 144 |
+
# define _CG_HAS_FP16_COLLECTIVE
|
| 145 |
+
#endif
|
| 146 |
+
|
| 147 |
+
// Include libcu++ where supported.
|
| 148 |
+
#if defined(_CG_CPP11_FEATURES) && !defined(__QNX__) && !defined(__ibmxl__) && \
|
| 149 |
+
(defined(__NVCC__) || defined(__CUDACC_RTC__)) && \
|
| 150 |
+
(defined(__x86_64__) || defined(__aarch64__) || defined(__ppc64__)|| defined(_M_X64) || defined(_M_ARM64)) && \
|
| 151 |
+
(defined(_MSC_VER) || defined(__GNUC__) || defined(__clang__))
|
| 152 |
+
# define _CG_USE_CUDA_STL
|
| 153 |
+
#else
|
| 154 |
+
# define _CG_USE_OWN_TRAITS
|
| 155 |
+
#endif
|
| 156 |
+
|
| 157 |
+
#if defined(_CG_USE_CUDA_STL) && (!defined(__CUDA_ARCH__) || \
|
| 158 |
+
((!defined(_MSC_VER) && __CUDA_ARCH__ >= 600) || (defined(_MSC_VER) && __CUDA_ARCH__ >= 700)))
|
| 159 |
+
# define _CG_HAS_STL_ATOMICS
|
| 160 |
+
#endif
|
| 161 |
+
|
| 162 |
+
#ifdef _CG_CPP11_FEATURES
|
| 163 |
+
// Use cuda::std:: for type_traits
|
| 164 |
+
# if defined(_CG_USE_CUDA_STL)
|
| 165 |
+
# define _CG_STL_NAMESPACE cuda::std
|
| 166 |
+
# include <cuda/std/type_traits>
|
| 167 |
+
// Use CG's implementation of type traits
|
| 168 |
+
# else
|
| 169 |
+
# define _CG_STL_NAMESPACE cooperative_groups::details::templates
|
| 170 |
+
# endif
|
| 171 |
+
#endif
|
| 172 |
+
|
| 173 |
+
#ifdef _CG_CPP11_FEATURES
|
| 174 |
+
# define _CG_STATIC_CONST_DECL static constexpr
|
| 175 |
+
# define _CG_CONST_DECL constexpr
|
| 176 |
+
#else
|
| 177 |
+
# define _CG_STATIC_CONST_DECL static const
|
| 178 |
+
# define _CG_CONST_DECL const
|
| 179 |
+
#endif
|
| 180 |
+
|
| 181 |
+
#if (defined(_MSC_VER) && !defined(_WIN64)) || defined(__arm__)
|
| 182 |
+
# define _CG_ASM_PTR_CONSTRAINT "r"
|
| 183 |
+
#else
|
| 184 |
+
# define _CG_ASM_PTR_CONSTRAINT "l"
|
| 185 |
+
#endif
|
| 186 |
+
|
| 187 |
+
/*
|
| 188 |
+
** Define: CG_DEBUG
|
| 189 |
+
** What: Enables various runtime safety checks
|
| 190 |
+
*/
|
| 191 |
+
#if defined(__CUDACC_DEBUG__) && defined(CG_DEBUG) && !defined(NDEBUG)
|
| 192 |
+
# define _CG_DEBUG
|
| 193 |
+
#endif
|
| 194 |
+
|
| 195 |
+
#if defined(_CG_DEBUG)
|
| 196 |
+
# include <assert.h>
|
| 197 |
+
# define _CG_ASSERT(x) assert((x));
|
| 198 |
+
# define _CG_ABORT() assert(0);
|
| 199 |
+
#else
|
| 200 |
+
# define _CG_ASSERT(x)
|
| 201 |
+
# define _CG_ABORT() __trap();
|
| 202 |
+
#endif
|
| 203 |
+
|
| 204 |
+
_CG_BEGIN_NAMESPACE
|
| 205 |
+
|
| 206 |
+
namespace details {
|
| 207 |
+
_CG_STATIC_CONST_DECL unsigned int default_max_block_size = 1024;
|
| 208 |
+
|
| 209 |
+
#if defined(_CG_CPP11_FEATURES) && !defined(_CG_USE_CUDA_STL)
|
| 210 |
+
namespace templates {
|
| 211 |
+
|
| 212 |
+
/**
|
| 213 |
+
* Integral constants
|
| 214 |
+
**/
|
| 215 |
+
template <typename Ty, Ty Val>
|
| 216 |
+
struct integral_constant {
|
| 217 |
+
static constexpr Ty value = Val;
|
| 218 |
+
typedef Ty type;
|
| 219 |
+
|
| 220 |
+
_CG_QUALIFIER constexpr operator type() const noexcept { return value; }
|
| 221 |
+
_CG_QUALIFIER constexpr type operator()() const noexcept { return value; }
|
| 222 |
+
};
|
| 223 |
+
|
| 224 |
+
typedef integral_constant<bool, true> true_type;
|
| 225 |
+
typedef integral_constant<bool, false> false_type;
|
| 226 |
+
|
| 227 |
+
/**
|
| 228 |
+
* CV Qualifiers
|
| 229 |
+
**/
|
| 230 |
+
template <class Ty> struct is_lvalue_reference : public details::templates::false_type {};
|
| 231 |
+
template <class Ty> struct is_lvalue_reference<Ty&> : public details::templates::true_type {};
|
| 232 |
+
|
| 233 |
+
template <class Ty> struct remove_reference {typedef Ty type;};
|
| 234 |
+
template <class Ty> struct remove_reference<Ty&> {typedef Ty type;};
|
| 235 |
+
template <class Ty> struct remove_reference<Ty&&> {typedef Ty type;};
|
| 236 |
+
|
| 237 |
+
template <class Ty>
|
| 238 |
+
using remove_reference_t = typename details::templates::remove_reference<Ty>::type;
|
| 239 |
+
|
| 240 |
+
template <class Ty> struct remove_const {typedef Ty type;};
|
| 241 |
+
template <class Ty> struct remove_const<const Ty> {typedef Ty type;};
|
| 242 |
+
|
| 243 |
+
template <class Ty> struct remove_volatile {typedef Ty type;};
|
| 244 |
+
template <class Ty> struct remove_volatile<volatile Ty> {typedef Ty type;};
|
| 245 |
+
|
| 246 |
+
template <class Ty> struct remove_cv {typedef typename details::templates::remove_volatile<typename details::templates::remove_const<Ty>::type>::type type;};
|
| 247 |
+
|
| 248 |
+
template <class Ty>
|
| 249 |
+
using remove_cv_t = typename details::templates::remove_cv<Ty>::type;
|
| 250 |
+
|
| 251 |
+
template <class Ty>
|
| 252 |
+
_CG_QUALIFIER Ty&& forward(remove_reference_t<Ty> &t) noexcept {
|
| 253 |
+
return static_cast<Ty&&>(t);
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
template <class Ty>
|
| 257 |
+
_CG_QUALIFIER Ty&& forward(remove_reference_t<Ty> &&t) noexcept {
|
| 258 |
+
static_assert(!details::templates::is_lvalue_reference<Ty>::value, "Forwarding an rvalue as an lvalue is not allowed.");
|
| 259 |
+
return static_cast<Ty&&>(t);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
/**
|
| 263 |
+
* is_integral
|
| 264 |
+
**/
|
| 265 |
+
template <class Ty> struct _is_integral : public details::templates::false_type {};
|
| 266 |
+
template <> struct _is_integral<bool> : public details::templates::true_type {};
|
| 267 |
+
template <> struct _is_integral<char> : public details::templates::true_type {};
|
| 268 |
+
template <> struct _is_integral<unsigned char> : public details::templates::true_type {};
|
| 269 |
+
template <> struct _is_integral<short> : public details::templates::true_type {};
|
| 270 |
+
template <> struct _is_integral<unsigned short> : public details::templates::true_type {};
|
| 271 |
+
template <> struct _is_integral<int> : public details::templates::true_type {};
|
| 272 |
+
template <> struct _is_integral<unsigned int> : public details::templates::true_type {};
|
| 273 |
+
template <> struct _is_integral<long> : public details::templates::true_type {};
|
| 274 |
+
template <> struct _is_integral<long long> : public details::templates::true_type {};
|
| 275 |
+
template <> struct _is_integral<unsigned long> : public details::templates::true_type {};
|
| 276 |
+
template <> struct _is_integral<unsigned long long> : public details::templates::true_type {};
|
| 277 |
+
//Vector type support?
|
| 278 |
+
|
| 279 |
+
template <typename Ty>
|
| 280 |
+
struct is_integral : public details::templates::_is_integral<typename details::templates::remove_cv<Ty>::type> {};
|
| 281 |
+
|
| 282 |
+
/**
|
| 283 |
+
* is_floating_point
|
| 284 |
+
**/
|
| 285 |
+
template <class Ty> struct _is_floating_point : public details::templates::false_type {};
|
| 286 |
+
template <> struct _is_floating_point<float> : public details::templates::true_type {};
|
| 287 |
+
template <> struct _is_floating_point<double> : public details::templates::true_type {};
|
| 288 |
+
template <> struct _is_floating_point<long double> : public details::templates::true_type {};
|
| 289 |
+
# ifdef __CUDA_FP16_TYPES_EXIST__
|
| 290 |
+
template <> struct _is_floating_point<__half> : public details::templates::true_type {};
|
| 291 |
+
template <> struct _is_floating_point<__half2> : public details::templates::true_type {};
|
| 292 |
+
# endif
|
| 293 |
+
//Vector type support?
|
| 294 |
+
|
| 295 |
+
template <typename Ty>
|
| 296 |
+
struct is_floating_point : public details::templates::_is_floating_point<typename details::templates::remove_cv<Ty>::type> {};
|
| 297 |
+
|
| 298 |
+
template <class T>
|
| 299 |
+
struct is_arithmetic : details::templates::integral_constant<
|
| 300 |
+
bool,
|
| 301 |
+
details::templates::is_integral<T>::value ||
|
| 302 |
+
details::templates::is_floating_point<T>::value> {};
|
| 303 |
+
|
| 304 |
+
template <typename Ty, bool = details::templates::is_arithmetic<Ty>::value>
|
| 305 |
+
struct _is_unsigned : details::templates::integral_constant<bool, Ty(0) < Ty(-1)> {};
|
| 306 |
+
|
| 307 |
+
template <typename Ty>
|
| 308 |
+
struct _is_unsigned<Ty,false> : details::templates::false_type {};
|
| 309 |
+
|
| 310 |
+
template <typename Ty>
|
| 311 |
+
struct is_unsigned : _is_unsigned<typename details::templates::remove_cv<Ty>::type> {};
|
| 312 |
+
|
| 313 |
+
template <typename Ty> struct _is_pointer : public details::templates::false_type {};
|
| 314 |
+
template <typename Ty> struct _is_pointer<Ty*> : public details::templates::true_type {};
|
| 315 |
+
|
| 316 |
+
template <typename Ty>
|
| 317 |
+
struct is_pointer : _is_pointer<typename details::templates::remove_cv<Ty>::type> {};
|
| 318 |
+
|
| 319 |
+
/**
|
| 320 |
+
* programmatic type traits
|
| 321 |
+
**/
|
| 322 |
+
template<bool B, class Ty = void>
|
| 323 |
+
struct enable_if {};
|
| 324 |
+
|
| 325 |
+
template<class Ty>
|
| 326 |
+
struct enable_if<true, Ty> { typedef Ty type; };
|
| 327 |
+
|
| 328 |
+
template<bool Cond, typename Ty = void>
|
| 329 |
+
using enable_if_t = typename details::templates::enable_if<Cond, Ty>::type;
|
| 330 |
+
|
| 331 |
+
template<class Ty1, class Ty2>
|
| 332 |
+
struct is_same : details::templates::false_type {};
|
| 333 |
+
|
| 334 |
+
template<class Ty>
|
| 335 |
+
struct is_same<Ty, Ty> : details::templates::true_type {};
|
| 336 |
+
|
| 337 |
+
} // templates
|
| 338 |
+
#endif // _CG_CPP11_FEATURES
|
| 339 |
+
|
| 340 |
+
} // details
|
| 341 |
+
_CG_END_NAMESPACE
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
#endif // _CG_INFO_H_
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/invoke.h
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#ifndef _CG_INVOKE_H
|
| 51 |
+
#define _CG_INVOKE_H
|
| 52 |
+
|
| 53 |
+
#include "info.h"
|
| 54 |
+
#include "helpers.h"
|
| 55 |
+
|
| 56 |
+
#if defined(_CG_CPP11_FEATURES)
|
| 57 |
+
|
| 58 |
+
_CG_BEGIN_NAMESPACE
|
| 59 |
+
|
| 60 |
+
namespace details {
|
| 61 |
+
|
| 62 |
+
template <typename Group>
|
| 63 |
+
struct _elect_group_supported : _CG_STL_NAMESPACE::false_type {};
|
| 64 |
+
#ifdef _CG_HAS_INSTR_ELECT
|
| 65 |
+
template<>
|
| 66 |
+
struct _elect_group_supported<coalesced_group> : _CG_STL_NAMESPACE::true_type {};
|
| 67 |
+
template<unsigned int Size, typename Parent>
|
| 68 |
+
struct _elect_group_supported<thread_block_tile<Size, Parent>> :
|
| 69 |
+
_CG_STL_NAMESPACE::integral_constant<bool, (Size <= 32)> {};
|
| 70 |
+
#endif
|
| 71 |
+
|
| 72 |
+
template <typename Group>
|
| 73 |
+
struct elect_group_supported : public _elect_group_supported<details::remove_qual<Group>> {};
|
| 74 |
+
|
| 75 |
+
template<typename Group>
|
| 76 |
+
_CG_STATIC_QUALIFIER bool elect_one(const Group& group, unsigned int mask, unsigned int& leader_lane) {
|
| 77 |
+
int is_leader = 0;
|
| 78 |
+
#ifdef _CG_HAS_INSTR_ELECT
|
| 79 |
+
asm("{\n\t"
|
| 80 |
+
" .reg .pred p;\n\t"
|
| 81 |
+
" elect.sync %0|p, %2;\n\t"
|
| 82 |
+
" @p mov.s32 %1, 1;\n\t"
|
| 83 |
+
"}"
|
| 84 |
+
: "+r"(leader_lane), "+r"(is_leader) : "r" (mask));
|
| 85 |
+
#endif
|
| 86 |
+
return is_leader;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
template<bool UseElect>
|
| 90 |
+
struct invoke_one_impl {};
|
| 91 |
+
|
| 92 |
+
template<>
|
| 93 |
+
struct invoke_one_impl<true> {
|
| 94 |
+
template<typename Group, typename Fn, typename... Args>
|
| 95 |
+
_CG_STATIC_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) {
|
| 96 |
+
auto mask = details::_coalesced_group_data_access::get_mask(group);
|
| 97 |
+
unsigned int leader_lane = 0;
|
| 98 |
+
|
| 99 |
+
if (elect_one(group, mask, leader_lane)) {
|
| 100 |
+
_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...);
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
template<typename Group, typename Fn, typename... Args>
|
| 105 |
+
_CG_STATIC_QUALIFIER auto invoke_one_broadcast(const Group& group, Fn&& fn, Args&&... args)
|
| 106 |
+
-> typename _CG_STL_NAMESPACE::remove_reference<
|
| 107 |
+
decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...))>::type {
|
| 108 |
+
|
| 109 |
+
using ResultType = decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...));
|
| 110 |
+
details::remove_qual<ResultType> result;
|
| 111 |
+
auto mask = details::_coalesced_group_data_access::get_mask(group);
|
| 112 |
+
unsigned int leader_lane = 0;
|
| 113 |
+
|
| 114 |
+
if (elect_one(group, mask, leader_lane)) {
|
| 115 |
+
result = _CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
// Need to use low level api instead of group.shfl, because elect_one returns lane id, not group rank.
|
| 119 |
+
return tile::shuffle_dispatch<ResultType>::shfl(result, mask, leader_lane, 32);
|
| 120 |
+
}
|
| 121 |
+
};
|
| 122 |
+
|
| 123 |
+
template<>
|
| 124 |
+
struct invoke_one_impl<false> {
|
| 125 |
+
template<typename Group, typename Fn, typename... Args>
|
| 126 |
+
_CG_STATIC_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) {
|
| 127 |
+
if (group.thread_rank() == 0) {
|
| 128 |
+
_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...);
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
template<typename Group, typename Fn, typename... Args>
|
| 133 |
+
_CG_STATIC_QUALIFIER auto invoke_one_broadcast(const Group& group, Fn&& fn, Args&&... args)
|
| 134 |
+
-> typename _CG_STL_NAMESPACE::remove_reference<
|
| 135 |
+
decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...))>::type {
|
| 136 |
+
|
| 137 |
+
using ResultType = decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...));
|
| 138 |
+
details::remove_qual<ResultType> result;
|
| 139 |
+
|
| 140 |
+
if (group.thread_rank() == 0) {
|
| 141 |
+
result = _CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
return group.shfl(result, 0);
|
| 145 |
+
}
|
| 146 |
+
};
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
}; // namespace details
|
| 150 |
+
|
| 151 |
+
template<typename Group, typename Fn, typename... Args>
|
| 152 |
+
_CG_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) {
|
| 153 |
+
using impl = details::invoke_one_impl<details::elect_group_supported<Group>::value>;
|
| 154 |
+
impl::invoke_one(group, _CG_STL_NAMESPACE::forward<Fn>(fn), _CG_STL_NAMESPACE::forward<Args>(args)...);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
template<typename Fn, typename... Args>
|
| 158 |
+
_CG_QUALIFIER auto invoke_one_broadcast(const coalesced_group& group, Fn&& fn, Args&&... args)
|
| 159 |
+
-> typename _CG_STL_NAMESPACE::remove_reference<
|
| 160 |
+
decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...))>::type {
|
| 161 |
+
|
| 162 |
+
using ResultType = decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...));
|
| 163 |
+
static_assert(!_CG_STL_NAMESPACE::is_same<ResultType, void>::value,
|
| 164 |
+
"For invocables returning void invoke_one should be used instead");
|
| 165 |
+
using impl = details::invoke_one_impl<details::elect_group_supported<coalesced_group>::value>;
|
| 166 |
+
return impl::invoke_one_broadcast(group,
|
| 167 |
+
_CG_STL_NAMESPACE::forward<Fn>(fn),
|
| 168 |
+
_CG_STL_NAMESPACE::forward<Args>(args)...);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
template<unsigned int Size, typename Parent, typename Fn, typename... Args>
|
| 172 |
+
_CG_QUALIFIER auto invoke_one_broadcast(const thread_block_tile<Size, Parent>& group, Fn&& fn, Args&&... args)
|
| 173 |
+
-> typename _CG_STL_NAMESPACE::remove_reference<
|
| 174 |
+
decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...))>::type {
|
| 175 |
+
|
| 176 |
+
using ResultType = decltype(_CG_STL_NAMESPACE::forward<Fn>(fn)(_CG_STL_NAMESPACE::forward<Args>(args)...));
|
| 177 |
+
static_assert(!_CG_STL_NAMESPACE::is_same<ResultType, void>::value,
|
| 178 |
+
"For invocables returning void invoke_one should be used instead");
|
| 179 |
+
using impl = details::invoke_one_impl<details::elect_group_supported<thread_block_tile<Size, Parent>>::value>;
|
| 180 |
+
return impl::invoke_one_broadcast(group,
|
| 181 |
+
_CG_STL_NAMESPACE::forward<Fn>(fn),
|
| 182 |
+
_CG_STL_NAMESPACE::forward<Args>(args)...);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
_CG_END_NAMESPACE
|
| 186 |
+
|
| 187 |
+
#endif //_CG_CPP11_FEATURES
|
| 188 |
+
|
| 189 |
+
#endif // _CG_INVOKE_H
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/memory.h
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
#ifndef _COOPERATIVE_GROUPS_MEMORY_H_
|
| 50 |
+
# define _COOPERATIVE_GROUPS_MEMORY_H_
|
| 51 |
+
|
| 52 |
+
#include "info.h"
|
| 53 |
+
|
| 54 |
+
_CG_BEGIN_NAMESPACE
|
| 55 |
+
|
| 56 |
+
#if defined(_CG_CPP11_FEATURES)
|
| 57 |
+
namespace details {
|
| 58 |
+
_CG_STATIC_CONST_DECL int scratch_num_reserved_bytes = 12;
|
| 59 |
+
|
| 60 |
+
#if defined(_CG_HAS_RESERVED_SHARED)
|
| 61 |
+
_CG_STATIC_QUALIFIER void* reserved_shared_ptr()
|
| 62 |
+
{
|
| 63 |
+
void *ptr;
|
| 64 |
+
asm ("{\n\t"
|
| 65 |
+
" .reg .u32 start;\n\t"
|
| 66 |
+
" .reg .u64 extended;\n\t"
|
| 67 |
+
" mov.u32 start, %%reserved_smem_offset_1;\n\t"
|
| 68 |
+
" cvt.u64.u32 extended, start;\n\t"
|
| 69 |
+
" cvta.shared.u64 %0, extended;\n\t"
|
| 70 |
+
"}"
|
| 71 |
+
: "=" _CG_ASM_PTR_CONSTRAINT(ptr));
|
| 72 |
+
return ptr;
|
| 73 |
+
}
|
| 74 |
+
#endif
|
| 75 |
+
|
| 76 |
+
struct multi_warp_scratch {
|
| 77 |
+
// One barrier per possible size of the group.
|
| 78 |
+
_CG_STATIC_CONST_DECL unsigned int memory_barriers_count = 5;
|
| 79 |
+
_CG_STATIC_CONST_DECL size_t sync_memory_size = memory_barriers_count * sizeof(barrier_t);
|
| 80 |
+
|
| 81 |
+
using communication_type = unsigned long long;
|
| 82 |
+
_CG_STATIC_CONST_DECL size_t communication_size = sizeof(communication_type);
|
| 83 |
+
|
| 84 |
+
// Layout of the scratch space:
|
| 85 |
+
barrier_t barriers[memory_barriers_count];
|
| 86 |
+
char reserved[scratch_num_reserved_bytes]; // Reserve 12 bytes for future use
|
| 87 |
+
communication_type communication_memory[default_max_block_size / 32];
|
| 88 |
+
|
| 89 |
+
_CG_STATIC_CONSTEXPR_QUALIFIER unsigned int scratch_size_needed(unsigned int max_block_size) {
|
| 90 |
+
// One slot of collectives memory per warp.
|
| 91 |
+
return scratch_num_reserved_bytes + sync_memory_size + max_block_size / 32 * communication_size;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
_CG_QUALIFIER void init_barriers(unsigned int thread_rank) {
|
| 95 |
+
if (thread_rank < memory_barriers_count) {
|
| 96 |
+
barriers[thread_rank] = 0;
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
#if defined(_CG_HAS_RESERVED_SHARED)
|
| 102 |
+
// CG can expect at least 288 bytes available in reserved shared
|
| 103 |
+
static_assert(sizeof(multi_warp_scratch) <= 288, "multi-warp scratch size is too large");
|
| 104 |
+
#endif
|
| 105 |
+
|
| 106 |
+
// Make sure the structure can fit into the user provided memory
|
| 107 |
+
static_assert(sizeof(multi_warp_scratch) <= multi_warp_scratch::scratch_size_needed(default_max_block_size),
|
| 108 |
+
"multi-warp scratch size is too large");
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
_CG_QUALIFIER multi_warp_scratch* get_scratch_ptr(void* user_scratch) {
|
| 112 |
+
void *ptr;
|
| 113 |
+
#if defined(_CG_HAS_RESERVED_SHARED)
|
| 114 |
+
ptr = reserved_shared_ptr();
|
| 115 |
+
#else
|
| 116 |
+
ptr = user_scratch;
|
| 117 |
+
#endif
|
| 118 |
+
return static_cast<multi_warp_scratch*>(ptr);
|
| 119 |
+
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
template <unsigned int MaxBlockSize = details::default_max_block_size>
|
| 125 |
+
struct __align__(details::multi_warp_scratch::communication_size) block_tile_memory {
|
| 126 |
+
private:
|
| 127 |
+
#if !defined(_CG_HAS_RESERVED_SHARED)
|
| 128 |
+
char scratch[details::multi_warp_scratch::scratch_size_needed(MaxBlockSize)];
|
| 129 |
+
#endif
|
| 130 |
+
};
|
| 131 |
+
#endif
|
| 132 |
+
|
| 133 |
+
_CG_END_NAMESPACE
|
| 134 |
+
|
| 135 |
+
#endif /* !_COOPERATIVE_GROUPS_MEMORY_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/partitioning.h
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#ifndef _CG_PARTITIONING_H
|
| 51 |
+
#define _CG_PARTITIONING_H
|
| 52 |
+
|
| 53 |
+
#include "info.h"
|
| 54 |
+
#include "helpers.h"
|
| 55 |
+
|
| 56 |
+
_CG_BEGIN_NAMESPACE
|
| 57 |
+
|
| 58 |
+
namespace details {
|
| 59 |
+
|
| 60 |
+
template <typename TyGroup>
|
| 61 |
+
_CG_STATIC_QUALIFIER coalesced_group _binary_partition(const TyGroup &tile, bool pred) {
|
| 62 |
+
const unsigned int fullMask = ~0u;
|
| 63 |
+
|
| 64 |
+
unsigned int thisMask = _coalesced_group_data_access::get_mask(tile);
|
| 65 |
+
unsigned int predMask = pred ? 0 : fullMask;
|
| 66 |
+
unsigned int setMask = __ballot_sync(thisMask, pred);
|
| 67 |
+
|
| 68 |
+
if (setMask == thisMask || setMask == 0) {
|
| 69 |
+
coalesced_group subTile = _coalesced_group_data_access::construct_from_mask<coalesced_group>(thisMask);
|
| 70 |
+
_coalesced_group_data_access::modify_meta_group(subTile, 0, 1);
|
| 71 |
+
return subTile;
|
| 72 |
+
}
|
| 73 |
+
else {
|
| 74 |
+
unsigned int subMask = thisMask & (setMask ^ predMask);
|
| 75 |
+
coalesced_group subTile = _coalesced_group_data_access::construct_from_mask<coalesced_group>(subMask);
|
| 76 |
+
_coalesced_group_data_access::modify_meta_group(subTile, pred, 2);
|
| 77 |
+
return subTile;
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
#if defined(_CG_HAS_MATCH_COLLECTIVE) && defined(_CG_CPP11_FEATURES)
|
| 82 |
+
template <typename TyPredicate>
|
| 83 |
+
struct _labeled_partition_dispatch {
|
| 84 |
+
template <typename TyGroup>
|
| 85 |
+
_CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, TyPredicate pred) {
|
| 86 |
+
unsigned int thisMask = _coalesced_group_data_access::get_mask(tile);
|
| 87 |
+
unsigned int thisBias = __ffs(thisMask) - 1; // Subtract 1 to index properly from [1-32]
|
| 88 |
+
unsigned int subMask = __match_any_sync(thisMask, pred);
|
| 89 |
+
|
| 90 |
+
coalesced_group subTile = _coalesced_group_data_access::construct_from_mask<coalesced_group>(subMask);
|
| 91 |
+
|
| 92 |
+
int leaderLaneId = subTile.shfl(details::laneid(), 0);
|
| 93 |
+
|
| 94 |
+
bool isLeader = !subTile.thread_rank();
|
| 95 |
+
unsigned int leaderMask = __ballot_sync(thisMask, isLeader);
|
| 96 |
+
unsigned int tileRank = __fns(leaderMask, leaderLaneId, 0) - thisBias;
|
| 97 |
+
|
| 98 |
+
_coalesced_group_data_access::modify_meta_group(subTile, tileRank, __popc(leaderMask));
|
| 99 |
+
|
| 100 |
+
return subTile;
|
| 101 |
+
}
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
template <>
|
| 105 |
+
struct _labeled_partition_dispatch<bool> {
|
| 106 |
+
template <typename TyGroup>
|
| 107 |
+
_CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, bool pred) {
|
| 108 |
+
return _binary_partition(tile, pred);
|
| 109 |
+
}
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
template <typename TyPredicate>
|
| 113 |
+
struct _labeled_partition_dispatch<TyPredicate*> {
|
| 114 |
+
template <typename TyGroup>
|
| 115 |
+
_CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, TyPredicate* pred) {
|
| 116 |
+
auto impl = _labeled_partition_dispatch<unsigned long long>();
|
| 117 |
+
return impl(tile, reinterpret_cast<unsigned long long>(pred));
|
| 118 |
+
}
|
| 119 |
+
};
|
| 120 |
+
#endif
|
| 121 |
+
}; // namespace details
|
| 122 |
+
|
| 123 |
+
_CG_STATIC_QUALIFIER coalesced_group binary_partition(const coalesced_group &tile, bool pred) {
|
| 124 |
+
return details::_binary_partition(tile, pred);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template <unsigned int Size, typename ParentT>
|
| 128 |
+
_CG_STATIC_QUALIFIER coalesced_group binary_partition(const thread_block_tile<Size, ParentT> &tile, bool pred) {
|
| 129 |
+
#ifdef _CG_CPP11_FEATURES
|
| 130 |
+
static_assert(Size <= 32, "Binary partition is available only for tiles of size smaller or equal to 32");
|
| 131 |
+
#endif
|
| 132 |
+
return details::_binary_partition(tile, pred);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
#if defined(_CG_HAS_MATCH_COLLECTIVE) && defined(_CG_CPP11_FEATURES)
|
| 137 |
+
template <typename TyPredicate>
|
| 138 |
+
_CG_STATIC_QUALIFIER coalesced_group labeled_partition(const coalesced_group &tile, TyPredicate pred) {
|
| 139 |
+
static_assert(_CG_STL_NAMESPACE::is_integral<TyPredicate>::value ||
|
| 140 |
+
_CG_STL_NAMESPACE::is_pointer<TyPredicate>::value,
|
| 141 |
+
"labeled_partition predicate must be an integral or pointer type");
|
| 142 |
+
auto dispatch = details::_labeled_partition_dispatch<details::remove_qual<TyPredicate>>();
|
| 143 |
+
return dispatch(tile, pred);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
template <typename TyPredicate, unsigned int Size, typename ParentT>
|
| 147 |
+
_CG_STATIC_QUALIFIER coalesced_group labeled_partition(const thread_block_tile<Size, ParentT> &tile, TyPredicate pred) {
|
| 148 |
+
static_assert(_CG_STL_NAMESPACE::is_integral<TyPredicate>::value ||
|
| 149 |
+
_CG_STL_NAMESPACE::is_pointer<TyPredicate>::value,
|
| 150 |
+
"labeled_partition predicate must be an integral or pointer type");
|
| 151 |
+
static_assert(Size <= 32, "Labeled partition is available only for tiles of size smaller or equal to 32");
|
| 152 |
+
auto dispatch = details::_labeled_partition_dispatch<details::remove_qual<TyPredicate>>();
|
| 153 |
+
return dispatch(tile, pred);
|
| 154 |
+
}
|
| 155 |
+
#endif
|
| 156 |
+
|
| 157 |
+
_CG_END_NAMESPACE
|
| 158 |
+
|
| 159 |
+
#endif // _CG_PARTITIONING_H
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/reduce.h
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
#ifndef _CG_REDUCE_H_
|
| 50 |
+
#define _CG_REDUCE_H_
|
| 51 |
+
|
| 52 |
+
#include "info.h"
|
| 53 |
+
#include "helpers.h"
|
| 54 |
+
#include "coalesced_reduce.h"
|
| 55 |
+
#include "functional.h"
|
| 56 |
+
#include "cooperative_groups.h"
|
| 57 |
+
|
| 58 |
+
_CG_BEGIN_NAMESPACE
|
| 59 |
+
|
| 60 |
+
namespace details {
|
| 61 |
+
|
| 62 |
+
template <class Ty>
|
| 63 |
+
using _redux_is_add_supported = _CG_STL_NAMESPACE::integral_constant<
|
| 64 |
+
bool,
|
| 65 |
+
_CG_STL_NAMESPACE::is_integral<Ty>::value && (sizeof(Ty) <= 4)>;
|
| 66 |
+
|
| 67 |
+
template <class Ty>
|
| 68 |
+
using redux_is_add_supported = _redux_is_add_supported<Ty>;
|
| 69 |
+
|
| 70 |
+
// A specialization for 64 bit logical operations is possible
|
| 71 |
+
// but for now only accelerate 32 bit bitwise ops
|
| 72 |
+
template <class Ty>
|
| 73 |
+
using redux_is_logical_supported = redux_is_add_supported<Ty>;
|
| 74 |
+
|
| 75 |
+
// Base operator support case
|
| 76 |
+
template <class TyOp, class Ty> struct _redux_op_supported : public _CG_STL_NAMESPACE::false_type {};
|
| 77 |
+
#ifdef _CG_HAS_OP_REDUX
|
| 78 |
+
template <class Ty> struct _redux_op_supported<cooperative_groups::plus<Ty>, Ty> : public redux_is_add_supported<Ty> {};
|
| 79 |
+
template <class Ty> struct _redux_op_supported<cooperative_groups::less<Ty>, Ty> : public redux_is_add_supported<Ty> {};
|
| 80 |
+
template <class Ty> struct _redux_op_supported<cooperative_groups::greater<Ty>, Ty> : public redux_is_add_supported<Ty> {};
|
| 81 |
+
template <class Ty> struct _redux_op_supported<cooperative_groups::bit_and<Ty>, Ty> : public redux_is_logical_supported<Ty> {};
|
| 82 |
+
template <class Ty> struct _redux_op_supported<cooperative_groups::bit_or<Ty>, Ty> : public redux_is_logical_supported<Ty> {};
|
| 83 |
+
template <class Ty> struct _redux_op_supported<cooperative_groups::bit_xor<Ty>, Ty> : public redux_is_logical_supported<Ty> {};
|
| 84 |
+
#endif
|
| 85 |
+
|
| 86 |
+
template <class Ty, template <class> class TyOp>
|
| 87 |
+
using redux_op_supported = _redux_op_supported<
|
| 88 |
+
typename details::remove_qual<TyOp<Ty>>,
|
| 89 |
+
Ty>;
|
| 90 |
+
|
| 91 |
+
// Groups smaller than 16 actually have worse performance characteristics when used with redux
|
| 92 |
+
// tiles of size 16 and 32 perform the same or better and have better code generation profiles
|
| 93 |
+
template <class TyGroup> struct _redux_group_optimized : public _CG_STL_NAMESPACE::false_type {};
|
| 94 |
+
|
| 95 |
+
template <unsigned int Sz, typename TyPar>
|
| 96 |
+
struct _redux_group_optimized<cooperative_groups::thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::integral_constant<
|
| 97 |
+
bool,
|
| 98 |
+
(Sz >= 16)> {};
|
| 99 |
+
template <unsigned int Sz, typename TyPar>
|
| 100 |
+
struct _redux_group_optimized<internal_thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::integral_constant<
|
| 101 |
+
bool,
|
| 102 |
+
(Sz >= 16)> {};
|
| 103 |
+
template <>
|
| 104 |
+
struct _redux_group_optimized<cooperative_groups::coalesced_group> : public _CG_STL_NAMESPACE::true_type {};
|
| 105 |
+
|
| 106 |
+
template <typename TyGroup>
|
| 107 |
+
using redux_group_optimized = _redux_group_optimized<details::remove_qual<TyGroup>>;
|
| 108 |
+
|
| 109 |
+
template <template <class> class TyOp>
|
| 110 |
+
_CG_STATIC_QUALIFIER int pick_redux(int mask, int val);
|
| 111 |
+
template <template <class> class TyOp>
|
| 112 |
+
_CG_STATIC_QUALIFIER unsigned int pick_redux(int mask, unsigned int val);
|
| 113 |
+
|
| 114 |
+
#ifdef _CG_HAS_OP_REDUX
|
| 115 |
+
template <> _CG_QUALIFIER int pick_redux<cooperative_groups::plus>(int mask, int val) {
|
| 116 |
+
return __reduce_add_sync(mask, val);
|
| 117 |
+
}
|
| 118 |
+
template <> _CG_QUALIFIER int pick_redux<cooperative_groups::less>(int mask, int val) {
|
| 119 |
+
return __reduce_min_sync(mask, val);
|
| 120 |
+
}
|
| 121 |
+
template <> _CG_QUALIFIER int pick_redux<cooperative_groups::greater>(int mask, int val) {
|
| 122 |
+
return __reduce_max_sync(mask, val);
|
| 123 |
+
}
|
| 124 |
+
template <> _CG_QUALIFIER int pick_redux<cooperative_groups::bit_and>(int mask, int val) {
|
| 125 |
+
return __reduce_and_sync(mask, val);
|
| 126 |
+
}
|
| 127 |
+
template <> _CG_QUALIFIER int pick_redux<cooperative_groups::bit_xor>(int mask, int val) {
|
| 128 |
+
return __reduce_xor_sync(mask, val);
|
| 129 |
+
}
|
| 130 |
+
template <> _CG_QUALIFIER int pick_redux<cooperative_groups::bit_or>(int mask, int val) {
|
| 131 |
+
return __reduce_or_sync(mask, val);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::plus>(int mask, unsigned int val) {
|
| 135 |
+
return __reduce_add_sync(mask, val);
|
| 136 |
+
}
|
| 137 |
+
template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::less>(int mask, unsigned int val) {
|
| 138 |
+
return __reduce_min_sync(mask, val);
|
| 139 |
+
}
|
| 140 |
+
template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::greater>(int mask, unsigned int val) {
|
| 141 |
+
return __reduce_max_sync(mask, val);
|
| 142 |
+
}
|
| 143 |
+
template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::bit_and>(int mask, unsigned int val) {
|
| 144 |
+
return __reduce_and_sync(mask, val);
|
| 145 |
+
}
|
| 146 |
+
template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::bit_xor>(int mask, unsigned int val) {
|
| 147 |
+
return __reduce_xor_sync(mask, val);
|
| 148 |
+
}
|
| 149 |
+
template <> _CG_QUALIFIER unsigned int pick_redux<cooperative_groups::bit_or>(int mask, unsigned int val) {
|
| 150 |
+
return __reduce_or_sync(mask, val);
|
| 151 |
+
}
|
| 152 |
+
#endif
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
template <typename TyVal, bool = _CG_STL_NAMESPACE::is_unsigned<TyVal>::value>
|
| 156 |
+
struct _accelerated_op;
|
| 157 |
+
|
| 158 |
+
// Signed type redux intrinsic dispatch
|
| 159 |
+
template <typename TyVal>
|
| 160 |
+
struct _accelerated_op<TyVal, false> {
|
| 161 |
+
template <template <class> class TyOp>
|
| 162 |
+
_CG_STATIC_QUALIFIER TyVal redux(int mask, TyVal val) {
|
| 163 |
+
return static_cast<TyVal>(pick_redux<TyOp>(mask, static_cast<int>(val)));
|
| 164 |
+
}
|
| 165 |
+
};
|
| 166 |
+
|
| 167 |
+
// Unsigned type redux intrinsic dispatch
|
| 168 |
+
template <typename TyVal>
|
| 169 |
+
struct _accelerated_op<TyVal, true> {
|
| 170 |
+
template <template <class> class TyOp>
|
| 171 |
+
_CG_STATIC_QUALIFIER TyVal redux(int mask, TyVal val) {
|
| 172 |
+
return static_cast<TyVal>(pick_redux<TyOp>(mask, static_cast<unsigned int>(val)));
|
| 173 |
+
}
|
| 174 |
+
};
|
| 175 |
+
|
| 176 |
+
template <typename TyVal>
|
| 177 |
+
using accelerated_op = _accelerated_op<TyVal>;
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
template <typename TyVal, typename TyFnInput, typename TyGroup>
|
| 181 |
+
class _redux_dispatch {
|
| 182 |
+
template <class Ty, template <class> class TyOp>
|
| 183 |
+
using _redux_is_usable = _CG_STL_NAMESPACE::integral_constant<bool,
|
| 184 |
+
redux_op_supported<Ty, TyOp>::value &&
|
| 185 |
+
redux_group_optimized<TyGroup>::value>;
|
| 186 |
+
|
| 187 |
+
template <class Ty, template <class> class TyOp>
|
| 188 |
+
using redux_is_usable = typename _CG_STL_NAMESPACE::enable_if<_redux_is_usable<Ty, TyOp>::value, void>::type*;
|
| 189 |
+
|
| 190 |
+
template <class Ty, template <class> class TyOp>
|
| 191 |
+
using redux_is_not_usable = typename _CG_STL_NAMESPACE::enable_if<!_redux_is_usable<Ty, TyOp>::value, void>::type*;
|
| 192 |
+
|
| 193 |
+
public:
|
| 194 |
+
// Dispatch to redux if the combination of op and args are supported
|
| 195 |
+
template<
|
| 196 |
+
template <class> class TyOp,
|
| 197 |
+
redux_is_usable<TyFnInput, TyOp> = nullptr>
|
| 198 |
+
_CG_STATIC_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>&& op) -> decltype(op(val, val)) {
|
| 199 |
+
// Retrieve the mask for the group and dispatch to redux
|
| 200 |
+
return accelerated_op<TyFnInput>::template redux<TyOp>(_coalesced_group_data_access::get_mask(group), _CG_STL_NAMESPACE::forward<TyVal>(val));
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
template<
|
| 204 |
+
template <class> class TyOp,
|
| 205 |
+
redux_is_usable<TyFnInput, TyOp> = nullptr>
|
| 206 |
+
_CG_STATIC_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>& op) -> decltype(op(val, val)) {
|
| 207 |
+
// Retrieve the mask for the group and dispatch to redux
|
| 208 |
+
return accelerated_op<TyFnInput>::template redux<TyOp>(_coalesced_group_data_access::get_mask(group), _CG_STL_NAMESPACE::forward<TyVal>(val));
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
// Fallback shuffle sync reduction
|
| 212 |
+
template <
|
| 213 |
+
template <class> class TyOp,
|
| 214 |
+
redux_is_not_usable<TyFnInput, TyOp> = nullptr>
|
| 215 |
+
_CG_STATIC_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>&& op) -> decltype(op(val, val)) {
|
| 216 |
+
//Dispatch to fallback shuffle sync accelerated reduction
|
| 217 |
+
return coalesced_reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp<TyFnInput>>(op));
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
};
|
| 221 |
+
|
| 222 |
+
// Group support for reduce.
|
| 223 |
+
template <class TyGroup> struct _reduce_group_supported : public _CG_STL_NAMESPACE::false_type {};
|
| 224 |
+
|
| 225 |
+
template <unsigned int Sz, typename TyPar>
|
| 226 |
+
struct _reduce_group_supported<cooperative_groups::thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::true_type {};
|
| 227 |
+
template <unsigned int Sz, typename TyPar>
|
| 228 |
+
struct _reduce_group_supported<internal_thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::true_type {};
|
| 229 |
+
template <>
|
| 230 |
+
struct _reduce_group_supported<cooperative_groups::coalesced_group> : public _CG_STL_NAMESPACE::true_type {};
|
| 231 |
+
|
| 232 |
+
template <typename TyGroup>
|
| 233 |
+
using reduce_group_supported = _reduce_group_supported<details::remove_qual<TyGroup>>;
|
| 234 |
+
|
| 235 |
+
template <typename TyVal, typename TyFnInput, template <class> class TyOp, typename TyGroup>
|
| 236 |
+
_CG_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>&& op) -> decltype(op(val, val)) {
|
| 237 |
+
static_assert(details::is_op_type_same<TyFnInput, TyVal>::value, "Operator and argument types differ");
|
| 238 |
+
|
| 239 |
+
using dispatch = details::_redux_dispatch<TyVal, TyFnInput, TyGroup>;
|
| 240 |
+
return dispatch::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp<TyFnInput>>(op));
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
template <typename TyVal, typename TyFnInput, template <class> class TyOp, typename TyGroup>
|
| 244 |
+
_CG_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp<TyFnInput>& op) -> decltype(op(val, val)) {
|
| 245 |
+
static_assert(details::is_op_type_same<TyFnInput, TyVal>::value, "Operator and argument types differ");
|
| 246 |
+
|
| 247 |
+
using dispatch = details::_redux_dispatch<TyVal, TyFnInput, TyGroup>;
|
| 248 |
+
return dispatch::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp<TyFnInput>>(op));
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
template <typename TyVal, typename TyOp, typename TyGroup>
|
| 253 |
+
_CG_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) {
|
| 254 |
+
return details::coalesced_reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyOp>(op));
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
template <unsigned int GroupId>
|
| 258 |
+
struct tile_reduce_dispatch;
|
| 259 |
+
|
| 260 |
+
template <>
|
| 261 |
+
struct tile_reduce_dispatch<details::coalesced_group_id> {
|
| 262 |
+
template <typename TyGroup, typename TyVal, typename TyFn>
|
| 263 |
+
_CG_STATIC_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 264 |
+
return details::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 265 |
+
}
|
| 266 |
+
};
|
| 267 |
+
|
| 268 |
+
#if defined(_CG_CPP11_FEATURES)
|
| 269 |
+
template <>
|
| 270 |
+
struct tile_reduce_dispatch<details::multi_tile_group_id> {
|
| 271 |
+
template <unsigned int Size, typename ParentT, typename TyVal, typename TyFn>
|
| 272 |
+
_CG_STATIC_QUALIFIER auto reduce(const thread_block_tile<Size, ParentT>& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 273 |
+
using warpType = details::internal_thread_block_tile<32, __static_size_multi_warp_tile_base<Size>>;
|
| 274 |
+
using TyRet = details::remove_qual<TyVal>;
|
| 275 |
+
const unsigned int num_warps = Size / 32;
|
| 276 |
+
|
| 277 |
+
auto warp_lambda = [&] (const warpType& warp, TyRet* warp_scratch_location) {
|
| 278 |
+
*warp_scratch_location =
|
| 279 |
+
details::reduce(warp, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
|
| 280 |
+
};
|
| 281 |
+
auto inter_warp_lambda =
|
| 282 |
+
[&] (const details::internal_thread_block_tile<num_warps, warpType>& subwarp, TyRet* thread_scratch_location) {
|
| 283 |
+
*thread_scratch_location =
|
| 284 |
+
details::reduce(subwarp, *thread_scratch_location, _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 285 |
+
};
|
| 286 |
+
return details::multi_warp_collectives_helper<TyRet>(group, warp_lambda, inter_warp_lambda);
|
| 287 |
+
}
|
| 288 |
+
};
|
| 289 |
+
|
| 290 |
+
template <unsigned int GroupId>
|
| 291 |
+
struct tile_async_reduce_dispatch;
|
| 292 |
+
|
| 293 |
+
template <>
|
| 294 |
+
struct tile_async_reduce_dispatch<details::coalesced_group_id> {
|
| 295 |
+
template <typename GroupT, typename TyDst, typename TyVal, typename TyFn, typename TyResHandler>
|
| 296 |
+
_CG_STATIC_QUALIFIER void reduce(const GroupT& group, TyDst& dst, TyVal&& val, TyFn&& op, TyResHandler& res_handler) {
|
| 297 |
+
// Do regular, in group reduction
|
| 298 |
+
auto result = details::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
|
| 299 |
+
|
| 300 |
+
// One thread stores/updates the destination
|
| 301 |
+
if (group.thread_rank() == 0) {
|
| 302 |
+
res_handler(result);
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
};
|
| 306 |
+
|
| 307 |
+
template <>
|
| 308 |
+
struct tile_async_reduce_dispatch<details::multi_tile_group_id> {
|
| 309 |
+
template <unsigned int TySize, typename ParentT, typename TyDst, typename TyInputVal, typename TyFn, typename TyResHandler>
|
| 310 |
+
_CG_STATIC_QUALIFIER void reduce(const thread_block_tile<TySize, ParentT>& group, TyDst& dst, TyInputVal&& val, TyFn&& op, TyResHandler& res_handler) {
|
| 311 |
+
using TyVal = remove_qual<TyInputVal>;
|
| 312 |
+
const unsigned int num_warps = TySize / 32;
|
| 313 |
+
details::barrier_t* sync_location = multi_warp_sync_location_getter(group);
|
| 314 |
+
auto warp_scratch_location = multi_warp_scratch_location_getter<TyVal>(group, group.thread_rank() / 32);
|
| 315 |
+
|
| 316 |
+
// Do in warp reduce
|
| 317 |
+
auto warp = details::tiled_partition_internal<32, thread_block_tile<TySize, ParentT>>();
|
| 318 |
+
*warp_scratch_location = details::reduce(warp, _CG_STL_NAMESPACE::forward<TyInputVal>(val), op);
|
| 319 |
+
|
| 320 |
+
// Tile of size num_warps from the last warp to arrive does final reduction step
|
| 321 |
+
if (details::sync_warps_last_releases(sync_location, details::cta::thread_rank(), num_warps)) {
|
| 322 |
+
auto subwarp = details::tiled_partition_internal<num_warps, decltype(warp)>();
|
| 323 |
+
if (subwarp.meta_group_rank() == 0) {
|
| 324 |
+
auto thread_scratch_location = multi_warp_scratch_location_getter<TyVal>(group, subwarp.thread_rank());
|
| 325 |
+
auto thread_val = *thread_scratch_location;
|
| 326 |
+
// Release other warps, we read their contribution already.
|
| 327 |
+
subwarp.sync();
|
| 328 |
+
details::sync_warps_release(sync_location, subwarp.thread_rank() == 0, details::cta::thread_rank(), num_warps);
|
| 329 |
+
TyVal result = details::reduce(subwarp, thread_val, op);
|
| 330 |
+
// One thread stores the result or updates the atomic
|
| 331 |
+
if (subwarp.thread_rank() == 0) {
|
| 332 |
+
res_handler(result);
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
warp.sync();
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
};
|
| 339 |
+
#endif
|
| 340 |
+
|
| 341 |
+
template <typename TyGroup, typename TyInputVal, typename TyRetVal>
|
| 342 |
+
_CG_QUALIFIER void check_reduce_params() {
|
| 343 |
+
static_assert(details::is_op_type_same<TyInputVal, TyRetVal>::value, "Operator input and output types differ");
|
| 344 |
+
static_assert(details::reduce_group_supported<TyGroup>::value, "This group does not exclusively represent a tile");
|
| 345 |
+
};
|
| 346 |
+
|
| 347 |
+
template <typename TyGroup, typename TyDstVal, typename TyInputVal, typename TyRetVal>
|
| 348 |
+
_CG_QUALIFIER void check_async_reduce_params() {
|
| 349 |
+
check_reduce_params<TyGroup, TyInputVal, TyRetVal>();
|
| 350 |
+
static_assert(details::is_op_type_same<TyDstVal, TyInputVal>::value, "Destination and input types differ");
|
| 351 |
+
}
|
| 352 |
+
} // details
|
| 353 |
+
|
| 354 |
+
template <typename TyGroup, typename TyVal, typename TyFn>
|
| 355 |
+
_CG_QUALIFIER auto reduce(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 356 |
+
details::check_reduce_params<TyGroup, details::remove_qual<TyVal>, decltype(op(val, val))>();
|
| 357 |
+
|
| 358 |
+
using dispatch = details::tile_reduce_dispatch<TyGroup::_group_id>;
|
| 359 |
+
return dispatch::reduce(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
#if defined(_CG_CPP11_FEATURES)
|
| 363 |
+
|
| 364 |
+
# if defined(_CG_HAS_STL_ATOMICS)
|
| 365 |
+
template<typename TyGroup, typename TyVal, cuda::thread_scope Sco, typename TyInputVal, typename TyFn>
|
| 366 |
+
void _CG_QUALIFIER reduce_update_async(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) {
|
| 367 |
+
details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 368 |
+
auto update_lambda = [&] (TyVal& result) {
|
| 369 |
+
details::atomic_update(dst, result, op);
|
| 370 |
+
};
|
| 371 |
+
using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
|
| 372 |
+
dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), update_lambda);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
template<typename TyGroup, typename TyVal, cuda::thread_scope Sco, typename TyInputVal, typename TyFn>
|
| 376 |
+
void _CG_QUALIFIER reduce_update_async(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) {
|
| 377 |
+
details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 378 |
+
auto update_lambda = [&] (TyVal& result) {
|
| 379 |
+
details::atomic_update(dst, result, op);
|
| 380 |
+
};
|
| 381 |
+
using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
|
| 382 |
+
dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), update_lambda);
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
template<typename TyGroup, typename TyVal, cuda::thread_scope Sco, typename TyInputVal, typename TyFn>
|
| 386 |
+
void _CG_QUALIFIER reduce_store_async(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) {
|
| 387 |
+
details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 388 |
+
auto store_lambda = [&] (TyVal& result) {
|
| 389 |
+
details::atomic_store(dst, result);
|
| 390 |
+
};
|
| 391 |
+
using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
|
| 392 |
+
dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), store_lambda);
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
template<typename TyGroup, typename TyVal, cuda::thread_scope Sco, typename TyInputVal, typename TyFn>
|
| 396 |
+
void _CG_QUALIFIER reduce_store_async(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) {
|
| 397 |
+
details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 398 |
+
auto store_lambda = [&] (TyVal& result) {
|
| 399 |
+
details::atomic_store(dst, result);
|
| 400 |
+
};
|
| 401 |
+
using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
|
| 402 |
+
dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), store_lambda);
|
| 403 |
+
}
|
| 404 |
+
# endif
|
| 405 |
+
|
| 406 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, typename TyFn>
|
| 407 |
+
void _CG_QUALIFIER reduce_store_async(const TyGroup& group, TyVal* dst, TyInputVal&& val, TyFn&& op) {
|
| 408 |
+
details::check_async_reduce_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 409 |
+
auto store_lambda = [&] (TyVal& result) {
|
| 410 |
+
*dst = result;
|
| 411 |
+
};
|
| 412 |
+
using dispatch = details::tile_async_reduce_dispatch<TyGroup::_group_id>;
|
| 413 |
+
dispatch::reduce(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op), store_lambda);
|
| 414 |
+
}
|
| 415 |
+
#endif
|
| 416 |
+
|
| 417 |
+
_CG_END_NAMESPACE
|
| 418 |
+
|
| 419 |
+
#endif // _CG_REDUCE_H_
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/scan.h
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
#ifndef _CG_SCAN_H_
|
| 50 |
+
#define _CG_SCAN_H_
|
| 51 |
+
|
| 52 |
+
#include "info.h"
|
| 53 |
+
#include "helpers.h"
|
| 54 |
+
#include "functional.h"
|
| 55 |
+
#include "coalesced_scan.h"
|
| 56 |
+
|
| 57 |
+
_CG_BEGIN_NAMESPACE
|
| 58 |
+
|
| 59 |
+
namespace details {
|
| 60 |
+
|
| 61 |
+
// Group support for scan.
|
| 62 |
+
template <class TyGroup> struct _scan_group_supported : public _CG_STL_NAMESPACE::false_type {};
|
| 63 |
+
|
| 64 |
+
template <unsigned int Sz, typename TyPar>
|
| 65 |
+
struct _scan_group_supported<cooperative_groups::thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::true_type {};
|
| 66 |
+
template <unsigned int Sz, typename TyPar>
|
| 67 |
+
struct _scan_group_supported<internal_thread_block_tile<Sz, TyPar>> : public _CG_STL_NAMESPACE::true_type {};
|
| 68 |
+
template <>
|
| 69 |
+
struct _scan_group_supported<cooperative_groups::coalesced_group> : public _CG_STL_NAMESPACE::true_type {};
|
| 70 |
+
|
| 71 |
+
template <typename TyGroup>
|
| 72 |
+
using scan_group_supported = _scan_group_supported<details::remove_qual<TyGroup>>;
|
| 73 |
+
|
| 74 |
+
template <bool IsIntegralPlus>
|
| 75 |
+
struct integral_optimized_scan;
|
| 76 |
+
|
| 77 |
+
enum class ScanType { exclusive, inclusive };
|
| 78 |
+
|
| 79 |
+
template <unsigned int GroupId, ScanType TyScan>
|
| 80 |
+
struct scan_dispatch;
|
| 81 |
+
|
| 82 |
+
template <ScanType TyScan>
|
| 83 |
+
struct scan_dispatch<details::coalesced_group_id, TyScan> {
|
| 84 |
+
template <typename TyGroup, typename TyVal, typename TyFn>
|
| 85 |
+
_CG_STATIC_QUALIFIER auto scan(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 86 |
+
auto scan_result = coalesced_inclusive_scan(group, val, op);
|
| 87 |
+
if (TyScan == ScanType::exclusive) {
|
| 88 |
+
scan_result = convert_inclusive_to_exclusive(group,
|
| 89 |
+
scan_result,
|
| 90 |
+
_CG_STL_NAMESPACE::forward<TyVal>(val),
|
| 91 |
+
_CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 92 |
+
}
|
| 93 |
+
return scan_result;
|
| 94 |
+
}
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
#if defined(_CG_CPP11_FEATURES)
|
| 98 |
+
template <ScanType TyScan>
|
| 99 |
+
struct scan_dispatch<details::multi_tile_group_id, TyScan> {
|
| 100 |
+
template <unsigned int Size, typename ParentT, typename TyVal, typename TyFn>
|
| 101 |
+
_CG_STATIC_QUALIFIER auto scan(const thread_block_tile<Size, ParentT>& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 102 |
+
using warpType = details::internal_thread_block_tile<32, __static_size_multi_warp_tile_base<Size>>;
|
| 103 |
+
using TyRet = details::remove_qual<TyVal>;
|
| 104 |
+
const unsigned int num_warps = Size / 32;
|
| 105 |
+
// In warp scan result, calculated in warp_lambda
|
| 106 |
+
TyRet warp_scan;
|
| 107 |
+
|
| 108 |
+
// In warp scan, put sum in the warp_scratch_location
|
| 109 |
+
auto warp_lambda = [&] (const warpType& warp, TyRet* warp_scratch_location) {
|
| 110 |
+
warp_scan =
|
| 111 |
+
details::coalesced_inclusive_scan(warp, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
|
| 112 |
+
if (warp.thread_rank() + 1 == warp.size()) {
|
| 113 |
+
*warp_scratch_location = warp_scan;
|
| 114 |
+
}
|
| 115 |
+
if (TyScan == ScanType::exclusive) {
|
| 116 |
+
warp_scan = warp.shfl_up(warp_scan, 1);
|
| 117 |
+
}
|
| 118 |
+
};
|
| 119 |
+
|
| 120 |
+
// Tile of size num_warps performing the final scan part (exclusive scan of warp sums), other threads will add it
|
| 121 |
+
// to its in-warp scan result
|
| 122 |
+
auto inter_warp_lambda =
|
| 123 |
+
[&] (const details::internal_thread_block_tile<num_warps, warpType>& subwarp, TyRet* thread_scratch_location) {
|
| 124 |
+
auto thread_val = *thread_scratch_location;
|
| 125 |
+
auto result = coalesced_inclusive_scan(subwarp, thread_val, op);
|
| 126 |
+
*thread_scratch_location = convert_inclusive_to_exclusive(subwarp, result, thread_val, op);
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
TyRet previous_warps_sum = details::multi_warp_collectives_helper<TyRet>(group, warp_lambda, inter_warp_lambda);
|
| 130 |
+
if (TyScan == ScanType::exclusive && warpType::thread_rank() == 0) {
|
| 131 |
+
return previous_warps_sum;
|
| 132 |
+
}
|
| 133 |
+
if (warpType::meta_group_rank() == 0) {
|
| 134 |
+
return warp_scan;
|
| 135 |
+
}
|
| 136 |
+
else {
|
| 137 |
+
return op(warp_scan, previous_warps_sum);
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
};
|
| 141 |
+
|
| 142 |
+
#if defined(_CG_HAS_STL_ATOMICS)
|
| 143 |
+
template <unsigned int GroupId, ScanType TyScan>
|
| 144 |
+
struct scan_update_dispatch;
|
| 145 |
+
|
| 146 |
+
template <ScanType TyScan>
|
| 147 |
+
struct scan_update_dispatch<details::coalesced_group_id, TyScan> {
|
| 148 |
+
template <typename TyGroup, typename TyAtomic, typename TyVal, typename TyFn>
|
| 149 |
+
_CG_STATIC_QUALIFIER auto scan(const TyGroup& group, TyAtomic& dst, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 150 |
+
details::remove_qual<TyVal> old;
|
| 151 |
+
|
| 152 |
+
// Do regular in group scan
|
| 153 |
+
auto scan_result = details::coalesced_inclusive_scan(group, val, op);
|
| 154 |
+
|
| 155 |
+
// Last thread updates the atomic and distributes its old value to other threads
|
| 156 |
+
if (group.thread_rank() == group.size() - 1) {
|
| 157 |
+
old = atomic_update(dst, scan_result, _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 158 |
+
}
|
| 159 |
+
old = group.shfl(old, group.size() - 1);
|
| 160 |
+
if (TyScan == ScanType::exclusive) {
|
| 161 |
+
scan_result = convert_inclusive_to_exclusive(group, scan_result, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
|
| 162 |
+
}
|
| 163 |
+
scan_result = op(old, scan_result);
|
| 164 |
+
return scan_result;
|
| 165 |
+
}
|
| 166 |
+
};
|
| 167 |
+
|
| 168 |
+
template <ScanType TyScan>
|
| 169 |
+
struct scan_update_dispatch<details::multi_tile_group_id, TyScan> {
|
| 170 |
+
template <unsigned int Size, typename ParentT, typename TyAtomic, typename TyVal, typename TyFn>
|
| 171 |
+
_CG_STATIC_QUALIFIER auto scan(const thread_block_tile<Size, ParentT>& group, TyAtomic& dst, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 172 |
+
using warpType = details::internal_thread_block_tile<32, __static_size_multi_warp_tile_base<Size>>;
|
| 173 |
+
using TyRet = details::remove_qual<TyVal>;
|
| 174 |
+
const unsigned int num_warps = Size / 32;
|
| 175 |
+
// In warp scan result, calculated in warp_lambda
|
| 176 |
+
TyRet warp_scan;
|
| 177 |
+
|
| 178 |
+
// In warp scan, put sum in the warp_scratch_location
|
| 179 |
+
auto warp_lambda = [&] (const warpType& warp, TyRet* warp_scratch_location) {
|
| 180 |
+
warp_scan =
|
| 181 |
+
details::coalesced_inclusive_scan(warp, _CG_STL_NAMESPACE::forward<TyVal>(val), op);
|
| 182 |
+
if (warp.thread_rank() + 1 == warp.size()) {
|
| 183 |
+
*warp_scratch_location = warp_scan;
|
| 184 |
+
}
|
| 185 |
+
if (TyScan == ScanType::exclusive) {
|
| 186 |
+
warp_scan = warp.shfl_up(warp_scan, 1);
|
| 187 |
+
}
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
// Tile of size num_warps performing the final scan part (exclusive scan of warp sums), other threads will add it
|
| 191 |
+
// to its in-warp scan result
|
| 192 |
+
auto inter_warp_lambda =
|
| 193 |
+
[&] (const details::internal_thread_block_tile<num_warps, warpType>& subwarp, TyRet* thread_scratch_location) {
|
| 194 |
+
auto thread_val = *thread_scratch_location;
|
| 195 |
+
auto scan_result = details::coalesced_inclusive_scan(subwarp, thread_val, op);
|
| 196 |
+
TyRet offset;
|
| 197 |
+
// Single thread does the atomic update with sum of all contributions and reads the old value.
|
| 198 |
+
if (subwarp.thread_rank() == subwarp.size() - 1) {
|
| 199 |
+
offset = details::atomic_update(dst, scan_result, op);
|
| 200 |
+
}
|
| 201 |
+
offset = subwarp.shfl(offset, subwarp.size() - 1);
|
| 202 |
+
scan_result = convert_inclusive_to_exclusive(subwarp, scan_result, thread_val, op);
|
| 203 |
+
// Add offset read from the atomic to the scanned warp sum.
|
| 204 |
+
// Skipping first thread, since it got defautly constructed value from the conversion,
|
| 205 |
+
// it should just return the offset received from the thread that did the atomic update.
|
| 206 |
+
if (subwarp.thread_rank() != 0) {
|
| 207 |
+
offset = op(scan_result, offset);
|
| 208 |
+
}
|
| 209 |
+
*thread_scratch_location = offset;
|
| 210 |
+
};
|
| 211 |
+
|
| 212 |
+
TyRet previous_warps_sum = details::multi_warp_collectives_helper<TyRet>(group, warp_lambda, inter_warp_lambda);
|
| 213 |
+
if (TyScan == ScanType::exclusive && warpType::thread_rank() == 0) {
|
| 214 |
+
return previous_warps_sum;
|
| 215 |
+
}
|
| 216 |
+
return op(warp_scan, previous_warps_sum);
|
| 217 |
+
}
|
| 218 |
+
};
|
| 219 |
+
#endif
|
| 220 |
+
#endif
|
| 221 |
+
|
| 222 |
+
template <typename TyGroup, typename TyInputVal, typename TyRetVal>
|
| 223 |
+
_CG_QUALIFIER void check_scan_params() {
|
| 224 |
+
static_assert(details::is_op_type_same<TyInputVal, TyRetVal>::value, "Operator input and output types differ");
|
| 225 |
+
static_assert(details::scan_group_supported<TyGroup>::value, "This group does not exclusively represent a tile");
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
#if defined(_CG_HAS_STL_ATOMICS)
|
| 229 |
+
template <typename TyGroup, typename TyDstVal, typename TyInputVal, typename TyRetVal>
|
| 230 |
+
_CG_QUALIFIER void check_scan_update_params() {
|
| 231 |
+
check_scan_params<TyGroup, TyInputVal, TyRetVal>();
|
| 232 |
+
static_assert(details::is_op_type_same<TyDstVal, TyInputVal>::value, "Destination and input types differ");
|
| 233 |
+
}
|
| 234 |
+
#endif
|
| 235 |
+
|
| 236 |
+
} // details
|
| 237 |
+
|
| 238 |
+
template <typename TyGroup, typename TyVal, typename TyFn>
|
| 239 |
+
_CG_QUALIFIER auto inclusive_scan(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 240 |
+
details::check_scan_params<TyGroup, TyVal, decltype(op(val, val))>();
|
| 241 |
+
|
| 242 |
+
using dispatch = details::scan_dispatch<TyGroup::_group_id, details::ScanType::inclusive>;
|
| 243 |
+
return dispatch::scan(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
template <typename TyGroup, typename TyVal>
|
| 247 |
+
_CG_QUALIFIER details::remove_qual<TyVal> inclusive_scan(const TyGroup& group, TyVal&& val) {
|
| 248 |
+
return inclusive_scan(group, _CG_STL_NAMESPACE::forward<TyVal>(val), cooperative_groups::plus<details::remove_qual<TyVal>>());
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
template <typename TyGroup, typename TyVal, typename TyFn>
|
| 252 |
+
_CG_QUALIFIER auto exclusive_scan(const TyGroup& group, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 253 |
+
details::check_scan_params<TyGroup, TyVal, decltype(op(val, val))>();
|
| 254 |
+
|
| 255 |
+
using dispatch = details::scan_dispatch<TyGroup::_group_id, details::ScanType::exclusive>;
|
| 256 |
+
return dispatch::scan(group, _CG_STL_NAMESPACE::forward<TyVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
template <typename TyGroup, typename TyVal>
|
| 260 |
+
_CG_QUALIFIER details::remove_qual<TyVal> exclusive_scan(const TyGroup& group, TyVal&& val) {
|
| 261 |
+
return exclusive_scan(group, _CG_STL_NAMESPACE::forward<TyVal>(val), cooperative_groups::plus<details::remove_qual<TyVal>>());
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
#if defined(_CG_HAS_STL_ATOMICS)
|
| 265 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco, typename TyFn>
|
| 266 |
+
_CG_QUALIFIER auto inclusive_scan_update(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 267 |
+
details::check_scan_update_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 268 |
+
|
| 269 |
+
using dispatch = details::scan_update_dispatch<TyGroup::_group_id, details::ScanType::inclusive>;
|
| 270 |
+
return dispatch::scan(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco>
|
| 274 |
+
_CG_QUALIFIER TyVal inclusive_scan_update(const TyGroup& group, cuda::atomic<TyVal, Sco> & dst, TyInputVal&& val) {
|
| 275 |
+
return inclusive_scan_update(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), cooperative_groups::plus<TyVal>());
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco, typename TyFn>
|
| 279 |
+
_CG_QUALIFIER auto exclusive_scan_update(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 280 |
+
details::check_scan_update_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 281 |
+
|
| 282 |
+
using dispatch = details::scan_update_dispatch<TyGroup::_group_id, details::ScanType::exclusive>;
|
| 283 |
+
return dispatch::scan(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco>
|
| 287 |
+
_CG_QUALIFIER TyVal exclusive_scan_update(const TyGroup& group, cuda::atomic<TyVal, Sco>& dst, TyInputVal&& val) {
|
| 288 |
+
return exclusive_scan_update(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), cooperative_groups::plus<TyVal>());
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco, typename TyFn>
|
| 292 |
+
_CG_QUALIFIER auto inclusive_scan_update(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 293 |
+
details::check_scan_update_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 294 |
+
|
| 295 |
+
using dispatch = details::scan_update_dispatch<TyGroup::_group_id, details::ScanType::inclusive>;
|
| 296 |
+
return dispatch::scan(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco>
|
| 300 |
+
_CG_QUALIFIER TyVal inclusive_scan_update(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco> & dst, TyInputVal&& val) {
|
| 301 |
+
return inclusive_scan_update(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), cooperative_groups::plus<TyVal>());
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco, typename TyFn>
|
| 305 |
+
_CG_QUALIFIER auto exclusive_scan_update(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val, TyFn&& op) -> decltype(op(val, val)) {
|
| 306 |
+
details::check_scan_update_params<TyGroup, TyVal, details::remove_qual<TyInputVal>, decltype(op(val, val))>();
|
| 307 |
+
|
| 308 |
+
using dispatch = details::scan_update_dispatch<TyGroup::_group_id, details::ScanType::exclusive>;
|
| 309 |
+
return dispatch::scan(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), _CG_STL_NAMESPACE::forward<TyFn>(op));
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
template<typename TyGroup, typename TyVal, typename TyInputVal, cuda::thread_scope Sco>
|
| 313 |
+
_CG_QUALIFIER TyVal exclusive_scan_update(const TyGroup& group, const cuda::atomic_ref<TyVal, Sco>& dst, TyInputVal&& val) {
|
| 314 |
+
return exclusive_scan_update(group, dst, _CG_STL_NAMESPACE::forward<TyInputVal>(val), cooperative_groups::plus<TyVal>());
|
| 315 |
+
}
|
| 316 |
+
#endif
|
| 317 |
+
|
| 318 |
+
_CG_END_NAMESPACE
|
| 319 |
+
|
| 320 |
+
#endif // _CG_SCAN_H_
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/sync.h
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
#ifndef _CG_GRID_H
|
| 50 |
+
#define _CG_GRID_H
|
| 51 |
+
|
| 52 |
+
#include "info.h"
|
| 53 |
+
|
| 54 |
+
_CG_BEGIN_NAMESPACE
|
| 55 |
+
|
| 56 |
+
namespace details
|
| 57 |
+
{
|
| 58 |
+
|
| 59 |
+
typedef unsigned int barrier_t;
|
| 60 |
+
|
| 61 |
+
_CG_STATIC_QUALIFIER bool bar_has_flipped(unsigned int old_arrive, unsigned int current_arrive) {
|
| 62 |
+
return (((old_arrive ^ current_arrive) & 0x80000000) != 0);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
_CG_STATIC_QUALIFIER bool is_cta_master() {
|
| 66 |
+
return (threadIdx.x + threadIdx.y + threadIdx.z == 0);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
_CG_STATIC_QUALIFIER unsigned int sync_grids_arrive(volatile barrier_t *arrived) {
|
| 70 |
+
unsigned int oldArrive = 0;
|
| 71 |
+
|
| 72 |
+
__barrier_sync(0);
|
| 73 |
+
|
| 74 |
+
if (is_cta_master()) {
|
| 75 |
+
unsigned int expected = gridDim.x * gridDim.y * gridDim.z;
|
| 76 |
+
bool gpu_master = (blockIdx.x + blockIdx.y + blockIdx.z == 0);
|
| 77 |
+
unsigned int nb = 1;
|
| 78 |
+
|
| 79 |
+
if (gpu_master) {
|
| 80 |
+
nb = 0x80000000 - (expected - 1);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
#if __CUDA_ARCH__ < 700
|
| 84 |
+
// Fence; barrier update; volatile polling; fence
|
| 85 |
+
__threadfence();
|
| 86 |
+
|
| 87 |
+
oldArrive = atomicAdd((unsigned int*)arrived, nb);
|
| 88 |
+
#else
|
| 89 |
+
// Barrier update with release; polling with acquire
|
| 90 |
+
asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), "r"(nb) : "memory");
|
| 91 |
+
#endif
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
return oldArrive;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
_CG_STATIC_QUALIFIER void sync_grids_wait(unsigned int oldArrive, volatile barrier_t *arrived) {
|
| 99 |
+
if (is_cta_master()) {
|
| 100 |
+
#if __CUDA_ARCH__ < 700
|
| 101 |
+
while (!bar_has_flipped(oldArrive, *arrived));
|
| 102 |
+
|
| 103 |
+
__threadfence();
|
| 104 |
+
|
| 105 |
+
#else
|
| 106 |
+
unsigned int current_arrive;
|
| 107 |
+
do {
|
| 108 |
+
asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived) : "memory");
|
| 109 |
+
} while (!bar_has_flipped(oldArrive, current_arrive));
|
| 110 |
+
#endif
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
__barrier_sync(0);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
/* - Multi warp groups synchronization routines - */
|
| 117 |
+
|
| 118 |
+
// Need both acquire and release for the last warp, since it won't be able to acquire with red.and
|
| 119 |
+
_CG_STATIC_QUALIFIER unsigned int atom_or_acq_rel_cta(unsigned int *addr, unsigned int val) {
|
| 120 |
+
unsigned int old;
|
| 121 |
+
#if __CUDA_ARCH__ < 700
|
| 122 |
+
__threadfence_block();
|
| 123 |
+
old = atomicOr(addr, val);
|
| 124 |
+
#else
|
| 125 |
+
asm volatile("atom.or.acq_rel.cta.b32 %0,[%1],%2;" : "=r"(old) : _CG_ASM_PTR_CONSTRAINT(addr), "r"(val) : "memory");
|
| 126 |
+
#endif
|
| 127 |
+
return old;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
// Special case where barrier is arrived, but not waited on
|
| 131 |
+
_CG_STATIC_QUALIFIER void red_or_release_cta(unsigned int *addr, unsigned int val) {
|
| 132 |
+
#if __CUDA_ARCH__ < 700
|
| 133 |
+
__threadfence_block();
|
| 134 |
+
atomicOr(addr, val);
|
| 135 |
+
#else
|
| 136 |
+
asm volatile("red.or.release.cta.b32 [%0],%1;" :: _CG_ASM_PTR_CONSTRAINT(addr), "r"(val) : "memory");
|
| 137 |
+
#endif
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Usually called by last arriving warp to released other warps, can be relaxed, since or was already acq_rel
|
| 141 |
+
_CG_STATIC_QUALIFIER void red_and_relaxed_cta(unsigned int *addr, unsigned int val) {
|
| 142 |
+
#if __CUDA_ARCH__ < 700
|
| 143 |
+
atomicAnd(addr, val);
|
| 144 |
+
#else
|
| 145 |
+
asm volatile("red.and.relaxed.cta.b32 [%0],%1;" :: _CG_ASM_PTR_CONSTRAINT(addr), "r"(val) : "memory");
|
| 146 |
+
#endif
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// Special case of release, where last warp was doing extra work before releasing others, need to be release
|
| 150 |
+
// to ensure that extra work is visible
|
| 151 |
+
_CG_STATIC_QUALIFIER void red_and_release_cta(unsigned int *addr, unsigned int val) {
|
| 152 |
+
#if __CUDA_ARCH__ < 700
|
| 153 |
+
__threadfence_block();
|
| 154 |
+
atomicAnd(addr, val);
|
| 155 |
+
#else
|
| 156 |
+
asm volatile("red.and.release.cta.b32 [%0],%1;" :: _CG_ASM_PTR_CONSTRAINT(addr), "r"(val) : "memory");
|
| 157 |
+
#endif
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
// Read the barrier, acquire to ensure all memory operations following the sync are correctly performed after it is released
|
| 161 |
+
_CG_STATIC_QUALIFIER unsigned int ld_acquire_cta(unsigned int *addr) {
|
| 162 |
+
unsigned int val;
|
| 163 |
+
#if __CUDA_ARCH__ < 700
|
| 164 |
+
val = *((volatile unsigned int*) addr);
|
| 165 |
+
__threadfence_block();
|
| 166 |
+
#else
|
| 167 |
+
asm volatile("ld.acquire.cta.u32 %0,[%1];" : "=r"(val) : _CG_ASM_PTR_CONSTRAINT(addr) : "memory");
|
| 168 |
+
#endif
|
| 169 |
+
return val;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// Get synchronization bit mask of my thread_block_tile of size num_warps. Thread ranks 0..31 have the first bit assigned to them,
|
| 173 |
+
// thread ranks 32..63 second etc
|
| 174 |
+
// Bit masks are unique for each group, groups of the same size will have the same number of bits set, but on different positions
|
| 175 |
+
_CG_STATIC_QUALIFIER unsigned int get_group_mask(unsigned int thread_rank, unsigned int num_warps) {
|
| 176 |
+
return num_warps == 32 ? ~0 : ((1 << num_warps) - 1) << (num_warps * (thread_rank / (num_warps * 32)));
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
_CG_STATIC_QUALIFIER void barrier_wait(barrier_t *arrived, unsigned int warp_bit) {
|
| 180 |
+
while(ld_acquire_cta(arrived) & warp_bit);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
// Default blocking sync.
|
| 184 |
+
_CG_STATIC_QUALIFIER void sync_warps(barrier_t *arrived, unsigned int thread_rank, unsigned int num_warps) {
|
| 185 |
+
unsigned int warp_id = thread_rank / 32;
|
| 186 |
+
bool warp_master = (thread_rank % 32 == 0);
|
| 187 |
+
unsigned int warp_bit = 1 << warp_id;
|
| 188 |
+
unsigned int group_mask = get_group_mask(thread_rank, num_warps);
|
| 189 |
+
|
| 190 |
+
__syncwarp(0xFFFFFFFF);
|
| 191 |
+
|
| 192 |
+
if (warp_master) {
|
| 193 |
+
unsigned int old = atom_or_acq_rel_cta(arrived, warp_bit);
|
| 194 |
+
if (((old | warp_bit) & group_mask) == group_mask) {
|
| 195 |
+
red_and_relaxed_cta(arrived, ~group_mask);
|
| 196 |
+
}
|
| 197 |
+
else {
|
| 198 |
+
barrier_wait(arrived, warp_bit);
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
__syncwarp(0xFFFFFFFF);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// Blocking sync, except the last arriving warp, that releases other warps, returns to do other stuff first.
|
| 206 |
+
// Warp returning true from this function needs to call sync_warps_release.
|
| 207 |
+
_CG_STATIC_QUALIFIER bool sync_warps_last_releases(barrier_t *arrived, unsigned int thread_rank, unsigned int num_warps) {
|
| 208 |
+
unsigned int warp_id = thread_rank / 32;
|
| 209 |
+
bool warp_master = (thread_rank % 32 == 0);
|
| 210 |
+
unsigned int warp_bit = 1 << warp_id;
|
| 211 |
+
unsigned int group_mask = get_group_mask(thread_rank, num_warps);
|
| 212 |
+
|
| 213 |
+
__syncwarp(0xFFFFFFFF);
|
| 214 |
+
|
| 215 |
+
unsigned int old = 0;
|
| 216 |
+
if (warp_master) {
|
| 217 |
+
old = atom_or_acq_rel_cta(arrived, warp_bit);
|
| 218 |
+
}
|
| 219 |
+
old = __shfl_sync(0xFFFFFFFF, old, 0);
|
| 220 |
+
if (((old | warp_bit) & group_mask) == group_mask) {
|
| 221 |
+
return true;
|
| 222 |
+
}
|
| 223 |
+
barrier_wait(arrived, warp_bit);
|
| 224 |
+
|
| 225 |
+
return false;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
// Release my group from the barrier.
|
| 229 |
+
_CG_STATIC_QUALIFIER void sync_warps_release(barrier_t *arrived, bool is_master, unsigned int thread_rank, unsigned int num_warps) {
|
| 230 |
+
unsigned int group_mask = get_group_mask(thread_rank, num_warps);
|
| 231 |
+
if (is_master) {
|
| 232 |
+
red_and_release_cta(arrived, ~group_mask);
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
// Arrive at my group barrier, but don't block or release the barrier, even if every one arrives.
|
| 237 |
+
// sync_warps_release needs to be called by some warp after this one to reset the barrier.
|
| 238 |
+
_CG_STATIC_QUALIFIER void sync_warps_arrive(barrier_t *arrived, unsigned int thread_rank, unsigned int num_warps) {
|
| 239 |
+
unsigned int warp_id = thread_rank / 32;
|
| 240 |
+
bool warp_master = (thread_rank % 32 == 0);
|
| 241 |
+
unsigned int warp_bit = 1 << warp_id;
|
| 242 |
+
unsigned int group_mask = get_group_mask(thread_rank, num_warps);
|
| 243 |
+
|
| 244 |
+
__syncwarp(0xFFFFFFFF);
|
| 245 |
+
|
| 246 |
+
if (warp_master) {
|
| 247 |
+
red_or_release_cta(arrived, warp_bit);
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
// Wait for my warp to be released from the barrier. Warp must have arrived first.
|
| 252 |
+
_CG_STATIC_QUALIFIER void sync_warps_wait(barrier_t *arrived, unsigned int thread_rank) {
|
| 253 |
+
unsigned int warp_id = thread_rank / 32;
|
| 254 |
+
unsigned int warp_bit = 1 << warp_id;
|
| 255 |
+
|
| 256 |
+
barrier_wait(arrived, warp_bit);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Wait for specific warp to arrive at the barrier
|
| 260 |
+
_CG_QUALIFIER void sync_warps_wait_for_specific_warp(barrier_t *arrived, unsigned int wait_warp_id) {
|
| 261 |
+
unsigned int wait_mask = 1 << wait_warp_id;
|
| 262 |
+
while((ld_acquire_cta(arrived) & wait_mask) != wait_mask);
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
// Initialize the bit corresponding to my warp in the barrier
|
| 266 |
+
_CG_QUALIFIER void sync_warps_reset(barrier_t *arrived, unsigned int thread_rank) {
|
| 267 |
+
unsigned int warp_id = thread_rank / 32;
|
| 268 |
+
unsigned int warp_bit = 1 << warp_id;
|
| 269 |
+
|
| 270 |
+
__syncwarp(0xFFFFFFFF);
|
| 271 |
+
|
| 272 |
+
if (thread_rank % 32 == 0) {
|
| 273 |
+
red_and_release_cta(arrived, ~warp_bit);
|
| 274 |
+
}
|
| 275 |
+
// No need to sync after the atomic, there will be a sync of the group that is being partitioned right after this.
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
} // details
|
| 279 |
+
|
| 280 |
+
_CG_END_NAMESPACE
|
| 281 |
+
|
| 282 |
+
#endif // _CG_GRID_H
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8774224f5b11a73b15d074a3fcce7327322c5c4cfdfd924d6a826779eec968fe
|
| 3 |
+
size 707904
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn.h
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/* cudnn : Neural Networks Library */
|
| 51 |
+
|
| 52 |
+
#if !defined(CUDNN_H_)
|
| 53 |
+
#define CUDNN_H_
|
| 54 |
+
#if defined(__cplusplus)
|
| 55 |
+
extern "C" {
|
| 56 |
+
#endif
|
| 57 |
+
|
| 58 |
+
#include <cuda_runtime_api.h>
|
| 59 |
+
#include "cudnn_version.h"
|
| 60 |
+
#include "cudnn_graph.h"
|
| 61 |
+
#include "cudnn_ops.h"
|
| 62 |
+
#include "cudnn_adv.h"
|
| 63 |
+
#include "cudnn_cnn.h"
|
| 64 |
+
|
| 65 |
+
#if defined(__cplusplus)
|
| 66 |
+
}
|
| 67 |
+
#endif
|
| 68 |
+
#endif /* CUDNN_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv_v9.h
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/* cudnn_adv : cuDNN's advanced and experimental features.
|
| 51 |
+
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_ADV_H_)
|
| 55 |
+
#define CUDNN_ADV_H_
|
| 56 |
+
|
| 57 |
+
#include <stdint.h>
|
| 58 |
+
|
| 59 |
+
#include "cudnn_version.h"
|
| 60 |
+
#include "cudnn_ops.h"
|
| 61 |
+
|
| 62 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 63 |
+
#define CUDNN_ADV_MAJOR 9
|
| 64 |
+
#define CUDNN_ADV_MINOR 1
|
| 65 |
+
#define CUDNN_ADV_PATCH 0
|
| 66 |
+
|
| 67 |
+
#if (CUDNN_ADV_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_MINOR != CUDNN_MINOR) || (CUDNN_ADV_PATCH != CUDNN_PATCHLEVEL)
|
| 68 |
+
#error Version mismatch in cuDNN ADV INFER!!!
|
| 69 |
+
#endif
|
| 70 |
+
|
| 71 |
+
#if defined(__cplusplus)
|
| 72 |
+
extern "C" {
|
| 73 |
+
#endif
|
| 74 |
+
|
| 75 |
+
/* BASIC RNN API */
|
| 76 |
+
|
| 77 |
+
typedef enum {
|
| 78 |
+
CUDNN_RNN_ALGO_STANDARD = 0,
|
| 79 |
+
CUDNN_RNN_ALGO_PERSIST_STATIC = 1,
|
| 80 |
+
CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2,
|
| 81 |
+
CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H = 3,
|
| 82 |
+
CUDNN_RNN_ALGO_COUNT = 4,
|
| 83 |
+
} cudnnRNNAlgo_t;
|
| 84 |
+
|
| 85 |
+
typedef enum {
|
| 86 |
+
CUDNN_FWD_MODE_INFERENCE = 0,
|
| 87 |
+
CUDNN_FWD_MODE_TRAINING = 1,
|
| 88 |
+
} cudnnForwardMode_t;
|
| 89 |
+
|
| 90 |
+
typedef enum {
|
| 91 |
+
CUDNN_RNN_RELU = 0, /* basic RNN cell type with ReLu activation */
|
| 92 |
+
CUDNN_RNN_TANH = 1, /* basic RNN cell type with tanh activation */
|
| 93 |
+
CUDNN_LSTM = 2, /* LSTM with optional recurrent projection and clipping */
|
| 94 |
+
CUDNN_GRU = 3, /* Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1); */
|
| 95 |
+
} cudnnRNNMode_t;
|
| 96 |
+
|
| 97 |
+
typedef enum {
|
| 98 |
+
CUDNN_RNN_NO_BIAS = 0, /* rnn cell formulas do not use biases */
|
| 99 |
+
CUDNN_RNN_SINGLE_INP_BIAS = 1, /* rnn cell formulas use one input bias in input GEMM */
|
| 100 |
+
CUDNN_RNN_DOUBLE_BIAS = 2, /* default, rnn cell formulas use two bias vectors */
|
| 101 |
+
CUDNN_RNN_SINGLE_REC_BIAS = 3 /* rnn cell formulas use one recurrent bias in recurrent GEMM */
|
| 102 |
+
} cudnnRNNBiasMode_t;
|
| 103 |
+
|
| 104 |
+
typedef enum {
|
| 105 |
+
CUDNN_UNIDIRECTIONAL = 0, /* single direction network */
|
| 106 |
+
CUDNN_BIDIRECTIONAL = 1, /* output concatination at each layer */
|
| 107 |
+
} cudnnDirectionMode_t;
|
| 108 |
+
|
| 109 |
+
typedef enum {
|
| 110 |
+
CUDNN_LINEAR_INPUT = 0, /* adjustable weight matrix in first layer input GEMM */
|
| 111 |
+
CUDNN_SKIP_INPUT = 1, /* fixed identity matrix in the first layer input GEMM */
|
| 112 |
+
} cudnnRNNInputMode_t;
|
| 113 |
+
|
| 114 |
+
typedef enum {
|
| 115 |
+
CUDNN_RNN_CLIP_NONE = 0, /* disables LSTM cell clipping */
|
| 116 |
+
CUDNN_RNN_CLIP_MINMAX = 1, /* enables LSTM cell clipping */
|
| 117 |
+
} cudnnRNNClipMode_t;
|
| 118 |
+
|
| 119 |
+
typedef enum {
|
| 120 |
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED = 0, /* padded, outer stride from one time-step to the next */
|
| 121 |
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED = 1, /* sequence length sorted and packed as in basic RNN api */
|
| 122 |
+
CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED = 2, /* padded, outer stride from one batch to the next */
|
| 123 |
+
} cudnnRNNDataLayout_t;
|
| 124 |
+
|
| 125 |
+
/* For auxFlags in cudnnSetRNNDescriptor_v8() */
|
| 126 |
+
#define CUDNN_RNN_PADDED_IO_DISABLED 0
|
| 127 |
+
#define CUDNN_RNN_PADDED_IO_ENABLED (1U << 0)
|
| 128 |
+
|
| 129 |
+
struct cudnnRNNStruct;
|
| 130 |
+
typedef struct cudnnRNNStruct *cudnnRNNDescriptor_t;
|
| 131 |
+
|
| 132 |
+
struct cudnnRNNDataStruct;
|
| 133 |
+
typedef struct cudnnRNNDataStruct *cudnnRNNDataDescriptor_t;
|
| 134 |
+
|
| 135 |
+
cudnnStatus_t CUDNNWINAPI
|
| 136 |
+
cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc);
|
| 137 |
+
|
| 138 |
+
cudnnStatus_t CUDNNWINAPI
|
| 139 |
+
cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
|
| 140 |
+
|
| 141 |
+
/*
|
| 142 |
+
* mathPrec in cudnnSetRNNDescriptor_v8() specifies compute precision.
|
| 143 |
+
* Compute precision is further modified by mathType that sets the
|
| 144 |
+
* preferred option for using NVIDIA Tensor Cores. dataType specify
|
| 145 |
+
* input/output data type and weight/bias type.
|
| 146 |
+
*/
|
| 147 |
+
|
| 148 |
+
cudnnStatus_t CUDNNWINAPI
|
| 149 |
+
cudnnSetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
|
| 150 |
+
cudnnRNNAlgo_t algo,
|
| 151 |
+
cudnnRNNMode_t cellMode,
|
| 152 |
+
cudnnRNNBiasMode_t biasMode,
|
| 153 |
+
cudnnDirectionMode_t dirMode,
|
| 154 |
+
cudnnRNNInputMode_t inputMode,
|
| 155 |
+
cudnnDataType_t dataType,
|
| 156 |
+
cudnnDataType_t mathPrec,
|
| 157 |
+
cudnnMathType_t mathType,
|
| 158 |
+
int32_t inputSize,
|
| 159 |
+
int32_t hiddenSize,
|
| 160 |
+
int32_t projSize,
|
| 161 |
+
int32_t numLayers,
|
| 162 |
+
cudnnDropoutDescriptor_t dropoutDesc,
|
| 163 |
+
uint32_t auxFlags);
|
| 164 |
+
|
| 165 |
+
cudnnStatus_t CUDNNWINAPI
|
| 166 |
+
cudnnGetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
|
| 167 |
+
cudnnRNNAlgo_t *algo,
|
| 168 |
+
cudnnRNNMode_t *cellMode,
|
| 169 |
+
cudnnRNNBiasMode_t *biasMode,
|
| 170 |
+
cudnnDirectionMode_t *dirMode,
|
| 171 |
+
cudnnRNNInputMode_t *inputMode,
|
| 172 |
+
cudnnDataType_t *dataType,
|
| 173 |
+
cudnnDataType_t *mathPrec,
|
| 174 |
+
cudnnMathType_t *mathType,
|
| 175 |
+
int32_t *inputSize,
|
| 176 |
+
int32_t *hiddenSize,
|
| 177 |
+
int32_t *projSize,
|
| 178 |
+
int32_t *numLayers,
|
| 179 |
+
cudnnDropoutDescriptor_t *dropoutDesc,
|
| 180 |
+
uint32_t *auxFlags);
|
| 181 |
+
|
| 182 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 183 |
+
cudnnRNNSetClip_v8(cudnnRNNDescriptor_t rnnDesc,
|
| 184 |
+
cudnnRNNClipMode_t clipMode,
|
| 185 |
+
cudnnNanPropagation_t clipNanOpt,
|
| 186 |
+
double lclip,
|
| 187 |
+
double rclip);
|
| 188 |
+
|
| 189 |
+
cudnnStatus_t CUDNNWINAPI
|
| 190 |
+
cudnnRNNSetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t clipMode, double lclip, double rclip);
|
| 191 |
+
|
| 192 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 193 |
+
cudnnRNNGetClip_v8(cudnnRNNDescriptor_t rnnDesc,
|
| 194 |
+
cudnnRNNClipMode_t *clipMode,
|
| 195 |
+
cudnnNanPropagation_t *clipNanOpt,
|
| 196 |
+
double *lclip,
|
| 197 |
+
double *rclip);
|
| 198 |
+
|
| 199 |
+
cudnnStatus_t CUDNNWINAPI
|
| 200 |
+
cudnnRNNGetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t *clipMode, double *lclip, double *rclip);
|
| 201 |
+
|
| 202 |
+
cudnnStatus_t CUDNNWINAPI
|
| 203 |
+
cudnnBuildRNNDynamic(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int miniBatch);
|
| 204 |
+
|
| 205 |
+
cudnnStatus_t CUDNNWINAPI
|
| 206 |
+
cudnnGetRNNTempSpaceSizes(cudnnHandle_t handle,
|
| 207 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 208 |
+
cudnnForwardMode_t fwdMode,
|
| 209 |
+
cudnnRNNDataDescriptor_t xDesc,
|
| 210 |
+
size_t *workSpaceSize,
|
| 211 |
+
size_t *reserveSpaceSize);
|
| 212 |
+
|
| 213 |
+
cudnnStatus_t CUDNNWINAPI
|
| 214 |
+
cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, size_t *weightSpaceSize);
|
| 215 |
+
|
| 216 |
+
cudnnStatus_t CUDNNWINAPI
|
| 217 |
+
cudnnGetRNNWeightParams(cudnnHandle_t handle,
|
| 218 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 219 |
+
int32_t pseudoLayer,
|
| 220 |
+
size_t weightSpaceSize,
|
| 221 |
+
const void *weightSpace,
|
| 222 |
+
int32_t linLayerID,
|
| 223 |
+
cudnnTensorDescriptor_t mDesc,
|
| 224 |
+
void **mAddr,
|
| 225 |
+
cudnnTensorDescriptor_t bDesc,
|
| 226 |
+
void **bAddr);
|
| 227 |
+
|
| 228 |
+
cudnnStatus_t CUDNNWINAPI
|
| 229 |
+
cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc);
|
| 230 |
+
|
| 231 |
+
cudnnStatus_t CUDNNWINAPI
|
| 232 |
+
cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc);
|
| 233 |
+
|
| 234 |
+
cudnnStatus_t CUDNNWINAPI
|
| 235 |
+
cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
|
| 236 |
+
cudnnDataType_t dataType,
|
| 237 |
+
cudnnRNNDataLayout_t layout,
|
| 238 |
+
int maxSeqLength,
|
| 239 |
+
int batchSize,
|
| 240 |
+
int vectorSize,
|
| 241 |
+
const int seqLengthArray[], /* length of each sequence in the batch */
|
| 242 |
+
void *paddingFill); /* symbol for filling padding position in output */
|
| 243 |
+
|
| 244 |
+
cudnnStatus_t CUDNNWINAPI
|
| 245 |
+
cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
|
| 246 |
+
cudnnDataType_t *dataType,
|
| 247 |
+
cudnnRNNDataLayout_t *layout,
|
| 248 |
+
int *maxSeqLength,
|
| 249 |
+
int *batchSize,
|
| 250 |
+
int *vectorSize,
|
| 251 |
+
int arrayLengthRequested,
|
| 252 |
+
int seqLengthArray[],
|
| 253 |
+
void *paddingFill);
|
| 254 |
+
|
| 255 |
+
cudnnStatus_t CUDNNWINAPI
|
| 256 |
+
cudnnRNNForward(cudnnHandle_t handle,
|
| 257 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 258 |
+
cudnnForwardMode_t fwdMode,
|
| 259 |
+
const int32_t devSeqLengths[],
|
| 260 |
+
cudnnRNNDataDescriptor_t xDesc,
|
| 261 |
+
const void *x,
|
| 262 |
+
cudnnRNNDataDescriptor_t yDesc,
|
| 263 |
+
void *y,
|
| 264 |
+
cudnnTensorDescriptor_t hDesc,
|
| 265 |
+
const void *hx,
|
| 266 |
+
void *hy,
|
| 267 |
+
cudnnTensorDescriptor_t cDesc,
|
| 268 |
+
const void *cx,
|
| 269 |
+
void *cy,
|
| 270 |
+
size_t weightSpaceSize,
|
| 271 |
+
const void *weightSpace,
|
| 272 |
+
size_t workSpaceSize,
|
| 273 |
+
void *workSpace,
|
| 274 |
+
size_t reserveSpaceSize,
|
| 275 |
+
void *reserveSpace);
|
| 276 |
+
|
| 277 |
+
/* Sequence data descriptor */
|
| 278 |
+
|
| 279 |
+
typedef enum {
|
| 280 |
+
CUDNN_SEQDATA_TIME_DIM = 0, /* index in time */
|
| 281 |
+
CUDNN_SEQDATA_BATCH_DIM = 1, /* index in batch */
|
| 282 |
+
CUDNN_SEQDATA_BEAM_DIM = 2, /* index in beam */
|
| 283 |
+
CUDNN_SEQDATA_VECT_DIM = 3 /* index in vector */
|
| 284 |
+
} cudnnSeqDataAxis_t;
|
| 285 |
+
|
| 286 |
+
struct cudnnSeqDataStruct;
|
| 287 |
+
typedef struct cudnnSeqDataStruct *cudnnSeqDataDescriptor_t CUDNN_DEPRECATED;
|
| 288 |
+
|
| 289 |
+
#define CUDNN_SEQDATA_DIM_COUNT 4 /* dimension count */
|
| 290 |
+
|
| 291 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 292 |
+
cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc);
|
| 293 |
+
|
| 294 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 295 |
+
cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc);
|
| 296 |
+
|
| 297 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 298 |
+
cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc,
|
| 299 |
+
cudnnDataType_t dataType,
|
| 300 |
+
int nbDims,
|
| 301 |
+
const int dimA[],
|
| 302 |
+
const cudnnSeqDataAxis_t axes[],
|
| 303 |
+
size_t seqLengthArraySize,
|
| 304 |
+
const int seqLengthArray[],
|
| 305 |
+
void *paddingFill);
|
| 306 |
+
|
| 307 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 308 |
+
cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc,
|
| 309 |
+
cudnnDataType_t *dataType,
|
| 310 |
+
int *nbDims,
|
| 311 |
+
int nbDimsRequested,
|
| 312 |
+
int dimA[],
|
| 313 |
+
cudnnSeqDataAxis_t axes[],
|
| 314 |
+
size_t *seqLengthArraySize,
|
| 315 |
+
size_t seqLengthSizeRequested,
|
| 316 |
+
int seqLengthArray[],
|
| 317 |
+
void *paddingFill);
|
| 318 |
+
|
| 319 |
+
/* Multihead Attention */
|
| 320 |
+
|
| 321 |
+
/*
|
| 322 |
+
* Multi-head attention options passed via 'attnMode' in cudnnSetAttnDescriptor().
|
| 323 |
+
* Use the bitwise OR operator to combine several settings listed below. Additional
|
| 324 |
+
* minor options can be added here w/o changing or introducing new API functions.
|
| 325 |
+
*/
|
| 326 |
+
#define CUDNN_ATTN_QUERYMAP_ALL_TO_ONE 0 /* multiple Q-s map to a single (K,V) set when beam size > 1 */
|
| 327 |
+
#define CUDNN_ATTN_QUERYMAP_ONE_TO_ONE (1U << 0) /* multiple Q-s map to multiple (K,V) sets when beam size > 1 */
|
| 328 |
+
#define CUDNN_ATTN_DISABLE_PROJ_BIASES 0 /* no biases in attention input and output projections */
|
| 329 |
+
#define CUDNN_ATTN_ENABLE_PROJ_BIASES (1U << 1) /* use biases in attention input and output projections */
|
| 330 |
+
|
| 331 |
+
struct cudnnAttnStruct;
|
| 332 |
+
typedef struct cudnnAttnStruct *cudnnAttnDescriptor_t CUDNN_DEPRECATED;
|
| 333 |
+
|
| 334 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 335 |
+
cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc);
|
| 336 |
+
|
| 337 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 338 |
+
cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc);
|
| 339 |
+
|
| 340 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 341 |
+
cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
|
| 342 |
+
unsigned attnMode,
|
| 343 |
+
int nHeads,
|
| 344 |
+
double smScaler,
|
| 345 |
+
cudnnDataType_t dataType,
|
| 346 |
+
cudnnDataType_t computePrec,
|
| 347 |
+
cudnnMathType_t mathType,
|
| 348 |
+
cudnnDropoutDescriptor_t attnDropoutDesc,
|
| 349 |
+
cudnnDropoutDescriptor_t postDropoutDesc,
|
| 350 |
+
int qSize,
|
| 351 |
+
int kSize,
|
| 352 |
+
int vSize,
|
| 353 |
+
int qProjSize,
|
| 354 |
+
int kProjSize,
|
| 355 |
+
int vProjSize,
|
| 356 |
+
int oProjSize,
|
| 357 |
+
int qoMaxSeqLength,
|
| 358 |
+
int kvMaxSeqLength,
|
| 359 |
+
int maxBatchSize,
|
| 360 |
+
int maxBeamSize);
|
| 361 |
+
|
| 362 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 363 |
+
cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
|
| 364 |
+
unsigned *attnMode,
|
| 365 |
+
int *nHeads,
|
| 366 |
+
double *smScaler,
|
| 367 |
+
cudnnDataType_t *dataType,
|
| 368 |
+
cudnnDataType_t *computePrec,
|
| 369 |
+
cudnnMathType_t *mathType,
|
| 370 |
+
cudnnDropoutDescriptor_t *attnDropoutDesc,
|
| 371 |
+
cudnnDropoutDescriptor_t *postDropoutDesc,
|
| 372 |
+
int *qSize,
|
| 373 |
+
int *kSize,
|
| 374 |
+
int *vSize,
|
| 375 |
+
int *qProjSize,
|
| 376 |
+
int *kProjSize,
|
| 377 |
+
int *vProjSize,
|
| 378 |
+
int *oProjSize,
|
| 379 |
+
int *qoMaxSeqLength,
|
| 380 |
+
int *kvMaxSeqLength,
|
| 381 |
+
int *maxBatchSize,
|
| 382 |
+
int *maxBeamSize);
|
| 383 |
+
|
| 384 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 385 |
+
cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle,
|
| 386 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 387 |
+
size_t *weightSizeInBytes,
|
| 388 |
+
size_t *workSpaceSizeInBytes,
|
| 389 |
+
size_t *reserveSpaceSizeInBytes);
|
| 390 |
+
|
| 391 |
+
typedef enum {
|
| 392 |
+
CUDNN_MH_ATTN_Q_WEIGHTS = 0, /* input projection weights for 'queries' */
|
| 393 |
+
CUDNN_MH_ATTN_K_WEIGHTS = 1, /* input projection weights for 'keys' */
|
| 394 |
+
CUDNN_MH_ATTN_V_WEIGHTS = 2, /* input projection weights for 'values' */
|
| 395 |
+
CUDNN_MH_ATTN_O_WEIGHTS = 3, /* output projection weights */
|
| 396 |
+
CUDNN_MH_ATTN_Q_BIASES = 4, /* input projection bias tensor for 'queries' */
|
| 397 |
+
CUDNN_MH_ATTN_K_BIASES = 5, /* input projection bias for 'keys' */
|
| 398 |
+
CUDNN_MH_ATTN_V_BIASES = 6, /* input projection bias for 'values' */
|
| 399 |
+
CUDNN_MH_ATTN_O_BIASES = 7, /* output projection biases */
|
| 400 |
+
} cudnnMultiHeadAttnWeightKind_t;
|
| 401 |
+
|
| 402 |
+
#define CUDNN_ATTN_WKIND_COUNT 8 /* Number of attention weight/bias tensors */
|
| 403 |
+
|
| 404 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 405 |
+
cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle,
|
| 406 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 407 |
+
cudnnMultiHeadAttnWeightKind_t wKind,
|
| 408 |
+
size_t weightSizeInBytes,
|
| 409 |
+
const void *weights,
|
| 410 |
+
cudnnTensorDescriptor_t wDesc,
|
| 411 |
+
void **wAddr);
|
| 412 |
+
|
| 413 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 414 |
+
cudnnMultiHeadAttnForward(cudnnHandle_t handle,
|
| 415 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 416 |
+
int currIdx,
|
| 417 |
+
const int loWinIdx[],
|
| 418 |
+
const int hiWinIdx[],
|
| 419 |
+
const int devSeqLengthsQO[],
|
| 420 |
+
const int devSeqLengthsKV[],
|
| 421 |
+
const cudnnSeqDataDescriptor_t qDesc,
|
| 422 |
+
const void *queries,
|
| 423 |
+
const void *residuals,
|
| 424 |
+
const cudnnSeqDataDescriptor_t kDesc,
|
| 425 |
+
const void *keys,
|
| 426 |
+
const cudnnSeqDataDescriptor_t vDesc,
|
| 427 |
+
const void *values,
|
| 428 |
+
const cudnnSeqDataDescriptor_t oDesc,
|
| 429 |
+
void *out,
|
| 430 |
+
size_t weightSizeInBytes,
|
| 431 |
+
const void *weights,
|
| 432 |
+
size_t workSpaceSizeInBytes,
|
| 433 |
+
void *workSpace,
|
| 434 |
+
size_t reserveSpaceSizeInBytes,
|
| 435 |
+
void *reserveSpace);
|
| 436 |
+
|
| 437 |
+
/*
|
| 438 |
+
* \brief Cross-library version checker.
|
| 439 |
+
* This function is implemented differently in each sub-library. Each sublib
|
| 440 |
+
* checks whether its own version matches that of its dependencies.
|
| 441 |
+
* \returns CUDNN_STATUS_SUCCESS if the version check passes,
|
| 442 |
+
* CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
|
| 443 |
+
*/
|
| 444 |
+
cudnnStatus_t CUDNNWINAPI
|
| 445 |
+
cudnnAdvVersionCheck(void);
|
| 446 |
+
|
| 447 |
+
typedef enum {
|
| 448 |
+
CUDNN_WGRAD_MODE_ADD = 0, /* add partial gradients to wgrad output buffers */
|
| 449 |
+
CUDNN_WGRAD_MODE_SET = 1, /* write partial gradients to wgrad output buffers */
|
| 450 |
+
} cudnnWgradMode_t;
|
| 451 |
+
|
| 452 |
+
cudnnStatus_t CUDNNWINAPI
|
| 453 |
+
cudnnRNNBackwardData_v8(cudnnHandle_t handle,
|
| 454 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 455 |
+
const int32_t devSeqLengths[],
|
| 456 |
+
cudnnRNNDataDescriptor_t yDesc,
|
| 457 |
+
const void *y,
|
| 458 |
+
const void *dy,
|
| 459 |
+
cudnnRNNDataDescriptor_t xDesc,
|
| 460 |
+
void *dx,
|
| 461 |
+
cudnnTensorDescriptor_t hDesc,
|
| 462 |
+
const void *hx,
|
| 463 |
+
const void *dhy,
|
| 464 |
+
void *dhx,
|
| 465 |
+
cudnnTensorDescriptor_t cDesc,
|
| 466 |
+
const void *cx,
|
| 467 |
+
const void *dcy,
|
| 468 |
+
void *dcx,
|
| 469 |
+
size_t weightSpaceSize,
|
| 470 |
+
const void *weightSpace,
|
| 471 |
+
size_t workSpaceSize,
|
| 472 |
+
void *workSpace,
|
| 473 |
+
size_t reserveSpaceSize,
|
| 474 |
+
void *reserveSpace);
|
| 475 |
+
|
| 476 |
+
cudnnStatus_t CUDNNWINAPI
|
| 477 |
+
cudnnRNNBackwardWeights_v8(cudnnHandle_t handle,
|
| 478 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 479 |
+
cudnnWgradMode_t addGrad,
|
| 480 |
+
const int32_t devSeqLengths[],
|
| 481 |
+
cudnnRNNDataDescriptor_t xDesc,
|
| 482 |
+
const void *x,
|
| 483 |
+
cudnnTensorDescriptor_t hDesc,
|
| 484 |
+
const void *hx,
|
| 485 |
+
cudnnRNNDataDescriptor_t yDesc,
|
| 486 |
+
const void *y,
|
| 487 |
+
size_t weightSpaceSize,
|
| 488 |
+
void *dweightSpace,
|
| 489 |
+
size_t workSpaceSize,
|
| 490 |
+
void *workSpace,
|
| 491 |
+
size_t reserveSpaceSize,
|
| 492 |
+
void *reserveSpace);
|
| 493 |
+
|
| 494 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 495 |
+
cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle,
|
| 496 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 497 |
+
const int loWinIdx[],
|
| 498 |
+
const int hiWinIdx[],
|
| 499 |
+
const int devSeqLengthsDQDO[],
|
| 500 |
+
const int devSeqLengthsDKDV[],
|
| 501 |
+
const cudnnSeqDataDescriptor_t doDesc,
|
| 502 |
+
const void *dout,
|
| 503 |
+
const cudnnSeqDataDescriptor_t dqDesc,
|
| 504 |
+
void *dqueries,
|
| 505 |
+
const void *queries,
|
| 506 |
+
const cudnnSeqDataDescriptor_t dkDesc,
|
| 507 |
+
void *dkeys,
|
| 508 |
+
const void *keys,
|
| 509 |
+
const cudnnSeqDataDescriptor_t dvDesc,
|
| 510 |
+
void *dvalues,
|
| 511 |
+
const void *values,
|
| 512 |
+
size_t weightSizeInBytes,
|
| 513 |
+
const void *weights,
|
| 514 |
+
size_t workSpaceSizeInBytes,
|
| 515 |
+
void *workSpace,
|
| 516 |
+
size_t reserveSpaceSizeInBytes,
|
| 517 |
+
void *reserveSpace);
|
| 518 |
+
|
| 519 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 520 |
+
cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle,
|
| 521 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 522 |
+
cudnnWgradMode_t addGrad,
|
| 523 |
+
const cudnnSeqDataDescriptor_t qDesc,
|
| 524 |
+
const void *queries,
|
| 525 |
+
const cudnnSeqDataDescriptor_t kDesc,
|
| 526 |
+
const void *keys,
|
| 527 |
+
const cudnnSeqDataDescriptor_t vDesc,
|
| 528 |
+
const void *values,
|
| 529 |
+
const cudnnSeqDataDescriptor_t doDesc,
|
| 530 |
+
const void *dout,
|
| 531 |
+
size_t weightSizeInBytes,
|
| 532 |
+
const void *weights,
|
| 533 |
+
void *dweights,
|
| 534 |
+
size_t workSpaceSizeInBytes,
|
| 535 |
+
void *workSpace,
|
| 536 |
+
size_t reserveSpaceSizeInBytes,
|
| 537 |
+
void *reserveSpace);
|
| 538 |
+
|
| 539 |
+
/*
|
| 540 |
+
* CTC (Connectionist Temporal Classification) loss descriptor create/destory/set/get functions
|
| 541 |
+
*/
|
| 542 |
+
/* Input normalization mode for loss function */
|
| 543 |
+
typedef enum {
|
| 544 |
+
CUDNN_LOSS_NORMALIZATION_NONE = 0,
|
| 545 |
+
CUDNN_LOSS_NORMALIZATION_SOFTMAX = 1,
|
| 546 |
+
} cudnnLossNormalizationMode_t;
|
| 547 |
+
|
| 548 |
+
cudnnStatus_t CUDNNWINAPI
|
| 549 |
+
cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc);
|
| 550 |
+
|
| 551 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 552 |
+
cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType);
|
| 553 |
+
|
| 554 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 555 |
+
cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 556 |
+
cudnnDataType_t compType,
|
| 557 |
+
cudnnLossNormalizationMode_t normMode,
|
| 558 |
+
cudnnNanPropagation_t gradMode);
|
| 559 |
+
|
| 560 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 561 |
+
cudnnSetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 562 |
+
cudnnDataType_t compType,
|
| 563 |
+
cudnnLossNormalizationMode_t normMode,
|
| 564 |
+
cudnnNanPropagation_t gradMode,
|
| 565 |
+
int maxLabelLength);
|
| 566 |
+
|
| 567 |
+
cudnnStatus_t CUDNNWINAPI
|
| 568 |
+
cudnnSetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 569 |
+
cudnnDataType_t compType,
|
| 570 |
+
cudnnLossNormalizationMode_t normMode,
|
| 571 |
+
cudnnCTCGradMode_t ctcGradMode,
|
| 572 |
+
int maxLabelLength);
|
| 573 |
+
|
| 574 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 575 |
+
cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType);
|
| 576 |
+
|
| 577 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 578 |
+
cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 579 |
+
cudnnDataType_t *compType,
|
| 580 |
+
cudnnLossNormalizationMode_t *normMode,
|
| 581 |
+
cudnnNanPropagation_t *gradMode);
|
| 582 |
+
|
| 583 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 584 |
+
cudnnGetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 585 |
+
cudnnDataType_t *compType,
|
| 586 |
+
cudnnLossNormalizationMode_t *normMode,
|
| 587 |
+
cudnnNanPropagation_t *gradMode,
|
| 588 |
+
int *maxLabelLength);
|
| 589 |
+
|
| 590 |
+
cudnnStatus_t CUDNNWINAPI
|
| 591 |
+
cudnnGetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 592 |
+
cudnnDataType_t *compType,
|
| 593 |
+
cudnnLossNormalizationMode_t *normMode,
|
| 594 |
+
cudnnCTCGradMode_t *ctcGradMode,
|
| 595 |
+
int *maxLabelLength);
|
| 596 |
+
|
| 597 |
+
cudnnStatus_t CUDNNWINAPI
|
| 598 |
+
cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc);
|
| 599 |
+
|
| 600 |
+
/* return the ctc costs and gradients, given the probabilities and labels */
|
| 601 |
+
cudnnStatus_t CUDNNWINAPI
|
| 602 |
+
cudnnCTCLoss(
|
| 603 |
+
cudnnHandle_t handle,
|
| 604 |
+
const cudnnTensorDescriptor_t
|
| 605 |
+
probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
|
| 606 |
+
mini batch size, A is the alphabet size) */
|
| 607 |
+
const void *probs, /* probabilities after softmax, in GPU memory */
|
| 608 |
+
const int hostLabels[], /* labels, in CPU memory */
|
| 609 |
+
const int hostLabelLengths[], /* the length of each label, in CPU memory */
|
| 610 |
+
const int hostInputLengths[], /* the lengths of timing steps in each batch, in CPU memory */
|
| 611 |
+
void *costs, /* the returned costs of CTC, in GPU memory */
|
| 612 |
+
const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
|
| 613 |
+
void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
|
| 614 |
+
cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
|
| 615 |
+
cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 616 |
+
void *workspace, /* pointer to the workspace, in GPU memory */
|
| 617 |
+
size_t workSpaceSizeInBytes); /* size of the workspace */
|
| 618 |
+
|
| 619 |
+
/* return the ctc costs and gradients, given the probabilities and labels */
|
| 620 |
+
cudnnStatus_t CUDNNWINAPI
|
| 621 |
+
cudnnCTCLoss_v8(
|
| 622 |
+
cudnnHandle_t handle,
|
| 623 |
+
cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
|
| 624 |
+
cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 625 |
+
const cudnnTensorDescriptor_t
|
| 626 |
+
probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
|
| 627 |
+
mini batch size, A is the alphabet size) */
|
| 628 |
+
const void *probs, /* probabilities after softmax, in GPU memory */
|
| 629 |
+
const int labels[], /* labels, in GPU memory */
|
| 630 |
+
const int labelLengths[], /* the length of each label, in GPU memory */
|
| 631 |
+
const int inputLengths[], /* the lengths of timing steps in each batch, in GPU memory */
|
| 632 |
+
void *costs, /* the returned costs of CTC, in GPU memory */
|
| 633 |
+
const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
|
| 634 |
+
void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
|
| 635 |
+
size_t workSpaceSizeInBytes, /* size of the workspace */
|
| 636 |
+
void *workspace); /* pointer to the workspace, in GPU memory */
|
| 637 |
+
|
| 638 |
+
/* return the workspace size needed for ctc */
|
| 639 |
+
cudnnStatus_t CUDNNWINAPI
|
| 640 |
+
cudnnGetCTCLossWorkspaceSize(
|
| 641 |
+
cudnnHandle_t handle,
|
| 642 |
+
const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
|
| 643 |
+
timing steps, N is the mini batch size, A is the alphabet size) */
|
| 644 |
+
const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
|
| 645 |
+
dimensions are T,N,A. To compute costs
|
| 646 |
+
only, set it to NULL */
|
| 647 |
+
const int *labels, /* labels, in CPU memory */
|
| 648 |
+
const int *labelLengths, /* the length of each label, in CPU memory */
|
| 649 |
+
const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */
|
| 650 |
+
cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
|
| 651 |
+
cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 652 |
+
size_t *sizeInBytes); /* pointer to the returned workspace size */
|
| 653 |
+
|
| 654 |
+
/* return the workspace size needed for ctc */
|
| 655 |
+
cudnnStatus_t CUDNNWINAPI
|
| 656 |
+
cudnnGetCTCLossWorkspaceSize_v8(
|
| 657 |
+
cudnnHandle_t handle,
|
| 658 |
+
cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
|
| 659 |
+
cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 660 |
+
const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
|
| 661 |
+
timing steps, N is the mini batch size, A is the alphabet size) */
|
| 662 |
+
const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
|
| 663 |
+
dimensions are T,N,A. To compute costs
|
| 664 |
+
only, set it to NULL */
|
| 665 |
+
size_t *sizeInBytes); /* pointer to the returned workspace size */
|
| 666 |
+
|
| 667 |
+
#if defined(__cplusplus)
|
| 668 |
+
}
|
| 669 |
+
#endif
|
| 670 |
+
|
| 671 |
+
#endif /* CUDNN_ADV_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend_v9.h
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#ifndef _CUDNN_BACKEND_H_
|
| 51 |
+
#define _CUDNN_BACKEND_H_
|
| 52 |
+
|
| 53 |
+
/*
|
| 54 |
+
* The content of this header has been moved into cudnn_graph.h.
|
| 55 |
+
* This header is kept for the backward compatibility purpose.
|
| 56 |
+
*/
|
| 57 |
+
|
| 58 |
+
#include "cudnn_graph.h"
|
| 59 |
+
|
| 60 |
+
#endif /* _CUDNN_BACKEND_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_graph.h
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* cudnn_graph : cuDNN's basic definitions operations.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_GRAPH_H_)
|
| 55 |
+
#define CUDNN_GRAPH_H_
|
| 56 |
+
|
| 57 |
+
#include <cuda_runtime_api.h>
|
| 58 |
+
#include <library_types.h>
|
| 59 |
+
|
| 60 |
+
#include <stdint.h>
|
| 61 |
+
|
| 62 |
+
#include "cudnn_version.h"
|
| 63 |
+
|
| 64 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 65 |
+
#define CUDNN_GRAPH_MAJOR 9
|
| 66 |
+
#define CUDNN_GRAPH_MINOR 1
|
| 67 |
+
#define CUDNN_GRAPH_PATCH 0
|
| 68 |
+
|
| 69 |
+
#if (CUDNN_GRAPH_MAJOR != CUDNN_MAJOR) || (CUDNN_GRAPH_MINOR != CUDNN_MINOR) || (CUDNN_GRAPH_PATCH != CUDNN_PATCHLEVEL)
|
| 70 |
+
#error Version mismatch in cuDNN GRAPH!!!
|
| 71 |
+
#endif
|
| 72 |
+
|
| 73 |
+
#ifndef CUDNNWINAPI
|
| 74 |
+
#ifdef _WIN32
|
| 75 |
+
#define CUDNNWINAPI __stdcall
|
| 76 |
+
#else
|
| 77 |
+
#define CUDNNWINAPI
|
| 78 |
+
#endif
|
| 79 |
+
#endif
|
| 80 |
+
|
| 81 |
+
/* Warnings for deprecated API-s are enabled using the CUDNN_WARN_DEPRECATED macro */
|
| 82 |
+
#if defined(CUDNN_WARN_DEPRECATED) && (defined(__GNUC__) || defined(__clang__))
|
| 83 |
+
/* GCC, Intel C/C++, Cray C/C++, CLANG, IBM XL C/C++ little endian */
|
| 84 |
+
#define CUDNN_DEPRECATED __attribute__((deprecated))
|
| 85 |
+
#define CUDNN_DEPRECATED_ENUM __attribute__((deprecated))
|
| 86 |
+
#elif defined(CUDNN_WARN_DEPRECATED) && defined(_MSC_VER)
|
| 87 |
+
/* Microsoft Visual C++ */
|
| 88 |
+
#define CUDNN_DEPRECATED __declspec(deprecated)
|
| 89 |
+
#define CUDNN_DEPRECATED_ENUM __declspec(deprecated)
|
| 90 |
+
#elif defined(CUDNN_WARN_DEPRECATED) && (__cplusplus >= 201402L)
|
| 91 |
+
/* C++14 compilers */
|
| 92 |
+
#define CUDNN_DEPRECATED [[deprecated]]
|
| 93 |
+
#define CUDNN_DEPRECATED_ENUM [[deprecated]]
|
| 94 |
+
#else
|
| 95 |
+
/* No support for the deprecated attribute */
|
| 96 |
+
#define CUDNN_DEPRECATED
|
| 97 |
+
#define CUDNN_DEPRECATED_ENUM
|
| 98 |
+
#endif
|
| 99 |
+
|
| 100 |
+
#if defined(__cplusplus)
|
| 101 |
+
extern "C" {
|
| 102 |
+
#endif
|
| 103 |
+
|
| 104 |
+
struct cudnnContext;
|
| 105 |
+
typedef struct cudnnContext *cudnnHandle_t;
|
| 106 |
+
|
| 107 |
+
size_t CUDNNWINAPI
|
| 108 |
+
cudnnGetVersion(void);
|
| 109 |
+
|
| 110 |
+
size_t CUDNNWINAPI
|
| 111 |
+
cudnnGetMaxDeviceVersion(void);
|
| 112 |
+
|
| 113 |
+
/* Returns CUDA Runtime version statically linked against cudnn */
|
| 114 |
+
size_t CUDNNWINAPI
|
| 115 |
+
cudnnGetCudartVersion(void);
|
| 116 |
+
|
| 117 |
+
/*
|
| 118 |
+
* CUDNN return codes
|
| 119 |
+
*/
|
| 120 |
+
typedef enum {
|
| 121 |
+
CUDNN_STATUS_SUCCESS = 0,
|
| 122 |
+
|
| 123 |
+
/* Uncategorized errors */
|
| 124 |
+
CUDNN_STATUS_NOT_INITIALIZED = 1001,
|
| 125 |
+
CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH = 1002,
|
| 126 |
+
CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH = 1003,
|
| 127 |
+
CUDNN_STATUS_DEPRECATED = 1004,
|
| 128 |
+
CUDNN_STATUS_LICENSE_ERROR = 1005,
|
| 129 |
+
CUDNN_STATUS_RUNTIME_IN_PROGRESS = 1006,
|
| 130 |
+
CUDNN_STATUS_RUNTIME_FP_OVERFLOW = 1007,
|
| 131 |
+
|
| 132 |
+
CUDNN_STATUS_BAD_PARAM = 2000,
|
| 133 |
+
CUDNN_STATUS_BAD_PARAM_NULL_POINTER = 2002,
|
| 134 |
+
CUDNN_STATUS_BAD_PARAM_MISALIGNED_POINTER = 2003,
|
| 135 |
+
CUDNN_STATUS_BAD_PARAM_NOT_FINALIZED = 2004,
|
| 136 |
+
CUDNN_STATUS_BAD_PARAM_OUT_OF_BOUND = 2005,
|
| 137 |
+
CUDNN_STATUS_BAD_PARAM_SIZE_INSUFFICIENT = 2006,
|
| 138 |
+
CUDNN_STATUS_BAD_PARAM_STREAM_MISMATCH = 2007,
|
| 139 |
+
CUDNN_STATUS_BAD_PARAM_SHAPE_MISMATCH = 2008,
|
| 140 |
+
CUDNN_STATUS_BAD_PARAM_DUPLICATED_ENTRIES = 2009,
|
| 141 |
+
CUDNN_STATUS_BAD_PARAM_ATTRIBUTE_TYPE = 2010,
|
| 142 |
+
|
| 143 |
+
CUDNN_STATUS_NOT_SUPPORTED = 3000,
|
| 144 |
+
CUDNN_STATUS_NOT_SUPPORTED_GRAPH_PATTERN = 3001,
|
| 145 |
+
CUDNN_STATUS_NOT_SUPPORTED_SHAPE = 3002,
|
| 146 |
+
CUDNN_STATUS_NOT_SUPPORTED_DATA_TYPE = 3003,
|
| 147 |
+
CUDNN_STATUS_NOT_SUPPORTED_LAYOUT = 3004,
|
| 148 |
+
CUDNN_STATUS_NOT_SUPPORTED_INCOMPATIBLE_CUDA_DRIVER = 3005,
|
| 149 |
+
CUDNN_STATUS_NOT_SUPPORTED_INCOMPATIBLE_CUDART = 3006,
|
| 150 |
+
CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH = 3007,
|
| 151 |
+
CUDNN_STATUS_NOT_SUPPORTED_RUNTIME_PREREQUISITE_MISSING = 3008,
|
| 152 |
+
CUDNN_STATUS_NOT_SUPPORTED_SUBLIBRARY_UNAVAILABLE = 3009,
|
| 153 |
+
CUDNN_STATUS_NOT_SUPPORTED_SHARED_MEMORY_INSUFFICIENT = 3010,
|
| 154 |
+
CUDNN_STATUS_NOT_SUPPORTED_PADDING = 3011,
|
| 155 |
+
CUDNN_STATUS_NOT_SUPPORTED_BAD_LAUNCH_PARAM = 3012,
|
| 156 |
+
|
| 157 |
+
CUDNN_STATUS_INTERNAL_ERROR = 4000,
|
| 158 |
+
CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED = 4001,
|
| 159 |
+
CUDNN_STATUS_INTERNAL_ERROR_UNEXPECTED_VALUE = 4002,
|
| 160 |
+
CUDNN_STATUS_INTERNAL_ERROR_HOST_ALLOCATION_FAILED = 4003,
|
| 161 |
+
CUDNN_STATUS_INTERNAL_ERROR_DEVICE_ALLOCATION_FAILED = 4004,
|
| 162 |
+
CUDNN_STATUS_INTERNAL_ERROR_BAD_LAUNCH_PARAM = 4005,
|
| 163 |
+
CUDNN_STATUS_INTERNAL_ERROR_TEXTURE_CREATION_FAILED = 4006,
|
| 164 |
+
|
| 165 |
+
CUDNN_STATUS_EXECUTION_FAILED = 5000,
|
| 166 |
+
CUDNN_STATUS_EXECUTION_FAILED_CUDA_DRIVER = 5001,
|
| 167 |
+
CUDNN_STATUS_EXECUTION_FAILED_CUBLAS = 5002,
|
| 168 |
+
CUDNN_STATUS_EXECUTION_FAILED_CUDART = 5003,
|
| 169 |
+
CUDNN_STATUS_EXECUTION_FAILED_CURAND = 5004,
|
| 170 |
+
|
| 171 |
+
CUDNN_STATUS_ALLOC_FAILED CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_INTERNAL_ERROR_HOST_ALLOCATION_FAILED,
|
| 172 |
+
CUDNN_STATUS_INVALID_VALUE CUDNN_DEPRECATED_ENUM = 2001 /* please transition to CUDNN_STATUS_BAD_PARAM instead */,
|
| 173 |
+
CUDNN_STATUS_ARCH_MISMATCH CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH,
|
| 174 |
+
CUDNN_STATUS_MAPPING_ERROR CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_INTERNAL_ERROR_TEXTURE_CREATION_FAILED,
|
| 175 |
+
CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING CUDNN_DEPRECATED_ENUM =
|
| 176 |
+
CUDNN_STATUS_NOT_SUPPORTED_RUNTIME_PREREQUISITE_MISSING,
|
| 177 |
+
CUDNN_STATUS_VERSION_MISMATCH CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH,
|
| 178 |
+
} cudnnStatus_t;
|
| 179 |
+
|
| 180 |
+
#define CUDNN_STATUS_FULL_ERROR_CODE(category, specific_err) ((cudnnStatus_t)(0 + (category) + (specific_err)))
|
| 181 |
+
#define CUDNN_STATUS_CATEGORY(full_error_code) ((full_error_code) / 1000 * 1000)
|
| 182 |
+
#define CUDNN_STATUS_SPECIFIC_ERROR(full_error_code) ((full_error_code) % 1000)
|
| 183 |
+
|
| 184 |
+
/* human-readable error messages */
|
| 185 |
+
const char *CUDNNWINAPI
|
| 186 |
+
cudnnGetErrorString(cudnnStatus_t status);
|
| 187 |
+
|
| 188 |
+
void CUDNNWINAPI
|
| 189 |
+
cudnnGetLastErrorString(char *message, size_t max_size);
|
| 190 |
+
|
| 191 |
+
/* Forward definition in this version only */
|
| 192 |
+
typedef struct cudnnRuntimeTag_t cudnnRuntimeTag_t CUDNN_DEPRECATED;
|
| 193 |
+
|
| 194 |
+
typedef enum {
|
| 195 |
+
CUDNN_ERRQUERY_RAWCODE = 0,
|
| 196 |
+
CUDNN_ERRQUERY_NONBLOCKING = 1,
|
| 197 |
+
CUDNN_ERRQUERY_BLOCKING = 2,
|
| 198 |
+
} cudnnErrQueryMode_t;
|
| 199 |
+
|
| 200 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 201 |
+
cudnnQueryRuntimeError(cudnnHandle_t handle, cudnnStatus_t *rstatus, cudnnErrQueryMode_t mode, cudnnRuntimeTag_t *tag);
|
| 202 |
+
|
| 203 |
+
cudnnStatus_t CUDNNWINAPI
|
| 204 |
+
cudnnGetProperty(libraryPropertyType type, int *value);
|
| 205 |
+
|
| 206 |
+
cudnnStatus_t CUDNNWINAPI
|
| 207 |
+
cudnnCreate(cudnnHandle_t *handle);
|
| 208 |
+
cudnnStatus_t CUDNNWINAPI
|
| 209 |
+
cudnnDestroy(cudnnHandle_t handle);
|
| 210 |
+
cudnnStatus_t CUDNNWINAPI
|
| 211 |
+
cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
|
| 212 |
+
cudnnStatus_t CUDNNWINAPI
|
| 213 |
+
cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId);
|
| 214 |
+
/*
|
| 215 |
+
* CUDNN data type
|
| 216 |
+
*/
|
| 217 |
+
typedef enum {
|
| 218 |
+
CUDNN_DATA_FLOAT = 0,
|
| 219 |
+
CUDNN_DATA_DOUBLE = 1,
|
| 220 |
+
CUDNN_DATA_HALF = 2,
|
| 221 |
+
CUDNN_DATA_INT8 = 3,
|
| 222 |
+
CUDNN_DATA_INT32 = 4,
|
| 223 |
+
CUDNN_DATA_INT8x4 CUDNN_DEPRECATED_ENUM = 5,
|
| 224 |
+
CUDNN_DATA_UINT8 = 6,
|
| 225 |
+
CUDNN_DATA_UINT8x4 CUDNN_DEPRECATED_ENUM = 7,
|
| 226 |
+
CUDNN_DATA_INT8x32 CUDNN_DEPRECATED_ENUM = 8,
|
| 227 |
+
CUDNN_DATA_BFLOAT16 = 9,
|
| 228 |
+
CUDNN_DATA_INT64 = 10,
|
| 229 |
+
CUDNN_DATA_BOOLEAN = 11,
|
| 230 |
+
CUDNN_DATA_FP8_E4M3 = 12,
|
| 231 |
+
CUDNN_DATA_FP8_E5M2 = 13,
|
| 232 |
+
CUDNN_DATA_FAST_FLOAT_FOR_FP8 = 14,
|
| 233 |
+
} cudnnDataType_t;
|
| 234 |
+
|
| 235 |
+
/*
|
| 236 |
+
* CUDNN math type
|
| 237 |
+
*/
|
| 238 |
+
typedef enum {
|
| 239 |
+
CUDNN_DEFAULT_MATH = 0,
|
| 240 |
+
CUDNN_TENSOR_OP_MATH = 1,
|
| 241 |
+
CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION = 2,
|
| 242 |
+
CUDNN_FMA_MATH = 3,
|
| 243 |
+
} cudnnMathType_t;
|
| 244 |
+
|
| 245 |
+
/*
|
| 246 |
+
* CUDNN propagate Nan
|
| 247 |
+
*/
|
| 248 |
+
typedef enum {
|
| 249 |
+
CUDNN_NOT_PROPAGATE_NAN CUDNN_DEPRECATED_ENUM = 0,
|
| 250 |
+
CUDNN_PROPAGATE_NAN CUDNN_DEPRECATED_ENUM = 1,
|
| 251 |
+
} cudnnNanPropagation_t;
|
| 252 |
+
|
| 253 |
+
/*
|
| 254 |
+
* Behavior for OOB samples. OOB samples are samples where L+R > T is encountered during the gradient calculation. If
|
| 255 |
+
* gradMode is set to CUDNN_CTC_SKIP_OOB_GRADIENTS, then the CTC loss function does not write to the gradient buffer for
|
| 256 |
+
* that sample. Instead, the current values, even not finite, are retained. If gradMode is set to
|
| 257 |
+
* CUDNN_CTC_ZERO_OOB_GRADIENTS, then the gradient for that sample is set to zero. This guarantees a finite gradient.
|
| 258 |
+
*/
|
| 259 |
+
typedef enum {
|
| 260 |
+
CUDNN_CTC_ZERO_OOB_GRADIENTS = 0,
|
| 261 |
+
CUDNN_CTC_SKIP_OOB_GRADIENTS = 1,
|
| 262 |
+
} cudnnCTCGradMode_t;
|
| 263 |
+
|
| 264 |
+
typedef enum {
|
| 265 |
+
CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */
|
| 266 |
+
CUDNN_TENSOR_NHWC = 1, /* feature maps interleaved ( cStride = 1 )*/
|
| 267 |
+
CUDNN_TENSOR_NCHW_VECT_C = 2, /* each image point is vector of element of C, vector length in data type */
|
| 268 |
+
} cudnnTensorFormat_t;
|
| 269 |
+
|
| 270 |
+
/*
|
| 271 |
+
* CUDNN ReduceTensor op type
|
| 272 |
+
*/
|
| 273 |
+
typedef enum {
|
| 274 |
+
CUDNN_REDUCE_TENSOR_ADD = 0,
|
| 275 |
+
CUDNN_REDUCE_TENSOR_MUL = 1,
|
| 276 |
+
CUDNN_REDUCE_TENSOR_MIN = 2,
|
| 277 |
+
CUDNN_REDUCE_TENSOR_MAX = 3,
|
| 278 |
+
CUDNN_REDUCE_TENSOR_AMAX = 4,
|
| 279 |
+
CUDNN_REDUCE_TENSOR_AVG = 5,
|
| 280 |
+
CUDNN_REDUCE_TENSOR_NORM1 = 6,
|
| 281 |
+
CUDNN_REDUCE_TENSOR_NORM2 = 7,
|
| 282 |
+
CUDNN_REDUCE_TENSOR_MUL_NO_ZEROS = 8,
|
| 283 |
+
} cudnnReduceTensorOp_t;
|
| 284 |
+
|
| 285 |
+
/*
|
| 286 |
+
* activation mode
|
| 287 |
+
*/
|
| 288 |
+
typedef enum {
|
| 289 |
+
CUDNN_ACTIVATION_SIGMOID = 0,
|
| 290 |
+
CUDNN_ACTIVATION_RELU = 1,
|
| 291 |
+
CUDNN_ACTIVATION_TANH = 2,
|
| 292 |
+
CUDNN_ACTIVATION_CLIPPED_RELU = 3,
|
| 293 |
+
CUDNN_ACTIVATION_ELU = 4,
|
| 294 |
+
CUDNN_ACTIVATION_IDENTITY = 5,
|
| 295 |
+
CUDNN_ACTIVATION_SWISH = 6
|
| 296 |
+
} cudnnActivationMode_t CUDNN_DEPRECATED;
|
| 297 |
+
|
| 298 |
+
typedef enum {
|
| 299 |
+
CUDNN_SEV_FATAL = 0,
|
| 300 |
+
CUDNN_SEV_ERROR = 1,
|
| 301 |
+
CUDNN_SEV_WARNING = 2,
|
| 302 |
+
CUDNN_SEV_INFO = 3,
|
| 303 |
+
} cudnnSeverity_t;
|
| 304 |
+
|
| 305 |
+
/* Message masks to be used with cudnnSetCallback() */
|
| 306 |
+
#define CUDNN_SEV_ERROR_EN (1U << CUDNN_SEV_ERROR)
|
| 307 |
+
#define CUDNN_SEV_WARNING_EN (1U << CUDNN_SEV_WARNING)
|
| 308 |
+
#define CUDNN_SEV_INFO_EN (1U << CUDNN_SEV_INFO)
|
| 309 |
+
|
| 310 |
+
/* struct containing useful informaiton for each API call */
|
| 311 |
+
typedef struct cudnnDebugStruct {
|
| 312 |
+
unsigned cudnn_version;
|
| 313 |
+
cudnnStatus_t cudnnStatus;
|
| 314 |
+
unsigned time_sec; /* epoch time in seconds */
|
| 315 |
+
unsigned time_usec; /* microseconds part of epoch time */
|
| 316 |
+
unsigned time_delta; /* time since start in seconds */
|
| 317 |
+
cudnnHandle_t handle; /* cudnn handle */
|
| 318 |
+
cudaStream_t stream; /* cuda stream ID */
|
| 319 |
+
unsigned long long pid; /* process ID */
|
| 320 |
+
unsigned long long tid; /* thread ID */
|
| 321 |
+
int cudaDeviceId; /* CUDA device ID */
|
| 322 |
+
int reserved[15]; /* reserved for future use */
|
| 323 |
+
} cudnnDebug_t;
|
| 324 |
+
|
| 325 |
+
typedef void (*cudnnCallback_t)(cudnnSeverity_t sev, void *udata, const cudnnDebug_t *dbg, const char *msg);
|
| 326 |
+
|
| 327 |
+
cudnnStatus_t CUDNNWINAPI
|
| 328 |
+
cudnnSetCallback(unsigned mask, void *udata, cudnnCallback_t fptr);
|
| 329 |
+
|
| 330 |
+
cudnnStatus_t CUDNNWINAPI
|
| 331 |
+
cudnnGetCallback(unsigned *mask, void **udata, cudnnCallback_t *fptr);
|
| 332 |
+
|
| 333 |
+
/*
|
| 334 |
+
* \brief Cross-library version checker.
|
| 335 |
+
* This function is implemented differently in each sub-library. Each sublib
|
| 336 |
+
* checks whether its own version matches that of its dependencies.
|
| 337 |
+
* \returns CUDNN_STATUS_SUCCESS if the version check passes,
|
| 338 |
+
* CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
|
| 339 |
+
*/
|
| 340 |
+
cudnnStatus_t CUDNNWINAPI
|
| 341 |
+
cudnnGraphVersionCheck(void);
|
| 342 |
+
|
| 343 |
+
/* Maximum supported number of tensor dimensions */
|
| 344 |
+
#define CUDNN_DIM_MAX 8
|
| 345 |
+
|
| 346 |
+
/*
|
| 347 |
+
* convolution mode
|
| 348 |
+
*/
|
| 349 |
+
typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t;
|
| 350 |
+
|
| 351 |
+
/*
|
| 352 |
+
* CUDNN Reorder
|
| 353 |
+
*/
|
| 354 |
+
typedef enum {
|
| 355 |
+
CUDNN_DEFAULT_REORDER = 0,
|
| 356 |
+
CUDNN_NO_REORDER = 1,
|
| 357 |
+
} cudnnReorderType_t CUDNN_DEPRECATED;
|
| 358 |
+
|
| 359 |
+
typedef void *cudnnBackendDescriptor_t;
|
| 360 |
+
|
| 361 |
+
typedef struct cudnnFractionStruct {
|
| 362 |
+
int64_t numerator;
|
| 363 |
+
int64_t denominator;
|
| 364 |
+
} cudnnFraction_t;
|
| 365 |
+
|
| 366 |
+
typedef enum {
|
| 367 |
+
CUDNN_POINTWISE_ADD = 0,
|
| 368 |
+
CUDNN_POINTWISE_ADD_SQUARE = 5,
|
| 369 |
+
CUDNN_POINTWISE_DIV = 6,
|
| 370 |
+
CUDNN_POINTWISE_MAX = 3,
|
| 371 |
+
CUDNN_POINTWISE_MIN = 2,
|
| 372 |
+
CUDNN_POINTWISE_MOD = 7,
|
| 373 |
+
CUDNN_POINTWISE_MUL = 1,
|
| 374 |
+
CUDNN_POINTWISE_POW = 8,
|
| 375 |
+
CUDNN_POINTWISE_SUB = 9,
|
| 376 |
+
|
| 377 |
+
CUDNN_POINTWISE_ABS = 10,
|
| 378 |
+
CUDNN_POINTWISE_CEIL = 11,
|
| 379 |
+
CUDNN_POINTWISE_COS = 12,
|
| 380 |
+
CUDNN_POINTWISE_EXP = 13,
|
| 381 |
+
CUDNN_POINTWISE_FLOOR = 14,
|
| 382 |
+
CUDNN_POINTWISE_LOG = 15,
|
| 383 |
+
CUDNN_POINTWISE_NEG = 16,
|
| 384 |
+
CUDNN_POINTWISE_RSQRT = 17,
|
| 385 |
+
CUDNN_POINTWISE_SIN = 18,
|
| 386 |
+
CUDNN_POINTWISE_SQRT = 4,
|
| 387 |
+
CUDNN_POINTWISE_TAN = 19,
|
| 388 |
+
CUDNN_POINTWISE_ERF = 20,
|
| 389 |
+
CUDNN_POINTWISE_IDENTITY = 21,
|
| 390 |
+
CUDNN_POINTWISE_RECIPROCAL = 22,
|
| 391 |
+
CUDNN_POINTWISE_ATAN2 = 23,
|
| 392 |
+
|
| 393 |
+
CUDNN_POINTWISE_RELU_FWD = 100,
|
| 394 |
+
CUDNN_POINTWISE_TANH_FWD = 101,
|
| 395 |
+
CUDNN_POINTWISE_SIGMOID_FWD = 102,
|
| 396 |
+
CUDNN_POINTWISE_ELU_FWD = 103,
|
| 397 |
+
CUDNN_POINTWISE_GELU_FWD = 104,
|
| 398 |
+
CUDNN_POINTWISE_SOFTPLUS_FWD = 105,
|
| 399 |
+
CUDNN_POINTWISE_SWISH_FWD = 106,
|
| 400 |
+
CUDNN_POINTWISE_GELU_APPROX_TANH_FWD = 107,
|
| 401 |
+
|
| 402 |
+
CUDNN_POINTWISE_RELU_BWD = 200,
|
| 403 |
+
CUDNN_POINTWISE_TANH_BWD = 201,
|
| 404 |
+
CUDNN_POINTWISE_SIGMOID_BWD = 202,
|
| 405 |
+
CUDNN_POINTWISE_ELU_BWD = 203,
|
| 406 |
+
CUDNN_POINTWISE_GELU_BWD = 204,
|
| 407 |
+
CUDNN_POINTWISE_SOFTPLUS_BWD = 205,
|
| 408 |
+
CUDNN_POINTWISE_SWISH_BWD = 206,
|
| 409 |
+
CUDNN_POINTWISE_GELU_APPROX_TANH_BWD = 207,
|
| 410 |
+
|
| 411 |
+
CUDNN_POINTWISE_CMP_EQ = 300,
|
| 412 |
+
CUDNN_POINTWISE_CMP_NEQ = 301,
|
| 413 |
+
CUDNN_POINTWISE_CMP_GT = 302,
|
| 414 |
+
CUDNN_POINTWISE_CMP_GE = 303,
|
| 415 |
+
CUDNN_POINTWISE_CMP_LT = 304,
|
| 416 |
+
CUDNN_POINTWISE_CMP_LE = 305,
|
| 417 |
+
|
| 418 |
+
CUDNN_POINTWISE_LOGICAL_AND = 400,
|
| 419 |
+
CUDNN_POINTWISE_LOGICAL_OR = 401,
|
| 420 |
+
CUDNN_POINTWISE_LOGICAL_NOT = 402,
|
| 421 |
+
|
| 422 |
+
CUDNN_POINTWISE_GEN_INDEX = 501,
|
| 423 |
+
|
| 424 |
+
CUDNN_POINTWISE_BINARY_SELECT = 601,
|
| 425 |
+
} cudnnPointwiseMode_t;
|
| 426 |
+
|
| 427 |
+
typedef enum {
|
| 428 |
+
CUDNN_RESAMPLE_NEAREST = 0,
|
| 429 |
+
CUDNN_RESAMPLE_BILINEAR = 1,
|
| 430 |
+
CUDNN_RESAMPLE_AVGPOOL = 2,
|
| 431 |
+
CUDNN_RESAMPLE_AVGPOOL_INCLUDE_PADDING = 2,
|
| 432 |
+
CUDNN_RESAMPLE_AVGPOOL_EXCLUDE_PADDING = 4,
|
| 433 |
+
CUDNN_RESAMPLE_MAXPOOL = 3,
|
| 434 |
+
} cudnnResampleMode_t;
|
| 435 |
+
|
| 436 |
+
typedef enum {
|
| 437 |
+
CUDNN_SIGNAL_SET = 0,
|
| 438 |
+
CUDNN_SIGNAL_WAIT = 1,
|
| 439 |
+
} cudnnSignalMode_t;
|
| 440 |
+
|
| 441 |
+
typedef enum {
|
| 442 |
+
CUDNN_GENSTATS_SUM_SQSUM = 0,
|
| 443 |
+
} cudnnGenStatsMode_t;
|
| 444 |
+
|
| 445 |
+
typedef enum {
|
| 446 |
+
CUDNN_BN_FINALIZE_STATISTICS_TRAINING = 0,
|
| 447 |
+
CUDNN_BN_FINALIZE_STATISTICS_INFERENCE = 1,
|
| 448 |
+
} cudnnBnFinalizeStatsMode_t;
|
| 449 |
+
|
| 450 |
+
typedef enum {
|
| 451 |
+
CUDNN_RNG_DISTRIBUTION_BERNOULLI,
|
| 452 |
+
CUDNN_RNG_DISTRIBUTION_UNIFORM,
|
| 453 |
+
CUDNN_RNG_DISTRIBUTION_NORMAL,
|
| 454 |
+
} cudnnRngDistribution_t;
|
| 455 |
+
|
| 456 |
+
typedef enum {
|
| 457 |
+
CUDNN_ATTR_POINTWISE_MODE = 0,
|
| 458 |
+
CUDNN_ATTR_POINTWISE_MATH_PREC = 1,
|
| 459 |
+
CUDNN_ATTR_POINTWISE_NAN_PROPAGATION CUDNN_DEPRECATED_ENUM = 2,
|
| 460 |
+
CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3,
|
| 461 |
+
CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4,
|
| 462 |
+
CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE = 5,
|
| 463 |
+
CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6,
|
| 464 |
+
CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7,
|
| 465 |
+
CUDNN_ATTR_POINTWISE_SWISH_BETA = 8,
|
| 466 |
+
CUDNN_ATTR_POINTWISE_AXIS = 9,
|
| 467 |
+
|
| 468 |
+
CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100,
|
| 469 |
+
CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101,
|
| 470 |
+
CUDNN_ATTR_CONVOLUTION_DILATIONS = 102,
|
| 471 |
+
CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES = 103,
|
| 472 |
+
CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104,
|
| 473 |
+
CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105,
|
| 474 |
+
CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106,
|
| 475 |
+
|
| 476 |
+
CUDNN_ATTR_ENGINEHEUR_MODE = 200,
|
| 477 |
+
CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH = 201,
|
| 478 |
+
CUDNN_ATTR_ENGINEHEUR_RESULTS = 202,
|
| 479 |
+
CUDNN_ATTR_ENGINEHEUR_SM_COUNT_TARGET = 203,
|
| 480 |
+
|
| 481 |
+
CUDNN_ATTR_ENGINECFG_ENGINE = 300,
|
| 482 |
+
CUDNN_ATTR_ENGINECFG_INTERMEDIATE_INFO = 301,
|
| 483 |
+
CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302,
|
| 484 |
+
|
| 485 |
+
CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400,
|
| 486 |
+
CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401,
|
| 487 |
+
CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402,
|
| 488 |
+
CUDNN_ATTR_EXECUTION_PLAN_COMPUTED_INTERMEDIATE_UIDS = 403,
|
| 489 |
+
CUDNN_ATTR_EXECUTION_PLAN_RUN_ONLY_INTERMEDIATE_UIDS = 404,
|
| 490 |
+
CUDNN_ATTR_EXECUTION_PLAN_JSON_REPRESENTATION = 405,
|
| 491 |
+
|
| 492 |
+
CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500,
|
| 493 |
+
CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501,
|
| 494 |
+
CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502,
|
| 495 |
+
CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_ATTRIBUTES = 503,
|
| 496 |
+
|
| 497 |
+
CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600,
|
| 498 |
+
CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE = 601,
|
| 499 |
+
|
| 500 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700,
|
| 501 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701,
|
| 502 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702,
|
| 503 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703,
|
| 504 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704,
|
| 505 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705,
|
| 506 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706,
|
| 507 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707,
|
| 508 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708,
|
| 509 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709,
|
| 510 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710,
|
| 511 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711,
|
| 512 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712,
|
| 513 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713,
|
| 514 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC = 714,
|
| 515 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715,
|
| 516 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716,
|
| 517 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717,
|
| 518 |
+
|
| 519 |
+
CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR = 750,
|
| 520 |
+
CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751,
|
| 521 |
+
CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752,
|
| 522 |
+
CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753,
|
| 523 |
+
CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754,
|
| 524 |
+
CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755,
|
| 525 |
+
CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756,
|
| 526 |
+
CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757,
|
| 527 |
+
CUDNN_ATTR_OPERATION_POINTWISE_TDESC = 758,
|
| 528 |
+
|
| 529 |
+
CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770,
|
| 530 |
+
CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC = 771,
|
| 531 |
+
CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772,
|
| 532 |
+
CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773,
|
| 533 |
+
CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC = 774,
|
| 534 |
+
|
| 535 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780,
|
| 536 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781,
|
| 537 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782,
|
| 538 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783,
|
| 539 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784,
|
| 540 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785,
|
| 541 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786,
|
| 542 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787,
|
| 543 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC = 788,
|
| 544 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789,
|
| 545 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790,
|
| 546 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791,
|
| 547 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792,
|
| 548 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793,
|
| 549 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794,
|
| 550 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795,
|
| 551 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796,
|
| 552 |
+
|
| 553 |
+
CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800,
|
| 554 |
+
CUDNN_ATTR_OPERATIONGRAPH_OPS = 801,
|
| 555 |
+
CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT = 802,
|
| 556 |
+
|
| 557 |
+
CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900,
|
| 558 |
+
CUDNN_ATTR_TENSOR_DATA_TYPE = 901,
|
| 559 |
+
CUDNN_ATTR_TENSOR_DIMENSIONS = 902,
|
| 560 |
+
CUDNN_ATTR_TENSOR_STRIDES = 903,
|
| 561 |
+
CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904,
|
| 562 |
+
CUDNN_ATTR_TENSOR_VECTORIZED_DIMENSION = 905,
|
| 563 |
+
CUDNN_ATTR_TENSOR_UNIQUE_ID = 906,
|
| 564 |
+
CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907,
|
| 565 |
+
CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908,
|
| 566 |
+
CUDNN_ATTR_TENSOR_REORDERING_MODE = 909,
|
| 567 |
+
CUDNN_ATTR_TENSOR_RAGGED_OFFSET_DESC = 913,
|
| 568 |
+
|
| 569 |
+
CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000,
|
| 570 |
+
CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS = 1001,
|
| 571 |
+
CUDNN_ATTR_VARIANT_PACK_INTERMEDIATES = 1002,
|
| 572 |
+
CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003,
|
| 573 |
+
|
| 574 |
+
CUDNN_ATTR_LAYOUT_INFO_TENSOR_UID = 1100,
|
| 575 |
+
CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101,
|
| 576 |
+
|
| 577 |
+
CUDNN_ATTR_KNOB_INFO_TYPE = 1200,
|
| 578 |
+
CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE = 1201,
|
| 579 |
+
CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE = 1202,
|
| 580 |
+
CUDNN_ATTR_KNOB_INFO_STRIDE = 1203,
|
| 581 |
+
|
| 582 |
+
CUDNN_ATTR_ENGINE_OPERATION_GRAPH = 1300,
|
| 583 |
+
CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301,
|
| 584 |
+
CUDNN_ATTR_ENGINE_KNOB_INFO = 1302,
|
| 585 |
+
CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303,
|
| 586 |
+
CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304,
|
| 587 |
+
CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305,
|
| 588 |
+
CUDNN_ATTR_ENGINE_SM_COUNT_TARGET = 1306,
|
| 589 |
+
|
| 590 |
+
CUDNN_ATTR_MATMUL_COMP_TYPE = 1500,
|
| 591 |
+
CUDNN_ATTR_MATMUL_PADDING_VALUE = 1503,
|
| 592 |
+
|
| 593 |
+
CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520,
|
| 594 |
+
CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521,
|
| 595 |
+
CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522,
|
| 596 |
+
CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523,
|
| 597 |
+
CUDNN_ATTR_OPERATION_MATMUL_IRREGULARLY_STRIDED_BATCH_COUNT CUDNN_DEPRECATED_ENUM = 1524,
|
| 598 |
+
CUDNN_ATTR_OPERATION_MATMUL_GEMM_M_OVERRIDE_DESC = 1525,
|
| 599 |
+
CUDNN_ATTR_OPERATION_MATMUL_GEMM_N_OVERRIDE_DESC = 1526,
|
| 600 |
+
CUDNN_ATTR_OPERATION_MATMUL_GEMM_K_OVERRIDE_DESC = 1527,
|
| 601 |
+
|
| 602 |
+
CUDNN_ATTR_REDUCTION_OPERATOR = 1600,
|
| 603 |
+
CUDNN_ATTR_REDUCTION_COMP_TYPE = 1601,
|
| 604 |
+
|
| 605 |
+
CUDNN_ATTR_OPERATION_REDUCTION_XDESC = 1610,
|
| 606 |
+
CUDNN_ATTR_OPERATION_REDUCTION_YDESC = 1611,
|
| 607 |
+
CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612,
|
| 608 |
+
|
| 609 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620,
|
| 610 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621,
|
| 611 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622,
|
| 612 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623,
|
| 613 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624,
|
| 614 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625,
|
| 615 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626,
|
| 616 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627,
|
| 617 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC = 1628,
|
| 618 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629,
|
| 619 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630,
|
| 620 |
+
|
| 621 |
+
CUDNN_ATTR_RESAMPLE_MODE = 1700,
|
| 622 |
+
CUDNN_ATTR_RESAMPLE_COMP_TYPE = 1701,
|
| 623 |
+
CUDNN_ATTR_RESAMPLE_SPATIAL_DIMS = 1702,
|
| 624 |
+
CUDNN_ATTR_RESAMPLE_POST_PADDINGS = 1703,
|
| 625 |
+
CUDNN_ATTR_RESAMPLE_PRE_PADDINGS = 1704,
|
| 626 |
+
CUDNN_ATTR_RESAMPLE_STRIDES = 1705,
|
| 627 |
+
CUDNN_ATTR_RESAMPLE_WINDOW_DIMS = 1706,
|
| 628 |
+
CUDNN_ATTR_RESAMPLE_NAN_PROPAGATION = 1707,
|
| 629 |
+
CUDNN_ATTR_RESAMPLE_PADDING_MODE = 1708,
|
| 630 |
+
|
| 631 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_XDESC = 1710,
|
| 632 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_YDESC = 1711,
|
| 633 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_IDXDESC = 1712,
|
| 634 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_ALPHA CUDNN_DEPRECATED_ENUM = 1713,
|
| 635 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_BETA CUDNN_DEPRECATED_ENUM = 1714,
|
| 636 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_DESC = 1716,
|
| 637 |
+
|
| 638 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DXDESC = 1720,
|
| 639 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DYDESC = 1721,
|
| 640 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_IDXDESC = 1722,
|
| 641 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_ALPHA CUDNN_DEPRECATED_ENUM = 1723,
|
| 642 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_BETA CUDNN_DEPRECATED_ENUM = 1724,
|
| 643 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DESC = 1725,
|
| 644 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_XDESC = 1726,
|
| 645 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_YDESC = 1727,
|
| 646 |
+
|
| 647 |
+
CUDNN_ATTR_OPERATION_CONCAT_AXIS = 1800,
|
| 648 |
+
CUDNN_ATTR_OPERATION_CONCAT_INPUT_DESCS = 1801,
|
| 649 |
+
CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX = 1802,
|
| 650 |
+
CUDNN_ATTR_OPERATION_CONCAT_OUTPUT_DESC = 1803,
|
| 651 |
+
|
| 652 |
+
CUDNN_ATTR_OPERATION_SIGNAL_MODE = 1900,
|
| 653 |
+
CUDNN_ATTR_OPERATION_SIGNAL_FLAGDESC = 1901,
|
| 654 |
+
CUDNN_ATTR_OPERATION_SIGNAL_VALUE = 1902,
|
| 655 |
+
CUDNN_ATTR_OPERATION_SIGNAL_XDESC = 1903,
|
| 656 |
+
CUDNN_ATTR_OPERATION_SIGNAL_YDESC = 1904,
|
| 657 |
+
|
| 658 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_MODE = 2000,
|
| 659 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_PHASE = 2001,
|
| 660 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_XDESC = 2002,
|
| 661 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC = 2003,
|
| 662 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC = 2004,
|
| 663 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC = 2005,
|
| 664 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC = 2006,
|
| 665 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC = 2007,
|
| 666 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_EXP_AVG_FACTOR_DESC = 2008,
|
| 667 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_MEAN_DESC = 2009,
|
| 668 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_VAR_DESC = 2010,
|
| 669 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_MEAN_DESC = 2011,
|
| 670 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_VAR_DESC = 2012,
|
| 671 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_YDESC = 2013,
|
| 672 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_PEER_STAT_DESCS = 2014,
|
| 673 |
+
|
| 674 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_MODE = 2100,
|
| 675 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_XDESC = 2101,
|
| 676 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC = 2102,
|
| 677 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC = 2103,
|
| 678 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC = 2104,
|
| 679 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC = 2105,
|
| 680 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_EPSILON_DESC = 2106,
|
| 681 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC = 2107,
|
| 682 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC = 2108,
|
| 683 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC = 2109,
|
| 684 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_PEER_STAT_DESCS = 2110,
|
| 685 |
+
|
| 686 |
+
CUDNN_ATTR_OPERATION_RESHAPE_XDESC = 2200,
|
| 687 |
+
CUDNN_ATTR_OPERATION_RESHAPE_YDESC = 2201,
|
| 688 |
+
|
| 689 |
+
CUDNN_ATTR_RNG_DISTRIBUTION = 2300,
|
| 690 |
+
CUDNN_ATTR_RNG_NORMAL_DIST_MEAN = 2301,
|
| 691 |
+
CUDNN_ATTR_RNG_NORMAL_DIST_STANDARD_DEVIATION = 2302,
|
| 692 |
+
CUDNN_ATTR_RNG_UNIFORM_DIST_MAXIMUM = 2303,
|
| 693 |
+
CUDNN_ATTR_RNG_UNIFORM_DIST_MINIMUM = 2304,
|
| 694 |
+
CUDNN_ATTR_RNG_BERNOULLI_DIST_PROBABILITY = 2305,
|
| 695 |
+
|
| 696 |
+
CUDNN_ATTR_OPERATION_RNG_YDESC = 2310,
|
| 697 |
+
CUDNN_ATTR_OPERATION_RNG_SEED = 2311,
|
| 698 |
+
CUDNN_ATTR_OPERATION_RNG_DESC = 2312,
|
| 699 |
+
CUDNN_ATTR_OPERATION_RNG_OFFSET_DESC = 2313,
|
| 700 |
+
} cudnnBackendAttributeName_t;
|
| 701 |
+
|
| 702 |
+
typedef enum {
|
| 703 |
+
CUDNN_TYPE_HANDLE = 0,
|
| 704 |
+
CUDNN_TYPE_DATA_TYPE,
|
| 705 |
+
CUDNN_TYPE_BOOLEAN,
|
| 706 |
+
CUDNN_TYPE_INT64,
|
| 707 |
+
CUDNN_TYPE_FLOAT,
|
| 708 |
+
CUDNN_TYPE_DOUBLE,
|
| 709 |
+
CUDNN_TYPE_VOID_PTR,
|
| 710 |
+
CUDNN_TYPE_CONVOLUTION_MODE,
|
| 711 |
+
CUDNN_TYPE_HEUR_MODE,
|
| 712 |
+
CUDNN_TYPE_KNOB_TYPE,
|
| 713 |
+
CUDNN_TYPE_NAN_PROPOGATION CUDNN_DEPRECATED_ENUM,
|
| 714 |
+
CUDNN_TYPE_NUMERICAL_NOTE,
|
| 715 |
+
CUDNN_TYPE_LAYOUT_TYPE,
|
| 716 |
+
CUDNN_TYPE_ATTRIB_NAME,
|
| 717 |
+
CUDNN_TYPE_POINTWISE_MODE,
|
| 718 |
+
CUDNN_TYPE_BACKEND_DESCRIPTOR,
|
| 719 |
+
CUDNN_TYPE_GENSTATS_MODE,
|
| 720 |
+
CUDNN_TYPE_BN_FINALIZE_STATS_MODE,
|
| 721 |
+
CUDNN_TYPE_REDUCTION_OPERATOR_TYPE,
|
| 722 |
+
CUDNN_TYPE_BEHAVIOR_NOTE,
|
| 723 |
+
CUDNN_TYPE_TENSOR_REORDERING_MODE,
|
| 724 |
+
CUDNN_TYPE_RESAMPLE_MODE,
|
| 725 |
+
CUDNN_TYPE_PADDING_MODE,
|
| 726 |
+
CUDNN_TYPE_INT32,
|
| 727 |
+
CUDNN_TYPE_CHAR,
|
| 728 |
+
CUDNN_TYPE_SIGNAL_MODE,
|
| 729 |
+
CUDNN_TYPE_FRACTION,
|
| 730 |
+
CUDNN_TYPE_NORM_MODE,
|
| 731 |
+
CUDNN_TYPE_NORM_FWD_PHASE,
|
| 732 |
+
CUDNN_TYPE_RNG_DISTRIBUTION
|
| 733 |
+
} cudnnBackendAttributeType_t;
|
| 734 |
+
|
| 735 |
+
typedef enum {
|
| 736 |
+
CUDNN_BACKEND_POINTWISE_DESCRIPTOR = 0,
|
| 737 |
+
CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR,
|
| 738 |
+
CUDNN_BACKEND_ENGINE_DESCRIPTOR,
|
| 739 |
+
CUDNN_BACKEND_ENGINECFG_DESCRIPTOR,
|
| 740 |
+
CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR,
|
| 741 |
+
CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR,
|
| 742 |
+
CUDNN_BACKEND_INTERMEDIATE_INFO_DESCRIPTOR,
|
| 743 |
+
CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR,
|
| 744 |
+
CUDNN_BACKEND_KNOB_INFO_DESCRIPTOR,
|
| 745 |
+
CUDNN_BACKEND_LAYOUT_INFO_DESCRIPTOR,
|
| 746 |
+
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
|
| 747 |
+
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
|
| 748 |
+
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
|
| 749 |
+
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR,
|
| 750 |
+
CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR,
|
| 751 |
+
CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR,
|
| 752 |
+
CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR,
|
| 753 |
+
CUDNN_BACKEND_TENSOR_DESCRIPTOR,
|
| 754 |
+
CUDNN_BACKEND_MATMUL_DESCRIPTOR,
|
| 755 |
+
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR,
|
| 756 |
+
CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR,
|
| 757 |
+
CUDNN_BACKEND_REDUCTION_DESCRIPTOR,
|
| 758 |
+
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR,
|
| 759 |
+
CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR,
|
| 760 |
+
CUDNN_BACKEND_RESAMPLE_DESCRIPTOR,
|
| 761 |
+
CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR,
|
| 762 |
+
CUDNN_BACKEND_OPERATION_RESAMPLE_BWD_DESCRIPTOR,
|
| 763 |
+
CUDNN_BACKEND_OPERATION_CONCAT_DESCRIPTOR,
|
| 764 |
+
CUDNN_BACKEND_OPERATION_SIGNAL_DESCRIPTOR,
|
| 765 |
+
CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR,
|
| 766 |
+
CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR,
|
| 767 |
+
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR,
|
| 768 |
+
CUDNN_BACKEND_RNG_DESCRIPTOR,
|
| 769 |
+
CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR,
|
| 770 |
+
} cudnnBackendDescriptorType_t;
|
| 771 |
+
|
| 772 |
+
typedef enum {
|
| 773 |
+
CUDNN_NUMERICAL_NOTE_TENSOR_CORE = 0,
|
| 774 |
+
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS,
|
| 775 |
+
CUDNN_NUMERICAL_NOTE_REDUCED_PRECISION_REDUCTION,
|
| 776 |
+
CUDNN_NUMERICAL_NOTE_FFT,
|
| 777 |
+
CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC,
|
| 778 |
+
CUDNN_NUMERICAL_NOTE_WINOGRAD,
|
| 779 |
+
CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_4x4,
|
| 780 |
+
CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_6x6,
|
| 781 |
+
CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_13x13,
|
| 782 |
+
CUDNN_NUMERICAL_NOTE_STRICT_NAN_PROP,
|
| 783 |
+
CUDNN_NUMERICAL_NOTE_TYPE_COUNT,
|
| 784 |
+
} cudnnBackendNumericalNote_t;
|
| 785 |
+
|
| 786 |
+
typedef enum {
|
| 787 |
+
CUDNN_BEHAVIOR_NOTE_RUNTIME_COMPILATION = 0,
|
| 788 |
+
CUDNN_BEHAVIOR_NOTE_REQUIRES_FILTER_INT8x32_REORDER = 1,
|
| 789 |
+
CUDNN_BEHAVIOR_NOTE_REQUIRES_BIAS_INT8x32_REORDER = 2,
|
| 790 |
+
CUDNN_BEHAVIOR_NOTE_TYPE_COUNT,
|
| 791 |
+
} cudnnBackendBehaviorNote_t;
|
| 792 |
+
|
| 793 |
+
typedef enum {
|
| 794 |
+
CUDNN_KNOB_TYPE_SPLIT_K CUDNN_DEPRECATED_ENUM = 0,
|
| 795 |
+
CUDNN_KNOB_TYPE_SWIZZLE = 1,
|
| 796 |
+
CUDNN_KNOB_TYPE_TILE_SIZE = 2,
|
| 797 |
+
CUDNN_KNOB_TYPE_USE_TEX CUDNN_DEPRECATED_ENUM = 3,
|
| 798 |
+
CUDNN_KNOB_TYPE_EDGE = 4,
|
| 799 |
+
CUDNN_KNOB_TYPE_KBLOCK CUDNN_DEPRECATED_ENUM = 5,
|
| 800 |
+
CUDNN_KNOB_TYPE_LDGA CUDNN_DEPRECATED_ENUM = 6,
|
| 801 |
+
CUDNN_KNOB_TYPE_LDGB CUDNN_DEPRECATED_ENUM = 7,
|
| 802 |
+
CUDNN_KNOB_TYPE_CHUNK_K CUDNN_DEPRECATED_ENUM = 8,
|
| 803 |
+
CUDNN_KNOB_TYPE_SPLIT_H CUDNN_DEPRECATED_ENUM = 9,
|
| 804 |
+
CUDNN_KNOB_TYPE_WINO_TILE CUDNN_DEPRECATED_ENUM = 10,
|
| 805 |
+
CUDNN_KNOB_TYPE_MULTIPLY = 11,
|
| 806 |
+
CUDNN_KNOB_TYPE_SPLIT_K_BUF = 12,
|
| 807 |
+
CUDNN_KNOB_TYPE_TILEK = 13,
|
| 808 |
+
CUDNN_KNOB_TYPE_STAGES = 14,
|
| 809 |
+
CUDNN_KNOB_TYPE_REDUCTION_MODE = 15,
|
| 810 |
+
CUDNN_KNOB_TYPE_CTA_SPLIT_K_MODE CUDNN_DEPRECATED_ENUM = 16,
|
| 811 |
+
CUDNN_KNOB_TYPE_SPLIT_K_SLC = 17,
|
| 812 |
+
CUDNN_KNOB_TYPE_IDX_MODE CUDNN_DEPRECATED_ENUM = 18,
|
| 813 |
+
CUDNN_KNOB_TYPE_SLICED CUDNN_DEPRECATED_ENUM = 19,
|
| 814 |
+
CUDNN_KNOB_TYPE_SPLIT_RS CUDNN_DEPRECATED_ENUM = 20,
|
| 815 |
+
CUDNN_KNOB_TYPE_SINGLEBUFFER CUDNN_DEPRECATED_ENUM = 21,
|
| 816 |
+
CUDNN_KNOB_TYPE_LDGC CUDNN_DEPRECATED_ENUM = 22,
|
| 817 |
+
CUDNN_KNOB_TYPE_SPECFILT = 23,
|
| 818 |
+
CUDNN_KNOB_TYPE_KERNEL_CFG = 24,
|
| 819 |
+
CUDNN_KNOB_TYPE_WORKSPACE = 25,
|
| 820 |
+
CUDNN_KNOB_TYPE_TILE_CGA CUDNN_DEPRECATED_ENUM = 26,
|
| 821 |
+
CUDNN_KNOB_TYPE_TILE_CGA_M = 27,
|
| 822 |
+
CUDNN_KNOB_TYPE_TILE_CGA_N = 28,
|
| 823 |
+
CUDNN_KNOB_TYPE_BLOCK_SIZE = 29,
|
| 824 |
+
CUDNN_KNOB_TYPE_OCCUPANCY = 30,
|
| 825 |
+
CUDNN_KNOB_TYPE_ARRAY_SIZE_PER_THREAD = 31,
|
| 826 |
+
CUDNN_KNOB_TYPE_NUM_C_PER_BLOCK CUDNN_DEPRECATED_ENUM = 32,
|
| 827 |
+
CUDNN_KNOB_TYPE_SPLIT_COLS = 33,
|
| 828 |
+
CUDNN_KNOB_TYPE_TILE_ROWS = 34,
|
| 829 |
+
CUDNN_KNOB_TYPE_TILE_COLS = 35,
|
| 830 |
+
CUDNN_KNOB_TYPE_LOAD_SIZE = 36,
|
| 831 |
+
CUDNN_KNOB_TYPE_COUNTS,
|
| 832 |
+
} cudnnBackendKnobType_t;
|
| 833 |
+
|
| 834 |
+
typedef enum {
|
| 835 |
+
CUDNN_LAYOUT_TYPE_PREFERRED_NCHW = 0,
|
| 836 |
+
CUDNN_LAYOUT_TYPE_PREFERRED_NHWC = 1,
|
| 837 |
+
CUDNN_LAYOUT_TYPE_PREFERRED_PAD4CK = 2,
|
| 838 |
+
CUDNN_LAYOUT_TYPE_PREFERRED_PAD8CK = 3,
|
| 839 |
+
CUDNN_LAYOUT_TYPE_COUNT = 4,
|
| 840 |
+
} cudnnBackendLayoutType_t;
|
| 841 |
+
|
| 842 |
+
typedef enum {
|
| 843 |
+
CUDNN_HEUR_MODE_INSTANT = 0,
|
| 844 |
+
CUDNN_HEUR_MODE_B = 1,
|
| 845 |
+
CUDNN_HEUR_MODE_FALLBACK = 2,
|
| 846 |
+
CUDNN_HEUR_MODE_A = 3,
|
| 847 |
+
CUDNN_HEUR_MODES_COUNT = 4,
|
| 848 |
+
} cudnnBackendHeurMode_t;
|
| 849 |
+
|
| 850 |
+
typedef enum {
|
| 851 |
+
CUDNN_TENSOR_REORDERING_NONE = 0,
|
| 852 |
+
CUDNN_TENSOR_REORDERING_INT8x32 = 1,
|
| 853 |
+
CUDNN_TENSOR_REORDERING_F16x16 = 2,
|
| 854 |
+
} cudnnBackendTensorReordering_t;
|
| 855 |
+
|
| 856 |
+
typedef enum {
|
| 857 |
+
CUDNN_ZERO_PAD = 0,
|
| 858 |
+
CUDNN_NEG_INF_PAD = 1,
|
| 859 |
+
CUDNN_EDGE_VAL_PAD = 2,
|
| 860 |
+
} cudnnPaddingMode_t;
|
| 861 |
+
|
| 862 |
+
typedef enum {
|
| 863 |
+
CUDNN_LAYER_NORM = 0,
|
| 864 |
+
CUDNN_INSTANCE_NORM = 1,
|
| 865 |
+
CUDNN_BATCH_NORM = 2,
|
| 866 |
+
CUDNN_GROUP_NORM = 3,
|
| 867 |
+
CUDNN_RMS_NORM = 4,
|
| 868 |
+
} cudnnBackendNormMode_t;
|
| 869 |
+
|
| 870 |
+
typedef enum {
|
| 871 |
+
CUDNN_NORM_FWD_INFERENCE = 0,
|
| 872 |
+
CUDNN_NORM_FWD_TRAINING = 1,
|
| 873 |
+
} cudnnBackendNormFwdPhase_t;
|
| 874 |
+
|
| 875 |
+
cudnnStatus_t CUDNNWINAPI
|
| 876 |
+
cudnnBackendCreateDescriptor(cudnnBackendDescriptorType_t descriptorType, cudnnBackendDescriptor_t *descriptor);
|
| 877 |
+
|
| 878 |
+
cudnnStatus_t CUDNNWINAPI
|
| 879 |
+
cudnnBackendDestroyDescriptor(cudnnBackendDescriptor_t descriptor);
|
| 880 |
+
|
| 881 |
+
cudnnStatus_t CUDNNWINAPI
|
| 882 |
+
cudnnBackendInitialize(cudnnBackendDescriptor_t descriptor);
|
| 883 |
+
|
| 884 |
+
cudnnStatus_t CUDNNWINAPI
|
| 885 |
+
cudnnBackendFinalize(cudnnBackendDescriptor_t descriptor);
|
| 886 |
+
|
| 887 |
+
cudnnStatus_t CUDNNWINAPI
|
| 888 |
+
cudnnBackendSetAttribute(cudnnBackendDescriptor_t descriptor,
|
| 889 |
+
cudnnBackendAttributeName_t attributeName,
|
| 890 |
+
cudnnBackendAttributeType_t attributeType,
|
| 891 |
+
int64_t elementCount,
|
| 892 |
+
const void *arrayOfElements);
|
| 893 |
+
|
| 894 |
+
cudnnStatus_t CUDNNWINAPI
|
| 895 |
+
cudnnBackendGetAttribute(cudnnBackendDescriptor_t const descriptor,
|
| 896 |
+
cudnnBackendAttributeName_t attributeName,
|
| 897 |
+
cudnnBackendAttributeType_t attributeType,
|
| 898 |
+
int64_t requestedElementCount,
|
| 899 |
+
int64_t *elementCount,
|
| 900 |
+
void *arrayOfElements);
|
| 901 |
+
|
| 902 |
+
cudnnStatus_t CUDNNWINAPI
|
| 903 |
+
cudnnBackendExecute(cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBackendDescriptor_t variantPack);
|
| 904 |
+
|
| 905 |
+
#if defined(__cplusplus)
|
| 906 |
+
}
|
| 907 |
+
#endif
|
| 908 |
+
|
| 909 |
+
#endif /* CUDNN_GRAPH_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops.h
ADDED
|
@@ -0,0 +1,1316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* cudnn_ops : cuDNN's basic definitions and basic operations.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_OPS_H_)
|
| 55 |
+
#define CUDNN_OPS_H_
|
| 56 |
+
|
| 57 |
+
#include <stdint.h>
|
| 58 |
+
|
| 59 |
+
#include "cudnn_version.h"
|
| 60 |
+
#include "cudnn_graph.h"
|
| 61 |
+
|
| 62 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 63 |
+
#define CUDNN_OPS_MAJOR 9
|
| 64 |
+
#define CUDNN_OPS_MINOR 1
|
| 65 |
+
#define CUDNN_OPS_PATCH 0
|
| 66 |
+
|
| 67 |
+
#if (CUDNN_OPS_MAJOR != CUDNN_MAJOR) || (CUDNN_OPS_MINOR != CUDNN_MINOR) || (CUDNN_OPS_PATCH != CUDNN_PATCHLEVEL)
|
| 68 |
+
#error Version mismatch in cuDNN OPS INFER!!!
|
| 69 |
+
#endif
|
| 70 |
+
|
| 71 |
+
#if defined(__cplusplus)
|
| 72 |
+
extern "C" {
|
| 73 |
+
#endif
|
| 74 |
+
|
| 75 |
+
/* Data structures to represent Image/Filter and the Neural Network Layer */
|
| 76 |
+
typedef struct cudnnTensorStruct *cudnnTensorDescriptor_t;
|
| 77 |
+
typedef struct cudnnPoolingStruct *cudnnPoolingDescriptor_t CUDNN_DEPRECATED;
|
| 78 |
+
typedef struct cudnnFilterStruct *cudnnFilterDescriptor_t CUDNN_DEPRECATED;
|
| 79 |
+
typedef struct cudnnLRNStruct *cudnnLRNDescriptor_t;
|
| 80 |
+
typedef struct cudnnActivationStruct *cudnnActivationDescriptor_t CUDNN_DEPRECATED;
|
| 81 |
+
typedef struct cudnnSpatialTransformerStruct *cudnnSpatialTransformerDescriptor_t;
|
| 82 |
+
typedef struct cudnnOpTensorStruct *cudnnOpTensorDescriptor_t CUDNN_DEPRECATED;
|
| 83 |
+
typedef struct cudnnReduceTensorStruct *cudnnReduceTensorDescriptor_t CUDNN_DEPRECATED;
|
| 84 |
+
typedef struct cudnnCTCLossStruct *cudnnCTCLossDescriptor_t;
|
| 85 |
+
typedef struct cudnnTensorTransformStruct *cudnnTensorTransformDescriptor_t CUDNN_DEPRECATED;
|
| 86 |
+
/*
|
| 87 |
+
* CUDNN Determinism
|
| 88 |
+
*/
|
| 89 |
+
typedef enum {
|
| 90 |
+
CUDNN_NON_DETERMINISTIC = 0,
|
| 91 |
+
CUDNN_DETERMINISTIC = 1,
|
| 92 |
+
} cudnnDeterminism_t;
|
| 93 |
+
|
| 94 |
+
/* Create an instance of a generic Tensor descriptor */
|
| 95 |
+
cudnnStatus_t CUDNNWINAPI
|
| 96 |
+
cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc);
|
| 97 |
+
|
| 98 |
+
cudnnStatus_t CUDNNWINAPI
|
| 99 |
+
cudnnSetTensor4dDescriptor(cudnnTensorDescriptor_t tensorDesc,
|
| 100 |
+
cudnnTensorFormat_t format,
|
| 101 |
+
cudnnDataType_t dataType, /* image data type */
|
| 102 |
+
int n, /* number of inputs (batch size) */
|
| 103 |
+
int c, /* number of input feature maps */
|
| 104 |
+
int h, /* height of input section */
|
| 105 |
+
int w); /* width of input section */
|
| 106 |
+
|
| 107 |
+
cudnnStatus_t CUDNNWINAPI
|
| 108 |
+
cudnnSetTensor4dDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
|
| 109 |
+
cudnnDataType_t dataType, /* image data type */
|
| 110 |
+
int n, /* number of inputs (batch size) */
|
| 111 |
+
int c, /* number of input feature maps */
|
| 112 |
+
int h, /* height of input section */
|
| 113 |
+
int w, /* width of input section */
|
| 114 |
+
int nStride,
|
| 115 |
+
int cStride,
|
| 116 |
+
int hStride,
|
| 117 |
+
int wStride);
|
| 118 |
+
|
| 119 |
+
cudnnStatus_t CUDNNWINAPI
|
| 120 |
+
cudnnGetTensor4dDescriptor(const cudnnTensorDescriptor_t tensorDesc,
|
| 121 |
+
cudnnDataType_t *dataType, /* image data type */
|
| 122 |
+
int *n, /* number of inputs (batch size) */
|
| 123 |
+
int *c, /* number of input feature maps */
|
| 124 |
+
int *h, /* height of input section */
|
| 125 |
+
int *w, /* width of input section */
|
| 126 |
+
int *nStride,
|
| 127 |
+
int *cStride,
|
| 128 |
+
int *hStride,
|
| 129 |
+
int *wStride);
|
| 130 |
+
|
| 131 |
+
cudnnStatus_t CUDNNWINAPI
|
| 132 |
+
cudnnSetTensorNdDescriptor(cudnnTensorDescriptor_t tensorDesc,
|
| 133 |
+
cudnnDataType_t dataType,
|
| 134 |
+
int nbDims,
|
| 135 |
+
const int dimA[],
|
| 136 |
+
const int strideA[]);
|
| 137 |
+
|
| 138 |
+
cudnnStatus_t CUDNNWINAPI
|
| 139 |
+
cudnnSetTensorNdDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
|
| 140 |
+
cudnnTensorFormat_t format,
|
| 141 |
+
cudnnDataType_t dataType,
|
| 142 |
+
int nbDims,
|
| 143 |
+
const int dimA[]);
|
| 144 |
+
|
| 145 |
+
cudnnStatus_t CUDNNWINAPI
|
| 146 |
+
cudnnGetTensorNdDescriptor(const cudnnTensorDescriptor_t tensorDesc,
|
| 147 |
+
int nbDimsRequested,
|
| 148 |
+
cudnnDataType_t *dataType,
|
| 149 |
+
int *nbDims,
|
| 150 |
+
int dimA[],
|
| 151 |
+
int strideA[]);
|
| 152 |
+
|
| 153 |
+
cudnnStatus_t CUDNNWINAPI
|
| 154 |
+
cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size);
|
| 155 |
+
|
| 156 |
+
/* PixelOffset( n, c, h, w ) = n *input_stride + c * feature_stride + h * h_stride + w * w_stride
|
| 157 |
+
|
| 158 |
+
1)Example of all images in row major order one batch of features after the other (with an optional padding on row)
|
| 159 |
+
input_stride : c x h x h_stride
|
| 160 |
+
feature_stride : h x h_stride
|
| 161 |
+
h_stride : >= w ( h_stride = w if no padding)
|
| 162 |
+
w_stride : 1
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
2)Example of all images in row major with features maps interleaved
|
| 166 |
+
input_stride : c x h x h_stride
|
| 167 |
+
feature_stride : 1
|
| 168 |
+
h_stride : w x c
|
| 169 |
+
w_stride : c
|
| 170 |
+
|
| 171 |
+
3)Example of all images in column major order one batch of features after the other (with optional padding on column)
|
| 172 |
+
input_stride : c x w x w_stride
|
| 173 |
+
feature_stride : w x w_stride
|
| 174 |
+
h_stride : 1
|
| 175 |
+
w_stride : >= h
|
| 176 |
+
|
| 177 |
+
*/
|
| 178 |
+
|
| 179 |
+
/* Destroy an instance of Tensor4d descriptor */
|
| 180 |
+
cudnnStatus_t CUDNNWINAPI
|
| 181 |
+
cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc);
|
| 182 |
+
|
| 183 |
+
/* Fold/unfold transforms */
|
| 184 |
+
typedef enum {
|
| 185 |
+
CUDNN_TRANSFORM_FOLD = 0U,
|
| 186 |
+
CUDNN_TRANSFORM_UNFOLD = 1U,
|
| 187 |
+
} cudnnFoldingDirection_t;
|
| 188 |
+
|
| 189 |
+
/** Create a destination descriptor for cudnnTransformTensor */
|
| 190 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 191 |
+
cudnnInitTransformDest(const cudnnTensorTransformDescriptor_t transformDesc,
|
| 192 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 193 |
+
cudnnTensorDescriptor_t destDesc,
|
| 194 |
+
size_t *destSizeInBytes);
|
| 195 |
+
|
| 196 |
+
/** Create an empty tensor transform descriptor */
|
| 197 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 198 |
+
cudnnCreateTensorTransformDescriptor(cudnnTensorTransformDescriptor_t *transformDesc);
|
| 199 |
+
|
| 200 |
+
/** Initialize a previously created tensor transform descriptor. */
|
| 201 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 202 |
+
cudnnSetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
|
| 203 |
+
const uint32_t nbDims,
|
| 204 |
+
const cudnnTensorFormat_t destFormat,
|
| 205 |
+
const int32_t padBeforeA[],
|
| 206 |
+
const int32_t padAfterA[],
|
| 207 |
+
const uint32_t foldA[],
|
| 208 |
+
const cudnnFoldingDirection_t direction);
|
| 209 |
+
|
| 210 |
+
/**
|
| 211 |
+
* Retrieves the values stored in a previously initialized tensor transform
|
| 212 |
+
* descriptor.
|
| 213 |
+
*/
|
| 214 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 215 |
+
cudnnGetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
|
| 216 |
+
uint32_t nbDimsRequested,
|
| 217 |
+
cudnnTensorFormat_t *destFormat,
|
| 218 |
+
int32_t padBeforeA[],
|
| 219 |
+
int32_t padAfterA[],
|
| 220 |
+
uint32_t foldA[],
|
| 221 |
+
cudnnFoldingDirection_t *direction);
|
| 222 |
+
|
| 223 |
+
/**
|
| 224 |
+
* Destroys a previously created tensor transform descriptor.
|
| 225 |
+
*/
|
| 226 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 227 |
+
cudnnDestroyTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc);
|
| 228 |
+
|
| 229 |
+
/* Tensor layout conversion helper (y = alpha * x + beta * y) */
|
| 230 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 231 |
+
cudnnTransformTensor(cudnnHandle_t handle,
|
| 232 |
+
const void *alpha,
|
| 233 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 234 |
+
const void *x,
|
| 235 |
+
const void *beta,
|
| 236 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 237 |
+
void *y);
|
| 238 |
+
|
| 239 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 240 |
+
cudnnTransformTensorEx(cudnnHandle_t handle,
|
| 241 |
+
const cudnnTensorTransformDescriptor_t transDesc,
|
| 242 |
+
const void *alpha,
|
| 243 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 244 |
+
const void *srcData,
|
| 245 |
+
const void *beta,
|
| 246 |
+
const cudnnTensorDescriptor_t destDesc,
|
| 247 |
+
void *destData);
|
| 248 |
+
|
| 249 |
+
/* Tensor Bias addition : C = alpha * A + beta * C */
|
| 250 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 251 |
+
cudnnAddTensor(cudnnHandle_t handle,
|
| 252 |
+
const void *alpha,
|
| 253 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 254 |
+
const void *A,
|
| 255 |
+
const void *beta,
|
| 256 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 257 |
+
void *C);
|
| 258 |
+
|
| 259 |
+
/*
|
| 260 |
+
* CUDNN OpTensor op type
|
| 261 |
+
*/
|
| 262 |
+
typedef enum {
|
| 263 |
+
CUDNN_OP_TENSOR_ADD = 0,
|
| 264 |
+
CUDNN_OP_TENSOR_MUL = 1,
|
| 265 |
+
CUDNN_OP_TENSOR_MIN = 2,
|
| 266 |
+
CUDNN_OP_TENSOR_MAX = 3,
|
| 267 |
+
CUDNN_OP_TENSOR_SQRT = 4,
|
| 268 |
+
CUDNN_OP_TENSOR_NOT = 5,
|
| 269 |
+
} cudnnOpTensorOp_t;
|
| 270 |
+
|
| 271 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 272 |
+
cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc);
|
| 273 |
+
|
| 274 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 275 |
+
cudnnSetOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc,
|
| 276 |
+
cudnnOpTensorOp_t opTensorOp,
|
| 277 |
+
cudnnDataType_t opTensorCompType,
|
| 278 |
+
cudnnNanPropagation_t opTensorNanOpt);
|
| 279 |
+
|
| 280 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 281 |
+
cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc,
|
| 282 |
+
cudnnOpTensorOp_t *opTensorOp,
|
| 283 |
+
cudnnDataType_t *opTensorCompType,
|
| 284 |
+
cudnnNanPropagation_t *opTensorNanOpt);
|
| 285 |
+
|
| 286 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 287 |
+
cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc);
|
| 288 |
+
|
| 289 |
+
/* Tensor operation : C = op( alpha1 * A, alpha2 * B ) + beta * C */
|
| 290 |
+
/* B tensor is ignored for CUDNN_OP_TENSOR_SQRT, CUDNN_OP_TENSOR_NOT. */
|
| 291 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 292 |
+
cudnnOpTensor(cudnnHandle_t handle,
|
| 293 |
+
const cudnnOpTensorDescriptor_t opTensorDesc,
|
| 294 |
+
const void *alpha1,
|
| 295 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 296 |
+
const void *A,
|
| 297 |
+
const void *alpha2,
|
| 298 |
+
const cudnnTensorDescriptor_t bDesc,
|
| 299 |
+
const void *B,
|
| 300 |
+
const void *beta,
|
| 301 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 302 |
+
void *C);
|
| 303 |
+
|
| 304 |
+
/*
|
| 305 |
+
* CUDNN ReduceTensor indices type
|
| 306 |
+
*/
|
| 307 |
+
typedef enum {
|
| 308 |
+
CUDNN_REDUCE_TENSOR_NO_INDICES = 0,
|
| 309 |
+
CUDNN_REDUCE_TENSOR_FLATTENED_INDICES = 1,
|
| 310 |
+
} cudnnReduceTensorIndices_t CUDNN_DEPRECATED;
|
| 311 |
+
|
| 312 |
+
/*
|
| 313 |
+
* CUDNN tensor indices type size (all unsigned)
|
| 314 |
+
* Currently not supported, default is 32 bit unsigned.
|
| 315 |
+
*/
|
| 316 |
+
typedef enum {
|
| 317 |
+
CUDNN_32BIT_INDICES = 0,
|
| 318 |
+
CUDNN_64BIT_INDICES = 1,
|
| 319 |
+
CUDNN_16BIT_INDICES = 2,
|
| 320 |
+
CUDNN_8BIT_INDICES = 3,
|
| 321 |
+
} cudnnIndicesType_t CUDNN_DEPRECATED;
|
| 322 |
+
|
| 323 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 324 |
+
cudnnCreateReduceTensorDescriptor(cudnnReduceTensorDescriptor_t *reduceTensorDesc);
|
| 325 |
+
|
| 326 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 327 |
+
cudnnSetReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 328 |
+
cudnnReduceTensorOp_t reduceTensorOp,
|
| 329 |
+
cudnnDataType_t reduceTensorCompType,
|
| 330 |
+
cudnnNanPropagation_t reduceTensorNanOpt,
|
| 331 |
+
cudnnReduceTensorIndices_t reduceTensorIndices,
|
| 332 |
+
cudnnIndicesType_t reduceTensorIndicesType);
|
| 333 |
+
|
| 334 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 335 |
+
cudnnGetReduceTensorDescriptor(const cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 336 |
+
cudnnReduceTensorOp_t *reduceTensorOp,
|
| 337 |
+
cudnnDataType_t *reduceTensorCompType,
|
| 338 |
+
cudnnNanPropagation_t *reduceTensorNanOpt,
|
| 339 |
+
cudnnReduceTensorIndices_t *reduceTensorIndices,
|
| 340 |
+
cudnnIndicesType_t *reduceTensorIndicesType);
|
| 341 |
+
|
| 342 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 343 |
+
cudnnDestroyReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc);
|
| 344 |
+
|
| 345 |
+
/* Helper function to return the minimum size of the index space to be passed to the reduction given the input and
|
| 346 |
+
* output tensors */
|
| 347 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 348 |
+
cudnnGetReductionIndicesSize(cudnnHandle_t handle,
|
| 349 |
+
const cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 350 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 351 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 352 |
+
size_t *sizeInBytes);
|
| 353 |
+
|
| 354 |
+
/* Helper function to return the minimum size of the workspace to be passed to the reduction given the input and output
|
| 355 |
+
* tensors */
|
| 356 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 357 |
+
cudnnGetReductionWorkspaceSize(cudnnHandle_t handle,
|
| 358 |
+
const cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 359 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 360 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 361 |
+
size_t *sizeInBytes);
|
| 362 |
+
|
| 363 |
+
/* Tensor operation : C = reduce op( alpha * A ) + beta * C */
|
| 364 |
+
/* The NaN propagation enum applies to only the min and max reduce ops; the other reduce ops propagate NaN as usual. */
|
| 365 |
+
/* The indices space is ignored for reduce ops other than min or max. */
|
| 366 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 367 |
+
cudnnReduceTensor(cudnnHandle_t handle,
|
| 368 |
+
const cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 369 |
+
void *indices,
|
| 370 |
+
size_t indicesSizeInBytes,
|
| 371 |
+
void *workspace,
|
| 372 |
+
size_t workspaceSizeInBytes,
|
| 373 |
+
const void *alpha,
|
| 374 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 375 |
+
const void *A,
|
| 376 |
+
const void *beta,
|
| 377 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 378 |
+
void *C);
|
| 379 |
+
|
| 380 |
+
/* Set all values of a tensor to a given value : y[i] = value[0] */
|
| 381 |
+
cudnnStatus_t CUDNNWINAPI
|
| 382 |
+
cudnnSetTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *valuePtr);
|
| 383 |
+
|
| 384 |
+
/* Scale all values of a tensor by a given factor : y[i] = alpha * y[i] */
|
| 385 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 386 |
+
cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *alpha);
|
| 387 |
+
|
| 388 |
+
/* Create an instance of FilterStruct */
|
| 389 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 390 |
+
cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc);
|
| 391 |
+
|
| 392 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 393 |
+
cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc,
|
| 394 |
+
cudnnDataType_t dataType, /* image data type */
|
| 395 |
+
cudnnTensorFormat_t format,
|
| 396 |
+
int k, /* number of output feature maps */
|
| 397 |
+
int c, /* number of input feature maps */
|
| 398 |
+
int h, /* height of each input filter */
|
| 399 |
+
int w); /* width of each input filter */
|
| 400 |
+
|
| 401 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 402 |
+
cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc,
|
| 403 |
+
cudnnDataType_t *dataType, /* image data type */
|
| 404 |
+
cudnnTensorFormat_t *format,
|
| 405 |
+
int *k, /* number of output feature maps */
|
| 406 |
+
int *c, /* number of input feature maps */
|
| 407 |
+
int *h, /* height of each input filter */
|
| 408 |
+
int *w); /* width of each input filter */
|
| 409 |
+
|
| 410 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 411 |
+
cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc,
|
| 412 |
+
cudnnDataType_t dataType, /* image data type */
|
| 413 |
+
cudnnTensorFormat_t format,
|
| 414 |
+
int nbDims,
|
| 415 |
+
const int filterDimA[]);
|
| 416 |
+
|
| 417 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 418 |
+
cudnnGetFilterNdDescriptor(const cudnnFilterDescriptor_t filterDesc,
|
| 419 |
+
int nbDimsRequested,
|
| 420 |
+
cudnnDataType_t *dataType, /* image data type */
|
| 421 |
+
cudnnTensorFormat_t *format,
|
| 422 |
+
int *nbDims,
|
| 423 |
+
int filterDimA[]);
|
| 424 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 425 |
+
cudnnGetFilterSizeInBytes(const cudnnFilterDescriptor_t filterDesc, size_t *size);
|
| 426 |
+
|
| 427 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 428 |
+
cudnnTransformFilter(cudnnHandle_t handle,
|
| 429 |
+
const cudnnTensorTransformDescriptor_t transDesc,
|
| 430 |
+
const void *alpha,
|
| 431 |
+
const cudnnFilterDescriptor_t srcDesc,
|
| 432 |
+
const void *srcData,
|
| 433 |
+
const void *beta,
|
| 434 |
+
const cudnnFilterDescriptor_t destDesc,
|
| 435 |
+
void *destData);
|
| 436 |
+
|
| 437 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 438 |
+
cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc);
|
| 439 |
+
|
| 440 |
+
/*
|
| 441 |
+
* softmax algorithm
|
| 442 |
+
*/
|
| 443 |
+
typedef enum {
|
| 444 |
+
CUDNN_SOFTMAX_FAST = 0, /* straightforward implementation */
|
| 445 |
+
CUDNN_SOFTMAX_ACCURATE = 1, /* subtract max from every point to avoid overflow */
|
| 446 |
+
CUDNN_SOFTMAX_LOG = 2
|
| 447 |
+
} cudnnSoftmaxAlgorithm_t;
|
| 448 |
+
|
| 449 |
+
typedef enum {
|
| 450 |
+
CUDNN_SOFTMAX_MODE_INSTANCE = 0, /* compute the softmax over all C, H, W for each N */
|
| 451 |
+
CUDNN_SOFTMAX_MODE_CHANNEL = 1 /* compute the softmax over all C for each H, W, N */
|
| 452 |
+
} cudnnSoftmaxMode_t;
|
| 453 |
+
|
| 454 |
+
/* Softmax functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 455 |
+
|
| 456 |
+
/* Function to perform forward softmax */
|
| 457 |
+
cudnnStatus_t CUDNNWINAPI
|
| 458 |
+
cudnnSoftmaxForward(cudnnHandle_t handle,
|
| 459 |
+
cudnnSoftmaxAlgorithm_t algo,
|
| 460 |
+
cudnnSoftmaxMode_t mode,
|
| 461 |
+
const void *alpha,
|
| 462 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 463 |
+
const void *x,
|
| 464 |
+
const void *beta,
|
| 465 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 466 |
+
void *y);
|
| 467 |
+
|
| 468 |
+
/*
|
| 469 |
+
* pooling mode
|
| 470 |
+
*/
|
| 471 |
+
typedef enum {
|
| 472 |
+
CUDNN_POOLING_MAX = 0,
|
| 473 |
+
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values */
|
| 474 |
+
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values */
|
| 475 |
+
CUDNN_POOLING_MAX_DETERMINISTIC = 3
|
| 476 |
+
} cudnnPoolingMode_t CUDNN_DEPRECATED;
|
| 477 |
+
|
| 478 |
+
/* Create an instance of pooling descriptor */
|
| 479 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 480 |
+
cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
|
| 481 |
+
|
| 482 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 483 |
+
cudnnSetPooling2dDescriptor(cudnnPoolingDescriptor_t poolingDesc,
|
| 484 |
+
cudnnPoolingMode_t mode,
|
| 485 |
+
cudnnNanPropagation_t maxpoolingNanOpt,
|
| 486 |
+
int windowHeight,
|
| 487 |
+
int windowWidth,
|
| 488 |
+
int verticalPadding,
|
| 489 |
+
int horizontalPadding,
|
| 490 |
+
int verticalStride,
|
| 491 |
+
int horizontalStride);
|
| 492 |
+
|
| 493 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 494 |
+
cudnnGetPooling2dDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
|
| 495 |
+
cudnnPoolingMode_t *mode,
|
| 496 |
+
cudnnNanPropagation_t *maxpoolingNanOpt,
|
| 497 |
+
int *windowHeight,
|
| 498 |
+
int *windowWidth,
|
| 499 |
+
int *verticalPadding,
|
| 500 |
+
int *horizontalPadding,
|
| 501 |
+
int *verticalStride,
|
| 502 |
+
int *horizontalStride);
|
| 503 |
+
|
| 504 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 505 |
+
cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc,
|
| 506 |
+
const cudnnPoolingMode_t mode,
|
| 507 |
+
const cudnnNanPropagation_t maxpoolingNanOpt,
|
| 508 |
+
int nbDims,
|
| 509 |
+
const int windowDimA[],
|
| 510 |
+
const int paddingA[],
|
| 511 |
+
const int strideA[]);
|
| 512 |
+
|
| 513 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 514 |
+
cudnnGetPoolingNdDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
|
| 515 |
+
int nbDimsRequested,
|
| 516 |
+
cudnnPoolingMode_t *mode,
|
| 517 |
+
cudnnNanPropagation_t *maxpoolingNanOpt,
|
| 518 |
+
int *nbDims,
|
| 519 |
+
int windowDimA[],
|
| 520 |
+
int paddingA[],
|
| 521 |
+
int strideA[]);
|
| 522 |
+
|
| 523 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 524 |
+
cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
|
| 525 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 526 |
+
int nbDims,
|
| 527 |
+
int outputTensorDimA[]);
|
| 528 |
+
|
| 529 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 530 |
+
cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
|
| 531 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 532 |
+
int *n,
|
| 533 |
+
int *c,
|
| 534 |
+
int *h,
|
| 535 |
+
int *w);
|
| 536 |
+
|
| 537 |
+
/* Destroy an instance of pooling descriptor */
|
| 538 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 539 |
+
cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc);
|
| 540 |
+
|
| 541 |
+
/* Pooling functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 542 |
+
|
| 543 |
+
/* Function to perform forward pooling */
|
| 544 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 545 |
+
cudnnPoolingForward(cudnnHandle_t handle,
|
| 546 |
+
const cudnnPoolingDescriptor_t poolingDesc,
|
| 547 |
+
const void *alpha,
|
| 548 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 549 |
+
const void *x,
|
| 550 |
+
const void *beta,
|
| 551 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 552 |
+
void *y);
|
| 553 |
+
|
| 554 |
+
/* Activation functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 555 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 556 |
+
cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc);
|
| 557 |
+
|
| 558 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 559 |
+
cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc,
|
| 560 |
+
cudnnActivationMode_t mode,
|
| 561 |
+
cudnnNanPropagation_t reluNanOpt,
|
| 562 |
+
double coef); /* ceiling for clipped RELU, alpha for ELU */
|
| 563 |
+
|
| 564 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 565 |
+
cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc,
|
| 566 |
+
cudnnActivationMode_t *mode,
|
| 567 |
+
cudnnNanPropagation_t *reluNanOpt,
|
| 568 |
+
double *coef); /* ceiling for clipped RELU, alpha for ELU */
|
| 569 |
+
|
| 570 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 571 |
+
cudnnSetActivationDescriptorSwishBeta(cudnnActivationDescriptor_t activationDesc, double swish_beta);
|
| 572 |
+
|
| 573 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 574 |
+
cudnnGetActivationDescriptorSwishBeta(cudnnActivationDescriptor_t activationDesc, double *swish_beta);
|
| 575 |
+
|
| 576 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 577 |
+
cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc);
|
| 578 |
+
|
| 579 |
+
/* Function to perform forward activation */
|
| 580 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 581 |
+
cudnnActivationForward(cudnnHandle_t handle,
|
| 582 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 583 |
+
const void *alpha,
|
| 584 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 585 |
+
const void *x,
|
| 586 |
+
const void *beta,
|
| 587 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 588 |
+
void *y);
|
| 589 |
+
|
| 590 |
+
/*
|
| 591 |
+
* Create an instance of LRN (Local Response Normalization) descriptor
|
| 592 |
+
* Uses lrnN=5, lrnAlpha=1e-4, lrnBeta=0.75, lrnK=2.0 as defaults from Krizhevsky'12 ImageNet paper
|
| 593 |
+
*/
|
| 594 |
+
cudnnStatus_t CUDNNWINAPI
|
| 595 |
+
cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc);
|
| 596 |
+
|
| 597 |
+
#define CUDNN_LRN_MIN_N 1 /* minimum allowed lrnN */
|
| 598 |
+
#define CUDNN_LRN_MAX_N 16 /* maximum allowed lrnN */
|
| 599 |
+
#define CUDNN_LRN_MIN_K 1e-5 /* minimum allowed lrnK */
|
| 600 |
+
#define CUDNN_LRN_MIN_BETA 0.01 /* minimum allowed lrnBeta */
|
| 601 |
+
|
| 602 |
+
/* LRN layer mode */
|
| 603 |
+
typedef enum {
|
| 604 |
+
CUDNN_LRN_CROSS_CHANNEL_DIM1 = 0, /* Normalize across tensor's dimA[1] dimension */
|
| 605 |
+
} cudnnLRNMode_t;
|
| 606 |
+
|
| 607 |
+
/*
|
| 608 |
+
* Uses a window [center-lookBehind, center+lookAhead], where
|
| 609 |
+
* lookBehind = floor( (lrnN-1)/2 ), lookAhead = lrnN-lookBehind-1.
|
| 610 |
+
* Values of double parameters cast to tensor data type.
|
| 611 |
+
*/
|
| 612 |
+
cudnnStatus_t CUDNNWINAPI
|
| 613 |
+
cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned lrnN, double lrnAlpha, double lrnBeta, double lrnK);
|
| 614 |
+
/*
|
| 615 |
+
* Retrieve the settings currently stored in an LRN layer descriptor
|
| 616 |
+
* Any of the provided pointers can be NULL (no corresponding value will be returned)
|
| 617 |
+
*/
|
| 618 |
+
cudnnStatus_t CUDNNWINAPI
|
| 619 |
+
cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrnAlpha, double *lrnBeta, double *lrnK);
|
| 620 |
+
|
| 621 |
+
/* Destroy an instance of LRN descriptor */
|
| 622 |
+
cudnnStatus_t CUDNNWINAPI
|
| 623 |
+
cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc);
|
| 624 |
+
|
| 625 |
+
/* LRN functions: output = alpha * normalize(x) + beta * old_y */
|
| 626 |
+
|
| 627 |
+
/* LRN cross-channel forward computation. Double parameters cast to tensor data type */
|
| 628 |
+
cudnnStatus_t CUDNNWINAPI
|
| 629 |
+
cudnnLRNCrossChannelForward(cudnnHandle_t handle,
|
| 630 |
+
cudnnLRNDescriptor_t normDesc,
|
| 631 |
+
cudnnLRNMode_t lrnMode,
|
| 632 |
+
const void *alpha,
|
| 633 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 634 |
+
const void *x,
|
| 635 |
+
const void *beta,
|
| 636 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 637 |
+
void *y);
|
| 638 |
+
|
| 639 |
+
typedef enum {
|
| 640 |
+
CUDNN_DIVNORM_PRECOMPUTED_MEANS = 0,
|
| 641 |
+
} cudnnDivNormMode_t;
|
| 642 |
+
|
| 643 |
+
/* LCN/divisive normalization functions: y = alpha * normalize(x) + beta * y */
|
| 644 |
+
cudnnStatus_t CUDNNWINAPI
|
| 645 |
+
cudnnDivisiveNormalizationForward(cudnnHandle_t handle,
|
| 646 |
+
cudnnLRNDescriptor_t normDesc,
|
| 647 |
+
cudnnDivNormMode_t mode,
|
| 648 |
+
const void *alpha,
|
| 649 |
+
const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */
|
| 650 |
+
const void *x,
|
| 651 |
+
const void *means, /* if NULL, means are assumed to be zero */
|
| 652 |
+
void *temp,
|
| 653 |
+
void *temp2,
|
| 654 |
+
const void *beta,
|
| 655 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 656 |
+
void *y);
|
| 657 |
+
|
| 658 |
+
typedef enum {
|
| 659 |
+
/* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
|
| 660 |
+
CUDNN_BATCHNORM_PER_ACTIVATION = 0,
|
| 661 |
+
|
| 662 |
+
/* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
|
| 663 |
+
CUDNN_BATCHNORM_SPATIAL = 1,
|
| 664 |
+
|
| 665 |
+
/*
|
| 666 |
+
* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors).
|
| 667 |
+
* May be faster than CUDNN_BATCHNORM_SPATIAL but imposes some limits on the range of values
|
| 668 |
+
*/
|
| 669 |
+
CUDNN_BATCHNORM_SPATIAL_PERSISTENT = 2,
|
| 670 |
+
} cudnnBatchNormMode_t CUDNN_DEPRECATED;
|
| 671 |
+
|
| 672 |
+
#define CUDNN_BN_MIN_EPSILON 0.0 /* Minimum epsilon allowed to be used in the Batch Normalization formula */
|
| 673 |
+
|
| 674 |
+
/*
|
| 675 |
+
* Derives a tensor descriptor from layer data descriptor for BatchNormalization
|
| 676 |
+
* scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
|
| 677 |
+
* bnScaleBiasMeanVarDesc and bnScaleBiasDiffDesc in Batch Normalization forward and backward functions.
|
| 678 |
+
*/
|
| 679 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 680 |
+
cudnnDeriveBNTensorDescriptor(cudnnTensorDescriptor_t derivedBnDesc,
|
| 681 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 682 |
+
cudnnBatchNormMode_t mode);
|
| 683 |
+
|
| 684 |
+
typedef enum {
|
| 685 |
+
CUDNN_BATCHNORM_OPS_BN = 0, /* do batch normalization only */
|
| 686 |
+
CUDNN_BATCHNORM_OPS_BN_ACTIVATION = 1, /* do batchNorm, then activation */
|
| 687 |
+
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION = 2, /* do batchNorm, then elemWiseAdd, then activation */
|
| 688 |
+
} cudnnBatchNormOps_t CUDNN_DEPRECATED;
|
| 689 |
+
|
| 690 |
+
/*
|
| 691 |
+
* Performs Batch Normalization during Inference:
|
| 692 |
+
* y[i] = bnScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + bnBias[k]
|
| 693 |
+
* with bnScale, bnBias, runningMean, runningInvVariance tensors indexed
|
| 694 |
+
* according to spatial or per-activation mode. Refer to cudnnBatchNormalizationForwardTraining
|
| 695 |
+
* above for notes on function arguments.
|
| 696 |
+
*/
|
| 697 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 698 |
+
cudnnBatchNormalizationForwardInference(cudnnHandle_t handle,
|
| 699 |
+
cudnnBatchNormMode_t mode,
|
| 700 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 701 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 702 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 703 |
+
const void *x, /* NxCxHxW */
|
| 704 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 705 |
+
void *y, /* NxCxHxW */
|
| 706 |
+
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
|
| 707 |
+
const void *bnScale,
|
| 708 |
+
const void *bnBias,
|
| 709 |
+
const void *estimatedMean,
|
| 710 |
+
const void *estimatedVariance,
|
| 711 |
+
double epsilon);
|
| 712 |
+
|
| 713 |
+
typedef enum {
|
| 714 |
+
/* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
|
| 715 |
+
CUDNN_NORM_PER_ACTIVATION = 0,
|
| 716 |
+
|
| 717 |
+
/* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
|
| 718 |
+
CUDNN_NORM_PER_CHANNEL = 1,
|
| 719 |
+
} cudnnNormMode_t CUDNN_DEPRECATED;
|
| 720 |
+
|
| 721 |
+
typedef enum { CUDNN_NORM_ALGO_STANDARD = 0, CUDNN_NORM_ALGO_PERSIST = 1 } cudnnNormAlgo_t CUDNN_DEPRECATED;
|
| 722 |
+
|
| 723 |
+
/*
|
| 724 |
+
* Derives a tensor descriptor from layer data descriptor for Normalization
|
| 725 |
+
* scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
|
| 726 |
+
* normScaleBiasMeanVarDesc and normScaleBiasDiffDesc in Normalization forward and backward functions.
|
| 727 |
+
*/
|
| 728 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 729 |
+
cudnnDeriveNormTensorDescriptor(cudnnTensorDescriptor_t derivedNormScaleBiasDesc,
|
| 730 |
+
cudnnTensorDescriptor_t derivedNormMeanVarDesc,
|
| 731 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 732 |
+
cudnnNormMode_t mode,
|
| 733 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 734 |
+
|
| 735 |
+
typedef enum {
|
| 736 |
+
CUDNN_NORM_OPS_NORM = 0, /* do normalization only */
|
| 737 |
+
CUDNN_NORM_OPS_NORM_ACTIVATION = 1, /* do Norm, then activation */
|
| 738 |
+
CUDNN_NORM_OPS_NORM_ADD_ACTIVATION = 2, /* do Norm, then elemWiseAdd, then activation */
|
| 739 |
+
} cudnnNormOps_t CUDNN_DEPRECATED;
|
| 740 |
+
|
| 741 |
+
/*
|
| 742 |
+
* Performs Normalization during Inference:
|
| 743 |
+
* y[i] = normScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + normBias[k]
|
| 744 |
+
* with normScale, normBias, runningMean, runningInvVariance tensors indexed
|
| 745 |
+
* according to per-channel or per-activation mode. Refer to cudnnNormalizationForwardTraining
|
| 746 |
+
* above for notes on function arguments.
|
| 747 |
+
*/
|
| 748 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 749 |
+
cudnnNormalizationForwardInference(cudnnHandle_t handle,
|
| 750 |
+
cudnnNormMode_t mode,
|
| 751 |
+
cudnnNormOps_t normOps,
|
| 752 |
+
cudnnNormAlgo_t algo,
|
| 753 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 754 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 755 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 756 |
+
const void *x, /* NxCxHxW */
|
| 757 |
+
const cudnnTensorDescriptor_t normScaleBiasDesc,
|
| 758 |
+
const void *normScale,
|
| 759 |
+
const void *normBias,
|
| 760 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 761 |
+
const void *estimatedMean,
|
| 762 |
+
const void *estimatedVariance,
|
| 763 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 764 |
+
const void *z,
|
| 765 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 766 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 767 |
+
void *y, /* NxCxHxW */
|
| 768 |
+
double epsilon,
|
| 769 |
+
int groupCnt); /* Place hold for future work*/
|
| 770 |
+
|
| 771 |
+
/* APIs for spatial transformer network*/
|
| 772 |
+
typedef enum {
|
| 773 |
+
CUDNN_SAMPLER_BILINEAR = 0,
|
| 774 |
+
} cudnnSamplerType_t;
|
| 775 |
+
|
| 776 |
+
cudnnStatus_t CUDNNWINAPI
|
| 777 |
+
cudnnCreateSpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t *stDesc);
|
| 778 |
+
|
| 779 |
+
cudnnStatus_t CUDNNWINAPI
|
| 780 |
+
cudnnSetSpatialTransformerNdDescriptor(cudnnSpatialTransformerDescriptor_t stDesc,
|
| 781 |
+
cudnnSamplerType_t samplerType,
|
| 782 |
+
cudnnDataType_t dataType,
|
| 783 |
+
const int nbDims,
|
| 784 |
+
const int dimA[]);
|
| 785 |
+
|
| 786 |
+
cudnnStatus_t CUDNNWINAPI
|
| 787 |
+
cudnnDestroySpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t stDesc);
|
| 788 |
+
|
| 789 |
+
cudnnStatus_t CUDNNWINAPI
|
| 790 |
+
cudnnSpatialTfGridGeneratorForward(cudnnHandle_t handle,
|
| 791 |
+
const cudnnSpatialTransformerDescriptor_t stDesc,
|
| 792 |
+
const void *theta,
|
| 793 |
+
void *grid);
|
| 794 |
+
|
| 795 |
+
cudnnStatus_t CUDNNWINAPI
|
| 796 |
+
cudnnSpatialTfSamplerForward(cudnnHandle_t handle,
|
| 797 |
+
cudnnSpatialTransformerDescriptor_t stDesc,
|
| 798 |
+
const void *alpha,
|
| 799 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 800 |
+
const void *x,
|
| 801 |
+
const void *grid,
|
| 802 |
+
const void *beta,
|
| 803 |
+
cudnnTensorDescriptor_t yDesc,
|
| 804 |
+
void *y);
|
| 805 |
+
|
| 806 |
+
typedef struct cudnnDropoutStruct *cudnnDropoutDescriptor_t;
|
| 807 |
+
|
| 808 |
+
cudnnStatus_t CUDNNWINAPI
|
| 809 |
+
cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc);
|
| 810 |
+
|
| 811 |
+
cudnnStatus_t CUDNNWINAPI
|
| 812 |
+
cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc);
|
| 813 |
+
|
| 814 |
+
/*helper function to determine size of the states to be passed to cudnnSetDropoutDescriptor */
|
| 815 |
+
cudnnStatus_t CUDNNWINAPI
|
| 816 |
+
cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t *sizeInBytes);
|
| 817 |
+
|
| 818 |
+
/*helper function to determine size of the reserve space to be passed to dropout forward/backward calls */
|
| 819 |
+
cudnnStatus_t CUDNNWINAPI
|
| 820 |
+
cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes);
|
| 821 |
+
|
| 822 |
+
cudnnStatus_t CUDNNWINAPI
|
| 823 |
+
cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
|
| 824 |
+
cudnnHandle_t handle,
|
| 825 |
+
float dropout,
|
| 826 |
+
void *states,
|
| 827 |
+
size_t stateSizeInBytes,
|
| 828 |
+
unsigned long long seed);
|
| 829 |
+
|
| 830 |
+
/* Restores the dropout descriptor to a previously saved-off state */
|
| 831 |
+
cudnnStatus_t CUDNNWINAPI
|
| 832 |
+
cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
|
| 833 |
+
cudnnHandle_t handle,
|
| 834 |
+
float dropout,
|
| 835 |
+
void *states,
|
| 836 |
+
size_t stateSizeInBytes,
|
| 837 |
+
unsigned long long seed);
|
| 838 |
+
|
| 839 |
+
cudnnStatus_t CUDNNWINAPI
|
| 840 |
+
cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
|
| 841 |
+
cudnnHandle_t handle,
|
| 842 |
+
float *dropout,
|
| 843 |
+
void **states,
|
| 844 |
+
unsigned long long *seed);
|
| 845 |
+
|
| 846 |
+
cudnnStatus_t CUDNNWINAPI
|
| 847 |
+
cudnnDropoutForward(cudnnHandle_t handle,
|
| 848 |
+
const cudnnDropoutDescriptor_t dropoutDesc,
|
| 849 |
+
const cudnnTensorDescriptor_t xdesc,
|
| 850 |
+
const void *x,
|
| 851 |
+
const cudnnTensorDescriptor_t ydesc,
|
| 852 |
+
void *y,
|
| 853 |
+
void *reserveSpace,
|
| 854 |
+
size_t reserveSpaceSizeInBytes);
|
| 855 |
+
|
| 856 |
+
/* TODO: move these enums out to the appropriate submodule */
|
| 857 |
+
typedef enum {
|
| 858 |
+
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0,
|
| 859 |
+
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1,
|
| 860 |
+
CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2,
|
| 861 |
+
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3,
|
| 862 |
+
CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4,
|
| 863 |
+
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5,
|
| 864 |
+
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6,
|
| 865 |
+
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7,
|
| 866 |
+
CUDNN_CONVOLUTION_FWD_ALGO_COUNT = 8
|
| 867 |
+
} cudnnConvolutionFwdAlgo_t;
|
| 868 |
+
|
| 869 |
+
typedef enum {
|
| 870 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, /* non-deterministic */
|
| 871 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
|
| 872 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
|
| 873 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3, /* non-deterministic */
|
| 874 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4, /* not implemented */
|
| 875 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5,
|
| 876 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING = 6,
|
| 877 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT = 7
|
| 878 |
+
} cudnnConvolutionBwdFilterAlgo_t;
|
| 879 |
+
|
| 880 |
+
typedef enum {
|
| 881 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, /* non-deterministic */
|
| 882 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
|
| 883 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
|
| 884 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
|
| 885 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4,
|
| 886 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5,
|
| 887 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT = 6
|
| 888 |
+
} cudnnConvolutionBwdDataAlgo_t;
|
| 889 |
+
|
| 890 |
+
typedef enum { CUDNN_CTC_LOSS_ALGO_DETERMINISTIC = 0, CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC = 1 } cudnnCTCLossAlgo_t;
|
| 891 |
+
|
| 892 |
+
/*
|
| 893 |
+
* \brief Cross-library version checker.
|
| 894 |
+
* This function is implemented differently in each sub-library. Each sublib
|
| 895 |
+
* checks whether its own version matches that of its dependencies.
|
| 896 |
+
* \returns CUDNN_STATUS_SUCCESS if the version check passes,
|
| 897 |
+
* CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
|
| 898 |
+
*/
|
| 899 |
+
cudnnStatus_t CUDNNWINAPI
|
| 900 |
+
cudnnOpsVersionCheck(void);
|
| 901 |
+
|
| 902 |
+
/* Function to perform backward softmax */
|
| 903 |
+
cudnnStatus_t CUDNNWINAPI
|
| 904 |
+
cudnnSoftmaxBackward(cudnnHandle_t handle,
|
| 905 |
+
cudnnSoftmaxAlgorithm_t algo,
|
| 906 |
+
cudnnSoftmaxMode_t mode,
|
| 907 |
+
const void *alpha,
|
| 908 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 909 |
+
const void *y,
|
| 910 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 911 |
+
const void *dy,
|
| 912 |
+
const void *beta,
|
| 913 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 914 |
+
void *dx);
|
| 915 |
+
|
| 916 |
+
/* Function to perform backward pooling */
|
| 917 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 918 |
+
cudnnPoolingBackward(cudnnHandle_t handle,
|
| 919 |
+
const cudnnPoolingDescriptor_t poolingDesc,
|
| 920 |
+
const void *alpha,
|
| 921 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 922 |
+
const void *y,
|
| 923 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 924 |
+
const void *dy,
|
| 925 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 926 |
+
const void *x,
|
| 927 |
+
const void *beta,
|
| 928 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 929 |
+
void *dx);
|
| 930 |
+
|
| 931 |
+
/* Function to perform backward activation */
|
| 932 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 933 |
+
cudnnActivationBackward(cudnnHandle_t handle,
|
| 934 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 935 |
+
const void *alpha,
|
| 936 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 937 |
+
const void *y,
|
| 938 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 939 |
+
const void *dy,
|
| 940 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 941 |
+
const void *x,
|
| 942 |
+
const void *beta,
|
| 943 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 944 |
+
void *dx);
|
| 945 |
+
|
| 946 |
+
/* LRN cross-channel backward computation. Double parameters cast to tensor data type */
|
| 947 |
+
cudnnStatus_t CUDNNWINAPI
|
| 948 |
+
cudnnLRNCrossChannelBackward(cudnnHandle_t handle,
|
| 949 |
+
cudnnLRNDescriptor_t normDesc,
|
| 950 |
+
cudnnLRNMode_t lrnMode,
|
| 951 |
+
const void *alpha,
|
| 952 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 953 |
+
const void *y,
|
| 954 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 955 |
+
const void *dy,
|
| 956 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 957 |
+
const void *x,
|
| 958 |
+
const void *beta,
|
| 959 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 960 |
+
void *dx);
|
| 961 |
+
|
| 962 |
+
cudnnStatus_t CUDNNWINAPI
|
| 963 |
+
cudnnDivisiveNormalizationBackward(cudnnHandle_t handle,
|
| 964 |
+
cudnnLRNDescriptor_t normDesc,
|
| 965 |
+
cudnnDivNormMode_t mode,
|
| 966 |
+
const void *alpha,
|
| 967 |
+
const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */
|
| 968 |
+
const void *x,
|
| 969 |
+
const void *means, /* if NULL, means are assumed to be zero */
|
| 970 |
+
const void *dy,
|
| 971 |
+
void *temp,
|
| 972 |
+
void *temp2,
|
| 973 |
+
const void *beta,
|
| 974 |
+
const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */
|
| 975 |
+
void *dx, /* output x differential */
|
| 976 |
+
void *dMeans); /* output means differential, can be NULL */
|
| 977 |
+
|
| 978 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 979 |
+
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(cudnnHandle_t handle,
|
| 980 |
+
cudnnBatchNormMode_t mode,
|
| 981 |
+
cudnnBatchNormOps_t bnOps,
|
| 982 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 983 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 984 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 985 |
+
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
|
| 986 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 987 |
+
size_t *sizeInBytes);
|
| 988 |
+
|
| 989 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 990 |
+
cudnnGetBatchNormalizationBackwardExWorkspaceSize(cudnnHandle_t handle,
|
| 991 |
+
cudnnBatchNormMode_t mode,
|
| 992 |
+
cudnnBatchNormOps_t bnOps,
|
| 993 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 994 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 995 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 996 |
+
const cudnnTensorDescriptor_t dzDesc,
|
| 997 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 998 |
+
const cudnnTensorDescriptor_t dBnScaleBiasDesc,
|
| 999 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1000 |
+
size_t *sizeInBytes);
|
| 1001 |
+
|
| 1002 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1003 |
+
cudnnGetBatchNormalizationTrainingExReserveSpaceSize(cudnnHandle_t handle,
|
| 1004 |
+
cudnnBatchNormMode_t mode,
|
| 1005 |
+
cudnnBatchNormOps_t bnOps,
|
| 1006 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1007 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1008 |
+
size_t *sizeInBytes);
|
| 1009 |
+
|
| 1010 |
+
/* Computes y = BN(x). Also accumulates moving averages of mean and inverse variances */
|
| 1011 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1012 |
+
cudnnBatchNormalizationForwardTraining(
|
| 1013 |
+
cudnnHandle_t handle,
|
| 1014 |
+
cudnnBatchNormMode_t mode,
|
| 1015 |
+
|
| 1016 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 1017 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 1018 |
+
|
| 1019 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1020 |
+
const void *x, /* NxCxHxW */
|
| 1021 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1022 |
+
void *y, /* NxCxHxW */
|
| 1023 |
+
|
| 1024 |
+
/* Shared desc for the next 6 tensors in the argument list.
|
| 1025 |
+
Data type to be set as follows:
|
| 1026 |
+
type = (typeOf(x) == double) ? double : float
|
| 1027 |
+
Dimensions for this descriptor depend on normalization mode
|
| 1028 |
+
- Spatial Normalization : tensors are expected to have dims 1xCx1x1
|
| 1029 |
+
(normalization is performed across NxHxW)
|
| 1030 |
+
- Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW
|
| 1031 |
+
(normalization is performed across N) */
|
| 1032 |
+
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
|
| 1033 |
+
|
| 1034 |
+
/* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */
|
| 1035 |
+
const void *bnScale,
|
| 1036 |
+
const void *bnBias,
|
| 1037 |
+
|
| 1038 |
+
/* MUST use factor=1 in the very first call of a complete training cycle.
|
| 1039 |
+
Use a factor=1/(1+n) at N-th call to the function to get
|
| 1040 |
+
Cumulative Moving Average (CMA) behavior
|
| 1041 |
+
CMA[n] = (x[1]+...+x[n])/n
|
| 1042 |
+
Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) =
|
| 1043 |
+
((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) =
|
| 1044 |
+
CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */
|
| 1045 |
+
double exponentialAverageFactor,
|
| 1046 |
+
|
| 1047 |
+
/* Used in Training phase only.
|
| 1048 |
+
runningMean = newMean*factor + runningMean*(1-factor) */
|
| 1049 |
+
void *resultRunningMean,
|
| 1050 |
+
/* Output in training mode, input in inference. Is the moving average
|
| 1051 |
+
of variance[x] (factor is applied in the same way as for runningMean) */
|
| 1052 |
+
void *resultRunningVariance,
|
| 1053 |
+
|
| 1054 |
+
/* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
|
| 1055 |
+
double epsilon,
|
| 1056 |
+
|
| 1057 |
+
/* Optionally save intermediate results from the forward pass here
|
| 1058 |
+
- can be reused to speed up backward pass. NULL if unused */
|
| 1059 |
+
void *resultSaveMean,
|
| 1060 |
+
void *resultSaveInvVariance);
|
| 1061 |
+
|
| 1062 |
+
/* Computes y = relu(BN(x) + z). Also accumulates moving averages of mean and inverse variances */
|
| 1063 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1064 |
+
cudnnBatchNormalizationForwardTrainingEx(
|
| 1065 |
+
cudnnHandle_t handle,
|
| 1066 |
+
cudnnBatchNormMode_t mode,
|
| 1067 |
+
cudnnBatchNormOps_t bnOps,
|
| 1068 |
+
|
| 1069 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 1070 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 1071 |
+
|
| 1072 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1073 |
+
const void *xData,
|
| 1074 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 1075 |
+
const void *zData,
|
| 1076 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1077 |
+
void *yData,
|
| 1078 |
+
|
| 1079 |
+
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
|
| 1080 |
+
const void *bnScale,
|
| 1081 |
+
const void *bnBias,
|
| 1082 |
+
|
| 1083 |
+
double exponentialAverageFactor,
|
| 1084 |
+
void *resultRunningMean,
|
| 1085 |
+
void *resultRunningVariance,
|
| 1086 |
+
|
| 1087 |
+
/* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
|
| 1088 |
+
double epsilon,
|
| 1089 |
+
|
| 1090 |
+
/* Optionally save intermediate results from the forward pass here
|
| 1091 |
+
- can be reused to speed up backward pass. NULL if unused */
|
| 1092 |
+
void *resultSaveMean,
|
| 1093 |
+
void *resultSaveInvVariance,
|
| 1094 |
+
|
| 1095 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 1096 |
+
void *workspace,
|
| 1097 |
+
size_t workSpaceSizeInBytes,
|
| 1098 |
+
void *reserveSpace,
|
| 1099 |
+
size_t reserveSpaceSizeInBytes);
|
| 1100 |
+
|
| 1101 |
+
/* Performs backward pass of Batch Normalization layer. Returns x gradient,
|
| 1102 |
+
* bnScale gradient and bnBias gradient */
|
| 1103 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1104 |
+
cudnnBatchNormalizationBackward(cudnnHandle_t handle,
|
| 1105 |
+
cudnnBatchNormMode_t mode,
|
| 1106 |
+
const void *alphaDataDiff,
|
| 1107 |
+
const void *betaDataDiff,
|
| 1108 |
+
const void *alphaParamDiff,
|
| 1109 |
+
const void *betaParamDiff,
|
| 1110 |
+
const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */
|
| 1111 |
+
const void *x,
|
| 1112 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1113 |
+
const void *dy,
|
| 1114 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1115 |
+
void *dx,
|
| 1116 |
+
/* Shared tensor desc for the 4 tensors below */
|
| 1117 |
+
const cudnnTensorDescriptor_t dBnScaleBiasDesc,
|
| 1118 |
+
const void *bnScale, /* bnBias doesn't affect backpropagation */
|
| 1119 |
+
/* scale and bias diff are not backpropagated below this layer */
|
| 1120 |
+
void *dBnScaleResult,
|
| 1121 |
+
void *dBnBiasResult,
|
| 1122 |
+
/* Same epsilon as forward pass */
|
| 1123 |
+
double epsilon,
|
| 1124 |
+
|
| 1125 |
+
/* Optionally cached intermediate results from
|
| 1126 |
+
forward pass */
|
| 1127 |
+
const void *savedMean,
|
| 1128 |
+
const void *savedInvVariance);
|
| 1129 |
+
|
| 1130 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1131 |
+
cudnnBatchNormalizationBackwardEx(cudnnHandle_t handle,
|
| 1132 |
+
cudnnBatchNormMode_t mode,
|
| 1133 |
+
cudnnBatchNormOps_t bnOps,
|
| 1134 |
+
|
| 1135 |
+
const void *alphaDataDiff,
|
| 1136 |
+
const void *betaDataDiff,
|
| 1137 |
+
const void *alphaParamDiff,
|
| 1138 |
+
const void *betaParamDiff,
|
| 1139 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1140 |
+
const void *xData,
|
| 1141 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1142 |
+
const void *yData,
|
| 1143 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1144 |
+
const void *dyData,
|
| 1145 |
+
const cudnnTensorDescriptor_t dzDesc,
|
| 1146 |
+
void *dzData,
|
| 1147 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1148 |
+
void *dxData,
|
| 1149 |
+
|
| 1150 |
+
/* Shared tensor desc for the 4 tensors below */
|
| 1151 |
+
const cudnnTensorDescriptor_t dBnScaleBiasDesc,
|
| 1152 |
+
const void *bnScaleData,
|
| 1153 |
+
const void *bnBiasData, /* needed if there is activation */
|
| 1154 |
+
void *dBnScaleData,
|
| 1155 |
+
void *dBnBiasData,
|
| 1156 |
+
double epsilon, /* Same epsilon as forward pass */
|
| 1157 |
+
|
| 1158 |
+
/* Optionally cached intermediate results from
|
| 1159 |
+
forward pass */
|
| 1160 |
+
const void *savedMean,
|
| 1161 |
+
const void *savedInvVariance,
|
| 1162 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 1163 |
+
void *workSpace,
|
| 1164 |
+
size_t workSpaceSizeInBytes,
|
| 1165 |
+
void *reserveSpace,
|
| 1166 |
+
size_t reserveSpaceSizeInBytes);
|
| 1167 |
+
|
| 1168 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1169 |
+
cudnnGetNormalizationForwardTrainingWorkspaceSize(cudnnHandle_t handle,
|
| 1170 |
+
cudnnNormMode_t mode,
|
| 1171 |
+
cudnnNormOps_t normOps,
|
| 1172 |
+
cudnnNormAlgo_t algo,
|
| 1173 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1174 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 1175 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1176 |
+
const cudnnTensorDescriptor_t normScaleBiasDesc,
|
| 1177 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1178 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 1179 |
+
size_t *sizeInBytes,
|
| 1180 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1181 |
+
|
| 1182 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1183 |
+
cudnnGetNormalizationBackwardWorkspaceSize(cudnnHandle_t handle,
|
| 1184 |
+
cudnnNormMode_t mode,
|
| 1185 |
+
cudnnNormOps_t normOps,
|
| 1186 |
+
cudnnNormAlgo_t algo,
|
| 1187 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1188 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1189 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1190 |
+
const cudnnTensorDescriptor_t dzDesc,
|
| 1191 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1192 |
+
const cudnnTensorDescriptor_t dNormScaleBiasDesc,
|
| 1193 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1194 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 1195 |
+
size_t *sizeInBytes,
|
| 1196 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1197 |
+
|
| 1198 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1199 |
+
cudnnGetNormalizationTrainingReserveSpaceSize(cudnnHandle_t handle,
|
| 1200 |
+
cudnnNormMode_t mode,
|
| 1201 |
+
cudnnNormOps_t normOps,
|
| 1202 |
+
cudnnNormAlgo_t algo,
|
| 1203 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1204 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1205 |
+
size_t *sizeInBytes,
|
| 1206 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1207 |
+
|
| 1208 |
+
/* Computes y = relu(Norm(x) + z). Also accumulates moving averages of mean and inverse variances */
|
| 1209 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1210 |
+
cudnnNormalizationForwardTraining(cudnnHandle_t handle,
|
| 1211 |
+
cudnnNormMode_t mode,
|
| 1212 |
+
cudnnNormOps_t normOps,
|
| 1213 |
+
cudnnNormAlgo_t algo,
|
| 1214 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 1215 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 1216 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1217 |
+
const void *xData,
|
| 1218 |
+
const cudnnTensorDescriptor_t normScaleBiasDesc,
|
| 1219 |
+
const void *normScale,
|
| 1220 |
+
const void *normBias,
|
| 1221 |
+
double exponentialAverageFactor,
|
| 1222 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 1223 |
+
void *resultRunningMean,
|
| 1224 |
+
void *resultRunningVariance,
|
| 1225 |
+
/* Has to be >= 0. Should be the same in forward and backward functions. */
|
| 1226 |
+
double epsilon,
|
| 1227 |
+
/* Optionally save intermediate results from the forward pass here
|
| 1228 |
+
- can be reused to speed up backward pass. NULL if unused */
|
| 1229 |
+
void *resultSaveMean,
|
| 1230 |
+
void *resultSaveInvVariance,
|
| 1231 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 1232 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 1233 |
+
const void *zData,
|
| 1234 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1235 |
+
void *yData,
|
| 1236 |
+
void *workspace,
|
| 1237 |
+
size_t workSpaceSizeInBytes,
|
| 1238 |
+
void *reserveSpace,
|
| 1239 |
+
size_t reserveSpaceSizeInBytes,
|
| 1240 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1241 |
+
|
| 1242 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1243 |
+
cudnnNormalizationBackward(cudnnHandle_t handle,
|
| 1244 |
+
cudnnNormMode_t mode,
|
| 1245 |
+
cudnnNormOps_t normOps,
|
| 1246 |
+
cudnnNormAlgo_t algo,
|
| 1247 |
+
const void *alphaDataDiff,
|
| 1248 |
+
const void *betaDataDiff,
|
| 1249 |
+
const void *alphaParamDiff,
|
| 1250 |
+
const void *betaParamDiff,
|
| 1251 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1252 |
+
const void *xData,
|
| 1253 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1254 |
+
const void *yData,
|
| 1255 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1256 |
+
const void *dyData,
|
| 1257 |
+
const cudnnTensorDescriptor_t dzDesc,
|
| 1258 |
+
void *dzData,
|
| 1259 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1260 |
+
void *dxData,
|
| 1261 |
+
/* Shared tensor desc for the 4 tensors below */
|
| 1262 |
+
const cudnnTensorDescriptor_t dNormScaleBiasDesc,
|
| 1263 |
+
const void *normScaleData,
|
| 1264 |
+
const void *normBiasData, /* needed if there is activation */
|
| 1265 |
+
void *dNormScaleData,
|
| 1266 |
+
void *dNormBiasData,
|
| 1267 |
+
double epsilon, /* Same epsilon as forward pass */
|
| 1268 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 1269 |
+
/* Optionally cached intermediate results from
|
| 1270 |
+
forward pass */
|
| 1271 |
+
const void *savedMean,
|
| 1272 |
+
const void *savedInvVariance,
|
| 1273 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 1274 |
+
void *workSpace,
|
| 1275 |
+
size_t workSpaceSizeInBytes,
|
| 1276 |
+
void *reserveSpace,
|
| 1277 |
+
size_t reserveSpaceSizeInBytes,
|
| 1278 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1279 |
+
|
| 1280 |
+
cudnnStatus_t CUDNNWINAPI
|
| 1281 |
+
cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle,
|
| 1282 |
+
const cudnnSpatialTransformerDescriptor_t stDesc,
|
| 1283 |
+
const void *dgrid,
|
| 1284 |
+
void *dtheta);
|
| 1285 |
+
|
| 1286 |
+
cudnnStatus_t CUDNNWINAPI
|
| 1287 |
+
cudnnSpatialTfSamplerBackward(cudnnHandle_t handle,
|
| 1288 |
+
cudnnSpatialTransformerDescriptor_t stDesc,
|
| 1289 |
+
const void *alpha,
|
| 1290 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1291 |
+
const void *x,
|
| 1292 |
+
const void *beta,
|
| 1293 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1294 |
+
void *dx,
|
| 1295 |
+
const void *alphaDgrid,
|
| 1296 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1297 |
+
const void *dy,
|
| 1298 |
+
const void *grid,
|
| 1299 |
+
const void *betaDgrid,
|
| 1300 |
+
void *dgrid);
|
| 1301 |
+
|
| 1302 |
+
cudnnStatus_t CUDNNWINAPI
|
| 1303 |
+
cudnnDropoutBackward(cudnnHandle_t handle,
|
| 1304 |
+
const cudnnDropoutDescriptor_t dropoutDesc,
|
| 1305 |
+
const cudnnTensorDescriptor_t dydesc,
|
| 1306 |
+
const void *dy,
|
| 1307 |
+
const cudnnTensorDescriptor_t dxdesc,
|
| 1308 |
+
void *dx,
|
| 1309 |
+
void *reserveSpace,
|
| 1310 |
+
size_t reserveSpaceSizeInBytes);
|
| 1311 |
+
|
| 1312 |
+
#if defined(__cplusplus)
|
| 1313 |
+
}
|
| 1314 |
+
#endif
|
| 1315 |
+
|
| 1316 |
+
#endif /* CUDNN_OPS_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version_v9.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/**
|
| 51 |
+
* \file: The master cuDNN version file.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#ifndef CUDNN_VERSION_H_
|
| 55 |
+
#define CUDNN_VERSION_H_
|
| 56 |
+
|
| 57 |
+
#define CUDNN_MAJOR 9
|
| 58 |
+
#define CUDNN_MINOR 1
|
| 59 |
+
#define CUDNN_PATCHLEVEL 0
|
| 60 |
+
|
| 61 |
+
#define CUDNN_VERSION (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
|
| 62 |
+
|
| 63 |
+
/* cannot use constexpr here since this is a C-only file */
|
| 64 |
+
/* Below is the max SM version this cuDNN library is aware of and supports natively */
|
| 65 |
+
|
| 66 |
+
#define CUDNN_MAX_SM_MAJOR_NUMBER 9
|
| 67 |
+
#define CUDNN_MAX_SM_MINOR_NUMBER 0
|
| 68 |
+
#define CUDNN_MAX_DEVICE_VERSION (CUDNN_MAX_SM_MAJOR_NUMBER * 100 + CUDNN_MAX_SM_MINOR_NUMBER * 10)
|
| 69 |
+
|
| 70 |
+
#endif /* CUDNN_VERSION_H */
|
.venv/lib/python3.11/site-packages/nvidia/cusolver/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cusolver/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (188 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverDn.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverMg.h
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2019 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#if !defined(CUSOLVERMG_H_)
|
| 51 |
+
#define CUSOLVERMG_H_
|
| 52 |
+
|
| 53 |
+
#include <stdint.h>
|
| 54 |
+
#include "cusolverDn.h"
|
| 55 |
+
|
| 56 |
+
#if defined(__cplusplus)
|
| 57 |
+
extern "C" {
|
| 58 |
+
#endif /* __cplusplus */
|
| 59 |
+
|
| 60 |
+
struct cusolverMgContext;
|
| 61 |
+
typedef struct cusolverMgContext *cusolverMgHandle_t;
|
| 62 |
+
|
| 63 |
+
/**
|
| 64 |
+
* \beief This enum decides how 1D device Ids (or process ranks) get mapped to
|
| 65 |
+
* a 2D grid.
|
| 66 |
+
*/
|
| 67 |
+
typedef enum {
|
| 68 |
+
|
| 69 |
+
CUDALIBMG_GRID_MAPPING_ROW_MAJOR = 1,
|
| 70 |
+
CUDALIBMG_GRID_MAPPING_COL_MAJOR = 0
|
| 71 |
+
|
| 72 |
+
} cusolverMgGridMapping_t;
|
| 73 |
+
|
| 74 |
+
/** \brief Opaque structure of the distributed grid */
|
| 75 |
+
typedef void *cudaLibMgGrid_t;
|
| 76 |
+
/** \brief Opaque structure of the distributed matrix descriptor */
|
| 77 |
+
typedef void *cudaLibMgMatrixDesc_t;
|
| 78 |
+
|
| 79 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgCreate(cusolverMgHandle_t *handle);
|
| 80 |
+
|
| 81 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgDestroy(cusolverMgHandle_t handle);
|
| 82 |
+
|
| 83 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgDeviceSelect(
|
| 84 |
+
cusolverMgHandle_t handle,
|
| 85 |
+
int nbDevices,
|
| 86 |
+
int deviceId[]);
|
| 87 |
+
|
| 88 |
+
/**
|
| 89 |
+
* \brief Allocates resources related to the shared memory device grid.
|
| 90 |
+
* \param[out] grid the opaque data strcuture that holds the grid
|
| 91 |
+
* \param[in] numRowDevices number of devices in the row
|
| 92 |
+
* \param[in] numColDevices number of devices in the column
|
| 93 |
+
* \param[in] deviceId This array of size height * width stores the
|
| 94 |
+
* device-ids of the 2D grid; each entry must correspond to a valid
|
| 95 |
+
* gpu or to -1 (denoting CPU). \param[in] mapping whether the 2D grid is in
|
| 96 |
+
* row/column major \returns the status code
|
| 97 |
+
*/
|
| 98 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgCreateDeviceGrid(
|
| 99 |
+
cudaLibMgGrid_t * grid,
|
| 100 |
+
int32_t numRowDevices,
|
| 101 |
+
int32_t numColDevices,
|
| 102 |
+
const int32_t deviceId[],
|
| 103 |
+
cusolverMgGridMapping_t mapping);
|
| 104 |
+
|
| 105 |
+
/**
|
| 106 |
+
* \brief Releases the allocated resources related to the distributed grid.
|
| 107 |
+
* \param[in] grid the opaque data strcuture that holds the distributed grid
|
| 108 |
+
* \returns the status code
|
| 109 |
+
*/
|
| 110 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgDestroyGrid(cudaLibMgGrid_t grid);
|
| 111 |
+
|
| 112 |
+
/**
|
| 113 |
+
* \brief Allocates resources related to the distributed matrix descriptor.
|
| 114 |
+
* \param[out] desc the opaque data strcuture that holds the descriptor
|
| 115 |
+
* \param[in] numRows number of total rows
|
| 116 |
+
* \param[in] numCols number of total columns
|
| 117 |
+
* \param[in] rowBlockSize row block size
|
| 118 |
+
* \param[in] colBlockSize column block size
|
| 119 |
+
* \param[in] dataType the data type of each element in cudaDataType
|
| 120 |
+
* \param[in] grid the opaque data structure of the distributed grid
|
| 121 |
+
* \returns the status code
|
| 122 |
+
*/
|
| 123 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgCreateMatrixDesc(
|
| 124 |
+
cudaLibMgMatrixDesc_t *desc,
|
| 125 |
+
int64_t numRows,
|
| 126 |
+
int64_t numCols,
|
| 127 |
+
int64_t rowBlockSize,
|
| 128 |
+
int64_t colBlockSize,
|
| 129 |
+
cudaDataType dataType,
|
| 130 |
+
const cudaLibMgGrid_t grid);
|
| 131 |
+
|
| 132 |
+
/**
|
| 133 |
+
* \brief Releases the allocated resources related to the distributed matrix
|
| 134 |
+
* descriptor. \param[in] desc the opaque data strcuture that holds the
|
| 135 |
+
* descriptor \returns the status code
|
| 136 |
+
*/
|
| 137 |
+
cusolverStatus_t CUSOLVERAPI
|
| 138 |
+
cusolverMgDestroyMatrixDesc(cudaLibMgMatrixDesc_t desc);
|
| 139 |
+
|
| 140 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgSyevd_bufferSize(
|
| 141 |
+
cusolverMgHandle_t handle,
|
| 142 |
+
cusolverEigMode_t jobz,
|
| 143 |
+
cublasFillMode_t uplo,
|
| 144 |
+
int N,
|
| 145 |
+
void * array_d_A[],
|
| 146 |
+
int IA,
|
| 147 |
+
int JA,
|
| 148 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 149 |
+
void * W,
|
| 150 |
+
cudaDataType dataTypeW,
|
| 151 |
+
cudaDataType computeType,
|
| 152 |
+
int64_t * lwork);
|
| 153 |
+
|
| 154 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgSyevd(
|
| 155 |
+
cusolverMgHandle_t handle,
|
| 156 |
+
cusolverEigMode_t jobz,
|
| 157 |
+
cublasFillMode_t uplo,
|
| 158 |
+
int N,
|
| 159 |
+
void * array_d_A[],
|
| 160 |
+
int IA,
|
| 161 |
+
int JA,
|
| 162 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 163 |
+
void * W,
|
| 164 |
+
cudaDataType dataTypeW,
|
| 165 |
+
cudaDataType computeType,
|
| 166 |
+
void * array_d_work[],
|
| 167 |
+
int64_t lwork,
|
| 168 |
+
int * info);
|
| 169 |
+
|
| 170 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgGetrf_bufferSize(
|
| 171 |
+
cusolverMgHandle_t handle,
|
| 172 |
+
int M,
|
| 173 |
+
int N,
|
| 174 |
+
void * array_d_A[],
|
| 175 |
+
int IA,
|
| 176 |
+
int JA,
|
| 177 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 178 |
+
int * array_d_IPIV[],
|
| 179 |
+
cudaDataType computeType,
|
| 180 |
+
int64_t * lwork);
|
| 181 |
+
|
| 182 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgGetrf(
|
| 183 |
+
cusolverMgHandle_t handle,
|
| 184 |
+
int M,
|
| 185 |
+
int N,
|
| 186 |
+
void * array_d_A[],
|
| 187 |
+
int IA,
|
| 188 |
+
int JA,
|
| 189 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 190 |
+
int * array_d_IPIV[],
|
| 191 |
+
cudaDataType computeType,
|
| 192 |
+
void * array_d_work[],
|
| 193 |
+
int64_t lwork,
|
| 194 |
+
int * info);
|
| 195 |
+
|
| 196 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgGetrs_bufferSize(
|
| 197 |
+
cusolverMgHandle_t handle,
|
| 198 |
+
cublasOperation_t TRANS,
|
| 199 |
+
int N,
|
| 200 |
+
int NRHS,
|
| 201 |
+
void * array_d_A[],
|
| 202 |
+
int IA,
|
| 203 |
+
int JA,
|
| 204 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 205 |
+
int * array_d_IPIV[],
|
| 206 |
+
void * array_d_B[],
|
| 207 |
+
int IB,
|
| 208 |
+
int JB,
|
| 209 |
+
cudaLibMgMatrixDesc_t descrB,
|
| 210 |
+
cudaDataType computeType,
|
| 211 |
+
int64_t * lwork);
|
| 212 |
+
|
| 213 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgGetrs(
|
| 214 |
+
cusolverMgHandle_t handle,
|
| 215 |
+
cublasOperation_t TRANS,
|
| 216 |
+
int N,
|
| 217 |
+
int NRHS,
|
| 218 |
+
void * array_d_A[],
|
| 219 |
+
int IA,
|
| 220 |
+
int JA,
|
| 221 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 222 |
+
int * array_d_IPIV[],
|
| 223 |
+
void * array_d_B[],
|
| 224 |
+
int IB,
|
| 225 |
+
int JB,
|
| 226 |
+
cudaLibMgMatrixDesc_t descrB,
|
| 227 |
+
cudaDataType computeType,
|
| 228 |
+
void * array_d_work[],
|
| 229 |
+
int64_t lwork,
|
| 230 |
+
int * info);
|
| 231 |
+
|
| 232 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgPotrf_bufferSize(
|
| 233 |
+
cusolverMgHandle_t handle,
|
| 234 |
+
cublasFillMode_t uplo,
|
| 235 |
+
int N,
|
| 236 |
+
void * array_d_A[],
|
| 237 |
+
int IA,
|
| 238 |
+
int JA,
|
| 239 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 240 |
+
cudaDataType computeType,
|
| 241 |
+
int64_t * lwork);
|
| 242 |
+
|
| 243 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgPotrf(
|
| 244 |
+
cusolverMgHandle_t handle,
|
| 245 |
+
cublasFillMode_t uplo,
|
| 246 |
+
int N,
|
| 247 |
+
void * array_d_A[],
|
| 248 |
+
int IA,
|
| 249 |
+
int JA,
|
| 250 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 251 |
+
cudaDataType computeType,
|
| 252 |
+
void * array_d_work[],
|
| 253 |
+
int64_t lwork,
|
| 254 |
+
int * h_info);
|
| 255 |
+
|
| 256 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgPotrs_bufferSize(
|
| 257 |
+
cusolverMgHandle_t handle,
|
| 258 |
+
cublasFillMode_t uplo,
|
| 259 |
+
int n,
|
| 260 |
+
int nrhs,
|
| 261 |
+
void * array_d_A[],
|
| 262 |
+
int IA,
|
| 263 |
+
int JA,
|
| 264 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 265 |
+
void * array_d_B[],
|
| 266 |
+
int IB,
|
| 267 |
+
int JB,
|
| 268 |
+
cudaLibMgMatrixDesc_t descrB,
|
| 269 |
+
cudaDataType computeType,
|
| 270 |
+
int64_t * lwork);
|
| 271 |
+
|
| 272 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgPotrs(
|
| 273 |
+
cusolverMgHandle_t handle,
|
| 274 |
+
cublasFillMode_t uplo,
|
| 275 |
+
int n,
|
| 276 |
+
int nrhs,
|
| 277 |
+
void * array_d_A[],
|
| 278 |
+
int IA,
|
| 279 |
+
int JA,
|
| 280 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 281 |
+
void * array_d_B[],
|
| 282 |
+
int IB,
|
| 283 |
+
int JB,
|
| 284 |
+
cudaLibMgMatrixDesc_t descrB,
|
| 285 |
+
cudaDataType computeType,
|
| 286 |
+
void * array_d_work[],
|
| 287 |
+
int64_t lwork,
|
| 288 |
+
int * h_info);
|
| 289 |
+
|
| 290 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgPotri_bufferSize(
|
| 291 |
+
cusolverMgHandle_t handle,
|
| 292 |
+
cublasFillMode_t uplo,
|
| 293 |
+
int N,
|
| 294 |
+
void * array_d_A[],
|
| 295 |
+
int IA,
|
| 296 |
+
int JA,
|
| 297 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 298 |
+
cudaDataType computeType,
|
| 299 |
+
int64_t * lwork);
|
| 300 |
+
|
| 301 |
+
cusolverStatus_t CUSOLVERAPI cusolverMgPotri(
|
| 302 |
+
cusolverMgHandle_t handle,
|
| 303 |
+
cublasFillMode_t uplo,
|
| 304 |
+
int N,
|
| 305 |
+
void * array_d_A[],
|
| 306 |
+
int IA,
|
| 307 |
+
int JA,
|
| 308 |
+
cudaLibMgMatrixDesc_t descrA,
|
| 309 |
+
cudaDataType computeType,
|
| 310 |
+
void * array_d_work[],
|
| 311 |
+
int64_t lwork,
|
| 312 |
+
int * h_info);
|
| 313 |
+
|
| 314 |
+
#if defined(__cplusplus)
|
| 315 |
+
}
|
| 316 |
+
#endif /* __cplusplus */
|
| 317 |
+
|
| 318 |
+
#endif // CUSOLVERMG_H_
|
.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverRf.h
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2014 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#if !defined(CUSOLVERRF_H_)
|
| 51 |
+
#define CUSOLVERRF_H_
|
| 52 |
+
|
| 53 |
+
#include "driver_types.h"
|
| 54 |
+
#include "cuComplex.h"
|
| 55 |
+
#include "cusolver_common.h"
|
| 56 |
+
|
| 57 |
+
#if defined(__cplusplus)
|
| 58 |
+
extern "C" {
|
| 59 |
+
#endif /* __cplusplus */
|
| 60 |
+
|
| 61 |
+
/* CUSOLVERRF mode */
|
| 62 |
+
typedef enum {
|
| 63 |
+
CUSOLVERRF_RESET_VALUES_FAST_MODE_OFF = 0, // default
|
| 64 |
+
CUSOLVERRF_RESET_VALUES_FAST_MODE_ON = 1
|
| 65 |
+
} cusolverRfResetValuesFastMode_t;
|
| 66 |
+
|
| 67 |
+
/* CUSOLVERRF matrix format */
|
| 68 |
+
typedef enum {
|
| 69 |
+
CUSOLVERRF_MATRIX_FORMAT_CSR = 0, // default
|
| 70 |
+
CUSOLVERRF_MATRIX_FORMAT_CSC = 1
|
| 71 |
+
} cusolverRfMatrixFormat_t;
|
| 72 |
+
|
| 73 |
+
/* CUSOLVERRF unit diagonal */
|
| 74 |
+
typedef enum {
|
| 75 |
+
CUSOLVERRF_UNIT_DIAGONAL_STORED_L = 0, // default
|
| 76 |
+
CUSOLVERRF_UNIT_DIAGONAL_STORED_U = 1,
|
| 77 |
+
CUSOLVERRF_UNIT_DIAGONAL_ASSUMED_L = 2,
|
| 78 |
+
CUSOLVERRF_UNIT_DIAGONAL_ASSUMED_U = 3
|
| 79 |
+
} cusolverRfUnitDiagonal_t;
|
| 80 |
+
|
| 81 |
+
/* CUSOLVERRF factorization algorithm */
|
| 82 |
+
typedef enum {
|
| 83 |
+
CUSOLVERRF_FACTORIZATION_ALG0 = 0, // default
|
| 84 |
+
CUSOLVERRF_FACTORIZATION_ALG1 = 1,
|
| 85 |
+
CUSOLVERRF_FACTORIZATION_ALG2 = 2,
|
| 86 |
+
} cusolverRfFactorization_t;
|
| 87 |
+
|
| 88 |
+
/* CUSOLVERRF triangular solve algorithm */
|
| 89 |
+
typedef enum {
|
| 90 |
+
CUSOLVERRF_TRIANGULAR_SOLVE_ALG1 = 1, // default
|
| 91 |
+
CUSOLVERRF_TRIANGULAR_SOLVE_ALG2 = 2,
|
| 92 |
+
CUSOLVERRF_TRIANGULAR_SOLVE_ALG3 = 3
|
| 93 |
+
} cusolverRfTriangularSolve_t;
|
| 94 |
+
|
| 95 |
+
/* CUSOLVERRF numeric boost report */
|
| 96 |
+
typedef enum {
|
| 97 |
+
CUSOLVERRF_NUMERIC_BOOST_NOT_USED = 0, // default
|
| 98 |
+
CUSOLVERRF_NUMERIC_BOOST_USED = 1
|
| 99 |
+
} cusolverRfNumericBoostReport_t;
|
| 100 |
+
|
| 101 |
+
/* Opaque structure holding CUSOLVERRF library common */
|
| 102 |
+
struct cusolverRfCommon;
|
| 103 |
+
typedef struct cusolverRfCommon* cusolverRfHandle_t;
|
| 104 |
+
|
| 105 |
+
/* CUSOLVERRF create (allocate memory) and destroy (free memory) in the handle
|
| 106 |
+
*/
|
| 107 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfCreate(cusolverRfHandle_t* handle);
|
| 108 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfDestroy(cusolverRfHandle_t handle);
|
| 109 |
+
|
| 110 |
+
/* CUSOLVERRF set and get input format */
|
| 111 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfGetMatrixFormat(
|
| 112 |
+
cusolverRfHandle_t handle,
|
| 113 |
+
cusolverRfMatrixFormat_t* format,
|
| 114 |
+
cusolverRfUnitDiagonal_t* diag);
|
| 115 |
+
|
| 116 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfSetMatrixFormat(
|
| 117 |
+
cusolverRfHandle_t handle,
|
| 118 |
+
cusolverRfMatrixFormat_t format,
|
| 119 |
+
cusolverRfUnitDiagonal_t diag);
|
| 120 |
+
|
| 121 |
+
/* CUSOLVERRF set and get numeric properties */
|
| 122 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfSetNumericProperties(
|
| 123 |
+
cusolverRfHandle_t handle,
|
| 124 |
+
double zero,
|
| 125 |
+
double boost);
|
| 126 |
+
|
| 127 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfGetNumericProperties(
|
| 128 |
+
cusolverRfHandle_t handle,
|
| 129 |
+
double* zero,
|
| 130 |
+
double* boost);
|
| 131 |
+
|
| 132 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfGetNumericBoostReport(
|
| 133 |
+
cusolverRfHandle_t handle,
|
| 134 |
+
cusolverRfNumericBoostReport_t* report);
|
| 135 |
+
|
| 136 |
+
/* CUSOLVERRF choose the triangular solve algorithm */
|
| 137 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfSetAlgs(
|
| 138 |
+
cusolverRfHandle_t handle,
|
| 139 |
+
cusolverRfFactorization_t factAlg,
|
| 140 |
+
cusolverRfTriangularSolve_t solveAlg);
|
| 141 |
+
|
| 142 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfGetAlgs(
|
| 143 |
+
cusolverRfHandle_t handle,
|
| 144 |
+
cusolverRfFactorization_t* factAlg,
|
| 145 |
+
cusolverRfTriangularSolve_t* solveAlg);
|
| 146 |
+
|
| 147 |
+
/* CUSOLVERRF set and get fast mode */
|
| 148 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfGetResetValuesFastMode(
|
| 149 |
+
cusolverRfHandle_t handle,
|
| 150 |
+
cusolverRfResetValuesFastMode_t* fastMode);
|
| 151 |
+
|
| 152 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfSetResetValuesFastMode(
|
| 153 |
+
cusolverRfHandle_t handle,
|
| 154 |
+
cusolverRfResetValuesFastMode_t fastMode);
|
| 155 |
+
|
| 156 |
+
/*** Non-Batched Routines ***/
|
| 157 |
+
/* CUSOLVERRF setup of internal structures from host or device memory */
|
| 158 |
+
cusolverStatus_t CUSOLVERAPI
|
| 159 |
+
cusolverRfSetupHost(/* Input (in the host memory) */
|
| 160 |
+
int n,
|
| 161 |
+
int nnzA,
|
| 162 |
+
int* h_csrRowPtrA,
|
| 163 |
+
int* h_csrColIndA,
|
| 164 |
+
double* h_csrValA,
|
| 165 |
+
int nnzL,
|
| 166 |
+
int* h_csrRowPtrL,
|
| 167 |
+
int* h_csrColIndL,
|
| 168 |
+
double* h_csrValL,
|
| 169 |
+
int nnzU,
|
| 170 |
+
int* h_csrRowPtrU,
|
| 171 |
+
int* h_csrColIndU,
|
| 172 |
+
double* h_csrValU,
|
| 173 |
+
int* h_P,
|
| 174 |
+
int* h_Q,
|
| 175 |
+
/* Output */
|
| 176 |
+
cusolverRfHandle_t handle);
|
| 177 |
+
|
| 178 |
+
cusolverStatus_t CUSOLVERAPI
|
| 179 |
+
cusolverRfSetupDevice(/* Input (in the device memory) */
|
| 180 |
+
int n,
|
| 181 |
+
int nnzA,
|
| 182 |
+
int* csrRowPtrA,
|
| 183 |
+
int* csrColIndA,
|
| 184 |
+
double* csrValA,
|
| 185 |
+
int nnzL,
|
| 186 |
+
int* csrRowPtrL,
|
| 187 |
+
int* csrColIndL,
|
| 188 |
+
double* csrValL,
|
| 189 |
+
int nnzU,
|
| 190 |
+
int* csrRowPtrU,
|
| 191 |
+
int* csrColIndU,
|
| 192 |
+
double* csrValU,
|
| 193 |
+
int* P,
|
| 194 |
+
int* Q,
|
| 195 |
+
/* Output */
|
| 196 |
+
cusolverRfHandle_t handle);
|
| 197 |
+
|
| 198 |
+
/* CUSOLVERRF update the matrix values (assuming the reordering, pivoting
|
| 199 |
+
and consequently the sparsity pattern of L and U did not change),
|
| 200 |
+
and zero out the remaining values. */
|
| 201 |
+
cusolverStatus_t CUSOLVERAPI
|
| 202 |
+
cusolverRfResetValues(/* Input (in the device memory) */
|
| 203 |
+
int n,
|
| 204 |
+
int nnzA,
|
| 205 |
+
int* csrRowPtrA,
|
| 206 |
+
int* csrColIndA,
|
| 207 |
+
double* csrValA,
|
| 208 |
+
int* P,
|
| 209 |
+
int* Q,
|
| 210 |
+
/* Output */
|
| 211 |
+
cusolverRfHandle_t handle);
|
| 212 |
+
|
| 213 |
+
/* CUSOLVERRF analysis (for parallelism) */
|
| 214 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfAnalyze(cusolverRfHandle_t handle);
|
| 215 |
+
|
| 216 |
+
/* CUSOLVERRF re-factorization (for parallelism) */
|
| 217 |
+
cusolverStatus_t CUSOLVERAPI cusolverRfRefactor(cusolverRfHandle_t handle);
|
| 218 |
+
|
| 219 |
+
/* CUSOLVERRF extraction: Get L & U packed into a single matrix M */
|
| 220 |
+
cusolverStatus_t CUSOLVERAPI
|
| 221 |
+
cusolverRfAccessBundledFactorsDevice(/* Input */
|
| 222 |
+
cusolverRfHandle_t handle,
|
| 223 |
+
/* Output (in the host memory) */
|
| 224 |
+
int* nnzM,
|
| 225 |
+
/* Output (in the device memory) */
|
| 226 |
+
int** Mp,
|
| 227 |
+
int** Mi,
|
| 228 |
+
double** Mx);
|
| 229 |
+
|
| 230 |
+
cusolverStatus_t CUSOLVERAPI
|
| 231 |
+
cusolverRfExtractBundledFactorsHost(/* Input */
|
| 232 |
+
cusolverRfHandle_t handle,
|
| 233 |
+
/* Output (in the host memory) */
|
| 234 |
+
int* h_nnzM,
|
| 235 |
+
int** h_Mp,
|
| 236 |
+
int** h_Mi,
|
| 237 |
+
double** h_Mx);
|
| 238 |
+
|
| 239 |
+
/* CUSOLVERRF extraction: Get L & U individually */
|
| 240 |
+
cusolverStatus_t CUSOLVERAPI
|
| 241 |
+
cusolverRfExtractSplitFactorsHost(/* Input */
|
| 242 |
+
cusolverRfHandle_t handle,
|
| 243 |
+
/* Output (in the host memory) */
|
| 244 |
+
int* h_nnzL,
|
| 245 |
+
int** h_csrRowPtrL,
|
| 246 |
+
int** h_csrColIndL,
|
| 247 |
+
double** h_csrValL,
|
| 248 |
+
int* h_nnzU,
|
| 249 |
+
int** h_csrRowPtrU,
|
| 250 |
+
int** h_csrColIndU,
|
| 251 |
+
double** h_csrValU);
|
| 252 |
+
|
| 253 |
+
/* CUSOLVERRF (forward and backward triangular) solves */
|
| 254 |
+
cusolverStatus_t CUSOLVERAPI
|
| 255 |
+
cusolverRfSolve(/* Input (in the device memory) */
|
| 256 |
+
cusolverRfHandle_t handle,
|
| 257 |
+
int* P,
|
| 258 |
+
int* Q,
|
| 259 |
+
int nrhs, // only nrhs=1 is supported
|
| 260 |
+
double* Temp, // of size ldt*nrhs (ldt>=n)
|
| 261 |
+
int ldt,
|
| 262 |
+
/* Input/Output (in the device memory) */
|
| 263 |
+
double* XF,
|
| 264 |
+
/* Input */
|
| 265 |
+
int ldxf);
|
| 266 |
+
|
| 267 |
+
/*** Batched Routines ***/
|
| 268 |
+
/* CUSOLVERRF-batch setup of internal structures from host */
|
| 269 |
+
cusolverStatus_t CUSOLVERAPI
|
| 270 |
+
cusolverRfBatchSetupHost(/* Input (in the host memory)*/
|
| 271 |
+
int batchSize,
|
| 272 |
+
int n,
|
| 273 |
+
int nnzA,
|
| 274 |
+
int* h_csrRowPtrA,
|
| 275 |
+
int* h_csrColIndA,
|
| 276 |
+
double* h_csrValA_array[],
|
| 277 |
+
int nnzL,
|
| 278 |
+
int* h_csrRowPtrL,
|
| 279 |
+
int* h_csrColIndL,
|
| 280 |
+
double* h_csrValL,
|
| 281 |
+
int nnzU,
|
| 282 |
+
int* h_csrRowPtrU,
|
| 283 |
+
int* h_csrColIndU,
|
| 284 |
+
double* h_csrValU,
|
| 285 |
+
int* h_P,
|
| 286 |
+
int* h_Q,
|
| 287 |
+
/* Output (in the device memory) */
|
| 288 |
+
cusolverRfHandle_t handle);
|
| 289 |
+
|
| 290 |
+
/* CUSOLVERRF-batch update the matrix values (assuming the reordering,
|
| 291 |
+
pivoting and consequently the sparsity pattern of L and U did not change),
|
| 292 |
+
and zero out the remaining values. */
|
| 293 |
+
cusolverStatus_t CUSOLVERAPI
|
| 294 |
+
cusolverRfBatchResetValues(/* Input (in the device memory) */
|
| 295 |
+
int batchSize,
|
| 296 |
+
int n,
|
| 297 |
+
int nnzA,
|
| 298 |
+
int* csrRowPtrA,
|
| 299 |
+
int* csrColIndA,
|
| 300 |
+
double* csrValA_array[],
|
| 301 |
+
int* P,
|
| 302 |
+
int* Q,
|
| 303 |
+
/* Output */
|
| 304 |
+
cusolverRfHandle_t handle);
|
| 305 |
+
|
| 306 |
+
/* CUSOLVERRF-batch analysis (for parallelism) */
|
| 307 |
+
cusolverStatus_t CUSOLVERAPI
|
| 308 |
+
cusolverRfBatchAnalyze(cusolverRfHandle_t handle);
|
| 309 |
+
|
| 310 |
+
/* CUSOLVERRF-batch re-factorization (for parallelism) */
|
| 311 |
+
cusolverStatus_t CUSOLVERAPI
|
| 312 |
+
cusolverRfBatchRefactor(cusolverRfHandle_t handle);
|
| 313 |
+
|
| 314 |
+
/* CUSOLVERRF-batch (forward and backward triangular) solves */
|
| 315 |
+
cusolverStatus_t CUSOLVERAPI
|
| 316 |
+
cusolverRfBatchSolve(/* Input (in the device memory) */
|
| 317 |
+
cusolverRfHandle_t handle,
|
| 318 |
+
int* P,
|
| 319 |
+
int* Q,
|
| 320 |
+
int nrhs, // only nrhs=1 is supported
|
| 321 |
+
double* Temp, // of size 2*batchSize*(n*nrhs)
|
| 322 |
+
int ldt, // only ldt=n is supported
|
| 323 |
+
/* Input/Output (in the device memory) */
|
| 324 |
+
double* XF_array[],
|
| 325 |
+
/* Input */
|
| 326 |
+
int ldxf);
|
| 327 |
+
|
| 328 |
+
/* CUSOLVERRF-batch obtain the position of zero pivot */
|
| 329 |
+
cusolverStatus_t CUSOLVERAPI
|
| 330 |
+
cusolverRfBatchZeroPivot(/* Input */
|
| 331 |
+
cusolverRfHandle_t handle,
|
| 332 |
+
/* Output (in the host memory) */
|
| 333 |
+
int* position);
|
| 334 |
+
|
| 335 |
+
#if defined(__cplusplus)
|
| 336 |
+
}
|
| 337 |
+
#endif /* __cplusplus */
|
| 338 |
+
|
| 339 |
+
#endif /* CUSOLVERRF_H_ */
|