|
|
import logging |
|
|
from typing import Dict, List, Any, Optional, Tuple |
|
|
from dataclasses import dataclass, field |
|
|
import time |
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from src.embedding_system import EmbeddingSystem, RerankResult |
|
|
from src.vector_store import QdrantVectorStore, SearchResult |
|
|
from src.groq_client import LLMSystem |
|
|
from src.document_processor import DocumentChunk |
|
|
from src.utilites import load_yaml_config |
|
|
|
|
|
|
|
|
try: |
|
|
from logger.custom_logger import CustomLoggerTracker |
|
|
custom_log = CustomLoggerTracker() |
|
|
logger = custom_log.get_logger("rag_engine") |
|
|
|
|
|
except ImportError: |
|
|
logger = logging.getLogger("rag_engine") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Citation: |
|
|
source_file: str |
|
|
page_number: Optional[int] = None |
|
|
worksheet_name: Optional[str] = None |
|
|
cell_range: Optional[str] = None |
|
|
section_title: Optional[str] = None |
|
|
text_snippet: str = "" |
|
|
confidence: float = 0.0 |
|
|
chunk_id: str = "" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RAGResponse: |
|
|
answer: str |
|
|
confidence_score: float |
|
|
citations: List[Citation] = field(default_factory=list) |
|
|
context_chunks: List[DocumentChunk] = field(default_factory=list) |
|
|
processing_time: float = 0.0 |
|
|
retrieval_time: float = 0.0 |
|
|
generation_time: float = 0.0 |
|
|
rerank_time: float = 0.0 |
|
|
total_chunks_retrieved: int = 0 |
|
|
total_chunks_reranked: int = 0 |
|
|
model_used: str = "" |
|
|
success: bool = True |
|
|
error_message: Optional[str] = None |
|
|
|
|
|
|
|
|
class RAGEngine: |
|
|
def __init__(self, config: Dict[str, Any]): |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.embedding_system = EmbeddingSystem(config) |
|
|
self.vector_store = QdrantVectorStore(config) |
|
|
self.llm_system = LLMSystem(config) |
|
|
|
|
|
|
|
|
self.max_context_chunks = config.get('max_context_chunks', 5) |
|
|
self.similarity_threshold = config.get('similarity_threshold', 0.7) |
|
|
self.rerank_top_k = config.get('rerank_top_k', 20) |
|
|
self.final_top_k = config.get('final_top_k', 5) |
|
|
self.max_context_length = config.get('max_context_length', 4000) |
|
|
logger.info(f"RAG engine initialized with max_context_chunks={self.max_context_chunks}") |
|
|
|
|
|
|
|
|
def answer_question(self, question: str, filters: Optional[Dict[str, Any]] = None) -> RAGResponse: |
|
|
start_time = time.time() |
|
|
try: |
|
|
logger.info(f"Processing question: {question[:100]}...") |
|
|
|
|
|
query_embedding = self.embedding_system.generate_query_embedding(question) |
|
|
if not query_embedding: |
|
|
return RAGResponse( |
|
|
answer="I apologize, but I'm unable to process your question due to an embedding generation error.", |
|
|
confidence_score=0.0, |
|
|
success=False, |
|
|
error_message="Failed to generate query embedding") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval_start = time.time() |
|
|
search_results = self.vector_store.similarity_search( |
|
|
query_embedding=query_embedding, |
|
|
k=self.rerank_top_k, |
|
|
filters=filters) |
|
|
retrieval_time = time.time() - retrieval_start |
|
|
|
|
|
if not search_results: |
|
|
return RAGResponse( |
|
|
answer="I couldn't find any relevant information to answer your question. Please try rephrasing your question or check if the relevant documents have been uploaded.", |
|
|
confidence_score=0.0, |
|
|
retrieval_time=retrieval_time, |
|
|
processing_time=time.time() - start_time, |
|
|
success=True) |
|
|
|
|
|
logger.info(f"Retrieved {len(search_results)} chunks from vector store in {retrieval_time:.2f}s") |
|
|
|
|
|
rerank_start = time.time() |
|
|
reranked_chunks = self._rerank_chunks(question, search_results) |
|
|
rerank_time = time.time() - rerank_start |
|
|
|
|
|
|
|
|
context_chunks = reranked_chunks[:self.final_top_k] |
|
|
context_text = self._build_context(context_chunks) |
|
|
logger.info(f"Built context from top {len(context_chunks)} chunks") |
|
|
|
|
|
|
|
|
generation_start = time.time() |
|
|
answer = self.llm_system.answer_question(question, context_text) |
|
|
generation_time = time.time() - generation_start |
|
|
if not answer: |
|
|
return RAGResponse( |
|
|
answer="I apologize, but I was unable to generate an answer to your question.", |
|
|
confidence_score=0.0, |
|
|
retrieval_time=retrieval_time, |
|
|
generation_time=generation_time, |
|
|
rerank_time=rerank_time, |
|
|
processing_time=time.time() - start_time, |
|
|
success=False, |
|
|
error_message="LLM failed to generate an answer") |
|
|
|
|
|
logger.info(f"Generated answer in {generation_time:.2f}s") |
|
|
|
|
|
citations = self._extract_citations(context_chunks) |
|
|
logger.info(f"Extracted {len(citations)} citations") |
|
|
|
|
|
|
|
|
confidence_score = self._calculate_confidence_score(search_results, answer) |
|
|
logger.info(f"Calculated confidence score: {confidence_score:.2f}") |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
response = RAGResponse( |
|
|
answer=answer, |
|
|
confidence_score=confidence_score, |
|
|
citations=citations, |
|
|
context_chunks=[result.chunk for result in context_chunks], |
|
|
processing_time=total_time, |
|
|
retrieval_time=retrieval_time, |
|
|
generation_time=generation_time, |
|
|
rerank_time=rerank_time, |
|
|
total_chunks_retrieved=len(search_results), |
|
|
total_chunks_reranked=len(reranked_chunks), |
|
|
model_used=self.llm_system.default_model, |
|
|
success=True) |
|
|
logger.info(f"Question answered successfully in {total_time:.2f}s") |
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"RAG processing failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
|
|
|
return RAGResponse( |
|
|
answer="I apologize, but I encountered an error while processing your question. Please try again.", |
|
|
confidence_score=0.0, |
|
|
processing_time=time.time() - start_time, |
|
|
success=False, |
|
|
error_message=error_msg |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_relevant_context(self, question: str, k: int = 5, |
|
|
filters: Optional[Dict[str, Any]] = None) -> List[DocumentChunk]: |
|
|
try: |
|
|
query_embedding = self.embedding_system.generate_query_embedding(question) |
|
|
if not query_embedding: |
|
|
return [] |
|
|
|
|
|
|
|
|
search_results = self.vector_store.similarity_search( |
|
|
query_embedding=query_embedding, |
|
|
k=min(k * 2, self.rerank_top_k), |
|
|
filters=filters) |
|
|
if not search_results: |
|
|
return [] |
|
|
|
|
|
|
|
|
reranked_chunks = self._rerank_chunks(question, search_results) |
|
|
return [result.chunk for result in reranked_chunks[:k]] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get relevant context: {e}") |
|
|
return [] |
|
|
|
|
|
def _rerank_chunks(self, question: str, search_results: List[SearchResult]) -> List[SearchResult]: |
|
|
try: |
|
|
if len(search_results) <= 1: |
|
|
return search_results |
|
|
|
|
|
|
|
|
documents = [result.chunk.content for result in search_results] |
|
|
|
|
|
|
|
|
rerank_results = self.embedding_system.rerank_results( |
|
|
query=question, |
|
|
documents=documents, |
|
|
top_k=len(documents) |
|
|
) |
|
|
|
|
|
|
|
|
reranked_search_results = [] |
|
|
for rerank_result in rerank_results: |
|
|
|
|
|
original_index = rerank_result.index |
|
|
if 0 <= original_index < len(search_results): |
|
|
search_result = search_results[original_index] |
|
|
search_result.rerank_score = rerank_result.score |
|
|
reranked_search_results.append(search_result) |
|
|
|
|
|
logger.debug(f"Reranked {len(search_results)} chunks") |
|
|
return reranked_search_results |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Reranking failed, using original order: {e}") |
|
|
return search_results |
|
|
|
|
|
def _build_context(self, search_results: List[SearchResult]) -> str: |
|
|
""" |
|
|
Build context text from search results. |
|
|
|
|
|
Args: |
|
|
search_results: List of search results |
|
|
|
|
|
Returns: |
|
|
Formatted context text |
|
|
""" |
|
|
context_parts = [] |
|
|
current_length = 0 |
|
|
|
|
|
for i, result in enumerate(search_results): |
|
|
chunk = result.chunk |
|
|
|
|
|
|
|
|
citation_info = self._format_citation_info(chunk) |
|
|
content = f"[Source {i+1}: {citation_info}]\n{chunk.content}\n" |
|
|
|
|
|
|
|
|
if current_length + len(content) > self.max_context_length: |
|
|
|
|
|
remaining_space = self.max_context_length - current_length - len(f"[Source {i+1}: {citation_info}]\n") - 20 |
|
|
if remaining_space > 100: |
|
|
truncated_content = chunk.content[:remaining_space] + "..." |
|
|
content = f"[Source {i+1}: {citation_info}]\n{truncated_content}\n" |
|
|
context_parts.append(content) |
|
|
break |
|
|
|
|
|
context_parts.append(content) |
|
|
current_length += len(content) |
|
|
|
|
|
return "\n".join(context_parts) |
|
|
|
|
|
def _format_citation_info(self, chunk: DocumentChunk) -> str: |
|
|
""" |
|
|
Format citation information for a chunk. |
|
|
|
|
|
Args: |
|
|
chunk: Document chunk |
|
|
|
|
|
Returns: |
|
|
Formatted citation string |
|
|
""" |
|
|
parts = [] |
|
|
|
|
|
|
|
|
if hasattr(chunk.metadata, 'document_id'): |
|
|
parts.append(f"Doc: {chunk.metadata.document_id}") |
|
|
|
|
|
|
|
|
if chunk.metadata.page_number: |
|
|
parts.append(f"Page {chunk.metadata.page_number}") |
|
|
|
|
|
|
|
|
if chunk.metadata.worksheet_name: |
|
|
parts.append(f"Sheet: {chunk.metadata.worksheet_name}") |
|
|
if chunk.metadata.cell_range: |
|
|
parts.append(f"Range: {chunk.metadata.cell_range}") |
|
|
|
|
|
|
|
|
if chunk.metadata.section_title: |
|
|
parts.append(f"Section: {chunk.metadata.section_title}") |
|
|
|
|
|
return ", ".join(parts) if parts else "Unknown source" |
|
|
|
|
|
def _extract_citations(self, search_results: List[SearchResult]) -> List[Citation]: |
|
|
citations = [] |
|
|
for result in search_results: |
|
|
chunk = result.chunk |
|
|
|
|
|
|
|
|
citation = Citation( |
|
|
source_file=getattr(chunk.metadata, 'document_id', 'Unknown'), |
|
|
page_number=chunk.metadata.page_number, |
|
|
worksheet_name=chunk.metadata.worksheet_name, |
|
|
cell_range=chunk.metadata.cell_range, |
|
|
section_title=chunk.metadata.section_title, |
|
|
text_snippet=chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, |
|
|
confidence=result.similarity_score, |
|
|
chunk_id=chunk.metadata.chunk_id |
|
|
) |
|
|
|
|
|
citations.append(citation) |
|
|
|
|
|
return citations |
|
|
|
|
|
def _calculate_confidence_score(self, search_results: List[SearchResult], answer: str) -> float: |
|
|
if not search_results: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
similarity_scores = [result.similarity_score for result in search_results] |
|
|
avg_similarity = sum(similarity_scores) / len(similarity_scores) |
|
|
|
|
|
|
|
|
source_factor = min(len(search_results) / self.final_top_k, 1.0) |
|
|
|
|
|
|
|
|
length_factor = min(len(answer) / 100, 1.0) if answer else 0.0 |
|
|
|
|
|
|
|
|
confidence = (avg_similarity * 0.6 + source_factor * 0.2 + length_factor * 0.2) |
|
|
|
|
|
return min(max(confidence, 0.0), 1.0) |
|
|
|
|
|
|
|
|
|
|
|
def health_check(self) -> Dict[str, bool]: |
|
|
return { |
|
|
"vector_store": self.vector_store.health_check(), |
|
|
"llm_system": self.llm_system.client.health_check(), |
|
|
"embedding_system": True |
|
|
} |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
try: |
|
|
vector_stats = self.vector_store.get_collection_info() |
|
|
embedding_stats = self.embedding_system.get_cache_stats() |
|
|
|
|
|
return { |
|
|
"vector_store": vector_stats.__dict__ if vector_stats else {}, |
|
|
"embedding_cache": embedding_stats, |
|
|
"config": { |
|
|
"max_context_chunks": self.max_context_chunks, |
|
|
"similarity_threshold": self.similarity_threshold, |
|
|
"rerank_top_k": self.rerank_top_k, |
|
|
"final_top_k": self.final_top_k |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get RAG stats: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from src.utilites import validate_api_keys |
|
|
validation_results = validate_api_keys() |
|
|
if not validation_results['valid']: |
|
|
logger.error("Missing required API keys. Please set them in the environment variables.") |
|
|
else: |
|
|
logger.info("All required API keys are present.") |
|
|
|
|
|
config = load_yaml_config("src/config.yaml") |
|
|
rag_engine = RAGEngine(config) |
|
|
|
|
|
|