VcRlAgent commited on
Commit
08d20f8
·
1 Parent(s): ec1d52e

Reranker and Debug Added

Browse files
app/main.py CHANGED
@@ -27,6 +27,11 @@ app.add_middleware(
27
  app.include_router(ingest_routes.router, prefix="/api", tags=["Ingestion"])
28
  app.include_router(ask_routes.router, prefix="/api", tags=["Query"])
29
  app.include_router(metrics_routes.router, prefix="/api", tags=["Metrics"])
 
 
 
 
 
30
 
31
  @app.get("/")
32
  async def root():
 
27
  app.include_router(ingest_routes.router, prefix="/api", tags=["Ingestion"])
28
  app.include_router(ask_routes.router, prefix="/api", tags=["Query"])
29
  app.include_router(metrics_routes.router, prefix="/api", tags=["Metrics"])
30
+ app.include_router(debug_routes.router, prefix="/api", tags=["Debug"])
31
+
32
+ logger.info("✅ Routers initialized:")
33
+ for route in app.routes:
34
+ logger.info(f" - {route.path}")
35
 
36
  @app.get("/")
37
  async def root():
app/routes/ask_routes.py CHANGED
@@ -4,6 +4,7 @@ from fastapi import APIRouter, HTTPException
4
  from app.models.jira_schema import QueryRequest, QueryResponse
5
  from app.services.retriever import retriever
6
  from app.services.generator import generator
 
7
  from app.utils.response_builder import build_query_response, extract_chart_intent
8
  from app.utils.logger import setup_logger
9
  from collections import Counter
@@ -32,8 +33,14 @@ async def ask_question(request: QueryRequest):
32
  sources=[]
33
  )
34
 
 
 
 
 
35
  # Format context
36
- context = retriever.format_context(results)
 
 
37
 
38
  # Generate answer
39
  answer = generator.generate_rag_response(request.query, context)
 
4
  from app.models.jira_schema import QueryRequest, QueryResponse
5
  from app.services.retriever import retriever
6
  from app.services.generator import generator
7
+ from app.services.reranker import reranker
8
  from app.utils.response_builder import build_query_response, extract_chart_intent
9
  from app.utils.logger import setup_logger
10
  from collections import Counter
 
33
  sources=[]
34
  )
35
 
36
+ # 🧠 Re-rank results
37
+ logger.info("[RERANKER] Starting re-ranking process...")
38
+ reranked_results = reranker.rerank(request.query, results, top_k=5)
39
+
40
  # Format context
41
+ #context = retriever.format_context(results)
42
+ # Use reranked results for context
43
+ context = retriever.format_context(context = retriever.format_context(results))
44
 
45
  # Generate answer
46
  answer = generator.generate_rag_response(request.query, context)
app/routes/debug_routes.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ @router.post("/debug/retrieval")
2
+ async def debug_retrieval(request: QueryRequest):
3
+ results = retriever.retrieve(request.query)
4
+ reranked = reranker.rerank(request.query, results, top_k=10)
5
+ return {
6
+ "query": request.query,
7
+ "raw_faiss_scores": [r["score"] for r in results],
8
+ "reranked_scores": [r["rerank_score"] for r in reranked],
9
+ "top_docs": [r["payload"].get("summary") for r in reranked[:5]]
10
+ }
app/services/embeddings.py CHANGED
@@ -39,7 +39,7 @@ class EmbeddingService:
39
  def embed_batch(
40
  self,
41
  texts: List[str],
42
- batch_size: int = 32,
43
  is_query: bool = False,
44
  ) -> List[List[float]]:
45
  """Generate embeddings for a batch of texts (queries or passages)."""
 
39
  def embed_batch(
40
  self,
41
  texts: List[str],
42
+ batch_size: int = 16,
43
  is_query: bool = False,
44
  ) -> List[List[float]]:
45
  """Generate embeddings for a batch of texts (queries or passages)."""
app/services/reranker.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/services/reranker.py
2
+ from sentence_transformers import CrossEncoder
3
+ from app.utils.logger import setup_logger
4
+
5
+ logger = setup_logger(__name__)
6
+
7
+ class RerankerService:
8
+ """
9
+ Cross-Encoder based re-ranker for improving top-k retrieval precision.
10
+ """
11
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
12
+ logger.info(f"Loading reranker model: {model_name}")
13
+ self.model = CrossEncoder(model_name)
14
+
15
+ def rerank(self, query: str, results: list, top_k: int = 5) -> list:
16
+ """
17
+ Re-rank retrieved documents using CrossEncoder scores.
18
+
19
+ Args:
20
+ query: User query text
21
+ results: List of FAISS results [{"payload": {...}, "score": float}]
22
+ top_k: Return top_k reranked items
23
+
24
+ Returns:
25
+ List of reranked documents with updated scores
26
+ """
27
+ if not results:
28
+ return []
29
+
30
+ pairs = [(query, r["payload"].get("searchable_text", "")) for r in results]
31
+
32
+ logger.info(f"[RERANKER] Scoring {len(pairs)} query-document pairs...")
33
+ scores = self.model.predict(pairs)
34
+
35
+ # Attach rerank score to each document
36
+ for i, s in enumerate(scores):
37
+ results[i]["rerank_score"] = float(s)
38
+
39
+ # Sort by rerank_score (descending)
40
+ reranked = sorted(results, key=lambda x: x["rerank_score"], reverse=True)
41
+
42
+ logger.info(
43
+ f"[RERANKER] Top reranked scores: "
44
+ f"{[round(r['rerank_score'], 3) for r in reranked[:min(top_k, len(reranked))]]}"
45
+ )
46
+
47
+ return reranked[:top_k]
48
+
49
+ # Global instance
50
+ reranker = RerankerService()
app/services/retriever.py CHANGED
@@ -24,7 +24,7 @@ class RetrieverService:
24
  top_k = settings.TOP_K
25
 
26
  # Generate query embedding
27
- logger.info(f"Retrieving documents for query: {query}")
28
  query_embedding = self.embedding_service.embed_text(query,is_query=True)
29
  #logger.debug(f"Embedded query: {query_embedding}")
30
 
@@ -59,7 +59,12 @@ class RetrieverService:
59
  # score_threshold=settings.SCORE_THRESHOLD
60
  # )
61
 
62
- logger.info(f"Retrieved {len(results)} documents")
 
 
 
 
 
63
  return results
64
 
65
  def format_context(self, results: List[Dict[str, Any]]) -> str:
 
24
  top_k = settings.TOP_K
25
 
26
  # Generate query embedding
27
+ logger.info(f"[RETRIEVER] Retrieving documents for query: {query}")
28
  query_embedding = self.embedding_service.embed_text(query,is_query=True)
29
  #logger.debug(f"Embedded query: {query_embedding}")
30
 
 
59
  # score_threshold=settings.SCORE_THRESHOLD
60
  # )
61
 
62
+ logger.info(f"[RETRIEVER] Retrieved {len(results)} documents")
63
+
64
+ if results:
65
+ logger.debug("[RETRIEVER] Raw FAISS top-5 scores: " +
66
+ ", ".join(f"{r['score']:.4f}" for r in results[:5]))
67
+
68
  return results
69
 
70
  def format_context(self, results: List[Dict[str, Any]]) -> str: