fifth_try_CAG / app.py
kouki321's picture
Update app.py
b5241c5 verified
import streamlit as st
import torch
from transformers.cache_utils import DynamicCache
import io
import os
from time import time
# Import from our utility modules
from model_utils import load_model_and_tokenizer, generate
from cache_utils import create_cache_from_text, clone_cache, clean_up, save_cache, load_cache
from document_utils import get_document_text
# Create the Streamlit app
st.title("DeepSeek QA with KV Cache")
st.write("Upload a document (PDF or TXT) and ask questions about it")
# File uploader for the document
file_type = st.radio("Select file type:", ["Text (.txt)", "PDF (.pdf)"])
uploaded_file = None
if file_type == "Text (.txt)":
uploaded_file = st.file_uploader("Upload your document", type="txt")
else:
uploaded_file = st.file_uploader("Upload your document", type="pdf")
doc_text = None
cache = None
origin_len = None
if uploaded_file:
with st.spinner("Processing document..."):
# Get document text
t1=time()
doc_text = get_document_text(uploaded_file, file_type)
if doc_text:
# Create cache from text
cache, origin_len = create_cache_from_text(doc_text)
# Display document preview
with st.expander("Document Preview"):
st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
# Get user query
query = st.text_input("Ask a question about the document:")
if query and st.button("Generate Answer"):
with st.spinner("Generating answer..."):
model, tokenizer = load_model_and_tokenizer()
# Use a copy of the cache to avoid modifying the original
#current_cache = DynamicCache()
#for i in range(len(cache.key_cache)):
# current_cache.key_cache.append(cache.key_cache[i].clone())
# current_cache.value_cache.append(cache.value_cache[i].clone())
# Prepare input with the query
full_prompt = f"""
<|user|>
Question: {query}
<|assistant|>
""".strip()
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
# Generate response
output_ids = generate(model, input_ids,
cache ###################################current_cache
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
t2=time()
# Display the response
st.success("Answer:")
st.write(response,t2-t1)
# Option to save the cache
if st.button("Save Cache"):
cache_file = save_cache(cache, origin_len)
# Provide download button for the saved cache
with open(cache_file, "rb") as f:
cache_bytes = f.read()
st.download_button(
label="Download Cache File",
data=cache_bytes,
file_name="document_cache.pth",
mime="application/octet-stream"
)
else:
st.info("Please upload a document to start.")
# Optionally, add a section to load a previously saved cache
st.sidebar.header("Advanced Options")
load_saved_cache = st.sidebar.checkbox("Load saved cache")
if load_saved_cache:
cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth")
doc_file = st.sidebar.file_uploader("Upload corresponding document", type=["txt", "pdf"])
if cache_file and doc_file:
loaded_cache, loaded_origin_len, success = load_cache(cache_file)
if success:
st.sidebar.success("Cache loaded successfully!")
# Get document text
if doc_file.name.endswith(".pdf"):
doc_text = get_document_text(doc_file, "PDF (.pdf)")
else:
doc_text = get_document_text(doc_file, "Text (.txt)")
# Show that we're ready to use the loaded cache
st.sidebar.info("Using pre-loaded cache and document")
cache = loaded_cache
origin_len = loaded_origin_len
# Display document preview
with st.expander("Document Preview (Loaded)"):
st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
else:
st.sidebar.error("Failed to load cache file")