|
|
#pragma once |
|
|
|
|
|
#include "llama.h" |
|
|
|
|
|
#include "llama-cparams.h" |
|
|
|
|
|
#include <array> |
|
|
#include <vector> |
|
|
#include <set> |
|
|
#include <bitset> |
|
|
#include <memory> |
|
|
#include <unordered_map> |
|
|
|
|
|
|
|
|
struct llama_ubatch { |
|
|
bool equal_seqs() const { |
|
|
return b_equal_seqs != 0; |
|
|
} |
|
|
|
|
|
uint32_t b_equal_seqs; |
|
|
|
|
|
|
|
|
|
|
|
uint32_t n_tokens; |
|
|
uint32_t n_seq_tokens; |
|
|
uint32_t n_seqs; |
|
|
uint32_t n_seqs_unq; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llama_token * token; |
|
|
float * embd; |
|
|
llama_pos * pos; |
|
|
int32_t * n_seq_id; |
|
|
llama_seq_id ** seq_id; |
|
|
llama_seq_id * seq_id_unq; |
|
|
int32_t * seq_idx; |
|
|
int8_t * output; |
|
|
|
|
|
struct data_t { |
|
|
std::vector<llama_token> token; |
|
|
std::vector<float> embd; |
|
|
std::vector<llama_pos> pos; |
|
|
std::vector<int32_t> n_seq_id; |
|
|
std::vector<llama_seq_id *> seq_id; |
|
|
std::vector<llama_seq_id> seq_id_unq; |
|
|
std::vector<int32_t> seq_idx; |
|
|
std::vector<int8_t> output; |
|
|
}; |
|
|
|
|
|
|
|
|
std::shared_ptr<data_t> data; |
|
|
}; |
|
|
|
|
|
|
|
|
class llama_batch_allocr { |
|
|
public: |
|
|
llama_batch_allocr(uint32_t n_pos_per_embd); |
|
|
|
|
|
|
|
|
|
|
|
bool init( |
|
|
const llama_batch & batch_inp, |
|
|
const llama_vocab & vocab, |
|
|
const llama_memory_i * memory, |
|
|
uint32_t n_embd, |
|
|
uint32_t n_seq_max, |
|
|
bool output_all); |
|
|
|
|
|
const llama_batch & get_batch() const; |
|
|
|
|
|
uint32_t get_n_tokens() const; |
|
|
uint32_t get_n_outputs() const; |
|
|
uint32_t get_n_used() const; |
|
|
|
|
|
|
|
|
std::vector<int32_t> & get_out_ids(); |
|
|
|
|
|
|
|
|
llama_pos seq_pos_min(llama_seq_id seq_id) const; |
|
|
llama_pos seq_pos_max(llama_seq_id seq_id) const; |
|
|
|
|
|
|
|
|
void split_reset(); |
|
|
|
|
|
|
|
|
llama_ubatch split_simple(uint32_t n_ubatch); |
|
|
|
|
|
|
|
|
|
|
|
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); |
|
|
|
|
|
|
|
|
llama_ubatch split_seq(uint32_t n_ubatch); |
|
|
|
|
|
|
|
|
|
|
|
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs); |
|
|
|
|
|
private: |
|
|
void clear(); |
|
|
|
|
|
|
|
|
|
|
|
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs); |
|
|
|
|
|
|
|
|
void ubatch_print(const llama_ubatch & ubatch, int debug); |
|
|
|
|
|
llama_batch batch; |
|
|
|
|
|
|
|
|
const llama_vocab * vocab; |
|
|
|
|
|
|
|
|
|
|
|
const uint32_t n_pos_per_embd; |
|
|
|
|
|
uint32_t n_embd; |
|
|
uint32_t n_seq_max; |
|
|
uint32_t n_outputs; |
|
|
|
|
|
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; |
|
|
|
|
|
std::vector<llama_pos> pos; |
|
|
std::vector<int32_t> n_seq_id; |
|
|
std::vector<llama_seq_id *> seq_id; |
|
|
std::vector<llama_seq_id> seq_id_unq; |
|
|
std::vector<int32_t> seq_idx; |
|
|
std::vector<int8_t> output; |
|
|
|
|
|
using pos_set_t = std::set<llama_pos>; |
|
|
using seq_cpl_t = std::vector<bool>; |
|
|
|
|
|
|
|
|
bool has_cpl = false; |
|
|
|
|
|
std::vector<pos_set_t> seq_pos; |
|
|
std::vector<seq_cpl_t> seq_cpl; |
|
|
|
|
|
using idx_vec_t = std::vector<int32_t>; |
|
|
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; |
|
|
|
|
|
std::vector<seq_set_t> seq_set; |
|
|
|
|
|
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; |
|
|
|
|
|
|
|
|
std::vector<int32_t> out_ids; |
|
|
|
|
|
uint32_t n_used; |
|
|
|
|
|
|
|
|
std::vector<bool> used; |
|
|
|
|
|
int debug; |
|
|
}; |
|
|
|