File size: 5,106 Bytes
90f0b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#include "Encoder.hpp"
#include "DecoderMain.hpp"
#include "DecoderLoop.hpp"
#include <stdio.h>
#include <ctime>
#include <sys/time.h>
#include <ax_sys_api.h>
static double get_current_time()
{
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0;
}
int main(int argc, char** argv) {
int ret = AX_SYS_Init();
if (0 != ret) {
fprintf(stderr, "AX_SYS_Init failed! ret = 0x%x\n", ret);
return -1;
}
AX_ENGINE_NPU_ATTR_T npu_attr;
memset(&npu_attr, 0, sizeof(npu_attr));
npu_attr.eHardMode = static_cast<AX_ENGINE_NPU_MODE_T>(0);
ret = AX_ENGINE_Init(&npu_attr);
if (0 != ret) {
fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret);
return -1;
}
Encoder encoder;
DecoderMain decoder_main;
DecoderLoop decoder_loop;
double start, end;
double whole_start, whole_end;
start = get_current_time();
if (0 != encoder.Init("../axmodel/encoder.axmodel")) {
printf("Init encoder failed!\n");
return -1;
}
end = get_current_time();
printf("Load encoder take %.2fms\n", end - start);
start = get_current_time();
if (0 != decoder_main.Init("../axmodel/decoder_main.axmodel")) {
printf("Init decoder_main failed!\n");
return -1;
}
end = get_current_time();
printf("Load decoder_main take %.2fms\n", end - start);
start = get_current_time();
if (0 != decoder_loop.Init("../axmodel/decoder_loop.axmodel")) {
printf("Init decoder_loop failed!\n");
return -1;
}
end = get_current_time();
printf("Load decoder_loop take %.2fms\n", end - start);
std::vector<float> encoder_inputs(encoder.GetInputSize(0) / sizeof(float));
std::vector<float> encoder_input_lengths(encoder.GetInputSize(1) / sizeof(float));
encoder_input_lengths[0] = 100;
std::vector<float> n_layer_cross_k(encoder.GetOutputSize(0) / sizeof(float));
std::vector<float> n_layer_cross_v(encoder.GetOutputSize(1) / sizeof(float));
std::vector<float> cross_attn_mask(encoder.GetOutputSize(2) / sizeof(float));
start = get_current_time();
whole_start = start;
encoder.SetInput(encoder_inputs.data(), 0);
encoder.SetInput(encoder_input_lengths.data(), 1);
encoder.Run();
// encoder.GetOutput(n_layer_cross_k.data(), 0);
// encoder.GetOutput(n_layer_cross_v.data(), 1);
// encoder.GetOutput(cross_attn_mask.data(), 2);
end = get_current_time();
printf("Run encoder take %.2fms\n", end - start);
std::vector<int> tokens(decoder_main.GetInputSize(0) / sizeof(int));
std::vector<int> logits(decoder_main.GetOutputSize(0) / sizeof(int));
std::vector<float> n_layer_self_k_cache(decoder_main.GetOutputSize(1) / sizeof(float));
std::vector<float> n_layer_self_v_cache(decoder_main.GetOutputSize(2) / sizeof(float));
start = get_current_time();
decoder_main.SetInput(tokens.data(), 0);
// decoder_main.SetInput(encoder.GetOutputPtr(0), 1);
// decoder_main.SetInput(encoder.GetOutputPtr(1), 2);
// decoder_main.SetInput(encoder.GetOutputPtr(2), 3);
decoder_main.SetInput(n_layer_cross_k.data(), 1);
decoder_main.SetInput(n_layer_cross_v.data(), 2);
decoder_main.SetInput(cross_attn_mask.data(), 3);
decoder_main.Run();
decoder_main.GetOutput(logits.data(), 0);
// decoder_main.GetOutput(n_layer_self_k_cache.data(), 1);
// decoder_main.GetOutput(n_layer_self_v_cache.data(), 2);
end = get_current_time();
printf("Run decoder_main take %.2fms\n", end - start);
std::vector<float> pe(decoder_loop.GetOutputSize(5) / sizeof(float));
std::vector<float> self_attn_mask(decoder_loop.GetOutputSize(6) / sizeof(float));
decoder_loop.SetInput(n_layer_cross_k.data(), 3);
decoder_loop.SetInput(n_layer_cross_v.data(), 4);
for (int i = 0; i < 14; i++) {
// "tokens": tokens,
// "in_n_layer_self_k_cache": n_layer_self_k_cache,
// "in_n_layer_self_v_cache": n_layer_self_v_cache,
// "n_layer_cross_k": n_layer_cross_k_cache,
// "n_layer_cross_v": n_layer_cross_v_cache,
// "pe": pe,
// "self_attn_mask": self_attn_mask,
// "cross_attn_mask": cross_attn_mask,
start = get_current_time();
decoder_loop.SetInput(tokens.data(), 0);
decoder_loop.SetInput(decoder_loop.GetOutputPtr(1), 1);
decoder_loop.SetInput(decoder_loop.GetOutputPtr(2), 2);
// decoder_loop.SetInput(encoder.GetOutputPtr(0), 3);
// decoder_loop.SetInput(encoder.GetOutputPtr(1), 4);
decoder_loop.Run();
decoder_loop.GetOutput(logits.data(), 0);
// decoder_main.GetOutput(n_layer_self_k_cache.data(), 1);
// decoder_main.GetOutput(n_layer_self_v_cache.data(), 2);
end = get_current_time();
printf("Run decoder_loop take %.2fms\n", end - start);
}
whole_end = get_current_time();
printf("Whole duration %.2fms\n", whole_end - whole_start);
printf("RTF: %.4f\n", (whole_end - whole_start) / 4000.0);
return 0;
} |