|
|
|
|
|
import streamlit as st |
|
|
import PyPDF2 |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import numpy as np |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import logging |
|
|
import os |
|
|
import tempfile |
|
|
import gc |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SimplePDFRAG: |
|
|
def __init__(self): |
|
|
self.documents = [] |
|
|
self.embeddings = [] |
|
|
self.embedding_model = None |
|
|
self.granite_model = None |
|
|
self.tokenizer = None |
|
|
self.pdf_name = None |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def setup_cache_directory(self): |
|
|
try: |
|
|
cache_dir = tempfile.mkdtemp(prefix="model_cache_") |
|
|
os.environ['HF_HOME'] = cache_dir |
|
|
os.environ['TRANSFORMERS_CACHE'] = cache_dir |
|
|
os.environ['SENTENCE_TRANSFORMERS_HOME'] = cache_dir |
|
|
st.info(f"Using cache directory: {cache_dir}") |
|
|
st.info(f"Using device: {self.device}") |
|
|
return cache_dir |
|
|
except Exception as e: |
|
|
st.error(f"Error setting up cache directory: {e}") |
|
|
return None |
|
|
|
|
|
def load_models(self): |
|
|
try: |
|
|
cache_dir = self.setup_cache_directory() |
|
|
st.info("Loading embedding model...") |
|
|
self.embedding_model = SentenceTransformer( |
|
|
'all-MiniLM-L6-v2', cache_folder=cache_dir, device=self.device |
|
|
) |
|
|
|
|
|
st.info("Loading IBM Granite model...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "ibm-granite/granite-3-2b-instruct" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
cache_dir=cache_dir, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"cache_dir": cache_dir, |
|
|
"trust_remote_code": True, |
|
|
"low_cpu_mem_usage": True, |
|
|
} |
|
|
|
|
|
|
|
|
if self.device.type == "cuda": |
|
|
model_kwargs["torch_dtype"] = torch.float16 |
|
|
else: |
|
|
model_kwargs["torch_dtype"] = torch.float32 |
|
|
|
|
|
self.granite_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, **model_kwargs |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
st.success("Models loaded successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error loading models: {e}") |
|
|
logger.error(f"Model loading error: {e}") |
|
|
return False |
|
|
|
|
|
def extract_pdf_text(self, pdf_file): |
|
|
try: |
|
|
pdf_file.seek(0) |
|
|
pdf_reader = PyPDF2.PdfReader(pdf_file) |
|
|
text = "" |
|
|
st.info(f"PDF has {len(pdf_reader.pages)} pages") |
|
|
|
|
|
progress_bar = st.progress(0) |
|
|
for page_num, page in enumerate(pdf_reader.pages): |
|
|
try: |
|
|
page_text = page.extract_text() |
|
|
if page_text: |
|
|
text += page_text + "\n" |
|
|
st.write(f"β
Extracted text from page {page_num + 1}") |
|
|
else: |
|
|
st.warning(f"β οΈ No text found on page {page_num + 1}") |
|
|
except Exception as page_error: |
|
|
st.error(f"Error extracting page {page_num + 1}: {page_error}") |
|
|
|
|
|
|
|
|
progress_bar.progress((page_num + 1) / len(pdf_reader.pages)) |
|
|
|
|
|
progress_bar.empty() |
|
|
|
|
|
if text.strip(): |
|
|
st.success(f"Extracted {len(text)} characters from {len(pdf_reader.pages)} pages") |
|
|
st.write("π **Text Preview:**") |
|
|
st.text(text[:500] + "..." if len(text) > 500 else text) |
|
|
return text |
|
|
else: |
|
|
st.error("No text could be extracted from the PDF") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error reading PDF file: {e}") |
|
|
logger.error(f"PDF extraction error: {e}") |
|
|
return None |
|
|
|
|
|
def chunk_text(self, text, chunk_size=400, overlap=50): |
|
|
"""Improved chunking with overlap for better context preservation""" |
|
|
if not text or not text.strip(): |
|
|
return [] |
|
|
|
|
|
words = text.split() |
|
|
chunks = [] |
|
|
|
|
|
for i in range(0, len(words), chunk_size - overlap): |
|
|
chunk = " ".join(words[i:i + chunk_size]) |
|
|
if chunk.strip(): |
|
|
chunks.append(chunk) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def process_pdf(self, pdf_file, pdf_name): |
|
|
try: |
|
|
self.pdf_name = pdf_name |
|
|
st.info("π Extracting text from PDF...") |
|
|
text = self.extract_pdf_text(pdf_file) |
|
|
|
|
|
if not text: |
|
|
return False |
|
|
|
|
|
st.info("βοΈ Splitting text into chunks with overlap...") |
|
|
chunks = self.chunk_text(text) |
|
|
|
|
|
if not chunks: |
|
|
st.error("No valid text chunks created") |
|
|
return False |
|
|
|
|
|
st.info(f"π Creating embeddings for {len(chunks)} chunks...") |
|
|
|
|
|
|
|
|
batch_size = 32 |
|
|
embeddings = [] |
|
|
|
|
|
progress_bar = st.progress(0) |
|
|
for i in range(0, len(chunks), batch_size): |
|
|
batch = chunks[i:i + batch_size] |
|
|
batch_embeddings = self.embedding_model.encode(batch, show_progress_bar=False) |
|
|
embeddings.extend(batch_embeddings) |
|
|
progress_bar.progress(min(i + batch_size, len(chunks)) / len(chunks)) |
|
|
|
|
|
progress_bar.empty() |
|
|
|
|
|
self.documents = chunks |
|
|
self.embeddings = np.array(embeddings) |
|
|
|
|
|
st.success(f"β
Successfully processed PDF: {len(chunks)} chunks created with embeddings") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"β Error processing PDF: {e}") |
|
|
logger.error(f"PDF processing error: {e}") |
|
|
return False |
|
|
|
|
|
def search_documents(self, query, top_k=3): |
|
|
if not self.documents or len(self.embeddings) == 0: |
|
|
st.warning("No documents available for search") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
query_embedding = self.embedding_model.encode([query]) |
|
|
similarities = cosine_similarity(query_embedding, self.embeddings)[0] |
|
|
|
|
|
|
|
|
min_threshold = 0.1 |
|
|
valid_indices = np.where(similarities > min_threshold)[0] |
|
|
|
|
|
if len(valid_indices) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
valid_similarities = similarities[valid_indices] |
|
|
top_valid_indices = np.argsort(valid_similarities)[-top_k:][::-1] |
|
|
top_indices = valid_indices[top_valid_indices] |
|
|
|
|
|
return [{'text': self.documents[i], 'score': similarities[i]} |
|
|
for i in top_indices] |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error searching documents: {e}") |
|
|
logger.error(f"Search error: {e}") |
|
|
return [] |
|
|
|
|
|
def generate_answer(self, query, context_docs): |
|
|
if not self.granite_model or not context_docs: |
|
|
return "I don't have enough information to answer your question." |
|
|
|
|
|
|
|
|
context = "\n\n".join([f"Context {i+1}: {doc['text'][:300]}" |
|
|
for i, doc in enumerate(context_docs[:2])]) |
|
|
|
|
|
|
|
|
prompt = f"""Based on the following context, provide a clear and accurate answer to the question. If the context doesn't contain enough information, say so. |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: {query} |
|
|
|
|
|
Answer:""" |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = self.tokenizer.encode( |
|
|
prompt, |
|
|
return_tensors='pt', |
|
|
max_length=1024, |
|
|
truncation=True |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.granite_model.generate( |
|
|
inputs, |
|
|
max_new_tokens=150, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
repetition_penalty=1.2, |
|
|
top_p=0.9 |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode( |
|
|
outputs[0][inputs.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
response = response.strip() |
|
|
if len(response) < 10: |
|
|
return f"Based on the provided context: {context[:200]}..." |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Generation error: {e}") |
|
|
return f"Error generating response. Here's what I found: {context[:200]}..." |
|
|
finally: |
|
|
|
|
|
if self.device.type == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def answer_question(self, query): |
|
|
if not self.documents: |
|
|
return {'answer': "No PDF has been processed yet.", 'sources': []} |
|
|
|
|
|
relevant_docs = self.search_documents(query) |
|
|
|
|
|
if not relevant_docs: |
|
|
return {'answer': "No relevant information found in the document for your question.", 'sources': []} |
|
|
|
|
|
answer = self.generate_answer(query, relevant_docs) |
|
|
|
|
|
return { |
|
|
'answer': answer, |
|
|
'sources': relevant_docs |
|
|
} |
|
|
|
|
|
def main(): |
|
|
st.set_page_config( |
|
|
page_title="PDF RAG with IBM Granite", |
|
|
page_icon="π", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
st.title("π PDF RAG with IBM Granite") |
|
|
st.write("Upload a PDF and ask questions about its content using AI") |
|
|
|
|
|
|
|
|
if 'rag_system' not in st.session_state: |
|
|
st.session_state.rag_system = SimplePDFRAG() |
|
|
if 'models_loaded' not in st.session_state: |
|
|
st.session_state.models_loaded = False |
|
|
if 'pdf_processed' not in st.session_state: |
|
|
st.session_state.pdf_processed = False |
|
|
if 'current_pdf_name' not in st.session_state: |
|
|
st.session_state.current_pdf_name = None |
|
|
if 'uploaded_file_path' not in st.session_state: |
|
|
st.session_state.uploaded_file_path = None |
|
|
|
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
with col1: |
|
|
if st.session_state.models_loaded: |
|
|
st.success("π€ Models: Loaded") |
|
|
else: |
|
|
st.error("π€ Models: Not Loaded") |
|
|
|
|
|
with col2: |
|
|
if st.session_state.pdf_processed: |
|
|
st.success(f"π PDF: {st.session_state.current_pdf_name}") |
|
|
else: |
|
|
st.error("π PDF: Not Processed") |
|
|
|
|
|
with col3: |
|
|
if st.session_state.models_loaded and st.session_state.pdf_processed: |
|
|
st.success("π’ Ready") |
|
|
else: |
|
|
st.error("π΄ Not Ready") |
|
|
|
|
|
|
|
|
if not st.session_state.models_loaded: |
|
|
st.markdown("---") |
|
|
st.subheader("π€ Model Loading") |
|
|
st.info("Click below to load the AI models. This may take a few minutes.") |
|
|
|
|
|
if st.button("π€ Load Models", type="primary"): |
|
|
with st.spinner("Loading models... This may take a few minutes."): |
|
|
success = st.session_state.rag_system.load_models() |
|
|
st.session_state.models_loaded = success |
|
|
if success: |
|
|
st.balloons() |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
if st.session_state.models_loaded: |
|
|
st.markdown("---") |
|
|
st.subheader("π PDF Upload and Processing") |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
|
"Upload PDF", |
|
|
type=["pdf"], |
|
|
key="pdf_uploader", |
|
|
help="Upload a PDF file to analyze and ask questions about" |
|
|
) |
|
|
|
|
|
if uploaded_file: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: |
|
|
tmp.write(uploaded_file.read()) |
|
|
st.session_state.uploaded_file_path = tmp.name |
|
|
st.session_state.uploaded_file_name = uploaded_file.name |
|
|
st.session_state.pdf_processed = False |
|
|
st.session_state.current_pdf_name = None |
|
|
|
|
|
st.success(f"π Uploaded: {uploaded_file.name}") |
|
|
|
|
|
if st.session_state.uploaded_file_path and not st.session_state.pdf_processed: |
|
|
if st.button("π Process PDF", type="primary"): |
|
|
with st.spinner("Processing PDF... This may take a moment."): |
|
|
try: |
|
|
with open(st.session_state.uploaded_file_path, "rb") as f: |
|
|
success = st.session_state.rag_system.process_pdf( |
|
|
f, st.session_state.uploaded_file_name |
|
|
) |
|
|
|
|
|
if success: |
|
|
st.session_state.pdf_processed = True |
|
|
st.session_state.current_pdf_name = st.session_state.uploaded_file_name |
|
|
st.success("β
PDF processed successfully!") |
|
|
st.balloons() |
|
|
st.rerun() |
|
|
else: |
|
|
st.error("β Failed to process PDF") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"β Error processing PDF: {e}") |
|
|
|
|
|
|
|
|
if st.session_state.models_loaded and st.session_state.pdf_processed: |
|
|
st.markdown("---") |
|
|
st.subheader("β Ask Questions") |
|
|
st.info(f"π Current document: **{st.session_state.current_pdf_name}**") |
|
|
|
|
|
query = st.text_input( |
|
|
"Ask a question about your PDF:", |
|
|
placeholder="What is the main topic discussed in this document?", |
|
|
help="Ask specific questions about the content in your PDF" |
|
|
) |
|
|
|
|
|
if query and st.button("π Get Answer", type="primary"): |
|
|
with st.spinner("Searching document and generating answer..."): |
|
|
result = st.session_state.rag_system.answer_question(query) |
|
|
|
|
|
st.markdown("### π€ Answer:") |
|
|
st.write(result['answer']) |
|
|
|
|
|
if result.get('sources'): |
|
|
st.markdown("### π Sources:") |
|
|
for i, src in enumerate(result['sources']): |
|
|
with st.expander(f"Source {i+1} (Relevance: {src['score']:.3f})"): |
|
|
st.write(src['text'][:500] + "..." if len(src['text']) > 500 else src['text']) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("π How to Use") |
|
|
st.markdown(""" |
|
|
1. **Load Models** - Click to download and load AI models |
|
|
2. **Upload PDF** - Select your PDF file |
|
|
3. **Process PDF** - Extract and analyze the text |
|
|
4. **Ask Questions** - Query your document |
|
|
""") |
|
|
|
|
|
st.header("π‘ Tips") |
|
|
st.markdown(""" |
|
|
- Ask specific questions for better results |
|
|
- Try different phrasings if unsatisfied |
|
|
- The AI uses context from your document |
|
|
""") |
|
|
|
|
|
st.header("π§ System Info") |
|
|
device_info = "GPU" if torch.cuda.is_available() else "CPU" |
|
|
st.write(f"**Device:** {device_info}") |
|
|
st.write(f"**Models:** {'β
Loaded' if st.session_state.models_loaded else 'β Not loaded'}") |
|
|
st.write(f"**PDF:** {'β
Processed' if st.session_state.pdf_processed else 'β Not processed'}") |
|
|
|
|
|
if st.button("π Reset Everything"): |
|
|
|
|
|
for key in list(st.session_state.keys()): |
|
|
del st.session_state[key] |
|
|
|
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
st.rerun() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |