|
|
from typing import Dict, Any, List, Optional |
|
|
from ...core.base import LatticeComponent, LatticeError |
|
|
from pydantic import BaseModel |
|
|
import anthropic |
|
|
import asyncio |
|
|
from datetime import datetime |
|
|
|
|
|
class RAGConfig(BaseModel): |
|
|
"""RAG configuration""" |
|
|
retriever_type: str = "hybrid" |
|
|
model_name: str = "claude-3-opus" |
|
|
temperature: float = 0.7 |
|
|
max_tokens: int = 1000 |
|
|
top_k: int = 4 |
|
|
rerank_top_k: Optional[int] = None |
|
|
chunk_size: int = 500 |
|
|
chunk_overlap: int = 50 |
|
|
|
|
|
class RetrievalResult(BaseModel): |
|
|
"""Retrieved document chunk""" |
|
|
content: str |
|
|
score: float |
|
|
metadata: Dict[str, Any] |
|
|
source_id: str |
|
|
chunk_id: int |
|
|
|
|
|
class RAGResponse(BaseModel): |
|
|
"""RAG response""" |
|
|
answer: str |
|
|
sources: List[RetrievalResult] |
|
|
metadata: Dict[str, Any] |
|
|
usage: Dict[str, Any] |
|
|
timestamp: datetime |
|
|
|
|
|
class RAGEngine(LatticeComponent): |
|
|
"""Main RAG engine""" |
|
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None): |
|
|
super().__init__(config) |
|
|
self.rag_config = RAGConfig(**(config or {})) |
|
|
self.client = anthropic.Anthropic() |
|
|
|
|
|
async def initialize(self) -> None: |
|
|
"""Initialize RAG engine""" |
|
|
try: |
|
|
|
|
|
if self.rag_config.retriever_type == "hybrid": |
|
|
self.retriever = HybridRetriever(self.config) |
|
|
elif self.rag_config.retriever_type == "dense": |
|
|
self.retriever = DenseRetriever(self.config) |
|
|
else: |
|
|
self.retriever = SparseRetriever(self.config) |
|
|
|
|
|
|
|
|
if self.rag_config.rerank_top_k: |
|
|
self.reranker = CrossEncoderReranker(self.config) |
|
|
|
|
|
await self.retriever.initialize() |
|
|
self._initialized = True |
|
|
|
|
|
except Exception as e: |
|
|
raise LatticeError(f"Failed to initialize RAG engine: {str(e)}") |
|
|
|
|
|
async def validate_config(self) -> bool: |
|
|
"""Validate configuration""" |
|
|
try: |
|
|
RAGConfig(**(self.config or {})) |
|
|
return True |
|
|
except Exception as e: |
|
|
self.logger.error(f"Invalid configuration: {str(e)}") |
|
|
return False |
|
|
|
|
|
async def generate( |
|
|
self, |
|
|
query: str, |
|
|
context: Optional[Dict[str, Any]] = None, |
|
|
**kwargs |
|
|
) -> RAGResponse: |
|
|
"""Generate response using RAG""" |
|
|
self.ensure_initialized() |
|
|
|
|
|
try: |
|
|
|
|
|
retrieved_chunks = await self.retriever.retrieve( |
|
|
query, |
|
|
top_k=self.rag_config.top_k |
|
|
) |
|
|
|
|
|
|
|
|
if self.rag_config.rerank_top_k: |
|
|
retrieved_chunks = await self.reranker.rerank( |
|
|
query, |
|
|
retrieved_chunks, |
|
|
top_k=self.rag_config.rerank_top_k |
|
|
) |
|
|
|
|
|
|
|
|
prompt = self._construct_prompt(query, retrieved_chunks) |
|
|
|
|
|
|
|
|
response = await self.client.messages.create( |
|
|
model=self.rag_config.model_name, |
|
|
max_tokens=self.rag_config.max_tokens, |
|
|
temperature=self.rag_config.temperature, |
|
|
messages=[{"role": "user", "content": prompt}] |
|
|
) |
|
|
|
|
|
return RAGResponse( |
|
|
answer=response.content[0].text, |
|
|
sources=[ |
|
|
RetrievalResult( |
|
|
content=chunk.content, |
|
|
score=chunk.score, |
|
|
metadata=chunk.metadata, |
|
|
source_id=chunk.source_id, |
|
|
chunk_id=chunk.chunk_id |
|
|
) |
|
|
for chunk in retrieved_chunks |
|
|
], |
|
|
metadata={ |
|
|
"model": self.rag_config.model_name, |
|
|
"retriever": self.rag_config.retriever_type, |
|
|
"temperature": self.rag_config.temperature |
|
|
}, |
|
|
usage={ |
|
|
"prompt_tokens": response.usage.prompt_tokens, |
|
|
"completion_tokens": response.usage.completion_tokens, |
|
|
"total_tokens": response.usage.total_tokens |
|
|
}, |
|
|
timestamp=datetime.now() |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Error generating response: {str(e)}") |
|
|
raise LatticeError(f"RAG generation failed: {str(e)}") |
|
|
|
|
|
def _construct_prompt(self, query: str, chunks: List[RetrievalResult]) -> str: |
|
|
"""Construct prompt with retrieved context""" |
|
|
context = "\n\n".join([ |
|
|
f"Context {i+1}:\n{chunk.content}" |
|
|
for i, chunk in enumerate(chunks) |
|
|
]) |
|
|
|
|
|
return f"""Use the following retrieved contexts to answer the question. |
|
|
Include only information from the provided contexts. |
|
|
|
|
|
{context} |
|
|
|
|
|
Question: {query} |
|
|
|
|
|
Answer:""" |