|
|
#include "Encoder.hpp" |
|
|
#include "DecoderMain.hpp" |
|
|
#include "DecoderLoop.hpp" |
|
|
|
|
|
#include <stdio.h> |
|
|
#include <ctime> |
|
|
#include <sys/time.h> |
|
|
|
|
|
#include <ax_sys_api.h> |
|
|
|
|
|
static double get_current_time() |
|
|
{ |
|
|
struct timeval tv; |
|
|
gettimeofday(&tv, NULL); |
|
|
|
|
|
return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; |
|
|
} |
|
|
|
|
|
int main(int argc, char** argv) { |
|
|
int ret = AX_SYS_Init(); |
|
|
if (0 != ret) { |
|
|
fprintf(stderr, "AX_SYS_Init failed! ret = 0x%x\n", ret); |
|
|
return -1; |
|
|
} |
|
|
|
|
|
AX_ENGINE_NPU_ATTR_T npu_attr; |
|
|
memset(&npu_attr, 0, sizeof(npu_attr)); |
|
|
npu_attr.eHardMode = static_cast<AX_ENGINE_NPU_MODE_T>(0); |
|
|
ret = AX_ENGINE_Init(&npu_attr); |
|
|
if (0 != ret) { |
|
|
fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret); |
|
|
return -1; |
|
|
} |
|
|
|
|
|
Encoder encoder; |
|
|
DecoderMain decoder_main; |
|
|
DecoderLoop decoder_loop; |
|
|
|
|
|
double start, end; |
|
|
double whole_start, whole_end; |
|
|
|
|
|
start = get_current_time(); |
|
|
if (0 != encoder.Init("../axmodel/encoder.axmodel")) { |
|
|
printf("Init encoder failed!\n"); |
|
|
return -1; |
|
|
} |
|
|
end = get_current_time(); |
|
|
printf("Load encoder take %.2fms\n", end - start); |
|
|
|
|
|
start = get_current_time(); |
|
|
if (0 != decoder_main.Init("../axmodel/decoder_main.axmodel")) { |
|
|
printf("Init decoder_main failed!\n"); |
|
|
return -1; |
|
|
} |
|
|
end = get_current_time(); |
|
|
printf("Load decoder_main take %.2fms\n", end - start); |
|
|
|
|
|
start = get_current_time(); |
|
|
if (0 != decoder_loop.Init("../axmodel/decoder_loop.axmodel")) { |
|
|
printf("Init decoder_loop failed!\n"); |
|
|
return -1; |
|
|
} |
|
|
end = get_current_time(); |
|
|
printf("Load decoder_loop take %.2fms\n", end - start); |
|
|
|
|
|
std::vector<float> encoder_inputs(encoder.GetInputSize(0) / sizeof(float)); |
|
|
std::vector<float> encoder_input_lengths(encoder.GetInputSize(1) / sizeof(float)); |
|
|
encoder_input_lengths[0] = 100; |
|
|
|
|
|
std::vector<float> n_layer_cross_k(encoder.GetOutputSize(0) / sizeof(float)); |
|
|
std::vector<float> n_layer_cross_v(encoder.GetOutputSize(1) / sizeof(float)); |
|
|
std::vector<float> cross_attn_mask(encoder.GetOutputSize(2) / sizeof(float)); |
|
|
|
|
|
start = get_current_time(); |
|
|
whole_start = start; |
|
|
encoder.SetInput(encoder_inputs.data(), 0); |
|
|
encoder.SetInput(encoder_input_lengths.data(), 1); |
|
|
encoder.Run(); |
|
|
|
|
|
|
|
|
|
|
|
end = get_current_time(); |
|
|
printf("Run encoder take %.2fms\n", end - start); |
|
|
|
|
|
std::vector<int> tokens(decoder_main.GetInputSize(0) / sizeof(int)); |
|
|
|
|
|
std::vector<int> logits(decoder_main.GetOutputSize(0) / sizeof(int)); |
|
|
std::vector<float> n_layer_self_k_cache(decoder_main.GetOutputSize(1) / sizeof(float)); |
|
|
std::vector<float> n_layer_self_v_cache(decoder_main.GetOutputSize(2) / sizeof(float)); |
|
|
|
|
|
start = get_current_time(); |
|
|
decoder_main.SetInput(tokens.data(), 0); |
|
|
|
|
|
|
|
|
|
|
|
decoder_main.SetInput(n_layer_cross_k.data(), 1); |
|
|
decoder_main.SetInput(n_layer_cross_v.data(), 2); |
|
|
decoder_main.SetInput(cross_attn_mask.data(), 3); |
|
|
decoder_main.Run(); |
|
|
decoder_main.GetOutput(logits.data(), 0); |
|
|
|
|
|
|
|
|
end = get_current_time(); |
|
|
printf("Run decoder_main take %.2fms\n", end - start); |
|
|
|
|
|
std::vector<float> pe(decoder_loop.GetOutputSize(5) / sizeof(float)); |
|
|
std::vector<float> self_attn_mask(decoder_loop.GetOutputSize(6) / sizeof(float)); |
|
|
|
|
|
decoder_loop.SetInput(n_layer_cross_k.data(), 3); |
|
|
decoder_loop.SetInput(n_layer_cross_v.data(), 4); |
|
|
for (int i = 0; i < 14; i++) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start = get_current_time(); |
|
|
decoder_loop.SetInput(tokens.data(), 0); |
|
|
decoder_loop.SetInput(decoder_loop.GetOutputPtr(1), 1); |
|
|
decoder_loop.SetInput(decoder_loop.GetOutputPtr(2), 2); |
|
|
|
|
|
|
|
|
|
|
|
decoder_loop.Run(); |
|
|
decoder_loop.GetOutput(logits.data(), 0); |
|
|
|
|
|
|
|
|
end = get_current_time(); |
|
|
printf("Run decoder_loop take %.2fms\n", end - start); |
|
|
} |
|
|
|
|
|
whole_end = get_current_time(); |
|
|
printf("Whole duration %.2fms\n", whole_end - whole_start); |
|
|
printf("RTF: %.4f\n", (whole_end - whole_start) / 4000.0); |
|
|
|
|
|
return 0; |
|
|
} |