|
|
""" |
|
|
Main RAG (Retrieval-Augmented Generation) engine implementation. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from typing import List, Dict, Any, Optional, Tuple, Union |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class RAGEngine: |
|
|
"""Retrieval-Augmented Generation (RAG) engine for question answering.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedder, |
|
|
vector_db, |
|
|
llm=None, |
|
|
top_k: int = 5, |
|
|
search_type: str = "hybrid", |
|
|
prompt_template: Optional[str] = None |
|
|
): |
|
|
""" |
|
|
Initialize the RAG engine. |
|
|
|
|
|
Args: |
|
|
embedder: Embedding model |
|
|
vector_db: Vector database for document storage and retrieval |
|
|
llm: Language model for text generation (optional) |
|
|
top_k: Number of documents to retrieve |
|
|
search_type: Type of search ('semantic', 'keyword', 'hybrid') |
|
|
prompt_template: Optional custom prompt template |
|
|
""" |
|
|
self.embedder = embedder |
|
|
self.vector_db = vector_db |
|
|
self.llm = llm |
|
|
self.top_k = top_k |
|
|
self.search_type = search_type |
|
|
|
|
|
|
|
|
if prompt_template is None: |
|
|
from ..config import DEFAULT_PROMPT_TEMPLATE |
|
|
self.prompt_template = DEFAULT_PROMPT_TEMPLATE |
|
|
else: |
|
|
self.prompt_template = prompt_template |
|
|
|
|
|
def add_documents( |
|
|
self, |
|
|
texts: List[str], |
|
|
metadata: Optional[List[Dict[str, Any]]] = None, |
|
|
batch_size: int = 32 |
|
|
) -> List[str]: |
|
|
""" |
|
|
Add documents to the database. |
|
|
|
|
|
Args: |
|
|
texts: List of text chunks |
|
|
metadata: Optional list of metadata dictionaries for each text |
|
|
batch_size: Batch size for embedding generation |
|
|
|
|
|
Returns: |
|
|
List of document IDs |
|
|
""" |
|
|
from ..storage.vector_db import Document |
|
|
|
|
|
|
|
|
if metadata is None: |
|
|
metadata = [{} for _ in texts] |
|
|
elif len(metadata) != len(texts): |
|
|
raise ValueError(f"Length mismatch: got {len(texts)} texts but {len(metadata)} metadata entries") |
|
|
|
|
|
|
|
|
doc_ids = [] |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch_texts = texts[i:i+batch_size] |
|
|
batch_metadata = metadata[i:i+batch_size] |
|
|
|
|
|
|
|
|
logger.info(f"Generating embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}") |
|
|
batch_embeddings = self.embedder.embed(batch_texts) |
|
|
|
|
|
|
|
|
documents = [] |
|
|
for text, meta, embedding in zip(batch_texts, batch_metadata, batch_embeddings): |
|
|
doc = Document(text=text, metadata=meta, embedding=embedding) |
|
|
documents.append(doc) |
|
|
|
|
|
|
|
|
batch_ids = self.vector_db.add_documents(documents) |
|
|
doc_ids.extend(batch_ids) |
|
|
|
|
|
logger.info(f"Added {len(doc_ids)} documents to database") |
|
|
return doc_ids |
|
|
|
|
|
def search( |
|
|
self, |
|
|
query: str, |
|
|
top_k: Optional[int] = None, |
|
|
search_type: Optional[str] = None, |
|
|
filter_dict: Optional[Dict[str, Any]] = None |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Search for relevant documents. |
|
|
|
|
|
Args: |
|
|
query: Query string |
|
|
top_k: Number of results to return (defaults to self.top_k) |
|
|
search_type: Type of search (defaults to self.search_type) |
|
|
filter_dict: Dictionary of metadata filters |
|
|
|
|
|
Returns: |
|
|
List of document dictionaries |
|
|
""" |
|
|
if top_k is None: |
|
|
top_k = self.top_k |
|
|
|
|
|
if search_type is None: |
|
|
search_type = self.search_type |
|
|
|
|
|
|
|
|
filter_func = None |
|
|
if filter_dict: |
|
|
def filter_func(doc): |
|
|
for key, value in filter_dict.items(): |
|
|
|
|
|
if "." in key: |
|
|
parts = key.split(".") |
|
|
current = doc.metadata |
|
|
for part in parts[:-1]: |
|
|
if part not in current: |
|
|
return False |
|
|
current = current[part] |
|
|
if parts[-1] not in current or current[parts[-1]] != value: |
|
|
return False |
|
|
elif key not in doc.metadata or doc.metadata[key] != value: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
query_embedding = self.embedder.embed(query) |
|
|
|
|
|
|
|
|
results = self.vector_db.search(query_embedding, top_k, filter_func) |
|
|
|
|
|
|
|
|
return [ |
|
|
{ |
|
|
"id": doc.id, |
|
|
"text": doc.text, |
|
|
"metadata": doc.metadata, |
|
|
"score": score |
|
|
} |
|
|
for doc, score in results |
|
|
] |
|
|
|
|
|
def generate_response( |
|
|
self, |
|
|
query: str, |
|
|
top_k: Optional[int] = None, |
|
|
search_type: Optional[str] = None, |
|
|
filter_dict: Optional[Dict[str, Any]] = None, |
|
|
max_tokens: int = 512 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Generate a response to a query using RAG. |
|
|
|
|
|
Args: |
|
|
query: Query string |
|
|
top_k: Number of documents to retrieve |
|
|
search_type: Type of search |
|
|
filter_dict: Optional filter for document retrieval |
|
|
max_tokens: Maximum number of tokens in the response |
|
|
|
|
|
Returns: |
|
|
Dictionary with query, response, and retrieved documents |
|
|
""" |
|
|
|
|
|
retrieved_docs = self.search(query, top_k, search_type, filter_dict) |
|
|
|
|
|
|
|
|
if not retrieved_docs: |
|
|
return { |
|
|
"query": query, |
|
|
"response": "I couldn't find any relevant information to answer your question.", |
|
|
"retrieved_documents": [], |
|
|
"search_type": search_type or self.search_type |
|
|
} |
|
|
|
|
|
|
|
|
context = self._format_context(retrieved_docs) |
|
|
|
|
|
|
|
|
prompt = self.prompt_template.format(context=context, query=query) |
|
|
|
|
|
|
|
|
if self.llm is None: |
|
|
logger.warning("No LLM provided, returning only retrieved documents") |
|
|
response = "No language model available to generate a response. Here's what I found in the documents." |
|
|
else: |
|
|
response = self._generate_llm_response(prompt, max_tokens) |
|
|
|
|
|
|
|
|
return { |
|
|
"query": query, |
|
|
"response": response, |
|
|
"retrieved_documents": retrieved_docs, |
|
|
"search_type": search_type or self.search_type |
|
|
} |
|
|
|
|
|
def _format_context(self, documents: List[Dict[str, Any]]) -> str: |
|
|
""" |
|
|
Format retrieved documents into context for the prompt. |
|
|
|
|
|
Args: |
|
|
documents: List of retrieved documents |
|
|
|
|
|
Returns: |
|
|
Formatted context string |
|
|
""" |
|
|
context_parts = [] |
|
|
|
|
|
for i, doc in enumerate(documents): |
|
|
|
|
|
text = doc["text"] |
|
|
metadata = doc["metadata"] |
|
|
source = metadata.get("source", "Unknown") |
|
|
|
|
|
|
|
|
doc_text = f"Document {i+1}: [Source: {source}]\n{text}\n" |
|
|
context_parts.append(doc_text) |
|
|
|
|
|
return "\n".join(context_parts) |
|
|
|
|
|
def _generate_llm_response(self, prompt: str, max_tokens: int) -> str: |
|
|
""" |
|
|
Generate a response using the LLM. |
|
|
|
|
|
Args: |
|
|
prompt: The formatted prompt |
|
|
max_tokens: Maximum number of tokens in the response |
|
|
|
|
|
Returns: |
|
|
Generated response |
|
|
""" |
|
|
if hasattr(self.llm, "generate_openai_response"): |
|
|
|
|
|
return self.llm.generate_openai_response(prompt, max_tokens) |
|
|
elif hasattr(self.llm, "generate_huggingface_response"): |
|
|
|
|
|
return self.llm.generate_huggingface_response(prompt, max_tokens) |
|
|
else: |
|
|
|
|
|
try: |
|
|
return self.llm.generate_response(prompt, max_tokens) |
|
|
except Exception as e: |
|
|
logger.error(f"Error generating response: {e}") |
|
|
return "I encountered an error while generating a response." |
|
|
|
|
|
def update_prompt_template(self, new_template: str) -> None: |
|
|
""" |
|
|
Update the prompt template. |
|
|
|
|
|
Args: |
|
|
new_template: New prompt template |
|
|
""" |
|
|
self.prompt_template = new_template |
|
|
logger.info("Updated prompt template") |
|
|
|
|
|
def count_documents(self) -> int: |
|
|
""" |
|
|
Get the number of documents in the database. |
|
|
|
|
|
Returns: |
|
|
Number of documents |
|
|
""" |
|
|
return self.vector_db.count_documents() |
|
|
|
|
|
def clear_documents(self) -> None: |
|
|
"""Clear all documents from the database.""" |
|
|
self.vector_db.clear() |
|
|
logger.info("Cleared all documents from database") |
|
|
|
|
|
|
|
|
|
|
|
def create_rag_engine( |
|
|
embedder=None, |
|
|
vector_db=None, |
|
|
llm=None, |
|
|
config=None |
|
|
) -> RAGEngine: |
|
|
""" |
|
|
Factory function to create a RAG engine. |
|
|
|
|
|
Args: |
|
|
embedder: Embedding model (if None, created based on config) |
|
|
vector_db: Vector database (if None, created based on config) |
|
|
llm: Language model (if None, created based on config) |
|
|
config: Configuration module or dictionary |
|
|
|
|
|
Returns: |
|
|
Configured RAGEngine instance |
|
|
""" |
|
|
|
|
|
if config is None: |
|
|
from ..config import ( |
|
|
TOP_K, |
|
|
SEARCH_TYPE, |
|
|
DEFAULT_PROMPT_TEMPLATE |
|
|
) |
|
|
else: |
|
|
TOP_K = config.get("TOP_K", 5) |
|
|
SEARCH_TYPE = config.get("SEARCH_TYPE", "hybrid") |
|
|
DEFAULT_PROMPT_TEMPLATE = config.get( |
|
|
"DEFAULT_PROMPT_TEMPLATE", |
|
|
""" |
|
|
Answer the following question based ONLY on the provided context. |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: {query} |
|
|
|
|
|
Answer: |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
if embedder is None: |
|
|
from ..embedding.model import create_embedding_model |
|
|
embedder = create_embedding_model() |
|
|
|
|
|
|
|
|
if vector_db is None: |
|
|
from ..storage.vector_db import create_vector_database |
|
|
vector_db = create_vector_database(dimension=embedder.dimension) |
|
|
|
|
|
|
|
|
if llm is None: |
|
|
try: |
|
|
from ..llm.model import create_llm |
|
|
llm = create_llm() |
|
|
except (ImportError, ModuleNotFoundError): |
|
|
logger.warning("LLM module not found, proceeding without an LLM") |
|
|
|
|
|
|
|
|
return RAGEngine( |
|
|
embedder=embedder, |
|
|
vector_db=vector_db, |
|
|
llm=llm, |
|
|
top_k=TOP_K, |
|
|
search_type=SEARCH_TYPE, |
|
|
prompt_template=DEFAULT_PROMPT_TEMPLATE |
|
|
) |
|
|
|