Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers.cache_utils import DynamicCache | |
| import io | |
| import os | |
| from time import time | |
| # Import from our utility modules | |
| from model_utils import load_model_and_tokenizer, generate | |
| from cache_utils import create_cache_from_text, clone_cache, clean_up, save_cache, load_cache | |
| from document_utils import get_document_text | |
| # Create the Streamlit app | |
| st.title("DeepSeek QA with KV Cache") | |
| st.write("Upload a document (PDF or TXT) and ask questions about it") | |
| # File uploader for the document | |
| file_type = st.radio("Select file type:", ["Text (.txt)", "PDF (.pdf)"]) | |
| uploaded_file = None | |
| if file_type == "Text (.txt)": | |
| uploaded_file = st.file_uploader("Upload your document", type="txt") | |
| else: | |
| uploaded_file = st.file_uploader("Upload your document", type="pdf") | |
| doc_text = None | |
| cache = None | |
| origin_len = None | |
| if uploaded_file: | |
| with st.spinner("Processing document..."): | |
| # Get document text | |
| t1=time() | |
| doc_text = get_document_text(uploaded_file, file_type) | |
| if doc_text: | |
| # Create cache from text | |
| cache, origin_len = create_cache_from_text(doc_text) | |
| # Display document preview | |
| with st.expander("Document Preview"): | |
| st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text) | |
| # Get user query | |
| query = st.text_input("Ask a question about the document:") | |
| if query and st.button("Generate Answer"): | |
| with st.spinner("Generating answer..."): | |
| model, tokenizer = load_model_and_tokenizer() | |
| # Use a copy of the cache to avoid modifying the original | |
| #current_cache = DynamicCache() | |
| #for i in range(len(cache.key_cache)): | |
| # current_cache.key_cache.append(cache.key_cache[i].clone()) | |
| # current_cache.value_cache.append(cache.value_cache[i].clone()) | |
| # Prepare input with the query | |
| full_prompt = f""" | |
| <|user|> | |
| Question: {query} | |
| <|assistant|> | |
| """.strip() | |
| input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids | |
| # Generate response | |
| output_ids = generate(model, input_ids, | |
| cache ###################################current_cache | |
| ) | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| t2=time() | |
| # Display the response | |
| st.success("Answer:") | |
| st.write(response,t2-t1) | |
| # Option to save the cache | |
| if st.button("Save Cache"): | |
| cache_file = save_cache(cache, origin_len) | |
| # Provide download button for the saved cache | |
| with open(cache_file, "rb") as f: | |
| cache_bytes = f.read() | |
| st.download_button( | |
| label="Download Cache File", | |
| data=cache_bytes, | |
| file_name="document_cache.pth", | |
| mime="application/octet-stream" | |
| ) | |
| else: | |
| st.info("Please upload a document to start.") | |
| # Optionally, add a section to load a previously saved cache | |
| st.sidebar.header("Advanced Options") | |
| load_saved_cache = st.sidebar.checkbox("Load saved cache") | |
| if load_saved_cache: | |
| cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth") | |
| doc_file = st.sidebar.file_uploader("Upload corresponding document", type=["txt", "pdf"]) | |
| if cache_file and doc_file: | |
| loaded_cache, loaded_origin_len, success = load_cache(cache_file) | |
| if success: | |
| st.sidebar.success("Cache loaded successfully!") | |
| # Get document text | |
| if doc_file.name.endswith(".pdf"): | |
| doc_text = get_document_text(doc_file, "PDF (.pdf)") | |
| else: | |
| doc_text = get_document_text(doc_file, "Text (.txt)") | |
| # Show that we're ready to use the loaded cache | |
| st.sidebar.info("Using pre-loaded cache and document") | |
| cache = loaded_cache | |
| origin_len = loaded_origin_len | |
| # Display document preview | |
| with st.expander("Document Preview (Loaded)"): | |
| st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text) | |
| else: | |
| st.sidebar.error("Failed to load cache file") |