File size: 848 Bytes
5a3b322 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from __future__ import annotations
from typing import List
import pandas as pd
from schemas.candidates import RankedItem, RankedList, CandidateSet
def rerank_candidates(plan, candidate_set: CandidateSet, reranker, catalog_df: pd.DataFrame, k: int = 10) -> RankedList:
"""Rerank candidate union using cross-encoder reranker."""
catalog = catalog_df.set_index("assessment_id")
scored: List[RankedItem] = []
for cand in candidate_set.candidates:
if cand.assessment_id not in catalog.index:
continue
doc_text = catalog.loc[cand.assessment_id].get("doc_text", "")
score = reranker.score(plan.rerank_query, doc_text)
scored.append(RankedItem(assessment_id=cand.assessment_id, score=float(score)))
scored.sort(key=lambda x: x.score, reverse=True)
return RankedList(items=scored[:k])
|