Lattice / core /rag /engine.py
cryogenic22's picture
Create core/rag/engine.py
d770268 verified
from typing import Dict, Any, List, Optional
from ...core.base import LatticeComponent, LatticeError
from pydantic import BaseModel
import anthropic
import asyncio
from datetime import datetime
class RAGConfig(BaseModel):
"""RAG configuration"""
retriever_type: str = "hybrid"
model_name: str = "claude-3-opus"
temperature: float = 0.7
max_tokens: int = 1000
top_k: int = 4
rerank_top_k: Optional[int] = None
chunk_size: int = 500
chunk_overlap: int = 50
class RetrievalResult(BaseModel):
"""Retrieved document chunk"""
content: str
score: float
metadata: Dict[str, Any]
source_id: str
chunk_id: int
class RAGResponse(BaseModel):
"""RAG response"""
answer: str
sources: List[RetrievalResult]
metadata: Dict[str, Any]
usage: Dict[str, Any]
timestamp: datetime
class RAGEngine(LatticeComponent):
"""Main RAG engine"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self.rag_config = RAGConfig(**(config or {}))
self.client = anthropic.Anthropic()
async def initialize(self) -> None:
"""Initialize RAG engine"""
try:
# Initialize retrievers
if self.rag_config.retriever_type == "hybrid":
self.retriever = HybridRetriever(self.config)
elif self.rag_config.retriever_type == "dense":
self.retriever = DenseRetriever(self.config)
else:
self.retriever = SparseRetriever(self.config)
# Initialize reranker if needed
if self.rag_config.rerank_top_k:
self.reranker = CrossEncoderReranker(self.config)
await self.retriever.initialize()
self._initialized = True
except Exception as e:
raise LatticeError(f"Failed to initialize RAG engine: {str(e)}")
async def validate_config(self) -> bool:
"""Validate configuration"""
try:
RAGConfig(**(self.config or {}))
return True
except Exception as e:
self.logger.error(f"Invalid configuration: {str(e)}")
return False
async def generate(
self,
query: str,
context: Optional[Dict[str, Any]] = None,
**kwargs
) -> RAGResponse:
"""Generate response using RAG"""
self.ensure_initialized()
try:
# Retrieve relevant chunks
retrieved_chunks = await self.retriever.retrieve(
query,
top_k=self.rag_config.top_k
)
# Rerank if enabled
if self.rag_config.rerank_top_k:
retrieved_chunks = await self.reranker.rerank(
query,
retrieved_chunks,
top_k=self.rag_config.rerank_top_k
)
# Construct prompt with retrieved context
prompt = self._construct_prompt(query, retrieved_chunks)
# Generate response
response = await self.client.messages.create(
model=self.rag_config.model_name,
max_tokens=self.rag_config.max_tokens,
temperature=self.rag_config.temperature,
messages=[{"role": "user", "content": prompt}]
)
return RAGResponse(
answer=response.content[0].text,
sources=[
RetrievalResult(
content=chunk.content,
score=chunk.score,
metadata=chunk.metadata,
source_id=chunk.source_id,
chunk_id=chunk.chunk_id
)
for chunk in retrieved_chunks
],
metadata={
"model": self.rag_config.model_name,
"retriever": self.rag_config.retriever_type,
"temperature": self.rag_config.temperature
},
usage={
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
},
timestamp=datetime.now()
)
except Exception as e:
self.logger.error(f"Error generating response: {str(e)}")
raise LatticeError(f"RAG generation failed: {str(e)}")
def _construct_prompt(self, query: str, chunks: List[RetrievalResult]) -> str:
"""Construct prompt with retrieved context"""
context = "\n\n".join([
f"Context {i+1}:\n{chunk.content}"
for i, chunk in enumerate(chunks)
])
return f"""Use the following retrieved contexts to answer the question.
Include only information from the provided contexts.
{context}
Question: {query}
Answer:"""