diff --git a/.gitattributes b/.gitattributes index 93651752679e61d91a4b6c7ee6ef40807128fb1a..0a1aa7e6d18fa08133e1d1c566c2cd00932e1c44 100644 --- a/.gitattributes +++ b/.gitattributes @@ -120,3 +120,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/click/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/pyasn1/type/__pycache__/univ.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libvpx-9f572e11.so.9.1.0 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/nvidia/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/nvidia/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a06e3d156e265a8b1a785b12396dfb9464bcb668 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/nvidia/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py b/.venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47c1362208903c6ef96c60829cc1fd3583542e41 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__init__.py b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0f7e1ee93b8a6338f663829b17541f316899dad Binary files /dev/null and b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas.h b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas.h new file mode 100644 index 0000000000000000000000000000000000000000..96eadad8a8e8c3979b99910ceea41ceaf2c8b58e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas.h @@ -0,0 +1,891 @@ +/* + * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +/* + * This is the public header file for the CUBLAS library, defining the API + * + * CUBLAS is an implementation of BLAS (Basic Linear Algebra Subroutines) + * on top of the CUDA runtime. + */ + +#if !defined(CUBLAS_H_) +#define CUBLAS_H_ + +#if defined(CUBLAS_V2_H_) +#error "It is an error to include both cublas.h and cublas_v2.h" +#endif + +#include + +#ifndef CUBLASWINAPI +#ifdef _WIN32 +#define CUBLASWINAPI __stdcall +#else +#define CUBLASWINAPI +#endif +#endif + +#undef CUBLASAPI +#ifdef __CUDACC__ +#define CUBLASAPI __host__ +#else +#define CUBLASAPI +#endif + +#include "cublas_api.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +/* CUBLAS data types */ +#define cublasStatus cublasStatus_t + +cublasStatus CUBLASWINAPI cublasInit(void); +cublasStatus CUBLASWINAPI cublasShutdown(void); +cublasStatus CUBLASWINAPI cublasGetError(void); + +cublasStatus CUBLASWINAPI cublasGetVersion(int* version); +cublasStatus CUBLASWINAPI cublasAlloc(int n, int elemSize, void** devicePtr); + +cublasStatus CUBLASWINAPI cublasFree(void* devicePtr); + +cublasStatus CUBLASWINAPI cublasSetKernelStream(cudaStream_t stream); + +/* ---------------- CUBLAS BLAS1 functions ---------------- */ +/* NRM2 */ +float CUBLASWINAPI cublasSnrm2(int n, const float* x, int incx); +double CUBLASWINAPI cublasDnrm2(int n, const double* x, int incx); +float CUBLASWINAPI cublasScnrm2(int n, const cuComplex* x, int incx); +double CUBLASWINAPI cublasDznrm2(int n, const cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* DOT */ +float CUBLASWINAPI cublasSdot(int n, const float* x, int incx, const float* y, int incy); +double CUBLASWINAPI cublasDdot(int n, const double* x, int incx, const double* y, int incy); +cuComplex CUBLASWINAPI cublasCdotu(int n, const cuComplex* x, int incx, const cuComplex* y, int incy); +cuComplex CUBLASWINAPI cublasCdotc(int n, const cuComplex* x, int incx, const cuComplex* y, int incy); +cuDoubleComplex CUBLASWINAPI cublasZdotu(int n, const cuDoubleComplex* x, int incx, const cuDoubleComplex* y, int incy); +cuDoubleComplex CUBLASWINAPI cublasZdotc(int n, const cuDoubleComplex* x, int incx, const cuDoubleComplex* y, int incy); +/*------------------------------------------------------------------------*/ +/* SCAL */ +void CUBLASWINAPI cublasSscal(int n, float alpha, float* x, int incx); +void CUBLASWINAPI cublasDscal(int n, double alpha, double* x, int incx); +void CUBLASWINAPI cublasCscal(int n, cuComplex alpha, cuComplex* x, int incx); +void CUBLASWINAPI cublasZscal(int n, cuDoubleComplex alpha, cuDoubleComplex* x, int incx); + +void CUBLASWINAPI cublasCsscal(int n, float alpha, cuComplex* x, int incx); +void CUBLASWINAPI cublasZdscal(int n, double alpha, cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* AXPY */ +void CUBLASWINAPI cublasSaxpy(int n, float alpha, const float* x, int incx, float* y, int incy); +void CUBLASWINAPI cublasDaxpy(int n, double alpha, const double* x, int incx, double* y, int incy); +void CUBLASWINAPI cublasCaxpy(int n, cuComplex alpha, const cuComplex* x, int incx, cuComplex* y, int incy); +void CUBLASWINAPI +cublasZaxpy(int n, cuDoubleComplex alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy); +/*------------------------------------------------------------------------*/ +/* COPY */ +void CUBLASWINAPI cublasScopy(int n, const float* x, int incx, float* y, int incy); +void CUBLASWINAPI cublasDcopy(int n, const double* x, int incx, double* y, int incy); +void CUBLASWINAPI cublasCcopy(int n, const cuComplex* x, int incx, cuComplex* y, int incy); +void CUBLASWINAPI cublasZcopy(int n, const cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy); +/*------------------------------------------------------------------------*/ +/* SWAP */ +void CUBLASWINAPI cublasSswap(int n, float* x, int incx, float* y, int incy); +void CUBLASWINAPI cublasDswap(int n, double* x, int incx, double* y, int incy); +void CUBLASWINAPI cublasCswap(int n, cuComplex* x, int incx, cuComplex* y, int incy); +void CUBLASWINAPI cublasZswap(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy); +/*------------------------------------------------------------------------*/ +/* AMAX */ +int CUBLASWINAPI cublasIsamax(int n, const float* x, int incx); +int CUBLASWINAPI cublasIdamax(int n, const double* x, int incx); +int CUBLASWINAPI cublasIcamax(int n, const cuComplex* x, int incx); +int CUBLASWINAPI cublasIzamax(int n, const cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* AMIN */ +int CUBLASWINAPI cublasIsamin(int n, const float* x, int incx); +int CUBLASWINAPI cublasIdamin(int n, const double* x, int incx); + +int CUBLASWINAPI cublasIcamin(int n, const cuComplex* x, int incx); +int CUBLASWINAPI cublasIzamin(int n, const cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* ASUM */ +float CUBLASWINAPI cublasSasum(int n, const float* x, int incx); +double CUBLASWINAPI cublasDasum(int n, const double* x, int incx); +float CUBLASWINAPI cublasScasum(int n, const cuComplex* x, int incx); +double CUBLASWINAPI cublasDzasum(int n, const cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* ROT */ +void CUBLASWINAPI cublasSrot(int n, float* x, int incx, float* y, int incy, float sc, float ss); +void CUBLASWINAPI cublasDrot(int n, double* x, int incx, double* y, int incy, double sc, double ss); +void CUBLASWINAPI cublasCrot(int n, cuComplex* x, int incx, cuComplex* y, int incy, float c, cuComplex s); +void CUBLASWINAPI +cublasZrot(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, double sc, cuDoubleComplex cs); +void CUBLASWINAPI cublasCsrot(int n, cuComplex* x, int incx, cuComplex* y, int incy, float c, float s); +void CUBLASWINAPI cublasZdrot(int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, double c, double s); +/*------------------------------------------------------------------------*/ +/* ROTG */ +void CUBLASWINAPI cublasSrotg(float* sa, float* sb, float* sc, float* ss); +void CUBLASWINAPI cublasDrotg(double* sa, double* sb, double* sc, double* ss); +void CUBLASWINAPI cublasCrotg(cuComplex* ca, cuComplex cb, float* sc, cuComplex* cs); +void CUBLASWINAPI cublasZrotg(cuDoubleComplex* ca, cuDoubleComplex cb, double* sc, cuDoubleComplex* cs); +/*------------------------------------------------------------------------*/ +/* ROTM */ +void CUBLASWINAPI cublasSrotm(int n, float* x, int incx, float* y, int incy, const float* sparam); +void CUBLASWINAPI cublasDrotm(int n, double* x, int incx, double* y, int incy, const double* sparam); +/*------------------------------------------------------------------------*/ +/* ROTMG */ +void CUBLASWINAPI cublasSrotmg(float* sd1, float* sd2, float* sx1, const float* sy1, float* sparam); +void CUBLASWINAPI cublasDrotmg(double* sd1, double* sd2, double* sx1, const double* sy1, double* sparam); + +/* --------------- CUBLAS BLAS2 functions ---------------- */ +/* GEMV */ +void CUBLASWINAPI cublasSgemv(char trans, + int m, + int n, + float alpha, + const float* A, + int lda, + const float* x, + int incx, + float beta, + float* y, + int incy); +void CUBLASWINAPI cublasDgemv(char trans, + int m, + int n, + double alpha, + const double* A, + int lda, + const double* x, + int incx, + double beta, + double* y, + int incy); +void CUBLASWINAPI cublasCgemv(char trans, + int m, + int n, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + cuComplex beta, + cuComplex* y, + int incy); +void CUBLASWINAPI cublasZgemv(char trans, + int m, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex beta, + cuDoubleComplex* y, + int incy); +/*------------------------------------------------------------------------*/ +/* GBMV */ +void CUBLASWINAPI cublasSgbmv(char trans, + int m, + int n, + int kl, + int ku, + float alpha, + const float* A, + int lda, + const float* x, + int incx, + float beta, + float* y, + int incy); +void CUBLASWINAPI cublasDgbmv(char trans, + int m, + int n, + int kl, + int ku, + double alpha, + const double* A, + int lda, + const double* x, + int incx, + double beta, + double* y, + int incy); +void CUBLASWINAPI cublasCgbmv(char trans, + int m, + int n, + int kl, + int ku, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + cuComplex beta, + cuComplex* y, + int incy); +void CUBLASWINAPI cublasZgbmv(char trans, + int m, + int n, + int kl, + int ku, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex beta, + cuDoubleComplex* y, + int incy); +/*------------------------------------------------------------------------*/ +/* TRMV */ +void CUBLASWINAPI cublasStrmv(char uplo, char trans, char diag, int n, const float* A, int lda, float* x, int incx); +void CUBLASWINAPI cublasDtrmv(char uplo, char trans, char diag, int n, const double* A, int lda, double* x, int incx); +void CUBLASWINAPI +cublasCtrmv(char uplo, char trans, char diag, int n, const cuComplex* A, int lda, cuComplex* x, int incx); +void CUBLASWINAPI +cublasZtrmv(char uplo, char trans, char diag, int n, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* TBMV */ +void CUBLASWINAPI +cublasStbmv(char uplo, char trans, char diag, int n, int k, const float* A, int lda, float* x, int incx); +void CUBLASWINAPI +cublasDtbmv(char uplo, char trans, char diag, int n, int k, const double* A, int lda, double* x, int incx); +void CUBLASWINAPI +cublasCtbmv(char uplo, char trans, char diag, int n, int k, const cuComplex* A, int lda, cuComplex* x, int incx); +void CUBLASWINAPI cublasZtbmv( + char uplo, char trans, char diag, int n, int k, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* TPMV */ +void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, const float* AP, float* x, int incx); + +void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, const double* AP, double* x, int incx); + +void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, const cuComplex* AP, cuComplex* x, int incx); + +void CUBLASWINAPI +cublasZtpmv(char uplo, char trans, char diag, int n, const cuDoubleComplex* AP, cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* TRSV */ +void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, const float* A, int lda, float* x, int incx); + +void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, const double* A, int lda, double* x, int incx); + +void CUBLASWINAPI +cublasCtrsv(char uplo, char trans, char diag, int n, const cuComplex* A, int lda, cuComplex* x, int incx); + +void CUBLASWINAPI +cublasZtrsv(char uplo, char trans, char diag, int n, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* TPSV */ +void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, const float* AP, float* x, int incx); + +void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, const double* AP, double* x, int incx); + +void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, const cuComplex* AP, cuComplex* x, int incx); + +void CUBLASWINAPI +cublasZtpsv(char uplo, char trans, char diag, int n, const cuDoubleComplex* AP, cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* TBSV */ +void CUBLASWINAPI +cublasStbsv(char uplo, char trans, char diag, int n, int k, const float* A, int lda, float* x, int incx); + +void CUBLASWINAPI +cublasDtbsv(char uplo, char trans, char diag, int n, int k, const double* A, int lda, double* x, int incx); +void CUBLASWINAPI +cublasCtbsv(char uplo, char trans, char diag, int n, int k, const cuComplex* A, int lda, cuComplex* x, int incx); + +void CUBLASWINAPI cublasZtbsv( + char uplo, char trans, char diag, int n, int k, const cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx); +/*------------------------------------------------------------------------*/ +/* SYMV/HEMV */ +void CUBLASWINAPI cublasSsymv( + char uplo, int n, float alpha, const float* A, int lda, const float* x, int incx, float beta, float* y, int incy); +void CUBLASWINAPI cublasDsymv(char uplo, + int n, + double alpha, + const double* A, + int lda, + const double* x, + int incx, + double beta, + double* y, + int incy); +void CUBLASWINAPI cublasChemv(char uplo, + int n, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + cuComplex beta, + cuComplex* y, + int incy); +void CUBLASWINAPI cublasZhemv(char uplo, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex beta, + cuDoubleComplex* y, + int incy); +/*------------------------------------------------------------------------*/ +/* SBMV/HBMV */ +void CUBLASWINAPI cublasSsbmv(char uplo, + int n, + int k, + float alpha, + const float* A, + int lda, + const float* x, + int incx, + float beta, + float* y, + int incy); +void CUBLASWINAPI cublasDsbmv(char uplo, + int n, + int k, + double alpha, + const double* A, + int lda, + const double* x, + int incx, + double beta, + double* y, + int incy); +void CUBLASWINAPI cublasChbmv(char uplo, + int n, + int k, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + cuComplex beta, + cuComplex* y, + int incy); +void CUBLASWINAPI cublasZhbmv(char uplo, + int n, + int k, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex beta, + cuDoubleComplex* y, + int incy); +/*------------------------------------------------------------------------*/ +/* SPMV/HPMV */ +void CUBLASWINAPI +cublasSspmv(char uplo, int n, float alpha, const float* AP, const float* x, int incx, float beta, float* y, int incy); +void CUBLASWINAPI cublasDspmv( + char uplo, int n, double alpha, const double* AP, const double* x, int incx, double beta, double* y, int incy); +void CUBLASWINAPI cublasChpmv(char uplo, + int n, + cuComplex alpha, + const cuComplex* AP, + const cuComplex* x, + int incx, + cuComplex beta, + cuComplex* y, + int incy); +void CUBLASWINAPI cublasZhpmv(char uplo, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* AP, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex beta, + cuDoubleComplex* y, + int incy); + +/*------------------------------------------------------------------------*/ +/* GER */ +void CUBLASWINAPI +cublasSger(int m, int n, float alpha, const float* x, int incx, const float* y, int incy, float* A, int lda); +void CUBLASWINAPI +cublasDger(int m, int n, double alpha, const double* x, int incx, const double* y, int incy, double* A, int lda); + +void CUBLASWINAPI cublasCgeru( + int m, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* A, int lda); +void CUBLASWINAPI cublasCgerc( + int m, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* A, int lda); +void CUBLASWINAPI cublasZgeru(int m, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* A, + int lda); +void CUBLASWINAPI cublasZgerc(int m, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* A, + int lda); +/*------------------------------------------------------------------------*/ +/* SYR/HER */ +void CUBLASWINAPI cublasSsyr(char uplo, int n, float alpha, const float* x, int incx, float* A, int lda); +void CUBLASWINAPI cublasDsyr(char uplo, int n, double alpha, const double* x, int incx, double* A, int lda); + +void CUBLASWINAPI cublasCher(char uplo, int n, float alpha, const cuComplex* x, int incx, cuComplex* A, int lda); +void CUBLASWINAPI +cublasZher(char uplo, int n, double alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* A, int lda); + +/*------------------------------------------------------------------------*/ +/* SPR/HPR */ +void CUBLASWINAPI cublasSspr(char uplo, int n, float alpha, const float* x, int incx, float* AP); +void CUBLASWINAPI cublasDspr(char uplo, int n, double alpha, const double* x, int incx, double* AP); +void CUBLASWINAPI cublasChpr(char uplo, int n, float alpha, const cuComplex* x, int incx, cuComplex* AP); +void CUBLASWINAPI cublasZhpr(char uplo, int n, double alpha, const cuDoubleComplex* x, int incx, cuDoubleComplex* AP); +/*------------------------------------------------------------------------*/ +/* SYR2/HER2 */ +void CUBLASWINAPI +cublasSsyr2(char uplo, int n, float alpha, const float* x, int incx, const float* y, int incy, float* A, int lda); +void CUBLASWINAPI +cublasDsyr2(char uplo, int n, double alpha, const double* x, int incx, const double* y, int incy, double* A, int lda); +void CUBLASWINAPI cublasCher2(char uplo, + int n, + cuComplex alpha, + const cuComplex* x, + int incx, + const cuComplex* y, + int incy, + cuComplex* A, + int lda); +void CUBLASWINAPI cublasZher2(char uplo, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* A, + int lda); + +/*------------------------------------------------------------------------*/ +/* SPR2/HPR2 */ +void CUBLASWINAPI +cublasSspr2(char uplo, int n, float alpha, const float* x, int incx, const float* y, int incy, float* AP); +void CUBLASWINAPI +cublasDspr2(char uplo, int n, double alpha, const double* x, int incx, const double* y, int incy, double* AP); +void CUBLASWINAPI cublasChpr2( + char uplo, int n, cuComplex alpha, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* AP); +void CUBLASWINAPI cublasZhpr2(char uplo, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* AP); +/* ------------------------BLAS3 Functions ------------------------------- */ +/* GEMM */ +void CUBLASWINAPI cublasSgemm(char transa, + char transb, + int m, + int n, + int k, + float alpha, + const float* A, + int lda, + const float* B, + int ldb, + float beta, + float* C, + int ldc); +void CUBLASWINAPI cublasDgemm(char transa, + char transb, + int m, + int n, + int k, + double alpha, + const double* A, + int lda, + const double* B, + int ldb, + double beta, + double* C, + int ldc); +void CUBLASWINAPI cublasCgemm(char transa, + char transb, + int m, + int n, + int k, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + cuComplex beta, + cuComplex* C, + int ldc); +void CUBLASWINAPI cublasZgemm(char transa, + char transb, + int m, + int n, + int k, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + cuDoubleComplex beta, + cuDoubleComplex* C, + int ldc); +/* -------------------------------------------------------*/ +/* SYRK */ +void CUBLASWINAPI +cublasSsyrk(char uplo, char trans, int n, int k, float alpha, const float* A, int lda, float beta, float* C, int ldc); +void CUBLASWINAPI cublasDsyrk( + char uplo, char trans, int n, int k, double alpha, const double* A, int lda, double beta, double* C, int ldc); + +void CUBLASWINAPI cublasCsyrk(char uplo, + char trans, + int n, + int k, + cuComplex alpha, + const cuComplex* A, + int lda, + cuComplex beta, + cuComplex* C, + int ldc); +void CUBLASWINAPI cublasZsyrk(char uplo, + char trans, + int n, + int k, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + cuDoubleComplex beta, + cuDoubleComplex* C, + int ldc); +/* ------------------------------------------------------- */ +/* HERK */ +void CUBLASWINAPI cublasCherk( + char uplo, char trans, int n, int k, float alpha, const cuComplex* A, int lda, float beta, cuComplex* C, int ldc); +void CUBLASWINAPI cublasZherk(char uplo, + char trans, + int n, + int k, + double alpha, + const cuDoubleComplex* A, + int lda, + double beta, + cuDoubleComplex* C, + int ldc); +/* ------------------------------------------------------- */ +/* SYR2K */ +void CUBLASWINAPI cublasSsyr2k(char uplo, + char trans, + int n, + int k, + float alpha, + const float* A, + int lda, + const float* B, + int ldb, + float beta, + float* C, + int ldc); + +void CUBLASWINAPI cublasDsyr2k(char uplo, + char trans, + int n, + int k, + double alpha, + const double* A, + int lda, + const double* B, + int ldb, + double beta, + double* C, + int ldc); +void CUBLASWINAPI cublasCsyr2k(char uplo, + char trans, + int n, + int k, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + cuComplex beta, + cuComplex* C, + int ldc); + +void CUBLASWINAPI cublasZsyr2k(char uplo, + char trans, + int n, + int k, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + cuDoubleComplex beta, + cuDoubleComplex* C, + int ldc); +/* ------------------------------------------------------- */ +/* HER2K */ +void CUBLASWINAPI cublasCher2k(char uplo, + char trans, + int n, + int k, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + float beta, + cuComplex* C, + int ldc); + +void CUBLASWINAPI cublasZher2k(char uplo, + char trans, + int n, + int k, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + double beta, + cuDoubleComplex* C, + int ldc); + +/*------------------------------------------------------------------------*/ +/* SYMM*/ +void CUBLASWINAPI cublasSsymm(char side, + char uplo, + int m, + int n, + float alpha, + const float* A, + int lda, + const float* B, + int ldb, + float beta, + float* C, + int ldc); +void CUBLASWINAPI cublasDsymm(char side, + char uplo, + int m, + int n, + double alpha, + const double* A, + int lda, + const double* B, + int ldb, + double beta, + double* C, + int ldc); + +void CUBLASWINAPI cublasCsymm(char side, + char uplo, + int m, + int n, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + cuComplex beta, + cuComplex* C, + int ldc); + +void CUBLASWINAPI cublasZsymm(char side, + char uplo, + int m, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + cuDoubleComplex beta, + cuDoubleComplex* C, + int ldc); +/*------------------------------------------------------------------------*/ +/* HEMM*/ +void CUBLASWINAPI cublasChemm(char side, + char uplo, + int m, + int n, + cuComplex alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + cuComplex beta, + cuComplex* C, + int ldc); +void CUBLASWINAPI cublasZhemm(char side, + char uplo, + int m, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + cuDoubleComplex beta, + cuDoubleComplex* C, + int ldc); + +/*------------------------------------------------------------------------*/ +/* TRSM*/ +void CUBLASWINAPI cublasStrsm(char side, + char uplo, + char transa, + char diag, + int m, + int n, + float alpha, + const float* A, + int lda, + float* B, + int ldb); + +void CUBLASWINAPI cublasDtrsm(char side, + char uplo, + char transa, + char diag, + int m, + int n, + double alpha, + const double* A, + int lda, + double* B, + int ldb); + +void CUBLASWINAPI cublasCtrsm(char side, + char uplo, + char transa, + char diag, + int m, + int n, + cuComplex alpha, + const cuComplex* A, + int lda, + cuComplex* B, + int ldb); + +void CUBLASWINAPI cublasZtrsm(char side, + char uplo, + char transa, + char diag, + int m, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + cuDoubleComplex* B, + int ldb); +/*------------------------------------------------------------------------*/ +/* TRMM*/ +void CUBLASWINAPI cublasStrmm(char side, + char uplo, + char transa, + char diag, + int m, + int n, + float alpha, + const float* A, + int lda, + float* B, + int ldb); +void CUBLASWINAPI cublasDtrmm(char side, + char uplo, + char transa, + char diag, + int m, + int n, + double alpha, + const double* A, + int lda, + double* B, + int ldb); +void CUBLASWINAPI cublasCtrmm(char side, + char uplo, + char transa, + char diag, + int m, + int n, + cuComplex alpha, + const cuComplex* A, + int lda, + cuComplex* B, + int ldb); +void CUBLASWINAPI cublasZtrmm(char side, + char uplo, + char transa, + char diag, + int m, + int n, + cuDoubleComplex alpha, + const cuDoubleComplex* A, + int lda, + cuDoubleComplex* B, + int ldb); + +#if defined(__cplusplus) +} +#endif /* __cplusplus */ + +#endif /* !defined(CUBLAS_H_) */ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h new file mode 100644 index 0000000000000000000000000000000000000000..a7c9f346cadcad731d90e2f2c75f1549ff68240e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h @@ -0,0 +1,1845 @@ +/* + * Copyright 1993-2022 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ +#pragma once + +#ifndef CUBLASAPI +#ifdef __CUDACC__ +#define CUBLASAPI __host__ __device__ +#else +#define CUBLASAPI +#endif +#endif + +#include + +#include +#include +#include + +#if defined(__cplusplus) +extern "C" { +#endif /* __cplusplus */ + +/** Opaque structure holding CUBLASLT context + */ +typedef struct cublasLtContext* cublasLtHandle_t; + +cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t* lightHandle); + +cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle); + +const char* CUBLASWINAPI cublasLtGetStatusName(cublasStatus_t status); + +const char* CUBLASWINAPI cublasLtGetStatusString(cublasStatus_t status); + +size_t CUBLASWINAPI cublasLtGetVersion(void); + +size_t CUBLASWINAPI cublasLtGetCudartVersion(void); + +cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type, int* value); + +cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheGetCapacity(size_t* capacity); +cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheSetCapacity(size_t capacity); + +/** Restricts usage of CPU instructions (ISA) specified by the flags in the mask. + * + * Flags can be combined with bitwise OR(|) operator. Supported flags: + * - 0x1 -- x86-64 AVX512 ISA + * + * Default mask: 0 (any applicable ISA is allowed). + * + * The function returns the previous value of the mask. + * The function takes precedence over the environment variable CUBLASLT_DISABLE_CPU_INSTRUCTIONS_MASK. + */ +unsigned CUBLASWINAPI cublasLtDisableCpuInstructionsSetMask(unsigned mask); + +/** Semi-opaque descriptor for matrix memory layout + */ +typedef struct { + uint64_t data[8]; +} cublasLtMatrixLayoutOpaque_t; + +/** Opaque descriptor for matrix memory layout + */ +typedef cublasLtMatrixLayoutOpaque_t* cublasLtMatrixLayout_t; + +/** Semi-opaque algorithm descriptor (to avoid complicated alloc/free schemes) + * + * This structure can be trivially serialized and later restored for use with the same version of cuBLAS library to save + * on selecting the right configuration again. + */ +typedef struct { + uint64_t data[8]; +} cublasLtMatmulAlgo_t; + +/** Semi-opaque descriptor for cublasLtMatmul() operation details + */ +typedef struct { + uint64_t data[32]; +} cublasLtMatmulDescOpaque_t; + +/** Opaque descriptor for cublasLtMatmul() operation details + */ +typedef cublasLtMatmulDescOpaque_t* cublasLtMatmulDesc_t; + +/** Semi-opaque descriptor for cublasLtMatrixTransform() operation details + */ +typedef struct { + uint64_t data[8]; +} cublasLtMatrixTransformDescOpaque_t; + +/** Opaque descriptor for cublasLtMatrixTransform() operation details + */ +typedef cublasLtMatrixTransformDescOpaque_t* cublasLtMatrixTransformDesc_t; + +/** Semi-opaque descriptor for cublasLtMatmulPreference() operation details + */ +typedef struct { + uint64_t data[8]; +} cublasLtMatmulPreferenceOpaque_t; + +/** Opaque descriptor for cublasLtMatmulAlgoGetHeuristic() configuration + */ +typedef cublasLtMatmulPreferenceOpaque_t* cublasLtMatmulPreference_t; + +/** Tile size (in C/D matrix Rows x Cols) + * + * General order of tile IDs is sorted by size first and by first dimension second. + */ +typedef enum { + CUBLASLT_MATMUL_TILE_UNDEFINED = 0, + CUBLASLT_MATMUL_TILE_8x8 = 1, + CUBLASLT_MATMUL_TILE_8x16 = 2, + CUBLASLT_MATMUL_TILE_16x8 = 3, + CUBLASLT_MATMUL_TILE_8x32 = 4, + CUBLASLT_MATMUL_TILE_16x16 = 5, + CUBLASLT_MATMUL_TILE_32x8 = 6, + CUBLASLT_MATMUL_TILE_8x64 = 7, + CUBLASLT_MATMUL_TILE_16x32 = 8, + CUBLASLT_MATMUL_TILE_32x16 = 9, + CUBLASLT_MATMUL_TILE_64x8 = 10, + CUBLASLT_MATMUL_TILE_32x32 = 11, + CUBLASLT_MATMUL_TILE_32x64 = 12, + CUBLASLT_MATMUL_TILE_64x32 = 13, + CUBLASLT_MATMUL_TILE_32x128 = 14, + CUBLASLT_MATMUL_TILE_64x64 = 15, + CUBLASLT_MATMUL_TILE_128x32 = 16, + CUBLASLT_MATMUL_TILE_64x128 = 17, + CUBLASLT_MATMUL_TILE_128x64 = 18, + CUBLASLT_MATMUL_TILE_64x256 = 19, + CUBLASLT_MATMUL_TILE_128x128 = 20, + CUBLASLT_MATMUL_TILE_256x64 = 21, + CUBLASLT_MATMUL_TILE_64x512 = 22, + CUBLASLT_MATMUL_TILE_128x256 = 23, + CUBLASLT_MATMUL_TILE_256x128 = 24, + CUBLASLT_MATMUL_TILE_512x64 = 25, + CUBLASLT_MATMUL_TILE_64x96 = 26, + CUBLASLT_MATMUL_TILE_96x64 = 27, + CUBLASLT_MATMUL_TILE_96x128 = 28, + CUBLASLT_MATMUL_TILE_128x160 = 29, + CUBLASLT_MATMUL_TILE_160x128 = 30, + CUBLASLT_MATMUL_TILE_192x128 = 31, + CUBLASLT_MATMUL_TILE_128x192 = 32, + CUBLASLT_MATMUL_TILE_128x96 = 33, + CUBLASLT_MATMUL_TILE_32x256 = 34, + CUBLASLT_MATMUL_TILE_256x32 = 35, + CUBLASLT_MATMUL_TILE_END +} cublasLtMatmulTile_t; + +/** Size and number of stages in which elements are read into shared memory + * + * General order of stages IDs is sorted by stage size first and by number of stages second. + */ +typedef enum { + CUBLASLT_MATMUL_STAGES_UNDEFINED = 0, + CUBLASLT_MATMUL_STAGES_16x1 = 1, + CUBLASLT_MATMUL_STAGES_16x2 = 2, + CUBLASLT_MATMUL_STAGES_16x3 = 3, + CUBLASLT_MATMUL_STAGES_16x4 = 4, + CUBLASLT_MATMUL_STAGES_16x5 = 5, + CUBLASLT_MATMUL_STAGES_16x6 = 6, + CUBLASLT_MATMUL_STAGES_32x1 = 7, + CUBLASLT_MATMUL_STAGES_32x2 = 8, + CUBLASLT_MATMUL_STAGES_32x3 = 9, + CUBLASLT_MATMUL_STAGES_32x4 = 10, + CUBLASLT_MATMUL_STAGES_32x5 = 11, + CUBLASLT_MATMUL_STAGES_32x6 = 12, + CUBLASLT_MATMUL_STAGES_64x1 = 13, + CUBLASLT_MATMUL_STAGES_64x2 = 14, + CUBLASLT_MATMUL_STAGES_64x3 = 15, + CUBLASLT_MATMUL_STAGES_64x4 = 16, + CUBLASLT_MATMUL_STAGES_64x5 = 17, + CUBLASLT_MATMUL_STAGES_64x6 = 18, + CUBLASLT_MATMUL_STAGES_128x1 = 19, + CUBLASLT_MATMUL_STAGES_128x2 = 20, + CUBLASLT_MATMUL_STAGES_128x3 = 21, + CUBLASLT_MATMUL_STAGES_128x4 = 22, + CUBLASLT_MATMUL_STAGES_128x5 = 23, + CUBLASLT_MATMUL_STAGES_128x6 = 24, + CUBLASLT_MATMUL_STAGES_32x10 = 25, + CUBLASLT_MATMUL_STAGES_8x4 = 26, + CUBLASLT_MATMUL_STAGES_16x10 = 27, + CUBLASLT_MATMUL_STAGES_8x5 = 28, + CUBLASLT_MATMUL_STAGES_8x3 = 31, + CUBLASLT_MATMUL_STAGES_8xAUTO = 32, + CUBLASLT_MATMUL_STAGES_16xAUTO = 33, + CUBLASLT_MATMUL_STAGES_32xAUTO = 34, + CUBLASLT_MATMUL_STAGES_64xAUTO = 35, + CUBLASLT_MATMUL_STAGES_128xAUTO = 36, + CUBLASLT_MATMUL_STAGES_END +} cublasLtMatmulStages_t; + +/** Thread Block Cluster size + * + * Typically dimensioned similar to cublasLtMatmulTile_t, with the third coordinate unused at this time. + */ +typedef enum { + /** Let library pick cluster shape automatically */ + CUBLASLT_CLUSTER_SHAPE_AUTO = 0, + CUBLASLT_CLUSTER_SHAPE_1x1x1 = 2, + CUBLASLT_CLUSTER_SHAPE_2x1x1 = 3, + CUBLASLT_CLUSTER_SHAPE_4x1x1 = 4, + CUBLASLT_CLUSTER_SHAPE_1x2x1 = 5, + CUBLASLT_CLUSTER_SHAPE_2x2x1 = 6, + CUBLASLT_CLUSTER_SHAPE_4x2x1 = 7, + CUBLASLT_CLUSTER_SHAPE_1x4x1 = 8, + CUBLASLT_CLUSTER_SHAPE_2x4x1 = 9, + CUBLASLT_CLUSTER_SHAPE_4x4x1 = 10, + CUBLASLT_CLUSTER_SHAPE_8x1x1 = 11, + CUBLASLT_CLUSTER_SHAPE_1x8x1 = 12, + CUBLASLT_CLUSTER_SHAPE_8x2x1 = 13, + CUBLASLT_CLUSTER_SHAPE_2x8x1 = 14, + CUBLASLT_CLUSTER_SHAPE_16x1x1 = 15, + CUBLASLT_CLUSTER_SHAPE_1x16x1 = 16, + CUBLASLT_CLUSTER_SHAPE_3x1x1 = 17, + CUBLASLT_CLUSTER_SHAPE_5x1x1 = 18, + CUBLASLT_CLUSTER_SHAPE_6x1x1 = 19, + CUBLASLT_CLUSTER_SHAPE_7x1x1 = 20, + CUBLASLT_CLUSTER_SHAPE_9x1x1 = 21, + CUBLASLT_CLUSTER_SHAPE_10x1x1 = 22, + CUBLASLT_CLUSTER_SHAPE_11x1x1 = 23, + CUBLASLT_CLUSTER_SHAPE_12x1x1 = 24, + CUBLASLT_CLUSTER_SHAPE_13x1x1 = 25, + CUBLASLT_CLUSTER_SHAPE_14x1x1 = 26, + CUBLASLT_CLUSTER_SHAPE_15x1x1 = 27, + CUBLASLT_CLUSTER_SHAPE_3x2x1 = 28, + CUBLASLT_CLUSTER_SHAPE_5x2x1 = 29, + CUBLASLT_CLUSTER_SHAPE_6x2x1 = 30, + CUBLASLT_CLUSTER_SHAPE_7x2x1 = 31, + CUBLASLT_CLUSTER_SHAPE_1x3x1 = 32, + CUBLASLT_CLUSTER_SHAPE_2x3x1 = 33, + CUBLASLT_CLUSTER_SHAPE_3x3x1 = 34, + CUBLASLT_CLUSTER_SHAPE_4x3x1 = 35, + CUBLASLT_CLUSTER_SHAPE_5x3x1 = 36, + CUBLASLT_CLUSTER_SHAPE_3x4x1 = 37, + CUBLASLT_CLUSTER_SHAPE_1x5x1 = 38, + CUBLASLT_CLUSTER_SHAPE_2x5x1 = 39, + CUBLASLT_CLUSTER_SHAPE_3x5x1 = 40, + CUBLASLT_CLUSTER_SHAPE_1x6x1 = 41, + CUBLASLT_CLUSTER_SHAPE_2x6x1 = 42, + CUBLASLT_CLUSTER_SHAPE_1x7x1 = 43, + CUBLASLT_CLUSTER_SHAPE_2x7x1 = 44, + CUBLASLT_CLUSTER_SHAPE_1x9x1 = 45, + CUBLASLT_CLUSTER_SHAPE_1x10x1 = 46, + CUBLASLT_CLUSTER_SHAPE_1x11x1 = 47, + CUBLASLT_CLUSTER_SHAPE_1x12x1 = 48, + CUBLASLT_CLUSTER_SHAPE_1x13x1 = 49, + CUBLASLT_CLUSTER_SHAPE_1x14x1 = 50, + CUBLASLT_CLUSTER_SHAPE_1x15x1 = 51, + CUBLASLT_CLUSTER_SHAPE_END +} cublasLtClusterShape_t; + +/** Inner size of the kernel + * + * Represents various aspects of internal kernel design, that don't impact CUDA grid size but may have other more subtle + * effects. + * + */ +typedef enum { + CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED = 0, + CUBLASLT_MATMUL_INNER_SHAPE_MMA884 = 1, + CUBLASLT_MATMUL_INNER_SHAPE_MMA1684 = 2, + CUBLASLT_MATMUL_INNER_SHAPE_MMA1688 = 3, + CUBLASLT_MATMUL_INNER_SHAPE_MMA16816 = 4, + CUBLASLT_MATMUL_INNER_SHAPE_END +} cublasLtMatmulInnerShape_t; + +/** Pointer mode to use for alpha/beta */ +typedef enum { + /** matches CUBLAS_POINTER_MODE_HOST, pointer targets a single value host memory */ + CUBLASLT_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST, + /** matches CUBLAS_POINTER_MODE_DEVICE, pointer targets a single value device memory */ + CUBLASLT_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE, + /** pointer targets an array in device memory */ + CUBLASLT_POINTER_MODE_DEVICE_VECTOR = 2, + /** alpha pointer targets an array in device memory, beta is zero. Note: + CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is not supported, must be 0. */ + CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO = 3, + /** alpha pointer targets an array in device memory, beta is a single value in host memory. */ + CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST = 4, +} cublasLtPointerMode_t; + +/** Mask to define pointer mode capability */ +typedef enum { + /** see CUBLASLT_POINTER_MODE_HOST */ + CUBLASLT_POINTER_MODE_MASK_HOST = 1, + /** see CUBLASLT_POINTER_MODE_DEVICE */ + CUBLASLT_POINTER_MODE_MASK_DEVICE = 2, + /** see CUBLASLT_POINTER_MODE_DEVICE_VECTOR */ + CUBLASLT_POINTER_MODE_MASK_DEVICE_VECTOR = 4, + /** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO */ + CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_ZERO = 8, + /** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST */ + CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_HOST = 16, +} cublasLtPointerModeMask_t; + +/** Implementation details that may affect numerical behavior of algorithms. */ +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA (0x01ull << 0) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA (0x02ull << 0) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_IMMA (0x04ull << 0) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_DMMA (0x08ull << 0) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK (0xfeull << 0) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK (0xffull << 0) + +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_16F (0x01ull << 8) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32F (0x02ull << 8) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_64F (0x04ull << 8) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32I (0x08ull << 8) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_TYPE_MASK (0xffull << 8) + +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16F (0x01ull << 16) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16BF (0x02ull << 16) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_TF32 (0x04ull << 16) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_32F (0x08ull << 16) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_64F (0x10ull << 16) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8I (0x20ull << 16) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E4M3 (0x40ull << 16) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E5M2 (0x80ull << 16) +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK (0xffull << 16) + +#define CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN (0x01ull << 32) +typedef uint64_t cublasLtNumericalImplFlags_t; + +/** Execute matrix multiplication (D = alpha * op(A) * op(B) + beta * C). + * + * \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized + * \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g. + * when workspaceSizeInBytes is less than workspace required by configured + * algo + * \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured + * operation + * \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device + * \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device + * \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmul(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, + const void* alpha, /* host or device pointer */ + const void* A, + cublasLtMatrixLayout_t Adesc, + const void* B, + cublasLtMatrixLayout_t Bdesc, + const void* beta, /* host or device pointer */ + const void* C, + cublasLtMatrixLayout_t Cdesc, + void* D, + cublasLtMatrixLayout_t Ddesc, + const cublasLtMatmulAlgo_t* algo, + void* workspace, + size_t workspaceSizeInBytes, + cudaStream_t stream); + +/** Matrix layout conversion helper (C = alpha * op(A) + beta * op(B)) + * + * Can be used to change memory order of data or to scale and shift the values. + * + * \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized + * \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g. + * when A is not NULL, but Adesc is NULL + * \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured + * operation + * \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device + * \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device + * \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(cublasLtHandle_t lightHandle, + cublasLtMatrixTransformDesc_t transformDesc, + const void* alpha, /* host or device pointer */ + const void* A, + cublasLtMatrixLayout_t Adesc, + const void* beta, /* host or device pointer */ + const void* B, + cublasLtMatrixLayout_t Bdesc, + void* C, + cublasLtMatrixLayout_t Cdesc, + cudaStream_t stream); + +/* ---------------------------------------------------------------------------------------*/ +/* Helper functions for cublasLtMatrixLayout_t */ +/* ---------------------------------------------------------------------------------------*/ + +/** Enum for data ordering */ +typedef enum { + /** Column-major + * + * Leading dimension is the stride (in elements) to the beginning of next column in memory. + */ + CUBLASLT_ORDER_COL = 0, + /** Row major + * + * Leading dimension is the stride (in elements) to the beginning of next row in memory. + */ + CUBLASLT_ORDER_ROW = 1, + /** Column-major ordered tiles of 32 columns. + * + * Leading dimension is the stride (in elements) to the beginning of next group of 32-columns. E.g. if matrix has 33 + * columns and 2 rows, ld must be at least (32) * 2 = 64. + */ + CUBLASLT_ORDER_COL32 = 2, + /** Column-major ordered tiles of composite tiles with total 32 columns and 8 rows, tile composed of interleaved + * inner tiles of 4 columns within 4 even or odd rows in an alternating pattern. + * + * Leading dimension is the stride (in elements) to the beginning of the first 32 column x 8 row tile for the next + * 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32 * 8) * 1 = 256. + */ + CUBLASLT_ORDER_COL4_4R2_8C = 3, + /** Column-major ordered tiles of composite tiles with total 32 columns ands 32 rows. + * Element offset within the tile is calculated as (((row%8)/2*4+row/8)*2+row%2)*32+col. + * + * Leading dimension is the stride (in elements) to the beginning of the first 32 column x 32 row tile for the next + * 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32*32)*1 = 1024. + */ + CUBLASLT_ORDER_COL32_2R_4R4 = 4, + +} cublasLtOrder_t; + +/** Attributes of memory layout */ +typedef enum { + /** Data type, see cudaDataType. + * + * uint32_t + */ + CUBLASLT_MATRIX_LAYOUT_TYPE = 0, + + /** Memory order of the data, see cublasLtOrder_t. + * + * int32_t, default: CUBLASLT_ORDER_COL + */ + CUBLASLT_MATRIX_LAYOUT_ORDER = 1, + + /** Number of rows. + * + * Usually only values that can be expressed as int32_t are supported. + * + * uint64_t + */ + CUBLASLT_MATRIX_LAYOUT_ROWS = 2, + + /** Number of columns. + * + * Usually only values that can be expressed as int32_t are supported. + * + * uint64_t + */ + CUBLASLT_MATRIX_LAYOUT_COLS = 3, + + /** Matrix leading dimension. + * + * For CUBLASLT_ORDER_COL this is stride (in elements) of matrix column, for more details and documentation for + * other memory orders see documentation for cublasLtOrder_t values. + * + * Currently only non-negative values are supported, must be large enough so that matrix memory locations are not + * overlapping (e.g. greater or equal to CUBLASLT_MATRIX_LAYOUT_ROWS in case of CUBLASLT_ORDER_COL). + * + * int64_t; + */ + CUBLASLT_MATRIX_LAYOUT_LD = 4, + + /** Number of matmul operations to perform in the batch. + * + * See also CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT + * + * int32_t, default: 1 + */ + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5, + + /** Stride (in elements) to the next matrix for strided batch operation. + * + * When matrix type is planar-complex (CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET != 0), batch stride + * is interpreted by cublasLtMatmul() in number of real valued sub-elements. E.g. for data of type CUDA_C_16F, + * offset of 1024B is encoded as a stride of value 512 (since each element of the real and imaginary matrices + * is a 2B (16bit) floating point type). + * + * NOTE: A bug in cublasLtMatrixTransform() causes it to interpret the batch stride for a planar-complex matrix + * as if it was specified in number of complex elements. Therefore an offset of 1024B must be encoded as stride + * value 256 when calling cublasLtMatrixTransform() (each complex element is 4B with real and imaginary values 2B + * each). This behavior is expected to be corrected in the next major cuBLAS version. + * + * int64_t, default: 0 + */ + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6, + + /** Stride (in bytes) to the imaginary plane for planar complex layout. + * + * int64_t, default: 0 - 0 means that layout is regular (real and imaginary parts of complex numbers are interleaved + * in memory in each element) + */ + CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET = 7, +} cublasLtMatrixLayoutAttribute_t; + +/** Internal. Do not use directly. + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( // + cublasLtMatrixLayout_t matLayout, + size_t size, + cudaDataType type, + uint64_t rows, + uint64_t cols, + int64_t ld); + +/** Initialize matrix layout descriptor in pre-allocated space. + * + * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient + * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully + */ +static inline cublasStatus_t cublasLtMatrixLayoutInit( + cublasLtMatrixLayout_t matLayout, cudaDataType type, uint64_t rows, uint64_t cols, int64_t ld) { + return cublasLtMatrixLayoutInit_internal(matLayout, sizeof(*matLayout), type, rows, cols, ld); +} + +/** Create new matrix layout descriptor. + * + * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated + * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( // + cublasLtMatrixLayout_t* matLayout, + cudaDataType type, + uint64_t rows, + uint64_t cols, + int64_t ld); + +/** Destroy matrix layout descriptor. + * + * \retval CUBLAS_STATUS_SUCCESS if operation was successful + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout); + +/** Set matrix layout descriptor attribute. + * + * \param[in] matLayout The descriptor + * \param[in] attr The attribute + * \param[in] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * + * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( // + cublasLtMatrixLayout_t matLayout, + cublasLtMatrixLayoutAttribute_t attr, + const void* buf, + size_t sizeInBytes); + +/** Get matrix layout descriptor attribute. + * + * \param[in] matLayout The descriptor + * \param[in] attr The attribute + * \param[out] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of + * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents + * + * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero + * and buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( // + cublasLtMatrixLayout_t matLayout, + cublasLtMatrixLayoutAttribute_t attr, + void* buf, + size_t sizeInBytes, + size_t* sizeWritten); + +/* ---------------------------------------------------------------------------------------*/ +/* Helper functions for cublasLtMatmulDesc_t */ +/* ---------------------------------------------------------------------------------------*/ + +/** Matmul descriptor attributes to define details of the operation. */ +typedef enum { + /** Compute type, see cudaDataType. Defines data type used for multiply and accumulate operations and the + * accumulator during matrix multiplication. + * + * int32_t + */ + CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0, + + /** Scale type, see cudaDataType. Defines data type of alpha and beta. Accumulator and value from matrix C are + * typically converted to scale type before final scaling. Value is then converted from scale type to type of matrix + * D before being stored in memory. + * + * int32_t, default: same as CUBLASLT_MATMUL_DESC_COMPUTE_TYPE + */ + CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1, + + /** Pointer mode of alpha and beta, see cublasLtPointerMode_t. When CUBLASLT_POINTER_MODE_DEVICE_VECTOR is in use, + * alpha/beta vector lenghts must match number of output matrix rows. + * + * int32_t, default: CUBLASLT_POINTER_MODE_HOST + */ + CUBLASLT_MATMUL_DESC_POINTER_MODE = 2, + + /** Transform of matrix A, see cublasOperation_t. + * + * int32_t, default: CUBLAS_OP_N + */ + CUBLASLT_MATMUL_DESC_TRANSA = 3, + + /** Transform of matrix B, see cublasOperation_t. + * + * int32_t, default: CUBLAS_OP_N + */ + CUBLASLT_MATMUL_DESC_TRANSB = 4, + + /** Transform of matrix C, see cublasOperation_t. + * + * Currently only CUBLAS_OP_N is supported. + * + * int32_t, default: CUBLAS_OP_N + */ + CUBLASLT_MATMUL_DESC_TRANSC = 5, + + /** Matrix fill mode, see cublasFillMode_t. + * + * int32_t, default: CUBLAS_FILL_MODE_FULL + */ + CUBLASLT_MATMUL_DESC_FILL_MODE = 6, + + /** Epilogue function, see cublasLtEpilogue_t. + * + * uint32_t, default: CUBLASLT_EPILOGUE_DEFAULT + */ + CUBLASLT_MATMUL_DESC_EPILOGUE = 7, + + /** Bias or bias gradient vector pointer in the device memory. + * + * Bias case. See CUBLASLT_EPILOGUE_BIAS. + * For bias data type see CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE. + * + * Bias vector length must match matrix D rows count. + * + * Bias gradient case. See CUBLASLT_EPILOGUE_DRELU_BGRAD and CUBLASLT_EPILOGUE_DGELU_BGRAD. + * Bias gradient vector elements are the same type as the output elements + * (Ctype) with the exception of IMMA kernels (see above). + * + * Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic() + * depend on its value to determine expected pointer alignment. + * + * Bias case: const void *, default: NULL + * Bias gradient case: void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8, + + /** Batch stride for bias or bias gradient vector. + * + * Used together with CUBLASLT_MATMUL_DESC_BIAS_POINTER when matrix D's CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1. + * + * int64_t, default: 0 + */ + CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE = 10, + + /** Pointer for epilogue auxiliary buffer. + * + * - Output vector for ReLu bit-mask in forward pass when CUBLASLT_EPILOGUE_RELU_AUX + * or CUBLASLT_EPILOGUE_RELU_AUX_BIAS epilogue is used. + * - Input vector for ReLu bit-mask in backward pass when + * CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is used. + * + * - Output of GELU input matrix in forward pass when + * CUBLASLT_EPILOGUE_GELU_AUX_BIAS epilogue is used. + * - Input of GELU input matrix for backward pass when + * CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue is used. + * + * For aux data type see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE. + * + * Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic() + * depend on its value to determine expected pointer alignment. + * + * Requires setting CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD attribute. + * + * Forward pass: void *, default: NULL + * Backward pass: const void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER = 11, + + /** Leading dimension for epilogue auxiliary buffer. + * + * - ReLu bit-mask matrix leading dimension in elements (i.e. bits) + * when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is + * used. Must be divisible by 128 and be no less than the number of rows in the output matrix. + * + * - GELU input matrix leading dimension in elements + * when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used. + * Must be divisible by 8 and be no less than the number of rows in the output matrix. + * + * int64_t, default: 0 + */ + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD = 12, + + /** Batch stride for epilogue auxiliary buffer. + * + * - ReLu bit-mask matrix batch stride in elements (i.e. bits) + * when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is + * used. Must be divisible by 128. + * + * - GELU input matrix batch stride in elements + * when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used. + * Must be divisible by 8. + * + * int64_t, default: 0 + */ + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE = 13, + + /** Batch stride for alpha vector. + * + * Used together with CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST when matrix D's + * CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1. If CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO is set then + * CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE must be set to 0 as this mode doesnt supported batched alpha vector. + * + * int64_t, default: 0 + */ + CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE = 14, + + /** Number of SMs to target for parallel execution. Optimizes heuristics for execution on a different number of SMs + * when user expects a concurrent stream to be using some of the device resources. + * + * int32_t, default: 0 - use the number reported by the device. + */ + CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 15, + + /** Device pointer to the scale factor value that converts data in matrix A to the compute data type range. + * + * The scaling factor value must have the same type as the compute type. + * + * If not specified, or set to NULL, the scaling factor is assumed to be 1. + * + * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() + * will return CUBLAS_INVALID_VALUE. + * + * const void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 17, + + /** Device pointer to the scale factor value to convert data in matrix B to compute data type range. + * + * The scaling factor value must have the same type as the compute type. + * + * If not specified, or set to NULL, the scaling factor is assumed to be 1. + * + * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() + * will return CUBLAS_INVALID_VALUE. + * + * const void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 18, + + /** Device pointer to the scale factor value to convert data in matrix C to compute data type range. + * + * The scaling factor value must have the same type as the compute type. + * + * If not specified, or set to NULL, the scaling factor is assumed to be 1. + * + * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() + * will return CUBLAS_INVALID_VALUE. + * + * const void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 19, + + /** Device pointer to the scale factor value to convert data in matrix D to compute data type range. + * + * The scaling factor value must have the same type as the compute type. + * + * If not specified, or set to NULL, the scaling factor is assumed to be 1. + * + * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() + * will return CUBLAS_INVALID_VALUE. + * + * const void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 20, + + /** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the + * output matrix. + * + * The computed value has the same type as the compute type. + * + * If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix + * data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE. + * + * void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 21, + + /** Type of the data to be stored to the memory pointed to by CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + * + * If unset, the data type defaults to the type of elements of the output matrix with some exceptions, see details + * below. + * + * ReLu uses a bit-mask. + * + * GELU input matrix elements type is the same as the type of elements of + * the output matrix with some exceptions, see details below. + * + * For fp8 kernels with output type CUDA_R_8F_E4M3 the aux data type can be CUDA_R_8F_E4M3 or CUDA_R_16F with some + * restrictions. See https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulDescAttributes_t for more details. + * + * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() + * will return CUBLAS_INVALID_VALUE. + * + * int32_t based on cudaDataType, default: -1 + */ + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE = 22, + + /** Device pointer to the scaling factor value to convert results from compute type data range to storage + * data range in the auxiliary matrix that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + * + * The scaling factor value must have the same type as the compute type. + * + * If not specified, or set to NULL, the scaling factor is assumed to be 1. If set for an unsupported matrix data, + * scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE. + * + * void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER = 23, + + /** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the + * buffer that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + * + * The computed value has the same type as the compute type. + * + * If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix + * data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE. + * + * void *, default: NULL + */ + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER = 24, + + /** Flag for managing fp8 fast accumulation mode. + * When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results + * will not periodically be promoted to a higher precision. + * + * int8_t, default: 0 - fast accumulation mode is disabled. + */ + CUBLASLT_MATMUL_DESC_FAST_ACCUM = 25, + + /** Type of bias or bias gradient vector in the device memory. + * + * Bias case: see CUBLASLT_EPILOGUE_BIAS. + * + * Bias vector elements are the same type as the elements of output matrix (Dtype) with the following exceptions: + * - IMMA kernels with computeType=CUDA_R_32I and Ctype=CUDA_R_8I where the bias vector elements + * are the same type as alpha, beta (CUBLASLT_MATMUL_DESC_SCALE_TYPE=CUDA_R_32F) + * - fp8 kernels with an output type of CUDA_R_32F, CUDA_R_8F_E4M3 or CUDA_R_8F_E5M2, See + * https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul for details. + * + * int32_t based on cudaDataType, default: -1 + */ + CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 26, + + /** EXPERIMENTAL: Number of atomic synchronization chunks in the row dimension of the output matrix D. + * + * int32_t, default 0 (atomic synchronization disabled) + */ + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS = 27, + + /** EXPERIMENTAL: Number of atomic synchronization chunks in the column dimension of the output matrix D. + * + * int32_t, default 0 (atomic synchronization disabled) + */ + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS = 28, + + /** EXPERIMENTAL: Pointer to a device array of input atomic counters consumed by a matmul. + * + * int32_t *, default: NULL + * */ + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER = 29, + + /** EXPERIMENTAL: Pointer to a device array of output atomic counters produced by a matmul. + * + * int32_t *, default: NULL + * */ + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER = 30, +} cublasLtMatmulDescAttributes_t; + +/** Internal. Do not use directly. + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( // + cublasLtMatmulDesc_t matmulDesc, + size_t size, + cublasComputeType_t computeType, + cudaDataType_t scaleType); + +/** Initialize matmul operation descriptor in pre-allocated space. + * + * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient + * \retval CUBLAS_STATUS_SUCCESS if desciptor was initialized successfully + */ +static inline cublasStatus_t cublasLtMatmulDescInit( // + cublasLtMatmulDesc_t matmulDesc, + cublasComputeType_t computeType, + cudaDataType_t scaleType) { + return cublasLtMatmulDescInit_internal(matmulDesc, sizeof(*matmulDesc), computeType, scaleType); +} + +/** Create new matmul operation descriptor. + * + * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated + * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate(cublasLtMatmulDesc_t* matmulDesc, + cublasComputeType_t computeType, + cudaDataType_t scaleType); + +/** Destroy matmul operation descriptor. + * + * \retval CUBLAS_STATUS_SUCCESS if operation was successful + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc); + +/** Set matmul operation descriptor attribute. + * + * \param[in] matmulDesc The descriptor + * \param[in] attr The attribute + * \param[in] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * + * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( // + cublasLtMatmulDesc_t matmulDesc, + cublasLtMatmulDescAttributes_t attr, + const void* buf, + size_t sizeInBytes); + +/** Get matmul operation descriptor attribute. + * + * \param[in] matmulDesc The descriptor + * \param[in] attr The attribute + * \param[out] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of + * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents + * + * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero + * and buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( // + cublasLtMatmulDesc_t matmulDesc, + cublasLtMatmulDescAttributes_t attr, + void* buf, + size_t sizeInBytes, + size_t* sizeWritten); + +/* ---------------------------------------------------------------------------------------*/ +/* Helper functions for cublasLtMatrixTransformDesc_t */ +/* ---------------------------------------------------------------------------------------*/ + +/** Matrix transform descriptor attributes to define details of the operation. + */ +typedef enum { + /** Scale type, see cudaDataType. Inputs are converted to scale type for scaling and summation and results are then + * converted to output type to store in memory. + * + * int32_t + */ + CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE, + + /** Pointer mode of alpha and beta, see cublasLtPointerMode_t. + * + * int32_t, default: CUBLASLT_POINTER_MODE_HOST + */ + CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, + + /** Transform of matrix A, see cublasOperation_t. + * + * int32_t, default: CUBLAS_OP_N + */ + CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, + + /** Transform of matrix B, see cublasOperation_t. + * + * int32_t, default: CUBLAS_OP_N + */ + CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSB, +} cublasLtMatrixTransformDescAttributes_t; + +/** Internal. Do not use directly. + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc, + size_t size, + cudaDataType scaleType); + +/** Initialize matrix transform operation descriptor in pre-allocated space. + * + * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient + * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully + */ +static inline cublasStatus_t cublasLtMatrixTransformDescInit(cublasLtMatrixTransformDesc_t transformDesc, + cudaDataType scaleType) { + return cublasLtMatrixTransformDescInit_internal(transformDesc, sizeof(*transformDesc), scaleType); +} + +/** Create new matrix transform operation descriptor. + * + * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated + * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t* transformDesc, + cudaDataType scaleType); + +/** Destroy matrix transform operation descriptor. + * + * \retval CUBLAS_STATUS_SUCCESS if operation was successful + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc); + +/** Set matrix transform operation descriptor attribute. + * + * \param[in] transformDesc The descriptor + * \param[in] attr The attribute + * \param[in] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * + * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( // + cublasLtMatrixTransformDesc_t transformDesc, + cublasLtMatrixTransformDescAttributes_t attr, + const void* buf, + size_t sizeInBytes); + +/** Get matrix transform operation descriptor attribute. + * + * \param[in] transformDesc The descriptor + * \param[in] attr The attribute + * \param[out] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number + * of bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents + * + * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero + * and buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory + */ +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( // + cublasLtMatrixTransformDesc_t transformDesc, + cublasLtMatrixTransformDescAttributes_t attr, + void* buf, + size_t sizeInBytes, + size_t* sizeWritten); + +/** Reduction scheme for portions of the dot-product calculated in parallel (a. k. a. "split - K"). + */ +typedef enum { + /** No reduction scheme, dot-product shall be performed in one sequence. + */ + CUBLASLT_REDUCTION_SCHEME_NONE = 0, + + /** Reduction is performed "in place" - using the output buffer (and output data type) and counters (in workspace) to + * guarantee the sequentiality. + */ + CUBLASLT_REDUCTION_SCHEME_INPLACE = 1, + + /** Intermediate results are stored in compute type in the workspace and reduced in a separate step. + */ + CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE = 2, + + /** Intermediate results are stored in output type in the workspace and reduced in a separate step. + */ + CUBLASLT_REDUCTION_SCHEME_OUTPUT_TYPE = 4, + + CUBLASLT_REDUCTION_SCHEME_MASK = 0x7, +} cublasLtReductionScheme_t; + +/** Postprocessing options for the epilogue + */ +typedef enum { + /** No special postprocessing, just scale and quantize results if necessary. + */ + CUBLASLT_EPILOGUE_DEFAULT = 1, + + /** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)). + */ + CUBLASLT_EPILOGUE_RELU = 2, + + /** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)). + * + * This epilogue mode produces an extra output, a ReLu bit-mask matrix, + * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + */ + CUBLASLT_EPILOGUE_RELU_AUX = (CUBLASLT_EPILOGUE_RELU | 128), + + /** Bias, apply (broadcasted) Bias from bias vector. Bias vector length must match matrix D rows, it must be packed + * (stride between vector elements is 1). Bias vector is broadcasted to all columns and added before applying final + * postprocessing. + */ + CUBLASLT_EPILOGUE_BIAS = 4, + + /** ReLu and Bias, apply Bias and then ReLu transform + */ + CUBLASLT_EPILOGUE_RELU_BIAS = (CUBLASLT_EPILOGUE_RELU | CUBLASLT_EPILOGUE_BIAS), + + /** ReLu and Bias, apply Bias and then ReLu transform + * + * This epilogue mode produces an extra output, a ReLu bit-mask matrix, + * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + */ + CUBLASLT_EPILOGUE_RELU_AUX_BIAS = (CUBLASLT_EPILOGUE_RELU_AUX | CUBLASLT_EPILOGUE_BIAS), + + /* ReLu gradient. Apply ReLu gradient to matmul output. Store ReLu gradient in the output matrix. + * + * This epilogue mode requires an extra input, + * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + */ + CUBLASLT_EPILOGUE_DRELU = 8 | 128, + + /* ReLu and Bias gradients. Apply independently ReLu and Bias gradient to + * matmul output. Store ReLu gradient in the output matrix, and Bias gradient + * in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER). + * + * This epilogue mode requires an extra input, + * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + */ + CUBLASLT_EPILOGUE_DRELU_BGRAD = CUBLASLT_EPILOGUE_DRELU | 16, + + /** GELU, apply GELU point-wise transform to the results (x:=GELU(x)). + */ + CUBLASLT_EPILOGUE_GELU = 32, + + /** GELU, apply GELU point-wise transform to the results (x:=GELU(x)). + * + * This epilogue mode outputs GELU input as a separate matrix (useful for training). + * See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + */ + CUBLASLT_EPILOGUE_GELU_AUX = (CUBLASLT_EPILOGUE_GELU | 128), + + /** GELU and Bias, apply Bias and then GELU transform + */ + CUBLASLT_EPILOGUE_GELU_BIAS = (CUBLASLT_EPILOGUE_GELU | CUBLASLT_EPILOGUE_BIAS), + + /** GELU and Bias, apply Bias and then GELU transform + * + * This epilogue mode outputs GELU input as a separate matrix (useful for training). + * See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + */ + CUBLASLT_EPILOGUE_GELU_AUX_BIAS = (CUBLASLT_EPILOGUE_GELU_AUX | CUBLASLT_EPILOGUE_BIAS), + + /* GELU gradient. Apply GELU gradient to matmul output. Store GELU gradient in the output matrix. + * + * This epilogue mode requires an extra input, + * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + */ + CUBLASLT_EPILOGUE_DGELU = 64 | 128, + + /* GELU and Bias gradients. Apply independently GELU and Bias gradient to + * matmul output. Store GELU gradient in the output matrix, and Bias gradient + * in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER). + * + * This epilogue mode requires an extra input, + * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. + */ + CUBLASLT_EPILOGUE_DGELU_BGRAD = CUBLASLT_EPILOGUE_DGELU | 16, + + /** Bias gradient based on the input matrix A. + * + * The bias size corresponds to the number of rows of the matrix D. + * The reduction happens over the GEMM's "k" dimension. + * + * Stores Bias gradient in the auxiliary output + * (see CUBLASLT_MATMUL_DESC_BIAS_POINTER). + */ + CUBLASLT_EPILOGUE_BGRADA = 256, + + /** Bias gradient based on the input matrix B. + * + * The bias size corresponds to the number of columns of the matrix D. + * The reduction happens over the GEMM's "k" dimension. + * + * Stores Bias gradient in the auxiliary output + * (see CUBLASLT_MATMUL_DESC_BIAS_POINTER). + */ + CUBLASLT_EPILOGUE_BGRADB = 512, +} cublasLtEpilogue_t; + +/** Matmul heuristic search mode + */ +typedef enum { + /** ask heuristics for best algo for given usecase + */ + CUBLASLT_SEARCH_BEST_FIT = 0, + /** only try to find best config for preconfigured algo id + */ + CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID = 1, + /** reserved for future use + */ + CUBLASLT_SEARCH_RESERVED_02 = 2, + /** reserved for future use + */ + CUBLASLT_SEARCH_RESERVED_03 = 3, + /** reserved for future use + */ + CUBLASLT_SEARCH_RESERVED_04 = 4, + /** reserved for future use + */ + CUBLASLT_SEARCH_RESERVED_05 = 5, +} cublasLtMatmulSearch_t; + +/** Algo search preference to fine tune the heuristic function. */ +typedef enum { + /** Search mode, see cublasLtMatmulSearch_t. + * + * uint32_t, default: CUBLASLT_SEARCH_BEST_FIT + */ + CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0, + + /** Maximum allowed workspace size in bytes. + * + * uint64_t, default: 0 - no workspace allowed + */ + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1, + + /** Reduction scheme mask, see cublasLtReductionScheme_t. Filters heuristic result to only include algo configs that + * use one of the required modes. + * + * E.g. mask value of 0x03 will allow only INPLACE and COMPUTE_TYPE reduction schemes. + * + * uint32_t, default: CUBLASLT_REDUCTION_SCHEME_MASK (allows all reduction schemes) + */ + CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3, + + /** Minimum buffer alignment for matrix A (in bytes). + * + * Selecting a smaller value will exclude algorithms that can not work with matrix A that is not as strictly aligned + * as they need. + * + * uint32_t, default: 256 + */ + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5, + + /** Minimum buffer alignment for matrix B (in bytes). + * + * Selecting a smaller value will exclude algorithms that can not work with matrix B that is not as strictly aligned + * as they need. + * + * uint32_t, default: 256 + */ + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6, + + /** Minimum buffer alignment for matrix C (in bytes). + * + * Selecting a smaller value will exclude algorithms that can not work with matrix C that is not as strictly aligned + * as they need. + * + * uint32_t, default: 256 + */ + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7, + + /** Minimum buffer alignment for matrix D (in bytes). + * + * Selecting a smaller value will exclude algorithms that can not work with matrix D that is not as strictly aligned + * as they need. + * + * uint32_t, default: 256 + */ + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8, + + /** Maximum wave count. + * + * See cublasLtMatmulHeuristicResult_t::wavesCount. + * + * Selecting a non-zero value will exclude algorithms that report device utilization higher than specified. + * + * float, default: 0.0f + */ + CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9, + + /** Numerical implementation details mask, see cublasLtNumericalImplFlags_t. Filters heuristic result to only include + * algorithms that use the allowed implementations. + * + * uint64_t, default: uint64_t(-1) (allow everything) + */ + CUBLASLT_MATMUL_PREF_IMPL_MASK = 12, +} cublasLtMatmulPreferenceAttributes_t; + +/** Internal. Do not use directly. + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size); + +/** Initialize matmul heuristic search preference descriptor in pre-allocated space. + * + * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient + * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully + */ +static inline cublasStatus_t cublasLtMatmulPreferenceInit(cublasLtMatmulPreference_t pref) { + return cublasLtMatmulPreferenceInit_internal(pref, sizeof(*pref)); +} + +/** Create new matmul heuristic search preference descriptor. + * + * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated + * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t* pref); + +/** Destroy matmul heuristic search preference descriptor. + * + * \retval CUBLAS_STATUS_SUCCESS if operation was successful + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref); + +/** Set matmul heuristic search preference descriptor attribute. + * + * \param[in] pref The descriptor + * \param[in] attr The attribute + * \param[in] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * + * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( // + cublasLtMatmulPreference_t pref, + cublasLtMatmulPreferenceAttributes_t attr, + const void* buf, + size_t sizeInBytes); + +/** Get matmul heuristic search preference descriptor attribute. + * + * \param[in] pref The descriptor + * \param[in] attr The attribute + * \param[out] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of + * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents + * + * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero + * and buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( // + cublasLtMatmulPreference_t pref, + cublasLtMatmulPreferenceAttributes_t attr, + void* buf, + size_t sizeInBytes, + size_t* sizeWritten); + +/** Results structure used by cublasLtMatmulGetAlgo. + * + * Holds returned configured algo descriptor and its runtime properties. + */ +typedef struct { + /** Matmul algorithm descriptor. + * + * Must be initialized with cublasLtMatmulAlgoInit() if preferences' CUBLASLT_MATMUL_PERF_SEARCH_MODE is set to + * CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID + */ + cublasLtMatmulAlgo_t algo; + + /** Actual size of workspace memory required. + */ + size_t workspaceSize; + + /** Result status, other fields are only valid if after call to cublasLtMatmulAlgoGetHeuristic() this member is set to + * CUBLAS_STATUS_SUCCESS. + */ + cublasStatus_t state; + + /** Waves count - a device utilization metric. + * + * wavesCount value of 1.0f suggests that when kernel is launched it will fully occupy the GPU. + */ + float wavesCount; + + int reserved[4]; +} cublasLtMatmulHeuristicResult_t; + +/** Query cublasLt heuristic for algorithm appropriate for given use case. + * + * \param[in] lightHandle Pointer to the allocated cuBLASLt handle for the cuBLASLt + * context. See cublasLtHandle_t. + * \param[in] operationDesc Handle to the matrix multiplication descriptor. + * \param[in] Adesc Handle to the layout descriptors for matrix A. + * \param[in] Bdesc Handle to the layout descriptors for matrix B. + * \param[in] Cdesc Handle to the layout descriptors for matrix C. + * \param[in] Ddesc Handle to the layout descriptors for matrix D. + * \param[in] preference Pointer to the structure holding the heuristic search + * preferences descriptor. See cublasLtMatrixLayout_t. + * \param[in] requestedAlgoCount Size of heuristicResultsArray (in elements) and requested + * maximum number of algorithms to return. + * \param[in, out] heuristicResultsArray Output algorithms and associated runtime characteristics, + * ordered in increasing estimated compute time. + * \param[out] returnAlgoCount The number of heuristicResultsArray elements written. + * + * \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero + * \retval CUBLAS_STATUS_NOT_SUPPORTED if no heuristic function available for current configuration + * \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect + * heuristicResultsArray[0 to (returnAlgoCount - 1)].state + * for detail status of results + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t operationDesc, + cublasLtMatrixLayout_t Adesc, + cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, + cublasLtMatrixLayout_t Ddesc, + cublasLtMatmulPreference_t preference, + int requestedAlgoCount, + cublasLtMatmulHeuristicResult_t heuristicResultsArray[], + int* returnAlgoCount); + +/* ---------------------------------------------------------------------------------------*/ +/* Lower level API to be able to implement own Heuristic and Find routines */ +/* ---------------------------------------------------------------------------------------*/ + +/** Routine to get all algo IDs that can potentially run + * + * \param[in] int requestedAlgoCount requested number of algos (must be less or equal to size of algoIdsA + * (in elements)) \param[out] algoIdsA array to write algoIds to \param[out] returnAlgoCount number of algoIds + * actually written + * + * \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero + * \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect returnAlgoCount to get actual number of IDs + * available + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds(cublasLtHandle_t lightHandle, + cublasComputeType_t computeType, + cudaDataType_t scaleType, + cudaDataType_t Atype, + cudaDataType_t Btype, + cudaDataType_t Ctype, + cudaDataType_t Dtype, + int requestedAlgoCount, + int algoIdsArray[], + int* returnAlgoCount); + +/** Initialize algo structure + * + * \retval CUBLAS_STATUS_INVALID_VALUE if algo is NULL or algoId is outside of recognized range + * \retval CUBLAS_STATUS_NOT_SUPPORTED if algoId is not supported for given combination of data types + * \retval CUBLAS_STATUS_SUCCESS if the structure was successfully initialized + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit(cublasLtHandle_t lightHandle, + cublasComputeType_t computeType, + cudaDataType_t scaleType, + cudaDataType_t Atype, + cudaDataType_t Btype, + cudaDataType_t Ctype, + cudaDataType_t Dtype, + int algoId, + cublasLtMatmulAlgo_t* algo); + +/** Check configured algo descriptor for correctness and support on current device. + * + * Result includes required workspace size and calculated wave count. + * + * CUBLAS_STATUS_SUCCESS doesn't fully guarantee algo will run (will fail if e.g. buffers are not correctly aligned); + * but if cublasLtMatmulAlgoCheck fails, the algo will not run. + * + * \param[in] algo algo configuration to check + * \param[out] result result structure to report algo runtime characteristics; algo field is never updated + * + * \retval CUBLAS_STATUS_INVALID_VALUE if matrix layout descriptors or operation descriptor don't match algo + * descriptor + * \retval CUBLAS_STATUS_NOT_SUPPORTED if algo configuration or data type combination is not currently supported on + * given device + * \retval CUBLAS_STATUS_ARCH_MISMATCH if algo configuration cannot be run using the selected device + * \retval CUBLAS_STATUS_SUCCESS if check was successful + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( // + cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t operationDesc, + cublasLtMatrixLayout_t Adesc, + cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, + cublasLtMatrixLayout_t Ddesc, + const cublasLtMatmulAlgo_t* algo, ///< may point to result->algo + cublasLtMatmulHeuristicResult_t* result); + +/** Capabilities Attributes that can be retrieved from an initialized Algo structure + */ +typedef enum { + /** support for split K, see CUBLASLT_ALGO_CONFIG_SPLITK_NUM + * + * int32_t, 0 means no support, supported otherwise + */ + CUBLASLT_ALGO_CAP_SPLITK_SUPPORT = 0, + + /** reduction scheme mask, see cublasLtReductionScheme_t; shows supported reduction schemes, if reduction scheme is + * not masked out it is supported. + * + * e.g. int isReductionSchemeComputeTypeSupported ? (reductionSchemeMask & CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE) == + * CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE ? 1 : 0; + * + * uint32_t + */ + CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK = 1, + + /** support for cta swizzling, see CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING + * + * uint32_t, 0 means no support, 1 means supported value of 1, other values are reserved + */ + CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT = 2, + + /** support strided batch + * + * int32_t, 0 means no support, supported otherwise + */ + CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT = 3, + + /** support results out of place (D != C in D = alpha.A.B + beta.C) + * + * int32_t, 0 means no support, supported otherwise + */ + CUBLASLT_ALGO_CAP_OUT_OF_PLACE_RESULT_SUPPORT = 4, + + /** syrk/herk support (on top of regular gemm) + * + * int32_t, 0 means no support, supported otherwise + */ + CUBLASLT_ALGO_CAP_UPLO_SUPPORT = 5, + + /** tile ids possible to use, see cublasLtMatmulTile_t; if no tile ids are supported use + * CUBLASLT_MATMUL_TILE_UNDEFINED + * + * use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count + * + * array of uint32_t + */ + CUBLASLT_ALGO_CAP_TILE_IDS = 6, + + /** custom option range is from 0 to CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX (inclusive), see + * CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION + * + * int32_t + */ + CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX = 7, + + /** whether algorithm supports custom (not COL or ROW memory order), see cublasLtOrder_t + * + * int32_t 0 means only COL and ROW memory order is allowed, non-zero means that algo might have different + * requirements; + */ + CUBLASLT_ALGO_CAP_CUSTOM_MEMORY_ORDER = 10, + + /** bitmask enumerating pointer modes algorithm supports + * + * uint32_t, see cublasLtPointerModeMask_t + */ + CUBLASLT_ALGO_CAP_POINTER_MODE_MASK = 11, + + /** bitmask enumerating kinds of postprocessing algorithm supports in the epilogue + * + * uint32_t, see cublasLtEpilogue_t + */ + CUBLASLT_ALGO_CAP_EPILOGUE_MASK = 12, + + /** stages ids possible to use, see cublasLtMatmulStages_t; if no stages ids are supported use + * CUBLASLT_MATMUL_STAGES_UNDEFINED + * + * use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count + * + * array of uint32_t + */ + CUBLASLT_ALGO_CAP_STAGES_IDS = 13, + + /** support for nagative ld for all of the matrices + * + * int32_t 0 means no support, supported otherwise + */ + CUBLASLT_ALGO_CAP_LD_NEGATIVE = 14, + + /** details about algorithm's implementation that affect it's numerical behavior + * + * uint64_t, see cublasLtNumericalImplFlags_t + */ + CUBLASLT_ALGO_CAP_NUMERICAL_IMPL_FLAGS = 15, + + /** minimum alignment required for A matrix in bytes + * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order) + * + * uint32_t + */ + CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_A_BYTES = 16, + + /** minimum alignment required for B matrix in bytes + * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order) + * + * uint32_t + */ + CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_B_BYTES = 17, + + /** minimum alignment required for C matrix in bytes + * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order) + * + * uint32_t + */ + CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_C_BYTES = 18, + + /** minimum alignment required for D matrix in bytes + * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order) + * + * uint32_t + */ + CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_D_BYTES = 19, + + /** EXPERIMENTAL: support for synchronization via atomic counters + * + * int32_t + */ + CUBLASLT_ALGO_CAP_ATOMIC_SYNC = 20, +} cublasLtMatmulAlgoCapAttributes_t; + +/** Get algo capability attribute. + * + * E.g. to get list of supported Tile IDs: + * cublasLtMatmulTile_t tiles[CUBLASLT_MATMUL_TILE_END]; + * size_t num_tiles, size_written; + * if (cublasLtMatmulAlgoCapGetAttribute(algo, CUBLASLT_ALGO_CAP_TILE_IDS, tiles, sizeof(tiles), size_written) == + * CUBLAS_STATUS_SUCCESS) { num_tiles = size_written / sizeof(tiles[0]); + * } + * + * \param[in] algo The algo descriptor + * \param[in] attr The attribute + * \param[out] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of + * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents + * + * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero + * and buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute(const cublasLtMatmulAlgo_t* algo, + cublasLtMatmulAlgoCapAttributes_t attr, + void* buf, + size_t sizeInBytes, + size_t* sizeWritten); + +/** Algo Configuration Attributes that can be set according to the Algo capabilities + */ +typedef enum { + /** algorithm index, see cublasLtMatmulAlgoGetIds() + * + * readonly, set by cublasLtMatmulAlgoInit() + * int32_t + */ + CUBLASLT_ALGO_CONFIG_ID = 0, + /** tile id, see cublasLtMatmulTile_t + * + * uint32_t, default: CUBLASLT_MATMUL_TILE_UNDEFINED + */ + CUBLASLT_ALGO_CONFIG_TILE_ID = 1, + /** Number of K splits. If the number of K splits is greater than one, SPLITK_NUM parts + * of matrix multiplication will be computed in parallel. The results will be accumulated + * according to CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME + * + * int32_t, default: 1 + */ + CUBLASLT_ALGO_CONFIG_SPLITK_NUM = 2, + /** reduction scheme, see cublasLtReductionScheme_t + * + * uint32_t, default: CUBLASLT_REDUCTION_SCHEME_NONE + */ + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME = 3, + /** cta swizzling, change mapping from CUDA grid coordinates to parts of the matrices + * + * possible values: 0, 1, other values reserved + * + * uint32_t, default: 0 + */ + CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING = 4, + /** custom option, each algorithm can support some custom options that don't fit description of the other config + * attributes, see CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX to get accepted range for any specific case + * + * uint32_t, default: 0 + */ + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION = 5, + /** stages id, see cublasLtMatmulStages_t + * + * uint32_t, default: CUBLASLT_MATMUL_STAGES_UNDEFINED + */ + CUBLASLT_ALGO_CONFIG_STAGES_ID = 6, + /** inner shape id, see cublasLtMatmulInnerShape_t + * + * uint16_t, default: 0 (CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED) + */ + CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID = 7, + /** Thread Block Cluster shape id, see cublasLtClusterShape_t. Defines cluster size to use. + * + * uint16_t, default: 0 (CUBLASLT_CLUSTER_SHAPE_AUTO) + */ + CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID = 8, +} cublasLtMatmulAlgoConfigAttributes_t; + +/** Set algo configuration attribute. + * + * \param[in] algo The algo descriptor + * \param[in] attr The attribute + * \param[in] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * + * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute(cublasLtMatmulAlgo_t* algo, + cublasLtMatmulAlgoConfigAttributes_t attr, + const void* buf, + size_t sizeInBytes); + +/** Get algo configuration attribute. + * + * \param[in] algo The algo descriptor + * \param[in] attr The attribute + * \param[out] buf memory address containing the new value + * \param[in] sizeInBytes size of buf buffer for verification (in bytes) + * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of + * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents + * + * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero + * and buf is NULL or sizeInBytes doesn't match size of internal storage for + * selected attribute + * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory + */ +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute(const cublasLtMatmulAlgo_t* algo, + cublasLtMatmulAlgoConfigAttributes_t attr, + void* buf, + size_t sizeInBytes, + size_t* sizeWritten); + +/** Experimental: Logger callback type. + */ +typedef void (*cublasLtLoggerCallback_t)(int logLevel, const char* functionName, const char* message); + +/** Experimental: Logger callback setter. + * + * \param[in] callback a user defined callback function to be called by the logger + * + * \retval CUBLAS_STATUS_SUCCESS if callback was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtLoggerSetCallback(cublasLtLoggerCallback_t callback); + +/** Experimental: Log file setter. + * + * \param[in] file an open file with write permissions + * + * \retval CUBLAS_STATUS_SUCCESS if log file was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtLoggerSetFile(FILE* file); + +/** Experimental: Open log file. + * + * \param[in] logFile log file path. if the log file does not exist, it will be created + * + * \retval CUBLAS_STATUS_SUCCESS if log file was created successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtLoggerOpenFile(const char* logFile); + +/** Experimental: Log level setter. + * + * \param[in] level log level, should be one of the following: + * 0. Off + * 1. Errors + * 2. Performance Trace + * 3. Performance Hints + * 4. Heuristics Trace + * 5. API Trace + * + * \retval CUBLAS_STATUS_INVALID_VALUE if log level is not one of the above levels + * + * \retval CUBLAS_STATUS_SUCCESS if log level was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtLoggerSetLevel(int level); + +/** Experimental: Log mask setter. + * + * \param[in] mask log mask, should be a combination of the following masks: + * 0. Off + * 1. Errors + * 2. Performance Trace + * 4. Performance Hints + * 8. Heuristics Trace + * 16. API Trace + * + * \retval CUBLAS_STATUS_SUCCESS if log mask was set successfully + */ +cublasStatus_t CUBLASWINAPI cublasLtLoggerSetMask(int mask); + +/** Experimental: Disable logging for the entire session. + * + * \retval CUBLAS_STATUS_SUCCESS if disabled logging + */ +cublasStatus_t CUBLASWINAPI cublasLtLoggerForceDisable(); + +#if defined(__cplusplus) +} +#endif /* __cplusplus */ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasXt.h b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasXt.h new file mode 100644 index 0000000000000000000000000000000000000000..fe0e6f99b952514874c45208e751f5330e71570c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasXt.h @@ -0,0 +1,693 @@ +/* + * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +/* cublasXt : Host API, Out of Core and Multi-GPU BLAS Library + +*/ + +#if !defined(CUBLAS_XT_H_) +#define CUBLAS_XT_H_ + +#include "driver_types.h" +#include "cuComplex.h" /* import complex data type */ + +#include "cublas_v2.h" + +#if defined(__cplusplus) +extern "C" { +#endif /* __cplusplus */ + +struct cublasXtContext; +typedef struct cublasXtContext* cublasXtHandle_t; + +cublasStatus_t CUBLASWINAPI cublasXtCreate(cublasXtHandle_t* handle); +cublasStatus_t CUBLASWINAPI cublasXtDestroy(cublasXtHandle_t handle); +cublasStatus_t CUBLASWINAPI cublasXtGetNumBoards(int nbDevices, int deviceId[], int* nbBoards); +cublasStatus_t CUBLASWINAPI cublasXtMaxBoards(int* nbGpuBoards); +/* This routine selects the Gpus that the user want to use for CUBLAS-XT */ +cublasStatus_t CUBLASWINAPI cublasXtDeviceSelect(cublasXtHandle_t handle, int nbDevices, int deviceId[]); + +/* This routine allows to change the dimension of the tiles ( blockDim x blockDim ) */ +cublasStatus_t CUBLASWINAPI cublasXtSetBlockDim(cublasXtHandle_t handle, int blockDim); +cublasStatus_t CUBLASWINAPI cublasXtGetBlockDim(cublasXtHandle_t handle, int* blockDim); + +typedef enum { CUBLASXT_PINNING_DISABLED = 0, CUBLASXT_PINNING_ENABLED = 1 } cublasXtPinnedMemMode_t; +/* This routine allows to CUBLAS-XT to pin the Host memory if it find out that some of the matrix passed + are not pinned : Pinning/Unpinning the Host memory is still a costly operation + It is better if the user controls the memory on its own (by pinning/unpinning oly when necessary) +*/ +cublasStatus_t CUBLASWINAPI cublasXtGetPinningMemMode(cublasXtHandle_t handle, cublasXtPinnedMemMode_t* mode); +cublasStatus_t CUBLASWINAPI cublasXtSetPinningMemMode(cublasXtHandle_t handle, cublasXtPinnedMemMode_t mode); + +/* This routines is to provide a CPU Blas routines, used for too small sizes or hybrid computation */ +typedef enum { + CUBLASXT_FLOAT = 0, + CUBLASXT_DOUBLE = 1, + CUBLASXT_COMPLEX = 2, + CUBLASXT_DOUBLECOMPLEX = 3, +} cublasXtOpType_t; + +typedef enum { + CUBLASXT_GEMM = 0, + CUBLASXT_SYRK = 1, + CUBLASXT_HERK = 2, + CUBLASXT_SYMM = 3, + CUBLASXT_HEMM = 4, + CUBLASXT_TRSM = 5, + CUBLASXT_SYR2K = 6, + CUBLASXT_HER2K = 7, + + CUBLASXT_SPMM = 8, + CUBLASXT_SYRKX = 9, + CUBLASXT_HERKX = 10, + CUBLASXT_TRMM = 11, + CUBLASXT_ROUTINE_MAX = 12, +} cublasXtBlasOp_t; + +/* Currently only 32-bit integer BLAS routines are supported */ +cublasStatus_t CUBLASWINAPI cublasXtSetCpuRoutine(cublasXtHandle_t handle, + cublasXtBlasOp_t blasOp, + cublasXtOpType_t type, + void* blasFunctor); + +/* Specified the percentage of work that should done by the CPU, default is 0 (no work) */ +cublasStatus_t CUBLASWINAPI cublasXtSetCpuRatio(cublasXtHandle_t handle, + cublasXtBlasOp_t blasOp, + cublasXtOpType_t type, + float ratio); + +/* GEMM */ +cublasStatus_t CUBLASWINAPI cublasXtSgemm(cublasXtHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + size_t m, + size_t n, + size_t k, + const float* alpha, + const float* A, + size_t lda, + const float* B, + size_t ldb, + const float* beta, + float* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtDgemm(cublasXtHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + size_t m, + size_t n, + size_t k, + const double* alpha, + const double* A, + size_t lda, + const double* B, + size_t ldb, + const double* beta, + double* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtCgemm(cublasXtHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + size_t m, + size_t n, + size_t k, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* B, + size_t ldb, + const cuComplex* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZgemm(cublasXtHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + size_t m, + size_t n, + size_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* B, + size_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + size_t ldc); +/* ------------------------------------------------------- */ +/* SYRK */ +cublasStatus_t CUBLASWINAPI cublasXtSsyrk(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const float* alpha, + const float* A, + size_t lda, + const float* beta, + float* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtDsyrk(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const double* alpha, + const double* A, + size_t lda, + const double* beta, + double* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtCsyrk(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZsyrk(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + size_t ldc); +/* -------------------------------------------------------------------- */ +/* HERK */ +cublasStatus_t CUBLASWINAPI cublasXtCherk(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const float* alpha, + const cuComplex* A, + size_t lda, + const float* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZherk(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const double* alpha, + const cuDoubleComplex* A, + size_t lda, + const double* beta, + cuDoubleComplex* C, + size_t ldc); +/* -------------------------------------------------------------------- */ +/* SYR2K */ +cublasStatus_t CUBLASWINAPI cublasXtSsyr2k(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const float* alpha, + const float* A, + size_t lda, + const float* B, + size_t ldb, + const float* beta, + float* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtDsyr2k(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const double* alpha, + const double* A, + size_t lda, + const double* B, + size_t ldb, + const double* beta, + double* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtCsyr2k(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* B, + size_t ldb, + const cuComplex* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZsyr2k(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* B, + size_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + size_t ldc); +/* -------------------------------------------------------------------- */ +/* HERKX : variant extension of HERK */ +cublasStatus_t CUBLASWINAPI cublasXtCherkx(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* B, + size_t ldb, + const float* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZherkx(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* B, + size_t ldb, + const double* beta, + cuDoubleComplex* C, + size_t ldc); + +/* -------------------------------------------------------------------- */ +/* TRSM */ +cublasStatus_t CUBLASWINAPI cublasXtStrsm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + size_t m, + size_t n, + const float* alpha, + const float* A, + size_t lda, + float* B, + size_t ldb); + +cublasStatus_t CUBLASWINAPI cublasXtDtrsm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + size_t m, + size_t n, + const double* alpha, + const double* A, + size_t lda, + double* B, + size_t ldb); + +cublasStatus_t CUBLASWINAPI cublasXtCtrsm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + size_t m, + size_t n, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + cuComplex* B, + size_t ldb); + +cublasStatus_t CUBLASWINAPI cublasXtZtrsm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + size_t m, + size_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + cuDoubleComplex* B, + size_t ldb); +/* -------------------------------------------------------------------- */ +/* SYMM : Symmetric Multiply Matrix*/ +cublasStatus_t CUBLASWINAPI cublasXtSsymm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const float* alpha, + const float* A, + size_t lda, + const float* B, + size_t ldb, + const float* beta, + float* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtDsymm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const double* alpha, + const double* A, + size_t lda, + const double* B, + size_t ldb, + const double* beta, + double* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtCsymm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* B, + size_t ldb, + const cuComplex* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZsymm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* B, + size_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + size_t ldc); +/* -------------------------------------------------------------------- */ +/* HEMM : Hermitian Matrix Multiply */ +cublasStatus_t CUBLASWINAPI cublasXtChemm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* B, + size_t ldb, + const cuComplex* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZhemm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* B, + size_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + size_t ldc); + +/* -------------------------------------------------------------------- */ +/* SYRKX : variant extension of SYRK */ +cublasStatus_t CUBLASWINAPI cublasXtSsyrkx(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const float* alpha, + const float* A, + size_t lda, + const float* B, + size_t ldb, + const float* beta, + float* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtDsyrkx(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const double* alpha, + const double* A, + size_t lda, + const double* B, + size_t ldb, + const double* beta, + double* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtCsyrkx(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* B, + size_t ldb, + const cuComplex* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZsyrkx(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* B, + size_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + size_t ldc); +/* -------------------------------------------------------------------- */ +/* HER2K : variant extension of HERK */ +cublasStatus_t CUBLASWINAPI cublasXtCher2k(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* B, + size_t ldb, + const float* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZher2k(cublasXtHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + size_t n, + size_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* B, + size_t ldb, + const double* beta, + cuDoubleComplex* C, + size_t ldc); + +/* -------------------------------------------------------------------- */ +/* SPMM : Symmetric Packed Multiply Matrix*/ +cublasStatus_t CUBLASWINAPI cublasXtSspmm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const float* alpha, + const float* AP, + const float* B, + size_t ldb, + const float* beta, + float* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtDspmm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const double* alpha, + const double* AP, + const double* B, + size_t ldb, + const double* beta, + double* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtCspmm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const cuComplex* alpha, + const cuComplex* AP, + const cuComplex* B, + size_t ldb, + const cuComplex* beta, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZspmm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + size_t m, + size_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* AP, + const cuDoubleComplex* B, + size_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + size_t ldc); + +/* -------------------------------------------------------------------- */ +/* TRMM */ +cublasStatus_t CUBLASWINAPI cublasXtStrmm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + size_t m, + size_t n, + const float* alpha, + const float* A, + size_t lda, + const float* B, + size_t ldb, + float* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtDtrmm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + size_t m, + size_t n, + const double* alpha, + const double* A, + size_t lda, + const double* B, + size_t ldb, + double* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtCtrmm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + size_t m, + size_t n, + const cuComplex* alpha, + const cuComplex* A, + size_t lda, + const cuComplex* B, + size_t ldb, + cuComplex* C, + size_t ldc); + +cublasStatus_t CUBLASWINAPI cublasXtZtrmm(cublasXtHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + size_t m, + size_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + size_t lda, + const cuDoubleComplex* B, + size_t ldb, + cuDoubleComplex* C, + size_t ldc); + +#if defined(__cplusplus) +} +#endif /* __cplusplus */ + +#endif /* !defined(CUBLAS_XT_H_) */ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h new file mode 100644 index 0000000000000000000000000000000000000000..b46b3c41386e3740ae99c246d0fdc606111d8593 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h @@ -0,0 +1,5793 @@ +/* + * Copyright 1993-2022 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +/* + * This is the public header file for the CUBLAS library, defining the API + * + * CUBLAS is an implementation of BLAS (Basic Linear Algebra Subroutines) + * on top of the CUDA runtime. + */ + +#if !defined(CUBLAS_API_H_) +#define CUBLAS_API_H_ + +#ifndef CUBLASWINAPI +#ifdef _WIN32 +#define CUBLASWINAPI __stdcall +#else +#define CUBLASWINAPI +#endif +#endif + +#ifndef CUBLASAPI +#error "This file should not be included without defining CUBLASAPI" +#endif + +#include + +#include "driver_types.h" +#include "cuComplex.h" /* import complex data type */ + +#include +#include + +#include + +#if defined(__cplusplus) +extern "C" { +#endif /* __cplusplus */ + +#define CUBLAS_VER_MAJOR 12 +#define CUBLAS_VER_MINOR 4 +#define CUBLAS_VER_PATCH 5 +#define CUBLAS_VER_BUILD 8 +#define CUBLAS_VERSION (CUBLAS_VER_MAJOR * 10000 + CUBLAS_VER_MINOR * 100 + CUBLAS_VER_PATCH) + +/* CUBLAS status type returns */ +typedef enum { + CUBLAS_STATUS_SUCCESS = 0, + CUBLAS_STATUS_NOT_INITIALIZED = 1, + CUBLAS_STATUS_ALLOC_FAILED = 3, + CUBLAS_STATUS_INVALID_VALUE = 7, + CUBLAS_STATUS_ARCH_MISMATCH = 8, + CUBLAS_STATUS_MAPPING_ERROR = 11, + CUBLAS_STATUS_EXECUTION_FAILED = 13, + CUBLAS_STATUS_INTERNAL_ERROR = 14, + CUBLAS_STATUS_NOT_SUPPORTED = 15, + CUBLAS_STATUS_LICENSE_ERROR = 16 +} cublasStatus_t; + +typedef enum { CUBLAS_FILL_MODE_LOWER = 0, CUBLAS_FILL_MODE_UPPER = 1, CUBLAS_FILL_MODE_FULL = 2 } cublasFillMode_t; + +typedef enum { CUBLAS_DIAG_NON_UNIT = 0, CUBLAS_DIAG_UNIT = 1 } cublasDiagType_t; + +typedef enum { CUBLAS_SIDE_LEFT = 0, CUBLAS_SIDE_RIGHT = 1 } cublasSideMode_t; + +typedef enum { + CUBLAS_OP_N = 0, + CUBLAS_OP_T = 1, + CUBLAS_OP_C = 2, + CUBLAS_OP_HERMITAN = 2, /* synonym if CUBLAS_OP_C */ + CUBLAS_OP_CONJG = 3 /* conjugate, placeholder - not supported in the current release */ +} cublasOperation_t; + +typedef enum { CUBLAS_POINTER_MODE_HOST = 0, CUBLAS_POINTER_MODE_DEVICE = 1 } cublasPointerMode_t; + +typedef enum { CUBLAS_ATOMICS_NOT_ALLOWED = 0, CUBLAS_ATOMICS_ALLOWED = 1 } cublasAtomicsMode_t; + +/*For different GEMM algorithm */ +typedef enum { + CUBLAS_GEMM_DFALT = -1, + CUBLAS_GEMM_DEFAULT = -1, + CUBLAS_GEMM_ALGO0 = 0, + CUBLAS_GEMM_ALGO1 = 1, + CUBLAS_GEMM_ALGO2 = 2, + CUBLAS_GEMM_ALGO3 = 3, + CUBLAS_GEMM_ALGO4 = 4, + CUBLAS_GEMM_ALGO5 = 5, + CUBLAS_GEMM_ALGO6 = 6, + CUBLAS_GEMM_ALGO7 = 7, + CUBLAS_GEMM_ALGO8 = 8, + CUBLAS_GEMM_ALGO9 = 9, + CUBLAS_GEMM_ALGO10 = 10, + CUBLAS_GEMM_ALGO11 = 11, + CUBLAS_GEMM_ALGO12 = 12, + CUBLAS_GEMM_ALGO13 = 13, + CUBLAS_GEMM_ALGO14 = 14, + CUBLAS_GEMM_ALGO15 = 15, + CUBLAS_GEMM_ALGO16 = 16, + CUBLAS_GEMM_ALGO17 = 17, + CUBLAS_GEMM_ALGO18 = 18, // sliced 32x32 + CUBLAS_GEMM_ALGO19 = 19, // sliced 64x32 + CUBLAS_GEMM_ALGO20 = 20, // sliced 128x32 + CUBLAS_GEMM_ALGO21 = 21, // sliced 32x32 -splitK + CUBLAS_GEMM_ALGO22 = 22, // sliced 64x32 -splitK + CUBLAS_GEMM_ALGO23 = 23, // sliced 128x32 -splitK + CUBLAS_GEMM_DEFAULT_TENSOR_OP = 99, + CUBLAS_GEMM_DFALT_TENSOR_OP = 99, + CUBLAS_GEMM_ALGO0_TENSOR_OP = 100, + CUBLAS_GEMM_ALGO1_TENSOR_OP = 101, + CUBLAS_GEMM_ALGO2_TENSOR_OP = 102, + CUBLAS_GEMM_ALGO3_TENSOR_OP = 103, + CUBLAS_GEMM_ALGO4_TENSOR_OP = 104, + CUBLAS_GEMM_ALGO5_TENSOR_OP = 105, + CUBLAS_GEMM_ALGO6_TENSOR_OP = 106, + CUBLAS_GEMM_ALGO7_TENSOR_OP = 107, + CUBLAS_GEMM_ALGO8_TENSOR_OP = 108, + CUBLAS_GEMM_ALGO9_TENSOR_OP = 109, + CUBLAS_GEMM_ALGO10_TENSOR_OP = 110, + CUBLAS_GEMM_ALGO11_TENSOR_OP = 111, + CUBLAS_GEMM_ALGO12_TENSOR_OP = 112, + CUBLAS_GEMM_ALGO13_TENSOR_OP = 113, + CUBLAS_GEMM_ALGO14_TENSOR_OP = 114, + CUBLAS_GEMM_ALGO15_TENSOR_OP = 115 +} cublasGemmAlgo_t; + +/*Enum for default math mode/tensor operation*/ +typedef enum { + CUBLAS_DEFAULT_MATH = 0, + + /* deprecated, same effect as using CUBLAS_COMPUTE_32F_FAST_16F, will be removed in a future release */ + CUBLAS_TENSOR_OP_MATH = 1, + + /* same as using matching _PEDANTIC compute type when using cublasroutine calls or cublasEx() calls with + cudaDataType as compute type */ + CUBLAS_PEDANTIC_MATH = 2, + + /* allow accelerating single precision routines using TF32 tensor cores */ + CUBLAS_TF32_TENSOR_OP_MATH = 3, + + /* flag to force any reductons to use the accumulator type and not output type in case of mixed precision routines + with lower size output type */ + CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION = 16, +} cublasMath_t; + +/* For backward compatibility purposes */ +typedef cudaDataType cublasDataType_t; + +/* Enum for compute type + * + * - default types provide best available performance using all available hardware features + * and guarantee internal storage precision with at least the same precision and range; + * - _PEDANTIC types ensure standard arithmetic and exact specified internal storage format; + * - _FAST types allow for some loss of precision to enable higher throughput arithmetic. + */ +typedef enum { + CUBLAS_COMPUTE_16F = 64, /* half - default */ + CUBLAS_COMPUTE_16F_PEDANTIC = 65, /* half - pedantic */ + CUBLAS_COMPUTE_32F = 68, /* float - default */ + CUBLAS_COMPUTE_32F_PEDANTIC = 69, /* float - pedantic */ + CUBLAS_COMPUTE_32F_FAST_16F = 74, /* float - fast, allows down-converting inputs to half or TF32 */ + CUBLAS_COMPUTE_32F_FAST_16BF = 75, /* float - fast, allows down-converting inputs to bfloat16 or TF32 */ + CUBLAS_COMPUTE_32F_FAST_TF32 = 77, /* float - fast, allows down-converting inputs to TF32 */ + CUBLAS_COMPUTE_64F = 70, /* double - default */ + CUBLAS_COMPUTE_64F_PEDANTIC = 71, /* double - pedantic */ + CUBLAS_COMPUTE_32I = 72, /* signed 32-bit int - default */ + CUBLAS_COMPUTE_32I_PEDANTIC = 73, /* signed 32-bit int - pedantic */ +} cublasComputeType_t; + +/* Opaque structure holding CUBLAS library context */ +struct cublasContext; +typedef struct cublasContext* cublasHandle_t; + +/* Cublas logging */ +typedef void (*cublasLogCallback)(const char* msg); + +/* cuBLAS Exported API {{{ */ + +/* --------------- CUBLAS Helper Functions ---------------- */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCreate_v2(cublasHandle_t* handle); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDestroy_v2(cublasHandle_t handle); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGetVersion_v2(cublasHandle_t handle, int* version); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGetProperty(libraryPropertyType type, int* value); + +CUBLASAPI size_t CUBLASWINAPI cublasGetCudartVersion(void); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSetWorkspace_v2(cublasHandle_t handle, + void* workspace, + size_t workspaceSizeInBytes); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSetStream_v2(cublasHandle_t handle, cudaStream_t streamId); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGetStream_v2(cublasHandle_t handle, cudaStream_t* streamId); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGetPointerMode_v2(cublasHandle_t handle, cublasPointerMode_t* mode); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSetPointerMode_v2(cublasHandle_t handle, cublasPointerMode_t mode); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGetAtomicsMode(cublasHandle_t handle, cublasAtomicsMode_t* mode); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSetAtomicsMode(cublasHandle_t handle, cublasAtomicsMode_t mode); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGetMathMode(cublasHandle_t handle, cublasMath_t* mode); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSetMathMode(cublasHandle_t handle, cublasMath_t mode); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGetSmCountTarget(cublasHandle_t handle, int* smCountTarget); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSetSmCountTarget(cublasHandle_t handle, int smCountTarget); + +CUBLASAPI const char* CUBLASWINAPI cublasGetStatusName(cublasStatus_t status); + +CUBLASAPI const char* CUBLASWINAPI cublasGetStatusString(cublasStatus_t status); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasLoggerConfigure(int logIsOn, + int logToStdOut, + int logToStdErr, + const char* logFileName); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSetLoggerCallback(cublasLogCallback userCallback); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGetLoggerCallback(cublasLogCallback* userCallback); + +cublasStatus_t CUBLASWINAPI cublasSetVector(int n, int elemSize, const void* x, int incx, void* devicePtr, int incy); + +cublasStatus_t CUBLASWINAPI +cublasSetVector_64(int64_t n, int64_t elemSize, const void* x, int64_t incx, void* devicePtr, int64_t incy); + +cublasStatus_t CUBLASWINAPI cublasGetVector(int n, int elemSize, const void* x, int incx, void* y, int incy); + +cublasStatus_t CUBLASWINAPI +cublasGetVector_64(int64_t n, int64_t elemSize, const void* x, int64_t incx, void* y, int64_t incy); + +cublasStatus_t CUBLASWINAPI cublasSetMatrix(int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb); + +cublasStatus_t CUBLASWINAPI +cublasSetMatrix_64(int64_t rows, int64_t cols, int64_t elemSize, const void* A, int64_t lda, void* B, int64_t ldb); + +cublasStatus_t CUBLASWINAPI cublasGetMatrix(int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb); + +cublasStatus_t CUBLASWINAPI +cublasGetMatrix_64(int64_t rows, int64_t cols, int64_t elemSize, const void* A, int64_t lda, void* B, int64_t ldb); + +cublasStatus_t CUBLASWINAPI cublasSetVectorAsync( + int n, int elemSize, const void* hostPtr, int incx, void* devicePtr, int incy, cudaStream_t stream); + +cublasStatus_t CUBLASWINAPI cublasSetVectorAsync_64( + int64_t n, int64_t elemSize, const void* hostPtr, int64_t incx, void* devicePtr, int64_t incy, cudaStream_t stream); + +cublasStatus_t CUBLASWINAPI cublasGetVectorAsync( + int n, int elemSize, const void* devicePtr, int incx, void* hostPtr, int incy, cudaStream_t stream); + +cublasStatus_t CUBLASWINAPI cublasGetVectorAsync_64( + int64_t n, int64_t elemSize, const void* devicePtr, int64_t incx, void* hostPtr, int64_t incy, cudaStream_t stream); + +cublasStatus_t CUBLASWINAPI +cublasSetMatrixAsync(int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, cudaStream_t stream); + +cublasStatus_t CUBLASWINAPI cublasSetMatrixAsync_64(int64_t rows, + int64_t cols, + int64_t elemSize, + const void* A, + int64_t lda, + void* B, + int64_t ldb, + cudaStream_t stream); + +cublasStatus_t CUBLASWINAPI +cublasGetMatrixAsync(int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, cudaStream_t stream); + +cublasStatus_t CUBLASWINAPI cublasGetMatrixAsync_64(int64_t rows, + int64_t cols, + int64_t elemSize, + const void* A, + int64_t lda, + void* B, + int64_t ldb, + cudaStream_t stream); + +CUBLASAPI void CUBLASWINAPI cublasXerbla(const char* srName, int info); + +/* --------------- CUBLAS BLAS1 Functions ---------------- */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasNrm2Ex(cublasHandle_t handle, + int n, + const void* x, + cudaDataType xType, + int incx, + void* result, + cudaDataType resultType, + cudaDataType executionType); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasNrm2Ex_64(cublasHandle_t handle, + int64_t n, + const void* x, + cudaDataType xType, + int64_t incx, + void* result, + cudaDataType resultType, + cudaDataType executionType); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSnrm2_v2(cublasHandle_t handle, int n, const float* x, int incx, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSnrm2_v2_64(cublasHandle_t handle, int64_t n, const float* x, int64_t incx, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDnrm2_v2(cublasHandle_t handle, int n, const double* x, int incx, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDnrm2_v2_64(cublasHandle_t handle, int64_t n, const double* x, int64_t incx, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasScnrm2_v2(cublasHandle_t handle, int n, const cuComplex* x, int incx, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasScnrm2_v2_64(cublasHandle_t handle, int64_t n, const cuComplex* x, int64_t incx, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDznrm2_v2(cublasHandle_t handle, int n, const cuDoubleComplex* x, int incx, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDznrm2_v2_64(cublasHandle_t handle, int64_t n, const cuDoubleComplex* x, int64_t incx, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDotEx(cublasHandle_t handle, + int n, + const void* x, + cudaDataType xType, + int incx, + const void* y, + cudaDataType yType, + int incy, + void* result, + cudaDataType resultType, + cudaDataType executionType); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDotEx_64(cublasHandle_t handle, + int64_t n, + const void* x, + cudaDataType xType, + int64_t incx, + const void* y, + cudaDataType yType, + int64_t incy, + void* result, + cudaDataType resultType, + cudaDataType executionType); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDotcEx(cublasHandle_t handle, + int n, + const void* x, + cudaDataType xType, + int incx, + const void* y, + cudaDataType yType, + int incy, + void* result, + cudaDataType resultType, + cudaDataType executionType); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDotcEx_64(cublasHandle_t handle, + int64_t n, + const void* x, + cudaDataType xType, + int64_t incx, + const void* y, + cudaDataType yType, + int64_t incy, + void* result, + cudaDataType resultType, + cudaDataType executionType); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSdot_v2(cublasHandle_t handle, int n, const float* x, int incx, const float* y, int incy, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSdot_v2_64( + cublasHandle_t handle, int64_t n, const float* x, int64_t incx, const float* y, int64_t incy, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDdot_v2(cublasHandle_t handle, int n, const double* x, int incx, const double* y, int incy, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDdot_v2_64( + cublasHandle_t handle, int64_t n, const double* x, int64_t incx, const double* y, int64_t incy, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCdotu_v2( + cublasHandle_t handle, int n, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCdotu_v2_64(cublasHandle_t handle, + int64_t n, + const cuComplex* x, + int64_t incx, + const cuComplex* y, + int64_t incy, + cuComplex* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCdotc_v2( + cublasHandle_t handle, int n, const cuComplex* x, int incx, const cuComplex* y, int incy, cuComplex* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCdotc_v2_64(cublasHandle_t handle, + int64_t n, + const cuComplex* x, + int64_t incx, + const cuComplex* y, + int64_t incy, + cuComplex* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZdotu_v2(cublasHandle_t handle, + int n, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZdotu_v2_64(cublasHandle_t handle, + int64_t n, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* y, + int64_t incy, + cuDoubleComplex* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZdotc_v2(cublasHandle_t handle, + int n, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZdotc_v2_64(cublasHandle_t handle, + int64_t n, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* y, + int64_t incy, + cuDoubleComplex* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasScalEx(cublasHandle_t handle, + int n, + const void* alpha, + cudaDataType alphaType, + void* x, + cudaDataType xType, + int incx, + cudaDataType executionType); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasScalEx_64(cublasHandle_t handle, + int64_t n, + const void* alpha, + cudaDataType alphaType, + void* x, + cudaDataType xType, + int64_t incx, + cudaDataType executionType); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSscal_v2(cublasHandle_t handle, int n, const float* alpha, float* x, int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSscal_v2_64(cublasHandle_t handle, int64_t n, const float* alpha, float* x, int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDscal_v2(cublasHandle_t handle, int n, const double* alpha, double* x, int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDscal_v2_64(cublasHandle_t handle, int64_t n, const double* alpha, double* x, int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCscal_v2(cublasHandle_t handle, int n, const cuComplex* alpha, cuComplex* x, int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCscal_v2_64(cublasHandle_t handle, int64_t n, const cuComplex* alpha, cuComplex* x, int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCsscal_v2(cublasHandle_t handle, int n, const float* alpha, cuComplex* x, int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCsscal_v2_64(cublasHandle_t handle, int64_t n, const float* alpha, cuComplex* x, int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasZscal_v2(cublasHandle_t handle, int n, const cuDoubleComplex* alpha, cuDoubleComplex* x, int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasZscal_v2_64(cublasHandle_t handle, int64_t n, const cuDoubleComplex* alpha, cuDoubleComplex* x, int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasZdscal_v2(cublasHandle_t handle, int n, const double* alpha, cuDoubleComplex* x, int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasZdscal_v2_64(cublasHandle_t handle, int64_t n, const double* alpha, cuDoubleComplex* x, int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasAxpyEx(cublasHandle_t handle, + int n, + const void* alpha, + cudaDataType alphaType, + const void* x, + cudaDataType xType, + int incx, + void* y, + cudaDataType yType, + int incy, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasAxpyEx_64(cublasHandle_t handle, + int64_t n, + const void* alpha, + cudaDataType alphaType, + const void* x, + cudaDataType xType, + int64_t incx, + void* y, + cudaDataType yType, + int64_t incy, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSaxpy_v2(cublasHandle_t handle, int n, const float* alpha, const float* x, int incx, float* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSaxpy_v2_64( + cublasHandle_t handle, int64_t n, const float* alpha, const float* x, int64_t incx, float* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDaxpy_v2(cublasHandle_t handle, int n, const double* alpha, const double* x, int incx, double* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDaxpy_v2_64( + cublasHandle_t handle, int64_t n, const double* alpha, const double* x, int64_t incx, double* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCaxpy_v2( + cublasHandle_t handle, int n, const cuComplex* alpha, const cuComplex* x, int incx, cuComplex* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCaxpy_v2_64(cublasHandle_t handle, + int64_t n, + const cuComplex* alpha, + const cuComplex* x, + int64_t incx, + cuComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZaxpy_v2(cublasHandle_t handle, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZaxpy_v2_64(cublasHandle_t handle, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int64_t incx, + cuDoubleComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCopyEx( + cublasHandle_t handle, int n, const void* x, cudaDataType xType, int incx, void* y, cudaDataType yType, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCopyEx_64(cublasHandle_t handle, + int64_t n, + const void* x, + cudaDataType xType, + int64_t incx, + void* y, + cudaDataType yType, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasScopy_v2(cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasScopy_v2_64(cublasHandle_t handle, int64_t n, const float* x, int64_t incx, float* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDcopy_v2(cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDcopy_v2_64(cublasHandle_t handle, int64_t n, const double* x, int64_t incx, double* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCcopy_v2(cublasHandle_t handle, int n, const cuComplex* x, int incx, cuComplex* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCcopy_v2_64(cublasHandle_t handle, int64_t n, const cuComplex* x, int64_t incx, cuComplex* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasZcopy_v2(cublasHandle_t handle, int n, const cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZcopy_v2_64( + cublasHandle_t handle, int64_t n, const cuDoubleComplex* x, int64_t incx, cuDoubleComplex* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSswap_v2(cublasHandle_t handle, int n, float* x, int incx, float* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSswap_v2_64(cublasHandle_t handle, int64_t n, float* x, int64_t incx, float* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDswap_v2(cublasHandle_t handle, int n, double* x, int incx, double* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDswap_v2_64(cublasHandle_t handle, int64_t n, double* x, int64_t incx, double* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCswap_v2(cublasHandle_t handle, int n, cuComplex* x, int incx, cuComplex* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCswap_v2_64(cublasHandle_t handle, int64_t n, cuComplex* x, int64_t incx, cuComplex* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasZswap_v2(cublasHandle_t handle, int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasZswap_v2_64(cublasHandle_t handle, int64_t n, cuDoubleComplex* x, int64_t incx, cuDoubleComplex* y, int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSwapEx( + cublasHandle_t handle, int n, void* x, cudaDataType xType, int incx, void* y, cudaDataType yType, int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSwapEx_64(cublasHandle_t handle, + int64_t n, + void* x, + cudaDataType xType, + int64_t incx, + void* y, + cudaDataType yType, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIsamax_v2(cublasHandle_t handle, int n, const float* x, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIsamax_v2_64(cublasHandle_t handle, int64_t n, const float* x, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIdamax_v2(cublasHandle_t handle, int n, const double* x, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIdamax_v2_64(cublasHandle_t handle, int64_t n, const double* x, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIcamax_v2(cublasHandle_t handle, int n, const cuComplex* x, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIcamax_v2_64(cublasHandle_t handle, int64_t n, const cuComplex* x, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIzamax_v2(cublasHandle_t handle, int n, const cuDoubleComplex* x, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIzamax_v2_64(cublasHandle_t handle, int64_t n, const cuDoubleComplex* x, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIamaxEx(cublasHandle_t handle, int n, const void* x, cudaDataType xType, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIamaxEx_64(cublasHandle_t handle, int64_t n, const void* x, cudaDataType xType, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIsamin_v2(cublasHandle_t handle, int n, const float* x, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIsamin_v2_64(cublasHandle_t handle, int64_t n, const float* x, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIdamin_v2(cublasHandle_t handle, int n, const double* x, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIdamin_v2_64(cublasHandle_t handle, int64_t n, const double* x, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIcamin_v2(cublasHandle_t handle, int n, const cuComplex* x, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIcamin_v2_64(cublasHandle_t handle, int64_t n, const cuComplex* x, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIzamin_v2(cublasHandle_t handle, int n, const cuDoubleComplex* x, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIzamin_v2_64(cublasHandle_t handle, int64_t n, const cuDoubleComplex* x, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIaminEx(cublasHandle_t handle, int n, const void* x, cudaDataType xType, int incx, int* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasIaminEx_64(cublasHandle_t handle, int64_t n, const void* x, cudaDataType xType, int64_t incx, int64_t* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasAsumEx(cublasHandle_t handle, + int n, + const void* x, + cudaDataType xType, + int incx, + void* result, + cudaDataType resultType, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasAsumEx_64(cublasHandle_t handle, + int64_t n, + const void* x, + cudaDataType xType, + int64_t incx, + void* result, + cudaDataType resultType, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSasum_v2(cublasHandle_t handle, int n, const float* x, int incx, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSasum_v2_64(cublasHandle_t handle, int64_t n, const float* x, int64_t incx, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDasum_v2(cublasHandle_t handle, int n, const double* x, int incx, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDasum_v2_64(cublasHandle_t handle, int64_t n, const double* x, int64_t incx, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasScasum_v2(cublasHandle_t handle, int n, const cuComplex* x, int incx, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasScasum_v2_64(cublasHandle_t handle, int64_t n, const cuComplex* x, int64_t incx, float* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDzasum_v2(cublasHandle_t handle, int n, const cuDoubleComplex* x, int incx, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDzasum_v2_64(cublasHandle_t handle, int64_t n, const cuDoubleComplex* x, int64_t incx, double* result); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSrot_v2(cublasHandle_t handle, int n, float* x, int incx, float* y, int incy, const float* c, const float* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSrot_v2_64( + cublasHandle_t handle, int64_t n, float* x, int64_t incx, float* y, int64_t incy, const float* c, const float* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDrot_v2(cublasHandle_t handle, int n, double* x, int incx, double* y, int incy, const double* c, const double* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDrot_v2_64(cublasHandle_t handle, + int64_t n, + double* x, + int64_t incx, + double* y, + int64_t incy, + const double* c, + const double* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCrot_v2( + cublasHandle_t handle, int n, cuComplex* x, int incx, cuComplex* y, int incy, const float* c, const cuComplex* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCrot_v2_64(cublasHandle_t handle, + int64_t n, + cuComplex* x, + int64_t incx, + cuComplex* y, + int64_t incy, + const float* c, + const cuComplex* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsrot_v2( + cublasHandle_t handle, int n, cuComplex* x, int incx, cuComplex* y, int incy, const float* c, const float* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsrot_v2_64(cublasHandle_t handle, + int64_t n, + cuComplex* x, + int64_t incx, + cuComplex* y, + int64_t incy, + const float* c, + const float* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZrot_v2(cublasHandle_t handle, + int n, + cuDoubleComplex* x, + int incx, + cuDoubleComplex* y, + int incy, + const double* c, + const cuDoubleComplex* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZrot_v2_64(cublasHandle_t handle, + int64_t n, + cuDoubleComplex* x, + int64_t incx, + cuDoubleComplex* y, + int64_t incy, + const double* c, + const cuDoubleComplex* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZdrot_v2(cublasHandle_t handle, + int n, + cuDoubleComplex* x, + int incx, + cuDoubleComplex* y, + int incy, + const double* c, + const double* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZdrot_v2_64(cublasHandle_t handle, + int64_t n, + cuDoubleComplex* x, + int64_t incx, + cuDoubleComplex* y, + int64_t incy, + const double* c, + const double* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasRotEx(cublasHandle_t handle, + int n, + void* x, + cudaDataType xType, + int incx, + void* y, + cudaDataType yType, + int incy, + const void* c, + const void* s, + cudaDataType csType, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasRotEx_64(cublasHandle_t handle, + int64_t n, + void* x, + cudaDataType xType, + int64_t incx, + void* y, + cudaDataType yType, + int64_t incy, + const void* c, + const void* s, + cudaDataType csType, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSrotg_v2(cublasHandle_t handle, float* a, float* b, float* c, float* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDrotg_v2(cublasHandle_t handle, double* a, double* b, double* c, double* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCrotg_v2(cublasHandle_t handle, cuComplex* a, cuComplex* b, float* c, cuComplex* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasZrotg_v2(cublasHandle_t handle, cuDoubleComplex* a, cuDoubleComplex* b, double* c, cuDoubleComplex* s); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasRotgEx(cublasHandle_t handle, + void* a, + void* b, + cudaDataType abType, + void* c, + void* s, + cudaDataType csType, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSrotm_v2(cublasHandle_t handle, int n, float* x, int incx, float* y, int incy, const float* param); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSrotm_v2_64(cublasHandle_t handle, int64_t n, float* x, int64_t incx, float* y, int64_t incy, const float* param); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDrotm_v2(cublasHandle_t handle, int n, double* x, int incx, double* y, int incy, const double* param); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDrotm_v2_64( + cublasHandle_t handle, int64_t n, double* x, int64_t incx, double* y, int64_t incy, const double* param); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasRotmEx(cublasHandle_t handle, + int n, + void* x, + cudaDataType xType, + int incx, + void* y, + cudaDataType yType, + int incy, + const void* param, + cudaDataType paramType, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasRotmEx_64(cublasHandle_t handle, + int64_t n, + void* x, + cudaDataType xType, + int64_t incx, + void* y, + cudaDataType yType, + int64_t incy, + const void* param, + cudaDataType paramType, + cudaDataType executiontype); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSrotmg_v2(cublasHandle_t handle, float* d1, float* d2, float* x1, const float* y1, float* param); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDrotmg_v2(cublasHandle_t handle, double* d1, double* d2, double* x1, const double* y1, double* param); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasRotmgEx(cublasHandle_t handle, + void* d1, + cudaDataType d1Type, + void* d2, + cudaDataType d2Type, + void* x1, + cudaDataType x1Type, + const void* y1, + cudaDataType y1Type, + void* param, + cudaDataType paramType, + cudaDataType executiontype); + +/* --------------- CUBLAS BLAS2 Functions ---------------- */ + +/* GEMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemv_v2(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const float* A, + int lda, + const float* x, + int incx, + const float* beta, + float* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemv_v2_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const float* A, + int64_t lda, + const float* x, + int64_t incx, + const float* beta, + float* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemv_v2(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const double* alpha, + const double* A, + int lda, + const double* x, + int incx, + const double* beta, + double* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemv_v2_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const double* alpha, + const double* A, + int64_t lda, + const double* x, + int64_t incx, + const double* beta, + double* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemv_v2(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + const cuComplex* beta, + cuComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemv_v2_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* x, + int64_t incx, + const cuComplex* beta, + cuComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemv_v2(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemv_v2_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int64_t incy); + +/* GBMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgbmv_v2(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int kl, + int ku, + const float* alpha, + const float* A, + int lda, + const float* x, + int incx, + const float* beta, + float* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgbmv_v2_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + int64_t kl, + int64_t ku, + const float* alpha, + const float* A, + int64_t lda, + const float* x, + int64_t incx, + const float* beta, + float* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgbmv_v2(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int kl, + int ku, + const double* alpha, + const double* A, + int lda, + const double* x, + int incx, + const double* beta, + double* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgbmv_v2_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + int64_t kl, + int64_t ku, + const double* alpha, + const double* A, + int64_t lda, + const double* x, + int64_t incx, + const double* beta, + double* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgbmv_v2(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int kl, + int ku, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + const cuComplex* beta, + cuComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgbmv_v2_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + int64_t kl, + int64_t ku, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* x, + int64_t incx, + const cuComplex* beta, + cuComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgbmv_v2(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int kl, + int ku, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgbmv_v2_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + int64_t kl, + int64_t ku, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int64_t incy); + +/* TRMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const float* A, + int lda, + float* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const float* A, + int64_t lda, + float* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const double* A, + int lda, + double* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const double* A, + int64_t lda, + double* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuComplex* A, + int lda, + cuComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const cuComplex* A, + int64_t lda, + cuComplex* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuDoubleComplex* A, + int lda, + cuDoubleComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const cuDoubleComplex* A, + int64_t lda, + cuDoubleComplex* x, + int64_t incx); + +/* TBMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const float* A, + int lda, + float* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStbmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + int64_t k, + const float* A, + int64_t lda, + float* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const double* A, + int lda, + double* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtbmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + int64_t k, + const double* A, + int64_t lda, + double* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const cuComplex* A, + int lda, + cuComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtbmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + int64_t k, + const cuComplex* A, + int64_t lda, + cuComplex* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const cuDoubleComplex* A, + int lda, + cuDoubleComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtbmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + int64_t k, + const cuDoubleComplex* A, + int64_t lda, + cuDoubleComplex* x, + int64_t incx); + +/* TPMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const float* AP, + float* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStpmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const float* AP, + float* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const double* AP, + double* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtpmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const double* AP, + double* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuComplex* AP, + cuComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtpmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const cuComplex* AP, + cuComplex* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuDoubleComplex* AP, + cuDoubleComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtpmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const cuDoubleComplex* AP, + cuDoubleComplex* x, + int64_t incx); + +/* TRSV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const float* A, + int lda, + float* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const float* A, + int64_t lda, + float* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const double* A, + int lda, + double* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const double* A, + int64_t lda, + double* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuComplex* A, + int lda, + cuComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const cuComplex* A, + int64_t lda, + cuComplex* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuDoubleComplex* A, + int lda, + cuDoubleComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const cuDoubleComplex* A, + int64_t lda, + cuDoubleComplex* x, + int64_t incx); + +/* TPSV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStpsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const float* AP, + float* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStpsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const float* AP, + float* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtpsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const double* AP, + double* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtpsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const double* AP, + double* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtpsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuComplex* AP, + cuComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtpsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const cuComplex* AP, + cuComplex* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtpsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + const cuDoubleComplex* AP, + cuDoubleComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtpsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + const cuDoubleComplex* AP, + cuDoubleComplex* x, + int64_t incx); + +/* TBSV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const float* A, + int lda, + float* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStbsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + int64_t k, + const float* A, + int64_t lda, + float* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const double* A, + int lda, + double* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtbsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + int64_t k, + const double* A, + int64_t lda, + double* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const cuComplex* A, + int lda, + cuComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtbsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + int64_t k, + const cuComplex* A, + int64_t lda, + cuComplex* x, + int64_t incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int n, + int k, + const cuDoubleComplex* A, + int lda, + cuDoubleComplex* x, + int incx); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtbsv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t n, + int64_t k, + const cuDoubleComplex* A, + int64_t lda, + cuDoubleComplex* x, + int64_t incx); + +/* SYMV/HEMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsymv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float* alpha, + const float* A, + int lda, + const float* x, + int incx, + const float* beta, + float* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsymv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const float* alpha, + const float* A, + int64_t lda, + const float* x, + int64_t incx, + const float* beta, + float* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsymv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double* alpha, + const double* A, + int lda, + const double* x, + int incx, + const double* beta, + double* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsymv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const double* alpha, + const double* A, + int64_t lda, + const double* x, + int64_t incx, + const double* beta, + double* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsymv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + const cuComplex* beta, + cuComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsymv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* x, + int64_t incx, + const cuComplex* beta, + cuComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsymv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsymv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChemv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + const cuComplex* beta, + cuComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChemv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* x, + int64_t incx, + const cuComplex* beta, + cuComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhemv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhemv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int64_t incy); + +/* SBMV/HBMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + int k, + const float* alpha, + const float* A, + int lda, + const float* x, + int incx, + const float* beta, + float* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsbmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + int64_t k, + const float* alpha, + const float* A, + int64_t lda, + const float* x, + int64_t incx, + const float* beta, + float* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + int k, + const double* alpha, + const double* A, + int lda, + const double* x, + int incx, + const double* beta, + double* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsbmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + int64_t k, + const double* alpha, + const double* A, + int64_t lda, + const double* x, + int64_t incx, + const double* beta, + double* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + const cuComplex* beta, + cuComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChbmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* x, + int64_t incx, + const cuComplex* beta, + cuComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhbmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int64_t incy); + +/* SPMV/HPMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSspmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float* alpha, + const float* AP, + const float* x, + int incx, + const float* beta, + float* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSspmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const float* alpha, + const float* AP, + const float* x, + int64_t incx, + const float* beta, + float* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDspmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double* alpha, + const double* AP, + const double* x, + int incx, + const double* beta, + double* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDspmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const double* alpha, + const double* AP, + const double* x, + int64_t incx, + const double* beta, + double* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex* alpha, + const cuComplex* AP, + const cuComplex* x, + int incx, + const cuComplex* beta, + cuComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChpmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuComplex* alpha, + const cuComplex* AP, + const cuComplex* x, + int64_t incx, + const cuComplex* beta, + cuComplex* y, + int64_t incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* AP, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int incy); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpmv_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* AP, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int64_t incy); + +/* GER */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSger_v2(cublasHandle_t handle, + int m, + int n, + const float* alpha, + const float* x, + int incx, + const float* y, + int incy, + float* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSger_v2_64(cublasHandle_t handle, + int64_t m, + int64_t n, + const float* alpha, + const float* x, + int64_t incx, + const float* y, + int64_t incy, + float* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDger_v2(cublasHandle_t handle, + int m, + int n, + const double* alpha, + const double* x, + int incx, + const double* y, + int incy, + double* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDger_v2_64(cublasHandle_t handle, + int64_t m, + int64_t n, + const double* alpha, + const double* x, + int64_t incx, + const double* y, + int64_t incy, + double* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgeru_v2(cublasHandle_t handle, + int m, + int n, + const cuComplex* alpha, + const cuComplex* x, + int incx, + const cuComplex* y, + int incy, + cuComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgeru_v2_64(cublasHandle_t handle, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* x, + int64_t incx, + const cuComplex* y, + int64_t incy, + cuComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgerc_v2(cublasHandle_t handle, + int m, + int n, + const cuComplex* alpha, + const cuComplex* x, + int incx, + const cuComplex* y, + int incy, + cuComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgerc_v2_64(cublasHandle_t handle, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* x, + int64_t incx, + const cuComplex* y, + int64_t incy, + cuComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgeru_v2(cublasHandle_t handle, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgeru_v2_64(cublasHandle_t handle, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* y, + int64_t incy, + cuDoubleComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgerc_v2(cublasHandle_t handle, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgerc_v2_64(cublasHandle_t handle, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* y, + int64_t incy, + cuDoubleComplex* A, + int64_t lda); + +/* SYR/HER */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyr_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float* alpha, + const float* x, + int incx, + float* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyr_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const float* alpha, + const float* x, + int64_t incx, + float* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyr_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double* alpha, + const double* x, + int incx, + double* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyr_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const double* alpha, + const double* x, + int64_t incx, + double* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyr_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex* alpha, + const cuComplex* x, + int incx, + cuComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyr_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuComplex* alpha, + const cuComplex* x, + int64_t incx, + cuComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyr_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyr_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int64_t incx, + cuDoubleComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCher_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float* alpha, + const cuComplex* x, + int incx, + cuComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCher_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const float* alpha, + const cuComplex* x, + int64_t incx, + cuComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZher_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double* alpha, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZher_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const double* alpha, + const cuDoubleComplex* x, + int64_t incx, + cuDoubleComplex* A, + int64_t lda); + +/* SPR/HPR */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSspr_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, const float* alpha, const float* x, int incx, float* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSspr_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const float* alpha, + const float* x, + int64_t incx, + float* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDspr_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, const double* alpha, const double* x, int incx, double* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDspr_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const double* alpha, + const double* x, + int64_t incx, + double* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChpr_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float* alpha, + const cuComplex* x, + int incx, + cuComplex* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChpr_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const float* alpha, + const cuComplex* x, + int64_t incx, + cuComplex* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpr_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double* alpha, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpr_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const double* alpha, + const cuDoubleComplex* x, + int64_t incx, + cuDoubleComplex* AP); + +/* SYR2/HER2 */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyr2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float* alpha, + const float* x, + int incx, + const float* y, + int incy, + float* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyr2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const float* alpha, + const float* x, + int64_t incx, + const float* y, + int64_t incy, + float* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyr2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double* alpha, + const double* x, + int incx, + const double* y, + int incy, + double* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyr2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const double* alpha, + const double* x, + int64_t incx, + const double* y, + int64_t incy, + double* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyr2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex* alpha, + const cuComplex* x, + int incx, + const cuComplex* y, + int incy, + cuComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyr2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuComplex* alpha, + const cuComplex* x, + int64_t incx, + const cuComplex* y, + int64_t incy, + cuComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyr2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyr2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* y, + int64_t incy, + cuDoubleComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCher2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex* alpha, + const cuComplex* x, + int incx, + const cuComplex* y, + int incy, + cuComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCher2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuComplex* alpha, + const cuComplex* x, + int64_t incx, + const cuComplex* y, + int64_t incy, + cuComplex* A, + int64_t lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZher2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* A, + int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZher2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* y, + int64_t incy, + cuDoubleComplex* A, + int64_t lda); + +/* SPR2/HPR2 */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSspr2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const float* alpha, + const float* x, + int incx, + const float* y, + int incy, + float* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSspr2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const float* alpha, + const float* x, + int64_t incx, + const float* y, + int64_t incy, + float* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDspr2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const double* alpha, + const double* x, + int incx, + const double* y, + int incy, + double* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDspr2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const double* alpha, + const double* x, + int64_t incx, + const double* y, + int64_t incy, + double* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChpr2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex* alpha, + const cuComplex* x, + int incx, + const cuComplex* y, + int incy, + cuComplex* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChpr2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuComplex* alpha, + const cuComplex* x, + int64_t incx, + const cuComplex* y, + int64_t incy, + cuComplex* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpr2_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int incx, + const cuDoubleComplex* y, + int incy, + cuDoubleComplex* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhpr2_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* x, + int64_t incx, + const cuDoubleComplex* y, + int64_t incy, + cuDoubleComplex* AP); + +/* BATCH GEMV */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const float* const Aarray[], + int lda, + const float* const xarray[], + int incx, + const float* beta, + float* const yarray[], + int incy, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const float* const Aarray[], + int64_t lda, + const float* const xarray[], + int64_t incx, + const float* beta, + float* const yarray[], + int64_t incy, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const double* alpha, + const double* const Aarray[], + int lda, + const double* const xarray[], + int incx, + const double* beta, + double* const yarray[], + int incy, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const double* alpha, + const double* const Aarray[], + int64_t lda, + const double* const xarray[], + int64_t incx, + const double* beta, + double* const yarray[], + int64_t incy, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemvBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const cuComplex* alpha, + const cuComplex* const Aarray[], + int lda, + const cuComplex* const xarray[], + int incx, + const cuComplex* beta, + cuComplex* const yarray[], + int incy, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemvBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* const Aarray[], + int64_t lda, + const cuComplex* const xarray[], + int64_t incx, + const cuComplex* beta, + cuComplex* const yarray[], + int64_t incy, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemvBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* const Aarray[], + int lda, + const cuDoubleComplex* const xarray[], + int incx, + const cuDoubleComplex* beta, + cuDoubleComplex* const yarray[], + int incy, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemvBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* const Aarray[], + int64_t lda, + const cuDoubleComplex* const xarray[], + int64_t incx, + const cuDoubleComplex* beta, + cuDoubleComplex* const yarray[], + int64_t incy, + int64_t batchCount); + +#if defined(__cplusplus) + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const __half* const Aarray[], + int lda, + const __half* const xarray[], + int incx, + const float* beta, + __half* const yarray[], + int incy, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const __half* const Aarray[], + int64_t lda, + const __half* const xarray[], + int64_t incx, + const float* beta, + __half* const yarray[], + int64_t incy, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const __half* const Aarray[], + int lda, + const __half* const xarray[], + int incx, + const float* beta, + float* const yarray[], + int incy, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const __half* const Aarray[], + int64_t lda, + const __half* const xarray[], + int64_t incx, + const float* beta, + float* const yarray[], + int64_t incy, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const __nv_bfloat16* const Aarray[], + int lda, + const __nv_bfloat16* const xarray[], + int incx, + const float* beta, + __nv_bfloat16* const yarray[], + int incy, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const __nv_bfloat16* const Aarray[], + int64_t lda, + const __nv_bfloat16* const xarray[], + int64_t incx, + const float* beta, + __nv_bfloat16* const yarray[], + int64_t incy, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const __nv_bfloat16* const Aarray[], + int lda, + const __nv_bfloat16* const xarray[], + int incx, + const float* beta, + float* const yarray[], + int incy, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const __nv_bfloat16* const Aarray[], + int64_t lda, + const __nv_bfloat16* const xarray[], + int64_t incx, + const float* beta, + float* const yarray[], + int64_t incy, + int64_t batchCount); + +#endif + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvStridedBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const float* A, + int lda, + long long int strideA, + const float* x, + int incx, + long long int stridex, + const float* beta, + float* y, + int incy, + long long int stridey, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvStridedBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const float* A, + int64_t lda, + long long int strideA, + const float* x, + int64_t incx, + long long int stridex, + const float* beta, + float* y, + int64_t incy, + long long int stridey, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvStridedBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const double* alpha, + const double* A, + int lda, + long long int strideA, + const double* x, + int incx, + long long int stridex, + const double* beta, + double* y, + int incy, + long long int stridey, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvStridedBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const double* alpha, + const double* A, + int64_t lda, + long long int strideA, + const double* x, + int64_t incx, + long long int stridex, + const double* beta, + double* y, + int64_t incy, + long long int stridey, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemvStridedBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + long long int strideA, + const cuComplex* x, + int incx, + long long int stridex, + const cuComplex* beta, + cuComplex* y, + int incy, + long long int stridey, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemvStridedBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + long long int strideA, + const cuComplex* x, + int64_t incx, + long long int stridex, + const cuComplex* beta, + cuComplex* y, + int64_t incy, + long long int stridey, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemvStridedBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + long long int strideA, + const cuDoubleComplex* x, + int incx, + long long int stridex, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int incy, + long long int stridey, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemvStridedBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + long long int strideA, + const cuDoubleComplex* x, + int64_t incx, + long long int stridex, + const cuDoubleComplex* beta, + cuDoubleComplex* y, + int64_t incy, + long long int stridey, + int64_t batchCount); + +#if defined(__cplusplus) + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvStridedBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const __half* A, + int lda, + long long int strideA, + const __half* x, + int incx, + long long int stridex, + const float* beta, + __half* y, + int incy, + long long int stridey, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvStridedBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const __half* A, + int64_t lda, + long long int strideA, + const __half* x, + int64_t incx, + long long int stridex, + const float* beta, + __half* y, + int64_t incy, + long long int stridey, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvStridedBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const __half* A, + int lda, + long long int strideA, + const __half* x, + int incx, + long long int stridex, + const float* beta, + float* y, + int incy, + long long int stridey, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvStridedBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const __half* A, + int64_t lda, + long long int strideA, + const __half* x, + int64_t incx, + long long int stridex, + const float* beta, + float* y, + int64_t incy, + long long int stridey, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvStridedBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const __nv_bfloat16* A, + int lda, + long long int strideA, + const __nv_bfloat16* x, + int incx, + long long int stridex, + const float* beta, + __nv_bfloat16* y, + int incy, + long long int stridey, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvStridedBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const __nv_bfloat16* A, + int64_t lda, + long long int strideA, + const __nv_bfloat16* x, + int64_t incx, + long long int stridex, + const float* beta, + __nv_bfloat16* y, + int64_t incy, + long long int stridey, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvStridedBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + const float* alpha, + const __nv_bfloat16* A, + int lda, + long long int strideA, + const __nv_bfloat16* x, + int incx, + long long int stridex, + const float* beta, + float* y, + int incy, + long long int stridey, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvStridedBatched_64(cublasHandle_t handle, + cublasOperation_t trans, + int64_t m, + int64_t n, + const float* alpha, + const __nv_bfloat16* A, + int64_t lda, + long long int strideA, + const __nv_bfloat16* x, + int64_t incx, + long long int stridex, + const float* beta, + float* y, + int64_t incy, + long long int stridey, + int64_t batchCount); + +#endif + +/* ---------------- CUBLAS BLAS3 Functions ---------------- */ + +/* GEMM */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* A, + int lda, + const float* B, + int ldb, + const float* beta, + float* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const float* alpha, + const float* A, + int64_t lda, + const float* B, + int64_t ldb, + const float* beta, + float* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemm_v2(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const double* alpha, + const double* A, + int lda, + const double* B, + int ldb, + const double* beta, + double* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemm_v2_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const double* alpha, + const double* A, + int64_t lda, + const double* B, + int64_t ldb, + const double* beta, + double* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm_v2(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + const cuComplex* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm_v2_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + const cuComplex* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm3m(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + const cuComplex* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm3m_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + const cuComplex* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm3mEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex* alpha, + const void* A, + cudaDataType Atype, + int lda, + const void* B, + cudaDataType Btype, + int ldb, + const cuComplex* beta, + void* C, + cudaDataType Ctype, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm3mEx_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuComplex* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + const void* B, + cudaDataType Btype, + int64_t ldb, + const cuComplex* beta, + void* C, + cudaDataType Ctype, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemm_v2(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemm_v2_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemm3m(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemm3m_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int64_t ldc); + +#if defined(__cplusplus) + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemm(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const __half* alpha, + const __half* A, + int lda, + const __half* B, + int ldb, + const __half* beta, + __half* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemm_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const __half* alpha, + const __half* A, + int64_t lda, + const __half* B, + int64_t ldb, + const __half* beta, + __half* C, + int64_t ldc); + +#endif + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const void* A, + cudaDataType Atype, + int lda, + const void* B, + cudaDataType Btype, + int ldb, + const float* beta, + void* C, + cudaDataType Ctype, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmEx_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const float* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + const void* B, + cudaDataType Btype, + int64_t ldb, + const float* beta, + void* C, + cudaDataType Ctype, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void* alpha, + const void* A, + cudaDataType Atype, + int lda, + const void* B, + cudaDataType Btype, + int ldb, + const void* beta, + void* C, + cudaDataType Ctype, + int ldc, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmEx_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const void* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + const void* B, + cudaDataType Btype, + int64_t ldb, + const void* beta, + void* C, + cudaDataType Ctype, + int64_t ldc, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex* alpha, + const void* A, + cudaDataType Atype, + int lda, + const void* B, + cudaDataType Btype, + int ldb, + const cuComplex* beta, + void* C, + cudaDataType Ctype, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmEx_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuComplex* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + const void* B, + cudaDataType Btype, + int64_t ldb, + const cuComplex* beta, + void* C, + cudaDataType Ctype, + int64_t ldc); + +/* SYRK */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyrk_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float* alpha, + const float* A, + int lda, + const float* beta, + float* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyrk_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const float* alpha, + const float* A, + int64_t lda, + const float* beta, + float* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyrk_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const double* alpha, + const double* A, + int lda, + const double* beta, + double* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyrk_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const double* alpha, + const double* A, + int64_t lda, + const double* beta, + double* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyrk_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyrk_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyrk_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyrk_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyrkEx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex* alpha, + const void* A, + cudaDataType Atype, + int lda, + const cuComplex* beta, + void* C, + cudaDataType Ctype, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyrkEx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuComplex* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + const cuComplex* beta, + void* C, + cudaDataType Ctype, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyrk3mEx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex* alpha, + const void* A, + cudaDataType Atype, + int lda, + const cuComplex* beta, + void* C, + cudaDataType Ctype, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyrk3mEx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuComplex* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + const cuComplex* beta, + void* C, + cudaDataType Ctype, + int64_t ldc); + +/* HERK */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCherk_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float* alpha, + const cuComplex* A, + int lda, + const float* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCherk_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const float* alpha, + const cuComplex* A, + int64_t lda, + const float* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZherk_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const double* alpha, + const cuDoubleComplex* A, + int lda, + const double* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZherk_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const double* alpha, + const cuDoubleComplex* A, + int64_t lda, + const double* beta, + cuDoubleComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCherkEx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float* alpha, + const void* A, + cudaDataType Atype, + int lda, + const float* beta, + void* C, + cudaDataType Ctype, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCherkEx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const float* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + const float* beta, + void* C, + cudaDataType Ctype, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCherk3mEx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float* alpha, + const void* A, + cudaDataType Atype, + int lda, + const float* beta, + void* C, + cudaDataType Ctype, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCherk3mEx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const float* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + const float* beta, + void* C, + cudaDataType Ctype, + int64_t ldc); + +/* SYR2K / HER2K */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyr2k_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float* alpha, + const float* A, + int lda, + const float* B, + int ldb, + const float* beta, + float* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyr2k_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const float* alpha, + const float* A, + int64_t lda, + const float* B, + int64_t ldb, + const float* beta, + float* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyr2k_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const double* alpha, + const double* A, + int lda, + const double* B, + int ldb, + const double* beta, + double* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyr2k_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const double* alpha, + const double* A, + int64_t lda, + const double* B, + int64_t ldb, + const double* beta, + double* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyr2k_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + const cuComplex* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyr2k_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + const cuComplex* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyr2k_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyr2k_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCher2k_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + const float* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCher2k_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + const float* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZher2k_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + const double* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZher2k_v2_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + const double* beta, + cuDoubleComplex* C, + int64_t ldc); + +/* SYRKX / HERKX */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyrkx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const float* alpha, + const float* A, + int lda, + const float* B, + int ldb, + const float* beta, + float* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsyrkx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const float* alpha, + const float* A, + int64_t lda, + const float* B, + int64_t ldb, + const float* beta, + float* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyrkx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const double* alpha, + const double* A, + int lda, + const double* B, + int ldb, + const double* beta, + double* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsyrkx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const double* alpha, + const double* A, + int64_t lda, + const double* B, + int64_t ldb, + const double* beta, + double* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyrkx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + const cuComplex* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsyrkx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + const cuComplex* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyrkx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsyrkx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCherkx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + const float* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCherkx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + const float* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZherkx(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + const double* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZherkx_64(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + const double* beta, + cuDoubleComplex* C, + int64_t ldc); + +/* SYMM */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsymm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const float* alpha, + const float* A, + int lda, + const float* B, + int ldb, + const float* beta, + float* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSsymm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int64_t m, + int64_t n, + const float* alpha, + const float* A, + int64_t lda, + const float* B, + int64_t ldb, + const float* beta, + float* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsymm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const double* alpha, + const double* A, + int lda, + const double* B, + int ldb, + const double* beta, + double* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDsymm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int64_t m, + int64_t n, + const double* alpha, + const double* A, + int64_t lda, + const double* B, + int64_t ldb, + const double* beta, + double* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsymm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + const cuComplex* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCsymm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + const cuComplex* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsymm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZsymm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int64_t ldc); + +/* HEMM */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChemm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + const cuComplex* beta, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasChemm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + const cuComplex* beta, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhemm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZhemm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int64_t ldc); + +/* TRSM */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const float* alpha, + const float* A, + int lda, + float* B, + int ldb); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const float* alpha, + const float* A, + int64_t lda, + float* B, + int64_t ldb); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const double* alpha, + const double* A, + int lda, + double* B, + int ldb); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const double* alpha, + const double* A, + int64_t lda, + double* B, + int64_t ldb); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + cuComplex* B, + int ldb); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + cuComplex* B, + int64_t ldb); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + cuDoubleComplex* B, + int ldb); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + cuDoubleComplex* B, + int64_t ldb); + +/* TRMM */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrmm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const float* alpha, + const float* A, + int lda, + const float* B, + int ldb, + float* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrmm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const float* alpha, + const float* A, + int64_t lda, + const float* B, + int64_t ldb, + float* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrmm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const double* alpha, + const double* A, + int lda, + const double* B, + int ldb, + double* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrmm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const double* alpha, + const double* A, + int64_t lda, + const double* B, + int64_t ldb, + double* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrmm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* B, + int ldb, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrmm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* B, + int64_t ldb, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrmm_v2(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* B, + int ldb, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrmm_v2_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* B, + int64_t ldb, + cuDoubleComplex* C, + int64_t ldc); + +/* BATCH GEMM */ + +#if defined(__cplusplus) + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const __half* alpha, + const __half* const Aarray[], + int lda, + const __half* const Barray[], + int ldb, + const __half* beta, + __half* const Carray[], + int ldc, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const __half* alpha, + const __half* const Aarray[], + int64_t lda, + const __half* const Barray[], + int64_t ldb, + const __half* beta, + __half* const Carray[], + int64_t ldc, + int64_t batchCount); + +#endif + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* const Aarray[], + int lda, + const float* const Barray[], + int ldb, + const float* beta, + float* const Carray[], + int ldc, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const float* alpha, + const float* const Aarray[], + int64_t lda, + const float* const Barray[], + int64_t ldb, + const float* beta, + float* const Carray[], + int64_t ldc, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemmBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const double* alpha, + const double* const Aarray[], + int lda, + const double* const Barray[], + int ldb, + const double* beta, + double* const Carray[], + int ldc, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemmBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const double* alpha, + const double* const Aarray[], + int64_t lda, + const double* const Barray[], + int64_t ldb, + const double* beta, + double* const Carray[], + int64_t ldc, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex* alpha, + const cuComplex* const Aarray[], + int lda, + const cuComplex* const Barray[], + int ldb, + const cuComplex* beta, + cuComplex* const Carray[], + int ldc, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* const Aarray[], + int64_t lda, + const cuComplex* const Barray[], + int64_t ldb, + const cuComplex* beta, + cuComplex* const Carray[], + int64_t ldc, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm3mBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex* alpha, + const cuComplex* const Aarray[], + int lda, + const cuComplex* const Barray[], + int ldb, + const cuComplex* beta, + cuComplex* const Carray[], + int ldc, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm3mBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* const Aarray[], + int64_t lda, + const cuComplex* const Barray[], + int64_t ldb, + const cuComplex* beta, + cuComplex* const Carray[], + int64_t ldc, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemmBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* const Aarray[], + int lda, + const cuDoubleComplex* const Barray[], + int ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* const Carray[], + int ldc, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemmBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* const Aarray[], + int64_t lda, + const cuDoubleComplex* const Barray[], + int64_t ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* const Carray[], + int64_t ldc, + int64_t batchCount); + +#if defined(__cplusplus) + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmStridedBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const __half* alpha, + const __half* A, + int lda, + long long int strideA, + const __half* B, + int ldb, + long long int strideB, + const __half* beta, + __half* C, + int ldc, + long long int strideC, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmStridedBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const __half* alpha, + const __half* A, + int64_t lda, + long long int strideA, + const __half* B, + int64_t ldb, + long long int strideB, + const __half* beta, + __half* C, + int64_t ldc, + long long int strideC, + int64_t batchCount); + +#endif + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmStridedBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* A, + int lda, + long long int strideA, + const float* B, + int ldb, + long long int strideB, + const float* beta, + float* C, + int ldc, + long long int strideC, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmStridedBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const float* alpha, + const float* A, + int64_t lda, + long long int strideA, + const float* B, + int64_t ldb, + long long int strideB, + const float* beta, + float* C, + int64_t ldc, + long long int strideC, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemmStridedBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const double* alpha, + const double* A, + int lda, + long long int strideA, + const double* B, + int ldb, + long long int strideB, + const double* beta, + double* C, + int ldc, + long long int strideC, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemmStridedBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const double* alpha, + const double* A, + int64_t lda, + long long int strideA, + const double* B, + int64_t ldb, + long long int strideB, + const double* beta, + double* C, + int64_t ldc, + long long int strideC, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmStridedBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + long long int strideA, + const cuComplex* B, + int ldb, + long long int strideB, + const cuComplex* beta, + cuComplex* C, + int ldc, + long long int strideC, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmStridedBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + long long int strideA, + const cuComplex* B, + int64_t ldb, + long long int strideB, + const cuComplex* beta, + cuComplex* C, + int64_t ldc, + long long int strideC, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm3mStridedBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuComplex* alpha, + const cuComplex* A, + int lda, + long long int strideA, + const cuComplex* B, + int ldb, + long long int strideB, + const cuComplex* beta, + cuComplex* C, + int ldc, + long long int strideC, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemm3mStridedBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + long long int strideA, + const cuComplex* B, + int64_t ldb, + long long int strideB, + const cuComplex* beta, + cuComplex* C, + int64_t ldc, + long long int strideC, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemmStridedBatched(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + long long int strideA, + const cuDoubleComplex* B, + int ldb, + long long int strideB, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int ldc, + long long int strideC, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemmStridedBatched_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + long long int strideA, + const cuDoubleComplex* B, + int64_t ldb, + long long int strideB, + const cuDoubleComplex* beta, + cuDoubleComplex* C, + int64_t ldc, + long long int strideC, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmBatchedEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void* alpha, + const void* const Aarray[], + cudaDataType Atype, + int lda, + const void* const Barray[], + cudaDataType Btype, + int ldb, + const void* beta, + void* const Carray[], + cudaDataType Ctype, + int ldc, + int batchCount, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmBatchedEx_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const void* alpha, + const void* const Aarray[], + cudaDataType Atype, + int64_t lda, + const void* const Barray[], + cudaDataType Btype, + int64_t ldb, + const void* beta, + void* const Carray[], + cudaDataType Ctype, + int64_t ldc, + int64_t batchCount, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmStridedBatchedEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void* alpha, + const void* A, + cudaDataType Atype, + int lda, + long long int strideA, + const void* B, + cudaDataType Btype, + int ldb, + long long int strideB, + const void* beta, + void* C, + cudaDataType Ctype, + int ldc, + long long int strideC, + int batchCount, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmStridedBatchedEx_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + int64_t k, + const void* alpha, + const void* A, + cudaDataType Atype, + int64_t lda, + long long int strideA, + const void* B, + cudaDataType Btype, + int64_t ldb, + long long int strideB, + const void* beta, + void* C, + cudaDataType Ctype, + int64_t ldc, + long long int strideC, + int64_t batchCount, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmGroupedBatched(cublasHandle_t handle, + const cublasOperation_t transa_array[], + const cublasOperation_t transb_array[], + const int m_array[], + const int n_array[], + const int k_array[], + const float alpha_array[], + const float* const Aarray[], + const int lda_array[], + const float* const Barray[], + const int ldb_array[], + const float beta_array[], + float* const Carray[], + const int ldc_array[], + int group_count, + const int group_size[]); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmGroupedBatched_64(cublasHandle_t handle, + const cublasOperation_t transa_array[], + const cublasOperation_t transb_array[], + const int64_t m_array[], + const int64_t n_array[], + const int64_t k_array[], + const float alpha_array[], + const float* const Aarray[], + const int64_t lda_array[], + const float* const Barray[], + const int64_t ldb_array[], + const float beta_array[], + float* const Carray[], + const int64_t ldc_array[], + int64_t group_count, + const int64_t group_size[]); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemmGroupedBatched(cublasHandle_t handle, + const cublasOperation_t transa_array[], + const cublasOperation_t transb_array[], + const int m_array[], + const int n_array[], + const int k_array[], + const double alpha_array[], + const double* const Aarray[], + const int lda_array[], + const double* const Barray[], + const int ldb_array[], + const double beta_array[], + double* const Carray[], + const int ldc_array[], + int group_count, + const int group_size[]); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemmGroupedBatched_64(cublasHandle_t handle, + const cublasOperation_t transa_array[], + const cublasOperation_t transb_array[], + const int64_t m_array[], + const int64_t n_array[], + const int64_t k_array[], + const double alpha_array[], + const double* const Aarray[], + const int64_t lda_array[], + const double* const Barray[], + const int64_t ldb_array[], + const double beta_array[], + double* const Carray[], + const int64_t ldc_array[], + int64_t group_count, + const int64_t group_size[]); + +/* ---------------- CUBLAS BLAS-like Extension ---------------- */ + +/* GEAM */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgeam(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + const float* alpha, + const float* A, + int lda, + const float* beta, + const float* B, + int ldb, + float* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgeam_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + const float* alpha, + const float* A, + int64_t lda, + const float* beta, + const float* B, + int64_t ldb, + float* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgeam(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + const double* alpha, + const double* A, + int lda, + const double* beta, + const double* B, + int ldb, + double* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgeam_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + const double* alpha, + const double* A, + int64_t lda, + const double* beta, + const double* B, + int64_t ldb, + double* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgeam(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + const cuComplex* alpha, + const cuComplex* A, + int lda, + const cuComplex* beta, + const cuComplex* B, + int ldb, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgeam_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* A, + int64_t lda, + const cuComplex* beta, + const cuComplex* B, + int64_t ldb, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgeam(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* beta, + const cuDoubleComplex* B, + int ldb, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgeam_64(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* beta, + const cuDoubleComplex* B, + int64_t ldb, + cuDoubleComplex* C, + int64_t ldc); + +/* TRSM - Batched Triangular Solver */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsmBatched(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const float* alpha, + const float* const A[], + int lda, + float* const B[], + int ldb, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsmBatched_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const float* alpha, + const float* const A[], + int64_t lda, + float* const B[], + int64_t ldb, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsmBatched(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const double* alpha, + const double* const A[], + int lda, + double* const B[], + int ldb, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsmBatched_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const double* alpha, + const double* const A[], + int64_t lda, + double* const B[], + int64_t ldb, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsmBatched(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuComplex* alpha, + const cuComplex* const A[], + int lda, + cuComplex* const B[], + int ldb, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsmBatched_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const cuComplex* alpha, + const cuComplex* const A[], + int64_t lda, + cuComplex* const B[], + int64_t ldb, + int64_t batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsmBatched(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int m, + int n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* const A[], + int lda, + cuDoubleComplex* const B[], + int ldb, + int batchCount); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsmBatched_64(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, + int64_t m, + int64_t n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* const A[], + int64_t lda, + cuDoubleComplex* const B[], + int64_t ldb, + int64_t batchCount); + +/* DGMM */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSdgmm(cublasHandle_t handle, + cublasSideMode_t mode, + int m, + int n, + const float* A, + int lda, + const float* x, + int incx, + float* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSdgmm_64(cublasHandle_t handle, + cublasSideMode_t mode, + int64_t m, + int64_t n, + const float* A, + int64_t lda, + const float* x, + int64_t incx, + float* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDdgmm(cublasHandle_t handle, + cublasSideMode_t mode, + int m, + int n, + const double* A, + int lda, + const double* x, + int incx, + double* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDdgmm_64(cublasHandle_t handle, + cublasSideMode_t mode, + int64_t m, + int64_t n, + const double* A, + int64_t lda, + const double* x, + int64_t incx, + double* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCdgmm(cublasHandle_t handle, + cublasSideMode_t mode, + int m, + int n, + const cuComplex* A, + int lda, + const cuComplex* x, + int incx, + cuComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCdgmm_64(cublasHandle_t handle, + cublasSideMode_t mode, + int64_t m, + int64_t n, + const cuComplex* A, + int64_t lda, + const cuComplex* x, + int64_t incx, + cuComplex* C, + int64_t ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZdgmm(cublasHandle_t handle, + cublasSideMode_t mode, + int m, + int n, + const cuDoubleComplex* A, + int lda, + const cuDoubleComplex* x, + int incx, + cuDoubleComplex* C, + int ldc); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZdgmm_64(cublasHandle_t handle, + cublasSideMode_t mode, + int64_t m, + int64_t n, + const cuDoubleComplex* A, + int64_t lda, + const cuDoubleComplex* x, + int64_t incx, + cuDoubleComplex* C, + int64_t ldc); + +/* Batched - MATINV*/ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSmatinvBatched(cublasHandle_t handle, + int n, + const float* const A[], + int lda, + float* const Ainv[], + int lda_inv, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDmatinvBatched(cublasHandle_t handle, + int n, + const double* const A[], + int lda, + double* const Ainv[], + int lda_inv, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCmatinvBatched(cublasHandle_t handle, + int n, + const cuComplex* const A[], + int lda, + cuComplex* const Ainv[], + int lda_inv, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZmatinvBatched(cublasHandle_t handle, + int n, + const cuDoubleComplex* const A[], + int lda, + cuDoubleComplex* const Ainv[], + int lda_inv, + int* info, + int batchSize); + +/* Batch QR Factorization */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgeqrfBatched(cublasHandle_t handle, + int m, + int n, + float* const Aarray[], + int lda, + float* const TauArray[], + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgeqrfBatched(cublasHandle_t handle, + int m, + int n, + double* const Aarray[], + int lda, + double* const TauArray[], + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgeqrfBatched(cublasHandle_t handle, + int m, + int n, + cuComplex* const Aarray[], + int lda, + cuComplex* const TauArray[], + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgeqrfBatched(cublasHandle_t handle, + int m, + int n, + cuDoubleComplex* const Aarray[], + int lda, + cuDoubleComplex* const TauArray[], + int* info, + int batchSize); + +/* Least Square Min only m >= n and Non-transpose supported */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgelsBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int nrhs, + float* const Aarray[], + int lda, + float* const Carray[], + int ldc, + int* info, + int* devInfoArray, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgelsBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int nrhs, + double* const Aarray[], + int lda, + double* const Carray[], + int ldc, + int* info, + int* devInfoArray, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgelsBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int nrhs, + cuComplex* const Aarray[], + int lda, + cuComplex* const Carray[], + int ldc, + int* info, + int* devInfoArray, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgelsBatched(cublasHandle_t handle, + cublasOperation_t trans, + int m, + int n, + int nrhs, + cuDoubleComplex* const Aarray[], + int lda, + cuDoubleComplex* const Carray[], + int ldc, + int* info, + int* devInfoArray, + int batchSize); + +/* TPTTR : Triangular Pack format to Triangular format */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasStpttr(cublasHandle_t handle, cublasFillMode_t uplo, int n, const float* AP, float* A, int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDtpttr(cublasHandle_t handle, cublasFillMode_t uplo, int n, const double* AP, double* A, int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCtpttr(cublasHandle_t handle, cublasFillMode_t uplo, int n, const cuComplex* AP, cuComplex* A, int lda); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtpttr( + cublasHandle_t handle, cublasFillMode_t uplo, int n, const cuDoubleComplex* AP, cuDoubleComplex* A, int lda); + +/* TRTTP : Triangular format to Triangular Pack format */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasStrttp(cublasHandle_t handle, cublasFillMode_t uplo, int n, const float* A, int lda, float* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDtrttp(cublasHandle_t handle, cublasFillMode_t uplo, int n, const double* A, int lda, double* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCtrttp(cublasHandle_t handle, cublasFillMode_t uplo, int n, const cuComplex* A, int lda, cuComplex* AP); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrttp( + cublasHandle_t handle, cublasFillMode_t uplo, int n, const cuDoubleComplex* A, int lda, cuDoubleComplex* AP); + +/* Batched LU - GETRF*/ + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasSgetrfBatched(cublasHandle_t handle, int n, float* const A[], int lda, int* P, int* info, int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasDgetrfBatched(cublasHandle_t handle, int n, double* const A[], int lda, int* P, int* info, int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI +cublasCgetrfBatched(cublasHandle_t handle, int n, cuComplex* const A[], int lda, int* P, int* info, int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgetrfBatched( + cublasHandle_t handle, int n, cuDoubleComplex* const A[], int lda, int* P, int* info, int batchSize); + +/* Batched inversion based on LU factorization from getrf */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgetriBatched(cublasHandle_t handle, + int n, + const float* const A[], + int lda, + const int* P, + float* const C[], + int ldc, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgetriBatched(cublasHandle_t handle, + int n, + const double* const A[], + int lda, + const int* P, + double* const C[], + int ldc, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgetriBatched(cublasHandle_t handle, + int n, + const cuComplex* const A[], + int lda, + const int* P, + cuComplex* const C[], + int ldc, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgetriBatched(cublasHandle_t handle, + int n, + const cuDoubleComplex* const A[], + int lda, + const int* P, + cuDoubleComplex* const C[], + int ldc, + int* info, + int batchSize); + +/* Batched solver based on LU factorization from getrf */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgetrsBatched(cublasHandle_t handle, + cublasOperation_t trans, + int n, + int nrhs, + const float* const Aarray[], + int lda, + const int* devIpiv, + float* const Barray[], + int ldb, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgetrsBatched(cublasHandle_t handle, + cublasOperation_t trans, + int n, + int nrhs, + const double* const Aarray[], + int lda, + const int* devIpiv, + double* const Barray[], + int ldb, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgetrsBatched(cublasHandle_t handle, + cublasOperation_t trans, + int n, + int nrhs, + const cuComplex* const Aarray[], + int lda, + const int* devIpiv, + cuComplex* const Barray[], + int ldb, + int* info, + int batchSize); + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgetrsBatched(cublasHandle_t handle, + cublasOperation_t trans, + int n, + int nrhs, + const cuDoubleComplex* const Aarray[], + int lda, + const int* devIpiv, + cuDoubleComplex* const Barray[], + int ldb, + int* info, + int batchSize); + +/* Deprecated */ + +CUBLASAPI cublasStatus_t CUBLASWINAPI cublasUint8gemmBias(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + cublasOperation_t transc, + int m, + int n, + int k, + const unsigned char* A, + int A_bias, + int lda, + const unsigned char* B, + int B_bias, + int ldb, + unsigned char* C, + int C_bias, + int ldc, + int C_mult, + int C_shift); + +/* }}} cuBLAS Exported API */ + +#if defined(__cplusplus) +} + +static inline cublasStatus_t cublasMigrateComputeType(cublasHandle_t handle, + cudaDataType_t dataType, + cublasComputeType_t* computeType) { + cublasMath_t mathMode = CUBLAS_DEFAULT_MATH; + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + status = cublasGetMathMode(handle, &mathMode); + if (status != CUBLAS_STATUS_SUCCESS) { + return status; + } + + bool isPedantic = ((mathMode & 0xf) == CUBLAS_PEDANTIC_MATH); + + switch (dataType) { + case CUDA_R_32F: + case CUDA_C_32F: + *computeType = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; + return CUBLAS_STATUS_SUCCESS; + case CUDA_R_64F: + case CUDA_C_64F: + *computeType = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; + return CUBLAS_STATUS_SUCCESS; + case CUDA_R_16F: + *computeType = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; + return CUBLAS_STATUS_SUCCESS; + case CUDA_R_32I: + *computeType = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; + return CUBLAS_STATUS_SUCCESS; + default: + return CUBLAS_STATUS_NOT_SUPPORTED; + } +} +/* wrappers to accept old code with cudaDataType computeType when referenced from c++ code */ +static inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void* alpha, /* host or device pointer */ + const void* A, + cudaDataType Atype, + int lda, + const void* B, + cudaDataType Btype, + int ldb, + const void* beta, /* host or device pointer */ + void* C, + cudaDataType Ctype, + int ldc, + cudaDataType computeType, + cublasGemmAlgo_t algo) { + cublasComputeType_t migratedComputeType = CUBLAS_COMPUTE_32F; + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + status = cublasMigrateComputeType(handle, computeType, &migratedComputeType); + if (status != CUBLAS_STATUS_SUCCESS) { + return status; + } + + return cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + migratedComputeType, + algo); +} + +static inline cublasStatus_t cublasGemmBatchedEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void* alpha, /* host or device pointer */ + const void* const Aarray[], + cudaDataType Atype, + int lda, + const void* const Barray[], + cudaDataType Btype, + int ldb, + const void* beta, /* host or device pointer */ + void* const Carray[], + cudaDataType Ctype, + int ldc, + int batchCount, + cudaDataType computeType, + cublasGemmAlgo_t algo) { + cublasComputeType_t migratedComputeType; + cublasStatus_t status; + status = cublasMigrateComputeType(handle, computeType, &migratedComputeType); + if (status != CUBLAS_STATUS_SUCCESS) { + return status; + } + + return cublasGemmBatchedEx(handle, + transa, + transb, + m, + n, + k, + alpha, + Aarray, + Atype, + lda, + Barray, + Btype, + ldb, + beta, + Carray, + Ctype, + ldc, + batchCount, + migratedComputeType, + algo); +} + +static inline cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void* alpha, /* host or device pointer */ + const void* A, + cudaDataType Atype, + int lda, + long long int strideA, /* purposely signed */ + const void* B, + cudaDataType Btype, + int ldb, + long long int strideB, + const void* beta, /* host or device pointer */ + void* C, + cudaDataType Ctype, + int ldc, + long long int strideC, + int batchCount, + cudaDataType computeType, + cublasGemmAlgo_t algo) { + cublasComputeType_t migratedComputeType; + cublasStatus_t status; + status = cublasMigrateComputeType(handle, computeType, &migratedComputeType); + if (status != CUBLAS_STATUS_SUCCESS) { + return status; + } + + return cublasGemmStridedBatchedEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + strideA, + B, + Btype, + ldb, + strideB, + beta, + C, + Ctype, + ldc, + strideC, + batchCount, + migratedComputeType, + algo); +} +#endif /* __cplusplus */ + +#endif /* !defined(CUBLAS_API_H_) */ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..bd81a3b1d8e7e3d04d6c54f4c0640af7d8893eab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h @@ -0,0 +1,478 @@ +/* + * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +/* + * This is the public header file for the new CUBLAS library API, it mapped the generic + * Cublas name functions to the actual _v2 implementations. + */ + +#if !defined(CUBLAS_V2_H_) +#define CUBLAS_V2_H_ + +#if defined(CUBLAS_H_) +#error "It is an error to include both cublas.h and cublas_v2.h" +#endif + +#undef CUBLASAPI +#ifdef __CUDACC__ +#define CUBLASAPI __host__ __device__ +#else +#define CUBLASAPI +#endif + +#include "cublas_api.h" + +#define cublasCreate cublasCreate_v2 +#define cublasDestroy cublasDestroy_v2 +#define cublasGetVersion cublasGetVersion_v2 +#define cublasSetWorkspace cublasSetWorkspace_v2 +#define cublasSetStream cublasSetStream_v2 +#define cublasGetStream cublasGetStream_v2 +#define cublasGetPointerMode cublasGetPointerMode_v2 +#define cublasSetPointerMode cublasSetPointerMode_v2 + +/* 32-bit integer */ + +/* Blas1 Routines */ + +#define cublasSnrm2 cublasSnrm2_v2 +#define cublasDnrm2 cublasDnrm2_v2 +#define cublasScnrm2 cublasScnrm2_v2 +#define cublasDznrm2 cublasDznrm2_v2 + +#define cublasSdot cublasSdot_v2 +#define cublasDdot cublasDdot_v2 +#define cublasCdotu cublasCdotu_v2 +#define cublasCdotc cublasCdotc_v2 +#define cublasZdotu cublasZdotu_v2 +#define cublasZdotc cublasZdotc_v2 + +#define cublasSscal cublasSscal_v2 +#define cublasDscal cublasDscal_v2 +#define cublasCscal cublasCscal_v2 +#define cublasCsscal cublasCsscal_v2 +#define cublasZscal cublasZscal_v2 +#define cublasZdscal cublasZdscal_v2 + +#define cublasSaxpy cublasSaxpy_v2 +#define cublasDaxpy cublasDaxpy_v2 +#define cublasCaxpy cublasCaxpy_v2 +#define cublasZaxpy cublasZaxpy_v2 + +#define cublasScopy cublasScopy_v2 +#define cublasDcopy cublasDcopy_v2 +#define cublasCcopy cublasCcopy_v2 +#define cublasZcopy cublasZcopy_v2 + +#define cublasSswap cublasSswap_v2 +#define cublasDswap cublasDswap_v2 +#define cublasCswap cublasCswap_v2 +#define cublasZswap cublasZswap_v2 + +#define cublasIsamax cublasIsamax_v2 +#define cublasIdamax cublasIdamax_v2 +#define cublasIcamax cublasIcamax_v2 +#define cublasIzamax cublasIzamax_v2 + +#define cublasIsamin cublasIsamin_v2 +#define cublasIdamin cublasIdamin_v2 +#define cublasIcamin cublasIcamin_v2 +#define cublasIzamin cublasIzamin_v2 + +#define cublasSasum cublasSasum_v2 +#define cublasDasum cublasDasum_v2 +#define cublasScasum cublasScasum_v2 +#define cublasDzasum cublasDzasum_v2 + +#define cublasSrot cublasSrot_v2 +#define cublasDrot cublasDrot_v2 +#define cublasCrot cublasCrot_v2 +#define cublasCsrot cublasCsrot_v2 +#define cublasZrot cublasZrot_v2 +#define cublasZdrot cublasZdrot_v2 + +#define cublasSrotg cublasSrotg_v2 +#define cublasDrotg cublasDrotg_v2 +#define cublasCrotg cublasCrotg_v2 +#define cublasZrotg cublasZrotg_v2 + +#define cublasSrotm cublasSrotm_v2 +#define cublasDrotm cublasDrotm_v2 + +#define cublasSrotmg cublasSrotmg_v2 +#define cublasDrotmg cublasDrotmg_v2 + +/* Blas2 Routines */ + +#define cublasSgemv cublasSgemv_v2 +#define cublasDgemv cublasDgemv_v2 +#define cublasCgemv cublasCgemv_v2 +#define cublasZgemv cublasZgemv_v2 + +#define cublasSgbmv cublasSgbmv_v2 +#define cublasDgbmv cublasDgbmv_v2 +#define cublasCgbmv cublasCgbmv_v2 +#define cublasZgbmv cublasZgbmv_v2 + +#define cublasStrmv cublasStrmv_v2 +#define cublasDtrmv cublasDtrmv_v2 +#define cublasCtrmv cublasCtrmv_v2 +#define cublasZtrmv cublasZtrmv_v2 + +#define cublasStbmv cublasStbmv_v2 +#define cublasDtbmv cublasDtbmv_v2 +#define cublasCtbmv cublasCtbmv_v2 +#define cublasZtbmv cublasZtbmv_v2 + +#define cublasStpmv cublasStpmv_v2 +#define cublasDtpmv cublasDtpmv_v2 +#define cublasCtpmv cublasCtpmv_v2 +#define cublasZtpmv cublasZtpmv_v2 + +#define cublasStrsv cublasStrsv_v2 +#define cublasDtrsv cublasDtrsv_v2 +#define cublasCtrsv cublasCtrsv_v2 +#define cublasZtrsv cublasZtrsv_v2 + +#define cublasStpsv cublasStpsv_v2 +#define cublasDtpsv cublasDtpsv_v2 +#define cublasCtpsv cublasCtpsv_v2 +#define cublasZtpsv cublasZtpsv_v2 + +#define cublasStbsv cublasStbsv_v2 +#define cublasDtbsv cublasDtbsv_v2 +#define cublasCtbsv cublasCtbsv_v2 +#define cublasZtbsv cublasZtbsv_v2 + +#define cublasSsymv cublasSsymv_v2 +#define cublasDsymv cublasDsymv_v2 +#define cublasCsymv cublasCsymv_v2 +#define cublasZsymv cublasZsymv_v2 +#define cublasChemv cublasChemv_v2 +#define cublasZhemv cublasZhemv_v2 + +#define cublasSsbmv cublasSsbmv_v2 +#define cublasDsbmv cublasDsbmv_v2 +#define cublasChbmv cublasChbmv_v2 +#define cublasZhbmv cublasZhbmv_v2 + +#define cublasSspmv cublasSspmv_v2 +#define cublasDspmv cublasDspmv_v2 +#define cublasChpmv cublasChpmv_v2 +#define cublasZhpmv cublasZhpmv_v2 + +#define cublasSger cublasSger_v2 +#define cublasDger cublasDger_v2 +#define cublasCgeru cublasCgeru_v2 +#define cublasCgerc cublasCgerc_v2 +#define cublasZgeru cublasZgeru_v2 +#define cublasZgerc cublasZgerc_v2 + +#define cublasSsyr cublasSsyr_v2 +#define cublasDsyr cublasDsyr_v2 +#define cublasCsyr cublasCsyr_v2 +#define cublasZsyr cublasZsyr_v2 +#define cublasCher cublasCher_v2 +#define cublasZher cublasZher_v2 + +#define cublasSspr cublasSspr_v2 +#define cublasDspr cublasDspr_v2 +#define cublasChpr cublasChpr_v2 +#define cublasZhpr cublasZhpr_v2 + +#define cublasSsyr2 cublasSsyr2_v2 +#define cublasDsyr2 cublasDsyr2_v2 +#define cublasCsyr2 cublasCsyr2_v2 +#define cublasZsyr2 cublasZsyr2_v2 +#define cublasCher2 cublasCher2_v2 +#define cublasZher2 cublasZher2_v2 + +#define cublasSspr2 cublasSspr2_v2 +#define cublasDspr2 cublasDspr2_v2 +#define cublasChpr2 cublasChpr2_v2 +#define cublasZhpr2 cublasZhpr2_v2 + +/* Blas3 Routines */ + +#define cublasSgemm cublasSgemm_v2 +#define cublasDgemm cublasDgemm_v2 +#define cublasCgemm cublasCgemm_v2 +#define cublasZgemm cublasZgemm_v2 + +#define cublasSsyrk cublasSsyrk_v2 +#define cublasDsyrk cublasDsyrk_v2 +#define cublasCsyrk cublasCsyrk_v2 +#define cublasZsyrk cublasZsyrk_v2 +#define cublasCherk cublasCherk_v2 +#define cublasZherk cublasZherk_v2 + +#define cublasSsyr2k cublasSsyr2k_v2 +#define cublasDsyr2k cublasDsyr2k_v2 +#define cublasCsyr2k cublasCsyr2k_v2 +#define cublasZsyr2k cublasZsyr2k_v2 +#define cublasCher2k cublasCher2k_v2 +#define cublasZher2k cublasZher2k_v2 + +#define cublasSsymm cublasSsymm_v2 +#define cublasDsymm cublasDsymm_v2 +#define cublasCsymm cublasCsymm_v2 +#define cublasZsymm cublasZsymm_v2 +#define cublasChemm cublasChemm_v2 +#define cublasZhemm cublasZhemm_v2 + +#define cublasStrsm cublasStrsm_v2 +#define cublasDtrsm cublasDtrsm_v2 +#define cublasCtrsm cublasCtrsm_v2 +#define cublasZtrsm cublasZtrsm_v2 + +#define cublasStrmm cublasStrmm_v2 +#define cublasDtrmm cublasDtrmm_v2 +#define cublasCtrmm cublasCtrmm_v2 +#define cublasZtrmm cublasZtrmm_v2 + +/* 64-bit integer */ + +/* Blas1 Routines */ + +#define cublasSnrm2_64 cublasSnrm2_v2_64 +#define cublasDnrm2_64 cublasDnrm2_v2_64 +#define cublasScnrm2_64 cublasScnrm2_v2_64 +#define cublasDznrm2_64 cublasDznrm2_v2_64 + +#define cublasSdot_64 cublasSdot_v2_64 +#define cublasDdot_64 cublasDdot_v2_64 +#define cublasCdotu_64 cublasCdotu_v2_64 +#define cublasCdotc_64 cublasCdotc_v2_64 +#define cublasZdotu_64 cublasZdotu_v2_64 +#define cublasZdotc_64 cublasZdotc_v2_64 + +#define cublasSscal_64 cublasSscal_v2_64 +#define cublasDscal_64 cublasDscal_v2_64 +#define cublasCscal_64 cublasCscal_v2_64 +#define cublasCsscal_64 cublasCsscal_v2_64 +#define cublasZscal_64 cublasZscal_v2_64 +#define cublasZdscal_64 cublasZdscal_v2_64 + +#define cublasSaxpy_64 cublasSaxpy_v2_64 +#define cublasDaxpy_64 cublasDaxpy_v2_64 +#define cublasCaxpy_64 cublasCaxpy_v2_64 +#define cublasZaxpy_64 cublasZaxpy_v2_64 + +#define cublasScopy_64 cublasScopy_v2_64 +#define cublasDcopy_64 cublasDcopy_v2_64 +#define cublasCcopy_64 cublasCcopy_v2_64 +#define cublasZcopy_64 cublasZcopy_v2_64 + +#define cublasSswap_64 cublasSswap_v2_64 +#define cublasDswap_64 cublasDswap_v2_64 +#define cublasCswap_64 cublasCswap_v2_64 +#define cublasZswap_64 cublasZswap_v2_64 + +#define cublasIsamax_64 cublasIsamax_v2_64 +#define cublasIdamax_64 cublasIdamax_v2_64 +#define cublasIcamax_64 cublasIcamax_v2_64 +#define cublasIzamax_64 cublasIzamax_v2_64 + +#define cublasIsamin_64 cublasIsamin_v2_64 +#define cublasIdamin_64 cublasIdamin_v2_64 +#define cublasIcamin_64 cublasIcamin_v2_64 +#define cublasIzamin_64 cublasIzamin_v2_64 + +#define cublasSasum_64 cublasSasum_v2_64 +#define cublasDasum_64 cublasDasum_v2_64 +#define cublasScasum_64 cublasScasum_v2_64 +#define cublasDzasum_64 cublasDzasum_v2_64 + +#define cublasSrot_64 cublasSrot_v2_64 +#define cublasDrot_64 cublasDrot_v2_64 +#define cublasCrot_64 cublasCrot_v2_64 +#define cublasCsrot_64 cublasCsrot_v2_64 +#define cublasZrot_64 cublasZrot_v2_64 +#define cublasZdrot_64 cublasZdrot_v2_64 + +#define cublasSrotg_64 cublasSrotg_v2_64 +#define cublasDrotg_64 cublasDrotg_v2_64 +#define cublasCrotg_64 cublasCrotg_v2_64 +#define cublasZrotg_64 cublasZrotg_v2_64 + +#define cublasSrotm_64 cublasSrotm_v2_64 +#define cublasDrotm_64 cublasDrotm_v2_64 + +#define cublasSrotmg_64 cublasSrotmg_v2_64 +#define cublasDrotmg_64 cublasDrotmg_v2_64 + +/* Blas2 Routines */ + +#define cublasSgemv_64 cublasSgemv_v2_64 +#define cublasDgemv_64 cublasDgemv_v2_64 +#define cublasCgemv_64 cublasCgemv_v2_64 +#define cublasZgemv_64 cublasZgemv_v2_64 + +#define cublasSgbmv_64 cublasSgbmv_v2_64 +#define cublasDgbmv_64 cublasDgbmv_v2_64 +#define cublasCgbmv_64 cublasCgbmv_v2_64 +#define cublasZgbmv_64 cublasZgbmv_v2_64 + +#define cublasStrmv_64 cublasStrmv_v2_64 +#define cublasDtrmv_64 cublasDtrmv_v2_64 +#define cublasCtrmv_64 cublasCtrmv_v2_64 +#define cublasZtrmv_64 cublasZtrmv_v2_64 + +#define cublasStbmv_64 cublasStbmv_v2_64 +#define cublasDtbmv_64 cublasDtbmv_v2_64 +#define cublasCtbmv_64 cublasCtbmv_v2_64 +#define cublasZtbmv_64 cublasZtbmv_v2_64 + +#define cublasStpmv_64 cublasStpmv_v2_64 +#define cublasDtpmv_64 cublasDtpmv_v2_64 +#define cublasCtpmv_64 cublasCtpmv_v2_64 +#define cublasZtpmv_64 cublasZtpmv_v2_64 + +#define cublasStrsv_64 cublasStrsv_v2_64 +#define cublasDtrsv_64 cublasDtrsv_v2_64 +#define cublasCtrsv_64 cublasCtrsv_v2_64 +#define cublasZtrsv_64 cublasZtrsv_v2_64 + +#define cublasStpsv_64 cublasStpsv_v2_64 +#define cublasDtpsv_64 cublasDtpsv_v2_64 +#define cublasCtpsv_64 cublasCtpsv_v2_64 +#define cublasZtpsv_64 cublasZtpsv_v2_64 + +#define cublasStbsv_64 cublasStbsv_v2_64 +#define cublasDtbsv_64 cublasDtbsv_v2_64 +#define cublasCtbsv_64 cublasCtbsv_v2_64 +#define cublasZtbsv_64 cublasZtbsv_v2_64 + +#define cublasSsymv_64 cublasSsymv_v2_64 +#define cublasDsymv_64 cublasDsymv_v2_64 +#define cublasCsymv_64 cublasCsymv_v2_64 +#define cublasZsymv_64 cublasZsymv_v2_64 +#define cublasChemv_64 cublasChemv_v2_64 +#define cublasZhemv_64 cublasZhemv_v2_64 + +#define cublasSsbmv_64 cublasSsbmv_v2_64 +#define cublasDsbmv_64 cublasDsbmv_v2_64 +#define cublasChbmv_64 cublasChbmv_v2_64 +#define cublasZhbmv_64 cublasZhbmv_v2_64 + +#define cublasSspmv_64 cublasSspmv_v2_64 +#define cublasDspmv_64 cublasDspmv_v2_64 +#define cublasChpmv_64 cublasChpmv_v2_64 +#define cublasZhpmv_64 cublasZhpmv_v2_64 + +#define cublasSger_64 cublasSger_v2_64 +#define cublasDger_64 cublasDger_v2_64 +#define cublasCgeru_64 cublasCgeru_v2_64 +#define cublasCgerc_64 cublasCgerc_v2_64 +#define cublasZgeru_64 cublasZgeru_v2_64 +#define cublasZgerc_64 cublasZgerc_v2_64 + +#define cublasSsyr_64 cublasSsyr_v2_64 +#define cublasDsyr_64 cublasDsyr_v2_64 +#define cublasCsyr_64 cublasCsyr_v2_64 +#define cublasZsyr_64 cublasZsyr_v2_64 +#define cublasCher_64 cublasCher_v2_64 +#define cublasZher_64 cublasZher_v2_64 + +#define cublasSspr_64 cublasSspr_v2_64 +#define cublasDspr_64 cublasDspr_v2_64 +#define cublasChpr_64 cublasChpr_v2_64 +#define cublasZhpr_64 cublasZhpr_v2_64 + +#define cublasSsyr2_64 cublasSsyr2_v2_64 +#define cublasDsyr2_64 cublasDsyr2_v2_64 +#define cublasCsyr2_64 cublasCsyr2_v2_64 +#define cublasZsyr2_64 cublasZsyr2_v2_64 +#define cublasCher2_64 cublasCher2_v2_64 +#define cublasZher2_64 cublasZher2_v2_64 + +#define cublasSspr2_64 cublasSspr2_v2_64 +#define cublasDspr2_64 cublasDspr2_v2_64 +#define cublasChpr2_64 cublasChpr2_v2_64 +#define cublasZhpr2_64 cublasZhpr2_v2_64 + +/* Blas3 Routines */ + +#define cublasSgemm_64 cublasSgemm_v2_64 +#define cublasDgemm_64 cublasDgemm_v2_64 +#define cublasCgemm_64 cublasCgemm_v2_64 +#define cublasZgemm_64 cublasZgemm_v2_64 + +#define cublasSsyrk_64 cublasSsyrk_v2_64 +#define cublasDsyrk_64 cublasDsyrk_v2_64 +#define cublasCsyrk_64 cublasCsyrk_v2_64 +#define cublasZsyrk_64 cublasZsyrk_v2_64 +#define cublasCherk_64 cublasCherk_v2_64 +#define cublasZherk_64 cublasZherk_v2_64 + +#define cublasSsyr2k_64 cublasSsyr2k_v2_64 +#define cublasDsyr2k_64 cublasDsyr2k_v2_64 +#define cublasCsyr2k_64 cublasCsyr2k_v2_64 +#define cublasZsyr2k_64 cublasZsyr2k_v2_64 +#define cublasCher2k_64 cublasCher2k_v2_64 +#define cublasZher2k_64 cublasZher2k_v2_64 + +#define cublasSsymm_64 cublasSsymm_v2_64 +#define cublasDsymm_64 cublasDsymm_v2_64 +#define cublasCsymm_64 cublasCsymm_v2_64 +#define cublasZsymm_64 cublasZsymm_v2_64 +#define cublasChemm_64 cublasChemm_v2_64 +#define cublasZhemm_64 cublasZhemm_v2_64 + +#define cublasStrsm_64 cublasStrsm_v2_64 +#define cublasDtrsm_64 cublasDtrsm_v2_64 +#define cublasCtrsm_64 cublasCtrsm_v2_64 +#define cublasZtrsm_64 cublasZtrsm_v2_64 + +#define cublasStrmm_64 cublasStrmm_v2_64 +#define cublasDtrmm_64 cublasDtrmm_v2_64 +#define cublasCtrmm_64 cublasCtrmm_v2_64 +#define cublasZtrmm_64 cublasZtrmm_v2_64 + +#endif /* !defined(CUBLAS_V2_H_) */ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/include/nvblas.h b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/nvblas.h new file mode 100644 index 0000000000000000000000000000000000000000..29ea9153faf7b3e62a6d53c0be1980ae79c49f51 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cublas/include/nvblas.h @@ -0,0 +1,824 @@ +/* + * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#if !defined(NVBLAS_H_) +#define NVBLAS_H_ + +#include "driver_types.h" +#include "cuComplex.h" /* import complex data type */ + +#if defined(__cplusplus) +extern "C" { +#endif + +/* GEMM */ +void sgemm_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc); + +void dgemm_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc); + +void cgemm_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zgemm_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +void sgemm(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc); + +void dgemm(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc); + +void cgemm(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zgemm(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +/* SYRK */ +void ssyrk_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* beta, + float* c, + const int* ldc); + +void dsyrk_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* beta, + double* c, + const int* ldc); + +void csyrk_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zsyrk_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +void ssyrk(const char* uplo, + const char* trans, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* beta, + float* c, + const int* ldc); + +void dsyrk(const char* uplo, + const char* trans, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* beta, + double* c, + const int* ldc); + +void csyrk(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zsyrk(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +/* HERK */ +void cherk_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const float* alpha, + const cuComplex* a, + const int* lda, + const float* beta, + cuComplex* c, + const int* ldc); + +void zherk_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const double* alpha, + const cuDoubleComplex* a, + const int* lda, + const double* beta, + cuDoubleComplex* c, + const int* ldc); + +void cherk(const char* uplo, + const char* trans, + const int* n, + const int* k, + const float* alpha, + const cuComplex* a, + const int* lda, + const float* beta, + cuComplex* c, + const int* ldc); + +void zherk(const char* uplo, + const char* trans, + const int* n, + const int* k, + const double* alpha, + const cuDoubleComplex* a, + const int* lda, + const double* beta, + cuDoubleComplex* c, + const int* ldc); + +/* TRSM */ +void strsm_(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + float* b, + const int* ldb); + +void dtrsm_(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + double* b, + const int* ldb); + +void ctrsm_(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + cuComplex* b, + const int* ldb); + +void ztrsm_(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + cuDoubleComplex* b, + const int* ldb); + +void strsm(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + float* b, + const int* ldb); + +void dtrsm(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + double* b, + const int* ldb); + +void ctrsm(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + cuComplex* b, + const int* ldb); + +void ztrsm(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + cuDoubleComplex* b, + const int* ldb); + +/* SYMM */ +void ssymm_(const char* side, + const char* uplo, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc); + +void dsymm_(const char* side, + const char* uplo, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc); + +void csymm_(const char* side, + const char* uplo, + const int* m, + const int* n, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zsymm_(const char* side, + const char* uplo, + const int* m, + const int* n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +void ssymm(const char* side, + const char* uplo, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc); + +void dsymm(const char* side, + const char* uplo, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc); + +void csymm(const char* side, + const char* uplo, + const int* m, + const int* n, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zsymm(const char* side, + const char* uplo, + const int* m, + const int* n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +/* HEMM */ +void chemm_(const char* side, + const char* uplo, + const int* m, + const int* n, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zhemm_(const char* side, + const char* uplo, + const int* m, + const int* n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +/* HEMM with no underscore*/ +void chemm(const char* side, + const char* uplo, + const int* m, + const int* n, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zhemm(const char* side, + const char* uplo, + const int* m, + const int* n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +/* SYR2K */ +void ssyr2k_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc); + +void dsyr2k_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc); + +void csyr2k_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zsyr2k_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +/* SYR2K no_underscore*/ +void ssyr2k(const char* uplo, + const char* trans, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc); + +void dsyr2k(const char* uplo, + const char* trans, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc); + +void csyr2k(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const cuComplex* beta, + cuComplex* c, + const int* ldc); + +void zsyr2k(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const cuDoubleComplex* beta, + cuDoubleComplex* c, + const int* ldc); + +/* HERK */ +void cher2k_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const float* beta, + cuComplex* c, + const int* ldc); + +void zher2k_(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const double* beta, + cuDoubleComplex* c, + const int* ldc); + +/* HER2K with no underscore */ +void cher2k(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + const cuComplex* b, + const int* ldb, + const float* beta, + cuComplex* c, + const int* ldc); + +void zher2k(const char* uplo, + const char* trans, + const int* n, + const int* k, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + const cuDoubleComplex* b, + const int* ldb, + const double* beta, + cuDoubleComplex* c, + const int* ldc); + +/* TRMM */ +void strmm_(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + float* b, + const int* ldb); + +void dtrmm_(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + double* b, + const int* ldb); + +void ctrmm_(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + cuComplex* b, + const int* ldb); + +void ztrmm_(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + cuDoubleComplex* b, + const int* ldb); + +void strmm(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + float* b, + const int* ldb); + +void dtrmm(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + double* b, + const int* ldb); + +void ctrmm(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const cuComplex* alpha, + const cuComplex* a, + const int* lda, + cuComplex* b, + const int* ldb); + +void ztrmm(const char* side, + const char* uplo, + const char* transa, + const char* diag, + const int* m, + const int* n, + const cuDoubleComplex* alpha, + const cuDoubleComplex* a, + const int* lda, + cuDoubleComplex* b, + const int* ldb); + +#if defined(__cplusplus) +} +#endif /* __cplusplus */ + +#endif /* !defined(NVBLAS_H_) */ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py b/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c45d0aa00d2061228f50bd66e0f0769a9e8a5ef6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 b/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 new file mode 100644 index 0000000000000000000000000000000000000000..7dc8c88191367bee01e752de108445a72c669208 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c2a58dc54154208392301d0fe3d53a120e4c1ebeab9e80ce91fe9948baeadc9 +size 757496 diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__init__.py b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62ddebfa6c7a0b463facf5b53350f2a6554ed333 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__init__.py b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd2af756dec150094c29e9a777dbe4fd43d31c85 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/nvrtc.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/nvrtc.h new file mode 100644 index 0000000000000000000000000000000000000000..cd9dbf51efe8d0803ce470bbb2b90ab21ad993d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/include/nvrtc.h @@ -0,0 +1,869 @@ +// +// NVIDIA_COPYRIGHT_BEGIN +// +// Copyright (c) 2014-2023, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. +// +// NVIDIA_COPYRIGHT_END +// + +#ifndef __NVRTC_H__ +#define __NVRTC_H__ + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#include + + +/*************************************************************************//** + * + * \defgroup error Error Handling + * + * NVRTC defines the following enumeration type and function for API call + * error handling. + * + ****************************************************************************/ + + +/** + * \ingroup error + * \brief The enumerated type nvrtcResult defines API call result codes. + * NVRTC API functions return nvrtcResult to indicate the call + * result. + */ +typedef enum { + NVRTC_SUCCESS = 0, + NVRTC_ERROR_OUT_OF_MEMORY = 1, + NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2, + NVRTC_ERROR_INVALID_INPUT = 3, + NVRTC_ERROR_INVALID_PROGRAM = 4, + NVRTC_ERROR_INVALID_OPTION = 5, + NVRTC_ERROR_COMPILATION = 6, + NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7, + NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8, + NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9, + NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10, + NVRTC_ERROR_INTERNAL_ERROR = 11, + NVRTC_ERROR_TIME_FILE_WRITE_FAILED = 12 +} nvrtcResult; + + +/** + * \ingroup error + * \brief nvrtcGetErrorString is a helper function that returns a string + * describing the given nvrtcResult code, e.g., NVRTC_SUCCESS to + * \c "NVRTC_SUCCESS". + * For unrecognized enumeration values, it returns + * \c "NVRTC_ERROR unknown". + * + * \param [in] result CUDA Runtime Compilation API result code. + * \return Message string for the given #nvrtcResult code. + */ +const char *nvrtcGetErrorString(nvrtcResult result); + + +/*************************************************************************//** + * + * \defgroup query General Information Query + * + * NVRTC defines the following function for general information query. + * + ****************************************************************************/ + + +/** + * \ingroup query + * \brief nvrtcVersion sets the output parameters \p major and \p minor + * with the CUDA Runtime Compilation version number. + * + * \param [out] major CUDA Runtime Compilation major version number. + * \param [out] minor CUDA Runtime Compilation minor version number. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * + */ +nvrtcResult nvrtcVersion(int *major, int *minor); + + +/** + * \ingroup query + * \brief nvrtcGetNumSupportedArchs sets the output parameter \p numArchs + * with the number of architectures supported by NVRTC. This can + * then be used to pass an array to ::nvrtcGetSupportedArchs to + * get the supported architectures. + * + * \param [out] numArchs number of supported architectures. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * + * see ::nvrtcGetSupportedArchs + */ +nvrtcResult nvrtcGetNumSupportedArchs(int* numArchs); + + +/** + * \ingroup query + * \brief nvrtcGetSupportedArchs populates the array passed via the output parameter + * \p supportedArchs with the architectures supported by NVRTC. The array is + * sorted in the ascending order. The size of the array to be passed can be + * determined using ::nvrtcGetNumSupportedArchs. + * + * \param [out] supportedArchs sorted array of supported architectures. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * + * see ::nvrtcGetNumSupportedArchs + */ +nvrtcResult nvrtcGetSupportedArchs(int* supportedArchs); + + +/*************************************************************************//** + * + * \defgroup compilation Compilation + * + * NVRTC defines the following type and functions for actual compilation. + * + ****************************************************************************/ + + +/** + * \ingroup compilation + * \brief nvrtcProgram is the unit of compilation, and an opaque handle for + * a program. + * + * To compile a CUDA program string, an instance of nvrtcProgram must be + * created first with ::nvrtcCreateProgram, then compiled with + * ::nvrtcCompileProgram. + */ +typedef struct _nvrtcProgram *nvrtcProgram; + + +/** + * \ingroup compilation + * \brief nvrtcCreateProgram creates an instance of nvrtcProgram with the + * given input parameters, and sets the output parameter \p prog with + * it. + * + * \param [out] prog CUDA Runtime Compilation program. + * \param [in] src CUDA program source. + * \param [in] name CUDA program name.\n + * \p name can be \c NULL; \c "default_program" is + * used when \p name is \c NULL or "". + * \param [in] numHeaders Number of headers used.\n + * \p numHeaders must be greater than or equal to 0. + * \param [in] headers Sources of the headers.\n + * \p headers can be \c NULL when \p numHeaders is + * 0. + * \param [in] includeNames Name of each header by which they can be + * included in the CUDA program source.\n + * \p includeNames can be \c NULL when \p numHeaders + * is 0. These headers must be included with the exact + * names specified here. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_OUT_OF_MEMORY \endlink + * - \link #nvrtcResult NVRTC_ERROR_PROGRAM_CREATION_FAILURE \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcDestroyProgram + */ +nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog, + const char *src, + const char *name, + int numHeaders, + const char * const *headers, + const char * const *includeNames); + + +/** + * \ingroup compilation + * \brief nvrtcDestroyProgram destroys the given program. + * + * \param [in] prog CUDA Runtime Compilation program. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcCreateProgram + */ +nvrtcResult nvrtcDestroyProgram(nvrtcProgram *prog); + + +/** + * \ingroup compilation + * \brief nvrtcCompileProgram compiles the given program. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [in] numOptions Number of compiler options passed. + * \param [in] options Compiler options in the form of C string array.\n + * \p options can be \c NULL when \p numOptions is 0. + * + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_OUT_OF_MEMORY \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_OPTION \endlink + * - \link #nvrtcResult NVRTC_ERROR_COMPILATION \endlink + * - \link #nvrtcResult NVRTC_ERROR_BUILTIN_OPERATION_FAILURE \endlink + * - \link #nvrtcResult NVRTC_ERROR_TIME_FILE_WRITE_FAILED \endlink + * + * It supports compile options listed in \ref options. + */ +nvrtcResult nvrtcCompileProgram(nvrtcProgram prog, + int numOptions, const char * const *options); + + +/** + * \ingroup compilation + * \brief nvrtcGetPTXSize sets the value of \p ptxSizeRet with the size of the PTX + * generated by the previous compilation of \p prog (including the + * trailing \c NULL). + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] ptxSizeRet Size of the generated PTX (including the trailing + * \c NULL). + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetPTX + */ +nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet); + + +/** + * \ingroup compilation + * \brief nvrtcGetPTX stores the PTX generated by the previous compilation + * of \p prog in the memory pointed by \p ptx. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] ptx Compiled result. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetPTXSize + */ +nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx); + + +/** + * \ingroup compilation + * \brief nvrtcGetCUBINSize sets the value of \p cubinSizeRet with the size of the cubin + * generated by the previous compilation of \p prog. The value of + * cubinSizeRet is set to 0 if the value specified to \c -arch is a + * virtual architecture instead of an actual architecture. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] cubinSizeRet Size of the generated cubin. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetCUBIN + */ +nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog, size_t *cubinSizeRet); + + +/** + * \ingroup compilation + * \brief nvrtcGetCUBIN stores the cubin generated by the previous compilation + * of \p prog in the memory pointed by \p cubin. No cubin is available + * if the value specified to \c -arch is a virtual architecture instead + * of an actual architecture. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] cubin Compiled and assembled result. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetCUBINSize + */ +nvrtcResult nvrtcGetCUBIN(nvrtcProgram prog, char *cubin); + + +#if defined(_WIN32) +# define __DEPRECATED__(msg) __declspec(deprecated(msg)) +#elif (defined(__GNUC__) && (__GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 5 && !defined(__clang__)))) +# define __DEPRECATED__(msg) __attribute__((deprecated)) +#elif (defined(__GNUC__)) +# define __DEPRECATED__(msg) __attribute__((deprecated(msg))) +#else +# define __DEPRECATED__(msg) +#endif + +/** + * \ingroup compilation + * \brief + * DEPRECATION NOTICE: This function will be removed in a future release. Please use + * nvrtcGetLTOIRSize (and nvrtcGetLTOIR) instead. + */ +__DEPRECATED__("This function will be removed in a future release. Please use nvrtcGetLTOIRSize instead") +nvrtcResult nvrtcGetNVVMSize(nvrtcProgram prog, size_t *nvvmSizeRet); + +/** + * \ingroup compilation + * \brief + * DEPRECATION NOTICE: This function will be removed in a future release. Please use + * nvrtcGetLTOIR (and nvrtcGetLTOIRSize) instead. + */ +__DEPRECATED__("This function will be removed in a future release. Please use nvrtcGetLTOIR instead") +nvrtcResult nvrtcGetNVVM(nvrtcProgram prog, char *nvvm); + +#undef __DEPRECATED__ + +/** + * \ingroup compilation + * \brief nvrtcGetLTOIRSize sets the value of \p LTOIRSizeRet with the size of the LTO IR + * generated by the previous compilation of \p prog. The value of + * LTOIRSizeRet is set to 0 if the program was not compiled with + * \c -dlto. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] LTOIRSizeRet Size of the generated LTO IR. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetLTOIR + */ +nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *LTOIRSizeRet); + + +/** + * \ingroup compilation + * \brief nvrtcGetLTOIR stores the LTO IR generated by the previous compilation + * of \p prog in the memory pointed by \p LTOIR. No LTO IR is available + * if the program was compiled without \c -dlto. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] LTOIR Compiled result. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetLTOIRSize + */ +nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *LTOIR); + + +/** + * \ingroup compilation + * \brief nvrtcGetOptiXIRSize sets the value of \p optixirSizeRet with the size of the OptiX IR + * generated by the previous compilation of \p prog. The value of + * nvrtcGetOptiXIRSize is set to 0 if the program was compiled with + * options incompatible with OptiX IR generation. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] optixirSizeRet Size of the generated LTO IR. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetOptiXIR + */ +nvrtcResult nvrtcGetOptiXIRSize(nvrtcProgram prog, size_t *optixirSizeRet); + + +/** + * \ingroup compilation + * \brief nvrtcGetOptiXIR stores the OptiX IR generated by the previous compilation + * of \p prog in the memory pointed by \p optixir. No OptiX IR is available + * if the program was compiled with options incompatible with OptiX IR generation. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] Optix IR Compiled result. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetOptiXIRSize + */ +nvrtcResult nvrtcGetOptiXIR(nvrtcProgram prog, char *optixir); + +/** + * \ingroup compilation + * \brief nvrtcGetProgramLogSize sets \p logSizeRet with the size of the + * log generated by the previous compilation of \p prog (including the + * trailing \c NULL). + * + * Note that compilation log may be generated with warnings and informative + * messages, even when the compilation of \p prog succeeds. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] logSizeRet Size of the compilation log + * (including the trailing \c NULL). + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetProgramLog + */ +nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, size_t *logSizeRet); + + +/** + * \ingroup compilation + * \brief nvrtcGetProgramLog stores the log generated by the previous + * compilation of \p prog in the memory pointed by \p log. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [out] log Compilation log. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * + * \see ::nvrtcGetProgramLogSize + */ +nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log); + + +/** + * \ingroup compilation + * \brief nvrtcAddNameExpression notes the given name expression + * denoting the address of a __global__ function + * or __device__/__constant__ variable. + * + * The identical name expression string must be provided on a subsequent + * call to nvrtcGetLoweredName to extract the lowered name. + * \param [in] prog CUDA Runtime Compilation program. + * \param [in] name_expression constant expression denoting the address of + * a __global__ function or __device__/__constant__ variable. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION \endlink + * + * \see ::nvrtcGetLoweredName + */ +nvrtcResult nvrtcAddNameExpression(nvrtcProgram prog, + const char * const name_expression); + +/** + * \ingroup compilation + * \brief nvrtcGetLoweredName extracts the lowered (mangled) name + * for a __global__ function or __device__/__constant__ variable, + * and updates *lowered_name to point to it. The memory containing + * the name is released when the NVRTC program is destroyed by + * nvrtcDestroyProgram. + * The identical name expression must have been previously + * provided to nvrtcAddNameExpression. + * + * \param [in] prog CUDA Runtime Compilation program. + * \param [in] name_expression constant expression denoting the address of + * a __global__ function or __device__/__constant__ variable. + * \param [out] lowered_name initialized by the function to point to a + * C string containing the lowered (mangled) + * name corresponding to the provided name expression. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_PROGRAM \endlink + * - \link #nvrtcResult NVRTC_ERROR_INVALID_INPUT \endlink + * - \link #nvrtcResult NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID \endlink + * + * \see ::nvrtcAddNameExpression + */ +nvrtcResult nvrtcGetLoweredName(nvrtcProgram prog, + const char *const name_expression, + const char** lowered_name); + + +/** + * \defgroup options Supported Compile Options + * + * NVRTC supports the compile options below. + * Option names with two preceding dashs (\c --) are long option names and + * option names with one preceding dash (\c -) are short option names. + * Short option names can be used instead of long option names. + * When a compile option takes an argument, an assignment operator (\c =) + * is used to separate the compile option argument from the compile option + * name, e.g., \c "--gpu-architecture=compute_60". + * Alternatively, the compile option name and the argument can be specified in + * separate strings without an assignment operator, .e.g, + * \c "--gpu-architecture" \c "compute_60". + * Single-character short option names, such as \c -D, \c -U, and \c -I, do + * not require an assignment operator, and the compile option name and the + * argument can be present in the same string with or without spaces between + * them. + * For instance, \c "-D=", \c "-D", and \c "-D " are all + * supported. + * + * The valid compiler options are: + * + * - Compilation targets + * - \c --gpu-architecture=\ (\c -arch)\n + * Specify the name of the class of GPU architectures for which the + * input must be compiled.\n + * - Valid \s: + * - \c compute_50 + * - \c compute_52 + * - \c compute_53 + * - \c compute_60 + * - \c compute_61 + * - \c compute_62 + * - \c compute_70 + * - \c compute_72 + * - \c compute_75 + * - \c compute_80 + * - \c compute_87 + * - \c compute_89 + * - \c compute_90 + * - \c compute_90a + * - \c sm_50 + * - \c sm_52 + * - \c sm_53 + * - \c sm_60 + * - \c sm_61 + * - \c sm_62 + * - \c sm_70 + * - \c sm_72 + * - \c sm_75 + * - \c sm_80 + * - \c sm_87 + * - \c sm_89 + * - \c sm_90 + * - \c sm_90a + * - Default: \c compute_52 + * - Separate compilation / whole-program compilation + * - \c --device-c (\c -dc)\n + * Generate relocatable code that can be linked with other relocatable + * device code. It is equivalent to --relocatable-device-code=true. + * - \c --device-w (\c -dw)\n + * Generate non-relocatable code. It is equivalent to + * \c --relocatable-device-code=false. + * - \c --relocatable-device-code={true|false} (\c -rdc)\n + * Enable (disable) the generation of relocatable device code. + * - Default: \c false + * - \c --extensible-whole-program (\c -ewp)\n + * Do extensible whole program compilation of device code. + * - Default: \c false + * - Debugging support + * - \c --device-debug (\c -G)\n + * Generate debug information. If --dopt is not specified, + * then turns off all optimizations. + * - \c --generate-line-info (\c -lineinfo)\n + * Generate line-number information. + * - Code generation + * - \c --dopt on (\c -dopt)\n + * - \c --dopt=on \n + * Enable device code optimization. When specified along with '-G', enables + * limited debug information generation for optimized device code (currently, + * only line number information). + * When '-G' is not specified, '-dopt=on' is implicit. + * - \c --ptxas-options \ (\c -Xptxas)\n + * - \c --ptxas-options=\ \n + * Specify options directly to ptxas, the PTX optimizing assembler. + * - \c --maxrregcount=\ (\c -maxrregcount)\n + * Specify the maximum amount of registers that GPU functions can use. + * Until a function-specific limit, a higher value will generally + * increase the performance of individual GPU threads that execute this + * function. However, because thread registers are allocated from a + * global register pool on each GPU, a higher value of this option will + * also reduce the maximum thread block size, thereby reducing the amount + * of thread parallelism. Hence, a good maxrregcount value is the result + * of a trade-off. If this option is not specified, then no maximum is + * assumed. Value less than the minimum registers required by ABI will + * be bumped up by the compiler to ABI minimum limit. + * - \c --ftz={true|false} (\c -ftz)\n + * When performing single-precision floating-point operations, flush + * denormal values to zero or preserve denormal values. + * \c --use_fast_math implies \c --ftz=true. + * - Default: \c false + * - \c --prec-sqrt={true|false} (\c -prec-sqrt)\n + * For single-precision floating-point square root, use IEEE + * round-to-nearest mode or use a faster approximation. + * \c --use_fast_math implies \c --prec-sqrt=false. + * - Default: \c true + * - \c --prec-div={true|false} (\c -prec-div)\n + * For single-precision floating-point division and reciprocals, use IEEE + * round-to-nearest mode or use a faster approximation. + * \c --use_fast_math implies \c --prec-div=false. + * - Default: \c true + * - \c --fmad={true|false} (\c -fmad)\n + * Enables (disables) the contraction of floating-point multiplies and + * adds/subtracts into floating-point multiply-add operations (FMAD, + * FFMA, or DFMA). \c --use_fast_math implies \c --fmad=true. + * - Default: \c true + * - \c --use_fast_math (\c -use_fast_math)\n + * Make use of fast math operations. + * \c --use_fast_math implies \c --ftz=true \c --prec-div=false + * \c --prec-sqrt=false \c --fmad=true. + * - \c --extra-device-vectorization (\c -extra-device-vectorization)\n + * Enables more aggressive device code vectorization in the NVVM optimizer. + * - \c --modify-stack-limit={true|false} (\c -modify-stack-limit)\n + * On Linux, during compilation, use \c setrlimit() to increase stack size + * to maximum allowed. The limit is reset to the previous value at the + * end of compilation. + * Note: \c setrlimit() changes the value for the entire process. + * - Default: \c true + * - \c --dlink-time-opt (\c -dlto)\n + * Generate intermediate code for later link-time optimization. + * It implies \c -rdc=true. + * Note: when this option is used the nvrtcGetLTOIR API should be used, + * as PTX or Cubin will not be generated. + * - \c --gen-opt-lto (\c -gen-opt-lto)\n + * Run the optimizer passes before generating the LTO IR. + * - \c --optix-ir (\c -optix-ir)\n + * Generate OptiX IR. The Optix IR is only intended for consumption by OptiX + * through appropriate APIs. This feature is not supported with + * link-time-optimization (\c -dlto)\n. + * Note: when this option is used the nvrtcGetOptiX API should be used, + * as PTX or Cubin will not be generated. + * - \c --jump-table-density=[0-101] (\c -jtd)\n + * Specify the case density percentage in switch statements, and use it as + * a minimal threshold to determine whether jump table(brx.idx instruction) + * will be used to implement a switch statement. Default value is 101. The + * percentage ranges from 0 to 101 inclusively. + * - Preprocessing + * - \c --define-macro=\ (\c -D)\n + * \c \ can be either \c \ or \c \. + * - \c \ \n + * Predefine \c \ as a macro with definition \c 1. + * - \c \=\ \n + * The contents of \c \ are tokenized and preprocessed + * as if they appeared during translation phase three in a \c \#define + * directive. In particular, the definition will be truncated by + * embedded new line characters. + * - \c --undefine-macro=\ (\c -U)\n + * Cancel any previous definition of \c \. + * - \c --include-path=\ (\c -I)\n + * Add the directory \c \ to the list of directories to be + * searched for headers. These paths are searched after the list of + * headers given to ::nvrtcCreateProgram. + * - \c --pre-include=\ (\c -include)\n + * Preinclude \c \ during preprocessing. + * - \c --no-source-include (\c -no-source-include) + * The preprocessor by default adds the directory of each input sources + * to the include path. This option disables this feature and only + * considers the path specified explicitly. + * - Language Dialect + * - \c --std={c++03|c++11|c++14|c++17|c++20} + * (\c -std={c++11|c++14|c++17|c++20})\n + * Set language dialect to C++03, C++11, C++14, C++17 or C++20 + * - Default: \c c++17 + * - \c --builtin-move-forward={true|false} (\c -builtin-move-forward)\n + * Provide builtin definitions of \c std::move and \c std::forward, + * when C++11 or later language dialect is selected. + * - Default: \c true + * - \c --builtin-initializer-list={true|false} + * (\c -builtin-initializer-list)\n + * Provide builtin definitions of \c std::initializer_list class and + * member functions when C++11 or later language dialect is selected. + * - Default: \c true + * - Misc. + * - \c --disable-warnings (\c -w)\n + * Inhibit all warning messages. + * - \c --restrict (\c -restrict)\n + * Programmer assertion that all kernel pointer parameters are restrict + * pointers. + * - \c --device-as-default-execution-space + * (\c -default-device)\n + * Treat entities with no execution space annotation as \c __device__ + * entities. + * - \c --device-int128 (\c -device-int128)\n + * Allow the \c __int128 type in device code. Also causes the macro \c __CUDACC_RTC_INT128__ + * to be defined. + * - \c --optimization-info=\ (\c -opt-info)\n + * Provide optimization reports for the specified kind of optimization. + * The following kind tags are supported: + * - \c inline : emit a remark when a function is inlined. + * - \c --display-error-number (\c -err-no)\n + * Display diagnostic number for warning messages. (Default) + * - \c --no-display-error-number (\c -no-err-no)\n + * Disables the display of a diagnostic number for warning messages. + * - \c --diag-error=,... (\c -diag-error)\n + * Emit error for specified diagnostic message number(s). Message numbers can be separated by comma. + * - \c --diag-suppress=,... (\c -diag-suppress)\n + * Suppress specified diagnostic message number(s). Message numbers can be separated by comma. + * - \c --diag-warn=,... (\c -diag-warn)\n + * Emit warning for specified diagnostic message number(s). Message numbers can be separated by comma. + * - \c --brief-diagnostics={true|false} (\c -brief-diag)\n + * This option disables or enables showing source line and column info + * in a diagnostic. + * The --brief-diagnostics=true will not show the source line and column info. + * - Default: \c false + * - \c --time= (\c -time)\n + * Generate a comma separated value table with the time taken by each compilation + * phase, and append it at the end of the file given as the option argument. + * If the file does not exist, the column headings are generated in the first row + * of the table. If the file name is '-', the timing data is written to the compilation log. + * - \c --split-compile= (\c -split-compile=)\n + * Perform compiler optimizations in parallel. + * Split compilation attempts to reduce compile time by enabling the compiler to run certain + * optimization passes concurrently. This option accepts a numerical value that specifies the + * maximum number of threads the compiler can use. One can also allow the compiler to use the maximum + * threads available on the system by setting --split-compile=0. + * Setting --split-compile=1 will cause this option to be ignored. + * - \c --fdevice-syntax-only (\c -fdevice-syntax-only)\n + * Ends device compilation after front-end syntax checking. This option does not generate valid + * device code. + * - \c --minimal (\c -minimal)\n + * Omit certain language features to reduce compile time for small programs. + * In particular, the following are omitted: + * - Texture and surface functions and associated types, e.g., \c cudaTextureObject_t. + * - CUDA Runtime Functions that are provided by the cudadevrt device code library, + * typically named with prefix "cuda", e.g., \c cudaMalloc. + * - Kernel launch from device code. + * - Types and macros associated with CUDA Runtime and Driver APIs, + * provided by cuda/tools/cudart/driver_types.h, typically named with prefix "cuda", e.g., \c cudaError_t. + * + */ + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + + +/* The utility function 'nvrtcGetTypeName' is not available by default. Define + the macro 'NVRTC_GET_TYPE_NAME' to a non-zero value to make it available. +*/ + +#if NVRTC_GET_TYPE_NAME || __DOXYGEN_ONLY__ + +#if NVRTC_USE_CXXABI || __clang__ || __GNUC__ || __DOXYGEN_ONLY__ +#include +#include + +#elif defined(_WIN32) +#include +#include +#endif /* NVRTC_USE_CXXABI || __clang__ || __GNUC__ */ + + +#include +#include + +template struct __nvrtcGetTypeName_helper_t { }; + +/*************************************************************************//** + * + * \defgroup hosthelper Host Helper + * + * NVRTC defines the following functions for easier interaction with host code. + * + ****************************************************************************/ + +/** + * \ingroup hosthelper + * \brief nvrtcGetTypeName stores the source level name of a type in the given + * std::string location. + * + * This function is only provided when the macro NVRTC_GET_TYPE_NAME is + * defined with a non-zero value. It uses abi::__cxa_demangle or UnDecorateSymbolName + * function calls to extract the type name, when using gcc/clang or cl.exe compilers, + * respectively. If the name extraction fails, it will return NVRTC_INTERNAL_ERROR, + * otherwise *result is initialized with the extracted name. + * + * Windows-specific notes: + * - nvrtcGetTypeName() is not multi-thread safe because it calls UnDecorateSymbolName(), + * which is not multi-thread safe. + * - The returned string may contain Microsoft-specific keywords such as __ptr64 and __cdecl. + * + * \param [in] tinfo: reference to object of type std::type_info for a given type. + * \param [in] result: pointer to std::string in which to store the type name. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INTERNAL_ERROR \endlink + * + */ +inline nvrtcResult nvrtcGetTypeName(const std::type_info &tinfo, std::string *result) +{ +#if USE_CXXABI || __clang__ || __GNUC__ + const char *name = tinfo.name(); + int status; + char *undecorated_name = abi::__cxa_demangle(name, 0, 0, &status); + if (status == 0) { + *result = undecorated_name; + free(undecorated_name); + return NVRTC_SUCCESS; + } +#elif defined(_WIN32) + const char *name = tinfo.raw_name(); + if (!name || *name != '.') { + return NVRTC_ERROR_INTERNAL_ERROR; + } + char undecorated_name[4096]; + //name+1 skips over the '.' prefix + if(UnDecorateSymbolName(name+1, undecorated_name, + sizeof(undecorated_name) / sizeof(*undecorated_name), + //note: doesn't seem to work correctly without UNDNAME_NO_ARGUMENTS. + UNDNAME_NO_ARGUMENTS | UNDNAME_NAME_ONLY ) ) { + *result = undecorated_name; + return NVRTC_SUCCESS; + } +#endif /* USE_CXXABI || __clang__ || __GNUC__ */ + + return NVRTC_ERROR_INTERNAL_ERROR; +} + +/** + * \ingroup hosthelper + * \brief nvrtcGetTypeName stores the source level name of the template type argument + * T in the given std::string location. + * + * This function is only provided when the macro NVRTC_GET_TYPE_NAME is + * defined with a non-zero value. It uses abi::__cxa_demangle or UnDecorateSymbolName + * function calls to extract the type name, when using gcc/clang or cl.exe compilers, + * respectively. If the name extraction fails, it will return NVRTC_INTERNAL_ERROR, + * otherwise *result is initialized with the extracted name. + * + * Windows-specific notes: + * - nvrtcGetTypeName() is not multi-thread safe because it calls UnDecorateSymbolName(), + * which is not multi-thread safe. + * - The returned string may contain Microsoft-specific keywords such as __ptr64 and __cdecl. + * + * \param [in] result: pointer to std::string in which to store the type name. + * \return + * - \link #nvrtcResult NVRTC_SUCCESS \endlink + * - \link #nvrtcResult NVRTC_ERROR_INTERNAL_ERROR \endlink + * + */ + +template +nvrtcResult nvrtcGetTypeName(std::string *result) +{ + nvrtcResult res = nvrtcGetTypeName(typeid(__nvrtcGetTypeName_helper_t), + result); + if (res != NVRTC_SUCCESS) + return res; + + std::string repr = *result; + std::size_t idx = repr.find("__nvrtcGetTypeName_helper_t"); + idx = (idx != std::string::npos) ? repr.find("<", idx) : idx; + std::size_t last_idx = repr.find_last_of('>'); + if (idx == std::string::npos || last_idx == std::string::npos) { + return NVRTC_ERROR_INTERNAL_ERROR; + } + ++idx; + *result = repr.substr(idx, last_idx - idx); + return NVRTC_SUCCESS; +} + +#endif /* NVRTC_GET_TYPE_NAME */ + +#endif /* __NVRTC_H__ */ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__init__.py b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e04729f15a3b8ddf20f7a48ab8f77c6581f032e7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28bd84c3870eef18edb05fb58247a276b91d6b04 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/async.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/async.h new file mode 100644 index 0000000000000000000000000000000000000000..1b7dcb2433f2cb7d1ef61290995ac871a901b1e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/async.h @@ -0,0 +1,452 @@ +/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * The source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * The Licensed Deliverables contained herein are PROPRIETARY and + * CONFIDENTIAL to NVIDIA and are being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _CG_ASYNC_H +#define _CG_ASYNC_H + +#include "helpers.h" +#include "info.h" + +#include + +_CG_BEGIN_NAMESPACE + +namespace details { +// Groups supported by memcpy_async +template +struct _async_copy_group_supported : public _CG_STL_NAMESPACE::false_type {}; + +template +struct _async_copy_group_supported> + : public _CG_STL_NAMESPACE::true_type {}; +template <> +struct _async_copy_group_supported : public _CG_STL_NAMESPACE::true_type {}; +template <> +struct _async_copy_group_supported : public _CG_STL_NAMESPACE::true_type {}; + +template +using async_copy_group_supported = _async_copy_group_supported>; + +// Groups that require optimization +template +struct _async_copy_optimize_tile : public _CG_STL_NAMESPACE::false_type {}; + +template +struct _async_copy_optimize_tile> + : public _CG_STL_NAMESPACE::false_type {}; + +template +struct _async_copy_optimize_tile> + : public _CG_STL_NAMESPACE::true_type {}; + +template +using async_copy_optimize_tile = _async_copy_optimize_tile>; + +// SFINAE helpers for tile optimizations +template +using enable_tile_optimization = + typename _CG_STL_NAMESPACE::enable_if::value, void *>::type; + +template +using disable_tile_optimization = + typename _CG_STL_NAMESPACE::enable_if::value, void *>::type; + +// Segment for punning to aligned types +template +struct _Segment { + int _seg[N]; +}; + +// Trivial layout guaranteed-aligned copy-async compatible segments +template +struct Segment; +template <> +struct __align__(4) Segment<1> : public _Segment<1>{}; +template <> +struct __align__(8) Segment<2> : public _Segment<2>{}; +template <> +struct __align__(16) Segment<4> : public _Segment<4>{}; + +// Interleaved element by element copies from source to dest +template +_CG_STATIC_QUALIFIER void inline_copy(TyGroup &group, TyElem *__restrict__ dst, const TyElem *__restrict__ src, + size_t count) { + const unsigned int rank = group.thread_rank(); + const unsigned int stride = group.size(); + + for (size_t idx = rank; idx < count; idx += stride) { + dst[idx] = src[idx]; + } +} + +template = nullptr> +_CG_STATIC_QUALIFIER void accelerated_async_copy(TyGroup &group, TyElem *__restrict__ dst, + const TyElem *__restrict__ src, size_t count) { + static_assert(async_copy_group_supported::value, + "Async copy is only supported for groups that represent private shared memory"); + + if (count == 0) { + return; + } + + const bool dstIsNotShared = !__isShared(dst); + const bool srcIsNotGlobal = !__isGlobal(src); + + if (dstIsNotShared || srcIsNotGlobal) { + inline_copy(group, dst, src, count); + return; + } + + const unsigned int stride = group.size(); + const unsigned int rank = group.thread_rank(); + // Efficient copies require warps to operate on the same amount of work at each step. + // remainders are handled in a separate stage to prevent branching + const unsigned int subWarpMask = (stride - 1); + const unsigned int subwarpCopies = (subWarpMask & (unsigned int)count); + const unsigned int maxSubwarpRank = min(rank, subwarpCopies - 1); + + const size_t warpCopies = (count & (~subWarpMask)); + + for (size_t idx = 0; idx < warpCopies; idx += stride) { + size_t _srcIdx = rank + idx; + size_t _dstIdx = rank + idx; + __pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem)); + } + + if (subwarpCopies) { + size_t _srcIdx = warpCopies + maxSubwarpRank; + size_t _dstIdx = warpCopies + maxSubwarpRank; + __pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem)); + } +} + +template = nullptr> +_CG_STATIC_QUALIFIER void accelerated_async_copy(TyGroup &group, TyElem *__restrict__ dst, + const TyElem *__restrict__ src, size_t count) { + static_assert(async_copy_group_supported::value, + "Async copy is only supported for groups that represent private shared memory"); + + const bool dstIsNotShared = !__isShared(dst); + const bool srcIsNotGlobal = !__isGlobal(src); + + if (dstIsNotShared || srcIsNotGlobal) { + inline_copy(group, dst, src, count); + return; + } + + unsigned int stride = group.size(); + unsigned int rank = group.thread_rank(); + + for (size_t idx = rank; idx < count; idx += stride) { + size_t _srcIdx = idx; + size_t _dstIdx = idx; + __pipeline_memcpy_async(dst + _dstIdx, src + _srcIdx, sizeof(TyElem)); + } +} + +// Determine best possible alignment given an input and initial conditions +// Attempts to generate as little code as possible, most likely should only be used with 1 and 2 byte alignments +template +_CG_STATIC_QUALIFIER uint32_t find_best_alignment(void *__restrict__ dst, const void *__restrict__ src) { + // Narrowing conversion intentional + uint32_t base1 = (uint32_t) reinterpret_cast(src); + uint32_t base2 = (uint32_t) reinterpret_cast(dst); + + uint32_t diff = ((base1) ^ (base2)) & (MaxAlignment - 1); + + // range [MaxAlignment, alignof(elem)], step: x >> 1 + // over range of possible alignments, choose best available out of range + uint32_t out = MaxAlignment; +#pragma unroll + for (uint32_t alignment = (MaxAlignment >> 1); alignment >= MinAlignment; alignment >>= 1) { + if (alignment & diff) + out = alignment; + } + + return out; +} + +// Determine best possible alignment given an input and initial conditions +// Attempts to generate as little code as possible, most likely should only be used with 1 and 2 byte alignments +template +_CG_STATIC_QUALIFIER void copy_like(const TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src, + size_t count) { + const char *src = reinterpret_cast(_src); + char *dst = reinterpret_cast(_dst); + + constexpr uint32_t targetAlignment = (uint32_t)alignof(TyType); + + uint32_t base = (uint32_t) reinterpret_cast(src); + uint32_t alignOffset = ((~base) + 1) & (targetAlignment - 1); + + inline_copy(group, dst, src, alignOffset); + count -= alignOffset; + src += alignOffset; + dst += alignOffset; + + // Copy using the best available alignment, async_copy expects n-datums, not bytes + size_t asyncCount = count / sizeof(TyType); + accelerated_async_copy(group, reinterpret_cast(dst), reinterpret_cast(src), asyncCount); + asyncCount *= sizeof(TyType); + + count -= asyncCount; + src += asyncCount; + dst += asyncCount; + inline_copy(group, dst, src, count); +} + +// We must determine alignment and manually align src/dst ourselves +template +struct _memcpy_async_align_dispatch { + template + _CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ dst, const void *__restrict__ src, size_t count) { + uint32_t alignment = find_best_alignment(dst, src); + + // Avoid copying the extra bytes if desired copy count is smaller + alignment = count < alignment ? AlignHint : alignment; + + switch (alignment) { + default: + case 1: + inline_copy(group, reinterpret_cast(dst), reinterpret_cast(src), count); + break; + case 2: + inline_copy(group, reinterpret_cast(dst), reinterpret_cast(src), count >> 1); + break; + case 4: + copy_like>(group, dst, src, count); + break; + case 8: + copy_like>(group, dst, src, count); + break; + case 16: + copy_like>(group, dst, src, count); + break; + } + } +}; + +// Specialization for 4 byte alignments +template <> +struct _memcpy_async_align_dispatch<4> { + template + _CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src, + size_t count) { + const Segment<1> *src = reinterpret_cast *>(_src); + Segment<1> *dst = reinterpret_cast *>(_dst); + + // Dispatch straight to aligned LDGSTS calls + accelerated_async_copy(group, dst, src, count / sizeof(*dst)); + } +}; + +// Specialization for 8 byte alignments +template <> +struct _memcpy_async_align_dispatch<8> { + template + _CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src, + size_t count) { + const Segment<2> *src = reinterpret_cast *>(_src); + Segment<2> *dst = reinterpret_cast *>(_dst); + + // Dispatch straight to aligned LDGSTS calls + accelerated_async_copy(group, dst, src, count / sizeof(*dst)); + } +}; + +// Alignments over 16 are truncated to 16 and bypass alignment +// This is the highest performing memcpy available +template <> +struct _memcpy_async_align_dispatch<16> { + template + _CG_STATIC_QUALIFIER void copy(TyGroup &group, void *__restrict__ _dst, const void *__restrict__ _src, + size_t count) { + const Segment<4> *src = reinterpret_cast *>(_src); + Segment<4> *dst = reinterpret_cast *>(_dst); + + // Dispatch straight to aligned LDGSTS calls + accelerated_async_copy(group, dst, src, count / sizeof(*dst)); + } +}; + +// byte-wide API +template +_CG_STATIC_QUALIFIER void _memcpy_async_dispatch_to_aligned_copy(const TyGroup &group, void *__restrict__ _dst, + const void *__restrict__ _src, size_t count) { + static_assert(!(Alignment & (Alignment - 1)), "Known static alignment dispatch must be a power of 2"); + details::_memcpy_async_align_dispatch::copy(group, _dst, _src, count); +} + +// Internal dispatch APIs +// These deduce the alignments and sizes necessary to invoke the underlying copy engine +template +using is_void = _CG_STL_NAMESPACE::is_same; + +template +using enable_if_not_void = typename _CG_STL_NAMESPACE::enable_if::value, void *>::type; + +template +using enable_if_void = typename _CG_STL_NAMESPACE::enable_if::value, void *>::type; + +template +using enable_if_integral = + typename _CG_STL_NAMESPACE::enable_if<_CG_STL_NAMESPACE::is_integral::value, void *>::type; + +// byte-wide API using aligned_sized_t +template typename Alignment, size_t Hint> +_CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, void *__restrict__ _dst, + const void *__restrict__ _src, const Alignment &count) { + constexpr size_t _align = (Hint > 16) ? 16 : Hint; + + details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, _dst, _src, (size_t)count); +} + +// byte-wide API using type for aligment +template = nullptr, enable_if_integral = nullptr> +_CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, TyElem *__restrict__ _dst, + const TyElem *__restrict__ _src, const TySize& count) { + constexpr size_t _align = (Hint > 16) ? 16 : Hint; + + details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, _dst, _src, count); +} + +// byte-wide API with full alignment deduction required +template = nullptr, + enable_if_integral = nullptr> +_CG_STATIC_QUALIFIER void _memcpy_async_bytes(const TyGroup &group, TyElem *__restrict__ _dst, + const TyElem *__restrict__ _src, const TySize& count) { + details::_memcpy_async_dispatch_to_aligned_copy<1>(group, _dst, _src, count); +} + +// 1d-datum API +template +_CG_STATIC_QUALIFIER void _memcpy_async_datum(const TyGroup &group, TyElem *__restrict__ dst, const size_t dstCount, + const TyElem *__restrict__ src, const size_t srcCount) { + constexpr unsigned int _align = Hint; + const size_t totalCount = min(dstCount, srcCount) * sizeof(TyElem); + + details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, dst, src, totalCount); +} + +// 1d-datum API using aligned_size_t +template typename Alignment, size_t Hint> +_CG_STATIC_QUALIFIER void _memcpy_async_datum(const TyGroup &group, TyElem *__restrict__ dst, const Alignment &dstCount, + const TyElem *__restrict__ src, const Alignment &srcCount) { + constexpr unsigned int _align = Hint; + const size_t totalCount = min((size_t)dstCount, (size_t)srcCount) * sizeof(TyElem); + + details::_memcpy_async_dispatch_to_aligned_copy<_align>(group, dst, src, totalCount); +} + +} // namespace details + +/* + * Group submit batch of async-copy to cover contiguous 1D array + * and commit that batch to eventually wait for completion. + */ +template +_CG_STATIC_QUALIFIER void memcpy_async(const TyGroup &group, TyElem *__restrict__ _dst, const TyElem *__restrict__ _src, + const TySizeT &count) { + details::_memcpy_async_bytes(group, _dst, _src, count); + __pipeline_commit(); +} + +/* + * Group submit batch of async-copy to cover contiguous 1D array + * and commit that batch to eventually wait for completion. + * Object counts are in datum sized chunks, not bytes. + */ +template +_CG_STATIC_QUALIFIER void memcpy_async(const TyGroup &group, TyElem *__restrict__ dst, const DstLayout &dstLayout, + const TyElem *__restrict__ src, const SrcLayout &srcLayout) { + details::_memcpy_async_datum(group, dst, dstLayout, src, srcLayout); + __pipeline_commit(); +} + +/* Group wait for prior Nth stage of memcpy_async to complete. */ +template +_CG_STATIC_QUALIFIER void wait_prior(const TyGroup &group) { + __pipeline_wait_prior(Stage); + group.sync(); +} + +/* Group wait all previously submitted memcpy_async to complete. */ +template +_CG_STATIC_QUALIFIER void wait(const TyGroup &group) { + __pipeline_wait_prior(0); + group.sync(); +} + +/***************** CG APIs including pipeline are deprecated *****************/ + +/* Group submit batch of async-copy to cover of contiguous 1D array + to a pipeline and commit the batch*/ +template +_CG_DEPRECATED _CG_STATIC_QUALIFIER void memcpy_async(TyGroup &group, TyElem *dst, size_t dstCount, const TyElem *src, size_t srcCount, + nvcuda::experimental::pipeline &pipe) { + details::_memcpy_async_datum(group, dst, dstCount, src, srcCount); + pipe.commit(); +} + +/* Group wait for prior Nth stage of memcpy_async to complete. */ +template +_CG_DEPRECATED _CG_STATIC_QUALIFIER void wait_prior(TyGroup &group, nvcuda::experimental::pipeline &pipe) { + pipe.wait_prior(); + group.sync(); +} + +/* Group wait for stage-S of memcpy_async to complete. */ +template +_CG_DEPRECATED _CG_STATIC_QUALIFIER void wait(TyGroup &group, nvcuda::experimental::pipeline &pipe, size_t stage) { + pipe.wait(stage); + group.sync(); +} +_CG_END_NAMESPACE + +#endif // _CG_ASYNC_H diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/coalesced_scan.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/coalesced_scan.h new file mode 100644 index 0000000000000000000000000000000000000000..383f4bde059dd8daad7d1c56e99152ea7ee28a08 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/coalesced_scan.h @@ -0,0 +1,174 @@ +/* Copyright 1993-2016 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * The source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * The Licensed Deliverables contained herein are PROPRIETARY and + * CONFIDENTIAL to NVIDIA and are being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _CG_COALESCED_SCAN_H_ +#define _CG_COALESCED_SCAN_H_ + +#include "info.h" +#include "helpers.h" +#include "cooperative_groups.h" +#include "partitioning.h" +#include "functional.h" + +_CG_BEGIN_NAMESPACE + +namespace details { + +template +_CG_QUALIFIER auto inclusive_scan_contiguous(const TyGroup& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) { + auto out = val; + for (int mask = 1; mask < group.size(); mask <<= 1) { + auto tmp = group.shfl_up(out, mask); + if (mask <= group.thread_rank()) { + out = op(out, tmp); + } + } + + return out; +} + +template +_CG_QUALIFIER auto inclusive_scan_non_contiguous(const TyGroup& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) { + const unsigned int groupSize = group.size(); + auto out = val; + + const unsigned int mask = details::_coalesced_group_data_access::get_mask(group); + unsigned int lanemask = details::lanemask32_lt() & mask; + unsigned int srcLane = details::laneid(); + + const unsigned int base = __ffs(mask)-1; /* lane with rank == 0 */ + const unsigned int rank = __popc(lanemask); + + for (unsigned int i = 1, j = 1; i < groupSize; i <<= 1) { + if (i <= rank) { + srcLane -= j; + j = i; /* maximum possible lane */ + + unsigned int begLane = base + rank - i; /* minimum possible lane */ + + /* Next source lane is in the range [ begLane .. srcLane ] + * If begLane < srcLane then do a binary search. + */ + while (begLane < srcLane) { + const unsigned int halfLane = (begLane + srcLane) >> 1; + const unsigned int halfMask = lanemask >> halfLane; + const unsigned int d = __popc(halfMask); + if (d < i) { + srcLane = halfLane - 1; /* halfLane too large */ + } + else if ((i < d) || !(halfMask & 0x01)) { + begLane = halfLane + 1; /* halfLane too small */ + } + else { + begLane = srcLane = halfLane; /* happen to hit */ + } + } + } + + auto tmp = details::tile::shuffle_dispatch::shfl(out, mask, srcLane, 32); + if (i <= rank) { + out = op(out, tmp); + } + } + return out; +} + +template +_CG_QUALIFIER auto coalesced_inclusive_scan(const __single_warp_thread_block_tile& group, + TyVal&& val, + TyOp&& op) -> decltype(op(val, val)) { + return inclusive_scan_contiguous(group, _CG_STL_NAMESPACE::forward(val), _CG_STL_NAMESPACE::forward(op)); +} + +template +_CG_QUALIFIER auto coalesced_inclusive_scan(const coalesced_group& group, TyVal&& val, TyOp&& op) -> decltype(op(val, val)) { + if (group.size() == 32) { + return inclusive_scan_contiguous(group, _CG_STL_NAMESPACE::forward(val), _CG_STL_NAMESPACE::forward(op)); + } + else { + return inclusive_scan_non_contiguous(group, _CG_STL_NAMESPACE::forward(val), _CG_STL_NAMESPACE::forward(op)); + } +} + +template +struct scan_choose_convertion; + +template<> +struct scan_choose_convertion { + template + _CG_STATIC_QUALIFIER details::remove_qual convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val) { + return result - val; + } +}; + +template<> +struct scan_choose_convertion { + template + _CG_STATIC_QUALIFIER details::remove_qual convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val) { + auto ret = group.shfl_up(result, 1); + if (group.thread_rank() == 0) { + return {}; + } + else { + return ret; + } + } +}; + +template +_CG_QUALIFIER auto convert_inclusive_to_exclusive(const TyGroup& group, TyRes& result, TyVal&& val, TyFn&& op) -> decltype(op(val, val)) { + using conversion = scan_choose_convertion<_CG_STL_NAMESPACE::is_same, cooperative_groups::plus>>::value + && _CG_STL_NAMESPACE::is_integral>::value>; + return conversion::convert_inclusive_to_exclusive(group, result, _CG_STL_NAMESPACE::forward(val)); +} + +} // details + +_CG_END_NAMESPACE + +#endif // _CG_COALESCED_SCAN_H_ \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/driver_abi.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/driver_abi.h new file mode 100644 index 0000000000000000000000000000000000000000..9c866fcf740beb709a106057d28e8a2a1ac37924 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/driver_abi.h @@ -0,0 +1,99 @@ + /* Copyright 1993-2016 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * The source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * The Licensed Deliverables contained herein are PROPRIETARY and + * CONFIDENTIAL to NVIDIA and are being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _CG_DRIVER_API_H +#define _CG_DRIVER_API_H + +#include "info.h" + +_CG_BEGIN_NAMESPACE + +namespace details { + template + _CG_QUALIFIER unsigned int load_env_reg() { + // Abort by default + _CG_ABORT(); + return 0; + } + + template + _CG_QUALIFIER unsigned long long load_env_reg64() { + unsigned long long registerLo = load_env_reg(); + unsigned long long registerHi = load_env_reg(); + + return (registerHi << 32) | registerLo; + } + +// inline PTX for accessing registers requires an immediate for the special reg +# define LOAD_ENVREG(NUMBER) \ + template <> _CG_QUALIFIER unsigned int load_env_reg() { \ + unsigned int r; \ + asm ("mov.u32 %0, %%envreg" #NUMBER ";" : "=r"(r)); \ + return r; \ + } + + // Instantiate loaders for registers used + LOAD_ENVREG(0); + LOAD_ENVREG(1); + LOAD_ENVREG(2); +# undef LOAD_ENVREG + + struct grid_workspace { + unsigned int wsSize; + unsigned int barrier; + }; + + _CG_QUALIFIER grid_workspace* get_grid_workspace() { + unsigned long long gridWsAbiAddress = load_env_reg64<1, 2>(); + // Interpret the address from envreg 1 and 2 as the driver's grid workspace + return (reinterpret_cast(gridWsAbiAddress)); + } +} +_CG_END_NAMESPACE + +#endif // _CG_DRIVER_API_H diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/info.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/info.h new file mode 100644 index 0000000000000000000000000000000000000000..9a860402ea9e8be784d384d756217fd4c656538a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/info.h @@ -0,0 +1,344 @@ + /* Copyright 1993-2021 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * The source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * The Licensed Deliverables contained herein are PROPRIETARY and + * CONFIDENTIAL to NVIDIA and are being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + + + +#ifndef _CG_INFO_H_ +#define _CG_INFO_H_ +/* +** Define: _CG_VERSION +*/ +#define _CG_VERSION 1000 + +/* +** Define: _CG_ABI_VERSION +*/ +#ifndef _CG_ABI_VERSION +# define _CG_ABI_VERSION 1 +#endif + +/* +** Define: _CG_ABI_EXPERIMENTAL +** Desc: If enabled, sets all features enabled (ABI-breaking or experimental) +*/ +#if defined(_CG_ABI_EXPERIMENTAL) +#endif + +#define _CG_CONCAT_INNER(x, y) x ## y +#define _CG_CONCAT_OUTER(x, y) _CG_CONCAT_INNER(x, y) +#define _CG_NAMESPACE _CG_CONCAT_OUTER(__v, _CG_ABI_VERSION) + +#define _CG_BEGIN_NAMESPACE \ + namespace cooperative_groups { namespace _CG_NAMESPACE { +#define _CG_END_NAMESPACE \ + }; using namespace _CG_NAMESPACE; }; + +#if (defined(__cplusplus) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MSC_VER >= 1900)) +# define _CG_CPP11_FEATURES +#endif + +#if !defined(_CG_QUALIFIER) +# define _CG_QUALIFIER __forceinline__ __device__ +#endif +#if !defined(_CG_STATIC_QUALIFIER) +# define _CG_STATIC_QUALIFIER static __forceinline__ __device__ +#endif +#if !defined(_CG_CONSTEXPR_QUALIFIER) +# if defined(_CG_CPP11_FEATURES) +# define _CG_CONSTEXPR_QUALIFIER constexpr __forceinline__ __device__ +# else +# define _CG_CONSTEXPR_QUALIFIER _CG_QUALIFIER +# endif +#endif +#if !defined(_CG_STATIC_CONSTEXPR_QUALIFIER) +# if defined(_CG_CPP11_FEATURES) +# define _CG_STATIC_CONSTEXPR_QUALIFIER static constexpr __forceinline__ __device__ +# else +# define _CG_STATIC_CONSTEXPR_QUALIFIER _CG_STATIC_QUALIFIER +# endif +#endif + +#if defined(_MSC_VER) +# define _CG_DEPRECATED __declspec(deprecated) +#else +# define _CG_DEPRECATED __attribute__((deprecated)) +#endif + +#if (__CUDA_ARCH__ >= 600) || !defined(__CUDA_ARCH__) +# define _CG_HAS_GRID_GROUP +#endif +#if (__CUDA_ARCH__ >= 600) || !defined(__CUDA_ARCH__) +# define _CG_HAS_MULTI_GRID_GROUP +#endif +#if (__CUDA_ARCH__ >= 700) || !defined(__CUDA_ARCH__) +# define _CG_HAS_MATCH_COLLECTIVE +#endif + +#if (__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__) && (defined(__NVCC__) || defined(__CUDACC_RTC__)) +# define _CG_HAS_OP_REDUX +#endif + +#if ((__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__)) && !defined(_CG_USER_PROVIDED_SHARED_MEMORY) +# define _CG_HAS_RESERVED_SHARED +#endif + +#if ((__CUDA_ARCH__ >= 900) || !defined(__CUDA_ARCH__)) && \ + (defined(__NVCC__) || defined(__CUDACC_RTC__) || defined(_CG_CLUSTER_INTRINSICS_AVAILABLE)) && \ + defined(_CG_CPP11_FEATURES) +# define _CG_HAS_CLUSTER_GROUP +#endif + +#if (__CUDA_ARCH__ >= 900) || !defined(__CUDA_ARCH__) +# define _CG_HAS_INSTR_ELECT +#endif + +// Has __half and __half2 +// Only usable if you include the cuda_fp16.h extension, and +// _before_ including cooperative_groups.h +#ifdef __CUDA_FP16_TYPES_EXIST__ +# define _CG_HAS_FP16_COLLECTIVE +#endif + +// Include libcu++ where supported. +#if defined(_CG_CPP11_FEATURES) && !defined(__QNX__) && !defined(__ibmxl__) && \ + (defined(__NVCC__) || defined(__CUDACC_RTC__)) && \ + (defined(__x86_64__) || defined(__aarch64__) || defined(__ppc64__)|| defined(_M_X64) || defined(_M_ARM64)) && \ + (defined(_MSC_VER) || defined(__GNUC__) || defined(__clang__)) +# define _CG_USE_CUDA_STL +#else +# define _CG_USE_OWN_TRAITS +#endif + +#if defined(_CG_USE_CUDA_STL) && (!defined(__CUDA_ARCH__) || \ + ((!defined(_MSC_VER) && __CUDA_ARCH__ >= 600) || (defined(_MSC_VER) && __CUDA_ARCH__ >= 700))) +# define _CG_HAS_STL_ATOMICS +#endif + +#ifdef _CG_CPP11_FEATURES +// Use cuda::std:: for type_traits +# if defined(_CG_USE_CUDA_STL) +# define _CG_STL_NAMESPACE cuda::std +# include +// Use CG's implementation of type traits +# else +# define _CG_STL_NAMESPACE cooperative_groups::details::templates +# endif +#endif + +#ifdef _CG_CPP11_FEATURES +# define _CG_STATIC_CONST_DECL static constexpr +# define _CG_CONST_DECL constexpr +#else +# define _CG_STATIC_CONST_DECL static const +# define _CG_CONST_DECL const +#endif + +#if (defined(_MSC_VER) && !defined(_WIN64)) || defined(__arm__) +# define _CG_ASM_PTR_CONSTRAINT "r" +#else +# define _CG_ASM_PTR_CONSTRAINT "l" +#endif + +/* +** Define: CG_DEBUG +** What: Enables various runtime safety checks +*/ +#if defined(__CUDACC_DEBUG__) && defined(CG_DEBUG) && !defined(NDEBUG) +# define _CG_DEBUG +#endif + +#if defined(_CG_DEBUG) +# include +# define _CG_ASSERT(x) assert((x)); +# define _CG_ABORT() assert(0); +#else +# define _CG_ASSERT(x) +# define _CG_ABORT() __trap(); +#endif + +_CG_BEGIN_NAMESPACE + +namespace details { + _CG_STATIC_CONST_DECL unsigned int default_max_block_size = 1024; + +#if defined(_CG_CPP11_FEATURES) && !defined(_CG_USE_CUDA_STL) +namespace templates { + +/** + * Integral constants + **/ +template +struct integral_constant { + static constexpr Ty value = Val; + typedef Ty type; + + _CG_QUALIFIER constexpr operator type() const noexcept { return value; } + _CG_QUALIFIER constexpr type operator()() const noexcept { return value; } +}; + +typedef integral_constant true_type; +typedef integral_constant false_type; + +/** + * CV Qualifiers + **/ +template struct is_lvalue_reference : public details::templates::false_type {}; +template struct is_lvalue_reference : public details::templates::true_type {}; + +template struct remove_reference {typedef Ty type;}; +template struct remove_reference {typedef Ty type;}; +template struct remove_reference {typedef Ty type;}; + +template +using remove_reference_t = typename details::templates::remove_reference::type; + +template struct remove_const {typedef Ty type;}; +template struct remove_const {typedef Ty type;}; + +template struct remove_volatile {typedef Ty type;}; +template struct remove_volatile {typedef Ty type;}; + +template struct remove_cv {typedef typename details::templates::remove_volatile::type>::type type;}; + +template +using remove_cv_t = typename details::templates::remove_cv::type; + +template +_CG_QUALIFIER Ty&& forward(remove_reference_t &t) noexcept { + return static_cast(t); +} + +template +_CG_QUALIFIER Ty&& forward(remove_reference_t &&t) noexcept { + static_assert(!details::templates::is_lvalue_reference::value, "Forwarding an rvalue as an lvalue is not allowed."); + return static_cast(t); +} + +/** + * is_integral + **/ +template struct _is_integral : public details::templates::false_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +template <> struct _is_integral : public details::templates::true_type {}; +//Vector type support? + +template +struct is_integral : public details::templates::_is_integral::type> {}; + +/** + * is_floating_point + **/ +template struct _is_floating_point : public details::templates::false_type {}; +template <> struct _is_floating_point : public details::templates::true_type {}; +template <> struct _is_floating_point : public details::templates::true_type {}; +template <> struct _is_floating_point : public details::templates::true_type {}; +# ifdef __CUDA_FP16_TYPES_EXIST__ +template <> struct _is_floating_point<__half> : public details::templates::true_type {}; +template <> struct _is_floating_point<__half2> : public details::templates::true_type {}; +# endif +//Vector type support? + +template +struct is_floating_point : public details::templates::_is_floating_point::type> {}; + +template +struct is_arithmetic : details::templates::integral_constant< + bool, + details::templates::is_integral::value || + details::templates::is_floating_point::value> {}; + +template ::value> +struct _is_unsigned : details::templates::integral_constant {}; + +template +struct _is_unsigned : details::templates::false_type {}; + +template +struct is_unsigned : _is_unsigned::type> {}; + +template struct _is_pointer : public details::templates::false_type {}; +template struct _is_pointer : public details::templates::true_type {}; + +template +struct is_pointer : _is_pointer::type> {}; + +/** + * programmatic type traits + **/ +template +struct enable_if {}; + +template +struct enable_if { typedef Ty type; }; + +template +using enable_if_t = typename details::templates::enable_if::type; + +template +struct is_same : details::templates::false_type {}; + +template +struct is_same : details::templates::true_type {}; + +} // templates +#endif // _CG_CPP11_FEATURES + +} // details +_CG_END_NAMESPACE + + +#endif // _CG_INFO_H_ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/invoke.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/invoke.h new file mode 100644 index 0000000000000000000000000000000000000000..f00314ce140e390be90a1ab3c328fd73d73c0d46 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/invoke.h @@ -0,0 +1,189 @@ +/* + * Copyright 1993-2022 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _CG_INVOKE_H +#define _CG_INVOKE_H + +#include "info.h" +#include "helpers.h" + +#if defined(_CG_CPP11_FEATURES) + +_CG_BEGIN_NAMESPACE + +namespace details { + + template + struct _elect_group_supported : _CG_STL_NAMESPACE::false_type {}; +#ifdef _CG_HAS_INSTR_ELECT + template<> + struct _elect_group_supported : _CG_STL_NAMESPACE::true_type {}; + template + struct _elect_group_supported> : + _CG_STL_NAMESPACE::integral_constant {}; +#endif + + template + struct elect_group_supported : public _elect_group_supported> {}; + + template + _CG_STATIC_QUALIFIER bool elect_one(const Group& group, unsigned int mask, unsigned int& leader_lane) { + int is_leader = 0; +#ifdef _CG_HAS_INSTR_ELECT + asm("{\n\t" + " .reg .pred p;\n\t" + " elect.sync %0|p, %2;\n\t" + " @p mov.s32 %1, 1;\n\t" + "}" + : "+r"(leader_lane), "+r"(is_leader) : "r" (mask)); +#endif + return is_leader; + } + + template + struct invoke_one_impl {}; + + template<> + struct invoke_one_impl { + template + _CG_STATIC_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) { + auto mask = details::_coalesced_group_data_access::get_mask(group); + unsigned int leader_lane = 0; + + if (elect_one(group, mask, leader_lane)) { + _CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...); + } + } + + template + _CG_STATIC_QUALIFIER auto invoke_one_broadcast(const Group& group, Fn&& fn, Args&&... args) + -> typename _CG_STL_NAMESPACE::remove_reference< + decltype(_CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...))>::type { + + using ResultType = decltype(_CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...)); + details::remove_qual result; + auto mask = details::_coalesced_group_data_access::get_mask(group); + unsigned int leader_lane = 0; + + if (elect_one(group, mask, leader_lane)) { + result = _CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...); + } + + // Need to use low level api instead of group.shfl, because elect_one returns lane id, not group rank. + return tile::shuffle_dispatch::shfl(result, mask, leader_lane, 32); + } + }; + + template<> + struct invoke_one_impl { + template + _CG_STATIC_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) { + if (group.thread_rank() == 0) { + _CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...); + } + } + + template + _CG_STATIC_QUALIFIER auto invoke_one_broadcast(const Group& group, Fn&& fn, Args&&... args) + -> typename _CG_STL_NAMESPACE::remove_reference< + decltype(_CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...))>::type { + + using ResultType = decltype(_CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...)); + details::remove_qual result; + + if (group.thread_rank() == 0) { + result = _CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...); + } + + return group.shfl(result, 0); + } + }; + + +}; // namespace details + +template +_CG_QUALIFIER void invoke_one(const Group& group, Fn&& fn, Args&&... args) { + using impl = details::invoke_one_impl::value>; + impl::invoke_one(group, _CG_STL_NAMESPACE::forward(fn), _CG_STL_NAMESPACE::forward(args)...); +} + +template +_CG_QUALIFIER auto invoke_one_broadcast(const coalesced_group& group, Fn&& fn, Args&&... args) + -> typename _CG_STL_NAMESPACE::remove_reference< + decltype(_CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...))>::type { + + using ResultType = decltype(_CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...)); + static_assert(!_CG_STL_NAMESPACE::is_same::value, + "For invocables returning void invoke_one should be used instead"); + using impl = details::invoke_one_impl::value>; + return impl::invoke_one_broadcast(group, + _CG_STL_NAMESPACE::forward(fn), + _CG_STL_NAMESPACE::forward(args)...); +} + +template +_CG_QUALIFIER auto invoke_one_broadcast(const thread_block_tile& group, Fn&& fn, Args&&... args) + -> typename _CG_STL_NAMESPACE::remove_reference< + decltype(_CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...))>::type { + + using ResultType = decltype(_CG_STL_NAMESPACE::forward(fn)(_CG_STL_NAMESPACE::forward(args)...)); + static_assert(!_CG_STL_NAMESPACE::is_same::value, + "For invocables returning void invoke_one should be used instead"); + using impl = details::invoke_one_impl>::value>; + return impl::invoke_one_broadcast(group, + _CG_STL_NAMESPACE::forward(fn), + _CG_STL_NAMESPACE::forward(args)...); +} + +_CG_END_NAMESPACE + +#endif //_CG_CPP11_FEATURES + +#endif // _CG_INVOKE_H diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/memory.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/memory.h new file mode 100644 index 0000000000000000000000000000000000000000..47cf260f3b4e0b29bf08c948697102bf027616db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/memory.h @@ -0,0 +1,135 @@ +/* Copyright 1993-2022 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * The source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * The Licensed Deliverables contained herein are PROPRIETARY and + * CONFIDENTIAL to NVIDIA and are being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _COOPERATIVE_GROUPS_MEMORY_H_ +# define _COOPERATIVE_GROUPS_MEMORY_H_ + +#include "info.h" + +_CG_BEGIN_NAMESPACE + +#if defined(_CG_CPP11_FEATURES) +namespace details { + _CG_STATIC_CONST_DECL int scratch_num_reserved_bytes = 12; + +#if defined(_CG_HAS_RESERVED_SHARED) + _CG_STATIC_QUALIFIER void* reserved_shared_ptr() + { + void *ptr; + asm ("{\n\t" + " .reg .u32 start;\n\t" + " .reg .u64 extended;\n\t" + " mov.u32 start, %%reserved_smem_offset_1;\n\t" + " cvt.u64.u32 extended, start;\n\t" + " cvta.shared.u64 %0, extended;\n\t" + "}" + : "=" _CG_ASM_PTR_CONSTRAINT(ptr)); + return ptr; + } +#endif + + struct multi_warp_scratch { + // One barrier per possible size of the group. + _CG_STATIC_CONST_DECL unsigned int memory_barriers_count = 5; + _CG_STATIC_CONST_DECL size_t sync_memory_size = memory_barriers_count * sizeof(barrier_t); + + using communication_type = unsigned long long; + _CG_STATIC_CONST_DECL size_t communication_size = sizeof(communication_type); + + // Layout of the scratch space: + barrier_t barriers[memory_barriers_count]; + char reserved[scratch_num_reserved_bytes]; // Reserve 12 bytes for future use + communication_type communication_memory[default_max_block_size / 32]; + + _CG_STATIC_CONSTEXPR_QUALIFIER unsigned int scratch_size_needed(unsigned int max_block_size) { + // One slot of collectives memory per warp. + return scratch_num_reserved_bytes + sync_memory_size + max_block_size / 32 * communication_size; + } + + _CG_QUALIFIER void init_barriers(unsigned int thread_rank) { + if (thread_rank < memory_barriers_count) { + barriers[thread_rank] = 0; + } + } + }; + +#if defined(_CG_HAS_RESERVED_SHARED) + // CG can expect at least 288 bytes available in reserved shared + static_assert(sizeof(multi_warp_scratch) <= 288, "multi-warp scratch size is too large"); +#endif + + // Make sure the structure can fit into the user provided memory + static_assert(sizeof(multi_warp_scratch) <= multi_warp_scratch::scratch_size_needed(default_max_block_size), + "multi-warp scratch size is too large"); + + + _CG_QUALIFIER multi_warp_scratch* get_scratch_ptr(void* user_scratch) { + void *ptr; +#if defined(_CG_HAS_RESERVED_SHARED) + ptr = reserved_shared_ptr(); +#else + ptr = user_scratch; +#endif + return static_cast(ptr); + + } + +} + +template +struct __align__(details::multi_warp_scratch::communication_size) block_tile_memory { +private: +#if !defined(_CG_HAS_RESERVED_SHARED) + char scratch[details::multi_warp_scratch::scratch_size_needed(MaxBlockSize)]; +#endif +}; +#endif + +_CG_END_NAMESPACE + +#endif /* !_COOPERATIVE_GROUPS_MEMORY_H_ */ diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/partitioning.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/partitioning.h new file mode 100644 index 0000000000000000000000000000000000000000..9c219756c594c87be85a0a154cfa5579241a861f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/partitioning.h @@ -0,0 +1,159 @@ +/* + * Copyright 1993-2016 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _CG_PARTITIONING_H +#define _CG_PARTITIONING_H + +#include "info.h" +#include "helpers.h" + +_CG_BEGIN_NAMESPACE + +namespace details { + + template + _CG_STATIC_QUALIFIER coalesced_group _binary_partition(const TyGroup &tile, bool pred) { + const unsigned int fullMask = ~0u; + + unsigned int thisMask = _coalesced_group_data_access::get_mask(tile); + unsigned int predMask = pred ? 0 : fullMask; + unsigned int setMask = __ballot_sync(thisMask, pred); + + if (setMask == thisMask || setMask == 0) { + coalesced_group subTile = _coalesced_group_data_access::construct_from_mask(thisMask); + _coalesced_group_data_access::modify_meta_group(subTile, 0, 1); + return subTile; + } + else { + unsigned int subMask = thisMask & (setMask ^ predMask); + coalesced_group subTile = _coalesced_group_data_access::construct_from_mask(subMask); + _coalesced_group_data_access::modify_meta_group(subTile, pred, 2); + return subTile; + } + } + +#if defined(_CG_HAS_MATCH_COLLECTIVE) && defined(_CG_CPP11_FEATURES) + template + struct _labeled_partition_dispatch { + template + _CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, TyPredicate pred) { + unsigned int thisMask = _coalesced_group_data_access::get_mask(tile); + unsigned int thisBias = __ffs(thisMask) - 1; // Subtract 1 to index properly from [1-32] + unsigned int subMask = __match_any_sync(thisMask, pred); + + coalesced_group subTile = _coalesced_group_data_access::construct_from_mask(subMask); + + int leaderLaneId = subTile.shfl(details::laneid(), 0); + + bool isLeader = !subTile.thread_rank(); + unsigned int leaderMask = __ballot_sync(thisMask, isLeader); + unsigned int tileRank = __fns(leaderMask, leaderLaneId, 0) - thisBias; + + _coalesced_group_data_access::modify_meta_group(subTile, tileRank, __popc(leaderMask)); + + return subTile; + } + }; + + template <> + struct _labeled_partition_dispatch { + template + _CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, bool pred) { + return _binary_partition(tile, pred); + } + }; + + template + struct _labeled_partition_dispatch { + template + _CG_QUALIFIER coalesced_group operator()(const TyGroup &tile, TyPredicate* pred) { + auto impl = _labeled_partition_dispatch(); + return impl(tile, reinterpret_cast(pred)); + } + }; +#endif +}; // namespace details + +_CG_STATIC_QUALIFIER coalesced_group binary_partition(const coalesced_group &tile, bool pred) { + return details::_binary_partition(tile, pred); +} + +template +_CG_STATIC_QUALIFIER coalesced_group binary_partition(const thread_block_tile &tile, bool pred) { +#ifdef _CG_CPP11_FEATURES + static_assert(Size <= 32, "Binary partition is available only for tiles of size smaller or equal to 32"); +#endif + return details::_binary_partition(tile, pred); +} + + +#if defined(_CG_HAS_MATCH_COLLECTIVE) && defined(_CG_CPP11_FEATURES) +template +_CG_STATIC_QUALIFIER coalesced_group labeled_partition(const coalesced_group &tile, TyPredicate pred) { + static_assert(_CG_STL_NAMESPACE::is_integral::value || + _CG_STL_NAMESPACE::is_pointer::value, + "labeled_partition predicate must be an integral or pointer type"); + auto dispatch = details::_labeled_partition_dispatch>(); + return dispatch(tile, pred); +} + +template +_CG_STATIC_QUALIFIER coalesced_group labeled_partition(const thread_block_tile &tile, TyPredicate pred) { + static_assert(_CG_STL_NAMESPACE::is_integral::value || + _CG_STL_NAMESPACE::is_pointer::value, + "labeled_partition predicate must be an integral or pointer type"); + static_assert(Size <= 32, "Labeled partition is available only for tiles of size smaller or equal to 32"); + auto dispatch = details::_labeled_partition_dispatch>(); + return dispatch(tile, pred); +} +#endif + +_CG_END_NAMESPACE + +#endif // _CG_PARTITIONING_H diff --git a/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/reduce.h b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..0313b52a23f440e283509993d6f7997ba5df2365 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/include/cooperative_groups/details/reduce.h @@ -0,0 +1,419 @@ + /* Copyright 1993-2016 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * The source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * The Licensed Deliverables contained herein are PROPRIETARY and + * CONFIDENTIAL to NVIDIA and are being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef _CG_REDUCE_H_ +#define _CG_REDUCE_H_ + +#include "info.h" +#include "helpers.h" +#include "coalesced_reduce.h" +#include "functional.h" +#include "cooperative_groups.h" + +_CG_BEGIN_NAMESPACE + +namespace details { + + template + using _redux_is_add_supported = _CG_STL_NAMESPACE::integral_constant< + bool, + _CG_STL_NAMESPACE::is_integral::value && (sizeof(Ty) <= 4)>; + + template + using redux_is_add_supported = _redux_is_add_supported; + + // A specialization for 64 bit logical operations is possible + // but for now only accelerate 32 bit bitwise ops + template + using redux_is_logical_supported = redux_is_add_supported; + + // Base operator support case + template struct _redux_op_supported : public _CG_STL_NAMESPACE::false_type {}; +#ifdef _CG_HAS_OP_REDUX + template struct _redux_op_supported, Ty> : public redux_is_add_supported {}; + template struct _redux_op_supported, Ty> : public redux_is_add_supported {}; + template struct _redux_op_supported, Ty> : public redux_is_add_supported {}; + template struct _redux_op_supported, Ty> : public redux_is_logical_supported {}; + template struct _redux_op_supported, Ty> : public redux_is_logical_supported {}; + template struct _redux_op_supported, Ty> : public redux_is_logical_supported {}; +#endif + + template class TyOp> + using redux_op_supported = _redux_op_supported< + typename details::remove_qual>, + Ty>; + + // Groups smaller than 16 actually have worse performance characteristics when used with redux + // tiles of size 16 and 32 perform the same or better and have better code generation profiles + template struct _redux_group_optimized : public _CG_STL_NAMESPACE::false_type {}; + + template + struct _redux_group_optimized> : public _CG_STL_NAMESPACE::integral_constant< + bool, + (Sz >= 16)> {}; + template + struct _redux_group_optimized> : public _CG_STL_NAMESPACE::integral_constant< + bool, + (Sz >= 16)> {}; + template <> + struct _redux_group_optimized : public _CG_STL_NAMESPACE::true_type {}; + + template + using redux_group_optimized = _redux_group_optimized>; + + template