LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Adapter for the external A-Mem baseline."""
from __future__ import annotations
import importlib
import os
import sys
from pathlib import Path
from typing import Any, Callable
from eval_framework.datasets.schemas import (
MemoryDeltaRecord,
MemorySnapshotRecord,
NormalizedTurn,
RetrievalItem,
RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter
_BACKEND_ID = "A-Mem"
INTEGRATION_ERROR = (
f"{_BACKEND_ID} backend unavailable."
)
class AMemAdapter(MemoryAdapter):
"""Thin wrapper around A-Mem's robust memory system."""
def __init__(
self,
*,
backend: Any | None = None,
backend_factory: Callable[[], Any] | None = None,
source_root: str | os.PathLike[str] | None = None,
model_name: str = "all-MiniLM-L6-v2",
llm_backend: str = "openai",
llm_model: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
sglang_host: str = "http://localhost",
sglang_port: int = 30000,
) -> None:
self._source_root = Path(source_root).resolve() if source_root else self._default_source_root()
resolved_llm_model = llm_model or os.getenv("OPENAI_MODEL") or "gpt-5.1"
self._backend: Any | None = None
self._backend_factory = backend_factory
self._integration_error: str | None = None
self._session_id = ""
self._prev_snapshot_ids: set[str] = set()
self._note_session_map: dict[str, str] = {}
if backend is not None:
self._backend = backend
else:
try:
if self._backend_factory is None:
self._backend_factory = self._build_backend_factory(
model_name=model_name,
llm_backend=llm_backend,
llm_model=resolved_llm_model,
api_key=api_key,
api_base=api_base,
sglang_host=sglang_host,
sglang_port=sglang_port,
)
self._backend = self._backend_factory()
except Exception as exc:
self._integration_error = str(exc)
@staticmethod
def _default_source_root() -> Path:
here = Path(__file__).resolve()
# memory_adapters/ -> eval_framework/ -> our/ -> Benchmark/
return (here.parents[2].parent / "data_pipline" / "A-mem").resolve()
def _build_backend_factory(
self,
*,
model_name: str,
llm_backend: str,
llm_model: str,
api_key: str | None,
api_base: str | None,
sglang_host: str,
sglang_port: int,
) -> Callable[[], Any]:
if not self._source_root.is_dir():
raise RuntimeError(
f"{_BACKEND_ID}: source root not found at {self._source_root}"
)
src = str(self._source_root)
if src not in sys.path:
sys.path.insert(0, src)
mod = importlib.import_module("memory_layer_robust")
backend_cls = getattr(mod, "RobustAgenticMemorySystem")
return lambda: backend_cls(
model_name=model_name,
llm_backend=llm_backend,
llm_model=llm_model,
api_key=api_key or os.getenv("OPENAI_API_KEY"),
api_base=api_base or os.getenv("OPENAI_BASE_URL"),
sglang_host=sglang_host,
sglang_port=sglang_port,
)
def _runtime_error(self) -> RuntimeError:
detail = self._integration_error or INTEGRATION_ERROR
return RuntimeError(
f"{_BACKEND_ID}: backend unavailable — {detail}"
)
def reset(self) -> None:
if self._backend_factory is None and self._backend is None:
raise self._runtime_error()
if self._backend_factory is not None:
self._backend = self._backend_factory()
self._prev_snapshot_ids = set()
self._note_session_map = {}
self._session_id = ""
def ingest_turn(self, turn: NormalizedTurn) -> None:
backend = self._require_backend()
self._session_id = turn.session_id
text = self._turn_text(turn)
note_id = backend.add_note(text, time=turn.timestamp)
self._note_session_map[str(note_id)] = turn.session_id
def end_session(self, session_id: str) -> None:
self._require_backend()
self._session_id = session_id
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
backend = self._require_backend()
rows: list[MemorySnapshotRecord] = []
for note_id, note in getattr(backend, "memories", {}).items():
sid = self._note_session_map.get(str(note_id), self._session_id)
content = str(getattr(note, "content", ""))
context = getattr(note, "context", "")
keywords = list(getattr(note, "keywords", []) or [])
tags = list(getattr(note, "tags", []) or [])
# Include A-Mem enrichments in the snapshot text so that the
# eval captures what the system actually processed, not just
# the raw input.
enriched_parts = [content]
if context:
enriched_parts.append(f"[context] {context}")
if keywords:
enriched_parts.append(f"[keywords] {', '.join(keywords)}")
if tags:
enriched_parts.append(f"[tags] {', '.join(tags)}")
rows.append(
MemorySnapshotRecord(
memory_id=str(getattr(note, "id", note_id)),
text="\n".join(enriched_parts),
session_id=sid,
status="active",
source=_BACKEND_ID,
raw_backend_id=str(getattr(note, "id", note_id)),
raw_backend_type="a_mem_note",
metadata={
"timestamp": getattr(note, "timestamp", None),
"context": context,
"keywords": keywords,
"tags": tags,
"links": list(getattr(note, "links", []) or []),
},
)
)
return rows
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
"""Export delta by diffing current snapshot against previous snapshot."""
self._require_backend()
current_snapshot = self.snapshot_memories()
deltas: list[MemoryDeltaRecord] = []
current_ids: set[str] = set()
for snap in current_snapshot:
current_ids.add(snap.memory_id)
if snap.memory_id not in self._prev_snapshot_ids:
deltas.append(
MemoryDeltaRecord(
session_id=session_id,
op="add",
text=snap.text,
linked_previous=(),
raw_backend_id=snap.raw_backend_id,
metadata={
"baseline": _BACKEND_ID,
"backend_type": snap.raw_backend_type,
},
)
)
self._prev_snapshot_ids = current_ids
return deltas
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
backend = self._require_backend()
items: list[RetrievalItem] = []
memories = list(getattr(backend, "memories", {}).values())
retriever = getattr(backend, "retriever", None)
if retriever is not None and hasattr(retriever, "search"):
for rank, idx in enumerate(retriever.search(query, top_k)):
if 0 <= int(idx) < len(memories):
note = memories[int(idx)]
items.append(
RetrievalItem(
rank=rank,
memory_id=str(getattr(note, "id", idx)),
text=str(getattr(note, "content", "")),
score=1.0 / float(rank + 1),
raw_backend_id=str(getattr(note, "id", idx)),
)
)
if not items and hasattr(backend, "find_related_memories_raw"):
raw = backend.find_related_memories_raw(query, k=top_k)
if raw:
items.append(
RetrievalItem(
rank=0,
memory_id="a_mem:bundle",
text=str(raw),
score=1.0,
raw_backend_id=None,
)
)
return RetrievalRecord(
query=query,
top_k=top_k,
items=items[:top_k],
raw_trace={"baseline": _BACKEND_ID},
)
def get_capabilities(self) -> dict[str, Any]:
available = self._backend is not None or self._backend_factory is not None
return {
"backend": _BACKEND_ID,
"baseline": _BACKEND_ID,
"available": available and self._integration_error is None,
"integration_status": "integrated" if available and self._integration_error is None else "unavailable",
"integration_error": self._integration_error or INTEGRATION_ERROR,
"delta_granularity": "ingest_turn_only",
"snapshot_mode": "full_store",
}
def _require_backend(self) -> Any:
if self._backend is None:
raise self._runtime_error()
return self._backend
@staticmethod
def _turn_text(turn: NormalizedTurn) -> str:
parts = [f"{turn.role}: {turn.text}"]
for att in turn.attachments:
parts.append(f"[{att.type}] {att.caption}")
return "\n".join(parts)