File size: 3,613 Bytes
aca8ab4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""
RAG retrieval functions with context formatting.
"""
import logging
from typing import List, Optional, Dict, Any
from rag.vector_store import VectorStore
from rag.embeddings import EmbeddingGenerator
from utils.langfuse_client import observe
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class RAGRetriever:
"""RAG retrieval with semantic search and context formatting."""
def __init__(
self,
vector_store: VectorStore,
embedding_generator: EmbeddingGenerator,
top_k: int = 5
):
"""
Initialize RAG retriever.
Args:
vector_store: Vector store instance
embedding_generator: Embedding generator instance
top_k: Number of chunks to retrieve
"""
self.vector_store = vector_store
self.embedding_generator = embedding_generator
self.top_k = top_k
@observe(name="rag_retrieve", as_type="span")
def retrieve(
self,
query: str,
top_k: Optional[int] = None,
paper_ids: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Retrieve relevant chunks for a query.
Args:
query: Search query
top_k: Number of chunks to retrieve (overrides default)
paper_ids: Optional filter by paper IDs
Returns:
Dictionary with retrieved chunks and metadata
"""
k = top_k or self.top_k
# Generate query embedding
query_embedding = self.embedding_generator.generate_embedding(query)
# Search vector store
results = self.vector_store.search(
query_embedding=query_embedding,
top_k=k,
paper_ids=paper_ids
)
# Format results
chunks = []
for i, chunk_id in enumerate(results["ids"][0]):
chunks.append({
"chunk_id": chunk_id,
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i] if "distances" in results else None
})
logger.info(f"Retrieved {len(chunks)} chunks for query: {query[:50]}...")
return {
"query": query,
"chunks": chunks,
"chunk_ids": [c["chunk_id"] for c in chunks]
}
def format_context(
self,
chunks: List[Dict[str, Any]],
include_metadata: bool = True
) -> str:
"""
Format retrieved chunks into context string.
Args:
chunks: List of chunk dictionaries
include_metadata: Whether to include metadata in context
Returns:
Formatted context string
"""
context_parts = []
for i, chunk in enumerate(chunks, 1):
metadata = chunk["metadata"]
content = chunk["content"]
if include_metadata:
# Optimized: Concise headers to reduce token usage
header = f"[Chunk {i}] {metadata.get('title', 'Unknown')}\n"
if metadata.get('section'):
header += f"Section: {metadata['section']} | "
if metadata.get('page_number'):
header += f"Page {metadata['page_number']}"
header += "\n" + "=" * 40 + "\n"
context_parts.append(header + content)
else:
context_parts.append(content)
return "\n\n".join(context_parts)
|