adasdadsd's picture
Replace with cleaned vector-only implementation
86d4a3b verified
Raw
History Blame Contribute Delete
5.65 kB
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <vector>
#include <string>
#include "acl/acl.h"
#include "aclnn_smallq_flash_attention.h"
#define CHECK_ACL(expr) \
do { \
auto _ret = (expr); \
if (_ret != 0) { \
fprintf(stderr, "[ACL ERROR] %s:%d code=%d\n", \
__FILE__, __LINE__, (int)_ret); \
exit(1); \
} \
} while (0)
static int getEnvInt(const char* name, int def) {
const char* v = getenv(name);
return v ? atoi(v) : def;
}
static std::vector<uint8_t> readFile(const char* path) {
std::ifstream f(path, std::ios::binary | std::ios::ate);
if (!f) { fprintf(stderr, "cannot open %s\n", path); exit(1); }
size_t sz = f.tellg();
f.seekg(0);
std::vector<uint8_t> buf(sz);
f.read(reinterpret_cast<char*>(buf.data()), sz);
return buf;
}
static void writeFile(const char* path, const void* data, size_t sz) {
std::ofstream f(path, std::ios::binary);
f.write(reinterpret_cast<const char*>(data), sz);
}
int main() {
int numHeads = getEnvInt("NUM_HEADS", 128);
int qLen = getEnvInt("Q_LEN", 5);
int headDim = getEnvInt("HEAD_DIM", 128);
int kvLen = getEnvInt("KV_LEN", 8192);
printf("[INFO] NUM_HEADS=%d, Q_LEN=%d, HEAD_DIM=%d, KV_LEN=%d\n",
numHeads, qLen, headDim, kvLen);
size_t qBytes = (size_t)numHeads * qLen * headDim * sizeof(uint16_t);
size_t kvBytes = (size_t)numHeads * kvLen * headDim * sizeof(uint16_t);
size_t oBytes = qBytes;
CHECK_ACL(aclInit(nullptr));
int deviceId = getEnvInt("DEVICE_ID", 1);
CHECK_ACL(aclrtSetDevice(deviceId));
aclrtStream stream = nullptr;
CHECK_ACL(aclrtCreateStream(&stream));
// Read input data (allow override via env)
const char* inDir = getenv("IO_DIR");
std::string ioDir = inDir ? std::string(inDir) : std::string("");
std::string qPath = ioDir.empty() ? "input/q.bin" : ioDir + "/q.bin";
std::string kPath = ioDir.empty() ? "input/k.bin" : ioDir + "/k.bin";
std::string vPath = ioDir.empty() ? "input/v.bin" : ioDir + "/v.bin";
std::string oPath = ioDir.empty() ? "output/output_o.bin" : ioDir + "/output_o.bin";
auto qData = readFile(qPath.c_str());
auto kData = readFile(kPath.c_str());
auto vData = readFile(vPath.c_str());
// Allocate device memory
void *qDev, *kDev, *vDev, *oDev;
CHECK_ACL(aclrtMalloc(&qDev, qBytes, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&kDev, kvBytes, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&vDev, kvBytes, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&oDev, oBytes, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMemcpy(qDev, qBytes, qData.data(), qBytes, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ACL(aclrtMemcpy(kDev, kvBytes, kData.data(), kvBytes, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ACL(aclrtMemcpy(vDev, kvBytes, vData.data(), kvBytes, ACL_MEMCPY_HOST_TO_DEVICE));
// Create aclTensors
int64_t qShape[] = {numHeads, qLen, headDim};
int64_t kvShape[] = {numHeads, kvLen, headDim};
int64_t oShape[] = {numHeads, qLen, headDim};
int64_t qStrides[] = {(int64_t)qLen * headDim, (int64_t)headDim, 1};
int64_t kvStrides[] = {(int64_t)kvLen * headDim, (int64_t)headDim, 1};
int64_t oStrides[] = {(int64_t)qLen * headDim, (int64_t)headDim, 1};
aclTensor* qTensor = aclCreateTensor(qShape, 3, ACL_FLOAT16, qStrides, 0,
ACL_FORMAT_ND, qShape, 3, qDev);
aclTensor* kTensor = aclCreateTensor(kvShape, 3, ACL_FLOAT16, kvStrides, 0,
ACL_FORMAT_ND, kvShape, 3, kDev);
aclTensor* vTensor = aclCreateTensor(kvShape, 3, ACL_FLOAT16, kvStrides, 0,
ACL_FORMAT_ND, kvShape, 3, vDev);
aclTensor* oTensor = aclCreateTensor(oShape, 3, ACL_FLOAT16, oStrides, 0,
ACL_FORMAT_ND, oShape, 3, oDev);
// Call aclnn API
uint64_t workspaceSize = 0;
aclOpExecutor* executor = nullptr;
CHECK_ACL(aclnnSmallqFlashAttentionGetWorkspaceSize(
qTensor, kTensor, vTensor, oTensor, &workspaceSize, &executor));
printf("[INFO] workspaceSize = %lu\n", workspaceSize);
void* workspaceDev = nullptr;
if (workspaceSize > 0) {
CHECK_ACL(aclrtMalloc(&workspaceDev, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
}
CHECK_ACL(aclnnSmallqFlashAttention(workspaceDev, workspaceSize, executor, stream));
CHECK_ACL(aclrtSynchronizeStream(stream));
// Copy output back
std::vector<uint8_t> oHost(oBytes);
CHECK_ACL(aclrtMemcpy(oHost.data(), oBytes, oDev, oBytes, ACL_MEMCPY_DEVICE_TO_HOST));
writeFile(oPath.c_str(), oHost.data(), oBytes);
printf("[INFO] Output written to %s\n", oPath.c_str());
// Cleanup
aclDestroyTensor(qTensor);
aclDestroyTensor(kTensor);
aclDestroyTensor(vTensor);
aclDestroyTensor(oTensor);
if (workspaceDev) aclrtFree(workspaceDev);
aclrtFree(qDev);
aclrtFree(kDev);
aclrtFree(vDev);
aclrtFree(oDev);
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}