feat: implement matching router and stage 2 neural reranking logic
Browse files
backend/src/matching/stage2.py
CHANGED
|
@@ -104,7 +104,7 @@ async def stage2_rerank(jd: dict, shortlist: list[dict]) -> list[dict]:
|
|
| 104 |
passages.append(" ".join(parts) or "No profile text")
|
| 105 |
|
| 106 |
from fastapi.concurrency import run_in_threadpool
|
| 107 |
-
from .reranker import rerank
|
| 108 |
try:
|
| 109 |
logger.info(f"[Stage 2] Starting neural reranking of {len(passages)} candidates...")
|
| 110 |
reranker_scores = await run_in_threadpool(rerank, jd_query, passages)
|
|
|
|
| 104 |
passages.append(" ".join(parts) or "No profile text")
|
| 105 |
|
| 106 |
from fastapi.concurrency import run_in_threadpool
|
| 107 |
+
from ..ml.reranker import rerank
|
| 108 |
try:
|
| 109 |
logger.info(f"[Stage 2] Starting neural reranking of {len(passages)} candidates...")
|
| 110 |
reranker_scores = await run_in_threadpool(rerank, jd_query, passages)
|
backend/src/routers/matching.py
CHANGED
|
@@ -81,19 +81,24 @@ async def trigger_match(
|
|
| 81 |
stage2_top_k: int = Query(100, description="How many Stage 1 candidates to pass to the neural reranker (Stage 2)"),
|
| 82 |
db: AsyncSession = Depends(get_db),
|
| 83 |
):
|
| 84 |
-
jd = await _load_jd(jd_id, db)
|
| 85 |
-
qdrant = _get_qdrant(request)
|
| 86 |
-
jd_dict = _build_jd_dict(jd)
|
| 87 |
-
sid_str = str(session_id) if session_id else None
|
| 88 |
-
|
| 89 |
-
# Stage 1: Retrieve top-K from vector DB using composite weighted score
|
| 90 |
-
shortlist = await stage1_retrieve(jd_dict, db, qdrant, session_id=sid_str, top_k=stage1_top_k)
|
| 91 |
-
|
| 92 |
-
# Stage 2: Run neural cross-encoder reranker on only the top stage2_top_k from Stage 1
|
| 93 |
-
rerank_input = shortlist[:stage2_top_k]
|
| 94 |
-
final_ranked = await stage2_rerank(jd_dict, rerank_input)
|
| 95 |
-
|
| 96 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
await db.execute(
|
| 98 |
delete(MatchResult).where(
|
| 99 |
MatchResult.jd_id == jd_id,
|
|
@@ -118,14 +123,17 @@ async def trigger_match(
|
|
| 118 |
inserted_mrs.append(mr)
|
| 119 |
|
| 120 |
await db.commit()
|
|
|
|
|
|
|
| 121 |
except Exception as exc:
|
| 122 |
-
|
|
|
|
| 123 |
await db.rollback()
|
| 124 |
raise
|
| 125 |
|
| 126 |
from ..workers.explain import generate_top_explanations
|
| 127 |
|
| 128 |
-
# Pre-generate LLM explanations async for
|
| 129 |
top_20_ids = [str(mr.id) for mr in inserted_mrs[:20]]
|
| 130 |
if top_20_ids:
|
| 131 |
generate_top_explanations.delay(top_20_ids)
|
|
|
|
| 81 |
stage2_top_k: int = Query(100, description="How many Stage 1 candidates to pass to the neural reranker (Stage 2)"),
|
| 82 |
db: AsyncSession = Depends(get_db),
|
| 83 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
try:
|
| 85 |
+
jd = await _load_jd(jd_id, db)
|
| 86 |
+
qdrant = _get_qdrant(request)
|
| 87 |
+
jd_dict = _build_jd_dict(jd)
|
| 88 |
+
sid_str = str(session_id) if session_id else None
|
| 89 |
+
|
| 90 |
+
# Stage 1: Retrieve top-K from vector DB
|
| 91 |
+
logger.info(f"[trigger_match] JD={jd_id} | Stage 1 starting (top_k={stage1_top_k})")
|
| 92 |
+
shortlist = await stage1_retrieve(jd_dict, db, qdrant, session_id=sid_str, top_k=stage1_top_k)
|
| 93 |
+
logger.info(f"[trigger_match] JD={jd_id} | Stage 1 complete — {len(shortlist)} candidates retrieved")
|
| 94 |
+
|
| 95 |
+
# Stage 2: Neural cross-encoder reranker
|
| 96 |
+
rerank_input = shortlist[:stage2_top_k]
|
| 97 |
+
logger.info(f"[trigger_match] JD={jd_id} | Stage 2 starting (reranking {len(rerank_input)} candidates)")
|
| 98 |
+
final_ranked = await stage2_rerank(jd_dict, rerank_input)
|
| 99 |
+
logger.info(f"[trigger_match] JD={jd_id} | Stage 2 complete — {len(final_ranked)} candidates ranked")
|
| 100 |
+
|
| 101 |
+
# Persist results to DB
|
| 102 |
await db.execute(
|
| 103 |
delete(MatchResult).where(
|
| 104 |
MatchResult.jd_id == jd_id,
|
|
|
|
| 123 |
inserted_mrs.append(mr)
|
| 124 |
|
| 125 |
await db.commit()
|
| 126 |
+
logger.info(f"[trigger_match] JD={jd_id} | {len(inserted_mrs)} match results saved to DB")
|
| 127 |
+
|
| 128 |
except Exception as exc:
|
| 129 |
+
# Log the FULL traceback so it appears in HF container logs
|
| 130 |
+
logger.exception(f"[trigger_match] FATAL — JD={jd_id} session={session_id} | {type(exc).__name__}: {exc}")
|
| 131 |
await db.rollback()
|
| 132 |
raise
|
| 133 |
|
| 134 |
from ..workers.explain import generate_top_explanations
|
| 135 |
|
| 136 |
+
# Pre-generate LLM explanations async for top 20 in background
|
| 137 |
top_20_ids = [str(mr.id) for mr in inserted_mrs[:20]]
|
| 138 |
if top_20_ids:
|
| 139 |
generate_top_explanations.delay(top_20_ids)
|