| from __future__ import annotations |
|
|
| import hashlib |
| import json |
| import re |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from uuid import NAMESPACE_URL, uuid4, uuid5 |
|
|
| import faiss |
| from langchain_community.docstore.in_memory import InMemoryDocstore |
| from langchain_community.vectorstores import FAISS |
| from langchain_core.documents import Document |
| from langchain_core.embeddings import Embeddings |
|
|
| from memory_agent.config import AppConfig |
| from memory_agent.models import KnowledgeRecord |
|
|
|
|
| class FaissMemoryStore: |
| def __init__(self, config: AppConfig, embeddings: Embeddings) -> None: |
| self._config = config |
| self._embeddings = embeddings |
| self._data_dir = Path(self._config.memory_data_dir).expanduser().resolve() |
| self._indexes_dir = self._data_dir / "indexes" |
| self._records_file = self._data_dir / "knowledge_records.json" |
| self._vectorstores: dict[str, FAISS] = {} |
| self._embedding_dimension: int | None = None |
|
|
| self._data_dir.mkdir(parents=True, exist_ok=True) |
| self._indexes_dir.mkdir(parents=True, exist_ok=True) |
| self._records_by_namespace = self._load_records() |
|
|
| def upsert_knowledge( |
| self, |
| namespace: str, |
| content: str, |
| fact_key: str | None = None, |
| fact_value: str | None = None, |
| ) -> KnowledgeRecord: |
| normalized_key = self.normalize_key(fact_key) if fact_key else None |
| record_id = ( |
| str(uuid5(NAMESPACE_URL, f"{namespace}:{normalized_key}")) |
| if normalized_key |
| else str(uuid4()) |
| ) |
| record = KnowledgeRecord( |
| record_id=record_id, |
| namespace=namespace, |
| content=content, |
| fact_key=normalized_key, |
| fact_value=fact_value, |
| updated_at=datetime.now(tz=timezone.utc), |
| ) |
| namespace_records = self._records_by_namespace.setdefault(namespace, {}) |
| namespace_records[record_id] = record |
| self._persist_records() |
|
|
| vector_store = self._get_or_create_vector_store(namespace=namespace) |
| document = Document( |
| page_content=record.content, |
| metadata={ |
| "record_id": record.record_id, |
| "namespace": record.namespace, |
| "fact_key": record.fact_key, |
| "fact_value": record.fact_value, |
| "updated_at": record.updated_at.isoformat(), |
| }, |
| ) |
| |
| try: |
| vector_store.delete(ids=[record.record_id]) |
| except Exception: |
| pass |
| vector_store.add_documents([document], ids=[record.record_id]) |
| self._persist_vector_store(namespace=namespace, vector_store=vector_store) |
| return record |
|
|
| def fetch_records(self, namespace: str, limit: int = 2000) -> list[KnowledgeRecord]: |
| records = list(self._records_by_namespace.get(namespace, {}).values()) |
| records.sort(key=lambda record: record.updated_at, reverse=True) |
| return records[:limit] |
|
|
| def fetch_fact_map(self, namespace: str) -> dict[str, str]: |
| facts: dict[str, str] = {} |
| for record in self.fetch_records(namespace=namespace, limit=5000): |
| key = record.fact_key |
| value = record.fact_value |
| if key is None or value is None: |
| continue |
| if key not in facts: |
| facts[key] = value |
| return facts |
|
|
| def dense_search(self, namespace: str, query: str, k: int = 6) -> list[tuple[Document, float]]: |
| vector_store = self._get_or_create_vector_store(namespace=namespace) |
| if not vector_store.index_to_docstore_id: |
| return [] |
|
|
| results = vector_store.similarity_search_with_score( |
| query, |
| k=min(max(1, k), len(vector_store.index_to_docstore_id)), |
| ) |
|
|
| dense_results: list[tuple[Document, float]] = [] |
| for document, distance in results: |
| score = 1.0 / (1.0 + float(distance)) |
| dense_results.append((document, score)) |
| return dense_results |
|
|
| @staticmethod |
| def normalize_key(raw_key: str) -> str: |
| text = raw_key.strip().lower() |
| text = re.sub(r"\b(number|value)\b", "", text) |
| text = re.sub(r"[^a-z0-9_ ]+", "", text) |
| text = re.sub(r"\s+", "_", text).strip("_") |
| return text |
|
|
| def _get_or_create_vector_store(self, namespace: str) -> FAISS: |
| if namespace in self._vectorstores: |
| return self._vectorstores[namespace] |
|
|
| namespace_path = self._namespace_index_path(namespace=namespace) |
| if (namespace_path / "index.faiss").exists(): |
| try: |
| vector_store = FAISS.load_local( |
| str(namespace_path), |
| self._embeddings, |
| allow_dangerous_deserialization=True, |
| ) |
| self._vectorstores[namespace] = vector_store |
| return vector_store |
| except Exception: |
| pass |
|
|
| vector_store = self._build_empty_vector_store() |
| records = self.fetch_records(namespace=namespace, limit=5000) |
| if records: |
| documents = [self._record_to_document(record=record) for record in records] |
| ids = [record.record_id for record in records] |
| vector_store.add_documents(documents=documents, ids=ids) |
| self._persist_vector_store(namespace=namespace, vector_store=vector_store) |
|
|
| self._vectorstores[namespace] = vector_store |
| return vector_store |
|
|
| def _build_empty_vector_store(self) -> FAISS: |
| embedding_dimension = self._get_embedding_dimension() |
| index = faiss.IndexFlatL2(embedding_dimension) |
| return FAISS( |
| embedding_function=self._embeddings, |
| index=index, |
| docstore=InMemoryDocstore({}), |
| index_to_docstore_id={}, |
| ) |
|
|
| def _get_embedding_dimension(self) -> int: |
| if self._embedding_dimension is None: |
| probe_vector = self._embeddings.embed_query("embedding_dimension_probe") |
| if not probe_vector: |
| raise RuntimeError("Embedding model returned an empty vector.") |
| self._embedding_dimension = len(probe_vector) |
| return self._embedding_dimension |
|
|
| def _persist_vector_store(self, namespace: str, vector_store: FAISS) -> None: |
| namespace_path = self._namespace_index_path(namespace=namespace) |
| namespace_path.mkdir(parents=True, exist_ok=True) |
| vector_store.save_local(str(namespace_path)) |
|
|
| def _namespace_index_path(self, namespace: str) -> Path: |
| digest = hashlib.sha256(namespace.encode("utf-8")).hexdigest()[:24] |
| return self._indexes_dir / digest |
|
|
| def _load_records(self) -> dict[str, dict[str, KnowledgeRecord]]: |
| if not self._records_file.exists(): |
| return {} |
| payload = json.loads(self._records_file.read_text(encoding="utf-8")) |
| namespaces = payload.get("namespaces", {}) |
| loaded: dict[str, dict[str, KnowledgeRecord]] = {} |
| for namespace, items in namespaces.items(): |
| namespace_records: dict[str, KnowledgeRecord] = {} |
| for item in items: |
| record = self._payload_to_record(item) |
| namespace_records[record.record_id] = record |
| loaded[namespace] = namespace_records |
| return loaded |
|
|
| def _persist_records(self) -> None: |
| payload = {"namespaces": {}} |
| for namespace, records in self._records_by_namespace.items(): |
| serialized = [self._record_to_payload(record) for record in records.values()] |
| payload["namespaces"][namespace] = serialized |
| self._records_file.write_text(json.dumps(payload, indent=2), encoding="utf-8") |
|
|
| @staticmethod |
| def _record_to_payload(record: KnowledgeRecord) -> dict[str, str | None]: |
| return { |
| "record_id": record.record_id, |
| "namespace": record.namespace, |
| "content": record.content, |
| "fact_key": record.fact_key, |
| "fact_value": record.fact_value, |
| "updated_at": record.updated_at.isoformat(), |
| } |
|
|
| @staticmethod |
| def _payload_to_record(payload: dict[str, str | None]) -> KnowledgeRecord: |
| return KnowledgeRecord( |
| record_id=str(payload["record_id"]), |
| namespace=str(payload["namespace"]), |
| content=str(payload["content"]), |
| fact_key=str(payload["fact_key"]) if payload.get("fact_key") is not None else None, |
| fact_value=str(payload["fact_value"]) if payload.get("fact_value") is not None else None, |
| updated_at=FaissMemoryStore._to_datetime(str(payload["updated_at"])), |
| ) |
|
|
| @staticmethod |
| def _record_to_document(record: KnowledgeRecord) -> Document: |
| return Document( |
| page_content=record.content, |
| metadata={ |
| "record_id": record.record_id, |
| "namespace": record.namespace, |
| "fact_key": record.fact_key, |
| "fact_value": record.fact_value, |
| "updated_at": record.updated_at.isoformat(), |
| }, |
| ) |
|
|
| @staticmethod |
| def _to_datetime(value: datetime | str) -> datetime: |
| if isinstance(value, datetime): |
| return value |
| return datetime.fromisoformat(value) |
|
|