import streamlit as st import PyPDF2 from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM import torch import numpy as np from sklearn.metrics.pairwise import cosine_similarity import logging import os import tempfile import gc # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SimplePDFRAG: def __init__(self): self.documents = [] self.embeddings = [] self.embedding_model = None self.granite_model = None self.tokenizer = None self.pdf_name = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def setup_cache_directory(self): try: cache_dir = tempfile.mkdtemp(prefix="model_cache_") os.environ['HF_HOME'] = cache_dir os.environ['TRANSFORMERS_CACHE'] = cache_dir os.environ['SENTENCE_TRANSFORMERS_HOME'] = cache_dir st.info(f"Using cache directory: {cache_dir}") st.info(f"Using device: {self.device}") return cache_dir except Exception as e: st.error(f"Error setting up cache directory: {e}") return None def load_models(self): try: cache_dir = self.setup_cache_directory() st.info("Loading embedding model...") self.embedding_model = SentenceTransformer( 'all-MiniLM-L6-v2', cache_folder=cache_dir, device=self.device ) st.info("Loading IBM Granite model...") # Alternative models you could try: # model_name = "ibm-granite/granite-3-8b-instruct" # Larger, better performance # model_name = "microsoft/DialoGPT-medium" # model_name = "google/flan-t5-base" model_name = "ibm-granite/granite-3-2b-instruct" self.tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir=cache_dir, trust_remote_code=True ) # Optimize model loading based on available resources model_kwargs = { "cache_dir": cache_dir, "trust_remote_code": True, "low_cpu_mem_usage": True, } # Use appropriate dtype based on device if self.device.type == "cuda": model_kwargs["torch_dtype"] = torch.float16 else: model_kwargs["torch_dtype"] = torch.float32 self.granite_model = AutoModelForCausalLM.from_pretrained( model_name, **model_kwargs ).to(self.device) # Set pad token if not available if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token st.success("Models loaded successfully!") return True except Exception as e: st.error(f"Error loading models: {e}") logger.error(f"Model loading error: {e}") return False def extract_pdf_text(self, pdf_file): try: pdf_file.seek(0) pdf_reader = PyPDF2.PdfReader(pdf_file) text = "" st.info(f"PDF has {len(pdf_reader.pages)} pages") progress_bar = st.progress(0) for page_num, page in enumerate(pdf_reader.pages): try: page_text = page.extract_text() if page_text: text += page_text + "\n" st.write(f"✅ Extracted text from page {page_num + 1}") else: st.warning(f"⚠️ No text found on page {page_num + 1}") except Exception as page_error: st.error(f"Error extracting page {page_num + 1}: {page_error}") # Update progress progress_bar.progress((page_num + 1) / len(pdf_reader.pages)) progress_bar.empty() if text.strip(): st.success(f"Extracted {len(text)} characters from {len(pdf_reader.pages)} pages") st.write("📄 **Text Preview:**") st.text(text[:500] + "..." if len(text) > 500 else text) return text else: st.error("No text could be extracted from the PDF") return None except Exception as e: st.error(f"Error reading PDF file: {e}") logger.error(f"PDF extraction error: {e}") return None def chunk_text(self, text, chunk_size=400, overlap=50): """Improved chunking with overlap for better context preservation""" if not text or not text.strip(): return [] words = text.split() chunks = [] for i in range(0, len(words), chunk_size - overlap): chunk = " ".join(words[i:i + chunk_size]) if chunk.strip(): # Only add non-empty chunks chunks.append(chunk) return chunks def process_pdf(self, pdf_file, pdf_name): try: self.pdf_name = pdf_name st.info("🔍 Extracting text from PDF...") text = self.extract_pdf_text(pdf_file) if not text: return False st.info("✂️ Splitting text into chunks with overlap...") chunks = self.chunk_text(text) if not chunks: st.error("No valid text chunks created") return False st.info(f"🔄 Creating embeddings for {len(chunks)} chunks...") # Create embeddings in batches to manage memory batch_size = 32 embeddings = [] progress_bar = st.progress(0) for i in range(0, len(chunks), batch_size): batch = chunks[i:i + batch_size] batch_embeddings = self.embedding_model.encode(batch, show_progress_bar=False) embeddings.extend(batch_embeddings) progress_bar.progress(min(i + batch_size, len(chunks)) / len(chunks)) progress_bar.empty() self.documents = chunks self.embeddings = np.array(embeddings) st.success(f"✅ Successfully processed PDF: {len(chunks)} chunks created with embeddings") return True except Exception as e: st.error(f"❌ Error processing PDF: {e}") logger.error(f"PDF processing error: {e}") return False def search_documents(self, query, top_k=3): if not self.documents or len(self.embeddings) == 0: st.warning("No documents available for search") return [] try: query_embedding = self.embedding_model.encode([query]) similarities = cosine_similarity(query_embedding, self.embeddings)[0] # Filter out very low similarity scores min_threshold = 0.1 valid_indices = np.where(similarities > min_threshold)[0] if len(valid_indices) == 0: return [] # Get top k from valid indices valid_similarities = similarities[valid_indices] top_valid_indices = np.argsort(valid_similarities)[-top_k:][::-1] top_indices = valid_indices[top_valid_indices] return [{'text': self.documents[i], 'score': similarities[i]} for i in top_indices] except Exception as e: st.error(f"Error searching documents: {e}") logger.error(f"Search error: {e}") return [] def generate_answer(self, query, context_docs): if not self.granite_model or not context_docs: return "I don't have enough information to answer your question." # Create better context from top documents context = "\n\n".join([f"Context {i+1}: {doc['text'][:300]}" for i, doc in enumerate(context_docs[:2])]) # Use top 2 docs # Improved prompt formatting prompt = f"""Based on the following context, provide a clear and accurate answer to the question. If the context doesn't contain enough information, say so. Context: {context} Question: {query} Answer:""" try: # Tokenize with proper attention to length inputs = self.tokenizer.encode( prompt, return_tensors='pt', max_length=1024, truncation=True ).to(self.device) with torch.no_grad(): outputs = self.granite_model.generate( inputs, max_new_tokens=150, # Use max_new_tokens instead of max_length temperature=0.7, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.2, top_p=0.9 ) # Decode only the new tokens response = self.tokenizer.decode( outputs[0][inputs.shape[1]:], skip_special_tokens=True ) # Clean up the response response = response.strip() if len(response) < 10: return f"Based on the provided context: {context[:200]}..." return response except Exception as e: logger.error(f"Generation error: {e}") return f"Error generating response. Here's what I found: {context[:200]}..." finally: # Clean up GPU memory if self.device.type == "cuda": torch.cuda.empty_cache() def answer_question(self, query): if not self.documents: return {'answer': "No PDF has been processed yet.", 'sources': []} relevant_docs = self.search_documents(query) if not relevant_docs: return {'answer': "No relevant information found in the document for your question.", 'sources': []} answer = self.generate_answer(query, relevant_docs) return { 'answer': answer, 'sources': relevant_docs } def main(): st.set_page_config( page_title="PDF RAG with IBM Granite", page_icon="📄", layout="wide" ) st.title("📄 PDF RAG with IBM Granite") st.write("Upload a PDF and ask questions about its content using AI") # Initialize session state if 'rag_system' not in st.session_state: st.session_state.rag_system = SimplePDFRAG() if 'models_loaded' not in st.session_state: st.session_state.models_loaded = False if 'pdf_processed' not in st.session_state: st.session_state.pdf_processed = False if 'current_pdf_name' not in st.session_state: st.session_state.current_pdf_name = None if 'uploaded_file_path' not in st.session_state: st.session_state.uploaded_file_path = None # Status indicators col1, col2, col3 = st.columns(3) with col1: if st.session_state.models_loaded: st.success("🤖 Models: Loaded") else: st.error("🤖 Models: Not Loaded") with col2: if st.session_state.pdf_processed: st.success(f"📄 PDF: {st.session_state.current_pdf_name}") else: st.error("📄 PDF: Not Processed") with col3: if st.session_state.models_loaded and st.session_state.pdf_processed: st.success("🟢 Ready") else: st.error("🔴 Not Ready") # Model loading section if not st.session_state.models_loaded: st.markdown("---") st.subheader("🤖 Model Loading") st.info("Click below to load the AI models. This may take a few minutes.") if st.button("🤖 Load Models", type="primary"): with st.spinner("Loading models... This may take a few minutes."): success = st.session_state.rag_system.load_models() st.session_state.models_loaded = success if success: st.balloons() st.rerun() # PDF processing section if st.session_state.models_loaded: st.markdown("---") st.subheader("📁 PDF Upload and Processing") uploaded_file = st.file_uploader( "Upload PDF", type=["pdf"], key="pdf_uploader", help="Upload a PDF file to analyze and ask questions about" ) if uploaded_file: # Save uploaded file with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: tmp.write(uploaded_file.read()) st.session_state.uploaded_file_path = tmp.name st.session_state.uploaded_file_name = uploaded_file.name st.session_state.pdf_processed = False st.session_state.current_pdf_name = None st.success(f"📄 Uploaded: {uploaded_file.name}") if st.session_state.uploaded_file_path and not st.session_state.pdf_processed: if st.button("📖 Process PDF", type="primary"): with st.spinner("Processing PDF... This may take a moment."): try: with open(st.session_state.uploaded_file_path, "rb") as f: success = st.session_state.rag_system.process_pdf( f, st.session_state.uploaded_file_name ) if success: st.session_state.pdf_processed = True st.session_state.current_pdf_name = st.session_state.uploaded_file_name st.success("✅ PDF processed successfully!") st.balloons() st.rerun() else: st.error("❌ Failed to process PDF") except Exception as e: st.error(f"❌ Error processing PDF: {e}") # Q&A section if st.session_state.models_loaded and st.session_state.pdf_processed: st.markdown("---") st.subheader("❓ Ask Questions") st.info(f"📚 Current document: **{st.session_state.current_pdf_name}**") query = st.text_input( "Ask a question about your PDF:", placeholder="What is the main topic discussed in this document?", help="Ask specific questions about the content in your PDF" ) if query and st.button("🔍 Get Answer", type="primary"): with st.spinner("Searching document and generating answer..."): result = st.session_state.rag_system.answer_question(query) st.markdown("### 🤖 Answer:") st.write(result['answer']) if result.get('sources'): st.markdown("### 📚 Sources:") for i, src in enumerate(result['sources']): with st.expander(f"Source {i+1} (Relevance: {src['score']:.3f})"): st.write(src['text'][:500] + "..." if len(src['text']) > 500 else src['text']) # Sidebar with st.sidebar: st.header("📋 How to Use") st.markdown(""" 1. **Load Models** - Click to download and load AI models 2. **Upload PDF** - Select your PDF file 3. **Process PDF** - Extract and analyze the text 4. **Ask Questions** - Query your document """) st.header("💡 Tips") st.markdown(""" - Ask specific questions for better results - Try different phrasings if unsatisfied - The AI uses context from your document """) st.header("🔧 System Info") device_info = "GPU" if torch.cuda.is_available() else "CPU" st.write(f"**Device:** {device_info}") st.write(f"**Models:** {'✅ Loaded' if st.session_state.models_loaded else '❌ Not loaded'}") st.write(f"**PDF:** {'✅ Processed' if st.session_state.pdf_processed else '❌ Not processed'}") if st.button("🔄 Reset Everything"): # Clear all session state for key in list(st.session_state.keys()): del st.session_state[key] # Force garbage collection gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() st.rerun() if __name__ == "__main__": main()