| |
| |
| |
| |
| |
| |
| |
| |
| #pragma once |
| #include "acl_common.h" |
| #include "aclnn_ops.h" |
| #include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h> |
|
|
| |
| |
| |
| |
| |
| |
| |
| inline void apply_rope_fused(aclrtStream stream, |
| void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh, |
| void* k_data, int64_t Nk, |
| void* cos_data, void* sin_data) { |
| const aclDataType dt = ACL_BF16; |
| auto t_q = make_contig_tensor(q_data, dt, {B, S, Nq, Dh}); |
| auto t_k = make_contig_tensor(k_data, dt, {B, S, Nk, Dh}); |
| auto t_cos = make_contig_tensor(cos_data, dt, {1, S, 1, Dh}); |
| auto t_sin = make_contig_tensor(sin_data, dt, {1, S, 1, Dh}); |
| uint64_t ws = 0; aclOpExecutor* exec = nullptr; |
| char mode[] = "half"; |
| ACLNN_CHECK(aclnnApplyRotaryPosEmbV2GetWorkspaceSize( |
| t_q.get(), t_k.get(), t_cos.get(), t_sin.get(), |
| 1, mode, &ws, &exec)); |
| void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr; |
| ACLNN_CHECK(aclnnApplyRotaryPosEmbV2(wp, ws, exec, stream)); |
| } |
|
|
| |
| |
| |
| |
| |
| inline void apply_rope_manual(aclrtStream stream, |
| void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh, |
| void* k_data, int64_t Nk, |
| void* cos_data, void* sin_data, |
| void* scratch_data) { |
| const aclDataType dt = ACL_BF16; |
| const size_t elem = 2; |
| const int64_t halfDh = Dh / 2; |
|
|
| auto process = [&](void* x_data, int64_t N) { |
| |
| |
| const std::vector<int64_t> full_shape = {B, S, N, Dh}; |
| const std::vector<int64_t> full_stride = {S*N*Dh, N*Dh, Dh, 1}; |
| const std::vector<int64_t> half_shape = {B, S, N, halfDh}; |
| const std::vector<int64_t> half_stride = full_stride; |
|
|
| |
| auto t_x = make_acl_tensor(x_data, dt, full_shape, full_stride); |
|
|
| |
| auto t_x_left = make_acl_tensor(x_data, dt, half_shape, half_stride); |
| auto t_x_right = make_acl_tensor((char*)x_data + halfDh*elem, dt, half_shape, half_stride); |
|
|
| |
| const std::vector<int64_t> rh_stride = {S*N*Dh, N*Dh, Dh, 1}; |
| auto t_rh = make_acl_tensor(scratch_data, dt, full_shape, rh_stride); |
| auto t_rh_left = make_acl_tensor(scratch_data, dt, half_shape, rh_stride); |
| auto t_rh_right = make_acl_tensor((char*)scratch_data + halfDh*elem, dt, half_shape, rh_stride); |
|
|
| |
| neg(stream, t_x_right.get(), t_rh_left.get()); |
| |
| inplace_copy(stream, t_rh_right.get(), t_x_left.get()); |
|
|
| |
| |
| auto t_cos = make_acl_tensor(cos_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1}); |
| auto t_sin = make_acl_tensor(sin_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1}); |
|
|
| |
| |
| |
| |
| mul(stream, t_x.get(), t_cos.get(), t_x.get()); |
| addcmul(stream, t_x.get(), t_rh.get(), t_sin.get(), 1); |
| }; |
|
|
| process(q_data, Nq); |
| process(k_data, Nk); |
| } |
|
|