Spaces:
Sleeping
Sleeping
File size: 1,652 Bytes
b7f3196 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from .index_bm25 import BM25Index
from .index_dense import DenseIndex
from .rrf import rrf
try:
from .rerank import CrossEncoderReranker
except Exception:
CrossEncoderReranker = None
from .ingest import load_jsonl
class Retriever:
def __init__(self, corpora_config, use_reranker=False, embedding_model=None):
self.corpora = {}
docs_all = []
for name, cfg in corpora_config.items():
docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question","answer"))))
self.corpora[name] = docs
docs_all.extend(docs)
self.bm25 = BM25Index(docs_all)
self.dense = DenseIndex(docs_all, embedding_model=embedding_model)
self.reranker = CrossEncoderReranker() if (use_reranker and CrossEncoderReranker) else None
def retrieve(self, query, k=10, for_ui=True):
bm = self.bm25.search(query, k=100)
de = self.dense.search(query, k=100)
fused = rrf([bm, de], k=max(k, 20))
if self.reranker:
reranked = self.reranker.rerank(query, [d for d, _ in fused])[:k]
results = [(d, float(s)) for d, s in reranked]
else:
results = fused[:k]
if not for_ui:
return results
return [{
"id": d.id,
"title": d.title,
"snippet": d.text[:300] + ("..." if len(d.text) > 300 else ""),
"score": s,
"meta": d.meta
} for d, s in results]
def get_index_progress(self):
"""Returns (current, total) from dense index."""
return self.dense.get_progress()
|