RAG_backend / app /rag /retrieval.py
ChiragPatankar's picture
Add all RAG backend files - force add
c19c7bf
"""
Retrieval pipeline with confidence scoring and filtering.
"""
from typing import List, Dict, Any, Optional, Tuple
import logging
import re
from app.config import settings
from app.rag.embeddings import get_embedding_service
from app.rag.vectorstore import get_vector_store
from app.rag.intent import detect_intents, check_direct_match, get_intent_keywords
from app.models.schemas import RetrievalResult
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RetrievalService:
"""
Handles document retrieval with confidence scoring.
Implements threshold-based filtering for quality control.
"""
def __init__(
self,
top_k: int = settings.TOP_K,
similarity_threshold: float = settings.SIMILARITY_THRESHOLD
):
"""
Initialize the retrieval service.
Args:
top_k: Number of results to retrieve
similarity_threshold: Minimum similarity score to consider relevant
"""
self.top_k = top_k
self.similarity_threshold = similarity_threshold
self.embedding_service = get_embedding_service()
self.vector_store = get_vector_store()
def retrieve(
self,
query: str,
tenant_id: str, # CRITICAL: Multi-tenant isolation
kb_id: str,
user_id: str,
top_k: Optional[int] = None
) -> Tuple[List[RetrievalResult], float, bool]:
"""
Retrieve relevant documents for a query.
Args:
query: User's question
tenant_id: Tenant ID for multi-tenant isolation (CRITICAL)
kb_id: Knowledge base ID to search
user_id: User ID for filtering
top_k: Optional override for number of results
Returns:
Tuple of (results, average_confidence, has_relevant_results)
"""
k = top_k or self.top_k
# Generate query embedding
logger.info(f"Generating embedding for query: {query[:50]}...")
query_embedding = self.embedding_service.embed_query(query)
# Search vector store with filters - MUST include tenant_id for isolation
filter_dict = {
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
"kb_id": kb_id,
"user_id": user_id
}
logger.info(f"Searching vector store with filters: {filter_dict}")
raw_results = self.vector_store.search(
query_embedding=query_embedding,
top_k=k,
filter_dict=filter_dict
)
if not raw_results:
logger.warning(f"No results found for query in kb_id={kb_id}")
return [], 0.0, False
# Convert to RetrievalResult objects
results = []
for r in raw_results:
results.append(RetrievalResult(
chunk_id=r['id'],
content=r['content'],
metadata=r['metadata'],
similarity_score=r['similarity_score']
))
# HEAVY CONFIDENCE MODE: Use maximum similarity score from top results
# This ensures confidence reflects the best match found, not dragged down by weaker results
if results:
# Get top 3 results and use the maximum similarity score
# This gives maximum confidence if there's at least one strong match
top_results = results[:3]
max_score = max(r.similarity_score for r in top_results)
# If max score is good (>=0.4), use it directly
# Otherwise, use weighted average of top 3 to avoid over-inflating weak matches
if max_score >= 0.4:
avg_confidence = max_score
else:
# For weaker matches, use weighted average of top 3
scores = [r.similarity_score for r in top_results]
weights = [1.0, 0.7, 0.5][:len(scores)] # Aggressive weighting
weighted_sum = sum(s * w for s, w in zip(scores, weights))
total_weight = sum(weights[:len(scores)])
avg_confidence = weighted_sum / total_weight if total_weight > 0 else max_score
else:
avg_confidence = 0.0
# Filter results above threshold
filtered_results = [
r for r in results
if r.similarity_score >= self.similarity_threshold
]
# If no results pass threshold but we have results, use top results anyway
# This prevents over-filtering when threshold is too strict
if not filtered_results and results:
logger.warning(f"No results passed threshold {self.similarity_threshold}, using top {min(3, len(results))} results anyway")
filtered_results = results[:min(3, len(results))]
# Recalculate confidence with the fallback results
if filtered_results:
scores = [r.similarity_score for r in filtered_results]
avg_confidence = sum(scores) / len(scores) if scores else 0.0
# DIRECT MATCH GATE: Check if at least one chunk directly matches query intent
# For integration/API questions, this gate is stricter
has_direct_match = False
if filtered_results:
chunk_texts = [r.content for r in filtered_results]
intents = detect_intents(query)
intent_keywords = get_intent_keywords(intents)
# For integration/API questions, require direct match
if "integration" in intents or "api" in query.lower():
has_direct_match = check_direct_match(query, chunk_texts, intent_keywords)
logger.info(f"Direct match check (strict for integration): {has_direct_match} (intents: {intents})")
else:
# For other questions, be more lenient - just check if important words match
query_words = set(re.findall(r'\b\w+\b', query.lower()))
stop_words = {"the", "a", "an", "is", "are", "was", "were", "be", "been",
"to", "of", "and", "or", "but", "in", "on", "at", "for",
"with", "how", "what", "when", "where", "why", "do", "does"}
important_words = query_words - stop_words
# Check if at least one important word appears in chunks
for chunk in chunk_texts:
chunk_lower = chunk.lower()
matches = sum(1 for word in important_words if word in chunk_lower)
if matches >= 1 and len(important_words) > 0: # At least one important word
has_direct_match = True
break
logger.info(f"Direct match check (lenient): {has_direct_match} (intents: {intents})")
# Only consider relevant if we have filtered results AND (direct match OR high confidence)
# High confidence (>0.40) can bypass direct match requirement for non-integration questions
has_relevant = len(filtered_results) > 0 and (has_direct_match or avg_confidence > 0.40)
logger.info(
f"Retrieved {len(results)} results, "
f"{len(filtered_results)} above threshold ({self.similarity_threshold}), "
f"avg confidence: {avg_confidence:.3f}, "
f"direct match: {has_direct_match}"
)
return filtered_results, avg_confidence, has_relevant
def get_context_for_llm(
self,
results: List[RetrievalResult],
max_tokens: int = settings.MAX_CONTEXT_TOKENS
) -> Tuple[str, List[Dict[str, Any]]]:
"""
Format retrieved results into context for the LLM.
Args:
results: List of retrieval results
max_tokens: Maximum tokens for context
Returns:
Tuple of (formatted_context, citation_info)
"""
if not results:
return "", []
context_parts = []
citations = []
current_tokens = 0
# Estimate tokens (rough approximation: 1 token ≈ 4 chars)
for i, result in enumerate(results):
chunk_text = result.content
estimated_tokens = len(chunk_text) // 4
if current_tokens + estimated_tokens > max_tokens:
logger.info(f"Truncating context at {i} chunks due to token limit")
break
# Format chunk with source info
source_info = f"[Source {i+1}: {result.metadata.get('file_name', 'Unknown')}]"
if result.metadata.get('page_number'):
source_info += f" (Page {result.metadata['page_number']})"
context_parts.append(f"{source_info}\n{chunk_text}")
# Build citation info
citations.append({
"index": i + 1,
"file_name": result.metadata.get('file_name', 'Unknown'),
"chunk_id": result.chunk_id,
"page_number": result.metadata.get('page_number'),
"similarity_score": result.similarity_score,
"excerpt": chunk_text[:200] + "..." if len(chunk_text) > 200 else chunk_text
})
current_tokens += estimated_tokens
formatted_context = "\n\n---\n\n".join(context_parts)
return formatted_context, citations
# Global retrieval service instance
_retrieval_service: Optional[RetrievalService] = None
def get_retrieval_service() -> RetrievalService:
"""Get the global retrieval service instance."""
global _retrieval_service
if _retrieval_service is None:
_retrieval_service = RetrievalService()
return _retrieval_service