|
|
#pragma once |
|
|
|
|
|
#include "llama.h" |
|
|
#include "llama-cparams.h" |
|
|
|
|
|
#include <bitset> |
|
|
#include <cassert> |
|
|
#include <vector> |
|
|
#include <set> |
|
|
#include <map> |
|
|
|
|
|
|
|
|
|
|
|
class llama_kv_cells_unified { |
|
|
public: |
|
|
void reset() { |
|
|
for (uint32_t i = 0; i < pos.size(); ++i) { |
|
|
pos[i] = -1; |
|
|
shift[i] = 0; |
|
|
seq[i].reset(); |
|
|
} |
|
|
|
|
|
has_shift = false; |
|
|
|
|
|
used.clear(); |
|
|
|
|
|
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { |
|
|
seq_pos[s].clear(); |
|
|
} |
|
|
} |
|
|
|
|
|
void reset_shift() { |
|
|
has_shift = false; |
|
|
|
|
|
for (uint32_t i = 0; i < shift.size(); ++i) { |
|
|
shift[i] = 0; |
|
|
} |
|
|
} |
|
|
|
|
|
uint32_t size() const { |
|
|
return pos.size(); |
|
|
} |
|
|
|
|
|
void resize(uint32_t n) { |
|
|
pos.resize(n); |
|
|
shift.resize(n); |
|
|
seq.resize(n); |
|
|
|
|
|
reset(); |
|
|
} |
|
|
|
|
|
bool is_empty(uint32_t i) const { |
|
|
assert(i < pos.size()); |
|
|
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); |
|
|
|
|
|
return pos[i] == -1; |
|
|
} |
|
|
|
|
|
uint32_t get_used() const { |
|
|
return used.size(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
uint32_t used_min() const { |
|
|
return used.empty() ? 0 : *used.begin(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
uint32_t used_max_p1() const { |
|
|
return used.empty() ? 0 : *used.rbegin() + 1; |
|
|
} |
|
|
|
|
|
bool get_has_shift() const { |
|
|
return has_shift; |
|
|
} |
|
|
|
|
|
|
|
|
void mv(uint32_t isrc, uint32_t idst) { |
|
|
assert(isrc < pos.size()); |
|
|
assert(idst < pos.size()); |
|
|
|
|
|
assert(pos[idst] == -1); |
|
|
assert(pos[isrc] != -1); |
|
|
|
|
|
pos [idst] = pos [isrc]; |
|
|
shift[idst] = shift[isrc]; |
|
|
seq [idst] = seq [isrc]; |
|
|
|
|
|
pos [isrc] = -1; |
|
|
shift[isrc] = 0; |
|
|
seq [isrc].reset(); |
|
|
|
|
|
used.erase (isrc); |
|
|
used.insert(idst); |
|
|
} |
|
|
|
|
|
|
|
|
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { |
|
|
assert(i + n <= pos.size()); |
|
|
|
|
|
llama_kv_cells_unified res; |
|
|
|
|
|
res.resize(n); |
|
|
|
|
|
for (uint32_t j = 0; j < n; ++j) { |
|
|
const auto idx = i + j; |
|
|
|
|
|
res.pos[j] = pos[idx]; |
|
|
res.seq[j] = seq[idx]; |
|
|
|
|
|
assert(shift[idx] == 0); |
|
|
} |
|
|
|
|
|
return res; |
|
|
} |
|
|
|
|
|
|
|
|
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const { |
|
|
llama_kv_cells_unified res; |
|
|
|
|
|
res.resize(idxs.size()); |
|
|
|
|
|
for (uint32_t j = 0; j < idxs.size(); ++j) { |
|
|
const auto idx = idxs[j]; |
|
|
|
|
|
res.pos[j] = pos[idx]; |
|
|
res.seq[j] = seq[idx]; |
|
|
|
|
|
assert(shift[idx] == 0); |
|
|
} |
|
|
|
|
|
return res; |
|
|
} |
|
|
|
|
|
|
|
|
void set(uint32_t i, const llama_kv_cells_unified & other) { |
|
|
assert(i + other.pos.size() <= pos.size()); |
|
|
|
|
|
for (uint32_t j = 0; j < other.pos.size(); ++j) { |
|
|
const auto idx = i + j; |
|
|
|
|
|
if (pos[idx] == -1 && other.pos[j] != -1) { |
|
|
used.insert(i + j); |
|
|
} |
|
|
|
|
|
if (pos[idx] != -1 && other.pos[j] == -1) { |
|
|
used.erase(i + j); |
|
|
} |
|
|
|
|
|
if (pos[idx] != -1) { |
|
|
seq_pos_rm(i + j); |
|
|
} |
|
|
|
|
|
pos[idx] = other.pos[j]; |
|
|
seq[idx] = other.seq[j]; |
|
|
|
|
|
if (pos[idx] != -1) { |
|
|
seq_pos_add(i + j); |
|
|
} |
|
|
|
|
|
assert(shift[idx] == 0); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) { |
|
|
assert(idxs.size() == other.pos.size()); |
|
|
|
|
|
for (uint32_t j = 0; j < other.pos.size(); ++j) { |
|
|
const auto idx = idxs[j]; |
|
|
|
|
|
if (pos[idx] == -1 && other.pos[j] != -1) { |
|
|
used.insert(idx); |
|
|
} |
|
|
|
|
|
if (pos[idx] != -1 && other.pos[j] == -1) { |
|
|
used.erase(idx); |
|
|
} |
|
|
|
|
|
if (pos[idx] != -1) { |
|
|
seq_pos_rm(idx); |
|
|
} |
|
|
|
|
|
pos[idx] = other.pos[j]; |
|
|
seq[idx] = other.seq[j]; |
|
|
|
|
|
if (pos[idx] != -1) { |
|
|
seq_pos_add(idx); |
|
|
} |
|
|
|
|
|
assert(shift[idx] == 0); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void rm(uint32_t i) { |
|
|
assert(i < pos.size()); |
|
|
assert(pos[i] != -1); |
|
|
|
|
|
seq_pos_rm(i); |
|
|
seq[i].reset(); |
|
|
|
|
|
pos[i] = -1; |
|
|
shift[i] = 0; |
|
|
|
|
|
used.erase(i); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool seq_rm(uint32_t i, llama_seq_id seq_id) { |
|
|
assert(i < pos.size()); |
|
|
assert(seq[i].test(seq_id)); |
|
|
assert(pos[i] != -1); |
|
|
assert(seq_id >= 0); |
|
|
|
|
|
seq[i].reset(seq_id); |
|
|
seq_pos_dec(seq_id, pos[i]); |
|
|
|
|
|
if (seq[i].none()) { |
|
|
pos[i] = -1; |
|
|
shift[i] = 0; |
|
|
|
|
|
used.erase(i); |
|
|
|
|
|
return true; |
|
|
} |
|
|
|
|
|
return false; |
|
|
} |
|
|
|
|
|
|
|
|
bool seq_keep(uint32_t i, llama_seq_id seq_id) { |
|
|
assert(i < pos.size()); |
|
|
|
|
|
if (seq[i].test(seq_id)) { |
|
|
seq_pos_rm(i); |
|
|
seq[i].reset(); |
|
|
|
|
|
seq[i].set(seq_id); |
|
|
seq_pos_inc(seq_id, pos[i]); |
|
|
|
|
|
return false; |
|
|
} |
|
|
|
|
|
if (seq[i].any()) { |
|
|
seq_pos_rm(i); |
|
|
seq[i].reset(); |
|
|
|
|
|
pos[i] = -1; |
|
|
shift[i] = 0; |
|
|
|
|
|
used.erase(i); |
|
|
|
|
|
return true; |
|
|
} |
|
|
|
|
|
assert(pos[i] == -1); |
|
|
|
|
|
return false; |
|
|
} |
|
|
|
|
|
|
|
|
int seq_count(uint32_t i) const { |
|
|
assert(i < pos.size()); |
|
|
assert(pos[i] != -1); |
|
|
|
|
|
return seq[i].count(); |
|
|
} |
|
|
|
|
|
|
|
|
bool seq_has(uint32_t i, llama_seq_id seq_id) const { |
|
|
assert(i < pos.size()); |
|
|
assert(seq_id >= 0); |
|
|
|
|
|
return seq[i].test(seq_id); |
|
|
} |
|
|
|
|
|
|
|
|
void seq_add(uint32_t i, llama_seq_id seq_id) { |
|
|
assert(i < pos.size()); |
|
|
assert(pos[i] != -1); |
|
|
assert(!seq[i].test(seq_id)); |
|
|
|
|
|
seq[i].set(seq_id); |
|
|
seq_pos_inc(seq_id, pos[i]); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
llama_seq_id seq_get(uint32_t i) const { |
|
|
assert(seq[i].count() == 1); |
|
|
|
|
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
|
|
if (seq[i].test(s)) { |
|
|
return s; |
|
|
} |
|
|
} |
|
|
|
|
|
return -1; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
llama_pos seq_pos_min(llama_seq_id seq_id) const { |
|
|
assert(seq_id >= 0); |
|
|
assert(seq_id < LLAMA_MAX_SEQ); |
|
|
|
|
|
if (seq_pos[seq_id].empty()) { |
|
|
return -1; |
|
|
} |
|
|
|
|
|
assert(seq_pos[seq_id].begin()->second > 0); |
|
|
|
|
|
return seq_pos[seq_id].begin()->first; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
llama_pos seq_pos_max(llama_seq_id seq_id) const { |
|
|
assert(seq_id >= 0); |
|
|
assert(seq_id < LLAMA_MAX_SEQ); |
|
|
|
|
|
if (seq_pos[seq_id].empty()) { |
|
|
return -1; |
|
|
} |
|
|
|
|
|
assert(seq_pos[seq_id].rbegin()->second > 0); |
|
|
|
|
|
return seq_pos[seq_id].rbegin()->first; |
|
|
} |
|
|
|
|
|
|
|
|
llama_pos pos_get(uint32_t i) const { |
|
|
assert(i < pos.size()); |
|
|
assert(pos[i] != -1); |
|
|
|
|
|
return pos[i]; |
|
|
} |
|
|
|
|
|
|
|
|
llama_pos get_shift(uint32_t i) const { |
|
|
assert(i < pos.size()); |
|
|
assert(pos[i] != -1); |
|
|
|
|
|
return shift[i]; |
|
|
} |
|
|
|
|
|
|
|
|
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { |
|
|
assert(i < pos.size()); |
|
|
|
|
|
return pos[i] >= p0 && pos[i] < p1; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void pos_set(uint32_t i, llama_pos p) { |
|
|
assert(i < pos.size()); |
|
|
assert(pos[i] == -1); |
|
|
assert(seq[i].none()); |
|
|
|
|
|
pos[i] = p; |
|
|
|
|
|
used.insert(i); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool pos_add(uint32_t i, llama_pos d) { |
|
|
assert(i < pos.size()); |
|
|
assert(pos[i] != -1); |
|
|
|
|
|
seq_pos_rm(i); |
|
|
|
|
|
pos[i] += d; |
|
|
shift[i] += d; |
|
|
|
|
|
has_shift = true; |
|
|
|
|
|
if (pos[i] < 0) { |
|
|
seq[i].reset(); |
|
|
pos[i] = -1; |
|
|
shift[i] = 0; |
|
|
|
|
|
used.erase(i); |
|
|
|
|
|
return true; |
|
|
} |
|
|
|
|
|
seq_pos_add(i); |
|
|
|
|
|
return false; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void pos_div(uint32_t i, int d) { |
|
|
assert(i < pos.size()); |
|
|
assert(pos[i] != -1); |
|
|
|
|
|
const llama_pos p_old = pos[i]; |
|
|
|
|
|
seq_pos_rm(i); |
|
|
|
|
|
pos[i] /= d; |
|
|
shift[i] += p_old - pos[i]; |
|
|
|
|
|
seq_pos_add(i); |
|
|
|
|
|
has_shift = true; |
|
|
} |
|
|
|
|
|
private: |
|
|
bool has_shift = false; |
|
|
|
|
|
|
|
|
std::set<uint32_t> used; |
|
|
|
|
|
std::vector<llama_pos> pos; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<llama_pos> shift; |
|
|
|
|
|
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; |
|
|
|
|
|
|
|
|
std::vector<seq_set_t> seq; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ]; |
|
|
|
|
|
|
|
|
|
|
|
void seq_pos_dec(llama_seq_id s, llama_pos p) { |
|
|
auto it = seq_pos[s].find(p); |
|
|
assert(it != seq_pos[s].end()); |
|
|
|
|
|
if (--it->second == 0) { |
|
|
seq_pos[s].erase(it); |
|
|
} |
|
|
} |
|
|
|
|
|
void seq_pos_inc(llama_seq_id s, llama_pos p) { |
|
|
seq_pos[s][p]++; |
|
|
} |
|
|
|
|
|
|
|
|
void seq_pos_rm(uint32_t i) { |
|
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
|
|
if (seq[i].test(s)) { |
|
|
seq_pos_dec(s, pos[i]); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void seq_pos_add(uint32_t i) { |
|
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
|
|
if (seq[i].test(s)) { |
|
|
seq_pos_inc(s, pos[i]); |
|
|
} |
|
|
} |
|
|
} |
|
|
}; |
|
|
|