|
|
#pragma once |
|
|
|
|
|
#include "llama-batch.h" |
|
|
#include "llama-graph.h" |
|
|
#include "llama-kv-cells.h" |
|
|
#include "llama-memory.h" |
|
|
|
|
|
#include <unordered_map> |
|
|
#include <vector> |
|
|
|
|
|
struct llama_cparams; |
|
|
struct llama_hparams; |
|
|
struct llama_model; |
|
|
struct llama_context; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class llama_kv_cache_unified : public llama_memory_i { |
|
|
public: |
|
|
static uint32_t get_padding(const llama_cparams & cparams); |
|
|
|
|
|
|
|
|
using layer_filter_cb = std::function<bool(int32_t il)>; |
|
|
|
|
|
struct defrag_info { |
|
|
bool empty() const { |
|
|
return ids.empty(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<uint32_t> ids; |
|
|
}; |
|
|
|
|
|
struct stream_copy_info { |
|
|
bool empty() const { |
|
|
assert(ssrc.size() == sdst.size()); |
|
|
return ssrc.empty(); |
|
|
} |
|
|
|
|
|
std::vector<uint32_t> ssrc; |
|
|
std::vector<uint32_t> sdst; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct slot_info { |
|
|
|
|
|
using idx_vec_t = std::vector<uint32_t>; |
|
|
|
|
|
|
|
|
llama_seq_id s0; |
|
|
llama_seq_id s1; |
|
|
|
|
|
std::vector<llama_seq_id> strm; |
|
|
std::vector<idx_vec_t> idxs; |
|
|
|
|
|
uint32_t head() const { |
|
|
GGML_ASSERT(idxs.size() == 1); |
|
|
GGML_ASSERT(!idxs[0].empty()); |
|
|
|
|
|
return idxs[0][0]; |
|
|
} |
|
|
|
|
|
void resize(size_t n) { |
|
|
strm.resize(n); |
|
|
idxs.resize(n); |
|
|
} |
|
|
|
|
|
size_t size() const { |
|
|
GGML_ASSERT(idxs.size() == strm.size()); |
|
|
GGML_ASSERT(!idxs.empty()); |
|
|
|
|
|
return idxs[0].size(); |
|
|
} |
|
|
|
|
|
size_t n_stream() const { |
|
|
return strm.size(); |
|
|
} |
|
|
|
|
|
bool empty() const { |
|
|
return idxs.empty(); |
|
|
} |
|
|
|
|
|
void clear() { |
|
|
idxs.clear(); |
|
|
} |
|
|
}; |
|
|
|
|
|
using slot_info_vec_t = std::vector<slot_info>; |
|
|
|
|
|
llama_kv_cache_unified( |
|
|
const llama_model & model, |
|
|
layer_filter_cb && filter, |
|
|
ggml_type type_k, |
|
|
ggml_type type_v, |
|
|
bool v_trans, |
|
|
bool offload, |
|
|
bool unified, |
|
|
uint32_t kv_size, |
|
|
uint32_t n_seq_max, |
|
|
uint32_t n_pad, |
|
|
uint32_t n_swa, |
|
|
llama_swa_type swa_type); |
|
|
|
|
|
~llama_kv_cache_unified() = default; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llama_memory_context_ptr init_batch( |
|
|
llama_batch_allocr & balloc, |
|
|
uint32_t n_ubatch, |
|
|
bool embd_all) override; |
|
|
|
|
|
llama_memory_context_ptr init_full() override; |
|
|
|
|
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
|
|
|
|
|
bool get_can_shift() const override; |
|
|
|
|
|
void clear(bool data) override; |
|
|
|
|
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
|
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
|
|
void seq_keep(llama_seq_id seq_id) override; |
|
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
|
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
|
|
|
|
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
|
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
|
|
|
|
|
|
|
|
|
|
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; |
|
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint32_t get_size() const; |
|
|
uint32_t get_n_stream() const; |
|
|
|
|
|
bool get_has_shift() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint32_t get_n_kv() const; |
|
|
|
|
|
|
|
|
bool get_supports_set_rows() const; |
|
|
|
|
|
|
|
|
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
|
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
|
|
|
|
|
|
|
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; |
|
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches); |
|
|
|
|
|
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; |
|
|
|
|
|
|
|
|
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
|
|
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
|
|
|
|
|
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
|
|
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
|
|
|
|
|
void set_input_k_shift(ggml_tensor * dst) const; |
|
|
|
|
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
|
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
|
|
|
|
|
private: |
|
|
const llama_model & model; |
|
|
const llama_hparams & hparams; |
|
|
|
|
|
struct kv_layer { |
|
|
|
|
|
|
|
|
uint32_t il; |
|
|
|
|
|
ggml_tensor * k; |
|
|
ggml_tensor * v; |
|
|
|
|
|
std::vector<ggml_tensor *> k_stream; |
|
|
std::vector<ggml_tensor *> v_stream; |
|
|
}; |
|
|
|
|
|
bool v_trans = true; |
|
|
|
|
|
const uint32_t n_seq_max = 1; |
|
|
const uint32_t n_stream = 1; |
|
|
|
|
|
|
|
|
const uint32_t n_pad = 1; |
|
|
|
|
|
|
|
|
const uint32_t n_swa = 0; |
|
|
|
|
|
|
|
|
int debug = 0; |
|
|
|
|
|
|
|
|
|
|
|
bool supports_set_rows = false; |
|
|
|
|
|
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; |
|
|
|
|
|
std::vector<ggml_context_ptr> ctxs; |
|
|
std::vector<ggml_backend_buffer_ptr> bufs; |
|
|
|
|
|
|
|
|
|
|
|
std::vector<uint32_t> v_heads; |
|
|
|
|
|
std::vector<llama_kv_cells_unified> v_cells; |
|
|
|
|
|
|
|
|
std::vector<uint32_t> seq_to_stream; |
|
|
|
|
|
|
|
|
stream_copy_info sc_info; |
|
|
|
|
|
std::vector<kv_layer> layers; |
|
|
|
|
|
|
|
|
std::unordered_map<int32_t, int32_t> map_layer_ids; |
|
|
|
|
|
|
|
|
defrag_info defrag_prepare(int32_t n_max_nodes) const; |
|
|
|
|
|
size_t total_size() const; |
|
|
|
|
|
size_t size_k_bytes() const; |
|
|
size_t size_v_bytes() const; |
|
|
|
|
|
bool is_masked_swa(llama_pos p0, llama_pos p1) const; |
|
|
|
|
|
ggml_tensor * build_rope_shift( |
|
|
const llama_cparams & cparams, |
|
|
ggml_context * ctx, |
|
|
ggml_tensor * cur, |
|
|
ggml_tensor * shift, |
|
|
ggml_tensor * factors, |
|
|
float freq_base, |
|
|
float freq_scale) const; |
|
|
|
|
|
ggml_cgraph * build_graph_shift( |
|
|
llm_graph_result * res, |
|
|
llama_context * lctx) const; |
|
|
|
|
|
ggml_cgraph * build_graph_defrag( |
|
|
llm_graph_result * res, |
|
|
llama_context * lctx, |
|
|
const defrag_info & dinfo) const; |
|
|
|
|
|
struct cell_ranges_t { |
|
|
uint32_t strm; |
|
|
|
|
|
std::vector<std::pair<uint32_t, uint32_t>> data; |
|
|
}; |
|
|
|
|
|
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; |
|
|
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; |
|
|
|
|
|
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); |
|
|
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); |
|
|
}; |
|
|
|
|
|
class llama_kv_cache_unified_context : public llama_memory_context_i { |
|
|
public: |
|
|
|
|
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; |
|
|
using defrag_info = llama_kv_cache_unified::defrag_info; |
|
|
using stream_copy_info = llama_kv_cache_unified::stream_copy_info; |
|
|
|
|
|
|
|
|
llama_kv_cache_unified_context(llama_memory_status status); |
|
|
|
|
|
|
|
|
llama_kv_cache_unified_context( |
|
|
llama_kv_cache_unified * kv); |
|
|
|
|
|
|
|
|
llama_kv_cache_unified_context( |
|
|
llama_kv_cache_unified * kv, |
|
|
llama_context * lctx, |
|
|
bool do_shift, |
|
|
defrag_info dinfo, |
|
|
stream_copy_info sc_info); |
|
|
|
|
|
|
|
|
llama_kv_cache_unified_context( |
|
|
llama_kv_cache_unified * kv, |
|
|
slot_info_vec_t sinfos, |
|
|
std::vector<llama_ubatch> ubatches); |
|
|
|
|
|
virtual ~llama_kv_cache_unified_context(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool next() override; |
|
|
bool apply() override; |
|
|
|
|
|
llama_memory_status get_status() const override; |
|
|
const llama_ubatch & get_ubatch() const override; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint32_t get_n_kv() const; |
|
|
|
|
|
|
|
|
bool get_supports_set_rows() const; |
|
|
|
|
|
|
|
|
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; |
|
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; |
|
|
|
|
|
|
|
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; |
|
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; |
|
|
|
|
|
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
|
|
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
|
|
|
|
|
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
|
|
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
|
|
|
|
|
void set_input_k_shift (ggml_tensor * dst) const; |
|
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
|
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
|
|
|
|
|
private: |
|
|
llama_memory_status status; |
|
|
|
|
|
llama_kv_cache_unified * kv; |
|
|
llama_context * lctx; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool do_shift = false; |
|
|
|
|
|
defrag_info dinfo; |
|
|
|
|
|
stream_copy_info sc_info; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
size_t i_cur = 0; |
|
|
|
|
|
slot_info_vec_t sinfos; |
|
|
|
|
|
std::vector<llama_ubatch> ubatches; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int32_t n_kv; |
|
|
}; |
|
|
|