| #include <android/log.h> |
| #include <jni.h> |
| #include <iomanip> |
| #include <cmath> |
| #include <string> |
| #include <unistd.h> |
| #include <sampling.h> |
|
|
| #include "logging.h" |
| #include "chat.h" |
| #include "common.h" |
| #include "llama.h" |
|
|
| template<class T> |
| static std::string join(const std::vector<T> &values, const std::string &delim) { |
| std::ostringstream str; |
| for (size_t i = 0; i < values.size(); i++) { |
| str << values[i]; |
| if (i < values.size() - 1) { str << delim; } |
| } |
| return str.str(); |
| } |
|
|
| |
| |
| |
| constexpr int N_THREADS_MIN = 2; |
| constexpr int N_THREADS_MAX = 4; |
| constexpr int N_THREADS_HEADROOM = 2; |
|
|
| constexpr int DEFAULT_CONTEXT_SIZE = 8192; |
| constexpr int OVERFLOW_HEADROOM = 4; |
| constexpr int BATCH_SIZE = 512; |
| constexpr float DEFAULT_SAMPLER_TEMP = 0.3f; |
|
|
| static llama_model * g_model; |
| static llama_context * g_context; |
| static llama_batch g_batch; |
| static common_chat_templates_ptr g_chat_templates; |
| static common_sampler * g_sampler; |
|
|
| extern "C" |
| JNIEXPORT void JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject , jstring nativeLibDir) { |
| |
| llama_log_set(aichat_android_log_callback, nullptr); |
|
|
| |
| const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0); |
| LOGi("Loading backends from %s", path_to_backend); |
| ggml_backend_load_all_from_path(path_to_backend); |
| env->ReleaseStringUTFChars(nativeLibDir, path_to_backend); |
|
|
| |
| llama_backend_init(); |
| LOGi("Backend initiated; Log handler set."); |
| } |
|
|
| extern "C" |
| JNIEXPORT jint JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) { |
| llama_model_params model_params = llama_model_default_params(); |
|
|
| const auto *model_path = env->GetStringUTFChars(jmodel_path, 0); |
| LOGd("%s: Loading model from: \n%s\n", __func__, model_path); |
|
|
| auto *model = llama_model_load_from_file(model_path, model_params); |
| env->ReleaseStringUTFChars(jmodel_path, model_path); |
| if (!model) { |
| return 1; |
| } |
| g_model = model; |
| return 0; |
| } |
|
|
| static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) { |
| if (!model) { |
| LOGe("%s: model cannot be null", __func__); |
| return nullptr; |
| } |
|
|
| |
| const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX, |
| (int) sysconf(_SC_NPROCESSORS_ONLN) - |
| N_THREADS_HEADROOM)); |
| LOGi("%s: Using %d threads", __func__, n_threads); |
|
|
| |
| llama_context_params ctx_params = llama_context_default_params(); |
| const int trained_context_size = llama_model_n_ctx_train(model); |
| if (n_ctx > trained_context_size) { |
| LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...", |
| __func__, trained_context_size, n_ctx); |
| } |
| ctx_params.n_ctx = n_ctx; |
| ctx_params.n_batch = BATCH_SIZE; |
| ctx_params.n_ubatch = BATCH_SIZE; |
| ctx_params.n_threads = n_threads; |
| ctx_params.n_threads_batch = n_threads; |
| auto *context = llama_init_from_model(g_model, ctx_params); |
| if (context == nullptr) { |
| LOGe("%s: llama_new_context_with_model() returned null)", __func__); |
| } |
| return context; |
| } |
|
|
| static common_sampler *new_sampler(float temp) { |
| common_params_sampling sparams; |
| sparams.temp = temp; |
| return common_sampler_init(g_model, sparams); |
| } |
|
|
| extern "C" |
| JNIEXPORT jint JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * , jobject ) { |
| auto *context = init_context(g_model); |
| if (!context) { return 1; } |
| g_context = context; |
| g_batch = llama_batch_init(BATCH_SIZE, 0, 1); |
| g_chat_templates = common_chat_templates_init(g_model, ""); |
| g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP); |
| return 0; |
| } |
|
|
| static std::string get_backend() { |
| std::vector<std::string> backends; |
| for (size_t i = 0; i < ggml_backend_reg_count(); i++) { |
| auto *reg = ggml_backend_reg_get(i); |
| std::string name = ggml_backend_reg_name(reg); |
| if (name != "CPU") { |
| backends.push_back(ggml_backend_reg_name(reg)); |
| } |
| } |
| return backends.empty() ? "CPU" : join(backends, ","); |
| } |
|
|
| extern "C" |
| JNIEXPORT jstring JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject ) { |
| return env->NewStringUTF(llama_print_system_info()); |
| } |
|
|
| extern "C" |
| JNIEXPORT jstring JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject , jint pp, jint tg, |
| jint pl, jint nr) { |
| auto *context = init_context(g_model, pp); |
| if (!context) { |
| const auto *const err_msg = "Fail to init_context! Bench aborted."; |
| LOGe(err_msg); |
| return env->NewStringUTF(err_msg); |
| } |
|
|
| auto pp_avg = 0.0; |
| auto tg_avg = 0.0; |
| auto pp_std = 0.0; |
| auto tg_std = 0.0; |
|
|
| const uint32_t n_ctx = llama_n_ctx(context); |
| LOGi("n_ctx = %d", n_ctx); |
|
|
| int i, j; |
| int nri; |
| for (nri = 0; nri < nr; nri++) { |
| LOGi("Benchmark prompt processing (pp = %d)", pp); |
|
|
| common_batch_clear(g_batch); |
|
|
| const int n_tokens = pp; |
| for (i = 0; i < n_tokens; i++) { |
| common_batch_add(g_batch, 0, i, {0}, false); |
| } |
|
|
| g_batch.logits[g_batch.n_tokens - 1] = true; |
| llama_memory_clear(llama_get_memory(context), false); |
|
|
| const auto t_pp_start = ggml_time_us(); |
| if (llama_decode(context, g_batch) != 0) { |
| LOGe("llama_decode() failed during prompt processing"); |
| } |
| const auto t_pp_end = ggml_time_us(); |
|
|
| |
|
|
| LOGi("Benchmark text generation (tg = %d)", tg); |
|
|
| llama_memory_clear(llama_get_memory(context), false); |
| const auto t_tg_start = ggml_time_us(); |
| for (i = 0; i < tg; i++) { |
| common_batch_clear(g_batch); |
| for (j = 0; j < pl; j++) { |
| common_batch_add(g_batch, 0, i, {j}, true); |
| } |
|
|
| if (llama_decode(context, g_batch) != 0) { |
| LOGe("llama_decode() failed during text generation"); |
| } |
| } |
| const auto t_tg_end = ggml_time_us(); |
|
|
| llama_memory_clear(llama_get_memory(context), false); |
|
|
| const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; |
| const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; |
|
|
| const auto speed_pp = double(pp) / t_pp; |
| const auto speed_tg = double(pl * tg) / t_tg; |
|
|
| pp_avg += speed_pp; |
| tg_avg += speed_tg; |
|
|
| pp_std += speed_pp * speed_pp; |
| tg_std += speed_tg * speed_tg; |
|
|
| LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); |
| } |
|
|
| llama_free(context); |
|
|
| pp_avg /= double(nr); |
| tg_avg /= double(nr); |
|
|
| if (nr > 1) { |
| pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); |
| tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); |
| } else { |
| pp_std = 0; |
| tg_std = 0; |
| } |
|
|
| char model_desc[128]; |
| llama_model_desc(g_model, model_desc, sizeof(model_desc)); |
|
|
| const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0; |
| const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9; |
|
|
| const auto backend = get_backend(); |
| std::stringstream result; |
| result << std::setprecision(3); |
| result << "| model | size | params | backend | test | t/s |\n"; |
| result << "| --- | --- | --- | --- | --- | --- |\n"; |
| result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " |
| << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; |
| result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " |
| << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; |
| return env->NewStringUTF(result.str().c_str()); |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| constexpr const char *ROLE_SYSTEM = "system"; |
| constexpr const char *ROLE_USER = "user"; |
| constexpr const char *ROLE_ASSISTANT = "assistant"; |
|
|
| static std::vector<common_chat_msg> chat_msgs; |
| static llama_pos system_prompt_position; |
| static llama_pos current_position; |
|
|
| static void reset_long_term_states(const bool clear_kv_cache = true) { |
| chat_msgs.clear(); |
| system_prompt_position = 0; |
| current_position = 0; |
|
|
| if (clear_kv_cache) |
| llama_memory_clear(llama_get_memory(g_context), false); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| static void shift_context() { |
| const int n_discard = (current_position - system_prompt_position) / 2; |
| LOGi("%s: Discarding %d tokens", __func__, n_discard); |
| llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard); |
| llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard); |
| current_position -= n_discard; |
| LOGi("%s: Context shifting done! Current position: %d", __func__, current_position); |
| } |
|
|
| static std::string chat_add_and_format(const std::string &role, const std::string &content) { |
| common_chat_msg new_msg; |
| new_msg.role = role; |
| new_msg.content = content; |
| auto formatted = common_chat_format_single( |
| g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, false); |
| chat_msgs.push_back(new_msg); |
| LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str()); |
| return formatted; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| static llama_pos stop_generation_position; |
| static std::string cached_token_chars; |
| static std::ostringstream assistant_ss; |
|
|
| static void reset_short_term_states() { |
| stop_generation_position = 0; |
| cached_token_chars.clear(); |
| assistant_ss.str(""); |
| } |
|
|
| static int decode_tokens_in_batches( |
| llama_context *context, |
| llama_batch &batch, |
| const llama_tokens &tokens, |
| const llama_pos start_pos, |
| const bool compute_last_logit = false) { |
| |
| LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos); |
| for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { |
| const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); |
| common_batch_clear(batch); |
| LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i); |
|
|
| |
| if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) { |
| LOGw("%s: Current batch won't fit into context! Shifting...", __func__); |
| shift_context(); |
| } |
|
|
| |
| for (int j = 0; j < cur_batch_size; j++) { |
| const llama_token token_id = tokens[i + j]; |
| const llama_pos position = start_pos + i + j; |
| const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1); |
| common_batch_add(batch, token_id, position, {0}, want_logit); |
| } |
|
|
| |
| const int decode_result = llama_decode(context, batch); |
| if (decode_result) { |
| LOGe("%s: llama_decode failed w/ %d", __func__, decode_result); |
| return 1; |
| } |
| } |
| return 0; |
| } |
|
|
| extern "C" |
| JNIEXPORT jint JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt( |
| JNIEnv *env, |
| jobject , |
| jstring jsystem_prompt |
| ) { |
| |
| reset_long_term_states(); |
| reset_short_term_states(); |
|
|
| |
| const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr); |
| LOGd("%s: System prompt received: \n%s", __func__, system_prompt); |
| std::string formatted_system_prompt(system_prompt); |
| env->ReleaseStringUTFChars(jsystem_prompt, system_prompt); |
|
|
| |
| const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get()); |
| if (has_chat_template) { |
| formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt); |
| } |
|
|
| |
| const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, |
| has_chat_template, has_chat_template); |
| for (auto id: system_tokens) { |
| LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); |
| } |
|
|
| |
| const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM; |
| if ((int) system_tokens.size() > max_batch_size) { |
| LOGe("%s: System prompt too long for context! %d tokens, max: %d", |
| __func__, (int) system_tokens.size(), max_batch_size); |
| return 1; |
| } |
|
|
| |
| if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) { |
| LOGe("%s: llama_decode() failed!", __func__); |
| return 2; |
| } |
|
|
| |
| system_prompt_position = current_position = (int) system_tokens.size(); |
| return 0; |
| } |
|
|
| extern "C" |
| JNIEXPORT jint JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt( |
| JNIEnv *env, |
| jobject , |
| jstring juser_prompt, |
| jint n_predict |
| ) { |
| |
| reset_short_term_states(); |
|
|
| |
| const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr); |
| LOGd("%s: User prompt received: \n%s", __func__, user_prompt); |
| std::string formatted_user_prompt(user_prompt); |
| env->ReleaseStringUTFChars(juser_prompt, user_prompt); |
|
|
| |
| const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get()); |
| if (has_chat_template) { |
| formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt); |
| } |
|
|
| |
| auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template); |
| for (auto id: user_tokens) { |
| LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); |
| } |
|
|
| |
| const int user_prompt_size = (int) user_tokens.size(); |
| const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM; |
| if (user_prompt_size > max_batch_size) { |
| const int skipped_tokens = user_prompt_size - max_batch_size; |
| user_tokens.resize(max_batch_size); |
| LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens); |
| } |
|
|
| |
| if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) { |
| LOGe("%s: llama_decode() failed!", __func__); |
| return 2; |
| } |
|
|
| |
| current_position += user_prompt_size; |
| stop_generation_position = current_position + user_prompt_size + n_predict; |
| return 0; |
| } |
|
|
| static bool is_valid_utf8(const char *string) { |
| if (!string) { return true; } |
|
|
| const auto *bytes = (const unsigned char *) string; |
| int num; |
|
|
| while (*bytes != 0x00) { |
| if ((*bytes & 0x80) == 0x00) { |
| |
| num = 1; |
| } else if ((*bytes & 0xE0) == 0xC0) { |
| |
| num = 2; |
| } else if ((*bytes & 0xF0) == 0xE0) { |
| |
| num = 3; |
| } else if ((*bytes & 0xF8) == 0xF0) { |
| |
| num = 4; |
| } else { |
| return false; |
| } |
|
|
| bytes += 1; |
| for (int i = 1; i < num; ++i) { |
| if ((*bytes & 0xC0) != 0x80) { |
| return false; |
| } |
| bytes += 1; |
| } |
| } |
| return true; |
| } |
|
|
| extern "C" |
| JNIEXPORT jstring JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken( |
| JNIEnv *env, |
| jobject |
| ) { |
| |
| if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) { |
| LOGw("%s: Context full! Shifting...", __func__); |
| shift_context(); |
| } |
|
|
| |
| if (current_position >= stop_generation_position) { |
| LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position); |
| return nullptr; |
| } |
|
|
| |
| const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1); |
| common_sampler_accept(g_sampler, new_token_id, true); |
|
|
| |
| common_batch_clear(g_batch); |
| common_batch_add(g_batch, new_token_id, current_position, {0}, true); |
| if (llama_decode(g_context, g_batch) != 0) { |
| LOGe("%s: llama_decode() failed for generated token", __func__); |
| return nullptr; |
| } |
|
|
| |
| current_position++; |
|
|
| |
| if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) { |
| LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id); |
| chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str()); |
| return nullptr; |
| } |
|
|
| |
| auto new_token_chars = common_token_to_piece(g_context, new_token_id); |
| cached_token_chars += new_token_chars; |
|
|
| |
| jstring result = nullptr; |
| if (is_valid_utf8(cached_token_chars.c_str())) { |
| result = env->NewStringUTF(cached_token_chars.c_str()); |
| LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str()); |
|
|
| assistant_ss << cached_token_chars; |
| cached_token_chars.clear(); |
| } else { |
| LOGv("id: %d,\tappend to cache", new_token_id); |
| result = env->NewStringUTF(""); |
| } |
| return result; |
| } |
|
|
|
|
| extern "C" |
| JNIEXPORT void JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * , jobject ) { |
| |
| reset_long_term_states(); |
| reset_short_term_states(); |
|
|
| |
| common_sampler_free(g_sampler); |
| g_chat_templates.reset(); |
| llama_batch_free(g_batch); |
| llama_free(g_context); |
| llama_model_free(g_model); |
| } |
|
|
| extern "C" |
| JNIEXPORT void JNICALL |
| Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *, jobject ) { |
| llama_backend_free(); |
| } |
|
|