|
|
#pragma once |
|
|
|
|
|
#include "llama-batch.h" |
|
|
#include "llama-graph.h" |
|
|
#include "llama-memory.h" |
|
|
|
|
|
#include <set> |
|
|
#include <vector> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class llama_memory_recurrent : public llama_memory_i { |
|
|
public: |
|
|
|
|
|
|
|
|
using layer_filter_cb = std::function<bool(int32_t il)>; |
|
|
|
|
|
llama_memory_recurrent( |
|
|
const llama_model & model, |
|
|
layer_filter_cb && filter, |
|
|
ggml_type type_r, |
|
|
ggml_type type_s, |
|
|
bool offload, |
|
|
uint32_t mem_size, |
|
|
uint32_t n_seq_max); |
|
|
|
|
|
~llama_memory_recurrent() = default; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llama_memory_context_ptr init_batch( |
|
|
llama_batch_allocr & balloc, |
|
|
uint32_t n_ubatch, |
|
|
bool embd_all) override; |
|
|
|
|
|
llama_memory_context_ptr init_full() override; |
|
|
|
|
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
|
|
|
|
|
void clear(bool data) override; |
|
|
|
|
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
|
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
|
|
void seq_keep(llama_seq_id seq_id) override; |
|
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
|
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
|
|
|
|
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
|
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
|
|
|
|
|
bool prepare(const std::vector<llama_ubatch> & ubatches); |
|
|
|
|
|
|
|
|
bool find_slot(const llama_ubatch & ubatch); |
|
|
|
|
|
bool get_can_shift() const override; |
|
|
|
|
|
|
|
|
|
|
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; |
|
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; |
|
|
|
|
|
uint32_t head = 0; |
|
|
uint32_t size = 0; |
|
|
uint32_t used = 0; |
|
|
|
|
|
|
|
|
uint32_t n = 0; |
|
|
|
|
|
|
|
|
int32_t rs_z = -1; |
|
|
|
|
|
|
|
|
struct mem_cell { |
|
|
llama_pos pos = -1; |
|
|
int32_t src = -1; |
|
|
int32_t src0 = -1; |
|
|
int32_t tail = -1; |
|
|
|
|
|
std::set<llama_seq_id> seq_id; |
|
|
|
|
|
bool has_seq_id(const llama_seq_id & id) const { |
|
|
return seq_id.find(id) != seq_id.end(); |
|
|
} |
|
|
|
|
|
bool is_empty() const { |
|
|
return seq_id.empty(); |
|
|
} |
|
|
|
|
|
bool is_same_seq(const mem_cell & other) const { |
|
|
return seq_id == other.seq_id; |
|
|
} |
|
|
}; |
|
|
|
|
|
std::vector<mem_cell> cells; |
|
|
|
|
|
|
|
|
std::vector<ggml_tensor *> r_l; |
|
|
std::vector<ggml_tensor *> s_l; |
|
|
|
|
|
private: |
|
|
|
|
|
const llama_hparams & hparams; |
|
|
|
|
|
const uint32_t n_seq_max = 1; |
|
|
|
|
|
std::vector<ggml_context_ptr> ctxs; |
|
|
std::vector<ggml_backend_buffer_ptr> bufs; |
|
|
|
|
|
size_t total_size() const; |
|
|
|
|
|
size_t size_r_bytes() const; |
|
|
size_t size_s_bytes() const; |
|
|
|
|
|
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; |
|
|
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; |
|
|
|
|
|
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); |
|
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); |
|
|
}; |
|
|
|
|
|
class llama_memory_recurrent_context : public llama_memory_context_i { |
|
|
public: |
|
|
|
|
|
llama_memory_recurrent_context(llama_memory_status status); |
|
|
|
|
|
|
|
|
llama_memory_recurrent_context( |
|
|
llama_memory_recurrent * mem); |
|
|
|
|
|
|
|
|
llama_memory_recurrent_context( |
|
|
llama_memory_recurrent * mem, |
|
|
std::vector<llama_ubatch> ubatches); |
|
|
|
|
|
virtual ~llama_memory_recurrent_context(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool next() override; |
|
|
bool apply() override; |
|
|
|
|
|
llama_memory_status get_status() const override; |
|
|
const llama_ubatch & get_ubatch() const override; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint32_t get_n_rs() const; |
|
|
uint32_t get_head() const; |
|
|
int32_t get_rs_z() const; |
|
|
uint32_t get_size() const; |
|
|
|
|
|
ggml_tensor * get_r_l(int32_t il) const; |
|
|
ggml_tensor * get_s_l(int32_t il) const; |
|
|
|
|
|
int32_t s_copy(int i) const; |
|
|
|
|
|
private: |
|
|
const llama_memory_status status; |
|
|
|
|
|
llama_memory_recurrent * mem; |
|
|
|
|
|
size_t i_next = 0; |
|
|
|
|
|
std::vector<llama_ubatch> ubatches; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const bool is_full = false; |
|
|
}; |
|
|
|