| from __future__ import annotations |
|
|
| import re |
|
|
| from langchain_core.documents import Document |
| from rank_bm25 import BM25Plus |
|
|
| from memory_agent.models import KnowledgeRecord |
| from memory_agent.storage import FaissMemoryStore |
|
|
|
|
| class HybridRetriever: |
| def __init__( |
| self, |
| store: FaissMemoryStore, |
| dense_weight: float = 0.55, |
| sparse_weight: float = 0.45, |
| ) -> None: |
| if dense_weight <= 0 or sparse_weight <= 0: |
| raise ValueError("Dense and sparse weights must be positive.") |
| self._store = store |
| self._dense_weight = dense_weight |
| self._sparse_weight = sparse_weight |
|
|
| def retrieve(self, namespace: str, query: str, k: int = 6) -> list[tuple[Document, float]]: |
| records = self._store.fetch_records(namespace=namespace, limit=2000) |
| if not records: |
| return [] |
|
|
| dense_hits = self._store.dense_search(namespace=namespace, query=query, k=max(k, 8)) |
| sparse_hits = self._bm25_search(records=records, query=query, k=max(k, 8)) |
|
|
| dense_max = max((score for _, score in dense_hits), default=1.0) |
| sparse_max = max((score for _, score in sparse_hits), default=1.0) |
|
|
| merged: dict[str, tuple[Document, float]] = {} |
|
|
| for document, score in dense_hits: |
| record_id = self._record_id(document=document) |
| normalized = score / dense_max if dense_max else 0.0 |
| merged[record_id] = (document, self._dense_weight * normalized) |
|
|
| for document, score in sparse_hits: |
| record_id = self._record_id(document=document) |
| normalized = score / sparse_max if sparse_max else 0.0 |
| base_document, base_score = merged.get(record_id, (document, 0.0)) |
| merged[record_id] = ( |
| base_document, |
| base_score + self._sparse_weight * normalized, |
| ) |
|
|
| ranked = sorted(merged.values(), key=lambda item: item[1], reverse=True) |
| return ranked[:k] |
|
|
| def _bm25_search( |
| self, |
| records: list[KnowledgeRecord], |
| query: str, |
| k: int, |
| ) -> list[tuple[Document, float]]: |
| tokenized_corpus = [self._tokenize(record.content) or ["_"] for record in records] |
| bm25 = BM25Plus(tokenized_corpus) |
| query_tokens = self._tokenize(query) or ["_"] |
| scores = bm25.get_scores(query_tokens) |
|
|
| indexed = list(enumerate(scores)) |
| ranked_indexes = sorted(indexed, key=lambda item: item[1], reverse=True)[:k] |
|
|
| hits: list[tuple[Document, float]] = [] |
| for index, score in ranked_indexes: |
| record = records[index] |
| hits.append((self._record_to_document(record=record), float(score))) |
| return hits |
|
|
| @staticmethod |
| def _tokenize(text: str) -> list[str]: |
| return re.findall(r"[a-zA-Z0-9_]+", text.lower()) |
|
|
| @staticmethod |
| def _record_to_document(record: KnowledgeRecord) -> Document: |
| return Document( |
| page_content=record.content, |
| metadata={ |
| "record_id": record.record_id, |
| "namespace": record.namespace, |
| "fact_key": record.fact_key, |
| "fact_value": record.fact_value, |
| "updated_at": record.updated_at.isoformat(), |
| }, |
| ) |
|
|
| @staticmethod |
| def _record_id(document: Document) -> str: |
| return str(document.metadata.get("record_id")) |
|
|