| #pragma once
|
|
|
|
|
|
|
| #include "simd-mappings.h"
|
|
|
|
|
| #if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
|
|
|
|
|
|
|
| #if defined(__AVX512F__) || defined (__ARM_NEON__)
|
| static constexpr int GEMM_RM = 4;
|
| static constexpr int GEMM_RN = 4;
|
| #elif defined(__AVX2__) || defined(__AVX__)
|
| static constexpr int GEMM_RM = 6;
|
| static constexpr int GEMM_RN = 2;
|
| #else
|
| static constexpr int GEMM_RM = 2;
|
| static constexpr int GEMM_RN = 2;
|
| #endif
|
|
|
| template <int RM, int RN>
|
| static inline void simd_gemm_ukernel(
|
| float * GGML_RESTRICT C,
|
| const float * GGML_RESTRICT A,
|
| const float * GGML_RESTRICT B,
|
| int K, int N)
|
| {
|
| static constexpr int KN = GGML_F32_EPR;
|
|
|
| GGML_F32_VEC acc[RM][RN];
|
| for (int64_t i = 0; i < RM; i++) {
|
| for (int r = 0; r < RN; r++) {
|
| acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
|
| }
|
| }
|
|
|
| for (int64_t kk = 0; kk < K; kk++) {
|
| GGML_F32_VEC Bv[RN];
|
| for (int r = 0; r < RN; r++) {
|
| Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
|
| }
|
| for (int64_t i = 0; i < RM; i++) {
|
| GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
|
| for (int r = 0; r < RN; r++) {
|
| acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
|
| }
|
| }
|
| }
|
|
|
| for (int64_t i = 0; i < RM; i++) {
|
| for (int r = 0; r < RN; r++) {
|
| GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
|
| }
|
| }
|
| }
|
|
|
|
|
| static void simd_gemm(
|
| float * GGML_RESTRICT C,
|
| const float * GGML_RESTRICT A,
|
| const float * GGML_RESTRICT B,
|
| int M, int K, int N)
|
| {
|
| static constexpr int KN = GGML_F32_EPR;
|
|
|
| int64_t ii = 0;
|
| for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
|
| int64_t jj = 0;
|
| for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
| simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
|
| }
|
| for (; jj + KN <= N; jj += KN) {
|
| simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
|
| }
|
| for (; jj < N; jj++) {
|
| for (int64_t i = 0; i < GEMM_RM; i++) {
|
| float a = C[i * N + jj];
|
| for (int64_t kk = 0; kk < K; kk++) {
|
| a += A[i + kk] * B[kk * N + jj];
|
| }
|
| C[i * N + jj] = a;
|
| }
|
| }
|
|
|
| A += GEMM_RM * K;
|
| C += GEMM_RM * N;
|
| }
|
|
|
|
|
| for (; ii < M; ii++) {
|
| int64_t jj = 0;
|
| for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
| simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
|
| }
|
| for (; jj + KN <= N; jj += KN) {
|
| simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
|
| }
|
| for (; jj < N; jj++) {
|
| float a = C[jj];
|
| for (int64_t kk = 0; kk < K; kk++) {
|
| a += A[kk] * B[kk * N + jj];
|
| }
|
| C[jj] = a;
|
| }
|
|
|
| A += K;
|
| C += N;
|
| }
|
| }
|
|
|
| #if defined(__GNUC__) && !defined(__clang__)
|
| #pragma GCC diagnostic pop
|
| #endif
|
|
|
| #else
|
|
|
| static void simd_gemm(
|
| float * GGML_RESTRICT C,
|
| const float * GGML_RESTRICT A,
|
| const float * GGML_RESTRICT B,
|
| int M, int K, int N)
|
| {
|
| for (int64_t i = 0; i < M; i++) {
|
| for (int64_t j = 0; j < N; j++) {
|
| float sum = C[i * N + j];
|
| for (int64_t kk = 0; kk < K; kk++) {
|
| sum += A[i * K + kk] * B[kk * N + j];
|
| }
|
| C[i * N + j] = sum;
|
| }
|
| }
|
| }
|
|
|
| #endif
|
|
|