import streamlit as st import os import tempfile from typing import List, Optional import pickle # Core libraries from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline from langchain.llms import HuggingFacePipeline from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import Document from langchain import PromptTemplate from langchain.chains import RetrievalQA from langchain.vectorstores import FAISS # Document loaders from langchain.document_loaders import PyPDFLoader # Configure Streamlit page st.set_page_config( page_title="PDF RAG System", page_icon="📚", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for better styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'qa_chain' not in st.session_state: st.session_state.qa_chain = None if 'vectorstore' not in st.session_state: st.session_state.vectorstore = None if 'documents_processed' not in st.session_state: st.session_state.documents_processed = False if 'chat_history' not in st.session_state: st.session_state.chat_history = [] @st.cache_resource def setup_llm(model_name="google/flan-t5-small"): """Setup the language model for text generation""" with st.spinner("🤖 Loading language model..."): try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) pipe = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300, temperature=0.3, do_sample=True, device=-1 ) llm = HuggingFacePipeline(pipeline=pipe) return llm except Exception as e: st.error(f"Error loading model: {e}") return None @st.cache_resource def setup_embeddings(model_name="all-MiniLM-L6-v2"): """Setup the embedding model for vector generation""" with st.spinner("🔢 Loading embedding model..."): try: embeddings = HuggingFaceEmbeddings(model_name=model_name) return embeddings except Exception as e: st.error(f"Error loading embeddings: {e}") return None def process_uploaded_files(uploaded_files, embeddings): """Process uploaded PDF files and create FAISS vector store""" if not uploaded_files: return None, [] documents = [] # Process each uploaded file for uploaded_file in uploaded_files: try: # Create temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: tmp_file.write(uploaded_file.read()) tmp_file_path = tmp_file.name # Load PDF loader = PyPDFLoader(tmp_file_path) docs = loader.load() # Add file name to metadata for doc in docs: doc.metadata['source_file'] = uploaded_file.name documents.extend(docs) # Clean up temporary file os.unlink(tmp_file_path) st.success(f"✅ Processed: {uploaded_file.name} ({len(docs)} pages)") except Exception as e: st.error(f"❌ Error processing {uploaded_file.name}: {e}") if not documents: return None, [] # Split documents into chunks text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len, separators=["\n\n", "\n", " ", ""] ) text_chunks = text_splitter.split_documents(documents) # Add metadata to chunks for i, text in enumerate(text_chunks): text.metadata.update({ "chunk_id": i, "chunk_size": len(text.page_content) }) st.info(f"✂️ Created {len(text_chunks)} text chunks") # Create FAISS vector store try: vectorstore = FAISS.from_documents(text_chunks, embeddings) st.success(f"✅ Successfully created vector database with {len(text_chunks)} chunks!") return vectorstore, text_chunks except Exception as e: st.error(f"❌ Error creating vector database: {e}") return None, [] def create_qa_chain(llm, vectorstore, k=5): """Create a question-answering chain with retrieval""" if not vectorstore or not llm: return None prompt_template = """Use the following context to answer the question. If you cannot find the answer in the context, say "I cannot find this information in the provided documents." Context: {context} Question: {question} Answer:""" PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) try: qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever(search_kwargs={"k": k}), chain_type_kwargs={"prompt": PROMPT}, return_source_documents=True ) return qa_chain except Exception as e: st.error(f"Error creating QA chain: {e}") return None def ask_question(qa_chain, question): """Ask a question and get an answer with sources""" if not qa_chain: return None try: result = qa_chain({"query": question}) response = { "question": question, "answer": result["result"], "source_documents": result.get("source_documents", []) } return response except Exception as e: st.error(f"❌ Error processing question: {e}") return None def search_similar_chunks(vectorstore, query, k=5): """Search for similar chunks without generating an answer""" if not vectorstore: return [] try: results = vectorstore.similarity_search(query, k=k) return results except Exception as e: st.error(f"Error searching: {e}") return [] # Main App Interface def main(): st.markdown('

📚 PDF RAG System

