| | import os |
| | import tempfile |
| | import streamlit as st |
| | import fitz |
| | from typing import List, Dict, Any, Optional |
| | from langchain_community.llms import HuggingFaceEndpoint |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from langchain_community.embeddings import HuggingFaceEmbeddings |
| | from langchain_community.vectorstores import Chroma |
| | from langchain.chains import ConversationalRetrievalChain |
| | from langchain.memory import ConversationBufferMemory |
| | from langchain.prompts import PromptTemplate |
| |
|
| | |
| | st.set_page_config( |
| | page_title="PDF Q&A Assistant", |
| | page_icon="๐", |
| | layout="wide" |
| | ) |
| |
|
| | |
| | if "chat_history" not in st.session_state: |
| | st.session_state.chat_history = [] |
| | if "conversation_chain" not in st.session_state: |
| | st.session_state.conversation_chain = None |
| | if "document_processed" not in st.session_state: |
| | st.session_state.document_processed = False |
| | if "file_names" not in st.session_state: |
| | st.session_state.file_names = [] |
| |
|
| | class PDFQAAssistant: |
| | def __init__(self, |
| | hf_token: str = None, |
| | model_name: str = "google/flan-t5-base", |
| | embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): |
| | """ |
| | Initialize the PDF Q&A Assistant with Hugging Face models. |
| | |
| | Args: |
| | hf_token: Hugging Face API token |
| | model_name: HF model to use for Q&A |
| | embedding_model_name: HF model to use for embeddings |
| | """ |
| | self.model_name = model_name |
| | self.embedding_model_name = embedding_model_name |
| | self.hf_token = hf_token |
| | |
| | |
| | self.persist_directory = os.path.join(tempfile.gettempdir(), "pdf_qa_vectorstore") |
| | |
| | |
| | self.llm = HuggingFaceEndpoint( |
| | repo_id=model_name, |
| | huggingfacehub_api_token=hf_token, |
| | max_length=512, |
| | temperature=0.5 |
| | ) |
| | |
| | |
| | self.embeddings = HuggingFaceEmbeddings( |
| | model_name=embedding_model_name, |
| | model_kwargs={'device': 'cpu'} |
| | ) |
| | |
| | |
| | self.text_splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=800, |
| | chunk_overlap=150, |
| | length_function=len |
| | ) |
| | |
| | |
| | self.vectorstore = None |
| | self.memory = ConversationBufferMemory( |
| | memory_key="chat_history", |
| | return_messages=True |
| | ) |
| | |
| | |
| | os.makedirs(self.persist_directory, exist_ok=True) |
| | |
| | def extract_text_from_pdf(self, pdf_file) -> str: |
| | """ |
| | Extract text from a PDF file using PyMuPDF. |
| | |
| | Args: |
| | pdf_file: Uploaded PDF file |
| | |
| | Returns: |
| | Extracted text as a string |
| | """ |
| | try: |
| | with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: |
| | tmp_file.write(pdf_file.getvalue()) |
| | tmp_path = tmp_file.name |
| | |
| | |
| | doc = fitz.open(tmp_path) |
| | |
| | |
| | text = "" |
| | for page_num, page in enumerate(doc): |
| | text += page.get_text() |
| | |
| | |
| | doc.close() |
| | os.unlink(tmp_path) |
| | |
| | return text |
| | |
| | except Exception as e: |
| | st.error(f"Error extracting text from PDF: {e}") |
| | raise |
| | |
| | def process_pdf(self, pdf_file, document_name: str) -> None: |
| | """ |
| | Process a PDF file and prepare it for question answering. |
| | |
| | Args: |
| | pdf_file: Uploaded PDF file |
| | document_name: Name to identify the document |
| | """ |
| | |
| | with st.status("Extracting text from PDF..."): |
| | text = self.extract_text_from_pdf(pdf_file) |
| | st.write(f"Extracted {len(text)} characters") |
| | |
| | |
| | with st.status("Splitting document into chunks..."): |
| | chunks = self.text_splitter.split_text(text) |
| | st.write(f"Document split into {len(chunks)} chunks") |
| | |
| | |
| | with st.status("Creating vector embeddings..."): |
| | |
| | metadatas = [{"source": document_name, "chunk": i} for i in range(len(chunks))] |
| | |
| | |
| | if self.vectorstore is None: |
| | self.vectorstore = Chroma.from_texts( |
| | texts=chunks, |
| | embedding=self.embeddings, |
| | metadatas=metadatas, |
| | persist_directory=self.persist_directory |
| | ) |
| | else: |
| | self.vectorstore.add_texts(texts=chunks, metadatas=metadatas) |
| | |
| | |
| | if hasattr(self.vectorstore, 'persist'): |
| | self.vectorstore.persist() |
| | |
| | |
| | with st.status("Setting up Q&A system..."): |
| | retriever = self.vectorstore.as_retriever( |
| | search_kwargs={"k": 4} |
| | ) |
| | |
| | |
| | qa_prompt = PromptTemplate( |
| | input_variables=["context", "question", "chat_history"], |
| | template=""" |
| | You are an AI assistant specializing in answering questions about documents. |
| | Use the following pieces of context to answer the question at the end. |
| | If you don't know the answer, just say you don't know. Don't try to make up an answer. |
| | Always cite the specific source or page number when possible. |
| | |
| | Context: |
| | {context} |
| | |
| | Chat History: |
| | {chat_history} |
| | |
| | Question: |
| | {question} |
| | |
| | Answer: |
| | """ |
| | ) |
| | |
| | self.conversation_chain = ConversationalRetrievalChain.from_llm( |
| | llm=self.llm, |
| | retriever=retriever, |
| | memory=self.memory, |
| | combine_docs_chain_kwargs={"prompt": qa_prompt}, |
| | return_source_documents=True |
| | ) |
| | |
| | |
| | st.session_state.conversation_chain = self.conversation_chain |
| | |
| | st.success(f"Successfully processed {document_name}") |
| | st.session_state.document_processed = True |
| | |
| | def ask(self, question: str) -> Dict[str, Any]: |
| | """ |
| | Ask a question about the loaded documents. |
| | |
| | Args: |
| | question: The question to ask |
| | |
| | Returns: |
| | Dictionary with the answer and source documents |
| | """ |
| | if self.conversation_chain is None: |
| | return {"answer": "Please load a document first before asking questions.", "sources": []} |
| | |
| | try: |
| | result = self.conversation_chain({"question": question}) |
| | |
| | |
| | sources = [] |
| | if "source_documents" in result: |
| | for doc in result["source_documents"]: |
| | source = doc.metadata.get("source", "Unknown") |
| | chunk = doc.metadata.get("chunk", "Unknown") |
| | if source not in [s["source"] for s in sources]: |
| | sources.append({"source": source, "chunk": chunk}) |
| | |
| | return { |
| | "answer": result["answer"], |
| | "sources": sources |
| | } |
| | |
| | except Exception as e: |
| | st.error(f"Error processing question: {e}") |
| | return {"answer": f"Error processing your question: {e}", "sources": []} |
| | |
| | def clear_memory(self) -> None: |
| | """Clear the conversation memory.""" |
| | self.memory.clear() |
| |
|
| | def get_document_summary(assistant, document_name): |
| | """Get a summary of the loaded document.""" |
| | st.subheader("Document Summary") |
| | |
| | with st.status("Generating document summary..."): |
| | questions = [ |
| | "What is the main topic of this document?", |
| | "What are the key points from this document?", |
| | "Could you provide a summary of this document in 3-5 bullet points?" |
| | ] |
| | |
| | for question in questions: |
| | result = assistant.ask(question) |
| | st.write(f"**{question}**") |
| | st.write(result["answer"]) |
| | st.divider() |
| |
|
| | |
| | def main(): |
| | st.title("๐ AI-Powered PDF Reader & Q&A Assistant") |
| | |
| | |
| | with st.sidebar: |
| | st.header("Settings") |
| | |
| | |
| | if "HF_TOKEN" in st.secrets: |
| | hf_token = st.secrets["HF_TOKEN"] |
| | token_source = "Using HF_TOKEN from app secrets" |
| | elif os.environ.get("HF_TOKEN"): |
| | hf_token = os.environ.get("HF_TOKEN") |
| | token_source = "Using HF_TOKEN from environment variables" |
| | else: |
| | hf_token = None |
| | token_source = "No HF_TOKEN found" |
| | |
| | st.info(token_source) |
| | |
| | |
| | use_manual_token = st.checkbox("Enter token manually", value=not hf_token) |
| | |
| | if use_manual_token: |
| | hf_token = st.text_input("Enter Hugging Face API Token:", type="password") |
| | |
| | |
| | st.subheader("Model Settings") |
| | model_name = st.selectbox( |
| | "Select LLM model:", |
| | [ |
| | "google/flan-t5-base", |
| | "google/flan-t5-small", |
| | "facebook/bart-large-cnn", |
| | "distilbert-base-uncased" |
| | ], |
| | index=0 |
| | ) |
| | |
| | embedding_model = st.selectbox( |
| | "Select Embedding model:", |
| | [ |
| | "sentence-transformers/all-MiniLM-L6-v2", |
| | "sentence-transformers/paraphrase-MiniLM-L3-v2" |
| | ], |
| | index=0 |
| | ) |
| | |
| | |
| | st.subheader("Upload Documents") |
| | uploaded_files = st.file_uploader("Upload PDF documents", |
| | type="pdf", |
| | accept_multiple_files=True) |
| | |
| | if uploaded_files: |
| | process_btn = st.button("Process Documents") |
| | if process_btn: |
| | if not hf_token: |
| | st.error("Please provide a valid Hugging Face API token.") |
| | else: |
| | |
| | try: |
| | assistant = PDFQAAssistant( |
| | hf_token=hf_token, |
| | model_name=model_name, |
| | embedding_model_name=embedding_model |
| | ) |
| | |
| | |
| | for pdf_file in uploaded_files: |
| | file_name = pdf_file.name |
| | if file_name not in st.session_state.file_names: |
| | st.session_state.file_names.append(file_name) |
| | assistant.process_pdf(pdf_file, file_name) |
| | |
| | |
| | st.session_state.assistant = assistant |
| | except Exception as e: |
| | st.error(f"Error initializing assistant: {e}") |
| | st.error("Try selecting a different model or check your token permissions.") |
| | |
| | |
| | if st.session_state.get("document_processed", False): |
| | st.subheader("Document Management") |
| | |
| | if st.button("Clear Chat History"): |
| | if "assistant" in st.session_state: |
| | st.session_state.assistant.clear_memory() |
| | st.session_state.chat_history = [] |
| | st.success("Chat history cleared!") |
| | |
| | if st.button("Generate Document Summary"): |
| | if "assistant" in st.session_state and len(st.session_state.file_names) > 0: |
| | get_document_summary(st.session_state.assistant, |
| | st.session_state.file_names[0]) |
| | |
| | |
| | if not st.session_state.get("document_processed", False): |
| | st.info("๐ Please upload and process a PDF document to get started.") |
| | |
| | |
| | st.header("How It Works") |
| | col1, col2, col3 = st.columns(3) |
| | |
| | with col1: |
| | st.subheader("1. Upload PDF") |
| | st.markdown("Upload any PDF document you want to query.") |
| | |
| | with col2: |
| | st.subheader("2. Process Document") |
| | st.markdown("The AI will extract text and create searchable embeddings.") |
| | |
| | with col3: |
| | st.subheader("3. Ask Questions") |
| | st.markdown("Ask any question about your document and get accurate answers.") |
| | else: |
| | |
| | st.header("Ask Questions About Your Documents") |
| | |
| | |
| | st.caption(f"Processed Files: {', '.join(st.session_state.file_names)}") |
| | |
| | |
| | for message in st.session_state.chat_history: |
| | if message["role"] == "user": |
| | st.chat_message("user").write(message["content"]) |
| | else: |
| | st.chat_message("assistant").write(message["content"]) |
| | if message.get("sources"): |
| | with st.expander("View Sources"): |
| | for source in message["sources"]: |
| | st.write(f"- {source['source']} (chunk {source['chunk']})") |
| | |
| | |
| | if question := st.chat_input("Ask a question about your documents..."): |
| | |
| | st.session_state.chat_history.append({ |
| | "role": "user", |
| | "content": question |
| | }) |
| | |
| | |
| | st.chat_message("user").write(question) |
| | |
| | |
| | with st.chat_message("assistant"): |
| | with st.spinner("Thinking..."): |
| | try: |
| | result = st.session_state.assistant.ask(question) |
| | |
| | st.write(result["answer"]) |
| | |
| | |
| | if result.get("sources"): |
| | with st.expander("View Sources"): |
| | for source in result["sources"]: |
| | st.write(f"- {source['source']} (chunk {source['chunk']})") |
| | |
| | |
| | st.session_state.chat_history.append({ |
| | "role": "assistant", |
| | "content": result["answer"], |
| | "sources": result.get("sources", []) |
| | }) |
| | except Exception as e: |
| | st.error(f"Error getting response: {e}") |
| | st.error("Please try a different question or model.") |
| |
|
| | if __name__ == "__main__": |
| | main() |