Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends, HTTPException, status, Query | |
| from pydantic import BaseModel, Field, computed_field | |
| from typing import List, Optional, Dict, Any | |
| import logging | |
| import numpy as np | |
| from sentence_transformers import CrossEncoder | |
| from vector_store import get_vector_store, VectorStoreManager | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/retrieval", tags=["retrieval"]) | |
| _reranker = None | |
| def get_reranker(): | |
| global _reranker | |
| if _reranker is None: | |
| logger.info("Loading cross-encoder reranker...") | |
| _reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| return _reranker | |
| class RetrievalRequest(BaseModel): | |
| question: str = Field(..., min_length=1, max_length=500) | |
| top_k: int = Field(default=5, ge=1, le=20) | |
| filter_by_cluster: Optional[str] = None | |
| filter_by_source: Optional[str] = None | |
| filter_by_topic: Optional[str] = None | |
| contains_text: Optional[str] = None | |
| similarity_threshold: float = Field(default=1.0, ge=0.0, le=2.0) | |
| # ✅ Hybrid retrieval toggles | |
| enable_bm25: bool = Field( | |
| default=False, | |
| description="Enable BM25 + semantic hybrid retrieval", | |
| ) | |
| bm25_k: int = Field( | |
| default=20, | |
| ge=5, | |
| le=100, | |
| description="How many BM25 candidates to consider", | |
| ) | |
| hybrid_alpha: float = Field( | |
| default=0.4, | |
| ge=0.0, | |
| le=1.0, | |
| description="Dense weight in hybrid fusion (alpha=1 => semantic only)", | |
| ) | |
| # Reranking | |
| enable_rerank: bool = Field(default=False) | |
| rerank_top_k: int = Field(default=3, ge=1, le=10) | |
| class DocumentResult(BaseModel): | |
| chunk_id: str | |
| text: str | |
| source: str | |
| topic: Optional[str] | |
| cluster: Optional[str] | |
| distance: float | |
| rerank_score: Optional[float] = None | |
| def relevance_label(self) -> str: | |
| if self.distance < 0.8: | |
| return "Highly Relevant" | |
| elif self.distance < 1.0: | |
| return "Relevant" | |
| elif self.distance < 1.5: | |
| return "Somewhat Relevant" | |
| return "Low Relevance" | |
| class RetrievalResponse(BaseModel): | |
| documents: List[DocumentResult] | |
| count: int | |
| query: str | |
| filters_applied: Dict[str, Any] | |
| retrieval_stats: Dict[str, Any] | |
| def rerank_documents(query: str, documents: List[DocumentResult], top_k: int = 3): | |
| if not documents or len(documents) <= 1: | |
| return documents | |
| try: | |
| reranker = get_reranker() | |
| pairs = [[query, doc.text[:1500]] for doc in documents] | |
| scores = reranker.predict(pairs) | |
| for doc, score in zip(documents, scores): | |
| doc.rerank_score = float(score) | |
| reranked = sorted(documents, key=lambda x: x.rerank_score or 0.0, reverse=True) | |
| return reranked[:top_k] | |
| except Exception as e: | |
| logger.error(f"Reranking failed: {str(e)}, returning original results") | |
| return documents[:top_k] | |
| async def retrieve_documents_endpoint( | |
| request: RetrievalRequest, | |
| vector_store: VectorStoreManager = Depends(get_vector_store), | |
| ): | |
| try: | |
| logger.info(f"Processing query: '{request.question}' top_k={request.top_k}") | |
| where_filters: Dict[str, Any] = {} | |
| if request.filter_by_cluster: | |
| where_filters["cluster"] = request.filter_by_cluster | |
| if request.filter_by_source: | |
| where_filters["source"] = request.filter_by_source | |
| if request.filter_by_topic: | |
| where_filters["topic"] = request.filter_by_topic | |
| where_document = {"$contains": request.contains_text} if request.contains_text else None | |
| # If reranking or hybrid, fetch more candidates | |
| n_candidates = request.top_k * 3 if (request.enable_rerank or request.enable_bm25) else request.top_k | |
| candidates = vector_store.retrieve_documents( | |
| question=request.question, | |
| n_results=n_candidates, | |
| where_filters=where_filters if where_filters else None, | |
| where_document=where_document, | |
| enable_bm25=request.enable_bm25, | |
| bm25_k=request.bm25_k, | |
| alpha=request.hybrid_alpha, | |
| ) | |
| documents: List[DocumentResult] = [] | |
| filtered_count = 0 | |
| for c in candidates: | |
| distance = c.get("distance") | |
| # if candidate came only from BM25, distance may be None | |
| if distance is None: | |
| distance = 1.5 # treat as weak semantic match | |
| if distance <= request.similarity_threshold: | |
| meta = c.get("metadata") or {} | |
| documents.append( | |
| DocumentResult( | |
| chunk_id=c["id"], | |
| text=c["text"], | |
| source=meta.get("source", "Unknown"), | |
| topic=meta.get("topic"), | |
| cluster=meta.get("cluster"), | |
| distance=float(distance), | |
| ) | |
| ) | |
| else: | |
| filtered_count += 1 | |
| total_retrieved = len(candidates) | |
| # Rerank if enabled | |
| if request.enable_rerank and len(documents) > 1: | |
| documents = rerank_documents(request.question, documents, request.rerank_top_k) | |
| retrieval_method = "hybrid_with_rerank" if request.enable_bm25 else "semantic_with_rerank" | |
| else: | |
| documents = documents[:request.top_k] | |
| retrieval_method = "hybrid" if request.enable_bm25 else "semantic" | |
| distances = [d.distance for d in documents] | |
| avg_distance = float(np.mean(distances)) if distances else None | |
| best_distance = min(distances) if distances else None | |
| return RetrievalResponse( | |
| documents=documents, | |
| count=len(documents), | |
| query=request.question, | |
| filters_applied={ | |
| "cluster": request.filter_by_cluster, | |
| "source": request.filter_by_source, | |
| "topic": request.filter_by_topic, | |
| "contains_text": request.contains_text, | |
| "similarity_threshold": request.similarity_threshold, | |
| "enable_bm25": request.enable_bm25, | |
| "bm25_k": request.bm25_k, | |
| "hybrid_alpha": request.hybrid_alpha, | |
| }, | |
| retrieval_stats={ | |
| "method": retrieval_method, | |
| "total_retrieved": total_retrieved, | |
| "filtered_by_threshold": filtered_count, | |
| "returned": len(documents), | |
| "best_distance": best_distance, | |
| "avg_distance": avg_distance, | |
| "reranking_applied": request.enable_rerank, | |
| "bm25_applied": request.enable_bm25, | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Retrieval failed: {str(e)}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Retrieval failed: {str(e)}", | |
| ) | |