File size: 3,103 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// test_rope_manual.cpp — verify manual HF-style RoPE against Python reference.
#include "acl_common.h"
#include "acl_runtime.h"
#include "rope.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <vector>

static float bf16_to_float(uint16_t x) {
    uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}

static std::vector<uint8_t> read_file(const std::string& p) {
    std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
    f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
}

int main() {
    const std::string data = "tests/attn_data";
    AclRuntime rt;
    rt.init(0);

    const int64_t B = 1, S = 5, Hq = 64, Hkv = 4, Dh = 128;

    auto qn_h = read_file(data + "/q_normed.bin");
    auto kn_h = read_file(data + "/k_normed.bin");
    auto cos_h = read_file(data + "/cos.bin");
    auto sin_h = read_file(data + "/sin.bin");
    auto qr_h = read_file(data + "/q_roped.bin");
    auto kr_h = read_file(data + "/k_roped.bin");

    DeviceBuffer q_d(qn_h.size()), k_d(kn_h.size()), cos_d(cos_h.size()), sin_d(sin_h.size());
    DeviceBuffer scratch_d(B * S * Hq * Dh * 2);  // max of Nq, Nk

    ACL_CHECK(aclrtMemcpy(q_d.get(),   qn_h.size(), qn_h.data(),   qn_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(k_d.get(),   kn_h.size(), kn_h.data(),   kn_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(cos_d.get(), cos_h.size(), cos_h.data(), cos_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(sin_d.get(), sin_h.size(), sin_h.data(), sin_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));

    apply_rope_manual(rt.stream(),
                      q_d.get(), B, S, Hq, Dh,
                      k_d.get(), Hkv,
                      cos_d.get(), sin_d.get(),
                      scratch_d.get());
    rt.sync();

    auto compare = [&](const char* tag, const DeviceBuffer& buf, int64_t N, const std::vector<uint8_t>& ref_h) {
        std::vector<uint16_t> cxx(B * S * N * Dh);
        ACL_CHECK(aclrtMemcpy(cxx.data(), cxx.size()*2, buf.get(), cxx.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
        auto* ref = (const uint16_t*)ref_h.data();
        double l2d = 0, l2r = 0, maxd = 0;
        for (size_t i = 0; i < cxx.size(); i++) {
            float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]);
            l2d += (a-b)*(a-b); l2r += b*b;
            if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
        }
        double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
        printf("%s rel=%.4e max_abs=%.4f\n", tag, rel, maxd);
        printf("  cxx[0,0,:4]: "); for (int i = 0; i < 4; i++) printf("%.4f ", bf16_to_float(cxx[i]));
        printf("\n  ref[0,0,:4]: "); for (int i = 0; i < 4; i++) printf("%.4f ", bf16_to_float(ref[i])); printf("\n");
        return rel < 1e-2;
    };

    bool ok_q = compare("Q", q_d, Hq, qr_h);
    bool ok_k = compare("K", k_d, Hkv, kr_h);
    bool pass = ok_q && ok_k;
    printf("\n%s\n", pass ? "=== test_rope_manual PASS ===" : "=== test_rope_manual FAIL ===");
    return pass ? 0 : 1;
}