| | #include "arg.h"
|
| | #include "common.h"
|
| | #include "sampling.h"
|
| | #include "log.h"
|
| | #include "llama.h"
|
| |
|
| | #include <cstdio>
|
| | #include <string>
|
| | #include <vector>
|
| | #include <algorithm>
|
| |
|
| | struct ngram_data {
|
| | bool active = false;
|
| |
|
| | llama_seq_id seq_id = -1;
|
| |
|
| | std::vector<int> i_batch;
|
| |
|
| | std::vector<llama_token> tokens;
|
| | };
|
| |
|
| |
|
| | struct ngram_container {
|
| | ngram_container(int n_vocab, int N, int G) {
|
| | cnt.resize(n_vocab);
|
| | head.resize(n_vocab);
|
| | tokens.resize(n_vocab * G * (N - 1));
|
| | }
|
| |
|
| | int n_total = 0;
|
| |
|
| | std::vector<int> cnt;
|
| | std::vector<int> head;
|
| |
|
| |
|
| |
|
| | std::vector<llama_token> tokens;
|
| | };
|
| |
|
| | int main(int argc, char ** argv) {
|
| | common_params params;
|
| |
|
| | if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
| | return 1;
|
| | }
|
| |
|
| | common_init();
|
| |
|
| | const int W = 15;
|
| | const int N = 5;
|
| | const int G = 15;
|
| |
|
| |
|
| | params.n_parallel = W + G + 1;
|
| |
|
| |
|
| | params.kv_unified = true;
|
| |
|
| |
|
| | llama_backend_init();
|
| | llama_numa_init(params.numa);
|
| |
|
| |
|
| | auto llama_init = common_init_from_params(params);
|
| |
|
| | auto * model = llama_init->model();
|
| | auto * ctx = llama_init->context();
|
| |
|
| | auto * mem = llama_get_memory(ctx);
|
| |
|
| | const llama_vocab * vocab = llama_model_get_vocab(model);
|
| |
|
| |
|
| | std::vector<llama_token> inp;
|
| | std::vector<llama_token> all;
|
| |
|
| | inp = common_tokenize(ctx, params.prompt, true, true);
|
| | all = inp;
|
| |
|
| | const int max_context_size = llama_n_ctx(ctx);
|
| | const int max_tokens_list_size = max_context_size - 4;
|
| |
|
| | if ((int) inp.size() > max_tokens_list_size) {
|
| | LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
|
| | return 1;
|
| | }
|
| |
|
| | LOG("\n\n");
|
| |
|
| | for (auto id : inp) {
|
| | LOG("%s", common_token_to_piece(ctx, id).c_str());
|
| | }
|
| |
|
| | fflush(stderr);
|
| |
|
| | const int n_input = inp.size();
|
| |
|
| | const auto t_enc_start = ggml_time_us();
|
| |
|
| |
|
| | llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
| | llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
| |
|
| | for (int s = 1; s < W + G + 1; ++s) {
|
| | llama_memory_seq_cp(mem, 0, s, -1, -1);
|
| | }
|
| |
|
| | const auto t_enc_end = ggml_time_us();
|
| |
|
| | int n_predict = 0;
|
| | int n_accept = 0;
|
| |
|
| | int n_past = inp.size();
|
| |
|
| | llama_token id = 0;
|
| |
|
| |
|
| | bool has_eos = false;
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | llama_batch batch = llama_batch_init(llama_n_ctx(ctx), 0, W + G + 1);
|
| |
|
| |
|
| | struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
| |
|
| |
|
| | std::vector<ngram_data> ngrams_cur(G);
|
| |
|
| |
|
| | std::vector<llama_token> tokens_j_prev(W);
|
| | std::vector<std::vector<llama_token>> tokens_j(N - 1);
|
| | for (int j = 0; j < N - 1; j++) {
|
| | tokens_j[j].resize(W);
|
| |
|
| | for (int i = 0; i < W; i++) {
|
| |
|
| | if (0) {
|
| |
|
| | tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
|
| | } else {
|
| |
|
| | tokens_j[j][i] = 100 + i;
|
| | }
|
| | }
|
| | }
|
| |
|
| | std::vector<llama_seq_id> seq_id_look;
|
| |
|
| |
|
| | std::vector<llama_seq_id> seq_id_all(W + G + 1);
|
| | for (int i = 0; i < W + G + 1; i++) {
|
| | seq_id_all[i] = i;
|
| | }
|
| |
|
| |
|
| | ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G);
|
| |
|
| | const auto t_dec_start = ggml_time_us();
|
| |
|
| |
|
| | {
|
| | id = common_sampler_sample(smpl, ctx, 0);
|
| |
|
| | common_sampler_accept(smpl, id, true);
|
| |
|
| | {
|
| | const std::string token_str = common_token_to_piece(ctx, id);
|
| |
|
| | LOG("%s", token_str.c_str());
|
| | fflush(stdout);
|
| | }
|
| | }
|
| |
|
| | while (true) {
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | {
|
| | common_batch_clear(batch);
|
| |
|
| |
|
| | common_batch_add(batch, id, n_past, seq_id_all, true);
|
| |
|
| |
|
| | {
|
| | const int g_cur = ngrams_observed.cnt[id];
|
| |
|
| | ngrams_cur.resize(g_cur);
|
| | for (int g = 0; g < g_cur; g++) {
|
| | ngrams_cur[g].active = true;
|
| | ngrams_cur[g].tokens.resize(N);
|
| | ngrams_cur[g].i_batch.resize(N);
|
| | ngrams_cur[g].seq_id = W + 1 + g;
|
| | ngrams_cur[g].i_batch[0] = 0;
|
| | ngrams_cur[g].tokens [0] = id;
|
| | }
|
| |
|
| | for (int j = 0; j < N - 1; j++) {
|
| | for (int g = 0; g < g_cur; g++) {
|
| | const int idx = id*(N - 1)*G + g*(N - 1);
|
| |
|
| | const llama_token t = ngrams_observed.tokens[idx + j];
|
| |
|
| | ngrams_cur[g].tokens [j + 1] = t;
|
| | ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
|
| |
|
| | common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
|
| | }
|
| | }
|
| | }
|
| |
|
| |
|
| | for (int i = 1; i < W; i++) {
|
| | seq_id_look.resize(W - i);
|
| | for (int j = 0; j < W - i; j++) {
|
| | seq_id_look[j] = i + j + 1;
|
| | }
|
| |
|
| | common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
|
| | }
|
| |
|
| |
|
| | for (int j = 1; j < N - 1; j++) {
|
| | for (int i = 0; i < W; i++) {
|
| | common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
|
| | }
|
| | }
|
| | }
|
| |
|
| | if (llama_decode(ctx, batch) != 0) {
|
| | LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
|
| | return 1;
|
| | }
|
| |
|
| | int seq_id_best = 0;
|
| |
|
| | for (int v = 0; v < N; ++v) {
|
| | int i_batch = 0;
|
| |
|
| |
|
| | if (v > 0) {
|
| | for (int g = 0; g < (int) ngrams_cur.size(); g++) {
|
| | if (ngrams_cur[g].active) {
|
| | i_batch = ngrams_cur[g].i_batch[v];
|
| | seq_id_best = ngrams_cur[g].seq_id;
|
| |
|
| | ++n_accept;
|
| | break;
|
| | }
|
| | }
|
| |
|
| |
|
| | if (i_batch == 0) {
|
| | break;
|
| | }
|
| | }
|
| |
|
| |
|
| | id = common_sampler_sample(smpl, ctx, i_batch);
|
| |
|
| | common_sampler_accept(smpl, id, true);
|
| |
|
| |
|
| | {
|
| | const std::string token_str = common_token_to_piece(ctx, id);
|
| |
|
| | if (v == 0) {
|
| | LOG("%s", token_str.c_str());
|
| | } else {
|
| |
|
| | LOG("\033[0;96m%s\033[0m", token_str.c_str());
|
| | }
|
| | fflush(stdout);
|
| |
|
| | if (llama_vocab_is_eog(vocab, id)) {
|
| | has_eos = true;
|
| | }
|
| |
|
| | all.push_back(id);
|
| | }
|
| |
|
| | ++n_predict;
|
| | ++n_past;
|
| |
|
| | if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
| | break;
|
| | }
|
| |
|
| |
|
| | for (int g = 0; g < (int) ngrams_cur.size(); g++) {
|
| | if (ngrams_cur[g].active) {
|
| | if (v == N - 1) {
|
| | ngrams_cur[g].active = false;
|
| | } else {
|
| | if (id != ngrams_cur[g].tokens[v + 1]) {
|
| | ngrams_cur[g].active = false;
|
| | }
|
| | }
|
| | }
|
| | }
|
| |
|
| |
|
| | if (0 && v == 0) {
|
| | if (ngrams_observed.cnt[id] > 0) {
|
| | LOG("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], common_token_to_piece(ctx, id).c_str());
|
| | }
|
| |
|
| | for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
|
| | LOG(" - ngram %2d: ", i);
|
| |
|
| | const int idx = id*(N - 1)*G + i*(N - 1);
|
| |
|
| | for (int j = 0; j < N - 1; j++) {
|
| | const std::string token_str = common_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
|
| |
|
| | LOG("%s", token_str.c_str());
|
| | }
|
| |
|
| | LOG("\n");
|
| | }
|
| | }
|
| |
|
| |
|
| | {
|
| | for (int i = 0; i < W; i++) {
|
| | tokens_j_prev[i] = tokens_j[0][i];
|
| | }
|
| |
|
| | for (int j = 0; j < N - 2; j++) {
|
| | tokens_j[j] = tokens_j[j + 1];
|
| | }
|
| |
|
| | if (v == 0) {
|
| |
|
| | for (int i = 0; i < W; i++) {
|
| | tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
| | }
|
| | } else {
|
| | for (int i = 0; i < W; i++) {
|
| |
|
| | if (0) {
|
| |
|
| | tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
|
| | } else {
|
| |
|
| | tokens_j[N - 2][i] = tokens_j[0][i];
|
| | }
|
| | }
|
| | }
|
| | }
|
| |
|
| |
|
| | if (v == 0) {
|
| |
|
| | std::vector<llama_token> ngram(N - 1);
|
| |
|
| |
|
| |
|
| | for (int f = 0; f < W; ++f) {
|
| | const int ft = tokens_j_prev[f];
|
| |
|
| | for (int j = 0; j < N - 1; ++j) {
|
| | ngram[j] = tokens_j[j][f];
|
| | }
|
| |
|
| |
|
| | {
|
| | bool is_unique = true;
|
| |
|
| | for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
|
| | const int idx = ft*(N - 1)*G + k*(N - 1);
|
| |
|
| | bool is_match = true;
|
| | for (int j = 0; j < N - 1; ++j) {
|
| | if (ngrams_observed.tokens[idx + j] != ngram[j]) {
|
| | is_match = false;
|
| | break;
|
| | }
|
| | }
|
| |
|
| | if (is_match) {
|
| | is_unique = false;
|
| | break;
|
| | }
|
| | }
|
| |
|
| | if (!is_unique) {
|
| | continue;
|
| | }
|
| | }
|
| |
|
| | const int head = ngrams_observed.head[ft];
|
| | const int idx = ft*(N - 1)*G + head*(N - 1);
|
| |
|
| | for (int i = 0; i < N - 1; i++) {
|
| | ngrams_observed.tokens[idx + i] = ngram[i];
|
| | }
|
| |
|
| | ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
|
| | ngrams_observed.head[ft] = (head + 1) % G;
|
| |
|
| | ngrams_observed.n_total++;
|
| | }
|
| | }
|
| | }
|
| |
|
| | if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
| | break;
|
| | }
|
| |
|
| |
|
| |
|
| | llama_memory_seq_rm(mem, -1, n_past, -1);
|
| |
|
| | if (seq_id_best != 0) {
|
| |
|
| |
|
| | llama_memory_seq_keep(mem, seq_id_best);
|
| | llama_memory_seq_cp (mem, seq_id_best, 0, -1, -1);
|
| | llama_memory_seq_rm (mem, seq_id_best, -1, -1);
|
| |
|
| | for (int s = 1; s < W + G + 1; ++s) {
|
| | llama_memory_seq_cp(mem, 0, s, -1, -1);
|
| | }
|
| | }
|
| | }
|
| |
|
| | auto t_dec_end = ggml_time_us();
|
| |
|
| | LOG("\n\n");
|
| |
|
| | LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
| | LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
| |
|
| | LOG_INF("\n");
|
| | LOG_INF("W = %2d\n", W);
|
| | LOG_INF("N = %2d\n", N);
|
| | LOG_INF("G = %2d\n", G);
|
| | LOG_INF("\n");
|
| | LOG_INF("n_predict = %d\n", n_predict);
|
| | LOG_INF("n_accept = %d\n", n_accept);
|
| |
|
| | LOG_INF("\n");
|
| | common_perf_print(ctx, smpl);
|
| |
|
| | common_sampler_free(smpl);
|
| |
|
| | llama_batch_free(batch);
|
| |
|
| | llama_backend_free();
|
| |
|
| | LOG("\n\n");
|
| |
|
| | return 0;
|
| | }
|
| |
|