| from cython cimport floating |
|
|
| from scipy.linalg.cython_blas cimport sdot, ddot |
| from scipy.linalg.cython_blas cimport sasum, dasum |
| from scipy.linalg.cython_blas cimport saxpy, daxpy |
| from scipy.linalg.cython_blas cimport snrm2, dnrm2 |
| from scipy.linalg.cython_blas cimport scopy, dcopy |
| from scipy.linalg.cython_blas cimport sscal, dscal |
| from scipy.linalg.cython_blas cimport srotg, drotg |
| from scipy.linalg.cython_blas cimport srot, drot |
| from scipy.linalg.cython_blas cimport sgemv, dgemv |
| from scipy.linalg.cython_blas cimport sger, dger |
| from scipy.linalg.cython_blas cimport sgemm, dgemm |
|
|
|
|
| |
| |
| |
|
|
| cdef floating _dot(int n, const floating *x, int incx, |
| const floating *y, int incy) noexcept nogil: |
| """x.T.y""" |
| if floating is float: |
| return sdot(&n, <float *> x, &incx, <float *> y, &incy) |
| else: |
| return ddot(&n, <double *> x, &incx, <double *> y, &incy) |
|
|
|
|
| cpdef _dot_memview(const floating[::1] x, const floating[::1] y): |
| return _dot(x.shape[0], &x[0], 1, &y[0], 1) |
|
|
|
|
| cdef floating _asum(int n, const floating *x, int incx) noexcept nogil: |
| """sum(|x_i|)""" |
| if floating is float: |
| return sasum(&n, <float *> x, &incx) |
| else: |
| return dasum(&n, <double *> x, &incx) |
|
|
|
|
| cpdef _asum_memview(const floating[::1] x): |
| return _asum(x.shape[0], &x[0], 1) |
|
|
|
|
| cdef void _axpy(int n, floating alpha, const floating *x, int incx, |
| floating *y, int incy) noexcept nogil: |
| """y := alpha * x + y""" |
| if floating is float: |
| saxpy(&n, &alpha, <float *> x, &incx, y, &incy) |
| else: |
| daxpy(&n, &alpha, <double *> x, &incx, y, &incy) |
|
|
|
|
| cpdef _axpy_memview(floating alpha, const floating[::1] x, floating[::1] y): |
| _axpy(x.shape[0], alpha, &x[0], 1, &y[0], 1) |
|
|
|
|
| cdef floating _nrm2(int n, const floating *x, int incx) noexcept nogil: |
| """sqrt(sum((x_i)^2))""" |
| if floating is float: |
| return snrm2(&n, <float *> x, &incx) |
| else: |
| return dnrm2(&n, <double *> x, &incx) |
|
|
|
|
| cpdef _nrm2_memview(const floating[::1] x): |
| return _nrm2(x.shape[0], &x[0], 1) |
|
|
|
|
| cdef void _copy(int n, const floating *x, int incx, const floating *y, int incy) noexcept nogil: |
| """y := x""" |
| if floating is float: |
| scopy(&n, <float *> x, &incx, <float *> y, &incy) |
| else: |
| dcopy(&n, <double *> x, &incx, <double *> y, &incy) |
|
|
|
|
| cpdef _copy_memview(const floating[::1] x, const floating[::1] y): |
| _copy(x.shape[0], &x[0], 1, &y[0], 1) |
|
|
|
|
| cdef void _scal(int n, floating alpha, const floating *x, int incx) noexcept nogil: |
| """x := alpha * x""" |
| if floating is float: |
| sscal(&n, &alpha, <float *> x, &incx) |
| else: |
| dscal(&n, &alpha, <double *> x, &incx) |
|
|
|
|
| cpdef _scal_memview(floating alpha, const floating[::1] x): |
| _scal(x.shape[0], alpha, &x[0], 1) |
|
|
|
|
| cdef void _rotg(floating *a, floating *b, floating *c, floating *s) noexcept nogil: |
| """Generate plane rotation""" |
| if floating is float: |
| srotg(a, b, c, s) |
| else: |
| drotg(a, b, c, s) |
|
|
|
|
| cpdef _rotg_memview(floating a, floating b, floating c, floating s): |
| _rotg(&a, &b, &c, &s) |
| return a, b, c, s |
|
|
|
|
| cdef void _rot(int n, floating *x, int incx, floating *y, int incy, |
| floating c, floating s) noexcept nogil: |
| """Apply plane rotation""" |
| if floating is float: |
| srot(&n, x, &incx, y, &incy, &c, &s) |
| else: |
| drot(&n, x, &incx, y, &incy, &c, &s) |
|
|
|
|
| cpdef _rot_memview(floating[::1] x, floating[::1] y, floating c, floating s): |
| _rot(x.shape[0], &x[0], 1, &y[0], 1, c, s) |
|
|
|
|
| |
| |
| |
|
|
| cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, |
| const floating *A, int lda, const floating *x, int incx, |
| floating beta, floating *y, int incy) noexcept nogil: |
| """y := alpha * op(A).x + beta * y""" |
| cdef char ta_ = ta |
| if order == RowMajor: |
| ta_ = NoTrans if ta == Trans else Trans |
| if floating is float: |
| sgemv(&ta_, &n, &m, &alpha, <float *> A, &lda, <float *> x, |
| &incx, &beta, y, &incy) |
| else: |
| dgemv(&ta_, &n, &m, &alpha, <double *> A, &lda, <double *> x, |
| &incx, &beta, y, &incy) |
| else: |
| if floating is float: |
| sgemv(&ta_, &m, &n, &alpha, <float *> A, &lda, <float *> x, |
| &incx, &beta, y, &incy) |
| else: |
| dgemv(&ta_, &m, &n, &alpha, <double *> A, &lda, <double *> x, |
| &incx, &beta, y, &incy) |
|
|
|
|
| cpdef _gemv_memview(BLAS_Trans ta, floating alpha, const floating[:, :] A, |
| const floating[::1] x, floating beta, floating[::1] y): |
| cdef: |
| int m = A.shape[0] |
| int n = A.shape[1] |
| BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor |
| int lda = m if order == ColMajor else n |
|
|
| _gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1) |
|
|
|
|
| cdef void _ger(BLAS_Order order, int m, int n, floating alpha, |
| const floating *x, int incx, const floating *y, |
| int incy, floating *A, int lda) noexcept nogil: |
| """A := alpha * x.y.T + A""" |
| if order == RowMajor: |
| if floating is float: |
| sger(&n, &m, &alpha, <float *> y, &incy, <float *> x, &incx, A, &lda) |
| else: |
| dger(&n, &m, &alpha, <double *> y, &incy, <double *> x, &incx, A, &lda) |
| else: |
| if floating is float: |
| sger(&m, &n, &alpha, <float *> x, &incx, <float *> y, &incy, A, &lda) |
| else: |
| dger(&m, &n, &alpha, <double *> x, &incx, <double *> y, &incy, A, &lda) |
|
|
|
|
| cpdef _ger_memview(floating alpha, const floating[::1] x, |
| const floating[::1] y, floating[:, :] A): |
| cdef: |
| int m = A.shape[0] |
| int n = A.shape[1] |
| BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor |
| int lda = m if order == ColMajor else n |
|
|
| _ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) |
|
|
|
|
| |
| |
| |
|
|
| cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, |
| int k, floating alpha, const floating *A, int lda, const floating *B, |
| int ldb, floating beta, floating *C, int ldc) noexcept nogil: |
| """C := alpha * op(A).op(B) + beta * C""" |
| |
| |
| cdef: |
| char ta_ = ta |
| char tb_ = tb |
| if order == RowMajor: |
| if floating is float: |
| sgemm(&tb_, &ta_, &n, &m, &k, &alpha, <float*>B, |
| &ldb, <float*>A, &lda, &beta, C, &ldc) |
| else: |
| dgemm(&tb_, &ta_, &n, &m, &k, &alpha, <double*>B, |
| &ldb, <double*>A, &lda, &beta, C, &ldc) |
| else: |
| if floating is float: |
| sgemm(&ta_, &tb_, &m, &n, &k, &alpha, <float*>A, |
| &lda, <float*>B, &ldb, &beta, C, &ldc) |
| else: |
| dgemm(&ta_, &tb_, &m, &n, &k, &alpha, <double*>A, |
| &lda, <double*>B, &ldb, &beta, C, &ldc) |
|
|
|
|
| cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha, |
| const floating[:, :] A, const floating[:, :] B, floating beta, |
| floating[:, :] C): |
| cdef: |
| int m = A.shape[0] if ta == NoTrans else A.shape[1] |
| int n = B.shape[1] if tb == NoTrans else B.shape[0] |
| int k = A.shape[1] if ta == NoTrans else A.shape[0] |
| int lda, ldb, ldc |
| BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor |
|
|
| if order == RowMajor: |
| lda = k if ta == NoTrans else m |
| ldb = n if tb == NoTrans else k |
| ldc = n |
| else: |
| lda = m if ta == NoTrans else k |
| ldb = k if tb == NoTrans else n |
| ldc = m |
|
|
| _gemm(order, ta, tb, m, n, k, alpha, &A[0, 0], |
| lda, &B[0, 0], ldb, beta, &C[0, 0], ldc) |
|
|