taraky's picture
Upload folder using huggingface_hub
b7f3196 verified
from .index_bm25 import BM25Index
from .index_dense import DenseIndex
from .rrf import rrf
try:
from .rerank import CrossEncoderReranker
except Exception:
CrossEncoderReranker = None
from .ingest import load_jsonl
class Retriever:
def __init__(self, corpora_config, use_reranker=False, embedding_model=None):
self.corpora = {}
docs_all = []
for name, cfg in corpora_config.items():
docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question","answer"))))
self.corpora[name] = docs
docs_all.extend(docs)
self.bm25 = BM25Index(docs_all)
self.dense = DenseIndex(docs_all, embedding_model=embedding_model)
self.reranker = CrossEncoderReranker() if (use_reranker and CrossEncoderReranker) else None
def retrieve(self, query, k=10, for_ui=True):
bm = self.bm25.search(query, k=100)
de = self.dense.search(query, k=100)
fused = rrf([bm, de], k=max(k, 20))
if self.reranker:
reranked = self.reranker.rerank(query, [d for d, _ in fused])[:k]
results = [(d, float(s)) for d, s in reranked]
else:
results = fused[:k]
if not for_ui:
return results
return [{
"id": d.id,
"title": d.title,
"snippet": d.text[:300] + ("..." if len(d.text) > 300 else ""),
"score": s,
"meta": d.meta
} for d, s in results]
def get_index_progress(self):
"""Returns (current, total) from dense index."""
return self.dense.get_progress()