"""Adapter for the external A-Mem baseline.""" from __future__ import annotations import importlib import os import sys from pathlib import Path from typing import Any, Callable from eval_framework.datasets.schemas import ( MemoryDeltaRecord, MemorySnapshotRecord, NormalizedTurn, RetrievalItem, RetrievalRecord, ) from eval_framework.memory_adapters.base import MemoryAdapter _BACKEND_ID = "A-Mem" INTEGRATION_ERROR = ( f"{_BACKEND_ID} backend unavailable." ) class AMemAdapter(MemoryAdapter): """Thin wrapper around A-Mem's robust memory system.""" def __init__( self, *, backend: Any | None = None, backend_factory: Callable[[], Any] | None = None, source_root: str | os.PathLike[str] | None = None, model_name: str = "all-MiniLM-L6-v2", llm_backend: str = "openai", llm_model: str | None = None, api_key: str | None = None, api_base: str | None = None, sglang_host: str = "http://localhost", sglang_port: int = 30000, ) -> None: self._source_root = Path(source_root).resolve() if source_root else self._default_source_root() resolved_llm_model = llm_model or os.getenv("OPENAI_MODEL") or "gpt-5.1" self._backend: Any | None = None self._backend_factory = backend_factory self._integration_error: str | None = None self._session_id = "" self._prev_snapshot_ids: set[str] = set() self._note_session_map: dict[str, str] = {} if backend is not None: self._backend = backend else: try: if self._backend_factory is None: self._backend_factory = self._build_backend_factory( model_name=model_name, llm_backend=llm_backend, llm_model=resolved_llm_model, api_key=api_key, api_base=api_base, sglang_host=sglang_host, sglang_port=sglang_port, ) self._backend = self._backend_factory() except Exception as exc: self._integration_error = str(exc) @staticmethod def _default_source_root() -> Path: here = Path(__file__).resolve() # memory_adapters/ -> eval_framework/ -> our/ -> Benchmark/ return (here.parents[2].parent / "data_pipline" / "A-mem").resolve() def _build_backend_factory( self, *, model_name: str, llm_backend: str, llm_model: str, api_key: str | None, api_base: str | None, sglang_host: str, sglang_port: int, ) -> Callable[[], Any]: if not self._source_root.is_dir(): raise RuntimeError( f"{_BACKEND_ID}: source root not found at {self._source_root}" ) src = str(self._source_root) if src not in sys.path: sys.path.insert(0, src) mod = importlib.import_module("memory_layer_robust") backend_cls = getattr(mod, "RobustAgenticMemorySystem") return lambda: backend_cls( model_name=model_name, llm_backend=llm_backend, llm_model=llm_model, api_key=api_key or os.getenv("OPENAI_API_KEY"), api_base=api_base or os.getenv("OPENAI_BASE_URL"), sglang_host=sglang_host, sglang_port=sglang_port, ) def _runtime_error(self) -> RuntimeError: detail = self._integration_error or INTEGRATION_ERROR return RuntimeError( f"{_BACKEND_ID}: backend unavailable — {detail}" ) def reset(self) -> None: if self._backend_factory is None and self._backend is None: raise self._runtime_error() if self._backend_factory is not None: self._backend = self._backend_factory() self._prev_snapshot_ids = set() self._note_session_map = {} self._session_id = "" def ingest_turn(self, turn: NormalizedTurn) -> None: backend = self._require_backend() self._session_id = turn.session_id text = self._turn_text(turn) note_id = backend.add_note(text, time=turn.timestamp) self._note_session_map[str(note_id)] = turn.session_id def end_session(self, session_id: str) -> None: self._require_backend() self._session_id = session_id def snapshot_memories(self) -> list[MemorySnapshotRecord]: backend = self._require_backend() rows: list[MemorySnapshotRecord] = [] for note_id, note in getattr(backend, "memories", {}).items(): sid = self._note_session_map.get(str(note_id), self._session_id) content = str(getattr(note, "content", "")) context = getattr(note, "context", "") keywords = list(getattr(note, "keywords", []) or []) tags = list(getattr(note, "tags", []) or []) # Include A-Mem enrichments in the snapshot text so that the # eval captures what the system actually processed, not just # the raw input. enriched_parts = [content] if context: enriched_parts.append(f"[context] {context}") if keywords: enriched_parts.append(f"[keywords] {', '.join(keywords)}") if tags: enriched_parts.append(f"[tags] {', '.join(tags)}") rows.append( MemorySnapshotRecord( memory_id=str(getattr(note, "id", note_id)), text="\n".join(enriched_parts), session_id=sid, status="active", source=_BACKEND_ID, raw_backend_id=str(getattr(note, "id", note_id)), raw_backend_type="a_mem_note", metadata={ "timestamp": getattr(note, "timestamp", None), "context": context, "keywords": keywords, "tags": tags, "links": list(getattr(note, "links", []) or []), }, ) ) return rows def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: """Export delta by diffing current snapshot against previous snapshot.""" self._require_backend() current_snapshot = self.snapshot_memories() deltas: list[MemoryDeltaRecord] = [] current_ids: set[str] = set() for snap in current_snapshot: current_ids.add(snap.memory_id) if snap.memory_id not in self._prev_snapshot_ids: deltas.append( MemoryDeltaRecord( session_id=session_id, op="add", text=snap.text, linked_previous=(), raw_backend_id=snap.raw_backend_id, metadata={ "baseline": _BACKEND_ID, "backend_type": snap.raw_backend_type, }, ) ) self._prev_snapshot_ids = current_ids return deltas def retrieve(self, query: str, top_k: int) -> RetrievalRecord: backend = self._require_backend() items: list[RetrievalItem] = [] memories = list(getattr(backend, "memories", {}).values()) retriever = getattr(backend, "retriever", None) if retriever is not None and hasattr(retriever, "search"): for rank, idx in enumerate(retriever.search(query, top_k)): if 0 <= int(idx) < len(memories): note = memories[int(idx)] items.append( RetrievalItem( rank=rank, memory_id=str(getattr(note, "id", idx)), text=str(getattr(note, "content", "")), score=1.0 / float(rank + 1), raw_backend_id=str(getattr(note, "id", idx)), ) ) if not items and hasattr(backend, "find_related_memories_raw"): raw = backend.find_related_memories_raw(query, k=top_k) if raw: items.append( RetrievalItem( rank=0, memory_id="a_mem:bundle", text=str(raw), score=1.0, raw_backend_id=None, ) ) return RetrievalRecord( query=query, top_k=top_k, items=items[:top_k], raw_trace={"baseline": _BACKEND_ID}, ) def get_capabilities(self) -> dict[str, Any]: available = self._backend is not None or self._backend_factory is not None return { "backend": _BACKEND_ID, "baseline": _BACKEND_ID, "available": available and self._integration_error is None, "integration_status": "integrated" if available and self._integration_error is None else "unavailable", "integration_error": self._integration_error or INTEGRATION_ERROR, "delta_granularity": "ingest_turn_only", "snapshot_mode": "full_store", } def _require_backend(self) -> Any: if self._backend is None: raise self._runtime_error() return self._backend @staticmethod def _turn_text(turn: NormalizedTurn) -> str: parts = [f"{turn.role}: {turn.text}"] for att in turn.attachments: parts.append(f"[{att.type}] {att.caption}") return "\n".join(parts)