File size: 15,847 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
// aclnn_ops.h — thin wrappers around common aclnn operators used in forward pass.
// Each wrapper does GetWorkspaceSize + op call on the provided stream.
//
// All tensors are passed as raw aclTensor* (caller owns them).
// Workspace allocation uses DeviceBuffer (RAII).
#pragma once
#include "acl_common.h"
#include "workspace_pool.h"

// Thread-local shared workspace pool for all aclnn wrappers below. Single-threaded stream
// means we can safely reuse one buffer across serial op calls. Set via `GGML_CANN_WP=0` is
// not supported here — if truly needed, we'd wire a flag.
inline WorkspacePool& _lca_pool() {
    thread_local WorkspacePool pool;
    return pool;
}

#include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_addcmul.h>
#include <aclnnop/aclnn_grouped_matmul_v4.h>
#include <aclnnop/aclnn_moe_finalize_routing.h>
#include <aclnnop/aclnn_moe_finalize_routing_v2.h>
#include <aclnnop/aclnn_moe_gating_top_k_softmax.h>
#include <aclnnop/aclnn_moe_init_routing_v3.h>
#include <aclnnop/aclnn_cast.h>
#include <aclnnop/aclnn_copy.h>
#include <aclnnop/aclnn_div.h>
#include <aclnnop/aclnn_fused_infer_attention_score.h>
#include <aclnnop/aclnn_index_select.h>
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/aclnn_mul.h>
#include <aclnnop/aclnn_neg.h>
#include <aclnnop/aclnn_reduce_sum.h>
#include <aclnnop/aclnn_silu.h>

// ---- RmsNorm ----
// Signature (based on ggml-cann usage): aclnnRmsNorm(x, gamma, eps, y, rstd)
// where rstd (rsqrt of mean-square) is an extra output we usually discard.

// Forward declare header; include happens in impl file to keep this header light.
extern "C" {
#include <aclnnop/aclnn_rms_norm.h>
}

inline void rms_norm(aclrtStream stream,
                     aclTensor* x,         // [N, D] BF16/FP16
                     aclTensor* gamma,     // [D] same dtype as x
                     double     eps,
                     aclTensor* y,         // [N, D]
                     aclTensor* rstd       // [N] fp32 (required output)
) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnRmsNormGetWorkspaceSize(x, gamma, eps, y, rstd, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnRmsNorm(wp, ws, exec, stream));
}

// ---- Silu ----
inline void silu(aclrtStream stream, aclTensor* x, aclTensor* y) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnSiluGetWorkspaceSize(x, y, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnSilu(wp, ws, exec, stream));
}

// ---- Mul (element-wise) ----
inline void mul(aclrtStream stream, aclTensor* a, aclTensor* b, aclTensor* out) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnMulGetWorkspaceSize(a, b, out, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnMul(wp, ws, exec, stream));
}

// ---- Cast ----
inline void cast(aclrtStream stream, aclTensor* x, aclDataType dst_dtype, aclTensor* y) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnCastGetWorkspaceSize(x, dst_dtype, y, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnCast(wp, ws, exec, stream));
}

// ---- InplaceCopy: copy src (possibly non-contiguous via strides) into contiguous dst ----
inline void inplace_copy(aclrtStream stream, aclTensor* dst, aclTensor* src) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnInplaceCopyGetWorkspaceSize(dst, src, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnInplaceCopy(wp, ws, exec, stream));
}

// ---- Matmul: out = a @ b ----
// cube_math_type:
//   0 = KEEP_DTYPE, 1 = ALLOW_FP32_DOWN_PRECISION, 2 = USE_FP16, 3 = USE_HF32
inline void matmul(aclrtStream stream,
                   aclTensor* a, aclTensor* b, aclTensor* out,
                   int8_t cube_math_type = 1) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnMatmulGetWorkspaceSize(a, b, out, cube_math_type, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnMatmul(wp, ws, exec, stream));
}

// ---- Neg ----
inline void neg(aclrtStream stream, aclTensor* x, aclTensor* y) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnNegGetWorkspaceSize(x, y, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnNeg(wp, ws, exec, stream));
}

// ---- Addcmul: self = self + value * (tensor1 * tensor2) ----
inline void addcmul(aclrtStream stream, aclTensor* self_io, aclTensor* t1, aclTensor* t2, float value) {
    aclScalar* v = aclCreateScalar(&value, ACL_FLOAT);
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnAddcmulGetWorkspaceSize(self_io, t1, t2, v, self_io, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnAddcmul(wp, ws, exec, stream));
    aclDestroyScalar(v);
}

// ---- MoE Gating TopK Softmax ----
// x [N, E] → y [N, K] (top-K softmax probs), expert_idx [N, K] int32, row_idx [N, K] int32
inline void moe_gating_topk_softmax(aclrtStream stream,
                                    aclTensor* x, int64_t k,
                                    aclTensor* y_out, aclTensor* idx_out, aclTensor* row_idx_out) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnMoeGatingTopKSoftmaxGetWorkspaceSize(x, nullptr, k, y_out, idx_out, row_idx_out, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnMoeGatingTopKSoftmax(wp, ws, exec, stream));
}

