"""Adapter for A-Mem (new API: agentic_memory.AgenticMemorySystem).""" from __future__ import annotations import os import sys from pathlib import Path from typing import Any from dotenv import load_dotenv load_dotenv(Path(__file__).resolve().parents[2] / ".env") from eval_framework.datasets.schemas import ( MemoryDeltaRecord, MemorySnapshotRecord, NormalizedTurn, RetrievalItem, RetrievalRecord, ) from eval_framework.memory_adapters.base import MemoryAdapter _DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/A-Mem") class AMemV2Adapter(MemoryAdapter): """Adapter for A-Mem (new agentic_memory API).""" def __init__( self, *, source_root: str | os.PathLike[str] | None = None, **kwargs: Any, ) -> None: root = Path(source_root or _DEFAULT_SOURCE).resolve() if str(root) not in sys.path: sys.path.insert(0, str(root)) from agentic_memory.memory_system import AgenticMemorySystem self._cls = AgenticMemorySystem self._backend: Any = None self._session_id = "" self._prev_snapshot_ids: set[str] = set() self._init_backend() def _init_backend(self) -> None: self._backend = self._cls( model_name="all-MiniLM-L6-v2", llm_backend="openai", llm_model=os.getenv("OPENAI_MODEL") or "gpt-4o", api_key=os.getenv("OPENAI_API_KEY"), ) def reset(self) -> None: self._init_backend() self._prev_snapshot_ids = set() def ingest_turn(self, turn: NormalizedTurn) -> None: self._session_id = turn.session_id text = f"{turn.role}: {turn.text}" for att in turn.attachments: text += f"\n[{att.type}] {att.caption}" self._backend.add_note(text, time=turn.timestamp) def end_session(self, session_id: str) -> None: self._session_id = session_id def snapshot_memories(self) -> list[MemorySnapshotRecord]: rows: list[MemorySnapshotRecord] = [] for mid, note in self._backend.memories.items(): content = str(getattr(note, "content", "")) context = getattr(note, "context", "") keywords = list(getattr(note, "keywords", []) or []) parts = [content] if context: parts.append(f"[context] {context}") if keywords: parts.append(f"[keywords] {', '.join(keywords)}") rows.append(MemorySnapshotRecord( memory_id=str(mid), text="\n".join(parts), session_id=self._session_id, status="active", source="A-Mem", raw_backend_id=str(mid), raw_backend_type="a_mem_note", metadata={}, )) return rows def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: current = self.snapshot_memories() current_ids = {s.memory_id for s in current} deltas = [ MemoryDeltaRecord( session_id=session_id, op="add", text=s.text, linked_previous=(), raw_backend_id=s.raw_backend_id, metadata={"baseline": "A-Mem"}, ) for s in current if s.memory_id not in self._prev_snapshot_ids ] self._prev_snapshot_ids = current_ids return deltas def retrieve(self, query: str, top_k: int) -> RetrievalRecord: items: list[RetrievalItem] = [] try: results = self._backend.search(query, k=top_k) for i, r in enumerate(results[:top_k]): text = r.get("content", str(r)) if isinstance(r, dict) else str(r) mid = r.get("id", str(i)) if isinstance(r, dict) else str(i) score = float(r.get("score", 1.0 / (i + 1))) if isinstance(r, dict) else 1.0 / (i + 1) items.append(RetrievalItem( rank=i, memory_id=str(mid), text=text, score=score, raw_backend_id=str(mid), )) except Exception: # Fallback to raw search try: raw = self._backend.find_related_memories_raw(query, k=top_k) if raw: items.append(RetrievalItem( rank=0, memory_id="bundle", text=str(raw), score=1.0, raw_backend_id=None, )) except Exception: pass return RetrievalRecord( query=query, top_k=top_k, items=items[:top_k], raw_trace={"baseline": "A-Mem"}, ) def get_capabilities(self) -> dict[str, Any]: return { "backend": "A-Mem", "baseline": "A-Mem", "available": self._backend is not None, "delta_granularity": "snapshot_diff", "snapshot_mode": "full", }