project2 / src /rag_pipeline.py
dnj0's picture
Update src/rag_pipeline.py
4396f57 verified
from typing import List, Dict, Optional
from pdf_parser import extract_text_from_pdfs
from vector_store import VectorStore
from embeddings import CLIPEmbedder
from multimodal_model import Gemma3Model # ← Changed from GemmaVisionModel
import logging
logger = logging.getLogger(__name__)
class RAGPipeline:
def __init__(self, pdf_dir: str, chroma_dir: str = "./chroma_db", device: str = "cpu"):
self.pdf_dir = pdf_dir
self.device = device
logger.info("β†’ Initializing RAG Pipeline...")
try:
# Initialize embedder
logger.debug("Loading embedder...")
self.embedder = CLIPEmbedder(
model_name="openai/clip-vit-base-patch32",
device=device
)
# Initialize vector store
logger.debug("Initializing vector store...")
self.vector_store = VectorStore(persist_dir=chroma_dir)
self.vector_store.get_or_create_collection()
# Initialize LLM with Gemma3Model
logger.debug("Loading Gemma 3 1B model...")
self.llm = Gemma3Model(model_name="unsloth/gemma-3-1b-pt", device=device) # ← Use Gemma3Model
logger.info("βœ“ RAG Pipeline initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize pipeline: {str(e)}", exc_info=True)
raise
def index_pdfs(self):
"""Index PDFs with error logging"""
logger.info("β†’ Starting PDF indexing...")
try:
documents, metadatas = extract_text_from_pdfs(self.pdf_dir)
if not documents:
logger.warning("No documents extracted")
return
logger.debug(f"Extracted {len(documents)} document chunks")
ids = [f"doc_{i}" for i in range(len(documents))]
self.vector_store.add_documents(documents, metadatas, ids)
logger.info(f"βœ“ Indexed {len(documents)} document chunks")
except Exception as e:
logger.error(f"Error during indexing: {str(e)}", exc_info=True)
raise
def retrieve_documents(self, query: str, n_results: int = 5) -> List[Dict]:
"""Retrieve documents with error handling"""
try:
logger.debug(f"Searching for: {query[:50]}...")
results = self.vector_store.search(query, n_results=n_results)
retrieved_docs = []
for doc, metadata in zip(results["documents"][0], results["metadatas"][0]):
retrieved_docs.append({
"content": doc,
"source": f"{metadata.get('filename', 'Unknown')} (p{metadata.get('page', '?')})"
})
logger.debug(f"Retrieved {len(retrieved_docs)} documents")
return retrieved_docs
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}", exc_info=True)
return []
def answer_question(self, question: str, n_context_docs: int = 3) -> Dict:
"""Answer question using RAG with comprehensive error handling"""
logger.info(f"Processing question: {question[:50]}...")
try:
# Retrieve relevant documents
logger.debug(f"Retrieving {n_context_docs} documents...")
retrieved_docs = self.retrieve_documents(question, n_results=n_context_docs)
if not retrieved_docs:
logger.warning("No documents retrieved")
return {
"answer": "No relevant documents found.",
"sources": [],
"context_used": 0
}
logger.debug(f"Retrieved {len(retrieved_docs)} documents")
# Combine context (limit to prevent memory issues)
context = "\n\n".join([
f"[{doc['source']}]\n{doc['content'][:500]}"
for doc in retrieved_docs
])[:2000]
logger.debug("Generating answer with Gemma 3...")
try:
# Use greedy decoding for faster inference with Gemma 3
answer = self.llm.answer_question(question, context)
except Exception as e:
logger.warning(f"Answer generation failed ({e}), using greedy fallback...")
# Fallback to greedy
answer = self.llm.generate_response_greedy(
f"Q: {question}\nA:"
)
# Extract answer
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
logger.info("βœ“ Answer generated successfully")
return {
"answer": answer[:1000], # Limit output length
"sources": [doc["source"] for doc in retrieved_docs],
"context_used": len(retrieved_docs)
}
except Exception as e:
logger.error(f"Error in answer_question: {str(e)}", exc_info=True)
return {
"answer": f"Error generating answer: {str(e)}",
"sources": [],
"context_used": 0
}
def summarize_documents(self) -> str:
"""Summarize all indexed documents"""
collection_info = self.vector_store.get_collection_info()
doc_count = collection_info.get("document_count", 0)
if doc_count == 0:
return "No documents to summarize"
# Sample documents
results = self.vector_store.search("main topic summary", n_results=5)
sampled_content = " ".join([doc for docs in results["documents"] for doc in docs[:200]])
summary = self.llm.summarize_text(sampled_content)
return summary
def get_collection_info(self) -> Dict:
"""Get collection statistics"""
return self.vector_store.get_collection_info()