from typing import List, Dict, Optional from pdf_parser import extract_text_from_pdfs from vector_store import VectorStore from embeddings import CLIPEmbedder from multimodal_model import Gemma3Model # ← Changed from GemmaVisionModel import logging logger = logging.getLogger(__name__) class RAGPipeline: def __init__(self, pdf_dir: str, chroma_dir: str = "./chroma_db", device: str = "cpu"): self.pdf_dir = pdf_dir self.device = device logger.info("→ Initializing RAG Pipeline...") try: # Initialize embedder logger.debug("Loading embedder...") self.embedder = CLIPEmbedder( model_name="openai/clip-vit-base-patch32", device=device ) # Initialize vector store logger.debug("Initializing vector store...") self.vector_store = VectorStore(persist_dir=chroma_dir) self.vector_store.get_or_create_collection() # Initialize LLM with Gemma3Model logger.debug("Loading Gemma 3 1B model...") self.llm = Gemma3Model(model_name="unsloth/gemma-3-1b-pt", device=device) # ← Use Gemma3Model logger.info("✓ RAG Pipeline initialized successfully") except Exception as e: logger.error(f"Failed to initialize pipeline: {str(e)}", exc_info=True) raise def index_pdfs(self): """Index PDFs with error logging""" logger.info("→ Starting PDF indexing...") try: documents, metadatas = extract_text_from_pdfs(self.pdf_dir) if not documents: logger.warning("No documents extracted") return logger.debug(f"Extracted {len(documents)} document chunks") ids = [f"doc_{i}" for i in range(len(documents))] self.vector_store.add_documents(documents, metadatas, ids) logger.info(f"✓ Indexed {len(documents)} document chunks") except Exception as e: logger.error(f"Error during indexing: {str(e)}", exc_info=True) raise def retrieve_documents(self, query: str, n_results: int = 5) -> List[Dict]: """Retrieve documents with error handling""" try: logger.debug(f"Searching for: {query[:50]}...") results = self.vector_store.search(query, n_results=n_results) retrieved_docs = [] for doc, metadata in zip(results["documents"][0], results["metadatas"][0]): retrieved_docs.append({ "content": doc, "source": f"{metadata.get('filename', 'Unknown')} (p{metadata.get('page', '?')})" }) logger.debug(f"Retrieved {len(retrieved_docs)} documents") return retrieved_docs except Exception as e: logger.error(f"Error retrieving documents: {str(e)}", exc_info=True) return [] def answer_question(self, question: str, n_context_docs: int = 3) -> Dict: """Answer question using RAG with comprehensive error handling""" logger.info(f"Processing question: {question[:50]}...") try: # Retrieve relevant documents logger.debug(f"Retrieving {n_context_docs} documents...") retrieved_docs = self.retrieve_documents(question, n_results=n_context_docs) if not retrieved_docs: logger.warning("No documents retrieved") return { "answer": "No relevant documents found.", "sources": [], "context_used": 0 } logger.debug(f"Retrieved {len(retrieved_docs)} documents") # Combine context (limit to prevent memory issues) context = "\n\n".join([ f"[{doc['source']}]\n{doc['content'][:500]}" for doc in retrieved_docs ])[:2000] logger.debug("Generating answer with Gemma 3...") try: # Use greedy decoding for faster inference with Gemma 3 answer = self.llm.answer_question(question, context) except Exception as e: logger.warning(f"Answer generation failed ({e}), using greedy fallback...") # Fallback to greedy answer = self.llm.generate_response_greedy( f"Q: {question}\nA:" ) # Extract answer if "Answer:" in answer: answer = answer.split("Answer:")[-1].strip() logger.info("✓ Answer generated successfully") return { "answer": answer[:1000], # Limit output length "sources": [doc["source"] for doc in retrieved_docs], "context_used": len(retrieved_docs) } except Exception as e: logger.error(f"Error in answer_question: {str(e)}", exc_info=True) return { "answer": f"Error generating answer: {str(e)}", "sources": [], "context_used": 0 } def summarize_documents(self) -> str: """Summarize all indexed documents""" collection_info = self.vector_store.get_collection_info() doc_count = collection_info.get("document_count", 0) if doc_count == 0: return "No documents to summarize" # Sample documents results = self.vector_store.search("main topic summary", n_results=5) sampled_content = " ".join([doc for docs in results["documents"] for doc in docs[:200]]) summary = self.llm.summarize_text(sampled_content) return summary def get_collection_info(self) -> Dict: """Get collection statistics""" return self.vector_store.get_collection_info()