|
|
#pragma once |
|
|
|
|
|
#include "llama.h" |
|
|
|
|
|
#include <memory> |
|
|
|
|
|
struct llama_ubatch; |
|
|
|
|
|
class llama_batch_allocr; |
|
|
|
|
|
class llama_io_write_i; |
|
|
class llama_io_read_i; |
|
|
|
|
|
struct llama_memory_params { |
|
|
|
|
|
ggml_type type_k; |
|
|
ggml_type type_v; |
|
|
|
|
|
|
|
|
bool swa_full; |
|
|
}; |
|
|
|
|
|
enum llama_memory_status { |
|
|
LLAMA_MEMORY_STATUS_SUCCESS = 0, |
|
|
LLAMA_MEMORY_STATUS_NO_UPDATE, |
|
|
LLAMA_MEMORY_STATUS_FAILED_PREPARE, |
|
|
LLAMA_MEMORY_STATUS_FAILED_COMPUTE, |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); |
|
|
|
|
|
|
|
|
bool llama_memory_status_is_fail(llama_memory_status status); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct llama_memory_context_i { |
|
|
virtual ~llama_memory_context_i() = default; |
|
|
|
|
|
|
|
|
|
|
|
virtual bool next() = 0; |
|
|
|
|
|
|
|
|
|
|
|
virtual bool apply() = 0; |
|
|
|
|
|
|
|
|
virtual const llama_ubatch & get_ubatch() const = 0; |
|
|
|
|
|
|
|
|
virtual llama_memory_status get_status() const = 0; |
|
|
}; |
|
|
|
|
|
using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>; |
|
|
|
|
|
|
|
|
|
|
|
struct llama_memory_i { |
|
|
virtual ~llama_memory_i() = default; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual llama_memory_context_ptr init_batch( |
|
|
llama_batch_allocr & balloc, |
|
|
uint32_t n_ubatch, |
|
|
bool embd_all) = 0; |
|
|
|
|
|
|
|
|
virtual llama_memory_context_ptr init_full() = 0; |
|
|
|
|
|
|
|
|
|
|
|
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; |
|
|
|
|
|
|
|
|
virtual bool get_can_shift() const = 0; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual void clear(bool data) = 0; |
|
|
|
|
|
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; |
|
|
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; |
|
|
virtual void seq_keep(llama_seq_id seq_id) = 0; |
|
|
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; |
|
|
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; |
|
|
|
|
|
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; |
|
|
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; |
|
|
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; |
|
|
}; |
|
|
|
|
|
using llama_memory_ptr = std::unique_ptr<llama_memory_i>; |
|
|
|
|
|
|
|
|
struct llama_kv_cache : public llama_memory_i { |
|
|
virtual ~llama_kv_cache() = default; |
|
|
}; |
|
|
|