Spaces:
Build error
Build error
| import streamlit as st | |
| import os | |
| import tempfile | |
| from typing import List, Optional | |
| import pickle | |
| # Core libraries | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema import Document | |
| from langchain import PromptTemplate | |
| from langchain.chains import RetrievalQA | |
| from langchain.vectorstores import FAISS | |
| # Document loaders | |
| from langchain.document_loaders import PyPDFLoader | |
| # Configure Streamlit page | |
| st.set_page_config( | |
| page_title="PDF RAG System", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for better styling | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| color: #1f77b4; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .sidebar-header { | |
| font-size: 1.5rem; | |
| color: #ff7f0e; | |
| margin-bottom: 1rem; | |
| } | |
| .success-message { | |
| padding: 1rem; | |
| background-color: #d4edda; | |
| border: 1px solid #c3e6cb; | |
| border-radius: 0.5rem; | |
| color: #155724; | |
| margin: 1rem 0; | |
| } | |
| .error-message { | |
| padding: 1rem; | |
| background-color: #f8d7da; | |
| border: 1px solid #f5c6cb; | |
| border-radius: 0.5rem; | |
| color: #721c24; | |
| margin: 1rem 0; | |
| } | |
| .source-box { | |
| background-color: #f8f9fa; | |
| border-left: 4px solid #007bff; | |
| padding: 1rem; | |
| margin: 0.5rem 0; | |
| border-radius: 0 0.5rem 0.5rem 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Initialize session state | |
| if 'qa_chain' not in st.session_state: | |
| st.session_state.qa_chain = None | |
| if 'vectorstore' not in st.session_state: | |
| st.session_state.vectorstore = None | |
| if 'documents_processed' not in st.session_state: | |
| st.session_state.documents_processed = False | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| def setup_llm(model_name="google/flan-t5-small"): | |
| """Setup the language model for text generation""" | |
| with st.spinner("π€ Loading language model..."): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| pipe = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=300, | |
| temperature=0.3, | |
| do_sample=True, | |
| device=-1 | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| return llm | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return None | |
| def setup_embeddings(model_name="all-MiniLM-L6-v2"): | |
| """Setup the embedding model for vector generation""" | |
| with st.spinner("π’ Loading embedding model..."): | |
| try: | |
| embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
| return embeddings | |
| except Exception as e: | |
| st.error(f"Error loading embeddings: {e}") | |
| return None | |
| def process_uploaded_files(uploaded_files, embeddings): | |
| """Process uploaded PDF files and create FAISS vector store""" | |
| if not uploaded_files: | |
| return None, [] | |
| documents = [] | |
| # Process each uploaded file | |
| for uploaded_file in uploaded_files: | |
| try: | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: | |
| tmp_file.write(uploaded_file.read()) | |
| tmp_file_path = tmp_file.name | |
| # Load PDF | |
| loader = PyPDFLoader(tmp_file_path) | |
| docs = loader.load() | |
| # Add file name to metadata | |
| for doc in docs: | |
| doc.metadata['source_file'] = uploaded_file.name | |
| documents.extend(docs) | |
| # Clean up temporary file | |
| os.unlink(tmp_file_path) | |
| st.success(f"β Processed: {uploaded_file.name} ({len(docs)} pages)") | |
| except Exception as e: | |
| st.error(f"β Error processing {uploaded_file.name}: {e}") | |
| if not documents: | |
| return None, [] | |
| # Split documents into chunks | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| text_chunks = text_splitter.split_documents(documents) | |
| # Add metadata to chunks | |
| for i, text in enumerate(text_chunks): | |
| text.metadata.update({ | |
| "chunk_id": i, | |
| "chunk_size": len(text.page_content) | |
| }) | |
| st.info(f"βοΈ Created {len(text_chunks)} text chunks") | |
| # Create FAISS vector store | |
| try: | |
| vectorstore = FAISS.from_documents(text_chunks, embeddings) | |
| st.success(f"β Successfully created vector database with {len(text_chunks)} chunks!") | |
| return vectorstore, text_chunks | |
| except Exception as e: | |
| st.error(f"β Error creating vector database: {e}") | |
| return None, [] | |
| def create_qa_chain(llm, vectorstore, k=5): | |
| """Create a question-answering chain with retrieval""" | |
| if not vectorstore or not llm: | |
| return None | |
| prompt_template = """Use the following context to answer the question. If you cannot find the answer in the context, say "I cannot find this information in the provided documents." | |
| Context: {context} | |
| Question: {question} | |
| Answer:""" | |
| PROMPT = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| try: | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vectorstore.as_retriever(search_kwargs={"k": k}), | |
| chain_type_kwargs={"prompt": PROMPT}, | |
| return_source_documents=True | |
| ) | |
| return qa_chain | |
| except Exception as e: | |
| st.error(f"Error creating QA chain: {e}") | |
| return None | |
| def ask_question(qa_chain, question): | |
| """Ask a question and get an answer with sources""" | |
| if not qa_chain: | |
| return None | |
| try: | |
| result = qa_chain({"query": question}) | |
| response = { | |
| "question": question, | |
| "answer": result["result"], | |
| "source_documents": result.get("source_documents", []) | |
| } | |
| return response | |
| except Exception as e: | |
| st.error(f"β Error processing question: {e}") | |
| return None | |
| def search_similar_chunks(vectorstore, query, k=5): | |
| """Search for similar chunks without generating an answer""" | |
| if not vectorstore: | |
| return [] | |
| try: | |
| results = vectorstore.similarity_search(query, k=k) | |
| return results | |
| except Exception as e: | |
| st.error(f"Error searching: {e}") | |
| return [] | |
| # Main App Interface | |
| def main(): | |
| st.markdown('<h1 class="main-header">π PDF RAG System</h1>', unsafe_allow_html=True) | |
| st.markdown("Upload PDF documents and ask questions about their content using AI-powered retrieval!") | |
| # Sidebar for configuration | |
| with st.sidebar: | |
| st.markdown('<h2 class="sidebar-header">βοΈ Configuration</h2>', unsafe_allow_html=True) | |
| # Model configuration | |
| st.subheader("π€ Model Settings") | |
| llm_model = st.selectbox( | |
| "Language Model", | |
| ["google/flan-t5-small", "google/flan-t5-base"], | |
| help="Choose the language model (smaller models are faster)" | |
| ) | |
| embedding_model = st.selectbox( | |
| "Embedding Model", | |
| ["all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2"], | |
| help="Choose the embedding model" | |
| ) | |
| retrieval_k = st.slider( | |
| "Number of chunks to retrieve", | |
| min_value=1, | |
| max_value=10, | |
| value=5, | |
| help="How many relevant chunks to use for answering questions" | |
| ) | |
| st.subheader("πΎ Vector Store") | |
| st.info("Using FAISS (local vector storage)") | |
| # Option to save/load vector store | |
| if st.session_state.vectorstore: | |
| if st.button("πΎ Save Vector Store"): | |
| try: | |
| # Save vector store to session state or file | |
| st.session_state.vectorstore.save_local("faiss_index") | |
| st.success("Vector store saved!") | |
| except Exception as e: | |
| st.error(f"Error saving: {e}") | |
| # Main content area | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader("π Upload Documents") | |
| uploaded_files = st.file_uploader( | |
| "Choose PDF files", | |
| type=['pdf'], | |
| accept_multiple_files=True, | |
| help="Upload one or more PDF files to analyze" | |
| ) | |
| if st.button("π Process Documents", type="primary"): | |
| if not uploaded_files: | |
| st.warning("Please upload at least one PDF file.") | |
| else: | |
| with st.spinner("Processing documents..."): | |
| # Setup models | |
| llm = setup_llm(llm_model) | |
| embeddings = setup_embeddings(embedding_model) | |
| if llm and embeddings: | |
| # Process files | |
| vectorstore, text_chunks = process_uploaded_files(uploaded_files, embeddings) | |
| if vectorstore: | |
| # Create QA chain | |
| qa_chain = create_qa_chain(llm, vectorstore, k=retrieval_k) | |
| if qa_chain: | |
| # Store in session state | |
| st.session_state.qa_chain = qa_chain | |
| st.session_state.vectorstore = vectorstore | |
| st.session_state.documents_processed = True | |
| st.balloons() | |
| st.success("π Documents processed successfully! You can now ask questions.") | |
| else: | |
| st.error("Failed to create QA chain.") | |
| else: | |
| st.error("Failed to load models.") | |
| with col2: | |
| st.subheader("π¬ Ask Questions") | |
| if st.session_state.documents_processed: | |
| question = st.text_input( | |
| "Your question:", | |
| placeholder="What are the main topics discussed in the documents?", | |
| help="Ask any question about your uploaded documents" | |
| ) | |
| col2a, col2b = st.columns([1, 1]) | |
| with col2a: | |
| if st.button("π Get Answer"): | |
| if question: | |
| with st.spinner("Searching for answer..."): | |
| result = ask_question(st.session_state.qa_chain, question) | |
| if result: | |
| # Add to chat history | |
| st.session_state.chat_history.append({ | |
| "question": question, | |
| "answer": result["answer"], | |
| "sources": result["source_documents"] | |
| }) | |
| # Display answer | |
| st.subheader("π‘ Answer:") | |
| st.write(result["answer"]) | |
| # Display sources | |
| if result["source_documents"]: | |
| st.subheader("π Sources:") | |
| for i, doc in enumerate(result["source_documents"][:3]): | |
| with st.expander(f"Source {i+1}: {doc.metadata.get('source_file', 'Unknown')}"): | |
| st.write(doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content) | |
| else: | |
| st.warning("Please enter a question.") | |
| with col2b: | |
| if st.button("π Search Similar"): | |
| if question: | |
| with st.spinner("Searching for similar content..."): | |
| results = search_similar_chunks(st.session_state.vectorstore, question, k=5) | |
| if results: | |
| st.subheader("π Similar Content:") | |
| for i, doc in enumerate(results): | |
| with st.expander(f"Match {i+1}: {doc.metadata.get('source_file', 'Unknown')}"): | |
| st.write(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content) | |
| else: | |
| st.info("π Please upload and process documents first to start asking questions.") | |
| # Chat History | |
| if st.session_state.chat_history: | |
| st.subheader("π Chat History") | |
| for i, chat in enumerate(reversed(st.session_state.chat_history[-5:])): # Show last 5 | |
| with st.expander(f"Q: {chat['question'][:50]}..."): | |
| st.write("**Question:**", chat['question']) | |
| st.write("**Answer:**", chat['answer']) | |
| if chat['sources']: | |
| st.write("**Sources:**") | |
| for j, doc in enumerate(chat['sources'][:2]): # Show top 2 sources | |
| st.write(f"{j+1}. {doc.metadata.get('source_file', 'Unknown')}") | |
| # Clear session button | |
| if st.session_state.documents_processed: | |
| if st.button("ποΈ Clear Session"): | |
| st.session_state.qa_chain = None | |
| st.session_state.vectorstore = None | |
| st.session_state.documents_processed = False | |
| st.session_state.chat_history = [] | |
| st.success("Session cleared! You can upload new documents.") | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() |