ketannnn commited on
Commit
88462d6
·
1 Parent(s): 7770c5f

feat: implement matching router and stage 2 neural reranking logic

Browse files
backend/src/matching/stage2.py CHANGED
@@ -104,7 +104,7 @@ async def stage2_rerank(jd: dict, shortlist: list[dict]) -> list[dict]:
104
  passages.append(" ".join(parts) or "No profile text")
105
 
106
  from fastapi.concurrency import run_in_threadpool
107
- from .reranker import rerank
108
  try:
109
  logger.info(f"[Stage 2] Starting neural reranking of {len(passages)} candidates...")
110
  reranker_scores = await run_in_threadpool(rerank, jd_query, passages)
 
104
  passages.append(" ".join(parts) or "No profile text")
105
 
106
  from fastapi.concurrency import run_in_threadpool
107
+ from ..ml.reranker import rerank
108
  try:
109
  logger.info(f"[Stage 2] Starting neural reranking of {len(passages)} candidates...")
110
  reranker_scores = await run_in_threadpool(rerank, jd_query, passages)
backend/src/routers/matching.py CHANGED
@@ -81,19 +81,24 @@ async def trigger_match(
81
  stage2_top_k: int = Query(100, description="How many Stage 1 candidates to pass to the neural reranker (Stage 2)"),
82
  db: AsyncSession = Depends(get_db),
83
  ):
84
- jd = await _load_jd(jd_id, db)
85
- qdrant = _get_qdrant(request)
86
- jd_dict = _build_jd_dict(jd)
87
- sid_str = str(session_id) if session_id else None
88
-
89
- # Stage 1: Retrieve top-K from vector DB using composite weighted score
90
- shortlist = await stage1_retrieve(jd_dict, db, qdrant, session_id=sid_str, top_k=stage1_top_k)
91
-
92
- # Stage 2: Run neural cross-encoder reranker on only the top stage2_top_k from Stage 1
93
- rerank_input = shortlist[:stage2_top_k]
94
- final_ranked = await stage2_rerank(jd_dict, rerank_input)
95
-
96
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  await db.execute(
98
  delete(MatchResult).where(
99
  MatchResult.jd_id == jd_id,
@@ -118,14 +123,17 @@ async def trigger_match(
118
  inserted_mrs.append(mr)
119
 
120
  await db.commit()
 
 
121
  except Exception as exc:
122
- logger.exception(f"[trigger_match] FATAL ERROR for JD {jd_id}: {exc}")
 
123
  await db.rollback()
124
  raise
125
 
126
  from ..workers.explain import generate_top_explanations
127
 
128
- # Pre-generate LLM explanations async for the top 20 matches implicitly in background
129
  top_20_ids = [str(mr.id) for mr in inserted_mrs[:20]]
130
  if top_20_ids:
131
  generate_top_explanations.delay(top_20_ids)
 
81
  stage2_top_k: int = Query(100, description="How many Stage 1 candidates to pass to the neural reranker (Stage 2)"),
82
  db: AsyncSession = Depends(get_db),
83
  ):
 
 
 
 
 
 
 
 
 
 
 
 
84
  try:
85
+ jd = await _load_jd(jd_id, db)
86
+ qdrant = _get_qdrant(request)
87
+ jd_dict = _build_jd_dict(jd)
88
+ sid_str = str(session_id) if session_id else None
89
+
90
+ # Stage 1: Retrieve top-K from vector DB
91
+ logger.info(f"[trigger_match] JD={jd_id} | Stage 1 starting (top_k={stage1_top_k})")
92
+ shortlist = await stage1_retrieve(jd_dict, db, qdrant, session_id=sid_str, top_k=stage1_top_k)
93
+ logger.info(f"[trigger_match] JD={jd_id} | Stage 1 complete — {len(shortlist)} candidates retrieved")
94
+
95
+ # Stage 2: Neural cross-encoder reranker
96
+ rerank_input = shortlist[:stage2_top_k]
97
+ logger.info(f"[trigger_match] JD={jd_id} | Stage 2 starting (reranking {len(rerank_input)} candidates)")
98
+ final_ranked = await stage2_rerank(jd_dict, rerank_input)
99
+ logger.info(f"[trigger_match] JD={jd_id} | Stage 2 complete — {len(final_ranked)} candidates ranked")
100
+
101
+ # Persist results to DB
102
  await db.execute(
103
  delete(MatchResult).where(
104
  MatchResult.jd_id == jd_id,
 
123
  inserted_mrs.append(mr)
124
 
125
  await db.commit()
126
+ logger.info(f"[trigger_match] JD={jd_id} | {len(inserted_mrs)} match results saved to DB")
127
+
128
  except Exception as exc:
129
+ # Log the FULL traceback so it appears in HF container logs
130
+ logger.exception(f"[trigger_match] FATAL — JD={jd_id} session={session_id} | {type(exc).__name__}: {exc}")
131
  await db.rollback()
132
  raise
133
 
134
  from ..workers.explain import generate_top_explanations
135
 
136
+ # Pre-generate LLM explanations async for top 20 in background
137
  top_20_ids = [str(mr.id) for mr in inserted_mrs[:20]]
138
  if top_20_ids:
139
  generate_top_explanations.delay(top_20_ids)