File size: 4,874 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// rope.h — Manual HF-style RoPE using basic aclnn ops.
//
// Formula: q_out = q * cos + rotate_half(q) * sin
//   where rotate_half(q) = concat(-q[..., d/2:], q[..., :d/2], dim=-1)
//
// Tensor layout: q/k [B, S, N, Dh] BF16, cos/sin [1, S, Dh] BF16
//   (cos/sin are broadcast across B and N dims)
//
#pragma once
#include "acl_common.h"
#include "aclnn_ops.h"
#include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h>

// Fused RoPE via aclnnApplyRotaryPosEmbV2 — replaces the 8-op manual version with a single
// op, saving ~7 launches per layer × 94 layers = ~658 launches/token. Validated on 910 initial:
// layout=1 + rotaryMode="half" matches HF rotate_half semantics (rel=1.24e-3 vs manual).
//
// q_data: [B, S, Nq, Dh] BF16 (modified in place)
// k_data: [B, S, Nk, Dh] BF16 (modified in place)
// cos_data / sin_data: [1, S, 1, Dh] BF16  (single contiguous buffer slice from RopeCache)
inline void apply_rope_fused(aclrtStream stream,
                             void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh,
                             void* k_data, int64_t Nk,
                             void* cos_data, void* sin_data) {
    const aclDataType dt = ACL_BF16;
    auto t_q   = make_contig_tensor(q_data, dt, {B, S, Nq, Dh});
    auto t_k   = make_contig_tensor(k_data, dt, {B, S, Nk, Dh});
    auto t_cos = make_contig_tensor(cos_data, dt, {1, S, 1, Dh});
    auto t_sin = make_contig_tensor(sin_data, dt, {1, S, 1, Dh});
    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    char mode[] = "half";
    ACLNN_CHECK(aclnnApplyRotaryPosEmbV2GetWorkspaceSize(
        t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
        /*layout=*/1, mode, &ws, &exec));
    void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
    ACLNN_CHECK(aclnnApplyRotaryPosEmbV2(wp, ws, exec, stream));
}

// Apply RoPE in-place to q and k.
// q_data: pointer to [B, S, Nq, Dh] BF16 (modified in place)
// k_data: pointer to [B, S, Nk, Dh] BF16 (modified in place)
// cos_data, sin_data: [1, S, Dh] BF16
// scratch_data: pointer to contiguous [B, S, max(Nq,Nk), Dh] BF16 scratch buffer for rotate_half
inline void apply_rope_manual(aclrtStream stream,
                              void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh,
                              void* k_data, int64_t Nk,
                              void* cos_data, void* sin_data,
                              void* scratch_data) {
    const aclDataType dt = ACL_BF16;
    const size_t elem = 2;
    const int64_t halfDh = Dh / 2;

    auto process = [&](void* x_data, int64_t N) {
        // Strides in elements (row-major [B, S, N, Dh]):
        // stride = [S*N*Dh, N*Dh, Dh, 1]
        const std::vector<int64_t> full_shape   = {B, S, N, Dh};
        const std::vector<int64_t> full_stride  = {S*N*Dh, N*Dh, Dh, 1};
        const std::vector<int64_t> half_shape   = {B, S, N, halfDh};
        const std::vector<int64_t> half_stride  = full_stride;  // same leading 3 strides

        // View x as full
        auto t_x = make_acl_tensor(x_data, dt, full_shape, full_stride);

        // View of x left half and right half (shifted pointers, same layout, last dim half)
        auto t_x_left  = make_acl_tensor(x_data,                           dt, half_shape, half_stride);
        auto t_x_right = make_acl_tensor((char*)x_data + halfDh*elem,     dt, half_shape, half_stride);

        // rotate_half buffer view (contiguous [B, S, N, Dh])
        const std::vector<int64_t> rh_stride = {S*N*Dh, N*Dh, Dh, 1};
        auto t_rh       = make_acl_tensor(scratch_data,                    dt, full_shape, rh_stride);
        auto t_rh_left  = make_acl_tensor(scratch_data,                    dt, half_shape, rh_stride);
        auto t_rh_right = make_acl_tensor((char*)scratch_data + halfDh*elem, dt, half_shape, rh_stride);

        // rh[..., :Dh/2] = -x[..., Dh/2:]
        neg(stream, t_x_right.get(), t_rh_left.get());
        // rh[..., Dh/2:] = x[..., :Dh/2]
        inplace_copy(stream, t_rh_right.get(), t_x_left.get());

        // cos/sin views broadcastable to [B, S, N, Dh]
        // Original storage: [1, S, Dh]. For broadcast, use shape [1, S, 1, Dh] with strides [0, Dh, 0, 1].
        auto t_cos = make_acl_tensor(cos_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1});
        auto t_sin = make_acl_tensor(sin_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1});

        // q_rot = q * cos + rh * sin  (use addcmul: q *= cos, then q += rh * sin)
        // Compute tmp = q * cos (fresh buffer needed; use scratch_data is occupied by rh)
        // Better: multiply x in place: x *= cos, then x += rh * sin
        // aclnnMul with x as both in and out is inplace.
        mul(stream, t_x.get(), t_cos.get(), t_x.get());         // x = x * cos
        addcmul(stream, t_x.get(), t_rh.get(), t_sin.get(), 1); // x += 1 * (rh * sin)
    };

    process(q_data, Nq);
    process(k_data, Nk);
}