// rope.h — Manual HF-style RoPE using basic aclnn ops. // // Formula: q_out = q * cos + rotate_half(q) * sin // where rotate_half(q) = concat(-q[..., d/2:], q[..., :d/2], dim=-1) // // Tensor layout: q/k [B, S, N, Dh] BF16, cos/sin [1, S, Dh] BF16 // (cos/sin are broadcast across B and N dims) // #pragma once #include "acl_common.h" #include "aclnn_ops.h" #include // Fused RoPE via aclnnApplyRotaryPosEmbV2 — replaces the 8-op manual version with a single // op, saving ~7 launches per layer × 94 layers = ~658 launches/token. Validated on 910 initial: // layout=1 + rotaryMode="half" matches HF rotate_half semantics (rel=1.24e-3 vs manual). // // q_data: [B, S, Nq, Dh] BF16 (modified in place) // k_data: [B, S, Nk, Dh] BF16 (modified in place) // cos_data / sin_data: [1, S, 1, Dh] BF16 (single contiguous buffer slice from RopeCache) inline void apply_rope_fused(aclrtStream stream, void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh, void* k_data, int64_t Nk, void* cos_data, void* sin_data) { const aclDataType dt = ACL_BF16; auto t_q = make_contig_tensor(q_data, dt, {B, S, Nq, Dh}); auto t_k = make_contig_tensor(k_data, dt, {B, S, Nk, Dh}); auto t_cos = make_contig_tensor(cos_data, dt, {1, S, 1, Dh}); auto t_sin = make_contig_tensor(sin_data, dt, {1, S, 1, Dh}); uint64_t ws = 0; aclOpExecutor* exec = nullptr; char mode[] = "half"; ACLNN_CHECK(aclnnApplyRotaryPosEmbV2GetWorkspaceSize( t_q.get(), t_k.get(), t_cos.get(), t_sin.get(), /*layout=*/1, mode, &ws, &exec)); void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr; ACLNN_CHECK(aclnnApplyRotaryPosEmbV2(wp, ws, exec, stream)); } // Apply RoPE in-place to q and k. // q_data: pointer to [B, S, Nq, Dh] BF16 (modified in place) // k_data: pointer to [B, S, Nk, Dh] BF16 (modified in place) // cos_data, sin_data: [1, S, Dh] BF16 // scratch_data: pointer to contiguous [B, S, max(Nq,Nk), Dh] BF16 scratch buffer for rotate_half inline void apply_rope_manual(aclrtStream stream, void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh, void* k_data, int64_t Nk, void* cos_data, void* sin_data, void* scratch_data) { const aclDataType dt = ACL_BF16; const size_t elem = 2; const int64_t halfDh = Dh / 2; auto process = [&](void* x_data, int64_t N) { // Strides in elements (row-major [B, S, N, Dh]): // stride = [S*N*Dh, N*Dh, Dh, 1] const std::vector full_shape = {B, S, N, Dh}; const std::vector full_stride = {S*N*Dh, N*Dh, Dh, 1}; const std::vector half_shape = {B, S, N, halfDh}; const std::vector half_stride = full_stride; // same leading 3 strides // View x as full auto t_x = make_acl_tensor(x_data, dt, full_shape, full_stride); // View of x left half and right half (shifted pointers, same layout, last dim half) auto t_x_left = make_acl_tensor(x_data, dt, half_shape, half_stride); auto t_x_right = make_acl_tensor((char*)x_data + halfDh*elem, dt, half_shape, half_stride); // rotate_half buffer view (contiguous [B, S, N, Dh]) const std::vector rh_stride = {S*N*Dh, N*Dh, Dh, 1}; auto t_rh = make_acl_tensor(scratch_data, dt, full_shape, rh_stride); auto t_rh_left = make_acl_tensor(scratch_data, dt, half_shape, rh_stride); auto t_rh_right = make_acl_tensor((char*)scratch_data + halfDh*elem, dt, half_shape, rh_stride); // rh[..., :Dh/2] = -x[..., Dh/2:] neg(stream, t_x_right.get(), t_rh_left.get()); // rh[..., Dh/2:] = x[..., :Dh/2] inplace_copy(stream, t_rh_right.get(), t_x_left.get()); // cos/sin views broadcastable to [B, S, N, Dh] // Original storage: [1, S, Dh]. For broadcast, use shape [1, S, 1, Dh] with strides [0, Dh, 0, 1]. auto t_cos = make_acl_tensor(cos_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1}); auto t_sin = make_acl_tensor(sin_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1}); // q_rot = q * cos + rh * sin (use addcmul: q *= cos, then q += rh * sin) // Compute tmp = q * cos (fresh buffer needed; use scratch_data is occupied by rh) // Better: multiply x in place: x *= cos, then x += rh * sin // aclnnMul with x as both in and out is inplace. mul(stream, t_x.get(), t_cos.get(), t_x.get()); // x = x * cos addcmul(stream, t_x.get(), t_rh.get(), t_sin.get(), 1); // x += 1 * (rh * sin) }; process(q_data, Nq); process(k_data, Nk); }