Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- llama_cpp_python_streamingllm.py +49 -23
llama_cpp_python_streamingllm.py
CHANGED
|
@@ -31,7 +31,7 @@ def get_complete_UTF8(all_text):
|
|
| 31 |
class StreamingLLM(Llama):
|
| 32 |
def __init__(self, model_path: str, **kwargs):
|
| 33 |
super().__init__(model_path, **kwargs)
|
| 34 |
-
self.
|
| 35 |
|
| 36 |
def str_detokenize(self, tokens) -> str:
|
| 37 |
return get_complete_UTF8(self.detokenize(tokens))
|
|
@@ -39,38 +39,63 @@ class StreamingLLM(Llama):
|
|
| 39 |
def kv_cache_seq_trim(self):
|
| 40 |
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
| 41 |
|
| 42 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
self.venv.append(0)
|
| 44 |
-
|
|
|
|
| 45 |
|
| 46 |
-
def venv_disband(self):
|
| 47 |
if len(self.venv) <= 1:
|
| 48 |
-
return
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
def venv_remove(self,
|
| 54 |
-
if
|
| 55 |
-
|
| 56 |
-
if
|
| 57 |
-
return
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
self.
|
| 61 |
-
self.
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
def venv_pop_token(self):
|
| 70 |
self.n_tokens -= 1
|
| 71 |
self.venv[-1] -= 1
|
| 72 |
self.kv_cache_seq_trim()
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
|
| 75 |
if n_past < 0:
|
| 76 |
n_past = self.n_tokens
|
|
@@ -274,6 +299,7 @@ class StreamingLLM(Llama):
|
|
| 274 |
n_tokens)
|
| 275 |
self.n_tokens = n_tokens.contents.value
|
| 276 |
self.input_ids[:self.n_tokens] = tokens[:self.n_tokens]
|
|
|
|
| 277 |
return retn
|
| 278 |
|
| 279 |
def save_session(self, filepath: str):
|
|
|
|
| 31 |
class StreamingLLM(Llama):
|
| 32 |
def __init__(self, model_path: str, **kwargs):
|
| 33 |
super().__init__(model_path, **kwargs)
|
| 34 |
+
self._venv_init()
|
| 35 |
|
| 36 |
def str_detokenize(self, tokens) -> str:
|
| 37 |
return get_complete_UTF8(self.detokenize(tokens))
|
|
|
|
| 39 |
def kv_cache_seq_trim(self):
|
| 40 |
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
| 41 |
|
| 42 |
+
def _venv_init(self):
|
| 43 |
+
self.venv = [0]
|
| 44 |
+
self.venv_idx_map = []
|
| 45 |
+
|
| 46 |
+
def venv_create(self, name: str):
|
| 47 |
self.venv.append(0)
|
| 48 |
+
self.venv_idx_map.append(name)
|
| 49 |
+
return name
|
| 50 |
|
| 51 |
+
def venv_disband(self, name_set):
|
| 52 |
if len(self.venv) <= 1:
|
| 53 |
+
return name_set
|
| 54 |
+
name_set = {x for x in name_set if x in self.venv_idx_map}
|
| 55 |
+
if not name_set:
|
| 56 |
+
return name_set
|
| 57 |
+
while self.venv_idx_map:
|
| 58 |
+
if self.venv_idx_map[0] in name_set:
|
| 59 |
+
self.venv_idx_map.pop(0) # 删除
|
| 60 |
+
tmp = self.venv.pop(1) # 对应的 venv 移入上一层
|
| 61 |
+
self.venv[0] += tmp
|
| 62 |
+
else:
|
| 63 |
+
break
|
| 64 |
+
return name_set
|
| 65 |
|
| 66 |
+
def venv_remove(self, name: str):
|
| 67 |
+
if len(self.venv) <= 1:
|
| 68 |
+
return name
|
| 69 |
+
if name not in self.venv_idx_map:
|
| 70 |
+
return name
|
| 71 |
+
venv_idx = self.venv_idx_map.index(name) + 1
|
| 72 |
+
while self.venv_idx_map:
|
| 73 |
+
self.venv_idx_map.pop(venv_idx - 1) # 删除
|
| 74 |
+
if venv_idx == len(self.venv) - 1:
|
| 75 |
+
# 最后一层
|
| 76 |
+
self.n_tokens -= min(self.venv.pop(), self.n_tokens)
|
| 77 |
+
self.kv_cache_seq_trim()
|
| 78 |
+
break
|
| 79 |
+
else:
|
| 80 |
+
# 非最后一层
|
| 81 |
+
n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv)))
|
| 82 |
+
n_discard = self.venv.pop(venv_idx)
|
| 83 |
+
self.kv_cache_seq_ltrim(n_keep, n_discard)
|
| 84 |
+
try:
|
| 85 |
+
venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1
|
| 86 |
+
except ValueError: # 没有了
|
| 87 |
+
break
|
| 88 |
+
return name
|
| 89 |
|
| 90 |
def venv_pop_token(self):
|
| 91 |
self.n_tokens -= 1
|
| 92 |
self.venv[-1] -= 1
|
| 93 |
self.kv_cache_seq_trim()
|
| 94 |
|
| 95 |
+
@property
|
| 96 |
+
def venv_info(self):
|
| 97 |
+
return str((self.n_tokens, self.venv, self.venv_idx_map))
|
| 98 |
+
|
| 99 |
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
|
| 100 |
if n_past < 0:
|
| 101 |
n_past = self.n_tokens
|
|
|
|
| 299 |
n_tokens)
|
| 300 |
self.n_tokens = n_tokens.contents.value
|
| 301 |
self.input_ids[:self.n_tokens] = tokens[:self.n_tokens]
|
| 302 |
+
self._venv_init()
|
| 303 |
return retn
|
| 304 |
|
| 305 |
def save_session(self, filepath: str):
|