| | #include "models.h"
|
| |
|
| | #define CHUNK_SIZE 64
|
| |
|
| |
|
| |
|
| |
|
| | static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
|
| | return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
|
| | t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
|
| | }
|
| |
|
| | llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
|
| |
|
| | std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking(
|
| | ggml_tensor * q,
|
| | ggml_tensor * k,
|
| | ggml_tensor * v,
|
| | ggml_tensor * g,
|
| | ggml_tensor * b,
|
| | ggml_tensor * s,
|
| | int il) {
|
| | const int64_t S_k = q->ne[0];
|
| | const int64_t H_k = q->ne[1];
|
| | const int64_t n_tokens = q->ne[2];
|
| | const int64_t n_seqs = q->ne[3];
|
| |
|
| | const int64_t S_v = v->ne[0];
|
| | const int64_t H_v = v->ne[1];
|
| | const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k);
|
| |
|
| | GGML_ASSERT(S_k == S_v);
|
| | GGML_ASSERT(H_v % H_k == 0);
|
| |
|
| | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
| | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
| | GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
| |
|
| | GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
|
| | GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
|
| | GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
| | GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
| |
|
| | const float scale = 1.0f / sqrtf(S_k);
|
| |
|
| | q = ggml_scale(ctx0, q, scale);
|
| |
|
| | cb(q, "q_in", il);
|
| | cb(k, "k_in", il);
|
| | cb(v, "v_in", il);
|
| | cb(b, "b_in", il);
|
| | cb(g, "g_in", il);
|
| |
|
| | q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
| | k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
| | v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
| | g = ggml_permute(ctx0, g, 0, 2, 1, 3);
|
| | b = ggml_permute(ctx0, b, 0, 2, 1, 3);
|
| |
|
| | const int CS = CHUNK_SIZE;
|
| |
|
| | const int pad = (CS - n_tokens % CS) % CS;
|
| | const int n_chunks = (n_tokens + pad) / CS;
|
| |
|
| | q = ggml_pad(ctx0, q, 0, pad, 0, 0);
|
| | k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
| | v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
| | g = ggml_pad(ctx0, g, 0, pad, 0, 0);
|
| | b = ggml_pad(ctx0, b, 0, pad, 0, 0);
|
| |
|
| | ggml_tensor * v_b = ggml_mul(ctx0, v, b);
|
| | ggml_tensor * k_b = ggml_mul(ctx0, k, b);
|
| |
|
| | cb(v_b, "v_b", il);
|
| | cb(k_b, "k_b", il);
|
| |
|
| | q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs);
|
| | k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs);
|
| | k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
|
| | v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs);
|
| | v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
|
| |
|
| | g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs);
|
| | b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
|
| |
|
| |
|
| |
|
| | ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g)));
|
| | cb(g_cs, "g_cs", il);
|
| |
|
| | ggml_tensor * kb = nullptr;
|
| | ggml_tensor * kq = nullptr;
|
| | if (kda) {
|
| | const int64_t CHB = n_chunks * H_k * n_seqs;
|
| |
|
| | ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB);
|
| | ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB);
|
| |
|
| | g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB);
|
| |
|
| |
|
| | ggml_tensor * decay_mask;
|
| | decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
|
| | decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
|
| | decay_mask = ggml_exp(ctx0, decay_mask);
|
| | cb(decay_mask, "decay_mask", il);
|
| |
|
| |
|
| | decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB);
|
| |
|
| | ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS, 1, CHB);
|
| | ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, CS, CHB);
|
| | ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, CS, 1, CHB);
|
| |
|
| | ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i);
|
| | ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
|
| |
|
| |
|
| | kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j);
|
| | kq = ggml_mul_mat(ctx0, decay_q_i, k_j);
|
| |
|
| | kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs)));
|
| | kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs)));
|
| | } else {
|
| | ggml_tensor * g_cs_i = g_cs;
|
| | ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
|
| |
|
| | g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
|
| |
|
| |
|
| | ggml_tensor * decay_mask;
|
| | decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
|
| | decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
|
| | decay_mask = ggml_exp(ctx0, decay_mask);
|
| | cb(decay_mask, "decay_mask", il);
|
| |
|
| |
|
| | kb = ggml_mul_mat(ctx0, k, k_b);
|
| | kb = ggml_mul (ctx0, kb, decay_mask);
|
| |
|
| |
|
| | kq = ggml_mul_mat(ctx0, k, q);
|
| | kq = ggml_mul(ctx0, kq, decay_mask);
|
| | }
|
| |
|
| | kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
|
| | cb(kq, "kq", il);
|
| |
|
| |
|
| | ggml_tensor * attn;
|
| | attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
|
| | cb(attn, "attn", il);
|
| |
|
| | ggml_tensor * identity;
|
| | identity = ggml_view_1d(ctx0, attn, CS, 0);
|
| | identity = ggml_fill (ctx0, identity, 1.0f);
|
| | identity = ggml_diag (ctx0, identity);
|
| |
|
| | ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
|
| | cb(lhs, "dnet_add_ch_lhs", il);
|
| |
|
| | attn = ggml_neg(ctx0, attn);
|
| | cb(attn, "attn_pre_solve", il);
|
| |
|
| | ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
| | attn = ggml_add(ctx0, lin_solve, identity);
|
| | cb(attn, "dnet_add_ch_attn_solved", il);
|
| |
|
| |
|
| | v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
|
| |
|
| |
|
| | ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
|
| |
|
| | k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
|
| |
|
| |
|
| | ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
|
| | cb(kbg, "k_beta_g_exp", il);
|
| |
|
| |
|
| | ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
|
| | cb(k_cd, "k_cumdecay", il);
|
| |
|
| |
|
| | ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp));
|
| | ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3],
|
| | g_cs->nb[1],
|
| | g_cs->nb[2],
|
| | g_cs->nb[3],
|
| | ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
|
| | cb(g_last, "g_last", il);
|
| |
|
| |
|
| | g_last = ggml_cont(ctx0, g_last);
|
| |
|
| |
|
| | ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last));
|
| | cb(g_last_exp_t, "g_last_exp_t", il);
|
| |
|
| |
|
| | ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
|
| | cb(g_diff, "g_diff", il);
|
| |
|
| | ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff)));
|
| |
|
| |
|
| | ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
|
| | cb(kg, "key_gdiff", il);
|
| |
|
| |
|
| | ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
|
| | cb(kg_t, "key_gdiff_t", il);
|
| |
|
| | ggml_tensor * s_t = ggml_transpose(ctx0, s);
|
| | s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
|
| | cb(s_t, "dnet_add_ch_state", il);
|
| |
|
| |
|
| | ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
| |
|
| | for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
| | ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk);
|
| | ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk);
|
| | ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk);
|
| | ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk);
|
| | ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk);
|
| |
|
| |
|
| | ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
|
| | cb(v_t_p, "v_prime", il);
|
| |
|
| |
|
| | ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
|
| | cb(v_t_new, "v_t_new", il);
|
| |
|
| |
|
| | ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
|
| | cb(v_attn, "v_attn", il);
|
| |
|
| |
|
| | ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
|
| | cb(attn_inter, "attn_inter", il);
|
| |
|
| |
|
| | ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
|
| | cb(o_ch, "dnet_add_ch_attn_out", il);
|
| |
|
| | v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
|
| |
|
| |
|
| |
|
| | ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new);
|
| |
|
| |
|
| | ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
|
| |
|
| | s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
|
| | s_t = ggml_add(ctx0, s_t, kgv);
|
| | cb(s_t, "dnet_add_ch_state", il);
|
| | }
|
| |
|
| | s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
|
| |
|
| |
|
| | ggml_tensor * o = ggml_view_4d(ctx0, v,
|
| | S_v, n_tokens, H_v, n_seqs,
|
| | ggml_row_size(v->type, S_v),
|
| | ggml_row_size(v->type, S_v * CS * n_chunks),
|
| | ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
|
| | o = ggml_permute (ctx0, o, 0, 2, 1, 3);
|
| | s = ggml_transpose(ctx0, s_t);
|
| | cb(s, "output_state", il);
|
| |
|
| | return {o, s};
|
| | }
|
| |
|
| | std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive(
|
| | ggml_tensor * q,
|
| | ggml_tensor * k,
|
| | ggml_tensor * v,
|
| | ggml_tensor * g,
|
| | ggml_tensor * b,
|
| | ggml_tensor * s,
|
| | int il) {
|
| | const int64_t S_k = q->ne[0];
|
| | const int64_t H_k = q->ne[1];
|
| | const int64_t n_tokens = q->ne[2];
|
| | const int64_t n_seqs = q->ne[3];
|
| |
|
| | const int64_t S_v = v->ne[0];
|
| | const int64_t H_v = v->ne[1];
|
| |
|
| | GGML_ASSERT(n_tokens == 1);
|
| |
|
| | GGML_ASSERT(S_k == S_v);
|
| | GGML_ASSERT(H_v % H_k == 0);
|
| |
|
| | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
| | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
| | GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
| |
|
| | GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
|
| | GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
|
| | GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
| | GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
| |
|
| | const float scale = 1.0f / sqrtf(S_k);
|
| |
|
| | q = ggml_scale(ctx0, q, scale);
|
| |
|
| | q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
| | k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
| | v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
| |
|
| | cb(q, "q_in", il);
|
| | cb(k, "k_in", il);
|
| | cb(v, "v_in", il);
|
| | cb(b, "b_in", il);
|
| | cb(g, "g_in", il);
|
| |
|
| |
|
| |
|
| | g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs);
|
| | b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
|
| |
|
| |
|
| | g = ggml_exp(ctx0, g);
|
| | s = ggml_mul(ctx0, s, g);
|
| |
|
| | ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
|
| |
|
| |
|
| | ggml_tensor * sk;
|
| | sk = ggml_mul (ctx0, s_t, k);
|
| | sk = ggml_sum_rows(ctx0, sk);
|
| |
|
| |
|
| | ggml_tensor * d;
|
| | d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
|
| | d = ggml_mul(ctx0, d, b);
|
| |
|
| |
|
| | ggml_tensor * d_t;
|
| | d_t = ggml_transpose(ctx0, d);
|
| |
|
| |
|
| | ggml_tensor * kd;
|
| | k = ggml_repeat(ctx0, k, s);
|
| | kd = ggml_mul (ctx0, k, d_t);
|
| |
|
| | s_t = ggml_add(ctx0, s_t, kd);
|
| |
|
| | cb(s_t, "dnet_add_ar_state", il);
|
| |
|
| | ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
|
| | ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
|
| |
|
| | o = ggml_permute (ctx0, o, 2, 0, 1, 3);
|
| | s = ggml_transpose(ctx0, s_t);
|
| |
|
| | return {o, s};
|
| | }
|
| |
|