kouki321 commited on
Commit
d4ae547
·
verified ·
1 Parent(s): 0d5d472

Create cache_utils.py

Browse files
Files changed (1) hide show
  1. cache_utils.py +65 -0
cache_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.cache_utils import DynamicCache
3
+ import streamlit as st
4
+ from model_utils import get_kv_cache, load_model_and_tokenizer
5
+
6
+ def clean_up(cache, origin_len):
7
+ """
8
+ Remove any tokens appended to the original knowledge
9
+ """
10
+ for i in range(len(cache.key_cache)):
11
+ cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
12
+ cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
13
+ return cache
14
+
15
+ def clone_cache(cache):
16
+ """
17
+ Create a deep copy of a DynamicCache object
18
+ """
19
+ new_cache = DynamicCache()
20
+ for key, value in zip(cache.key_cache, cache.value_cache):
21
+ new_cache.key_cache.append(key.clone())
22
+ new_cache.value_cache.append(value.clone())
23
+ return new_cache
24
+
25
+ # Load document and create cache if not already done
26
+ @st.cache_resource
27
+ def create_cache_from_text(doc_text):
28
+ """
29
+ Create a KV cache from document text
30
+ """
31
+ model, tokenizer = load_model_and_tokenizer()
32
+
33
+ system_prompt = f"""
34
+ <|system|>
35
+ Answer concisely and precisely, You are an assistant who provides concise factual answers.
36
+ <|user|>
37
+ Context:
38
+ {doc_text}
39
+ Question:
40
+ """.strip()
41
+
42
+ cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
43
+ return cache, origin_len
44
+
45
+ def save_cache(cache, origin_len, filename="saved_cache.pth"):
46
+ """
47
+ Save the cache to a file
48
+ """
49
+ cache_to_save = clean_up(clone_cache(cache), origin_len)
50
+ torch.save(cache_to_save, filename)
51
+ return filename
52
+
53
+ def load_cache(cache_file):
54
+ """
55
+ Load a cache from a file
56
+ """
57
+ try:
58
+ with open("temp_cache.pth", "wb") as f:
59
+ f.write(cache_file.getvalue())
60
+
61
+ loaded_cache = torch.load("temp_cache.pth")
62
+ origin_len = loaded_cache.key_cache[0].shape[-2]
63
+ return loaded_cache, origin_len, True
64
+ except Exception as e:
65
+ return None, None, False