LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Adapter for A-Mem (new API: agentic_memory.AgenticMemorySystem)."""
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parents[2] / ".env")
from eval_framework.datasets.schemas import (
MemoryDeltaRecord,
MemorySnapshotRecord,
NormalizedTurn,
RetrievalItem,
RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter
_DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/A-Mem")
class AMemV2Adapter(MemoryAdapter):
"""Adapter for A-Mem (new agentic_memory API)."""
def __init__(
self,
*,
source_root: str | os.PathLike[str] | None = None,
**kwargs: Any,
) -> None:
root = Path(source_root or _DEFAULT_SOURCE).resolve()
if str(root) not in sys.path:
sys.path.insert(0, str(root))
from agentic_memory.memory_system import AgenticMemorySystem
self._cls = AgenticMemorySystem
self._backend: Any = None
self._session_id = ""
self._prev_snapshot_ids: set[str] = set()
self._init_backend()
def _init_backend(self) -> None:
self._backend = self._cls(
model_name="all-MiniLM-L6-v2",
llm_backend="openai",
llm_model=os.getenv("OPENAI_MODEL") or "gpt-4o",
api_key=os.getenv("OPENAI_API_KEY"),
)
def reset(self) -> None:
self._init_backend()
self._prev_snapshot_ids = set()
def ingest_turn(self, turn: NormalizedTurn) -> None:
self._session_id = turn.session_id
text = f"{turn.role}: {turn.text}"
for att in turn.attachments:
text += f"\n[{att.type}] {att.caption}"
self._backend.add_note(text, time=turn.timestamp)
def end_session(self, session_id: str) -> None:
self._session_id = session_id
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
rows: list[MemorySnapshotRecord] = []
for mid, note in self._backend.memories.items():
content = str(getattr(note, "content", ""))
context = getattr(note, "context", "")
keywords = list(getattr(note, "keywords", []) or [])
parts = [content]
if context:
parts.append(f"[context] {context}")
if keywords:
parts.append(f"[keywords] {', '.join(keywords)}")
rows.append(MemorySnapshotRecord(
memory_id=str(mid),
text="\n".join(parts),
session_id=self._session_id,
status="active",
source="A-Mem",
raw_backend_id=str(mid),
raw_backend_type="a_mem_note",
metadata={},
))
return rows
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
current = self.snapshot_memories()
current_ids = {s.memory_id for s in current}
deltas = [
MemoryDeltaRecord(
session_id=session_id, op="add", text=s.text,
linked_previous=(), raw_backend_id=s.raw_backend_id,
metadata={"baseline": "A-Mem"},
)
for s in current if s.memory_id not in self._prev_snapshot_ids
]
self._prev_snapshot_ids = current_ids
return deltas
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
items: list[RetrievalItem] = []
try:
results = self._backend.search(query, k=top_k)
for i, r in enumerate(results[:top_k]):
text = r.get("content", str(r)) if isinstance(r, dict) else str(r)
mid = r.get("id", str(i)) if isinstance(r, dict) else str(i)
score = float(r.get("score", 1.0 / (i + 1))) if isinstance(r, dict) else 1.0 / (i + 1)
items.append(RetrievalItem(
rank=i, memory_id=str(mid), text=text,
score=score, raw_backend_id=str(mid),
))
except Exception:
# Fallback to raw search
try:
raw = self._backend.find_related_memories_raw(query, k=top_k)
if raw:
items.append(RetrievalItem(
rank=0, memory_id="bundle", text=str(raw),
score=1.0, raw_backend_id=None,
))
except Exception:
pass
return RetrievalRecord(
query=query, top_k=top_k, items=items[:top_k],
raw_trace={"baseline": "A-Mem"},
)
def get_capabilities(self) -> dict[str, Any]:
return {
"backend": "A-Mem",
"baseline": "A-Mem",
"available": self._backend is not None,
"delta_granularity": "snapshot_diff",
"snapshot_mode": "full",
}