Spaces:
Runtime error
Runtime error
Upload llama_cpp_python_streamingllm.py
Browse files- llama_cpp_python_streamingllm.py +543 -8
llama_cpp_python_streamingllm.py
CHANGED
|
@@ -1,14 +1,516 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
-
from llama_cpp import
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from KMP_list import kmp_search, compute_lps_array
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class StreamingLLM(Llama):
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
self._venv_init()
|
| 13 |
|
| 14 |
def str_detokenize(self, tokens) -> str:
|
|
@@ -63,8 +565,9 @@ class StreamingLLM(Llama):
|
|
| 63 |
if name not in self.venv_idx_map:
|
| 64 |
return False
|
| 65 |
venv_idx = self.venv_idx_map.index(name) + 1
|
|
|
|
| 66 |
while self.venv_idx_map:
|
| 67 |
-
if keep_last and
|
| 68 |
break # 保留最后n个
|
| 69 |
self.venv_idx_map.pop(venv_idx - 1) # 删除
|
| 70 |
if venv_idx == len(self.venv) - 1:
|
|
@@ -81,6 +584,7 @@ class StreamingLLM(Llama):
|
|
| 81 |
venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1
|
| 82 |
except ValueError: # 没有了
|
| 83 |
break
|
|
|
|
| 84 |
return True
|
| 85 |
|
| 86 |
def venv_pop_token(self, n=1):
|
|
@@ -92,6 +596,36 @@ class StreamingLLM(Llama):
|
|
| 92 |
def venv_info(self):
|
| 93 |
return str((self.n_tokens, self.venv, self.venv_idx_map))
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
|
| 96 |
if n_keep < 0:
|
| 97 |
return
|
|
@@ -106,6 +640,7 @@ class StreamingLLM(Llama):
|
|
| 106 |
_idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
|
| 107 |
if _idx >= n_keep:
|
| 108 |
n_keep = _idx + len(im_start) # 至少保留一个 im_start 序列结构
|
|
|
|
| 109 |
self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
|
| 110 |
self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
|
| 111 |
self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
|
|
@@ -287,7 +822,7 @@ class StreamingLLM(Llama):
|
|
| 287 |
tokens = [token]
|
| 288 |
|
| 289 |
def load_session(self, filepath: str):
|
| 290 |
-
n_tokens = POINTER(
|
| 291 |
tokens = (llama_cpp.llama_token * self.n_ctx())()
|
| 292 |
retn = llama_cpp.llama_load_session_file(self._ctx.ctx,
|
| 293 |
filepath.encode('utf-8'),
|
|
|
|
| 1 |
+
from llama_cpp import *
|
| 2 |
+
from ctypes import POINTER, c_size_t
|
| 3 |
+
from llama_cpp._internals import (
|
| 4 |
+
_LlamaModel, # type: ignore
|
| 5 |
+
_LlamaContext, # type: ignore
|
| 6 |
+
_LlamaBatch, # type: ignore
|
| 7 |
+
_LlamaTokenDataArray, # type: ignore
|
| 8 |
+
)
|
| 9 |
|
| 10 |
from KMP_list import kmp_search, compute_lps_array
|
| 11 |
+
from Turbo_Colormap import map_value_to_color, NOCOLOR, LEGEND, BACK_WHITE
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LLMGenerate:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
model,
|
| 18 |
+
n_keep,
|
| 19 |
+
n_discard: int = 256,
|
| 20 |
+
im_start=None,
|
| 21 |
+
top_k: int = 40,
|
| 22 |
+
top_p: float = 0.95,
|
| 23 |
+
min_p: float = 0.05,
|
| 24 |
+
typical_p: float = 1.0,
|
| 25 |
+
temp: float = 0.80,
|
| 26 |
+
repeat_penalty: float = 1.1,
|
| 27 |
+
repeat_last_n: int = 64,
|
| 28 |
+
frequency_penalty: float = 0.0,
|
| 29 |
+
presence_penalty: float = 0.0,
|
| 30 |
+
tfs_z: float = 1.0,
|
| 31 |
+
mirostat_mode: int = 0,
|
| 32 |
+
mirostat_tau: float = 5.0,
|
| 33 |
+
mirostat_eta: float = 0.1
|
| 34 |
+
):
|
| 35 |
+
def _eval_t(tokens):
|
| 36 |
+
return model.eval_t(
|
| 37 |
+
tokens=tokens,
|
| 38 |
+
n_keep=n_keep,
|
| 39 |
+
n_discard=n_discard,
|
| 40 |
+
im_start=im_start
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def _sample_t(logits_processor):
|
| 44 |
+
return model.sample_t(
|
| 45 |
+
top_k=top_k,
|
| 46 |
+
top_p=top_p,
|
| 47 |
+
min_p=min_p,
|
| 48 |
+
typical_p=typical_p,
|
| 49 |
+
temp=temp,
|
| 50 |
+
repeat_penalty=repeat_penalty,
|
| 51 |
+
repeat_last_n=repeat_last_n,
|
| 52 |
+
frequency_penalty=frequency_penalty,
|
| 53 |
+
presence_penalty=presence_penalty,
|
| 54 |
+
tfs_z=tfs_z,
|
| 55 |
+
mirostat_mode=mirostat_mode,
|
| 56 |
+
mirostat_tau=mirostat_tau,
|
| 57 |
+
mirostat_eta=mirostat_eta,
|
| 58 |
+
logits_processor=logits_processor
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self._eval_t = _eval_t
|
| 62 |
+
self._sample_t = _sample_t
|
| 63 |
+
self.str_detokenize = model.str_detokenize
|
| 64 |
+
self.venv_pop_token = model.venv_pop_token
|
| 65 |
+
# ========== 保存输出 ==========
|
| 66 |
+
self.t_bot = []
|
| 67 |
+
self.completion_tokens = []
|
| 68 |
+
self.history = ''
|
| 69 |
+
self.token = None
|
| 70 |
+
|
| 71 |
+
def eval_t(self, tokens):
|
| 72 |
+
# ========== 避免不完整的utf-8编码 ==========
|
| 73 |
+
self.completion_tokens.extend(tokens)
|
| 74 |
+
all_text = self.str_detokenize(self.completion_tokens)
|
| 75 |
+
if all_text:
|
| 76 |
+
self.t_bot.extend(self.completion_tokens)
|
| 77 |
+
self.history += all_text
|
| 78 |
+
self.completion_tokens = []
|
| 79 |
+
return self._eval_t(tokens)
|
| 80 |
+
|
| 81 |
+
def sample_t(self, logits_processor):
|
| 82 |
+
self.token = self._sample_t(logits_processor)
|
| 83 |
+
return self.token
|
| 84 |
+
|
| 85 |
+
def detokenize_sample_t(self):
|
| 86 |
+
self.completion_tokens.append(self.token)
|
| 87 |
+
all_text = self.str_detokenize(self.completion_tokens)
|
| 88 |
+
if not all_text:
|
| 89 |
+
return False
|
| 90 |
+
self.t_bot.extend(self.completion_tokens)
|
| 91 |
+
self.history += all_text
|
| 92 |
+
self.completion_tokens = []
|
| 93 |
+
return True
|
| 94 |
+
|
| 95 |
+
def eval_sample_t(self):
|
| 96 |
+
return self._eval_t([self.token])
|
| 97 |
+
|
| 98 |
+
def endswith_t(self, token_list):
|
| 99 |
+
return self.token in token_list
|
| 100 |
+
|
| 101 |
+
def endswith_s(self, start_func, str_list, com_func=str.rstrip):
|
| 102 |
+
if self.completion_tokens: # 不完整
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
history = self.history
|
| 106 |
+
t_bot = self.t_bot
|
| 107 |
+
|
| 108 |
+
if start_func(history):
|
| 109 |
+
history = com_func(history)
|
| 110 |
+
for x in str_list:
|
| 111 |
+
if history.endswith(x):
|
| 112 |
+
n = len(t_bot)
|
| 113 |
+
for i in range(1, n): # 找出需要弃置的tokens长度
|
| 114 |
+
tmp = self.str_detokenize(t_bot[n - i:])
|
| 115 |
+
tmp = com_func(tmp)
|
| 116 |
+
if tmp.endswith(x):
|
| 117 |
+
if i > 1: # 最后一个token并未进入kv_cache
|
| 118 |
+
self.venv_pop_token(i - 1)
|
| 119 |
+
if history.endswith(tmp):
|
| 120 |
+
self.history = history[:-len(tmp)] # 移除末尾的tmp
|
| 121 |
+
return True
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
|
| 125 |
+
kv_cache_type = {
|
| 126 |
+
'f32': 0,
|
| 127 |
+
'f16': 1,
|
| 128 |
+
'q8_0': 8,
|
| 129 |
+
'q4_0': 2,
|
| 130 |
+
'q4_1': 3,
|
| 131 |
+
'iq4_nl': 20,
|
| 132 |
+
'q5_0': 6,
|
| 133 |
+
'q5_1': 7
|
| 134 |
+
}
|
| 135 |
|
| 136 |
class StreamingLLM(Llama):
|
| 137 |
+
|
| 138 |
+
__backend_initialized = False
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
model_path: str,
|
| 143 |
+
*,
|
| 144 |
+
# Model Params
|
| 145 |
+
n_gpu_layers: int = 0,
|
| 146 |
+
split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER,
|
| 147 |
+
main_gpu: int = 0,
|
| 148 |
+
tensor_split: Optional[List[float]] = None,
|
| 149 |
+
vocab_only: bool = False,
|
| 150 |
+
use_mmap: bool = True,
|
| 151 |
+
use_mlock: bool = False,
|
| 152 |
+
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
|
| 153 |
+
# Context Params
|
| 154 |
+
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
|
| 155 |
+
n_ctx: int = 512,
|
| 156 |
+
n_batch: int = 512,
|
| 157 |
+
n_threads: Optional[int] = None,
|
| 158 |
+
n_threads_batch: Optional[int] = None,
|
| 159 |
+
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
| 160 |
+
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
|
| 161 |
+
rope_freq_base: float = 0.0,
|
| 162 |
+
rope_freq_scale: float = 0.0,
|
| 163 |
+
yarn_ext_factor: float = -1.0,
|
| 164 |
+
yarn_attn_factor: float = 1.0,
|
| 165 |
+
yarn_beta_fast: float = 32.0,
|
| 166 |
+
yarn_beta_slow: float = 1.0,
|
| 167 |
+
yarn_orig_ctx: int = 0,
|
| 168 |
+
logits_all: bool = False,
|
| 169 |
+
embedding: bool = False,
|
| 170 |
+
offload_kqv: bool = True,
|
| 171 |
+
# Sampling Params
|
| 172 |
+
last_n_tokens_size: int = 64,
|
| 173 |
+
# LoRA Params
|
| 174 |
+
lora_base: Optional[str] = None,
|
| 175 |
+
lora_scale: float = 1.0,
|
| 176 |
+
lora_path: Optional[str] = None,
|
| 177 |
+
# Backend Params
|
| 178 |
+
numa: Union[bool, int] = False,
|
| 179 |
+
# Chat Format Params
|
| 180 |
+
chat_format: Optional[str] = None,
|
| 181 |
+
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
| 182 |
+
# Speculative Decoding
|
| 183 |
+
draft_model: Optional[LlamaDraftModel] = None,
|
| 184 |
+
# Tokenizer Override
|
| 185 |
+
tokenizer: Optional[BaseLlamaTokenizer] = None,
|
| 186 |
+
# Misc
|
| 187 |
+
verbose: bool = True,
|
| 188 |
+
# Extra Params
|
| 189 |
+
type_k: str = 'f16',
|
| 190 |
+
type_v: str = 'f16',
|
| 191 |
+
**kwargs, # type: ignore
|
| 192 |
+
):
|
| 193 |
+
"""Load a llama.cpp model from `model_path`.
|
| 194 |
+
|
| 195 |
+
Examples:
|
| 196 |
+
Basic usage
|
| 197 |
+
|
| 198 |
+
>>> import llama_cpp
|
| 199 |
+
>>> model = llama_cpp.Llama(
|
| 200 |
+
... model_path="path/to/model",
|
| 201 |
+
... )
|
| 202 |
+
>>> print(model("The quick brown fox jumps ", stop=["."])["choices"][0]["text"])
|
| 203 |
+
the lazy dog
|
| 204 |
+
|
| 205 |
+
Loading a chat model
|
| 206 |
+
|
| 207 |
+
>>> import llama_cpp
|
| 208 |
+
>>> model = llama_cpp.Llama(
|
| 209 |
+
... model_path="path/to/model",
|
| 210 |
+
... chat_format="llama-2",
|
| 211 |
+
... )
|
| 212 |
+
>>> print(model.create_chat_completion(
|
| 213 |
+
... messages=[{
|
| 214 |
+
... "role": "user",
|
| 215 |
+
... "content": "what is the meaning of life?"
|
| 216 |
+
... }]
|
| 217 |
+
... ))
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
model_path: Path to the model.
|
| 221 |
+
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
|
| 222 |
+
split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options.
|
| 223 |
+
main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_LAYER: ignored
|
| 224 |
+
tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
|
| 225 |
+
vocab_only: Only load the vocabulary no weights.
|
| 226 |
+
use_mmap: Use mmap if possible.
|
| 227 |
+
use_mlock: Force the system to keep the model in RAM.
|
| 228 |
+
kv_overrides: Key-value overrides for the model.
|
| 229 |
+
seed: RNG seed, -1 for random
|
| 230 |
+
n_ctx: Text context, 0 = from model
|
| 231 |
+
n_batch: Prompt processing maximum batch size
|
| 232 |
+
n_threads: Number of threads to use for generation
|
| 233 |
+
n_threads_batch: Number of threads to use for batch processing
|
| 234 |
+
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
| 235 |
+
pooling_type: Pooling type, from `enum llama_pooling_type`.
|
| 236 |
+
rope_freq_base: RoPE base frequency, 0 = from model
|
| 237 |
+
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
|
| 238 |
+
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
|
| 239 |
+
yarn_attn_factor: YaRN magnitude scaling factor
|
| 240 |
+
yarn_beta_fast: YaRN low correction dim
|
| 241 |
+
yarn_beta_slow: YaRN high correction dim
|
| 242 |
+
yarn_orig_ctx: YaRN original context size
|
| 243 |
+
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
|
| 244 |
+
embedding: Embedding mode only.
|
| 245 |
+
offload_kqv: Offload K, Q, V to GPU.
|
| 246 |
+
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
| 247 |
+
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
|
| 248 |
+
lora_path: Path to a LoRA file to apply to the model.
|
| 249 |
+
numa: numa policy
|
| 250 |
+
chat_format: String specifying the chat format to use when calling create_chat_completion.
|
| 251 |
+
chat_handler: Optional chat handler to use when calling create_chat_completion.
|
| 252 |
+
draft_model: Optional draft model to use for speculative decoding.
|
| 253 |
+
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
|
| 254 |
+
verbose: Print verbose output to stderr.
|
| 255 |
+
|
| 256 |
+
Raises:
|
| 257 |
+
ValueError: If the model path does not exist.
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
A Llama instance.
|
| 261 |
+
"""
|
| 262 |
+
self.verbose = verbose
|
| 263 |
+
|
| 264 |
+
set_verbose(verbose)
|
| 265 |
+
|
| 266 |
+
if not StreamingLLM.__backend_initialized:
|
| 267 |
+
with suppress_stdout_stderr(disable=verbose):
|
| 268 |
+
llama_cpp.llama_backend_init()
|
| 269 |
+
StreamingLLM.__backend_initialized = True
|
| 270 |
+
|
| 271 |
+
if isinstance(numa, bool):
|
| 272 |
+
self.numa = (
|
| 273 |
+
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
|
| 274 |
+
if numa
|
| 275 |
+
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
|
| 276 |
+
)
|
| 277 |
+
else:
|
| 278 |
+
self.numa = numa
|
| 279 |
+
|
| 280 |
+
if self.numa != llama_cpp.GGML_NUMA_STRATEGY_DISABLED:
|
| 281 |
+
with suppress_stdout_stderr(disable=verbose):
|
| 282 |
+
llama_cpp.llama_numa_init(self.numa)
|
| 283 |
+
|
| 284 |
+
self.model_path = model_path
|
| 285 |
+
|
| 286 |
+
# Model Params
|
| 287 |
+
self.model_params = llama_cpp.llama_model_default_params()
|
| 288 |
+
self.model_params.n_gpu_layers = (
|
| 289 |
+
0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
|
| 290 |
+
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
|
| 291 |
+
self.model_params.split_mode = split_mode
|
| 292 |
+
self.model_params.main_gpu = main_gpu
|
| 293 |
+
self.tensor_split = tensor_split
|
| 294 |
+
self._c_tensor_split = None
|
| 295 |
+
if self.tensor_split is not None:
|
| 296 |
+
if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}"
|
| 299 |
+
)
|
| 300 |
+
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
|
| 301 |
+
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
|
| 302 |
+
self._c_tensor_split = FloatArray(
|
| 303 |
+
*tensor_split # type: ignore
|
| 304 |
+
) # keep a reference to the array so it is not gc'd
|
| 305 |
+
self.model_params.tensor_split = self._c_tensor_split
|
| 306 |
+
self.model_params.vocab_only = vocab_only
|
| 307 |
+
self.model_params.use_mmap = use_mmap if lora_path is None else False
|
| 308 |
+
self.model_params.use_mlock = use_mlock
|
| 309 |
+
|
| 310 |
+
# kv_overrides is the original python dict
|
| 311 |
+
self.kv_overrides = kv_overrides
|
| 312 |
+
if kv_overrides is not None:
|
| 313 |
+
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
|
| 314 |
+
kvo_array_len = len(kv_overrides) + 1 # for sentinel element
|
| 315 |
+
self._kv_overrides_array = (
|
| 316 |
+
llama_cpp.llama_model_kv_override * kvo_array_len
|
| 317 |
+
)()
|
| 318 |
+
|
| 319 |
+
for i, (k, v) in enumerate(kv_overrides.items()):
|
| 320 |
+
self._kv_overrides_array[i].key = k.encode("utf-8")
|
| 321 |
+
if isinstance(v, bool):
|
| 322 |
+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
|
| 323 |
+
self._kv_overrides_array[i].value.bool_value = v
|
| 324 |
+
elif isinstance(v, int):
|
| 325 |
+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
|
| 326 |
+
self._kv_overrides_array[i].value.int_value = v
|
| 327 |
+
elif isinstance(v, float):
|
| 328 |
+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
|
| 329 |
+
self._kv_overrides_array[i].value.float_value = v
|
| 330 |
+
else:
|
| 331 |
+
raise ValueError(f"Unknown value type for {k}: {v}")
|
| 332 |
+
|
| 333 |
+
self._kv_overrides_array[-1].key = (
|
| 334 |
+
b"\0" # ensure sentinel element is zeroed
|
| 335 |
+
)
|
| 336 |
+
self.model_params.kv_overrides = self._kv_overrides_array
|
| 337 |
+
|
| 338 |
+
self.n_batch = min(n_ctx, n_batch) # ???
|
| 339 |
+
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
| 340 |
+
self.n_threads_batch = n_threads_batch or max(
|
| 341 |
+
multiprocessing.cpu_count() // 2, 1
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Context Params
|
| 345 |
+
self.context_params = llama_cpp.llama_context_default_params()
|
| 346 |
+
self.context_params.seed = seed
|
| 347 |
+
self.context_params.n_ctx = n_ctx
|
| 348 |
+
self.context_params.n_batch = self.n_batch
|
| 349 |
+
self.context_params.n_threads = self.n_threads
|
| 350 |
+
self.context_params.n_threads_batch = self.n_threads_batch
|
| 351 |
+
self.context_params.rope_scaling_type = (
|
| 352 |
+
rope_scaling_type
|
| 353 |
+
if rope_scaling_type is not None
|
| 354 |
+
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
| 355 |
+
)
|
| 356 |
+
self.context_params.pooling_type = pooling_type
|
| 357 |
+
self.context_params.rope_freq_base = (
|
| 358 |
+
rope_freq_base if rope_freq_base != 0.0 else 0
|
| 359 |
+
)
|
| 360 |
+
self.context_params.rope_freq_scale = (
|
| 361 |
+
rope_freq_scale if rope_freq_scale != 0.0 else 0
|
| 362 |
+
)
|
| 363 |
+
self.context_params.yarn_ext_factor = (
|
| 364 |
+
yarn_ext_factor if yarn_ext_factor != 0.0 else 0
|
| 365 |
+
)
|
| 366 |
+
self.context_params.yarn_attn_factor = (
|
| 367 |
+
yarn_attn_factor if yarn_attn_factor != 0.0 else 0
|
| 368 |
+
)
|
| 369 |
+
self.context_params.yarn_beta_fast = (
|
| 370 |
+
yarn_beta_fast if yarn_beta_fast != 0.0 else 0
|
| 371 |
+
)
|
| 372 |
+
self.context_params.yarn_beta_slow = (
|
| 373 |
+
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
|
| 374 |
+
)
|
| 375 |
+
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
|
| 376 |
+
self.context_params.logits_all = (
|
| 377 |
+
logits_all if draft_model is None else True
|
| 378 |
+
) # Must be set to True for speculative decoding
|
| 379 |
+
self.context_params.embeddings = embedding # TODO: Rename to embeddings
|
| 380 |
+
|
| 381 |
+
# KV cache quantization
|
| 382 |
+
print(self.context_params.type_k, self.context_params.type_v)
|
| 383 |
+
self.context_params.type_k = kv_cache_type[type_k]
|
| 384 |
+
self.context_params.type_v = kv_cache_type[type_v]
|
| 385 |
+
|
| 386 |
+
self.context_params.offload_kqv = offload_kqv
|
| 387 |
+
|
| 388 |
+
# Sampling Params
|
| 389 |
+
self.last_n_tokens_size = last_n_tokens_size
|
| 390 |
+
|
| 391 |
+
self.cache: Optional[BaseLlamaCache] = None
|
| 392 |
+
|
| 393 |
+
self.lora_base = lora_base
|
| 394 |
+
self.lora_scale = lora_scale
|
| 395 |
+
self.lora_path = lora_path
|
| 396 |
+
|
| 397 |
+
if not os.path.exists(model_path):
|
| 398 |
+
raise ValueError(f"Model path does not exist: {model_path}")
|
| 399 |
+
|
| 400 |
+
self._model = _LlamaModel(
|
| 401 |
+
path_model=self.model_path, params=self.model_params, verbose=self.verbose
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Override tokenizer
|
| 405 |
+
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
|
| 406 |
+
|
| 407 |
+
# Set the default value for the context and correct the batch
|
| 408 |
+
if n_ctx == 0:
|
| 409 |
+
n_ctx = self._model.n_ctx_train()
|
| 410 |
+
self.n_batch = min(n_ctx, n_batch)
|
| 411 |
+
self.context_params.n_ctx = self._model.n_ctx_train()
|
| 412 |
+
self.context_params.n_batch = self.n_batch
|
| 413 |
+
|
| 414 |
+
self._ctx = _LlamaContext(
|
| 415 |
+
model=self._model,
|
| 416 |
+
params=self.context_params,
|
| 417 |
+
verbose=self.verbose,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
self._batch = _LlamaBatch(
|
| 421 |
+
n_tokens=self.n_batch,
|
| 422 |
+
embd=0,
|
| 423 |
+
n_seq_max=self.context_params.n_ctx,
|
| 424 |
+
verbose=self.verbose,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if self.lora_path:
|
| 428 |
+
if self._model.apply_lora_from_file(
|
| 429 |
+
self.lora_path,
|
| 430 |
+
self.lora_scale,
|
| 431 |
+
self.lora_base,
|
| 432 |
+
self.n_threads,
|
| 433 |
+
):
|
| 434 |
+
raise RuntimeError(
|
| 435 |
+
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if self.verbose:
|
| 439 |
+
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
| 440 |
+
|
| 441 |
+
self.chat_format = chat_format
|
| 442 |
+
self.chat_handler = chat_handler
|
| 443 |
+
|
| 444 |
+
self.draft_model = draft_model
|
| 445 |
+
|
| 446 |
+
self._n_vocab = self.n_vocab()
|
| 447 |
+
self._n_ctx = self.n_ctx()
|
| 448 |
+
|
| 449 |
+
self._token_nl = self.token_nl()
|
| 450 |
+
self._token_eos = self.token_eos()
|
| 451 |
+
|
| 452 |
+
self._candidates = _LlamaTokenDataArray(n_vocab=self._n_vocab)
|
| 453 |
+
|
| 454 |
+
self.n_tokens = 0
|
| 455 |
+
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
|
| 456 |
+
self.scores: npt.NDArray[np.single] = np.ndarray(
|
| 457 |
+
(n_ctx, self._n_vocab), dtype=np.single
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
self._mirostat_mu = ctypes.c_float(
|
| 461 |
+
2.0 * 5.0
|
| 462 |
+
) # TODO: Move this to sampling context
|
| 463 |
+
|
| 464 |
+
try:
|
| 465 |
+
self.metadata = self._model.metadata()
|
| 466 |
+
except Exception as e:
|
| 467 |
+
self.metadata = {}
|
| 468 |
+
if self.verbose:
|
| 469 |
+
print(f"Failed to load metadata: {e}", file=sys.stderr)
|
| 470 |
+
|
| 471 |
+
if self.verbose:
|
| 472 |
+
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
| 473 |
+
|
| 474 |
+
if (
|
| 475 |
+
self.chat_format is None
|
| 476 |
+
and self.chat_handler is None
|
| 477 |
+
and "tokenizer.chat_template" in self.metadata
|
| 478 |
+
):
|
| 479 |
+
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
|
| 480 |
+
self.metadata
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if chat_format is not None:
|
| 484 |
+
self.chat_format = chat_format
|
| 485 |
+
if self.verbose:
|
| 486 |
+
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
|
| 487 |
+
else:
|
| 488 |
+
template = self.metadata["tokenizer.chat_template"]
|
| 489 |
+
try:
|
| 490 |
+
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
|
| 491 |
+
except:
|
| 492 |
+
eos_token_id = self.token_eos()
|
| 493 |
+
try:
|
| 494 |
+
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
|
| 495 |
+
except:
|
| 496 |
+
bos_token_id = self.token_bos()
|
| 497 |
+
|
| 498 |
+
eos_token = self._model.token_get_text(eos_token_id)
|
| 499 |
+
bos_token = self._model.token_get_text(bos_token_id)
|
| 500 |
+
|
| 501 |
+
if self.verbose:
|
| 502 |
+
print(f"Using gguf chat template: {template}", file=sys.stderr)
|
| 503 |
+
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
|
| 504 |
+
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
| 505 |
+
|
| 506 |
+
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
| 507 |
+
template=template, eos_token=eos_token, bos_token=bos_token
|
| 508 |
+
).to_chat_handler()
|
| 509 |
+
|
| 510 |
+
if self.chat_format is None and self.chat_handler is None:
|
| 511 |
+
self.chat_format = "llama-2"
|
| 512 |
+
if self.verbose:
|
| 513 |
+
print(f"Using fallback chat format: {chat_format}", file=sys.stderr)
|
| 514 |
self._venv_init()
|
| 515 |
|
| 516 |
def str_detokenize(self, tokens) -> str:
|
|
|
|
| 565 |
if name not in self.venv_idx_map:
|
| 566 |
return False
|
| 567 |
venv_idx = self.venv_idx_map.index(name) + 1
|
| 568 |
+
count_name = self.venv_idx_map.count(name) if keep_last else 0
|
| 569 |
while self.venv_idx_map:
|
| 570 |
+
if keep_last and count_name <= keep_last:
|
| 571 |
break # 保留最后n个
|
| 572 |
self.venv_idx_map.pop(venv_idx - 1) # 删除
|
| 573 |
if venv_idx == len(self.venv) - 1:
|
|
|
|
| 584 |
venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1
|
| 585 |
except ValueError: # 没有了
|
| 586 |
break
|
| 587 |
+
count_name -= 1 # 计数减一
|
| 588 |
return True
|
| 589 |
|
| 590 |
def venv_pop_token(self, n=1):
|
|
|
|
| 596 |
def venv_info(self):
|
| 597 |
return str((self.n_tokens, self.venv, self.venv_idx_map))
|
| 598 |
|
| 599 |
+
def venv_viz(self):
|
| 600 |
+
completion_tokens = []
|
| 601 |
+
history = LEGEND + '\n'
|
| 602 |
+
text_color = NOCOLOR
|
| 603 |
+
for i in range(self.venv[-1]):
|
| 604 |
+
idx = self.n_tokens - self.venv[-1] + i
|
| 605 |
+
token = self._input_ids[idx]
|
| 606 |
+
if not completion_tokens: # 不完整则是第一个token
|
| 607 |
+
# ========== 获取对应token的概率 ==========
|
| 608 |
+
score = self.scores[idx-1: idx, :].ravel() # 第i个token的分数是前i-1个token预测的,所以减一
|
| 609 |
+
score = np.exp(score) # 空白则全1,但无所谓了
|
| 610 |
+
sum_score = np.sum(score)
|
| 611 |
+
probabilities = score[token] / sum_score
|
| 612 |
+
if probabilities < 0.001:
|
| 613 |
+
text_color = NOCOLOR
|
| 614 |
+
else:
|
| 615 |
+
if text_color is NOCOLOR:
|
| 616 |
+
text_color = BACK_WHITE + map_value_to_color(probabilities)
|
| 617 |
+
else:
|
| 618 |
+
text_color = map_value_to_color(probabilities)
|
| 619 |
+
history += text_color
|
| 620 |
+
# ========== 避免不完整的utf-8编码 ==========
|
| 621 |
+
completion_tokens.append(token)
|
| 622 |
+
all_text = self.str_detokenize(completion_tokens)
|
| 623 |
+
if not all_text:
|
| 624 |
+
continue
|
| 625 |
+
completion_tokens = [] # 完整则清空缓存
|
| 626 |
+
history += repr(all_text)[1:-1]
|
| 627 |
+
return history + NOCOLOR
|
| 628 |
+
|
| 629 |
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
|
| 630 |
if n_keep < 0:
|
| 631 |
return
|
|
|
|
| 640 |
_idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
|
| 641 |
if _idx >= n_keep:
|
| 642 |
n_keep = _idx + len(im_start) # 至少保留一个 im_start 序列结构
|
| 643 |
+
print(im_start, n_keep, n_discard, _idx)
|
| 644 |
self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
|
| 645 |
self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
|
| 646 |
self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
|
|
|
|
| 822 |
tokens = [token]
|
| 823 |
|
| 824 |
def load_session(self, filepath: str):
|
| 825 |
+
n_tokens = POINTER(c_size_t)(c_size_t(0))
|
| 826 |
tokens = (llama_cpp.llama_token * self.n_ctx())()
|
| 827 |
retn = llama_cpp.llama_load_session_file(self._ctx.ctx,
|
| 828 |
filepath.encode('utf-8'),
|