xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
// rope.h — Manual HF-style RoPE using basic aclnn ops.
//
// Formula: q_out = q * cos + rotate_half(q) * sin
// where rotate_half(q) = concat(-q[..., d/2:], q[..., :d/2], dim=-1)
//
// Tensor layout: q/k [B, S, N, Dh] BF16, cos/sin [1, S, Dh] BF16
// (cos/sin are broadcast across B and N dims)
//
#pragma once
#include "acl_common.h"
#include "aclnn_ops.h"
#include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h>
// Fused RoPE via aclnnApplyRotaryPosEmbV2 — replaces the 8-op manual version with a single
// op, saving ~7 launches per layer × 94 layers = ~658 launches/token. Validated on 910 initial:
// layout=1 + rotaryMode="half" matches HF rotate_half semantics (rel=1.24e-3 vs manual).
//
// q_data: [B, S, Nq, Dh] BF16 (modified in place)
// k_data: [B, S, Nk, Dh] BF16 (modified in place)
// cos_data / sin_data: [1, S, 1, Dh] BF16 (single contiguous buffer slice from RopeCache)
inline void apply_rope_fused(aclrtStream stream,
void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh,
void* k_data, int64_t Nk,
void* cos_data, void* sin_data) {
const aclDataType dt = ACL_BF16;
auto t_q = make_contig_tensor(q_data, dt, {B, S, Nq, Dh});
auto t_k = make_contig_tensor(k_data, dt, {B, S, Nk, Dh});
auto t_cos = make_contig_tensor(cos_data, dt, {1, S, 1, Dh});
auto t_sin = make_contig_tensor(sin_data, dt, {1, S, 1, Dh});
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
char mode[] = "half";
ACLNN_CHECK(aclnnApplyRotaryPosEmbV2GetWorkspaceSize(
t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
/*layout=*/1, mode, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnApplyRotaryPosEmbV2(wp, ws, exec, stream));
}
// Apply RoPE in-place to q and k.
// q_data: pointer to [B, S, Nq, Dh] BF16 (modified in place)
// k_data: pointer to [B, S, Nk, Dh] BF16 (modified in place)
// cos_data, sin_data: [1, S, Dh] BF16
// scratch_data: pointer to contiguous [B, S, max(Nq,Nk), Dh] BF16 scratch buffer for rotate_half
inline void apply_rope_manual(aclrtStream stream,
void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh,
void* k_data, int64_t Nk,
void* cos_data, void* sin_data,
void* scratch_data) {
const aclDataType dt = ACL_BF16;
const size_t elem = 2;
const int64_t halfDh = Dh / 2;
auto process = [&](void* x_data, int64_t N) {
// Strides in elements (row-major [B, S, N, Dh]):
// stride = [S*N*Dh, N*Dh, Dh, 1]
const std::vector<int64_t> full_shape = {B, S, N, Dh};
const std::vector<int64_t> full_stride = {S*N*Dh, N*Dh, Dh, 1};
const std::vector<int64_t> half_shape = {B, S, N, halfDh};
const std::vector<int64_t> half_stride = full_stride; // same leading 3 strides
// View x as full
auto t_x = make_acl_tensor(x_data, dt, full_shape, full_stride);
// View of x left half and right half (shifted pointers, same layout, last dim half)
auto t_x_left = make_acl_tensor(x_data, dt, half_shape, half_stride);
auto t_x_right = make_acl_tensor((char*)x_data + halfDh*elem, dt, half_shape, half_stride);
// rotate_half buffer view (contiguous [B, S, N, Dh])
const std::vector<int64_t> rh_stride = {S*N*Dh, N*Dh, Dh, 1};
auto t_rh = make_acl_tensor(scratch_data, dt, full_shape, rh_stride);
auto t_rh_left = make_acl_tensor(scratch_data, dt, half_shape, rh_stride);
auto t_rh_right = make_acl_tensor((char*)scratch_data + halfDh*elem, dt, half_shape, rh_stride);
// rh[..., :Dh/2] = -x[..., Dh/2:]
neg(stream, t_x_right.get(), t_rh_left.get());
// rh[..., Dh/2:] = x[..., :Dh/2]
inplace_copy(stream, t_rh_right.get(), t_x_left.get());
// cos/sin views broadcastable to [B, S, N, Dh]
// Original storage: [1, S, Dh]. For broadcast, use shape [1, S, 1, Dh] with strides [0, Dh, 0, 1].
auto t_cos = make_acl_tensor(cos_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1});
auto t_sin = make_acl_tensor(sin_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1});
// q_rot = q * cos + rh * sin (use addcmul: q *= cos, then q += rh * sin)
// Compute tmp = q * cos (fresh buffer needed; use scratch_data is occupied by rh)
// Better: multiply x in place: x *= cos, then x += rh * sin
// aclnnMul with x as both in and out is inplace.
mul(stream, t_x.get(), t_cos.get(), t_x.get()); // x = x * cos
addcmul(stream, t_x.get(), t_rh.get(), t_sin.get(), 1); // x += 1 * (rh * sin)
};
process(q_data, Nq);
process(k_data, Nk);
}