|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
from schemas.candidates import Candidate, CandidateSet |
|
|
|
|
|
|
|
|
def retrieve_candidates(plan, bm25, vector, topn: int = 200, catalog_df: pd.DataFrame | None = None) -> CandidateSet: |
|
|
"""Retrieve union of BM25 + vector using plan queries, then apply weighted RRF fusion.""" |
|
|
bm25_res = bm25.recommend(plan.bm25_query, k=topn, return_scores=True) |
|
|
vec_res = vector.recommend(plan.vec_query, k=topn, return_scores=True) |
|
|
|
|
|
def to_id_scores(res): |
|
|
out_ids = [] |
|
|
out_scores = {} |
|
|
for r in res: |
|
|
if isinstance(r, dict) and "assessment_id" in r: |
|
|
aid = r["assessment_id"] |
|
|
out_ids.append(aid) |
|
|
if "score" in r: |
|
|
out_scores[aid] = r["score"] |
|
|
elif isinstance(r, (tuple, list)) and len(r) >= 2: |
|
|
aid, sc = r[0], r[1] |
|
|
out_ids.append(aid) |
|
|
try: |
|
|
out_scores[aid] = float(sc) |
|
|
except Exception: |
|
|
pass |
|
|
elif isinstance(r, str): |
|
|
out_ids.append(r) |
|
|
return out_ids, out_scores |
|
|
|
|
|
bm25_ids, bm25_scores = to_id_scores(bm25_res) |
|
|
vec_ids, vec_scores = to_id_scores(vec_res) |
|
|
|
|
|
bm25_pos = {aid: i + 1 for i, aid in enumerate(bm25_ids)} |
|
|
vec_pos = {aid: i + 1 for i, aid in enumerate(vec_ids)} |
|
|
|
|
|
|
|
|
seen = set() |
|
|
union_ids: List[str] = [] |
|
|
for aid in bm25_ids: |
|
|
if aid not in seen: |
|
|
union_ids.append(aid) |
|
|
seen.add(aid) |
|
|
for aid in vec_ids: |
|
|
if aid not in seen: |
|
|
union_ids.append(aid) |
|
|
seen.add(aid) |
|
|
|
|
|
|
|
|
def _choose_fusion_weights(plan, raw_query: str) -> Tuple[float, float]: |
|
|
q = (plan.rerank_query or raw_query or "").strip() |
|
|
n_words = len(q.split()) |
|
|
n_skills = len(plan.must_have_skills or []) |
|
|
n_soft = len(plan.soft_skills or []) |
|
|
w_b, w_v = 0.5, 0.5 |
|
|
if n_skills >= 2: |
|
|
w_b += 0.2 |
|
|
w_v -= 0.2 |
|
|
if n_words >= 18: |
|
|
w_v += 0.2 |
|
|
w_b -= 0.2 |
|
|
if n_soft >= 2: |
|
|
w_v += 0.1 |
|
|
w_b -= 0.1 |
|
|
w_b = max(0.1, min(0.9, w_b)) |
|
|
w_v = max(0.1, min(0.9, w_v)) |
|
|
s = w_b + w_v |
|
|
return w_b / s, w_v / s |
|
|
|
|
|
w_bm25, w_vec = _choose_fusion_weights(plan, raw_query=plan.bm25_query) |
|
|
k_rrf = 60.0 |
|
|
|
|
|
candidates: List[Candidate] = [] |
|
|
scored: List[Tuple[float, Candidate]] = [] |
|
|
for aid in union_ids: |
|
|
rb = bm25_pos.get(aid) |
|
|
rv = vec_pos.get(aid) |
|
|
rrf_b = w_bm25 / (k_rrf + rb) if rb is not None else 0.0 |
|
|
rrf_v = w_vec / (k_rrf + rv) if rv is not None else 0.0 |
|
|
fused = rrf_b + rrf_v |
|
|
cand = Candidate( |
|
|
assessment_id=aid, |
|
|
source="union", |
|
|
bm25_rank=rb, |
|
|
vector_rank=rv, |
|
|
hybrid_rank=None, |
|
|
bm25_score=bm25_scores.get(aid), |
|
|
vector_score=vec_scores.get(aid), |
|
|
score=fused, |
|
|
) |
|
|
scored.append((fused, cand)) |
|
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
for rank, (_, cand) in enumerate(scored[:topn]): |
|
|
cand.hybrid_rank = rank + 1 |
|
|
candidates.append(cand) |
|
|
|
|
|
return CandidateSet( |
|
|
candidates=candidates, |
|
|
raw_bm25=bm25_ids, |
|
|
raw_vector=vec_ids, |
|
|
fusion={"w_bm25": w_bm25, "w_vec": w_vec, "k_rrf": k_rrf}, |
|
|
) |
|
|
|