File size: 1,652 Bytes
b7f3196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from .index_bm25 import BM25Index
from .index_dense import DenseIndex
from .rrf import rrf
try:
    from .rerank import CrossEncoderReranker
except Exception:
    CrossEncoderReranker = None
from .ingest import load_jsonl

class Retriever:
    def __init__(self, corpora_config, use_reranker=False, embedding_model=None):
        self.corpora = {}
        docs_all = []
        for name, cfg in corpora_config.items():
            docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question","answer"))))
            self.corpora[name] = docs
            docs_all.extend(docs)
        self.bm25 = BM25Index(docs_all)
        self.dense = DenseIndex(docs_all, embedding_model=embedding_model)
        self.reranker = CrossEncoderReranker() if (use_reranker and CrossEncoderReranker) else None

    def retrieve(self, query, k=10, for_ui=True):
        bm = self.bm25.search(query, k=100)
        de = self.dense.search(query, k=100)
        fused = rrf([bm, de], k=max(k, 20))
        if self.reranker:
            reranked = self.reranker.rerank(query, [d for d, _ in fused])[:k]
            results = [(d, float(s)) for d, s in reranked]
        else:
            results = fused[:k]
        if not for_ui:
            return results
        return [{
            "id": d.id,
            "title": d.title,
            "snippet": d.text[:300] + ("..." if len(d.text) > 300 else ""),
            "score": s,
            "meta": d.meta
        } for d, s in results]

    def get_index_progress(self):
        """Returns (current, total) from dense index."""
        return self.dense.get_progress()