', unsafe_allow_html=True) st.markdown("Upload PDF documents and ask questions about their content using AI-powered retrieval!") # Sidebar for configuration with st.sidebar: st.markdown('', unsafe_allow_html=True) # Model configuration st.subheader("🤖 Model Settings") llm_model = st.selectbox( "Language Model", ["google/flan-t5-small", "google/flan-t5-base"], help="Choose the language model (smaller models are faster)" ) embedding_model = st.selectbox( "Embedding Model", ["all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2"], help="Choose the embedding model" ) retrieval_k = st.slider( "Number of chunks to retrieve", min_value=1, max_value=10, value=5, help="How many relevant chunks to use for answering questions" ) st.subheader("💾 Vector Store") st.info("Using FAISS (local vector storage)") # Option to save/load vector store if st.session_state.vectorstore: if st.button("💾 Save Vector Store"): try: # Save vector store to session state or file st.session_state.vectorstore.save_local("faiss_index") st.success("Vector store saved!") except Exception as e: st.error(f"Error saving: {e}") # Main content area col1, col2 = st.columns([1, 1]) with col1: st.subheader("📁 Upload Documents") uploaded_files = st.file_uploader( "Choose PDF files", type=['pdf'], accept_multiple_files=True, help="Upload one or more PDF files to analyze" ) if st.button("🚀 Process Documents", type="primary"): if not uploaded_files: st.warning("Please upload at least one PDF file.") else: with st.spinner("Processing documents..."): # Setup models llm = setup_llm(llm_model) embeddings = setup_embeddings(embedding_model) if llm and embeddings: # Process files vectorstore, text_chunks = process_uploaded_files(uploaded_files, embeddings) if vectorstore: # Create QA chain qa_chain = create_qa_chain(llm, vectorstore, k=retrieval_k) if qa_chain: # Store in session state st.session_state.qa_chain = qa_chain st.session_state.vectorstore = vectorstore st.session_state.documents_processed = True st.balloons() st.success("🎉 Documents processed successfully! You can now ask questions.") else: st.error("Failed to create QA chain.") else: st.error("Failed to load models.") with col2: st.subheader("💬 Ask Questions") if st.session_state.documents_processed: question = st.text_input( "Your question:", placeholder="What are the main topics discussed in the documents?", help="Ask any question about your uploaded documents" ) col2a, col2b = st.columns([1, 1]) with col2a: if st.button("🔍 Get Answer"): if question: with st.spinner("Searching for answer..."): result = ask_question(st.session_state.qa_chain, question) if result: # Add to chat history st.session_state.chat_history.append({ "question": question, "answer": result["answer"], "sources": result["source_documents"] }) # Display answer st.subheader("💡 Answer:") st.write(result["answer"]) # Display sources if result["source_documents"]: st.subheader("📚 Sources:") for i, doc in enumerate(result["source_documents"][:3]): with st.expander(f"Source {i+1}: {doc.metadata.get('source_file', 'Unknown')}"): st.write(doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content) else: st.warning("Please enter a question.") with col2b: if st.button("🔍 Search Similar"): if question: with st.spinner("Searching for similar content..."): results = search_similar_chunks(st.session_state.vectorstore, question, k=5) if results: st.subheader("🔍 Similar Content:") for i, doc in enumerate(results): with st.expander(f"Match {i+1}: {doc.metadata.get('source_file', 'Unknown')}"): st.write(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content) else: st.info("👆 Please upload and process documents first to start asking questions.") # Chat History if st.session_state.chat_history: st.subheader("📝 Chat History") for i, chat in enumerate(reversed(st.session_state.chat_history[-5:])): # Show last 5 with st.expander(f"Q: {chat['question'][:50]}..."): st.write("**Question:**", chat['question']) st.write("**Answer:**", chat['answer']) if chat['sources']: st.write("**Sources:**") for j, doc in enumerate(chat['sources'][:2]): # Show top 2 sources st.write(f"{j+1}. {doc.metadata.get('source_file', 'Unknown')}") # Clear session button if st.session_state.documents_processed: if st.button("🗑️ Clear Session"): st.session_state.qa_chain = None st.session_state.vectorstore = None st.session_state.documents_processed = False st.session_state.chat_history = [] st.success("Session cleared! You can upload new documents.") st.rerun() if __name__ == "__main__": main()