pdf / rag-engine.py
navid72m's picture
Upload 9 files
43efcb9 verified
"""
Main RAG (Retrieval-Augmented Generation) engine implementation.
"""
import os
import logging
from typing import List, Dict, Any, Optional, Tuple, Union
import numpy as np
# Configure logging
logger = logging.getLogger(__name__)
class RAGEngine:
"""Retrieval-Augmented Generation (RAG) engine for question answering."""
def __init__(
self,
embedder,
vector_db,
llm=None,
top_k: int = 5,
search_type: str = "hybrid",
prompt_template: Optional[str] = None
):
"""
Initialize the RAG engine.
Args:
embedder: Embedding model
vector_db: Vector database for document storage and retrieval
llm: Language model for text generation (optional)
top_k: Number of documents to retrieve
search_type: Type of search ('semantic', 'keyword', 'hybrid')
prompt_template: Optional custom prompt template
"""
self.embedder = embedder
self.vector_db = vector_db
self.llm = llm
self.top_k = top_k
self.search_type = search_type
# Set default prompt template if none provided
if prompt_template is None:
from ..config import DEFAULT_PROMPT_TEMPLATE
self.prompt_template = DEFAULT_PROMPT_TEMPLATE
else:
self.prompt_template = prompt_template
def add_documents(
self,
texts: List[str],
metadata: Optional[List[Dict[str, Any]]] = None,
batch_size: int = 32
) -> List[str]:
"""
Add documents to the database.
Args:
texts: List of text chunks
metadata: Optional list of metadata dictionaries for each text
batch_size: Batch size for embedding generation
Returns:
List of document IDs
"""
from ..storage.vector_db import Document
# Handle metadata
if metadata is None:
metadata = [{} for _ in texts]
elif len(metadata) != len(texts):
raise ValueError(f"Length mismatch: got {len(texts)} texts but {len(metadata)} metadata entries")
# Generate embeddings in batches
doc_ids = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
batch_metadata = metadata[i:i+batch_size]
# Generate embeddings
logger.info(f"Generating embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
batch_embeddings = self.embedder.embed(batch_texts)
# Create document objects
documents = []
for text, meta, embedding in zip(batch_texts, batch_metadata, batch_embeddings):
doc = Document(text=text, metadata=meta, embedding=embedding)
documents.append(doc)
# Add to database
batch_ids = self.vector_db.add_documents(documents)
doc_ids.extend(batch_ids)
logger.info(f"Added {len(doc_ids)} documents to database")
return doc_ids
def search(
self,
query: str,
top_k: Optional[int] = None,
search_type: Optional[str] = None,
filter_dict: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Search for relevant documents.
Args:
query: Query string
top_k: Number of results to return (defaults to self.top_k)
search_type: Type of search (defaults to self.search_type)
filter_dict: Dictionary of metadata filters
Returns:
List of document dictionaries
"""
if top_k is None:
top_k = self.top_k
if search_type is None:
search_type = self.search_type
# Create filter function if filter_dict is provided
filter_func = None
if filter_dict:
def filter_func(doc):
for key, value in filter_dict.items():
# Handle nested keys (e.g., "metadata.source")
if "." in key:
parts = key.split(".")
current = doc.metadata
for part in parts[:-1]:
if part not in current:
return False
current = current[part]
if parts[-1] not in current or current[parts[-1]] != value:
return False
elif key not in doc.metadata or doc.metadata[key] != value:
return False
return True
# Generate query embedding
query_embedding = self.embedder.embed(query)
# Perform search
results = self.vector_db.search(query_embedding, top_k, filter_func)
# Convert results to dictionaries
return [
{
"id": doc.id,
"text": doc.text,
"metadata": doc.metadata,
"score": score
}
for doc, score in results
]
def generate_response(
self,
query: str,
top_k: Optional[int] = None,
search_type: Optional[str] = None,
filter_dict: Optional[Dict[str, Any]] = None,
max_tokens: int = 512
) -> Dict[str, Any]:
"""
Generate a response to a query using RAG.
Args:
query: Query string
top_k: Number of documents to retrieve
search_type: Type of search
filter_dict: Optional filter for document retrieval
max_tokens: Maximum number of tokens in the response
Returns:
Dictionary with query, response, and retrieved documents
"""
# Retrieve relevant documents
retrieved_docs = self.search(query, top_k, search_type, filter_dict)
# If no documents were found, return a default message
if not retrieved_docs:
return {
"query": query,
"response": "I couldn't find any relevant information to answer your question.",
"retrieved_documents": [],
"search_type": search_type or self.search_type
}
# Format context from retrieved documents
context = self._format_context(retrieved_docs)
# Format prompt with context and query
prompt = self.prompt_template.format(context=context, query=query)
# Generate response using LLM
if self.llm is None:
logger.warning("No LLM provided, returning only retrieved documents")
response = "No language model available to generate a response. Here's what I found in the documents."
else:
response = self._generate_llm_response(prompt, max_tokens)
# Return the results
return {
"query": query,
"response": response,
"retrieved_documents": retrieved_docs,
"search_type": search_type or self.search_type
}
def _format_context(self, documents: List[Dict[str, Any]]) -> str:
"""
Format retrieved documents into context for the prompt.
Args:
documents: List of retrieved documents
Returns:
Formatted context string
"""
context_parts = []
for i, doc in enumerate(documents):
# Extract relevant fields
text = doc["text"]
metadata = doc["metadata"]
source = metadata.get("source", "Unknown")
# Format the document
doc_text = f"Document {i+1}: [Source: {source}]\n{text}\n"
context_parts.append(doc_text)
return "\n".join(context_parts)
def _generate_llm_response(self, prompt: str, max_tokens: int) -> str:
"""
Generate a response using the LLM.
Args:
prompt: The formatted prompt
max_tokens: Maximum number of tokens in the response
Returns:
Generated response
"""
if hasattr(self.llm, "generate_openai_response"):
# OpenAI-compatible LLM
return self.llm.generate_openai_response(prompt, max_tokens)
elif hasattr(self.llm, "generate_huggingface_response"):
# HuggingFace-compatible LLM
return self.llm.generate_huggingface_response(prompt, max_tokens)
else:
# Default implementation
try:
return self.llm.generate_response(prompt, max_tokens)
except Exception as e:
logger.error(f"Error generating response: {e}")
return "I encountered an error while generating a response."
def update_prompt_template(self, new_template: str) -> None:
"""
Update the prompt template.
Args:
new_template: New prompt template
"""
self.prompt_template = new_template
logger.info("Updated prompt template")
def count_documents(self) -> int:
"""
Get the number of documents in the database.
Returns:
Number of documents
"""
return self.vector_db.count_documents()
def clear_documents(self) -> None:
"""Clear all documents from the database."""
self.vector_db.clear()
logger.info("Cleared all documents from database")
# Factory function to create the RAG engine
def create_rag_engine(
embedder=None,
vector_db=None,
llm=None,
config=None
) -> RAGEngine:
"""
Factory function to create a RAG engine.
Args:
embedder: Embedding model (if None, created based on config)
vector_db: Vector database (if None, created based on config)
llm: Language model (if None, created based on config)
config: Configuration module or dictionary
Returns:
Configured RAGEngine instance
"""
# Load configuration if provided
if config is None:
from ..config import (
TOP_K,
SEARCH_TYPE,
DEFAULT_PROMPT_TEMPLATE
)
else:
TOP_K = config.get("TOP_K", 5)
SEARCH_TYPE = config.get("SEARCH_TYPE", "hybrid")
DEFAULT_PROMPT_TEMPLATE = config.get(
"DEFAULT_PROMPT_TEMPLATE",
"""
Answer the following question based ONLY on the provided context.
Context:
{context}
Question: {query}
Answer:
"""
)
# Create embedding model if not provided
if embedder is None:
from ..embedding.model import create_embedding_model
embedder = create_embedding_model()
# Create vector database if not provided
if vector_db is None:
from ..storage.vector_db import create_vector_database
vector_db = create_vector_database(dimension=embedder.dimension)
# Create language model if not provided and requested
if llm is None:
try:
from ..llm.model import create_llm
llm = create_llm()
except (ImportError, ModuleNotFoundError):
logger.warning("LLM module not found, proceeding without an LLM")
# Create and return the RAG engine
return RAGEngine(
embedder=embedder,
vector_db=vector_db,
llm=llm,
top_k=TOP_K,
search_type=SEARCH_TYPE,
prompt_template=DEFAULT_PROMPT_TEMPLATE
)