fifth_try_CAG / cache_utils.py
kouki321's picture
Create cache_utils.py
d4ae547 verified
import torch
from transformers.cache_utils import DynamicCache
import streamlit as st
from model_utils import get_kv_cache, load_model_and_tokenizer
def clean_up(cache, origin_len):
"""
Remove any tokens appended to the original knowledge
"""
for i in range(len(cache.key_cache)):
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
return cache
def clone_cache(cache):
"""
Create a deep copy of a DynamicCache object
"""
new_cache = DynamicCache()
for key, value in zip(cache.key_cache, cache.value_cache):
new_cache.key_cache.append(key.clone())
new_cache.value_cache.append(value.clone())
return new_cache
# Load document and create cache if not already done
@st.cache_resource
def create_cache_from_text(doc_text):
"""
Create a KV cache from document text
"""
model, tokenizer = load_model_and_tokenizer()
system_prompt = f"""
<|system|>
Answer concisely and precisely, You are an assistant who provides concise factual answers.
<|user|>
Context:
{doc_text}
Question:
""".strip()
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
return cache, origin_len
def save_cache(cache, origin_len, filename="saved_cache.pth"):
"""
Save the cache to a file
"""
cache_to_save = clean_up(clone_cache(cache), origin_len)
torch.save(cache_to_save, filename)
return filename
def load_cache(cache_file):
"""
Load a cache from a file
"""
try:
with open("temp_cache.pth", "wb") as f:
f.write(cache_file.getvalue())
loaded_cache = torch.load("temp_cache.pth")
origin_len = loaded_cache.key_cache[0].shape[-2]
return loaded_cache, origin_len, True
except Exception as e:
return None, None, False