| """Adapter for the external A-Mem baseline.""" |
|
|
| from __future__ import annotations |
|
|
| import importlib |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Any, Callable |
|
|
| from eval_framework.datasets.schemas import ( |
| MemoryDeltaRecord, |
| MemorySnapshotRecord, |
| NormalizedTurn, |
| RetrievalItem, |
| RetrievalRecord, |
| ) |
| from eval_framework.memory_adapters.base import MemoryAdapter |
|
|
| _BACKEND_ID = "A-Mem" |
|
|
| INTEGRATION_ERROR = ( |
| f"{_BACKEND_ID} backend unavailable." |
| ) |
|
|
|
|
| class AMemAdapter(MemoryAdapter): |
| """Thin wrapper around A-Mem's robust memory system.""" |
|
|
| def __init__( |
| self, |
| *, |
| backend: Any | None = None, |
| backend_factory: Callable[[], Any] | None = None, |
| source_root: str | os.PathLike[str] | None = None, |
| model_name: str = "all-MiniLM-L6-v2", |
| llm_backend: str = "openai", |
| llm_model: str | None = None, |
| api_key: str | None = None, |
| api_base: str | None = None, |
| sglang_host: str = "http://localhost", |
| sglang_port: int = 30000, |
| ) -> None: |
| self._source_root = Path(source_root).resolve() if source_root else self._default_source_root() |
| resolved_llm_model = llm_model or os.getenv("OPENAI_MODEL") or "gpt-5.1" |
| self._backend: Any | None = None |
| self._backend_factory = backend_factory |
| self._integration_error: str | None = None |
| self._session_id = "" |
| self._prev_snapshot_ids: set[str] = set() |
| self._note_session_map: dict[str, str] = {} |
|
|
| if backend is not None: |
| self._backend = backend |
| else: |
| try: |
| if self._backend_factory is None: |
| self._backend_factory = self._build_backend_factory( |
| model_name=model_name, |
| llm_backend=llm_backend, |
| llm_model=resolved_llm_model, |
| api_key=api_key, |
| api_base=api_base, |
| sglang_host=sglang_host, |
| sglang_port=sglang_port, |
| ) |
| self._backend = self._backend_factory() |
| except Exception as exc: |
| self._integration_error = str(exc) |
|
|
| @staticmethod |
| def _default_source_root() -> Path: |
| here = Path(__file__).resolve() |
| |
| return (here.parents[2].parent / "data_pipline" / "A-mem").resolve() |
|
|
| def _build_backend_factory( |
| self, |
| *, |
| model_name: str, |
| llm_backend: str, |
| llm_model: str, |
| api_key: str | None, |
| api_base: str | None, |
| sglang_host: str, |
| sglang_port: int, |
| ) -> Callable[[], Any]: |
| if not self._source_root.is_dir(): |
| raise RuntimeError( |
| f"{_BACKEND_ID}: source root not found at {self._source_root}" |
| ) |
| src = str(self._source_root) |
| if src not in sys.path: |
| sys.path.insert(0, src) |
| mod = importlib.import_module("memory_layer_robust") |
| backend_cls = getattr(mod, "RobustAgenticMemorySystem") |
| return lambda: backend_cls( |
| model_name=model_name, |
| llm_backend=llm_backend, |
| llm_model=llm_model, |
| api_key=api_key or os.getenv("OPENAI_API_KEY"), |
| api_base=api_base or os.getenv("OPENAI_BASE_URL"), |
| sglang_host=sglang_host, |
| sglang_port=sglang_port, |
| ) |
|
|
| def _runtime_error(self) -> RuntimeError: |
| detail = self._integration_error or INTEGRATION_ERROR |
| return RuntimeError( |
| f"{_BACKEND_ID}: backend unavailable — {detail}" |
| ) |
|
|
| def reset(self) -> None: |
| if self._backend_factory is None and self._backend is None: |
| raise self._runtime_error() |
| if self._backend_factory is not None: |
| self._backend = self._backend_factory() |
| self._prev_snapshot_ids = set() |
| self._note_session_map = {} |
| self._session_id = "" |
|
|
| def ingest_turn(self, turn: NormalizedTurn) -> None: |
| backend = self._require_backend() |
| self._session_id = turn.session_id |
| text = self._turn_text(turn) |
| note_id = backend.add_note(text, time=turn.timestamp) |
| self._note_session_map[str(note_id)] = turn.session_id |
|
|
| def end_session(self, session_id: str) -> None: |
| self._require_backend() |
| self._session_id = session_id |
|
|
| def snapshot_memories(self) -> list[MemorySnapshotRecord]: |
| backend = self._require_backend() |
| rows: list[MemorySnapshotRecord] = [] |
| for note_id, note in getattr(backend, "memories", {}).items(): |
| sid = self._note_session_map.get(str(note_id), self._session_id) |
| content = str(getattr(note, "content", "")) |
| context = getattr(note, "context", "") |
| keywords = list(getattr(note, "keywords", []) or []) |
| tags = list(getattr(note, "tags", []) or []) |
| |
| |
| |
| enriched_parts = [content] |
| if context: |
| enriched_parts.append(f"[context] {context}") |
| if keywords: |
| enriched_parts.append(f"[keywords] {', '.join(keywords)}") |
| if tags: |
| enriched_parts.append(f"[tags] {', '.join(tags)}") |
| rows.append( |
| MemorySnapshotRecord( |
| memory_id=str(getattr(note, "id", note_id)), |
| text="\n".join(enriched_parts), |
| session_id=sid, |
| status="active", |
| source=_BACKEND_ID, |
| raw_backend_id=str(getattr(note, "id", note_id)), |
| raw_backend_type="a_mem_note", |
| metadata={ |
| "timestamp": getattr(note, "timestamp", None), |
| "context": context, |
| "keywords": keywords, |
| "tags": tags, |
| "links": list(getattr(note, "links", []) or []), |
| }, |
| ) |
| ) |
| return rows |
|
|
| def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: |
| """Export delta by diffing current snapshot against previous snapshot.""" |
| self._require_backend() |
| current_snapshot = self.snapshot_memories() |
| deltas: list[MemoryDeltaRecord] = [] |
| current_ids: set[str] = set() |
|
|
| for snap in current_snapshot: |
| current_ids.add(snap.memory_id) |
| if snap.memory_id not in self._prev_snapshot_ids: |
| deltas.append( |
| MemoryDeltaRecord( |
| session_id=session_id, |
| op="add", |
| text=snap.text, |
| linked_previous=(), |
| raw_backend_id=snap.raw_backend_id, |
| metadata={ |
| "baseline": _BACKEND_ID, |
| "backend_type": snap.raw_backend_type, |
| }, |
| ) |
| ) |
|
|
| self._prev_snapshot_ids = current_ids |
| return deltas |
|
|
| def retrieve(self, query: str, top_k: int) -> RetrievalRecord: |
| backend = self._require_backend() |
| items: list[RetrievalItem] = [] |
| memories = list(getattr(backend, "memories", {}).values()) |
| retriever = getattr(backend, "retriever", None) |
| if retriever is not None and hasattr(retriever, "search"): |
| for rank, idx in enumerate(retriever.search(query, top_k)): |
| if 0 <= int(idx) < len(memories): |
| note = memories[int(idx)] |
| items.append( |
| RetrievalItem( |
| rank=rank, |
| memory_id=str(getattr(note, "id", idx)), |
| text=str(getattr(note, "content", "")), |
| score=1.0 / float(rank + 1), |
| raw_backend_id=str(getattr(note, "id", idx)), |
| ) |
| ) |
| if not items and hasattr(backend, "find_related_memories_raw"): |
| raw = backend.find_related_memories_raw(query, k=top_k) |
| if raw: |
| items.append( |
| RetrievalItem( |
| rank=0, |
| memory_id="a_mem:bundle", |
| text=str(raw), |
| score=1.0, |
| raw_backend_id=None, |
| ) |
| ) |
| return RetrievalRecord( |
| query=query, |
| top_k=top_k, |
| items=items[:top_k], |
| raw_trace={"baseline": _BACKEND_ID}, |
| ) |
|
|
| def get_capabilities(self) -> dict[str, Any]: |
| available = self._backend is not None or self._backend_factory is not None |
| return { |
| "backend": _BACKEND_ID, |
| "baseline": _BACKEND_ID, |
| "available": available and self._integration_error is None, |
| "integration_status": "integrated" if available and self._integration_error is None else "unavailable", |
| "integration_error": self._integration_error or INTEGRATION_ERROR, |
| "delta_granularity": "ingest_turn_only", |
| "snapshot_mode": "full_store", |
| } |
|
|
| def _require_backend(self) -> Any: |
| if self._backend is None: |
| raise self._runtime_error() |
| return self._backend |
|
|
| @staticmethod |
| def _turn_text(turn: NormalizedTurn) -> str: |
| parts = [f"{turn.role}: {turn.text}"] |
| for att in turn.attachments: |
| parts.append(f"[{att.type}] {att.caption}") |
| return "\n".join(parts) |
|
|