from typing import Dict, Any, List, Optional from ...core.base import LatticeComponent, LatticeError from pydantic import BaseModel import anthropic import asyncio from datetime import datetime class RAGConfig(BaseModel): """RAG configuration""" retriever_type: str = "hybrid" model_name: str = "claude-3-opus" temperature: float = 0.7 max_tokens: int = 1000 top_k: int = 4 rerank_top_k: Optional[int] = None chunk_size: int = 500 chunk_overlap: int = 50 class RetrievalResult(BaseModel): """Retrieved document chunk""" content: str score: float metadata: Dict[str, Any] source_id: str chunk_id: int class RAGResponse(BaseModel): """RAG response""" answer: str sources: List[RetrievalResult] metadata: Dict[str, Any] usage: Dict[str, Any] timestamp: datetime class RAGEngine(LatticeComponent): """Main RAG engine""" def __init__(self, config: Optional[Dict[str, Any]] = None): super().__init__(config) self.rag_config = RAGConfig(**(config or {})) self.client = anthropic.Anthropic() async def initialize(self) -> None: """Initialize RAG engine""" try: # Initialize retrievers if self.rag_config.retriever_type == "hybrid": self.retriever = HybridRetriever(self.config) elif self.rag_config.retriever_type == "dense": self.retriever = DenseRetriever(self.config) else: self.retriever = SparseRetriever(self.config) # Initialize reranker if needed if self.rag_config.rerank_top_k: self.reranker = CrossEncoderReranker(self.config) await self.retriever.initialize() self._initialized = True except Exception as e: raise LatticeError(f"Failed to initialize RAG engine: {str(e)}") async def validate_config(self) -> bool: """Validate configuration""" try: RAGConfig(**(self.config or {})) return True except Exception as e: self.logger.error(f"Invalid configuration: {str(e)}") return False async def generate( self, query: str, context: Optional[Dict[str, Any]] = None, **kwargs ) -> RAGResponse: """Generate response using RAG""" self.ensure_initialized() try: # Retrieve relevant chunks retrieved_chunks = await self.retriever.retrieve( query, top_k=self.rag_config.top_k ) # Rerank if enabled if self.rag_config.rerank_top_k: retrieved_chunks = await self.reranker.rerank( query, retrieved_chunks, top_k=self.rag_config.rerank_top_k ) # Construct prompt with retrieved context prompt = self._construct_prompt(query, retrieved_chunks) # Generate response response = await self.client.messages.create( model=self.rag_config.model_name, max_tokens=self.rag_config.max_tokens, temperature=self.rag_config.temperature, messages=[{"role": "user", "content": prompt}] ) return RAGResponse( answer=response.content[0].text, sources=[ RetrievalResult( content=chunk.content, score=chunk.score, metadata=chunk.metadata, source_id=chunk.source_id, chunk_id=chunk.chunk_id ) for chunk in retrieved_chunks ], metadata={ "model": self.rag_config.model_name, "retriever": self.rag_config.retriever_type, "temperature": self.rag_config.temperature }, usage={ "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens }, timestamp=datetime.now() ) except Exception as e: self.logger.error(f"Error generating response: {str(e)}") raise LatticeError(f"RAG generation failed: {str(e)}") def _construct_prompt(self, query: str, chunks: List[RetrievalResult]) -> str: """Construct prompt with retrieved context""" context = "\n\n".join([ f"Context {i+1}:\n{chunk.content}" for i, chunk in enumerate(chunks) ]) return f"""Use the following retrieved contexts to answer the question. Include only information from the provided contexts. {context} Question: {query} Answer:"""