Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers.cache_utils import DynamicCache | |
| import streamlit as st | |
| from model_utils import get_kv_cache, load_model_and_tokenizer | |
| def clean_up(cache, origin_len): | |
| """ | |
| Remove any tokens appended to the original knowledge | |
| """ | |
| for i in range(len(cache.key_cache)): | |
| cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :] | |
| cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :] | |
| return cache | |
| def clone_cache(cache): | |
| """ | |
| Create a deep copy of a DynamicCache object | |
| """ | |
| new_cache = DynamicCache() | |
| for key, value in zip(cache.key_cache, cache.value_cache): | |
| new_cache.key_cache.append(key.clone()) | |
| new_cache.value_cache.append(value.clone()) | |
| return new_cache | |
| # Load document and create cache if not already done | |
| def create_cache_from_text(doc_text): | |
| """ | |
| Create a KV cache from document text | |
| """ | |
| model, tokenizer = load_model_and_tokenizer() | |
| system_prompt = f""" | |
| <|system|> | |
| Answer concisely and precisely, You are an assistant who provides concise factual answers. | |
| <|user|> | |
| Context: | |
| {doc_text} | |
| Question: | |
| """.strip() | |
| cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) | |
| return cache, origin_len | |
| def save_cache(cache, origin_len, filename="saved_cache.pth"): | |
| """ | |
| Save the cache to a file | |
| """ | |
| cache_to_save = clean_up(clone_cache(cache), origin_len) | |
| torch.save(cache_to_save, filename) | |
| return filename | |
| def load_cache(cache_file): | |
| """ | |
| Load a cache from a file | |
| """ | |
| try: | |
| with open("temp_cache.pth", "wb") as f: | |
| f.write(cache_file.getvalue()) | |
| loaded_cache = torch.load("temp_cache.pth") | |
| origin_len = loaded_cache.key_cache[0].shape[-2] | |
| return loaded_cache, origin_len, True | |
| except Exception as e: | |
| return None, None, False |