File size: 1,791 Bytes
8cadb90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
retriever.py
BM25-based retrieval over the support corpus.
No LLM needed — pure keyword matching, fast and reliable.
"""

import re
from rank_bm25 import BM25Okapi


def _tokenize(text: str) -> list[str]:
    return re.findall(r"[a-z0-9]+", text.lower())


class DomainRetriever:
    def __init__(self, docs: list[dict]):
        self.docs = docs
        tokenized = [_tokenize(f"{d['title']} {d['text']}") for d in docs]
        self.bm25 = BM25Okapi(tokenized) if tokenized else None

    def retrieve(self, query: str, top_k: int = 4) -> list[dict]:
        if not self.docs or not self.bm25:
            return []
        scores = self.bm25.get_scores(_tokenize(query))
        ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
        results = []
        for idx in ranked[:top_k]:
            if scores[idx] > 0:
                doc = dict(self.docs[idx])
                doc["score"]   = round(float(scores[idx]), 3)
                doc["snippet"] = doc["text"][:1000]
                results.append(doc)
        return results


class MultiDomainRetriever:
    def __init__(self, corpus: dict[str, list[dict]]):
        self.retrievers = {d: DomainRetriever(docs) for d, docs in corpus.items()}

    def retrieve_for_domain(self, domain: str, query: str, top_k: int = 4) -> list[dict]:
        r = self.retrievers.get(domain)
        return r.retrieve(query, top_k) if r else []

    def retrieve_all(self, query: str, top_k_per_domain: int = 2) -> list[dict]:
        results = []
        for domain, r in self.retrievers.items():
            for hit in r.retrieve(query, top_k_per_domain):
                hit["domain"] = domain
                results.append(hit)
        results.sort(key=lambda x: x["score"], reverse=True)
        return results