|
|
from typing import Dict, List
|
|
|
from retriever import Retriever
|
|
|
from retriever.rrf import rrf
|
|
|
from team.interfaces import Candidate
|
|
|
from pathlib import Path
|
|
|
|
|
|
def _default_corpora_config() -> Dict[str, dict]:
|
|
|
return {
|
|
|
"medical_qa": {"path":"data/corpora/medical_qa.jsonl",
|
|
|
"text_fields":["question","answer","title"]},
|
|
|
"miriad": {"path":"data/corpora/miriad_text.jsonl",
|
|
|
"text_fields":["question","answer","title"]},
|
|
|
"pubmed": {"path":"data/corpora/pubmed.jsonl",
|
|
|
"text_fields":["contents","title"]},
|
|
|
"unidoc": {"path":"data/corpora/unidoc_qa.jsonl",
|
|
|
"text_fields":["question","answer","title"]},
|
|
|
}
|
|
|
|
|
|
def _available(cfg: Dict[str, dict]) -> Dict[str, dict]:
|
|
|
return {k:v for k,v in cfg.items() if Path(v["path"]).exists()}
|
|
|
|
|
|
def get_candidates(
|
|
|
query: str,
|
|
|
retriever: Retriever,
|
|
|
k_retrieve: int = 50,
|
|
|
) -> List[Candidate]:
|
|
|
"""
|
|
|
Returns top-N fused candidates with component scores (bm25, dense, rrf).
|
|
|
"""
|
|
|
r = retriever
|
|
|
|
|
|
|
|
|
bm = r.bm25.search(query, k=max(k_retrieve, 100))
|
|
|
de = r.dense.search(query, k=max(k_retrieve, 100))
|
|
|
|
|
|
|
|
|
bm_map = {d.id: float(s) for d, s in bm}
|
|
|
de_map = {d.id: float(s) for d, s in de}
|
|
|
|
|
|
|
|
|
fused = rrf([bm, de], k=max(k_retrieve, 50))
|
|
|
|
|
|
|
|
|
K = 60
|
|
|
bm_rank = {d.id:i for i,(d,_) in enumerate(bm)}
|
|
|
de_rank = {d.id:i for i,(d,_) in enumerate(de)}
|
|
|
|
|
|
out: List[Candidate] = []
|
|
|
for doc, _ in fused[:k_retrieve]:
|
|
|
rrf_score = 0.0
|
|
|
if doc.id in bm_rank:
|
|
|
rrf_score += 1.0 / (K + bm_rank[doc.id] + 1)
|
|
|
if doc.id in de_rank:
|
|
|
rrf_score += 1.0 / (K + de_rank[doc.id] + 1)
|
|
|
out.append(Candidate(
|
|
|
id=doc.id,
|
|
|
title=doc.title or "",
|
|
|
text=doc.text,
|
|
|
meta=doc.meta or {},
|
|
|
bm25=bm_map.get(doc.id, 0.0),
|
|
|
dense=de_map.get(doc.id, 0.0),
|
|
|
rrf=rrf_score,
|
|
|
))
|
|
|
|
|
|
out.sort(key=lambda c: c.rrf, reverse=True)
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|