Marik1337's picture
Add application file
6059138
from __future__ import annotations
import re
from langchain_core.documents import Document
from rank_bm25 import BM25Plus
from memory_agent.models import KnowledgeRecord
from memory_agent.storage import FaissMemoryStore
class HybridRetriever:
def __init__(
self,
store: FaissMemoryStore,
dense_weight: float = 0.55,
sparse_weight: float = 0.45,
) -> None:
if dense_weight <= 0 or sparse_weight <= 0:
raise ValueError("Dense and sparse weights must be positive.")
self._store = store
self._dense_weight = dense_weight
self._sparse_weight = sparse_weight
def retrieve(self, namespace: str, query: str, k: int = 6) -> list[tuple[Document, float]]:
records = self._store.fetch_records(namespace=namespace, limit=2000)
if not records:
return []
dense_hits = self._store.dense_search(namespace=namespace, query=query, k=max(k, 8))
sparse_hits = self._bm25_search(records=records, query=query, k=max(k, 8))
dense_max = max((score for _, score in dense_hits), default=1.0)
sparse_max = max((score for _, score in sparse_hits), default=1.0)
merged: dict[str, tuple[Document, float]] = {}
for document, score in dense_hits:
record_id = self._record_id(document=document)
normalized = score / dense_max if dense_max else 0.0
merged[record_id] = (document, self._dense_weight * normalized)
for document, score in sparse_hits:
record_id = self._record_id(document=document)
normalized = score / sparse_max if sparse_max else 0.0
base_document, base_score = merged.get(record_id, (document, 0.0))
merged[record_id] = (
base_document,
base_score + self._sparse_weight * normalized,
)
ranked = sorted(merged.values(), key=lambda item: item[1], reverse=True)
return ranked[:k]
def _bm25_search(
self,
records: list[KnowledgeRecord],
query: str,
k: int,
) -> list[tuple[Document, float]]:
tokenized_corpus = [self._tokenize(record.content) or ["_"] for record in records]
bm25 = BM25Plus(tokenized_corpus)
query_tokens = self._tokenize(query) or ["_"]
scores = bm25.get_scores(query_tokens)
indexed = list(enumerate(scores))
ranked_indexes = sorted(indexed, key=lambda item: item[1], reverse=True)[:k]
hits: list[tuple[Document, float]] = []
for index, score in ranked_indexes:
record = records[index]
hits.append((self._record_to_document(record=record), float(score)))
return hits
@staticmethod
def _tokenize(text: str) -> list[str]:
return re.findall(r"[a-zA-Z0-9_]+", text.lower())
@staticmethod
def _record_to_document(record: KnowledgeRecord) -> Document:
return Document(
page_content=record.content,
metadata={
"record_id": record.record_id,
"namespace": record.namespace,
"fact_key": record.fact_key,
"fact_value": record.fact_value,
"updated_at": record.updated_at.isoformat(),
},
)
@staticmethod
def _record_id(document: Document) -> str:
return str(document.metadata.get("record_id"))