rbbist's picture
Update app.py
2c0aa0c verified
import streamlit as st
import os
import tempfile
from typing import List, Optional
import pickle
# Core libraries
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
# Document loaders
from langchain.document_loaders import PyPDFLoader
# Configure Streamlit page
st.set_page_config(
page_title="PDF RAG System",
page_icon="πŸ“š",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
color: #1f77b4;
text-align: center;
margin-bottom: 2rem;
}
.sidebar-header {
font-size: 1.5rem;
color: #ff7f0e;
margin-bottom: 1rem;
}
.success-message {
padding: 1rem;
background-color: #d4edda;
border: 1px solid #c3e6cb;
border-radius: 0.5rem;
color: #155724;
margin: 1rem 0;
}
.error-message {
padding: 1rem;
background-color: #f8d7da;
border: 1px solid #f5c6cb;
border-radius: 0.5rem;
color: #721c24;
margin: 1rem 0;
}
.source-box {
background-color: #f8f9fa;
border-left: 4px solid #007bff;
padding: 1rem;
margin: 0.5rem 0;
border-radius: 0 0.5rem 0.5rem 0;
}
</style>
""", unsafe_allow_html=True)
# Initialize session state
if 'qa_chain' not in st.session_state:
st.session_state.qa_chain = None
if 'vectorstore' not in st.session_state:
st.session_state.vectorstore = None
if 'documents_processed' not in st.session_state:
st.session_state.documents_processed = False
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
@st.cache_resource
def setup_llm(model_name="google/flan-t5-small"):
"""Setup the language model for text generation"""
with st.spinner("πŸ€– Loading language model..."):
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=300,
temperature=0.3,
do_sample=True,
device=-1
)
llm = HuggingFacePipeline(pipeline=pipe)
return llm
except Exception as e:
st.error(f"Error loading model: {e}")
return None
@st.cache_resource
def setup_embeddings(model_name="all-MiniLM-L6-v2"):
"""Setup the embedding model for vector generation"""
with st.spinner("πŸ”’ Loading embedding model..."):
try:
embeddings = HuggingFaceEmbeddings(model_name=model_name)
return embeddings
except Exception as e:
st.error(f"Error loading embeddings: {e}")
return None
def process_uploaded_files(uploaded_files, embeddings):
"""Process uploaded PDF files and create FAISS vector store"""
if not uploaded_files:
return None, []
documents = []
# Process each uploaded file
for uploaded_file in uploaded_files:
try:
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(uploaded_file.read())
tmp_file_path = tmp_file.name
# Load PDF
loader = PyPDFLoader(tmp_file_path)
docs = loader.load()
# Add file name to metadata
for doc in docs:
doc.metadata['source_file'] = uploaded_file.name
documents.extend(docs)
# Clean up temporary file
os.unlink(tmp_file_path)
st.success(f"βœ… Processed: {uploaded_file.name} ({len(docs)} pages)")
except Exception as e:
st.error(f"❌ Error processing {uploaded_file.name}: {e}")
if not documents:
return None, []
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
text_chunks = text_splitter.split_documents(documents)
# Add metadata to chunks
for i, text in enumerate(text_chunks):
text.metadata.update({
"chunk_id": i,
"chunk_size": len(text.page_content)
})
st.info(f"βœ‚οΈ Created {len(text_chunks)} text chunks")
# Create FAISS vector store
try:
vectorstore = FAISS.from_documents(text_chunks, embeddings)
st.success(f"βœ… Successfully created vector database with {len(text_chunks)} chunks!")
return vectorstore, text_chunks
except Exception as e:
st.error(f"❌ Error creating vector database: {e}")
return None, []
def create_qa_chain(llm, vectorstore, k=5):
"""Create a question-answering chain with retrieval"""
if not vectorstore or not llm:
return None
prompt_template = """Use the following context to answer the question. If you cannot find the answer in the context, say "I cannot find this information in the provided documents."
Context: {context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
try:
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={"k": k}),
chain_type_kwargs={"prompt": PROMPT},
return_source_documents=True
)
return qa_chain
except Exception as e:
st.error(f"Error creating QA chain: {e}")
return None
def ask_question(qa_chain, question):
"""Ask a question and get an answer with sources"""
if not qa_chain:
return None
try:
result = qa_chain({"query": question})
response = {
"question": question,
"answer": result["result"],
"source_documents": result.get("source_documents", [])
}
return response
except Exception as e:
st.error(f"❌ Error processing question: {e}")
return None
def search_similar_chunks(vectorstore, query, k=5):
"""Search for similar chunks without generating an answer"""
if not vectorstore:
return []
try:
results = vectorstore.similarity_search(query, k=k)
return results
except Exception as e:
st.error(f"Error searching: {e}")
return []
# Main App Interface
def main():
st.markdown('<h1 class="main-header">πŸ“š PDF RAG System</h1>', unsafe_allow_html=True)
st.markdown("Upload PDF documents and ask questions about their content using AI-powered retrieval!")
# Sidebar for configuration
with st.sidebar:
st.markdown('<h2 class="sidebar-header">βš™οΈ Configuration</h2>', unsafe_allow_html=True)
# Model configuration
st.subheader("πŸ€– Model Settings")
llm_model = st.selectbox(
"Language Model",
["google/flan-t5-small", "google/flan-t5-base"],
help="Choose the language model (smaller models are faster)"
)
embedding_model = st.selectbox(
"Embedding Model",
["all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2"],
help="Choose the embedding model"
)
retrieval_k = st.slider(
"Number of chunks to retrieve",
min_value=1,
max_value=10,
value=5,
help="How many relevant chunks to use for answering questions"
)
st.subheader("πŸ’Ύ Vector Store")
st.info("Using FAISS (local vector storage)")
# Option to save/load vector store
if st.session_state.vectorstore:
if st.button("πŸ’Ύ Save Vector Store"):
try:
# Save vector store to session state or file
st.session_state.vectorstore.save_local("faiss_index")
st.success("Vector store saved!")
except Exception as e:
st.error(f"Error saving: {e}")
# Main content area
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("πŸ“ Upload Documents")
uploaded_files = st.file_uploader(
"Choose PDF files",
type=['pdf'],
accept_multiple_files=True,
help="Upload one or more PDF files to analyze"
)
if st.button("πŸš€ Process Documents", type="primary"):
if not uploaded_files:
st.warning("Please upload at least one PDF file.")
else:
with st.spinner("Processing documents..."):
# Setup models
llm = setup_llm(llm_model)
embeddings = setup_embeddings(embedding_model)
if llm and embeddings:
# Process files
vectorstore, text_chunks = process_uploaded_files(uploaded_files, embeddings)
if vectorstore:
# Create QA chain
qa_chain = create_qa_chain(llm, vectorstore, k=retrieval_k)
if qa_chain:
# Store in session state
st.session_state.qa_chain = qa_chain
st.session_state.vectorstore = vectorstore
st.session_state.documents_processed = True
st.balloons()
st.success("πŸŽ‰ Documents processed successfully! You can now ask questions.")
else:
st.error("Failed to create QA chain.")
else:
st.error("Failed to load models.")
with col2:
st.subheader("πŸ’¬ Ask Questions")
if st.session_state.documents_processed:
question = st.text_input(
"Your question:",
placeholder="What are the main topics discussed in the documents?",
help="Ask any question about your uploaded documents"
)
col2a, col2b = st.columns([1, 1])
with col2a:
if st.button("πŸ” Get Answer"):
if question:
with st.spinner("Searching for answer..."):
result = ask_question(st.session_state.qa_chain, question)
if result:
# Add to chat history
st.session_state.chat_history.append({
"question": question,
"answer": result["answer"],
"sources": result["source_documents"]
})
# Display answer
st.subheader("πŸ’‘ Answer:")
st.write(result["answer"])
# Display sources
if result["source_documents"]:
st.subheader("πŸ“š Sources:")
for i, doc in enumerate(result["source_documents"][:3]):
with st.expander(f"Source {i+1}: {doc.metadata.get('source_file', 'Unknown')}"):
st.write(doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content)
else:
st.warning("Please enter a question.")
with col2b:
if st.button("πŸ” Search Similar"):
if question:
with st.spinner("Searching for similar content..."):
results = search_similar_chunks(st.session_state.vectorstore, question, k=5)
if results:
st.subheader("πŸ” Similar Content:")
for i, doc in enumerate(results):
with st.expander(f"Match {i+1}: {doc.metadata.get('source_file', 'Unknown')}"):
st.write(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content)
else:
st.info("πŸ‘† Please upload and process documents first to start asking questions.")
# Chat History
if st.session_state.chat_history:
st.subheader("πŸ“ Chat History")
for i, chat in enumerate(reversed(st.session_state.chat_history[-5:])): # Show last 5
with st.expander(f"Q: {chat['question'][:50]}..."):
st.write("**Question:**", chat['question'])
st.write("**Answer:**", chat['answer'])
if chat['sources']:
st.write("**Sources:**")
for j, doc in enumerate(chat['sources'][:2]): # Show top 2 sources
st.write(f"{j+1}. {doc.metadata.get('source_file', 'Unknown')}")
# Clear session button
if st.session_state.documents_processed:
if st.button("πŸ—‘οΈ Clear Session"):
st.session_state.qa_chain = None
st.session_state.vectorstore = None
st.session_state.documents_processed = False
st.session_state.chat_history = []
st.success("Session cleared! You can upload new documents.")
st.rerun()
if __name__ == "__main__":
main()