feat: implement Stage 2 reranker with BGE cross-encoder and RRF fusion
Browse files- backend/src/matching/scorer.py +37 -0
- backend/src/matching/stage2.py +96 -0
backend/src/matching/scorer.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
DEFAULT_WEIGHTS = {
|
| 5 |
+
"semantic": 0.20,
|
| 6 |
+
"skill": 0.35,
|
| 7 |
+
"yoe": 0.15,
|
| 8 |
+
"company": 0.10,
|
| 9 |
+
"growth": 0.10,
|
| 10 |
+
"education": 0.10,
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def normalize_weights(weights: dict[str, float]) -> dict[str, float]:
|
| 15 |
+
total = sum(weights.values())
|
| 16 |
+
if total == 0:
|
| 17 |
+
return DEFAULT_WEIGHTS.copy()
|
| 18 |
+
return {k: v / total for k, v in weights.items()}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def rerank_with_weights(
|
| 22 |
+
match_results: list[dict[str, Any]],
|
| 23 |
+
weights: dict[str, float],
|
| 24 |
+
) -> list[dict[str, Any]]:
|
| 25 |
+
w = normalize_weights({**DEFAULT_WEIGHTS, **weights})
|
| 26 |
+
|
| 27 |
+
reranked = []
|
| 28 |
+
for item in match_results:
|
| 29 |
+
components = item.get("component_scores") or {}
|
| 30 |
+
new_score = sum(w.get(k, 0) * v for k, v in components.items())
|
| 31 |
+
reranked.append({**item, "final_score": round(new_score, 4), "weights_used": w})
|
| 32 |
+
|
| 33 |
+
reranked.sort(key=lambda x: x["final_score"], reverse=True)
|
| 34 |
+
for i, item in enumerate(reranked):
|
| 35 |
+
item["rank"] = i + 1
|
| 36 |
+
|
| 37 |
+
return reranked
|
backend/src/matching/stage2.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
from ..ml.reranker import rerank
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _compute_gaps(jd: dict, candidate: dict) -> list[dict]:
|
| 6 |
+
gaps = []
|
| 7 |
+
|
| 8 |
+
jd_skills = {s.lower().strip() for s in (jd.get("required_skills") or [])}
|
| 9 |
+
cand_skills = set()
|
| 10 |
+
for lst in [
|
| 11 |
+
candidate.get("programming_languages") or [],
|
| 12 |
+
candidate.get("backend_frameworks") or [],
|
| 13 |
+
candidate.get("frontend_technologies") or [],
|
| 14 |
+
]:
|
| 15 |
+
cand_skills.update(s.lower().strip() for s in lst if s)
|
| 16 |
+
if candidate.get("parsed_skills"):
|
| 17 |
+
cand_skills.update(s.strip().lower() for s in candidate["parsed_skills"].split(",") if s.strip())
|
| 18 |
+
|
| 19 |
+
missing_skills = jd_skills - cand_skills
|
| 20 |
+
for skill in sorted(missing_skills)[:8]:
|
| 21 |
+
gaps.append({"type": "missing_skill", "detail": skill})
|
| 22 |
+
|
| 23 |
+
min_yoe = jd.get("min_yoe")
|
| 24 |
+
cand_yoe = candidate.get("years_of_experience")
|
| 25 |
+
if min_yoe and cand_yoe is not None and float(cand_yoe) < float(min_yoe):
|
| 26 |
+
gaps.append({
|
| 27 |
+
"type": "yoe_gap",
|
| 28 |
+
"detail": f"Requires {min_yoe}+ years, candidate has {cand_yoe}",
|
| 29 |
+
})
|
| 30 |
+
|
| 31 |
+
jd_location = (jd.get("location") or "").lower()
|
| 32 |
+
cand_location = (candidate.get("open_to_working_at") or "").lower()
|
| 33 |
+
if jd_location and jd_location not in ("remote", "") and cand_location:
|
| 34 |
+
if jd_location not in cand_location and cand_location not in jd_location:
|
| 35 |
+
remote_allowed = jd.get("remote_allowed", False)
|
| 36 |
+
gaps.append({
|
| 37 |
+
"type": "location_mismatch",
|
| 38 |
+
"detail": f"JD is in {jd.get('location')}, candidate is open to {candidate.get('open_to_working_at')}",
|
| 39 |
+
"mitigated_by_remote": bool(remote_allowed),
|
| 40 |
+
})
|
| 41 |
+
|
| 42 |
+
jd_engineer_type = (jd.get("engineer_type") or "").lower()
|
| 43 |
+
cand_engineer_type = (candidate.get("engineer_type") or "").lower()
|
| 44 |
+
if jd_engineer_type and cand_engineer_type and jd_engineer_type not in cand_engineer_type and cand_engineer_type not in jd_engineer_type:
|
| 45 |
+
gaps.append({
|
| 46 |
+
"type": "engineer_type_mismatch",
|
| 47 |
+
"detail": f"JD needs {jd.get('engineer_type')}, candidate is {candidate.get('engineer_type')}",
|
| 48 |
+
})
|
| 49 |
+
|
| 50 |
+
return gaps
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _reciprocal_rank_fusion(stage1_scored: list[dict], reranker_scores: list[float], k: int = 60) -> list[dict]:
|
| 54 |
+
stage1_ranks = {item["candidate_id"]: i + 1 for i, item in enumerate(stage1_scored)}
|
| 55 |
+
reranker_ranks = {}
|
| 56 |
+
reranker_order = sorted(range(len(reranker_scores)), key=lambda i: reranker_scores[i], reverse=True)
|
| 57 |
+
for rank, idx in enumerate(reranker_order):
|
| 58 |
+
cid = stage1_scored[idx]["candidate_id"]
|
| 59 |
+
reranker_ranks[cid] = rank + 1
|
| 60 |
+
|
| 61 |
+
results = []
|
| 62 |
+
for item in stage1_scored:
|
| 63 |
+
cid = item["candidate_id"]
|
| 64 |
+
rrf_score = 1.0 / (k + stage1_ranks.get(cid, k)) + 1.0 / (k + reranker_ranks.get(cid, k))
|
| 65 |
+
results.append({**item, "stage2_score": round(reranker_scores[stage1_scored.index(item)], 4), "final_score": round(rrf_score, 6)})
|
| 66 |
+
|
| 67 |
+
results.sort(key=lambda x: x["final_score"], reverse=True)
|
| 68 |
+
return results
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
async def stage2_rerank(jd: dict, shortlist: list[dict]) -> list[dict]:
|
| 72 |
+
if not shortlist:
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
jd_query = f"{jd.get('title', '')} {jd.get('raw_text', '')}"
|
| 76 |
+
|
| 77 |
+
passages = []
|
| 78 |
+
for cand in shortlist:
|
| 79 |
+
parts = []
|
| 80 |
+
if cand.get("parsed_summary"):
|
| 81 |
+
parts.append(cand["parsed_summary"])
|
| 82 |
+
if cand.get("parsed_skills"):
|
| 83 |
+
parts.append(f"Skills: {cand['parsed_skills']}")
|
| 84 |
+
langs = cand.get("programming_languages") or []
|
| 85 |
+
if langs:
|
| 86 |
+
parts.append(f"Languages: {', '.join(langs[:10])}")
|
| 87 |
+
passages.append(" ".join(parts) or "No profile text")
|
| 88 |
+
|
| 89 |
+
reranker_scores = rerank(jd_query, passages)
|
| 90 |
+
|
| 91 |
+
results = _reciprocal_rank_fusion(shortlist, reranker_scores)
|
| 92 |
+
|
| 93 |
+
for cand in results:
|
| 94 |
+
cand["gaps"] = _compute_gaps(jd, cand)
|
| 95 |
+
|
| 96 |
+
return results[:20]
|