Spaces:
Runtime error
Runtime error
File size: 1,791 Bytes
8cadb90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | """
retriever.py
BM25-based retrieval over the support corpus.
No LLM needed — pure keyword matching, fast and reliable.
"""
import re
from rank_bm25 import BM25Okapi
def _tokenize(text: str) -> list[str]:
return re.findall(r"[a-z0-9]+", text.lower())
class DomainRetriever:
def __init__(self, docs: list[dict]):
self.docs = docs
tokenized = [_tokenize(f"{d['title']} {d['text']}") for d in docs]
self.bm25 = BM25Okapi(tokenized) if tokenized else None
def retrieve(self, query: str, top_k: int = 4) -> list[dict]:
if not self.docs or not self.bm25:
return []
scores = self.bm25.get_scores(_tokenize(query))
ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
results = []
for idx in ranked[:top_k]:
if scores[idx] > 0:
doc = dict(self.docs[idx])
doc["score"] = round(float(scores[idx]), 3)
doc["snippet"] = doc["text"][:1000]
results.append(doc)
return results
class MultiDomainRetriever:
def __init__(self, corpus: dict[str, list[dict]]):
self.retrievers = {d: DomainRetriever(docs) for d, docs in corpus.items()}
def retrieve_for_domain(self, domain: str, query: str, top_k: int = 4) -> list[dict]:
r = self.retrievers.get(domain)
return r.retrieve(query, top_k) if r else []
def retrieve_all(self, query: str, top_k_per_domain: int = 2) -> list[dict]:
results = []
for domain, r in self.retrievers.items():
for hit in r.retrieve(query, top_k_per_domain):
hit["domain"] = domain
results.append(hit)
results.sort(key=lambda x: x["score"], reverse=True)
return results
|