Spaces:
Sleeping
Sleeping
| """ | |
| Retrieval pipeline with confidence scoring and filtering. | |
| """ | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import logging | |
| import re | |
| from app.config import settings | |
| from app.rag.embeddings import get_embedding_service | |
| from app.rag.vectorstore import get_vector_store | |
| from app.rag.intent import detect_intents, check_direct_match, get_intent_keywords | |
| from app.models.schemas import RetrievalResult | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class RetrievalService: | |
| """ | |
| Handles document retrieval with confidence scoring. | |
| Implements threshold-based filtering for quality control. | |
| """ | |
| def __init__( | |
| self, | |
| top_k: int = settings.TOP_K, | |
| similarity_threshold: float = settings.SIMILARITY_THRESHOLD | |
| ): | |
| """ | |
| Initialize the retrieval service. | |
| Args: | |
| top_k: Number of results to retrieve | |
| similarity_threshold: Minimum similarity score to consider relevant | |
| """ | |
| self.top_k = top_k | |
| self.similarity_threshold = similarity_threshold | |
| self.embedding_service = get_embedding_service() | |
| self.vector_store = get_vector_store() | |
| def retrieve( | |
| self, | |
| query: str, | |
| tenant_id: str, # CRITICAL: Multi-tenant isolation | |
| kb_id: str, | |
| user_id: str, | |
| top_k: Optional[int] = None | |
| ) -> Tuple[List[RetrievalResult], float, bool]: | |
| """ | |
| Retrieve relevant documents for a query. | |
| Args: | |
| query: User's question | |
| tenant_id: Tenant ID for multi-tenant isolation (CRITICAL) | |
| kb_id: Knowledge base ID to search | |
| user_id: User ID for filtering | |
| top_k: Optional override for number of results | |
| Returns: | |
| Tuple of (results, average_confidence, has_relevant_results) | |
| """ | |
| k = top_k or self.top_k | |
| # Generate query embedding | |
| logger.info(f"Generating embedding for query: {query[:50]}...") | |
| query_embedding = self.embedding_service.embed_query(query) | |
| # Search vector store with filters - MUST include tenant_id for isolation | |
| filter_dict = { | |
| "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation | |
| "kb_id": kb_id, | |
| "user_id": user_id | |
| } | |
| logger.info(f"Searching vector store with filters: {filter_dict}") | |
| raw_results = self.vector_store.search( | |
| query_embedding=query_embedding, | |
| top_k=k, | |
| filter_dict=filter_dict | |
| ) | |
| if not raw_results: | |
| logger.warning(f"No results found for query in kb_id={kb_id}") | |
| return [], 0.0, False | |
| # Convert to RetrievalResult objects | |
| results = [] | |
| for r in raw_results: | |
| results.append(RetrievalResult( | |
| chunk_id=r['id'], | |
| content=r['content'], | |
| metadata=r['metadata'], | |
| similarity_score=r['similarity_score'] | |
| )) | |
| # HEAVY CONFIDENCE MODE: Use maximum similarity score from top results | |
| # This ensures confidence reflects the best match found, not dragged down by weaker results | |
| if results: | |
| # Get top 3 results and use the maximum similarity score | |
| # This gives maximum confidence if there's at least one strong match | |
| top_results = results[:3] | |
| max_score = max(r.similarity_score for r in top_results) | |
| # If max score is good (>=0.4), use it directly | |
| # Otherwise, use weighted average of top 3 to avoid over-inflating weak matches | |
| if max_score >= 0.4: | |
| avg_confidence = max_score | |
| else: | |
| # For weaker matches, use weighted average of top 3 | |
| scores = [r.similarity_score for r in top_results] | |
| weights = [1.0, 0.7, 0.5][:len(scores)] # Aggressive weighting | |
| weighted_sum = sum(s * w for s, w in zip(scores, weights)) | |
| total_weight = sum(weights[:len(scores)]) | |
| avg_confidence = weighted_sum / total_weight if total_weight > 0 else max_score | |
| else: | |
| avg_confidence = 0.0 | |
| # Filter results above threshold | |
| filtered_results = [ | |
| r for r in results | |
| if r.similarity_score >= self.similarity_threshold | |
| ] | |
| # If no results pass threshold but we have results, use top results anyway | |
| # This prevents over-filtering when threshold is too strict | |
| if not filtered_results and results: | |
| logger.warning(f"No results passed threshold {self.similarity_threshold}, using top {min(3, len(results))} results anyway") | |
| filtered_results = results[:min(3, len(results))] | |
| # Recalculate confidence with the fallback results | |
| if filtered_results: | |
| scores = [r.similarity_score for r in filtered_results] | |
| avg_confidence = sum(scores) / len(scores) if scores else 0.0 | |
| # DIRECT MATCH GATE: Check if at least one chunk directly matches query intent | |
| # For integration/API questions, this gate is stricter | |
| has_direct_match = False | |
| if filtered_results: | |
| chunk_texts = [r.content for r in filtered_results] | |
| intents = detect_intents(query) | |
| intent_keywords = get_intent_keywords(intents) | |
| # For integration/API questions, require direct match | |
| if "integration" in intents or "api" in query.lower(): | |
| has_direct_match = check_direct_match(query, chunk_texts, intent_keywords) | |
| logger.info(f"Direct match check (strict for integration): {has_direct_match} (intents: {intents})") | |
| else: | |
| # For other questions, be more lenient - just check if important words match | |
| query_words = set(re.findall(r'\b\w+\b', query.lower())) | |
| stop_words = {"the", "a", "an", "is", "are", "was", "were", "be", "been", | |
| "to", "of", "and", "or", "but", "in", "on", "at", "for", | |
| "with", "how", "what", "when", "where", "why", "do", "does"} | |
| important_words = query_words - stop_words | |
| # Check if at least one important word appears in chunks | |
| for chunk in chunk_texts: | |
| chunk_lower = chunk.lower() | |
| matches = sum(1 for word in important_words if word in chunk_lower) | |
| if matches >= 1 and len(important_words) > 0: # At least one important word | |
| has_direct_match = True | |
| break | |
| logger.info(f"Direct match check (lenient): {has_direct_match} (intents: {intents})") | |
| # Only consider relevant if we have filtered results AND (direct match OR high confidence) | |
| # High confidence (>0.40) can bypass direct match requirement for non-integration questions | |
| has_relevant = len(filtered_results) > 0 and (has_direct_match or avg_confidence > 0.40) | |
| logger.info( | |
| f"Retrieved {len(results)} results, " | |
| f"{len(filtered_results)} above threshold ({self.similarity_threshold}), " | |
| f"avg confidence: {avg_confidence:.3f}, " | |
| f"direct match: {has_direct_match}" | |
| ) | |
| return filtered_results, avg_confidence, has_relevant | |
| def get_context_for_llm( | |
| self, | |
| results: List[RetrievalResult], | |
| max_tokens: int = settings.MAX_CONTEXT_TOKENS | |
| ) -> Tuple[str, List[Dict[str, Any]]]: | |
| """ | |
| Format retrieved results into context for the LLM. | |
| Args: | |
| results: List of retrieval results | |
| max_tokens: Maximum tokens for context | |
| Returns: | |
| Tuple of (formatted_context, citation_info) | |
| """ | |
| if not results: | |
| return "", [] | |
| context_parts = [] | |
| citations = [] | |
| current_tokens = 0 | |
| # Estimate tokens (rough approximation: 1 token ≈ 4 chars) | |
| for i, result in enumerate(results): | |
| chunk_text = result.content | |
| estimated_tokens = len(chunk_text) // 4 | |
| if current_tokens + estimated_tokens > max_tokens: | |
| logger.info(f"Truncating context at {i} chunks due to token limit") | |
| break | |
| # Format chunk with source info | |
| source_info = f"[Source {i+1}: {result.metadata.get('file_name', 'Unknown')}]" | |
| if result.metadata.get('page_number'): | |
| source_info += f" (Page {result.metadata['page_number']})" | |
| context_parts.append(f"{source_info}\n{chunk_text}") | |
| # Build citation info | |
| citations.append({ | |
| "index": i + 1, | |
| "file_name": result.metadata.get('file_name', 'Unknown'), | |
| "chunk_id": result.chunk_id, | |
| "page_number": result.metadata.get('page_number'), | |
| "similarity_score": result.similarity_score, | |
| "excerpt": chunk_text[:200] + "..." if len(chunk_text) > 200 else chunk_text | |
| }) | |
| current_tokens += estimated_tokens | |
| formatted_context = "\n\n---\n\n".join(context_parts) | |
| return formatted_context, citations | |
| # Global retrieval service instance | |
| _retrieval_service: Optional[RetrievalService] = None | |
| def get_retrieval_service() -> RetrievalService: | |
| """Get the global retrieval service instance.""" | |
| global _retrieval_service | |
| if _retrieval_service is None: | |
| _retrieval_service = RetrievalService() | |
| return _retrieval_service | |