| #include <cstdlib> |
| #include <cstdio> |
| #include <cstring> |
| #include <fstream> |
| #include <vector> |
| #include <string> |
| #include "acl/acl.h" |
| #include "aclnn_smallq_flash_attention.h" |
|
|
| #define CHECK_ACL(expr) \ |
| do { \ |
| auto _ret = (expr); \ |
| if (_ret != 0) { \ |
| fprintf(stderr, "[ACL ERROR] %s:%d code=%d\n", \ |
| __FILE__, __LINE__, (int)_ret); \ |
| exit(1); \ |
| } \ |
| } while (0) |
|
|
| static int getEnvInt(const char* name, int def) { |
| const char* v = getenv(name); |
| return v ? atoi(v) : def; |
| } |
|
|
| static std::vector<uint8_t> readFile(const char* path) { |
| std::ifstream f(path, std::ios::binary | std::ios::ate); |
| if (!f) { fprintf(stderr, "cannot open %s\n", path); exit(1); } |
| size_t sz = f.tellg(); |
| f.seekg(0); |
| std::vector<uint8_t> buf(sz); |
| f.read(reinterpret_cast<char*>(buf.data()), sz); |
| return buf; |
| } |
|
|
| static void writeFile(const char* path, const void* data, size_t sz) { |
| std::ofstream f(path, std::ios::binary); |
| f.write(reinterpret_cast<const char*>(data), sz); |
| } |
|
|
| int main() { |
| int numHeads = getEnvInt("NUM_HEADS", 128); |
| int qLen = getEnvInt("Q_LEN", 5); |
| int headDim = getEnvInt("HEAD_DIM", 128); |
| int kvLen = getEnvInt("KV_LEN", 8192); |
|
|
| printf("[INFO] NUM_HEADS=%d, Q_LEN=%d, HEAD_DIM=%d, KV_LEN=%d\n", |
| numHeads, qLen, headDim, kvLen); |
|
|
| size_t qBytes = (size_t)numHeads * qLen * headDim * sizeof(uint16_t); |
| size_t kvBytes = (size_t)numHeads * kvLen * headDim * sizeof(uint16_t); |
| size_t oBytes = qBytes; |
|
|
| CHECK_ACL(aclInit(nullptr)); |
| int deviceId = getEnvInt("DEVICE_ID", 1); |
| CHECK_ACL(aclrtSetDevice(deviceId)); |
| aclrtStream stream = nullptr; |
| CHECK_ACL(aclrtCreateStream(&stream)); |
|
|
| |
| const char* inDir = getenv("IO_DIR"); |
| std::string ioDir = inDir ? std::string(inDir) : std::string(""); |
| std::string qPath = ioDir.empty() ? "input/q.bin" : ioDir + "/q.bin"; |
| std::string kPath = ioDir.empty() ? "input/k.bin" : ioDir + "/k.bin"; |
| std::string vPath = ioDir.empty() ? "input/v.bin" : ioDir + "/v.bin"; |
| std::string oPath = ioDir.empty() ? "output/output_o.bin" : ioDir + "/output_o.bin"; |
| auto qData = readFile(qPath.c_str()); |
| auto kData = readFile(kPath.c_str()); |
| auto vData = readFile(vPath.c_str()); |
|
|
| |
| void *qDev, *kDev, *vDev, *oDev; |
| CHECK_ACL(aclrtMalloc(&qDev, qBytes, ACL_MEM_MALLOC_HUGE_FIRST)); |
| CHECK_ACL(aclrtMalloc(&kDev, kvBytes, ACL_MEM_MALLOC_HUGE_FIRST)); |
| CHECK_ACL(aclrtMalloc(&vDev, kvBytes, ACL_MEM_MALLOC_HUGE_FIRST)); |
| CHECK_ACL(aclrtMalloc(&oDev, oBytes, ACL_MEM_MALLOC_HUGE_FIRST)); |
|
|
| CHECK_ACL(aclrtMemcpy(qDev, qBytes, qData.data(), qBytes, ACL_MEMCPY_HOST_TO_DEVICE)); |
| CHECK_ACL(aclrtMemcpy(kDev, kvBytes, kData.data(), kvBytes, ACL_MEMCPY_HOST_TO_DEVICE)); |
| CHECK_ACL(aclrtMemcpy(vDev, kvBytes, vData.data(), kvBytes, ACL_MEMCPY_HOST_TO_DEVICE)); |
|
|
| |
| int64_t qShape[] = {numHeads, qLen, headDim}; |
| int64_t kvShape[] = {numHeads, kvLen, headDim}; |
| int64_t oShape[] = {numHeads, qLen, headDim}; |
|
|
| int64_t qStrides[] = {(int64_t)qLen * headDim, (int64_t)headDim, 1}; |
| int64_t kvStrides[] = {(int64_t)kvLen * headDim, (int64_t)headDim, 1}; |
| int64_t oStrides[] = {(int64_t)qLen * headDim, (int64_t)headDim, 1}; |
|
|
| aclTensor* qTensor = aclCreateTensor(qShape, 3, ACL_FLOAT16, qStrides, 0, |
| ACL_FORMAT_ND, qShape, 3, qDev); |
| aclTensor* kTensor = aclCreateTensor(kvShape, 3, ACL_FLOAT16, kvStrides, 0, |
| ACL_FORMAT_ND, kvShape, 3, kDev); |
| aclTensor* vTensor = aclCreateTensor(kvShape, 3, ACL_FLOAT16, kvStrides, 0, |
| ACL_FORMAT_ND, kvShape, 3, vDev); |
| aclTensor* oTensor = aclCreateTensor(oShape, 3, ACL_FLOAT16, oStrides, 0, |
| ACL_FORMAT_ND, oShape, 3, oDev); |
|
|
| |
| uint64_t workspaceSize = 0; |
| aclOpExecutor* executor = nullptr; |
|
|
| CHECK_ACL(aclnnSmallqFlashAttentionGetWorkspaceSize( |
| qTensor, kTensor, vTensor, oTensor, &workspaceSize, &executor)); |
|
|
| printf("[INFO] workspaceSize = %lu\n", workspaceSize); |
|
|
| void* workspaceDev = nullptr; |
| if (workspaceSize > 0) { |
| CHECK_ACL(aclrtMalloc(&workspaceDev, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST)); |
| } |
|
|
| CHECK_ACL(aclnnSmallqFlashAttention(workspaceDev, workspaceSize, executor, stream)); |
| CHECK_ACL(aclrtSynchronizeStream(stream)); |
|
|
| |
| std::vector<uint8_t> oHost(oBytes); |
| CHECK_ACL(aclrtMemcpy(oHost.data(), oBytes, oDev, oBytes, ACL_MEMCPY_DEVICE_TO_HOST)); |
| writeFile(oPath.c_str(), oHost.data(), oBytes); |
|
|
| printf("[INFO] Output written to %s\n", oPath.c_str()); |
|
|
| |
| aclDestroyTensor(qTensor); |
| aclDestroyTensor(kTensor); |
| aclDestroyTensor(vTensor); |
| aclDestroyTensor(oTensor); |
| if (workspaceDev) aclrtFree(workspaceDev); |
| aclrtFree(qDev); |
| aclrtFree(kDev); |
| aclrtFree(vDev); |
| aclrtFree(oDev); |
| aclrtDestroyStream(stream); |
| aclrtResetDevice(deviceId); |
| aclFinalize(); |
| return 0; |
| } |
|
|