from fastapi import APIRouter, Depends, HTTPException, status, Query from pydantic import BaseModel, Field, computed_field from typing import List, Optional, Dict, Any import logging import numpy as np from sentence_transformers import CrossEncoder from vector_store import get_vector_store, VectorStoreManager logger = logging.getLogger(__name__) router = APIRouter(prefix="/retrieval", tags=["retrieval"]) _reranker = None def get_reranker(): global _reranker if _reranker is None: logger.info("Loading cross-encoder reranker...") _reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") return _reranker class RetrievalRequest(BaseModel): question: str = Field(..., min_length=1, max_length=500) top_k: int = Field(default=5, ge=1, le=20) filter_by_cluster: Optional[str] = None filter_by_source: Optional[str] = None filter_by_topic: Optional[str] = None contains_text: Optional[str] = None similarity_threshold: float = Field(default=1.0, ge=0.0, le=2.0) # ✅ Hybrid retrieval toggles enable_bm25: bool = Field( default=False, description="Enable BM25 + semantic hybrid retrieval", ) bm25_k: int = Field( default=20, ge=5, le=100, description="How many BM25 candidates to consider", ) hybrid_alpha: float = Field( default=0.4, ge=0.0, le=1.0, description="Dense weight in hybrid fusion (alpha=1 => semantic only)", ) # Reranking enable_rerank: bool = Field(default=False) rerank_top_k: int = Field(default=3, ge=1, le=10) class DocumentResult(BaseModel): chunk_id: str text: str source: str topic: Optional[str] cluster: Optional[str] distance: float rerank_score: Optional[float] = None @computed_field @property def relevance_label(self) -> str: if self.distance < 0.8: return "Highly Relevant" elif self.distance < 1.0: return "Relevant" elif self.distance < 1.5: return "Somewhat Relevant" return "Low Relevance" class RetrievalResponse(BaseModel): documents: List[DocumentResult] count: int query: str filters_applied: Dict[str, Any] retrieval_stats: Dict[str, Any] def rerank_documents(query: str, documents: List[DocumentResult], top_k: int = 3): if not documents or len(documents) <= 1: return documents try: reranker = get_reranker() pairs = [[query, doc.text[:1500]] for doc in documents] scores = reranker.predict(pairs) for doc, score in zip(documents, scores): doc.rerank_score = float(score) reranked = sorted(documents, key=lambda x: x.rerank_score or 0.0, reverse=True) return reranked[:top_k] except Exception as e: logger.error(f"Reranking failed: {str(e)}, returning original results") return documents[:top_k] @router.post("/search", response_model=RetrievalResponse) async def retrieve_documents_endpoint( request: RetrievalRequest, vector_store: VectorStoreManager = Depends(get_vector_store), ): try: logger.info(f"Processing query: '{request.question}' top_k={request.top_k}") where_filters: Dict[str, Any] = {} if request.filter_by_cluster: where_filters["cluster"] = request.filter_by_cluster if request.filter_by_source: where_filters["source"] = request.filter_by_source if request.filter_by_topic: where_filters["topic"] = request.filter_by_topic where_document = {"$contains": request.contains_text} if request.contains_text else None # If reranking or hybrid, fetch more candidates n_candidates = request.top_k * 3 if (request.enable_rerank or request.enable_bm25) else request.top_k candidates = vector_store.retrieve_documents( question=request.question, n_results=n_candidates, where_filters=where_filters if where_filters else None, where_document=where_document, enable_bm25=request.enable_bm25, bm25_k=request.bm25_k, alpha=request.hybrid_alpha, ) documents: List[DocumentResult] = [] filtered_count = 0 for c in candidates: distance = c.get("distance") # if candidate came only from BM25, distance may be None if distance is None: distance = 1.5 # treat as weak semantic match if distance <= request.similarity_threshold: meta = c.get("metadata") or {} documents.append( DocumentResult( chunk_id=c["id"], text=c["text"], source=meta.get("source", "Unknown"), topic=meta.get("topic"), cluster=meta.get("cluster"), distance=float(distance), ) ) else: filtered_count += 1 total_retrieved = len(candidates) # Rerank if enabled if request.enable_rerank and len(documents) > 1: documents = rerank_documents(request.question, documents, request.rerank_top_k) retrieval_method = "hybrid_with_rerank" if request.enable_bm25 else "semantic_with_rerank" else: documents = documents[:request.top_k] retrieval_method = "hybrid" if request.enable_bm25 else "semantic" distances = [d.distance for d in documents] avg_distance = float(np.mean(distances)) if distances else None best_distance = min(distances) if distances else None return RetrievalResponse( documents=documents, count=len(documents), query=request.question, filters_applied={ "cluster": request.filter_by_cluster, "source": request.filter_by_source, "topic": request.filter_by_topic, "contains_text": request.contains_text, "similarity_threshold": request.similarity_threshold, "enable_bm25": request.enable_bm25, "bm25_k": request.bm25_k, "hybrid_alpha": request.hybrid_alpha, }, retrieval_stats={ "method": retrieval_method, "total_retrieved": total_retrieved, "filtered_by_threshold": filtered_count, "returned": len(documents), "best_distance": best_distance, "avg_distance": avg_distance, "reranking_applied": request.enable_rerank, "bm25_applied": request.enable_bm25, }, ) except Exception as e: logger.error(f"Retrieval failed: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Retrieval failed: {str(e)}", )