|
|
from typing import List, Dict, Optional |
|
|
from pdf_parser import extract_text_from_pdfs |
|
|
from vector_store import VectorStore |
|
|
from embeddings import CLIPEmbedder |
|
|
from multimodal_model import Gemma3Model |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RAGPipeline: |
|
|
def __init__(self, pdf_dir: str, chroma_dir: str = "./chroma_db", device: str = "cpu"): |
|
|
self.pdf_dir = pdf_dir |
|
|
self.device = device |
|
|
|
|
|
logger.info("β Initializing RAG Pipeline...") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.debug("Loading embedder...") |
|
|
self.embedder = CLIPEmbedder( |
|
|
model_name="openai/clip-vit-base-patch32", |
|
|
device=device |
|
|
) |
|
|
|
|
|
|
|
|
logger.debug("Initializing vector store...") |
|
|
self.vector_store = VectorStore(persist_dir=chroma_dir) |
|
|
self.vector_store.get_or_create_collection() |
|
|
|
|
|
|
|
|
logger.debug("Loading Gemma 3 1B model...") |
|
|
self.llm = Gemma3Model(model_name="unsloth/gemma-3-1b-pt", device=device) |
|
|
|
|
|
logger.info("β RAG Pipeline initialized successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize pipeline: {str(e)}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def index_pdfs(self): |
|
|
"""Index PDFs with error logging""" |
|
|
logger.info("β Starting PDF indexing...") |
|
|
|
|
|
try: |
|
|
documents, metadatas = extract_text_from_pdfs(self.pdf_dir) |
|
|
|
|
|
if not documents: |
|
|
logger.warning("No documents extracted") |
|
|
return |
|
|
|
|
|
logger.debug(f"Extracted {len(documents)} document chunks") |
|
|
|
|
|
ids = [f"doc_{i}" for i in range(len(documents))] |
|
|
self.vector_store.add_documents(documents, metadatas, ids) |
|
|
|
|
|
logger.info(f"β Indexed {len(documents)} document chunks") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during indexing: {str(e)}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def retrieve_documents(self, query: str, n_results: int = 5) -> List[Dict]: |
|
|
"""Retrieve documents with error handling""" |
|
|
try: |
|
|
logger.debug(f"Searching for: {query[:50]}...") |
|
|
results = self.vector_store.search(query, n_results=n_results) |
|
|
|
|
|
retrieved_docs = [] |
|
|
for doc, metadata in zip(results["documents"][0], results["metadatas"][0]): |
|
|
retrieved_docs.append({ |
|
|
"content": doc, |
|
|
"source": f"{metadata.get('filename', 'Unknown')} (p{metadata.get('page', '?')})" |
|
|
}) |
|
|
|
|
|
logger.debug(f"Retrieved {len(retrieved_docs)} documents") |
|
|
return retrieved_docs |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error retrieving documents: {str(e)}", exc_info=True) |
|
|
return [] |
|
|
|
|
|
def answer_question(self, question: str, n_context_docs: int = 3) -> Dict: |
|
|
"""Answer question using RAG with comprehensive error handling""" |
|
|
|
|
|
logger.info(f"Processing question: {question[:50]}...") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.debug(f"Retrieving {n_context_docs} documents...") |
|
|
retrieved_docs = self.retrieve_documents(question, n_results=n_context_docs) |
|
|
|
|
|
if not retrieved_docs: |
|
|
logger.warning("No documents retrieved") |
|
|
return { |
|
|
"answer": "No relevant documents found.", |
|
|
"sources": [], |
|
|
"context_used": 0 |
|
|
} |
|
|
|
|
|
logger.debug(f"Retrieved {len(retrieved_docs)} documents") |
|
|
|
|
|
|
|
|
context = "\n\n".join([ |
|
|
f"[{doc['source']}]\n{doc['content'][:500]}" |
|
|
for doc in retrieved_docs |
|
|
])[:2000] |
|
|
|
|
|
logger.debug("Generating answer with Gemma 3...") |
|
|
|
|
|
try: |
|
|
|
|
|
answer = self.llm.answer_question(question, context) |
|
|
except Exception as e: |
|
|
logger.warning(f"Answer generation failed ({e}), using greedy fallback...") |
|
|
|
|
|
answer = self.llm.generate_response_greedy( |
|
|
f"Q: {question}\nA:" |
|
|
) |
|
|
|
|
|
|
|
|
if "Answer:" in answer: |
|
|
answer = answer.split("Answer:")[-1].strip() |
|
|
|
|
|
logger.info("β Answer generated successfully") |
|
|
|
|
|
return { |
|
|
"answer": answer[:1000], |
|
|
"sources": [doc["source"] for doc in retrieved_docs], |
|
|
"context_used": len(retrieved_docs) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in answer_question: {str(e)}", exc_info=True) |
|
|
return { |
|
|
"answer": f"Error generating answer: {str(e)}", |
|
|
"sources": [], |
|
|
"context_used": 0 |
|
|
} |
|
|
|
|
|
def summarize_documents(self) -> str: |
|
|
"""Summarize all indexed documents""" |
|
|
collection_info = self.vector_store.get_collection_info() |
|
|
doc_count = collection_info.get("document_count", 0) |
|
|
|
|
|
if doc_count == 0: |
|
|
return "No documents to summarize" |
|
|
|
|
|
|
|
|
results = self.vector_store.search("main topic summary", n_results=5) |
|
|
sampled_content = " ".join([doc for docs in results["documents"] for doc in docs[:200]]) |
|
|
|
|
|
summary = self.llm.summarize_text(sampled_content) |
|
|
return summary |
|
|
|
|
|
def get_collection_info(self) -> Dict: |
|
|
"""Get collection statistics""" |
|
|
return self.vector_store.get_collection_info() |
|
|
|