from __future__ import annotations import re from langchain_core.documents import Document from rank_bm25 import BM25Plus from memory_agent.models import KnowledgeRecord from memory_agent.storage import FaissMemoryStore class HybridRetriever: def __init__( self, store: FaissMemoryStore, dense_weight: float = 0.55, sparse_weight: float = 0.45, ) -> None: if dense_weight <= 0 or sparse_weight <= 0: raise ValueError("Dense and sparse weights must be positive.") self._store = store self._dense_weight = dense_weight self._sparse_weight = sparse_weight def retrieve(self, namespace: str, query: str, k: int = 6) -> list[tuple[Document, float]]: records = self._store.fetch_records(namespace=namespace, limit=2000) if not records: return [] dense_hits = self._store.dense_search(namespace=namespace, query=query, k=max(k, 8)) sparse_hits = self._bm25_search(records=records, query=query, k=max(k, 8)) dense_max = max((score for _, score in dense_hits), default=1.0) sparse_max = max((score for _, score in sparse_hits), default=1.0) merged: dict[str, tuple[Document, float]] = {} for document, score in dense_hits: record_id = self._record_id(document=document) normalized = score / dense_max if dense_max else 0.0 merged[record_id] = (document, self._dense_weight * normalized) for document, score in sparse_hits: record_id = self._record_id(document=document) normalized = score / sparse_max if sparse_max else 0.0 base_document, base_score = merged.get(record_id, (document, 0.0)) merged[record_id] = ( base_document, base_score + self._sparse_weight * normalized, ) ranked = sorted(merged.values(), key=lambda item: item[1], reverse=True) return ranked[:k] def _bm25_search( self, records: list[KnowledgeRecord], query: str, k: int, ) -> list[tuple[Document, float]]: tokenized_corpus = [self._tokenize(record.content) or ["_"] for record in records] bm25 = BM25Plus(tokenized_corpus) query_tokens = self._tokenize(query) or ["_"] scores = bm25.get_scores(query_tokens) indexed = list(enumerate(scores)) ranked_indexes = sorted(indexed, key=lambda item: item[1], reverse=True)[:k] hits: list[tuple[Document, float]] = [] for index, score in ranked_indexes: record = records[index] hits.append((self._record_to_document(record=record), float(score))) return hits @staticmethod def _tokenize(text: str) -> list[str]: return re.findall(r"[a-zA-Z0-9_]+", text.lower()) @staticmethod def _record_to_document(record: KnowledgeRecord) -> Document: return Document( page_content=record.content, metadata={ "record_id": record.record_id, "namespace": record.namespace, "fact_key": record.fact_key, "fact_value": record.fact_value, "updated_at": record.updated_at.isoformat(), }, ) @staticmethod def _record_id(document: Document) -> str: return str(document.metadata.get("record_id"))