// ---- MoE Init Routing V3 ----
// x [N, D], expert_idx [N, K] int32 → expanded_x [N*K, D], expanded_row_idx [N*K] int32,
// tokens_per_expert [E] int64
inline void moe_init_routing_v3(aclrtStream stream,
                                aclTensor* x, aclTensor* expert_idx,
                                int64_t n_experts, int64_t active_num,
                                aclTensor* expanded_x, aclTensor* expanded_row_idx,
                                aclTensor* tokens_per_expert)
{
    int64_t range[2] = {0, n_experts};
    aclIntArray* r = aclCreateIntArray(range, 2);
    // scale_out_optional we dummy since quant_mode=-1 (no quant) still requires pass a placeholder?
    // Per our POC test earlier: pass a real tensor for scale_out works.
    // For simplicity here, we'll allocate a dummy [active_num] float tensor.
    DeviceBuffer dummy(active_num * 4);
    auto t_dummy = make_contig_tensor(dummy.get(), ACL_FLOAT, {active_num});

    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    // rowIdxType=1: expanded_row_idx[i] = sorted_position p for i-th original (n,k) flat index.
    // This lets us use expanded_row_idx directly as the gather index (forward permutation).
    ACLNN_CHECK(aclnnMoeInitRoutingV3GetWorkspaceSize(
        x, expert_idx, nullptr, nullptr,
        active_num, 0, n_experts, 0, 1, true, -1,
        r, 1,
        expanded_x, expanded_row_idx, tokens_per_expert, t_dummy.get(),
        &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnMoeInitRoutingV3(wp, ws, exec, stream));
    aclDestroyIntArray(r);
}

// ---- GroupedMatmulV4 (single-in single-out, M-axis split) ----
// x [T, K_in], w [E, K_in, N_out] contiguous row-major, group_list [E] int64 → y [T, N_out]
// group_list_type: 0=cumsum, 1=counts (V4 doc)
inline void grouped_matmul_v4(aclrtStream stream,
                              aclTensor* x, aclTensor* w, aclTensor* group_list, aclTensor* y,
                              int64_t group_list_type = 1)
{
    aclTensor* xa[] = {x}; aclTensorList* x_list = aclCreateTensorList(xa, 1);
    aclTensor* wa[] = {w}; aclTensorList* w_list = aclCreateTensorList(wa, 1);
    aclTensor* ya[] = {y}; aclTensorList* y_list = aclCreateTensorList(ya, 1);

    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnGroupedMatmulV4GetWorkspaceSize(
        x_list, w_list,
        nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
        group_list,
        nullptr, nullptr, nullptr,
        3, 0, group_list_type, 0,
        y_list, nullptr, nullptr,
        &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnGroupedMatmulV4(wp, ws, exec, stream));
    // NOTE: TensorList takes ownership of the raw tensors. Destroying the list frees them,
    // which would cause double-free in the caller's AclTensorPtr. Leak the list (small cost).
    // A cleaner API would accept (ptr, shape, dtype) triples and build tensors internally.
    // TODO(M6): refactor for long-running use.
}

// ---- MoE Finalize Routing V2: out = x1 + weighted_sum of top-K outputs ----
// V2 has all inputs optional except expandedX/expandedRowIdx/out; pass nullptr for x1 to
// skip the residual add, or pass the residual to fuse it into this op.
inline void moe_finalize_routing(aclrtStream stream,
                                 aclTensor* expanded_x,
                                 aclTensor* x1_skip,               // [N, D] added to output (nullable)
                                 aclTensor* scales,                // weights [N, K]
                                 aclTensor* expanded_row_idx,
                                 aclTensor* expert_idx,             // [N, K] topk expert indices (nullable)
                                 aclTensor* out)
{
    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnMoeFinalizeRoutingV2GetWorkspaceSize(
        expanded_x,
        expanded_row_idx,
        x1_skip,        // x1Optional
        nullptr,        // x2Optional
        nullptr,        // biasOptional
        scales,         // scalesOptional
        expert_idx,     // expertIdxOptional (needed for correct routing)
        0,              // dropPadMode (0 = dropless, which matches our pipeline)
        out,
        &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnMoeFinalizeRoutingV2(wp, ws, exec, stream));
}

// ---- Div: self / other (broadcast supported) ----
inline void div_tensor(aclrtStream stream, aclTensor* self, aclTensor* other, aclTensor* out) {
    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnDivGetWorkspaceSize(self, other, out, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnDiv(wp, ws, exec, stream));
}

// ---- In-place scalar add: self += scalar ----
#include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_argsort.h>

// ---- Argsort: indices that would sort self along dim (returns INT64) ----
inline void argsort(aclrtStream stream, aclTensor* self, int64_t dim, bool descending,
                    aclTensor* indices_out) {
    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnArgsortGetWorkspaceSize(self, dim, descending, indices_out, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnArgsort(wp, ws, exec, stream));
}

inline void inplace_adds(aclrtStream stream, aclTensor* self, double value) {
    float v = (float)value;
    aclScalar* s = aclCreateScalar(&v, ACL_FLOAT);
    float alpha_v = 1.0f;
    aclScalar* al = aclCreateScalar(&alpha_v, ACL_FLOAT);
    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnInplaceAddsGetWorkspaceSize(self, s, al, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnInplaceAdds(wp, ws, exec, stream));
    aclDestroyScalar(s);
    aclDestroyScalar(al);
}

// ---- ReduceSum over specified dims ----
inline void reduce_sum(aclrtStream stream, aclTensor* self, const std::vector<int64_t>& dims,
                       bool keep_dims, aclDataType out_dtype, aclTensor* out) {
    aclIntArray* d = aclCreateIntArray(dims.data(), dims.size());
    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnReduceSumGetWorkspaceSize(self, d, keep_dims, out_dtype, out, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnReduceSum(wp, ws, exec, stream));
    aclDestroyIntArray(d);
}

// ---- IndexSelect: out[j] = self[index[j], ...] ----
inline void index_select(aclrtStream stream, aclTensor* self, int64_t dim, aclTensor* index, aclTensor* out) {
    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnIndexSelectGetWorkspaceSize(self, dim, index, out, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnIndexSelect(wp, ws, exec, stream));
}

// ---- FusedInferAttentionScore (simplified wrapper for prefill/decode without quant, BSH layout).
// Caller owns q/k/v/mask/out; k/v are single-tensor lists.
inline void fused_infer_attention_score(
    aclrtStream stream,
    aclTensor* q,                       // [B, S, Hq*Dh] BF16
    aclTensor* k,                       // [B, S, Hkv*Dh] BF16
    aclTensor* v,                       // [B, S, Hkv*Dh] BF16
    aclTensor* atten_mask,              // [1, 1, M, M] bool, sparse_mode=3 needs M=2048
    std::vector<int64_t> actual_seq_lens,
    std::vector<int64_t> actual_seq_lens_kv,
    int64_t num_heads, int64_t num_kv_heads,
    double scale, int64_t sparse_mode,
    aclTensor* out)                     // [B, S, Hq*Dh]
{
    aclTensor* k_arr[] = {k};
    aclTensor* v_arr[] = {v};
    aclTensorList* k_list = aclCreateTensorList(k_arr, 1);
    aclTensorList* v_list = aclCreateTensorList(v_arr, 1);
    aclIntArray* sq   = aclCreateIntArray(actual_seq_lens.data(),    (uint64_t)actual_seq_lens.size());
    aclIntArray* skv  = aclCreateIntArray(actual_seq_lens_kv.data(), (uint64_t)actual_seq_lens_kv.size());

    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    ACLNN_CHECK(aclnnFusedInferAttentionScoreGetWorkspaceSize(
        q, k_list, v_list,
        nullptr,            // pseShift
        atten_mask,
        sq, skv,
        nullptr, nullptr, nullptr, nullptr, nullptr,    // dequant/quant scales
        nullptr, nullptr,                                // antiquant
        nullptr, nullptr, nullptr,                       // block_table, q_padding, kv_padding
        num_heads,
        scale,
        2147483647, 2147483647,                         // pre/next tokens (no limit)
        (char*)"BSH",
        num_kv_heads,
        sparse_mode,
        0,                                               // inner_precise
        0, 0,                                            // block_size, antiquant_mode
        false,                                           // softmax_lse_flag
        out, nullptr,
        &ws, &exec));

    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnFusedInferAttentionScore(wp, ws, exec, stream));
    // See note on grouped_matmul_v4 — intentionally leak lists to avoid double-free with caller RAII.
    (void)k_list; (void)v_list;
    aclDestroyIntArray(sq);
    aclDestroyIntArray(skv);
}

// ---- "Linear" helper: y = x @ W.T where W is stored as [out_features, in_features] (HF convention).
// Achieved by viewing W as [in_features, out_features] with stride [1, in_features] (elements).
// Returns y [N, out_features].
// Caller allocates y.
inline void linear_hf(aclrtStream stream,
                      aclTensor* x,                       // [N, in_features]
                      void* W_data, aclDataType dtype,
                      int64_t out_features, int64_t in_features,
                      aclTensor* y_out)                   // [N, out_features]
{
    auto W_view = make_acl_tensor(W_data, dtype,
                                  {in_features, out_features},
                                  {1, in_features});  // strides: d0=1 elem, d1=in_features elems
    matmul(stream, x, W_view.get(), y_out);
}