FireRedASR-AED / cpp /main.cpp
inoryQwQ's picture
Shorten kv cache
90f0b29
#include "Encoder.hpp"
#include "DecoderMain.hpp"
#include "DecoderLoop.hpp"
#include <stdio.h>
#include <ctime>
#include <sys/time.h>
#include <ax_sys_api.h>
static double get_current_time()
{
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0;
}
int main(int argc, char** argv) {
int ret = AX_SYS_Init();
if (0 != ret) {
fprintf(stderr, "AX_SYS_Init failed! ret = 0x%x\n", ret);
return -1;
}
AX_ENGINE_NPU_ATTR_T npu_attr;
memset(&npu_attr, 0, sizeof(npu_attr));
npu_attr.eHardMode = static_cast<AX_ENGINE_NPU_MODE_T>(0);
ret = AX_ENGINE_Init(&npu_attr);
if (0 != ret) {
fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret);
return -1;
}
Encoder encoder;
DecoderMain decoder_main;
DecoderLoop decoder_loop;
double start, end;
double whole_start, whole_end;
start = get_current_time();
if (0 != encoder.Init("../axmodel/encoder.axmodel")) {
printf("Init encoder failed!\n");
return -1;
}
end = get_current_time();
printf("Load encoder take %.2fms\n", end - start);
start = get_current_time();
if (0 != decoder_main.Init("../axmodel/decoder_main.axmodel")) {
printf("Init decoder_main failed!\n");
return -1;
}
end = get_current_time();
printf("Load decoder_main take %.2fms\n", end - start);
start = get_current_time();
if (0 != decoder_loop.Init("../axmodel/decoder_loop.axmodel")) {
printf("Init decoder_loop failed!\n");
return -1;
}
end = get_current_time();
printf("Load decoder_loop take %.2fms\n", end - start);
std::vector<float> encoder_inputs(encoder.GetInputSize(0) / sizeof(float));
std::vector<float> encoder_input_lengths(encoder.GetInputSize(1) / sizeof(float));
encoder_input_lengths[0] = 100;
std::vector<float> n_layer_cross_k(encoder.GetOutputSize(0) / sizeof(float));
std::vector<float> n_layer_cross_v(encoder.GetOutputSize(1) / sizeof(float));
std::vector<float> cross_attn_mask(encoder.GetOutputSize(2) / sizeof(float));
start = get_current_time();
whole_start = start;
encoder.SetInput(encoder_inputs.data(), 0);
encoder.SetInput(encoder_input_lengths.data(), 1);
encoder.Run();
// encoder.GetOutput(n_layer_cross_k.data(), 0);
// encoder.GetOutput(n_layer_cross_v.data(), 1);
// encoder.GetOutput(cross_attn_mask.data(), 2);
end = get_current_time();
printf("Run encoder take %.2fms\n", end - start);
std::vector<int> tokens(decoder_main.GetInputSize(0) / sizeof(int));
std::vector<int> logits(decoder_main.GetOutputSize(0) / sizeof(int));
std::vector<float> n_layer_self_k_cache(decoder_main.GetOutputSize(1) / sizeof(float));
std::vector<float> n_layer_self_v_cache(decoder_main.GetOutputSize(2) / sizeof(float));
start = get_current_time();
decoder_main.SetInput(tokens.data(), 0);
// decoder_main.SetInput(encoder.GetOutputPtr(0), 1);
// decoder_main.SetInput(encoder.GetOutputPtr(1), 2);
// decoder_main.SetInput(encoder.GetOutputPtr(2), 3);
decoder_main.SetInput(n_layer_cross_k.data(), 1);
decoder_main.SetInput(n_layer_cross_v.data(), 2);
decoder_main.SetInput(cross_attn_mask.data(), 3);
decoder_main.Run();
decoder_main.GetOutput(logits.data(), 0);
// decoder_main.GetOutput(n_layer_self_k_cache.data(), 1);
// decoder_main.GetOutput(n_layer_self_v_cache.data(), 2);
end = get_current_time();
printf("Run decoder_main take %.2fms\n", end - start);
std::vector<float> pe(decoder_loop.GetOutputSize(5) / sizeof(float));
std::vector<float> self_attn_mask(decoder_loop.GetOutputSize(6) / sizeof(float));
decoder_loop.SetInput(n_layer_cross_k.data(), 3);
decoder_loop.SetInput(n_layer_cross_v.data(), 4);
for (int i = 0; i < 14; i++) {
// "tokens": tokens,
// "in_n_layer_self_k_cache": n_layer_self_k_cache,
// "in_n_layer_self_v_cache": n_layer_self_v_cache,
// "n_layer_cross_k": n_layer_cross_k_cache,
// "n_layer_cross_v": n_layer_cross_v_cache,
// "pe": pe,
// "self_attn_mask": self_attn_mask,
// "cross_attn_mask": cross_attn_mask,
start = get_current_time();
decoder_loop.SetInput(tokens.data(), 0);
decoder_loop.SetInput(decoder_loop.GetOutputPtr(1), 1);
decoder_loop.SetInput(decoder_loop.GetOutputPtr(2), 2);
// decoder_loop.SetInput(encoder.GetOutputPtr(0), 3);
// decoder_loop.SetInput(encoder.GetOutputPtr(1), 4);
decoder_loop.Run();
decoder_loop.GetOutput(logits.data(), 0);
// decoder_main.GetOutput(n_layer_self_k_cache.data(), 1);
// decoder_main.GetOutput(n_layer_self_v_cache.data(), 2);
end = get_current_time();
printf("Run decoder_loop take %.2fms\n", end - start);
}
whole_end = get_current_time();
printf("Whole duration %.2fms\n", whole_end - whole_start);
printf("RTF: %.4f\n", (whole_end - whole_start) / 4000.0);
return 0;
}