llm_mutil_npu / include /aclnn_ops.h
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
// aclnn_ops.h — thin wrappers around common aclnn operators used in forward pass.
// Each wrapper does GetWorkspaceSize + op call on the provided stream.
//
// All tensors are passed as raw aclTensor* (caller owns them).
// Workspace allocation uses DeviceBuffer (RAII).
#pragma once
#include "acl_common.h"
#include "workspace_pool.h"
// Thread-local shared workspace pool for all aclnn wrappers below. Single-threaded stream
// means we can safely reuse one buffer across serial op calls. Set via `GGML_CANN_WP=0` is
// not supported here — if truly needed, we'd wire a flag.
inline WorkspacePool& _lca_pool() {
thread_local WorkspacePool pool;
return pool;
}
#include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_addcmul.h>
#include <aclnnop/aclnn_grouped_matmul_v4.h>
#include <aclnnop/aclnn_moe_finalize_routing.h>
#include <aclnnop/aclnn_moe_finalize_routing_v2.h>
#include <aclnnop/aclnn_moe_gating_top_k_softmax.h>
#include <aclnnop/aclnn_moe_init_routing_v3.h>
#include <aclnnop/aclnn_cast.h>
#include <aclnnop/aclnn_copy.h>
#include <aclnnop/aclnn_div.h>
#include <aclnnop/aclnn_fused_infer_attention_score.h>
#include <aclnnop/aclnn_index_select.h>
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/aclnn_mul.h>
#include <aclnnop/aclnn_neg.h>
#include <aclnnop/aclnn_reduce_sum.h>
#include <aclnnop/aclnn_silu.h>
// ---- RmsNorm ----
// Signature (based on ggml-cann usage): aclnnRmsNorm(x, gamma, eps, y, rstd)
// where rstd (rsqrt of mean-square) is an extra output we usually discard.
// Forward declare header; include happens in impl file to keep this header light.
extern "C" {
#include <aclnnop/aclnn_rms_norm.h>
}
inline void rms_norm(aclrtStream stream,
aclTensor* x, // [N, D] BF16/FP16
aclTensor* gamma, // [D] same dtype as x
double eps,
aclTensor* y, // [N, D]
aclTensor* rstd // [N] fp32 (required output)
) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnRmsNormGetWorkspaceSize(x, gamma, eps, y, rstd, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnRmsNorm(wp, ws, exec, stream));
}
// ---- Silu ----
inline void silu(aclrtStream stream, aclTensor* x, aclTensor* y) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnSiluGetWorkspaceSize(x, y, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnSilu(wp, ws, exec, stream));
}
// ---- Mul (element-wise) ----
inline void mul(aclrtStream stream, aclTensor* a, aclTensor* b, aclTensor* out) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnMulGetWorkspaceSize(a, b, out, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnMul(wp, ws, exec, stream));
}
// ---- Cast ----
inline void cast(aclrtStream stream, aclTensor* x, aclDataType dst_dtype, aclTensor* y) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnCastGetWorkspaceSize(x, dst_dtype, y, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnCast(wp, ws, exec, stream));
}
// ---- InplaceCopy: copy src (possibly non-contiguous via strides) into contiguous dst ----
inline void inplace_copy(aclrtStream stream, aclTensor* dst, aclTensor* src) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnInplaceCopyGetWorkspaceSize(dst, src, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnInplaceCopy(wp, ws, exec, stream));
}
// ---- Matmul: out = a @ b ----
// cube_math_type:
// 0 = KEEP_DTYPE, 1 = ALLOW_FP32_DOWN_PRECISION, 2 = USE_FP16, 3 = USE_HF32
inline void matmul(aclrtStream stream,
aclTensor* a, aclTensor* b, aclTensor* out,
int8_t cube_math_type = 1) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnMatmulGetWorkspaceSize(a, b, out, cube_math_type, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnMatmul(wp, ws, exec, stream));
}
// ---- Neg ----
inline void neg(aclrtStream stream, aclTensor* x, aclTensor* y) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnNegGetWorkspaceSize(x, y, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnNeg(wp, ws, exec, stream));
}
// ---- Addcmul: self = self + value * (tensor1 * tensor2) ----
inline void addcmul(aclrtStream stream, aclTensor* self_io, aclTensor* t1, aclTensor* t2, float value) {
aclScalar* v = aclCreateScalar(&value, ACL_FLOAT);
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnAddcmulGetWorkspaceSize(self_io, t1, t2, v, self_io, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnAddcmul(wp, ws, exec, stream));
aclDestroyScalar(v);
}
// ---- MoE Gating TopK Softmax ----
// x [N, E] → y [N, K] (top-K softmax probs), expert_idx [N, K] int32, row_idx [N, K] int32
inline void moe_gating_topk_softmax(aclrtStream stream,
aclTensor* x, int64_t k,
aclTensor* y_out, aclTensor* idx_out, aclTensor* row_idx_out) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnMoeGatingTopKSoftmaxGetWorkspaceSize(x, nullptr, k, y_out, idx_out, row_idx_out, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnMoeGatingTopKSoftmax(wp, ws, exec, stream));
}
// ---- MoE Init Routing V3 ----
// x [N, D], expert_idx [N, K] int32 → expanded_x [N*K, D], expanded_row_idx [N*K] int32,
// tokens_per_expert [E] int64
inline void moe_init_routing_v3(aclrtStream stream,
aclTensor* x, aclTensor* expert_idx,
int64_t n_experts, int64_t active_num,
aclTensor* expanded_x, aclTensor* expanded_row_idx,
aclTensor* tokens_per_expert)
{
int64_t range[2] = {0, n_experts};
aclIntArray* r = aclCreateIntArray(range, 2);
// scale_out_optional we dummy since quant_mode=-1 (no quant) still requires pass a placeholder?
// Per our POC test earlier: pass a real tensor for scale_out works.
// For simplicity here, we'll allocate a dummy [active_num] float tensor.
DeviceBuffer dummy(active_num * 4);
auto t_dummy = make_contig_tensor(dummy.get(), ACL_FLOAT, {active_num});
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
// rowIdxType=1: expanded_row_idx[i] = sorted_position p for i-th original (n,k) flat index.
// This lets us use expanded_row_idx directly as the gather index (forward permutation).
ACLNN_CHECK(aclnnMoeInitRoutingV3GetWorkspaceSize(
x, expert_idx, nullptr, nullptr,
active_num, 0, n_experts, 0, 1, true, -1,
r, 1,
expanded_x, expanded_row_idx, tokens_per_expert, t_dummy.get(),
&ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnMoeInitRoutingV3(wp, ws, exec, stream));
aclDestroyIntArray(r);
}
// ---- GroupedMatmulV4 (single-in single-out, M-axis split) ----
// x [T, K_in], w [E, K_in, N_out] contiguous row-major, group_list [E] int64 → y [T, N_out]
// group_list_type: 0=cumsum, 1=counts (V4 doc)
inline void grouped_matmul_v4(aclrtStream stream,
aclTensor* x, aclTensor* w, aclTensor* group_list, aclTensor* y,
int64_t group_list_type = 1)
{
aclTensor* xa[] = {x}; aclTensorList* x_list = aclCreateTensorList(xa, 1);
aclTensor* wa[] = {w}; aclTensorList* w_list = aclCreateTensorList(wa, 1);
aclTensor* ya[] = {y}; aclTensorList* y_list = aclCreateTensorList(ya, 1);
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnGroupedMatmulV4GetWorkspaceSize(
x_list, w_list,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
group_list,
nullptr, nullptr, nullptr,
3, 0, group_list_type, 0,
y_list, nullptr, nullptr,
&ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnGroupedMatmulV4(wp, ws, exec, stream));
// NOTE: TensorList takes ownership of the raw tensors. Destroying the list frees them,
// which would cause double-free in the caller's AclTensorPtr. Leak the list (small cost).
// A cleaner API would accept (ptr, shape, dtype) triples and build tensors internally.
// TODO(M6): refactor for long-running use.
}
// ---- MoE Finalize Routing V2: out = x1 + weighted_sum of top-K outputs ----
// V2 has all inputs optional except expandedX/expandedRowIdx/out; pass nullptr for x1 to
// skip the residual add, or pass the residual to fuse it into this op.
inline void moe_finalize_routing(aclrtStream stream,
aclTensor* expanded_x,
aclTensor* x1_skip, // [N, D] added to output (nullable)
aclTensor* scales, // weights [N, K]
aclTensor* expanded_row_idx,
aclTensor* expert_idx, // [N, K] topk expert indices (nullable)
aclTensor* out)
{
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnMoeFinalizeRoutingV2GetWorkspaceSize(
expanded_x,
expanded_row_idx,
x1_skip, // x1Optional
nullptr, // x2Optional
nullptr, // biasOptional
scales, // scalesOptional
expert_idx, // expertIdxOptional (needed for correct routing)
0, // dropPadMode (0 = dropless, which matches our pipeline)
out,
&ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnMoeFinalizeRoutingV2(wp, ws, exec, stream));
}
// ---- Div: self / other (broadcast supported) ----
inline void div_tensor(aclrtStream stream, aclTensor* self, aclTensor* other, aclTensor* out) {
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnDivGetWorkspaceSize(self, other, out, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnDiv(wp, ws, exec, stream));
}
// ---- In-place scalar add: self += scalar ----
#include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_argsort.h>
// ---- Argsort: indices that would sort self along dim (returns INT64) ----
inline void argsort(aclrtStream stream, aclTensor* self, int64_t dim, bool descending,
aclTensor* indices_out) {
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnArgsortGetWorkspaceSize(self, dim, descending, indices_out, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnArgsort(wp, ws, exec, stream));
}
inline void inplace_adds(aclrtStream stream, aclTensor* self, double value) {
float v = (float)value;
aclScalar* s = aclCreateScalar(&v, ACL_FLOAT);
float alpha_v = 1.0f;
aclScalar* al = aclCreateScalar(&alpha_v, ACL_FLOAT);
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnInplaceAddsGetWorkspaceSize(self, s, al, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnInplaceAdds(wp, ws, exec, stream));
aclDestroyScalar(s);
aclDestroyScalar(al);
}
// ---- ReduceSum over specified dims ----
inline void reduce_sum(aclrtStream stream, aclTensor* self, const std::vector<int64_t>& dims,
bool keep_dims, aclDataType out_dtype, aclTensor* out) {
aclIntArray* d = aclCreateIntArray(dims.data(), dims.size());
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnReduceSumGetWorkspaceSize(self, d, keep_dims, out_dtype, out, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnReduceSum(wp, ws, exec, stream));
aclDestroyIntArray(d);
}
// ---- IndexSelect: out[j] = self[index[j], ...] ----
inline void index_select(aclrtStream stream, aclTensor* self, int64_t dim, aclTensor* index, aclTensor* out) {
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnIndexSelectGetWorkspaceSize(self, dim, index, out, &ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnIndexSelect(wp, ws, exec, stream));
}
// ---- FusedInferAttentionScore (simplified wrapper for prefill/decode without quant, BSH layout).
// Caller owns q/k/v/mask/out; k/v are single-tensor lists.
inline void fused_infer_attention_score(
aclrtStream stream,
aclTensor* q, // [B, S, Hq*Dh] BF16
aclTensor* k, // [B, S, Hkv*Dh] BF16
aclTensor* v, // [B, S, Hkv*Dh] BF16
aclTensor* atten_mask, // [1, 1, M, M] bool, sparse_mode=3 needs M=2048
std::vector<int64_t> actual_seq_lens,
std::vector<int64_t> actual_seq_lens_kv,
int64_t num_heads, int64_t num_kv_heads,
double scale, int64_t sparse_mode,
aclTensor* out) // [B, S, Hq*Dh]
{
aclTensor* k_arr[] = {k};
aclTensor* v_arr[] = {v};
aclTensorList* k_list = aclCreateTensorList(k_arr, 1);
aclTensorList* v_list = aclCreateTensorList(v_arr, 1);
aclIntArray* sq = aclCreateIntArray(actual_seq_lens.data(), (uint64_t)actual_seq_lens.size());
aclIntArray* skv = aclCreateIntArray(actual_seq_lens_kv.data(), (uint64_t)actual_seq_lens_kv.size());
uint64_t ws = 0;
aclOpExecutor* exec = nullptr;
ACLNN_CHECK(aclnnFusedInferAttentionScoreGetWorkspaceSize(
q, k_list, v_list,
nullptr, // pseShift
atten_mask,
sq, skv,
nullptr, nullptr, nullptr, nullptr, nullptr, // dequant/quant scales
nullptr, nullptr, // antiquant
nullptr, nullptr, nullptr, // block_table, q_padding, kv_padding
num_heads,
scale,
2147483647, 2147483647, // pre/next tokens (no limit)
(char*)"BSH",
num_kv_heads,
sparse_mode,
0, // inner_precise
0, 0, // block_size, antiquant_mode
false, // softmax_lse_flag
out, nullptr,
&ws, &exec));
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
ACLNN_CHECK(aclnnFusedInferAttentionScore(wp, ws, exec, stream));
// See note on grouped_matmul_v4 — intentionally leak lists to avoid double-free with caller RAII.
(void)k_list; (void)v_list;
aclDestroyIntArray(sq);
aclDestroyIntArray(skv);
}
// ---- "Linear" helper: y = x @ W.T where W is stored as [out_features, in_features] (HF convention).
// Achieved by viewing W as [in_features, out_features] with stride [1, in_features] (elements).
// Returns y [N, out_features].
// Caller allocates y.
inline void linear_hf(aclrtStream stream,
aclTensor* x, // [N, in_features]
void* W_data, aclDataType dtype,
int64_t out_features, int64_t in_features,
aclTensor* y_out) // [N, out_features]
{
auto W_view = make_acl_tensor(W_data, dtype,
{in_features, out_features},
{1, in_features}); // strides: d0=1 elem, d1=in_features elems
matmul(stream, x, W_view.get(), y_out);
}