Marik1337's picture
Add application file
6059138
from __future__ import annotations
import hashlib
import json
import re
from datetime import datetime, timezone
from pathlib import Path
from uuid import NAMESPACE_URL, uuid4, uuid5
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from memory_agent.config import AppConfig
from memory_agent.models import KnowledgeRecord
class FaissMemoryStore:
def __init__(self, config: AppConfig, embeddings: Embeddings) -> None:
self._config = config
self._embeddings = embeddings
self._data_dir = Path(self._config.memory_data_dir).expanduser().resolve()
self._indexes_dir = self._data_dir / "indexes"
self._records_file = self._data_dir / "knowledge_records.json"
self._vectorstores: dict[str, FAISS] = {}
self._embedding_dimension: int | None = None
self._data_dir.mkdir(parents=True, exist_ok=True)
self._indexes_dir.mkdir(parents=True, exist_ok=True)
self._records_by_namespace = self._load_records()
def upsert_knowledge(
self,
namespace: str,
content: str,
fact_key: str | None = None,
fact_value: str | None = None,
) -> KnowledgeRecord:
normalized_key = self.normalize_key(fact_key) if fact_key else None
record_id = (
str(uuid5(NAMESPACE_URL, f"{namespace}:{normalized_key}"))
if normalized_key
else str(uuid4())
)
record = KnowledgeRecord(
record_id=record_id,
namespace=namespace,
content=content,
fact_key=normalized_key,
fact_value=fact_value,
updated_at=datetime.now(tz=timezone.utc),
)
namespace_records = self._records_by_namespace.setdefault(namespace, {})
namespace_records[record_id] = record
self._persist_records()
vector_store = self._get_or_create_vector_store(namespace=namespace)
document = Document(
page_content=record.content,
metadata={
"record_id": record.record_id,
"namespace": record.namespace,
"fact_key": record.fact_key,
"fact_value": record.fact_value,
"updated_at": record.updated_at.isoformat(),
},
)
# Ensure "latest value wins" at vector layer for keyed facts.
try:
vector_store.delete(ids=[record.record_id])
except Exception:
pass
vector_store.add_documents([document], ids=[record.record_id])
self._persist_vector_store(namespace=namespace, vector_store=vector_store)
return record
def fetch_records(self, namespace: str, limit: int = 2000) -> list[KnowledgeRecord]:
records = list(self._records_by_namespace.get(namespace, {}).values())
records.sort(key=lambda record: record.updated_at, reverse=True)
return records[:limit]
def fetch_fact_map(self, namespace: str) -> dict[str, str]:
facts: dict[str, str] = {}
for record in self.fetch_records(namespace=namespace, limit=5000):
key = record.fact_key
value = record.fact_value
if key is None or value is None:
continue
if key not in facts:
facts[key] = value
return facts
def dense_search(self, namespace: str, query: str, k: int = 6) -> list[tuple[Document, float]]:
vector_store = self._get_or_create_vector_store(namespace=namespace)
if not vector_store.index_to_docstore_id:
return []
results = vector_store.similarity_search_with_score(
query,
k=min(max(1, k), len(vector_store.index_to_docstore_id)),
)
dense_results: list[tuple[Document, float]] = []
for document, distance in results:
score = 1.0 / (1.0 + float(distance))
dense_results.append((document, score))
return dense_results
@staticmethod
def normalize_key(raw_key: str) -> str:
text = raw_key.strip().lower()
text = re.sub(r"\b(number|value)\b", "", text)
text = re.sub(r"[^a-z0-9_ ]+", "", text)
text = re.sub(r"\s+", "_", text).strip("_")
return text
def _get_or_create_vector_store(self, namespace: str) -> FAISS:
if namespace in self._vectorstores:
return self._vectorstores[namespace]
namespace_path = self._namespace_index_path(namespace=namespace)
if (namespace_path / "index.faiss").exists():
try:
vector_store = FAISS.load_local(
str(namespace_path),
self._embeddings,
allow_dangerous_deserialization=True,
)
self._vectorstores[namespace] = vector_store
return vector_store
except Exception:
pass
vector_store = self._build_empty_vector_store()
records = self.fetch_records(namespace=namespace, limit=5000)
if records:
documents = [self._record_to_document(record=record) for record in records]
ids = [record.record_id for record in records]
vector_store.add_documents(documents=documents, ids=ids)
self._persist_vector_store(namespace=namespace, vector_store=vector_store)
self._vectorstores[namespace] = vector_store
return vector_store
def _build_empty_vector_store(self) -> FAISS:
embedding_dimension = self._get_embedding_dimension()
index = faiss.IndexFlatL2(embedding_dimension)
return FAISS(
embedding_function=self._embeddings,
index=index,
docstore=InMemoryDocstore({}),
index_to_docstore_id={},
)
def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
probe_vector = self._embeddings.embed_query("embedding_dimension_probe")
if not probe_vector:
raise RuntimeError("Embedding model returned an empty vector.")
self._embedding_dimension = len(probe_vector)
return self._embedding_dimension
def _persist_vector_store(self, namespace: str, vector_store: FAISS) -> None:
namespace_path = self._namespace_index_path(namespace=namespace)
namespace_path.mkdir(parents=True, exist_ok=True)
vector_store.save_local(str(namespace_path))
def _namespace_index_path(self, namespace: str) -> Path:
digest = hashlib.sha256(namespace.encode("utf-8")).hexdigest()[:24]
return self._indexes_dir / digest
def _load_records(self) -> dict[str, dict[str, KnowledgeRecord]]:
if not self._records_file.exists():
return {}
payload = json.loads(self._records_file.read_text(encoding="utf-8"))
namespaces = payload.get("namespaces", {})
loaded: dict[str, dict[str, KnowledgeRecord]] = {}
for namespace, items in namespaces.items():
namespace_records: dict[str, KnowledgeRecord] = {}
for item in items:
record = self._payload_to_record(item)
namespace_records[record.record_id] = record
loaded[namespace] = namespace_records
return loaded
def _persist_records(self) -> None:
payload = {"namespaces": {}}
for namespace, records in self._records_by_namespace.items():
serialized = [self._record_to_payload(record) for record in records.values()]
payload["namespaces"][namespace] = serialized
self._records_file.write_text(json.dumps(payload, indent=2), encoding="utf-8")
@staticmethod
def _record_to_payload(record: KnowledgeRecord) -> dict[str, str | None]:
return {
"record_id": record.record_id,
"namespace": record.namespace,
"content": record.content,
"fact_key": record.fact_key,
"fact_value": record.fact_value,
"updated_at": record.updated_at.isoformat(),
}
@staticmethod
def _payload_to_record(payload: dict[str, str | None]) -> KnowledgeRecord:
return KnowledgeRecord(
record_id=str(payload["record_id"]),
namespace=str(payload["namespace"]),
content=str(payload["content"]),
fact_key=str(payload["fact_key"]) if payload.get("fact_key") is not None else None,
fact_value=str(payload["fact_value"]) if payload.get("fact_value") is not None else None,
updated_at=FaissMemoryStore._to_datetime(str(payload["updated_at"])),
)
@staticmethod
def _record_to_document(record: KnowledgeRecord) -> Document:
return Document(
page_content=record.content,
metadata={
"record_id": record.record_id,
"namespace": record.namespace,
"fact_key": record.fact_key,
"fact_value": record.fact_value,
"updated_at": record.updated_at.isoformat(),
},
)
@staticmethod
def _to_datetime(value: datetime | str) -> datetime:
if isinstance(value, datetime):
return value
return datetime.fromisoformat(value)