File size: 9,429 Bytes
6059138 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | from __future__ import annotations
import hashlib
import json
import re
from datetime import datetime, timezone
from pathlib import Path
from uuid import NAMESPACE_URL, uuid4, uuid5
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from memory_agent.config import AppConfig
from memory_agent.models import KnowledgeRecord
class FaissMemoryStore:
def __init__(self, config: AppConfig, embeddings: Embeddings) -> None:
self._config = config
self._embeddings = embeddings
self._data_dir = Path(self._config.memory_data_dir).expanduser().resolve()
self._indexes_dir = self._data_dir / "indexes"
self._records_file = self._data_dir / "knowledge_records.json"
self._vectorstores: dict[str, FAISS] = {}
self._embedding_dimension: int | None = None
self._data_dir.mkdir(parents=True, exist_ok=True)
self._indexes_dir.mkdir(parents=True, exist_ok=True)
self._records_by_namespace = self._load_records()
def upsert_knowledge(
self,
namespace: str,
content: str,
fact_key: str | None = None,
fact_value: str | None = None,
) -> KnowledgeRecord:
normalized_key = self.normalize_key(fact_key) if fact_key else None
record_id = (
str(uuid5(NAMESPACE_URL, f"{namespace}:{normalized_key}"))
if normalized_key
else str(uuid4())
)
record = KnowledgeRecord(
record_id=record_id,
namespace=namespace,
content=content,
fact_key=normalized_key,
fact_value=fact_value,
updated_at=datetime.now(tz=timezone.utc),
)
namespace_records = self._records_by_namespace.setdefault(namespace, {})
namespace_records[record_id] = record
self._persist_records()
vector_store = self._get_or_create_vector_store(namespace=namespace)
document = Document(
page_content=record.content,
metadata={
"record_id": record.record_id,
"namespace": record.namespace,
"fact_key": record.fact_key,
"fact_value": record.fact_value,
"updated_at": record.updated_at.isoformat(),
},
)
# Ensure "latest value wins" at vector layer for keyed facts.
try:
vector_store.delete(ids=[record.record_id])
except Exception:
pass
vector_store.add_documents([document], ids=[record.record_id])
self._persist_vector_store(namespace=namespace, vector_store=vector_store)
return record
def fetch_records(self, namespace: str, limit: int = 2000) -> list[KnowledgeRecord]:
records = list(self._records_by_namespace.get(namespace, {}).values())
records.sort(key=lambda record: record.updated_at, reverse=True)
return records[:limit]
def fetch_fact_map(self, namespace: str) -> dict[str, str]:
facts: dict[str, str] = {}
for record in self.fetch_records(namespace=namespace, limit=5000):
key = record.fact_key
value = record.fact_value
if key is None or value is None:
continue
if key not in facts:
facts[key] = value
return facts
def dense_search(self, namespace: str, query: str, k: int = 6) -> list[tuple[Document, float]]:
vector_store = self._get_or_create_vector_store(namespace=namespace)
if not vector_store.index_to_docstore_id:
return []
results = vector_store.similarity_search_with_score(
query,
k=min(max(1, k), len(vector_store.index_to_docstore_id)),
)
dense_results: list[tuple[Document, float]] = []
for document, distance in results:
score = 1.0 / (1.0 + float(distance))
dense_results.append((document, score))
return dense_results
@staticmethod
def normalize_key(raw_key: str) -> str:
text = raw_key.strip().lower()
text = re.sub(r"\b(number|value)\b", "", text)
text = re.sub(r"[^a-z0-9_ ]+", "", text)
text = re.sub(r"\s+", "_", text).strip("_")
return text
def _get_or_create_vector_store(self, namespace: str) -> FAISS:
if namespace in self._vectorstores:
return self._vectorstores[namespace]
namespace_path = self._namespace_index_path(namespace=namespace)
if (namespace_path / "index.faiss").exists():
try:
vector_store = FAISS.load_local(
str(namespace_path),
self._embeddings,
allow_dangerous_deserialization=True,
)
self._vectorstores[namespace] = vector_store
return vector_store
except Exception:
pass
vector_store = self._build_empty_vector_store()
records = self.fetch_records(namespace=namespace, limit=5000)
if records:
documents = [self._record_to_document(record=record) for record in records]
ids = [record.record_id for record in records]
vector_store.add_documents(documents=documents, ids=ids)
self._persist_vector_store(namespace=namespace, vector_store=vector_store)
self._vectorstores[namespace] = vector_store
return vector_store
def _build_empty_vector_store(self) -> FAISS:
embedding_dimension = self._get_embedding_dimension()
index = faiss.IndexFlatL2(embedding_dimension)
return FAISS(
embedding_function=self._embeddings,
index=index,
docstore=InMemoryDocstore({}),
index_to_docstore_id={},
)
def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
probe_vector = self._embeddings.embed_query("embedding_dimension_probe")
if not probe_vector:
raise RuntimeError("Embedding model returned an empty vector.")
self._embedding_dimension = len(probe_vector)
return self._embedding_dimension
def _persist_vector_store(self, namespace: str, vector_store: FAISS) -> None:
namespace_path = self._namespace_index_path(namespace=namespace)
namespace_path.mkdir(parents=True, exist_ok=True)
vector_store.save_local(str(namespace_path))
def _namespace_index_path(self, namespace: str) -> Path:
digest = hashlib.sha256(namespace.encode("utf-8")).hexdigest()[:24]
return self._indexes_dir / digest
def _load_records(self) -> dict[str, dict[str, KnowledgeRecord]]:
if not self._records_file.exists():
return {}
payload = json.loads(self._records_file.read_text(encoding="utf-8"))
namespaces = payload.get("namespaces", {})
loaded: dict[str, dict[str, KnowledgeRecord]] = {}
for namespace, items in namespaces.items():
namespace_records: dict[str, KnowledgeRecord] = {}
for item in items:
record = self._payload_to_record(item)
namespace_records[record.record_id] = record
loaded[namespace] = namespace_records
return loaded
def _persist_records(self) -> None:
payload = {"namespaces": {}}
for namespace, records in self._records_by_namespace.items():
serialized = [self._record_to_payload(record) for record in records.values()]
payload["namespaces"][namespace] = serialized
self._records_file.write_text(json.dumps(payload, indent=2), encoding="utf-8")
@staticmethod
def _record_to_payload(record: KnowledgeRecord) -> dict[str, str | None]:
return {
"record_id": record.record_id,
"namespace": record.namespace,
"content": record.content,
"fact_key": record.fact_key,
"fact_value": record.fact_value,
"updated_at": record.updated_at.isoformat(),
}
@staticmethod
def _payload_to_record(payload: dict[str, str | None]) -> KnowledgeRecord:
return KnowledgeRecord(
record_id=str(payload["record_id"]),
namespace=str(payload["namespace"]),
content=str(payload["content"]),
fact_key=str(payload["fact_key"]) if payload.get("fact_key") is not None else None,
fact_value=str(payload["fact_value"]) if payload.get("fact_value") is not None else None,
updated_at=FaissMemoryStore._to_datetime(str(payload["updated_at"])),
)
@staticmethod
def _record_to_document(record: KnowledgeRecord) -> Document:
return Document(
page_content=record.content,
metadata={
"record_id": record.record_id,
"namespace": record.namespace,
"fact_key": record.fact_key,
"fact_value": record.fact_value,
"updated_at": record.updated_at.isoformat(),
},
)
@staticmethod
def _to_datetime(value: datetime | str) -> datetime:
if isinstance(value, datetime):
return value
return datetime.fromisoformat(value)
|