RAG-Insurance / retrieval.py
mokhles's picture
Initial commit: Insurance RAG API
af37875
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field, computed_field
from typing import List, Optional, Dict, Any
import logging
import numpy as np
from sentence_transformers import CrossEncoder
from vector_store import get_vector_store, VectorStoreManager
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/retrieval", tags=["retrieval"])
_reranker = None
def get_reranker():
global _reranker
if _reranker is None:
logger.info("Loading cross-encoder reranker...")
_reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
return _reranker
class RetrievalRequest(BaseModel):
question: str = Field(..., min_length=1, max_length=500)
top_k: int = Field(default=5, ge=1, le=20)
filter_by_cluster: Optional[str] = None
filter_by_source: Optional[str] = None
filter_by_topic: Optional[str] = None
contains_text: Optional[str] = None
similarity_threshold: float = Field(default=1.0, ge=0.0, le=2.0)
# ✅ Hybrid retrieval toggles
enable_bm25: bool = Field(
default=False,
description="Enable BM25 + semantic hybrid retrieval",
)
bm25_k: int = Field(
default=20,
ge=5,
le=100,
description="How many BM25 candidates to consider",
)
hybrid_alpha: float = Field(
default=0.4,
ge=0.0,
le=1.0,
description="Dense weight in hybrid fusion (alpha=1 => semantic only)",
)
# Reranking
enable_rerank: bool = Field(default=False)
rerank_top_k: int = Field(default=3, ge=1, le=10)
class DocumentResult(BaseModel):
chunk_id: str
text: str
source: str
topic: Optional[str]
cluster: Optional[str]
distance: float
rerank_score: Optional[float] = None
@computed_field
@property
def relevance_label(self) -> str:
if self.distance < 0.8:
return "Highly Relevant"
elif self.distance < 1.0:
return "Relevant"
elif self.distance < 1.5:
return "Somewhat Relevant"
return "Low Relevance"
class RetrievalResponse(BaseModel):
documents: List[DocumentResult]
count: int
query: str
filters_applied: Dict[str, Any]
retrieval_stats: Dict[str, Any]
def rerank_documents(query: str, documents: List[DocumentResult], top_k: int = 3):
if not documents or len(documents) <= 1:
return documents
try:
reranker = get_reranker()
pairs = [[query, doc.text[:1500]] for doc in documents]
scores = reranker.predict(pairs)
for doc, score in zip(documents, scores):
doc.rerank_score = float(score)
reranked = sorted(documents, key=lambda x: x.rerank_score or 0.0, reverse=True)
return reranked[:top_k]
except Exception as e:
logger.error(f"Reranking failed: {str(e)}, returning original results")
return documents[:top_k]
@router.post("/search", response_model=RetrievalResponse)
async def retrieve_documents_endpoint(
request: RetrievalRequest,
vector_store: VectorStoreManager = Depends(get_vector_store),
):
try:
logger.info(f"Processing query: '{request.question}' top_k={request.top_k}")
where_filters: Dict[str, Any] = {}
if request.filter_by_cluster:
where_filters["cluster"] = request.filter_by_cluster
if request.filter_by_source:
where_filters["source"] = request.filter_by_source
if request.filter_by_topic:
where_filters["topic"] = request.filter_by_topic
where_document = {"$contains": request.contains_text} if request.contains_text else None
# If reranking or hybrid, fetch more candidates
n_candidates = request.top_k * 3 if (request.enable_rerank or request.enable_bm25) else request.top_k
candidates = vector_store.retrieve_documents(
question=request.question,
n_results=n_candidates,
where_filters=where_filters if where_filters else None,
where_document=where_document,
enable_bm25=request.enable_bm25,
bm25_k=request.bm25_k,
alpha=request.hybrid_alpha,
)
documents: List[DocumentResult] = []
filtered_count = 0
for c in candidates:
distance = c.get("distance")
# if candidate came only from BM25, distance may be None
if distance is None:
distance = 1.5 # treat as weak semantic match
if distance <= request.similarity_threshold:
meta = c.get("metadata") or {}
documents.append(
DocumentResult(
chunk_id=c["id"],
text=c["text"],
source=meta.get("source", "Unknown"),
topic=meta.get("topic"),
cluster=meta.get("cluster"),
distance=float(distance),
)
)
else:
filtered_count += 1
total_retrieved = len(candidates)
# Rerank if enabled
if request.enable_rerank and len(documents) > 1:
documents = rerank_documents(request.question, documents, request.rerank_top_k)
retrieval_method = "hybrid_with_rerank" if request.enable_bm25 else "semantic_with_rerank"
else:
documents = documents[:request.top_k]
retrieval_method = "hybrid" if request.enable_bm25 else "semantic"
distances = [d.distance for d in documents]
avg_distance = float(np.mean(distances)) if distances else None
best_distance = min(distances) if distances else None
return RetrievalResponse(
documents=documents,
count=len(documents),
query=request.question,
filters_applied={
"cluster": request.filter_by_cluster,
"source": request.filter_by_source,
"topic": request.filter_by_topic,
"contains_text": request.contains_text,
"similarity_threshold": request.similarity_threshold,
"enable_bm25": request.enable_bm25,
"bm25_k": request.bm25_k,
"hybrid_alpha": request.hybrid_alpha,
},
retrieval_stats={
"method": retrieval_method,
"total_retrieved": total_retrieved,
"filtered_by_threshold": filtered_count,
"returned": len(documents),
"best_distance": best_distance,
"avg_distance": avg_distance,
"reranking_applied": request.enable_rerank,
"bm25_applied": request.enable_bm25,
},
)
except Exception as e:
logger.error(f"Retrieval failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Retrieval failed: {str(e)}",
)