import torch from transformers.cache_utils import DynamicCache import streamlit as st from model_utils import get_kv_cache, load_model_and_tokenizer def clean_up(cache, origin_len): """ Remove any tokens appended to the original knowledge """ for i in range(len(cache.key_cache)): cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :] cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :] return cache def clone_cache(cache): """ Create a deep copy of a DynamicCache object """ new_cache = DynamicCache() for key, value in zip(cache.key_cache, cache.value_cache): new_cache.key_cache.append(key.clone()) new_cache.value_cache.append(value.clone()) return new_cache # Load document and create cache if not already done @st.cache_resource def create_cache_from_text(doc_text): """ Create a KV cache from document text """ model, tokenizer = load_model_and_tokenizer() system_prompt = f""" <|system|> Answer concisely and precisely, You are an assistant who provides concise factual answers. <|user|> Context: {doc_text} Question: """.strip() cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) return cache, origin_len def save_cache(cache, origin_len, filename="saved_cache.pth"): """ Save the cache to a file """ cache_to_save = clean_up(clone_cache(cache), origin_len) torch.save(cache_to_save, filename) return filename def load_cache(cache_file): """ Load a cache from a file """ try: with open("temp_cache.pth", "wb") as f: f.write(cache_file.getvalue()) loaded_cache = torch.load("temp_cache.pth") origin_len = loaded_cache.key_cache[0].shape[-2] return loaded_cache, origin_len, True except Exception as e: return None, None, False