File size: 5,133 Bytes
d770268 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from typing import Dict, Any, List, Optional
from ...core.base import LatticeComponent, LatticeError
from pydantic import BaseModel
import anthropic
import asyncio
from datetime import datetime
class RAGConfig(BaseModel):
"""RAG configuration"""
retriever_type: str = "hybrid"
model_name: str = "claude-3-opus"
temperature: float = 0.7
max_tokens: int = 1000
top_k: int = 4
rerank_top_k: Optional[int] = None
chunk_size: int = 500
chunk_overlap: int = 50
class RetrievalResult(BaseModel):
"""Retrieved document chunk"""
content: str
score: float
metadata: Dict[str, Any]
source_id: str
chunk_id: int
class RAGResponse(BaseModel):
"""RAG response"""
answer: str
sources: List[RetrievalResult]
metadata: Dict[str, Any]
usage: Dict[str, Any]
timestamp: datetime
class RAGEngine(LatticeComponent):
"""Main RAG engine"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self.rag_config = RAGConfig(**(config or {}))
self.client = anthropic.Anthropic()
async def initialize(self) -> None:
"""Initialize RAG engine"""
try:
# Initialize retrievers
if self.rag_config.retriever_type == "hybrid":
self.retriever = HybridRetriever(self.config)
elif self.rag_config.retriever_type == "dense":
self.retriever = DenseRetriever(self.config)
else:
self.retriever = SparseRetriever(self.config)
# Initialize reranker if needed
if self.rag_config.rerank_top_k:
self.reranker = CrossEncoderReranker(self.config)
await self.retriever.initialize()
self._initialized = True
except Exception as e:
raise LatticeError(f"Failed to initialize RAG engine: {str(e)}")
async def validate_config(self) -> bool:
"""Validate configuration"""
try:
RAGConfig(**(self.config or {}))
return True
except Exception as e:
self.logger.error(f"Invalid configuration: {str(e)}")
return False
async def generate(
self,
query: str,
context: Optional[Dict[str, Any]] = None,
**kwargs
) -> RAGResponse:
"""Generate response using RAG"""
self.ensure_initialized()
try:
# Retrieve relevant chunks
retrieved_chunks = await self.retriever.retrieve(
query,
top_k=self.rag_config.top_k
)
# Rerank if enabled
if self.rag_config.rerank_top_k:
retrieved_chunks = await self.reranker.rerank(
query,
retrieved_chunks,
top_k=self.rag_config.rerank_top_k
)
# Construct prompt with retrieved context
prompt = self._construct_prompt(query, retrieved_chunks)
# Generate response
response = await self.client.messages.create(
model=self.rag_config.model_name,
max_tokens=self.rag_config.max_tokens,
temperature=self.rag_config.temperature,
messages=[{"role": "user", "content": prompt}]
)
return RAGResponse(
answer=response.content[0].text,
sources=[
RetrievalResult(
content=chunk.content,
score=chunk.score,
metadata=chunk.metadata,
source_id=chunk.source_id,
chunk_id=chunk.chunk_id
)
for chunk in retrieved_chunks
],
metadata={
"model": self.rag_config.model_name,
"retriever": self.rag_config.retriever_type,
"temperature": self.rag_config.temperature
},
usage={
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
},
timestamp=datetime.now()
)
except Exception as e:
self.logger.error(f"Error generating response: {str(e)}")
raise LatticeError(f"RAG generation failed: {str(e)}")
def _construct_prompt(self, query: str, chunks: List[RetrievalResult]) -> str:
"""Construct prompt with retrieved context"""
context = "\n\n".join([
f"Context {i+1}:\n{chunk.content}"
for i, chunk in enumerate(chunks)
])
return f"""Use the following retrieved contexts to answer the question.
Include only information from the provided contexts.
{context}
Question: {query}
Answer:""" |