taraky's picture
Upload folder using huggingface_hub
b7f3196 verified
from typing import Dict, List
from retriever import Retriever
from retriever.rrf import rrf
from team.interfaces import Candidate
from pathlib import Path
def _default_corpora_config() -> Dict[str, dict]:
return {
"medical_qa": {"path":"data/corpora/medical_qa.jsonl",
"text_fields":["question","answer","title"]},
"miriad": {"path":"data/corpora/miriad_text.jsonl",
"text_fields":["question","answer","title"]},
"pubmed": {"path":"data/corpora/pubmed.jsonl",
"text_fields":["contents","title"]},
"unidoc": {"path":"data/corpora/unidoc_qa.jsonl",
"text_fields":["question","answer","title"]},
}
def _available(cfg: Dict[str, dict]) -> Dict[str, dict]:
return {k:v for k,v in cfg.items() if Path(v["path"]).exists()}
def get_candidates(
query: str,
retriever: Retriever,
k_retrieve: int = 50,
) -> List[Candidate]:
"""
Returns top-N fused candidates with component scores (bm25, dense, rrf).
"""
r = retriever
# get separate result lists (doc, score)
bm = r.bm25.search(query, k=max(k_retrieve, 100))
de = r.dense.search(query, k=max(k_retrieve, 100))
# maps for score lookup
bm_map = {d.id: float(s) for d, s in bm}
de_map = {d.id: float(s) for d, s in de}
# fuse and pick candidate set
fused = rrf([bm, de], k=max(k_retrieve, 50))
# compute RRF per candidate using rank positions
K = 60
bm_rank = {d.id:i for i,(d,_) in enumerate(bm)}
de_rank = {d.id:i for i,(d,_) in enumerate(de)}
out: List[Candidate] = []
for doc, _ in fused[:k_retrieve]:
rrf_score = 0.0
if doc.id in bm_rank:
rrf_score += 1.0 / (K + bm_rank[doc.id] + 1)
if doc.id in de_rank:
rrf_score += 1.0 / (K + de_rank[doc.id] + 1)
out.append(Candidate(
id=doc.id,
title=doc.title or "",
text=doc.text,
meta=doc.meta or {},
bm25=bm_map.get(doc.id, 0.0),
dense=de_map.get(doc.id, 0.0),
rrf=rrf_score,
))
# baseline order: RRF
out.sort(key=lambda c: c.rrf, reverse=True)
return out
#how to call/run below for everyone
# from team.candidates import get_candidates
# q = "worst headache of my life with fever and stiff neck"
# cands = get_candidates(q, k_retrieve=60) # returns List[Candidate]
# for c in cands[:3]:
# print(c.id, c.bm25, c.dense, c.rrf, c.title)