enterprise-rag-assistant / src /streamlit_app.py
SimranShaikh's picture
commit
998a186 verified
# Improved SimplePDFRAG with better error handling and model optimization
import streamlit as st
import PyPDF2
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import logging
import os
import tempfile
import gc
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SimplePDFRAG:
def __init__(self):
self.documents = []
self.embeddings = []
self.embedding_model = None
self.granite_model = None
self.tokenizer = None
self.pdf_name = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def setup_cache_directory(self):
try:
cache_dir = tempfile.mkdtemp(prefix="model_cache_")
os.environ['HF_HOME'] = cache_dir
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['SENTENCE_TRANSFORMERS_HOME'] = cache_dir
st.info(f"Using cache directory: {cache_dir}")
st.info(f"Using device: {self.device}")
return cache_dir
except Exception as e:
st.error(f"Error setting up cache directory: {e}")
return None
def load_models(self):
try:
cache_dir = self.setup_cache_directory()
st.info("Loading embedding model...")
self.embedding_model = SentenceTransformer(
'all-MiniLM-L6-v2', cache_folder=cache_dir, device=self.device
)
st.info("Loading IBM Granite model...")
# Alternative models you could try:
# model_name = "ibm-granite/granite-3-8b-instruct" # Larger, better performance
# model_name = "microsoft/DialoGPT-medium"
# model_name = "google/flan-t5-base"
model_name = "ibm-granite/granite-3-2b-instruct"
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir,
trust_remote_code=True
)
# Optimize model loading based on available resources
model_kwargs = {
"cache_dir": cache_dir,
"trust_remote_code": True,
"low_cpu_mem_usage": True,
}
# Use appropriate dtype based on device
if self.device.type == "cuda":
model_kwargs["torch_dtype"] = torch.float16
else:
model_kwargs["torch_dtype"] = torch.float32
self.granite_model = AutoModelForCausalLM.from_pretrained(
model_name, **model_kwargs
).to(self.device)
# Set pad token if not available
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
st.success("Models loaded successfully!")
return True
except Exception as e:
st.error(f"Error loading models: {e}")
logger.error(f"Model loading error: {e}")
return False
def extract_pdf_text(self, pdf_file):
try:
pdf_file.seek(0)
pdf_reader = PyPDF2.PdfReader(pdf_file)
text = ""
st.info(f"PDF has {len(pdf_reader.pages)} pages")
progress_bar = st.progress(0)
for page_num, page in enumerate(pdf_reader.pages):
try:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
st.write(f"βœ… Extracted text from page {page_num + 1}")
else:
st.warning(f"⚠️ No text found on page {page_num + 1}")
except Exception as page_error:
st.error(f"Error extracting page {page_num + 1}: {page_error}")
# Update progress
progress_bar.progress((page_num + 1) / len(pdf_reader.pages))
progress_bar.empty()
if text.strip():
st.success(f"Extracted {len(text)} characters from {len(pdf_reader.pages)} pages")
st.write("πŸ“„ **Text Preview:**")
st.text(text[:500] + "..." if len(text) > 500 else text)
return text
else:
st.error("No text could be extracted from the PDF")
return None
except Exception as e:
st.error(f"Error reading PDF file: {e}")
logger.error(f"PDF extraction error: {e}")
return None
def chunk_text(self, text, chunk_size=400, overlap=50):
"""Improved chunking with overlap for better context preservation"""
if not text or not text.strip():
return []
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = " ".join(words[i:i + chunk_size])
if chunk.strip(): # Only add non-empty chunks
chunks.append(chunk)
return chunks
def process_pdf(self, pdf_file, pdf_name):
try:
self.pdf_name = pdf_name
st.info("πŸ” Extracting text from PDF...")
text = self.extract_pdf_text(pdf_file)
if not text:
return False
st.info("βœ‚οΈ Splitting text into chunks with overlap...")
chunks = self.chunk_text(text)
if not chunks:
st.error("No valid text chunks created")
return False
st.info(f"πŸ”„ Creating embeddings for {len(chunks)} chunks...")
# Create embeddings in batches to manage memory
batch_size = 32
embeddings = []
progress_bar = st.progress(0)
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
batch_embeddings = self.embedding_model.encode(batch, show_progress_bar=False)
embeddings.extend(batch_embeddings)
progress_bar.progress(min(i + batch_size, len(chunks)) / len(chunks))
progress_bar.empty()
self.documents = chunks
self.embeddings = np.array(embeddings)
st.success(f"βœ… Successfully processed PDF: {len(chunks)} chunks created with embeddings")
return True
except Exception as e:
st.error(f"❌ Error processing PDF: {e}")
logger.error(f"PDF processing error: {e}")
return False
def search_documents(self, query, top_k=3):
if not self.documents or len(self.embeddings) == 0:
st.warning("No documents available for search")
return []
try:
query_embedding = self.embedding_model.encode([query])
similarities = cosine_similarity(query_embedding, self.embeddings)[0]
# Filter out very low similarity scores
min_threshold = 0.1
valid_indices = np.where(similarities > min_threshold)[0]
if len(valid_indices) == 0:
return []
# Get top k from valid indices
valid_similarities = similarities[valid_indices]
top_valid_indices = np.argsort(valid_similarities)[-top_k:][::-1]
top_indices = valid_indices[top_valid_indices]
return [{'text': self.documents[i], 'score': similarities[i]}
for i in top_indices]
except Exception as e:
st.error(f"Error searching documents: {e}")
logger.error(f"Search error: {e}")
return []
def generate_answer(self, query, context_docs):
if not self.granite_model or not context_docs:
return "I don't have enough information to answer your question."
# Create better context from top documents
context = "\n\n".join([f"Context {i+1}: {doc['text'][:300]}"
for i, doc in enumerate(context_docs[:2])]) # Use top 2 docs
# Improved prompt formatting
prompt = f"""Based on the following context, provide a clear and accurate answer to the question. If the context doesn't contain enough information, say so.
Context:
{context}
Question: {query}
Answer:"""
try:
# Tokenize with proper attention to length
inputs = self.tokenizer.encode(
prompt,
return_tensors='pt',
max_length=1024,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = self.granite_model.generate(
inputs,
max_new_tokens=150, # Use max_new_tokens instead of max_length
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.2,
top_p=0.9
)
# Decode only the new tokens
response = self.tokenizer.decode(
outputs[0][inputs.shape[1]:],
skip_special_tokens=True
)
# Clean up the response
response = response.strip()
if len(response) < 10:
return f"Based on the provided context: {context[:200]}..."
return response
except Exception as e:
logger.error(f"Generation error: {e}")
return f"Error generating response. Here's what I found: {context[:200]}..."
finally:
# Clean up GPU memory
if self.device.type == "cuda":
torch.cuda.empty_cache()
def answer_question(self, query):
if not self.documents:
return {'answer': "No PDF has been processed yet.", 'sources': []}
relevant_docs = self.search_documents(query)
if not relevant_docs:
return {'answer': "No relevant information found in the document for your question.", 'sources': []}
answer = self.generate_answer(query, relevant_docs)
return {
'answer': answer,
'sources': relevant_docs
}
def main():
st.set_page_config(
page_title="PDF RAG with IBM Granite",
page_icon="πŸ“„",
layout="wide"
)
st.title("πŸ“„ PDF RAG with IBM Granite")
st.write("Upload a PDF and ask questions about its content using AI")
# Initialize session state
if 'rag_system' not in st.session_state:
st.session_state.rag_system = SimplePDFRAG()
if 'models_loaded' not in st.session_state:
st.session_state.models_loaded = False
if 'pdf_processed' not in st.session_state:
st.session_state.pdf_processed = False
if 'current_pdf_name' not in st.session_state:
st.session_state.current_pdf_name = None
if 'uploaded_file_path' not in st.session_state:
st.session_state.uploaded_file_path = None
# Status indicators
col1, col2, col3 = st.columns(3)
with col1:
if st.session_state.models_loaded:
st.success("πŸ€– Models: Loaded")
else:
st.error("πŸ€– Models: Not Loaded")
with col2:
if st.session_state.pdf_processed:
st.success(f"πŸ“„ PDF: {st.session_state.current_pdf_name}")
else:
st.error("πŸ“„ PDF: Not Processed")
with col3:
if st.session_state.models_loaded and st.session_state.pdf_processed:
st.success("🟒 Ready")
else:
st.error("πŸ”΄ Not Ready")
# Model loading section
if not st.session_state.models_loaded:
st.markdown("---")
st.subheader("πŸ€– Model Loading")
st.info("Click below to load the AI models. This may take a few minutes.")
if st.button("πŸ€– Load Models", type="primary"):
with st.spinner("Loading models... This may take a few minutes."):
success = st.session_state.rag_system.load_models()
st.session_state.models_loaded = success
if success:
st.balloons()
st.rerun()
# PDF processing section
if st.session_state.models_loaded:
st.markdown("---")
st.subheader("πŸ“ PDF Upload and Processing")
uploaded_file = st.file_uploader(
"Upload PDF",
type=["pdf"],
key="pdf_uploader",
help="Upload a PDF file to analyze and ask questions about"
)
if uploaded_file:
# Save uploaded file
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(uploaded_file.read())
st.session_state.uploaded_file_path = tmp.name
st.session_state.uploaded_file_name = uploaded_file.name
st.session_state.pdf_processed = False
st.session_state.current_pdf_name = None
st.success(f"πŸ“„ Uploaded: {uploaded_file.name}")
if st.session_state.uploaded_file_path and not st.session_state.pdf_processed:
if st.button("πŸ“– Process PDF", type="primary"):
with st.spinner("Processing PDF... This may take a moment."):
try:
with open(st.session_state.uploaded_file_path, "rb") as f:
success = st.session_state.rag_system.process_pdf(
f, st.session_state.uploaded_file_name
)
if success:
st.session_state.pdf_processed = True
st.session_state.current_pdf_name = st.session_state.uploaded_file_name
st.success("βœ… PDF processed successfully!")
st.balloons()
st.rerun()
else:
st.error("❌ Failed to process PDF")
except Exception as e:
st.error(f"❌ Error processing PDF: {e}")
# Q&A section
if st.session_state.models_loaded and st.session_state.pdf_processed:
st.markdown("---")
st.subheader("❓ Ask Questions")
st.info(f"πŸ“š Current document: **{st.session_state.current_pdf_name}**")
query = st.text_input(
"Ask a question about your PDF:",
placeholder="What is the main topic discussed in this document?",
help="Ask specific questions about the content in your PDF"
)
if query and st.button("πŸ” Get Answer", type="primary"):
with st.spinner("Searching document and generating answer..."):
result = st.session_state.rag_system.answer_question(query)
st.markdown("### πŸ€– Answer:")
st.write(result['answer'])
if result.get('sources'):
st.markdown("### πŸ“š Sources:")
for i, src in enumerate(result['sources']):
with st.expander(f"Source {i+1} (Relevance: {src['score']:.3f})"):
st.write(src['text'][:500] + "..." if len(src['text']) > 500 else src['text'])
# Sidebar
with st.sidebar:
st.header("πŸ“‹ How to Use")
st.markdown("""
1. **Load Models** - Click to download and load AI models
2. **Upload PDF** - Select your PDF file
3. **Process PDF** - Extract and analyze the text
4. **Ask Questions** - Query your document
""")
st.header("πŸ’‘ Tips")
st.markdown("""
- Ask specific questions for better results
- Try different phrasings if unsatisfied
- The AI uses context from your document
""")
st.header("πŸ”§ System Info")
device_info = "GPU" if torch.cuda.is_available() else "CPU"
st.write(f"**Device:** {device_info}")
st.write(f"**Models:** {'βœ… Loaded' if st.session_state.models_loaded else '❌ Not loaded'}")
st.write(f"**PDF:** {'βœ… Processed' if st.session_state.pdf_processed else '❌ Not processed'}")
if st.button("πŸ”„ Reset Everything"):
# Clear all session state
for key in list(st.session_state.keys()):
del st.session_state[key]
# Force garbage collection
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
st.rerun()
if __name__ == "__main__":
main()