File size: 15,070 Bytes
7dfe46c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
import logging
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
import time
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.embedding_system import EmbeddingSystem, RerankResult
from src.vector_store import QdrantVectorStore, SearchResult
from src.groq_client import LLMSystem
from src.document_processor import DocumentChunk
from src.utilites import load_yaml_config
try:
from logger.custom_logger import CustomLoggerTracker
custom_log = CustomLoggerTracker()
logger = custom_log.get_logger("rag_engine")
except ImportError:
logger = logging.getLogger("rag_engine")
@dataclass
class Citation:
source_file: str
page_number: Optional[int] = None
worksheet_name: Optional[str] = None
cell_range: Optional[str] = None
section_title: Optional[str] = None
text_snippet: str = ""
confidence: float = 0.0
chunk_id: str = ""
@dataclass
class RAGResponse:
answer: str
confidence_score: float
citations: List[Citation] = field(default_factory=list)
context_chunks: List[DocumentChunk] = field(default_factory=list)
processing_time: float = 0.0
retrieval_time: float = 0.0
generation_time: float = 0.0
rerank_time: float = 0.0
total_chunks_retrieved: int = 0
total_chunks_reranked: int = 0
model_used: str = ""
success: bool = True
error_message: Optional[str] = None
class RAGEngine:
def __init__(self, config: Dict[str, Any]):
self.config = config
# Initialize components
self.embedding_system = EmbeddingSystem(config)
self.vector_store = QdrantVectorStore(config)
self.llm_system = LLMSystem(config)
# RAG parameters
self.max_context_chunks = config.get('max_context_chunks', 5)
self.similarity_threshold = config.get('similarity_threshold', 0.7)
self.rerank_top_k = config.get('rerank_top_k', 20)
self.final_top_k = config.get('final_top_k', 5)
self.max_context_length = config.get('max_context_length', 4000)
logger.info(f"RAG engine initialized with max_context_chunks={self.max_context_chunks}")
def answer_question(self, question: str, filters: Optional[Dict[str, Any]] = None) -> RAGResponse:
start_time = time.time()
try:
logger.info(f"Processing question: {question[:100]}...")
# Step 1: Generate query embedding
query_embedding = self.embedding_system.generate_query_embedding(question)
if not query_embedding:
return RAGResponse(
answer="I apologize, but I'm unable to process your question due to an embedding generation error.",
confidence_score=0.0,
success=False,
error_message="Failed to generate query embedding")
# Step 2: Retrieve relevant chunks
retrieval_start = time.time()
search_results = self.vector_store.similarity_search(
query_embedding=query_embedding,
k=self.rerank_top_k,
filters=filters)
retrieval_time = time.time() - retrieval_start
if not search_results:
return RAGResponse(
answer="I couldn't find any relevant information to answer your question. Please try rephrasing your question or check if the relevant documents have been uploaded.",
confidence_score=0.0,
retrieval_time=retrieval_time,
processing_time=time.time() - start_time,
success=True)
logger.info(f"Retrieved {len(search_results)} chunks from vector store in {retrieval_time:.2f}s")
# Step 3: Rerank results
rerank_start = time.time()
reranked_chunks = self._rerank_chunks(question, search_results)
rerank_time = time.time() - rerank_start
# Step 4: Select top chunks and build context
context_chunks = reranked_chunks[:self.final_top_k]
context_text = self._build_context(context_chunks)
logger.info(f"Built context from top {len(context_chunks)} chunks")
# Step 5: Generate answer
generation_start = time.time()
answer = self.llm_system.answer_question(question, context_text)
generation_time = time.time() - generation_start
if not answer:
return RAGResponse(
answer="I apologize, but I was unable to generate an answer to your question.",
confidence_score=0.0,
retrieval_time=retrieval_time,
generation_time=generation_time,
rerank_time=rerank_time,
processing_time=time.time() - start_time,
success=False,
error_message="LLM failed to generate an answer")
logger.info(f"Generated answer in {generation_time:.2f}s")
# Step 6: Extract citations
citations = self._extract_citations(context_chunks)
logger.info(f"Extracted {len(citations)} citations")
# Step 7: Calculate confidence score
confidence_score = self._calculate_confidence_score(search_results, answer)
logger.info(f"Calculated confidence score: {confidence_score:.2f}")
total_time = time.time() - start_time
response = RAGResponse(
answer=answer,
confidence_score=confidence_score,
citations=citations,
context_chunks=[result.chunk for result in context_chunks],
processing_time=total_time,
retrieval_time=retrieval_time,
generation_time=generation_time,
rerank_time=rerank_time,
total_chunks_retrieved=len(search_results),
total_chunks_reranked=len(reranked_chunks),
model_used=self.llm_system.default_model,
success=True)
logger.info(f"Question answered successfully in {total_time:.2f}s")
return response
except Exception as e:
error_msg = f"RAG processing failed: {str(e)}"
logger.error(error_msg)
return RAGResponse(
answer="I apologize, but I encountered an error while processing your question. Please try again.",
confidence_score=0.0,
processing_time=time.time() - start_time,
success=False,
error_message=error_msg
)
def get_relevant_context(self, question: str, k: int = 5,
filters: Optional[Dict[str, Any]] = None) -> List[DocumentChunk]:
try:
query_embedding = self.embedding_system.generate_query_embedding(question)
if not query_embedding:
return []
# Retrieve and rerank
search_results = self.vector_store.similarity_search(
query_embedding=query_embedding,
k=min(k * 2, self.rerank_top_k),
filters=filters)
if not search_results:
return []
# Rerank and return top k
reranked_chunks = self._rerank_chunks(question, search_results)
return [result.chunk for result in reranked_chunks[:k]]
except Exception as e:
logger.error(f"Failed to get relevant context: {e}")
return []
def _rerank_chunks(self, question: str, search_results: List[SearchResult]) -> List[SearchResult]:
try:
if len(search_results) <= 1:
return search_results
# Extract documents for reranking
documents = [result.chunk.content for result in search_results]
# Perform reranking
rerank_results = self.embedding_system.rerank_results(
query=question,
documents=documents,
top_k=len(documents)
)
# Map rerank results back to search results
reranked_search_results = []
for rerank_result in rerank_results:
# Find corresponding search result
original_index = rerank_result.index
if 0 <= original_index < len(search_results):
search_result = search_results[original_index]
search_result.rerank_score = rerank_result.score
reranked_search_results.append(search_result)
logger.debug(f"Reranked {len(search_results)} chunks")
return reranked_search_results
except Exception as e:
logger.warning(f"Reranking failed, using original order: {e}")
return search_results
def _build_context(self, search_results: List[SearchResult]) -> str:
"""
Build context text from search results.
Args:
search_results: List of search results
Returns:
Formatted context text
"""
context_parts = []
current_length = 0
for i, result in enumerate(search_results):
chunk = result.chunk
# Create context entry with citation info
citation_info = self._format_citation_info(chunk)
content = f"[Source {i+1}: {citation_info}]\n{chunk.content}\n"
# Check if adding this chunk would exceed max context length
if current_length + len(content) > self.max_context_length:
# Try to fit a truncated version
remaining_space = self.max_context_length - current_length - len(f"[Source {i+1}: {citation_info}]\n") - 20
if remaining_space > 100: # Only add if we have reasonable space
truncated_content = chunk.content[:remaining_space] + "..."
content = f"[Source {i+1}: {citation_info}]\n{truncated_content}\n"
context_parts.append(content)
break
context_parts.append(content)
current_length += len(content)
return "\n".join(context_parts)
def _format_citation_info(self, chunk: DocumentChunk) -> str:
"""
Format citation information for a chunk.
Args:
chunk: Document chunk
Returns:
Formatted citation string
"""
parts = []
# Add document ID or filename if available
if hasattr(chunk.metadata, 'document_id'):
parts.append(f"Doc: {chunk.metadata.document_id}")
# Add page number for PDFs
if chunk.metadata.page_number:
parts.append(f"Page {chunk.metadata.page_number}")
# Add worksheet info for Excel
if chunk.metadata.worksheet_name:
parts.append(f"Sheet: {chunk.metadata.worksheet_name}")
if chunk.metadata.cell_range:
parts.append(f"Range: {chunk.metadata.cell_range}")
# Add section title if available
if chunk.metadata.section_title:
parts.append(f"Section: {chunk.metadata.section_title}")
return ", ".join(parts) if parts else "Unknown source"
def _extract_citations(self, search_results: List[SearchResult]) -> List[Citation]:
citations = []
for result in search_results:
chunk = result.chunk
# Create citation
citation = Citation(
source_file=getattr(chunk.metadata, 'document_id', 'Unknown'),
page_number=chunk.metadata.page_number,
worksheet_name=chunk.metadata.worksheet_name,
cell_range=chunk.metadata.cell_range,
section_title=chunk.metadata.section_title,
text_snippet=chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
confidence=result.similarity_score,
chunk_id=chunk.metadata.chunk_id
)
citations.append(citation)
return citations
def _calculate_confidence_score(self, search_results: List[SearchResult], answer: str) -> float:
if not search_results:
return 0.0
# Base confidence on similarity scores
similarity_scores = [result.similarity_score for result in search_results]
avg_similarity = sum(similarity_scores) / len(similarity_scores)
# Adjust based on number of sources
source_factor = min(len(search_results) / self.final_top_k, 1.0)
# Adjust based on answer length (very short answers might be less reliable)
length_factor = min(len(answer) / 100, 1.0) if answer else 0.0
# Combine factors
confidence = (avg_similarity * 0.6 + source_factor * 0.2 + length_factor * 0.2)
return min(max(confidence, 0.0), 1.0) # Clamp to [0, 1]
def health_check(self) -> Dict[str, bool]:
return {
"vector_store": self.vector_store.health_check(),
"llm_system": self.llm_system.client.health_check(),
"embedding_system": True # Silicon Flow doesn't have a direct health check
}
def get_stats(self) -> Dict[str, Any]:
try:
vector_stats = self.vector_store.get_collection_info()
embedding_stats = self.embedding_system.get_cache_stats()
return {
"vector_store": vector_stats.__dict__ if vector_stats else {},
"embedding_cache": embedding_stats,
"config": {
"max_context_chunks": self.max_context_chunks,
"similarity_threshold": self.similarity_threshold,
"rerank_top_k": self.rerank_top_k,
"final_top_k": self.final_top_k
}
}
except Exception as e:
logger.error(f"Failed to get RAG stats: {e}")
return {"error": str(e)}
if __name__ == "__main__":
from src.utilites import validate_api_keys
validation_results = validate_api_keys()
if not validation_results['valid']:
logger.error("Missing required API keys. Please set them in the environment variables.")
else:
logger.info("All required API keys are present.")
## Example usage
config = load_yaml_config("src/config.yaml")
rag_engine = RAGEngine(config)
|