File size: 1,923 Bytes
d4ae547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from transformers.cache_utils import DynamicCache
import streamlit as st
from model_utils import get_kv_cache, load_model_and_tokenizer

def clean_up(cache, origin_len):
    """
    Remove any tokens appended to the original knowledge
    """
    for i in range(len(cache.key_cache)):
        cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
        cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
    return cache

def clone_cache(cache):
    """
    Create a deep copy of a DynamicCache object
    """
    new_cache = DynamicCache()
    for key, value in zip(cache.key_cache, cache.value_cache):
        new_cache.key_cache.append(key.clone())
        new_cache.value_cache.append(value.clone())
    return new_cache

# Load document and create cache if not already done
@st.cache_resource
def create_cache_from_text(doc_text):
    """
    Create a KV cache from document text
    """
    model, tokenizer = load_model_and_tokenizer()
    
    system_prompt = f"""
    <|system|>
    Answer concisely and precisely, You are an assistant who provides concise factual answers.
    <|user|>
    Context:
    {doc_text}
    Question:
    """.strip()
    
    cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
    return cache, origin_len

def save_cache(cache, origin_len, filename="saved_cache.pth"):
    """
    Save the cache to a file
    """
    cache_to_save = clean_up(clone_cache(cache), origin_len)
    torch.save(cache_to_save, filename)
    return filename

def load_cache(cache_file):
    """
    Load a cache from a file
    """
    try:
        with open("temp_cache.pth", "wb") as f:
            f.write(cache_file.getvalue())
        
        loaded_cache = torch.load("temp_cache.pth")
        origin_len = loaded_cache.key_cache[0].shape[-2]
        return loaded_cache, origin_len, True
    except Exception as e:
        return None, None, False