eval_framework / memory_adapters /mem0_adapter.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Adapters for Mem0 and Mem0-Graph baselines."""
from __future__ import annotations
import os
import uuid as _uuid
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parents[2] / ".env")
from eval_framework.datasets.schemas import (
MemoryDeltaRecord,
MemorySnapshotRecord,
NormalizedTurn,
RetrievalItem,
RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter
class Mem0Adapter(MemoryAdapter):
"""Adapter for Mem0 (vector mode)."""
def __init__(self, *, use_graph: bool = False, **kwargs: Any) -> None:
from mem0 import Memory
self._user_id = f"eval_{_uuid.uuid4().hex[:8]}"
self._session_id = ""
self._prev_snapshot_ids: set[str] = set()
config: dict[str, Any] = {
"llm": {
"provider": "openai",
"config": {
"model": os.getenv("OPENAI_MODEL") or "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY") or "",
},
},
"embedder": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": os.getenv("OPENAI_API_KEY") or "",
"embedding_dims": 1536,
},
},
}
base_url = os.getenv("OPENAI_BASE_URL")
if base_url:
config["llm"]["config"]["openai_base_url"] = base_url
config["embedder"]["config"]["openai_base_url"] = base_url
if use_graph:
config["graph_store"] = {
"provider": "kuzu",
"config": {
"url": "/tmp/mem0_kuzu_eval",
},
}
self._memory = Memory.from_config(config)
self._use_graph = use_graph
def reset(self) -> None:
self._memory.delete_all(user_id=self._user_id)
self._user_id = f"eval_{_uuid.uuid4().hex[:8]}"
self._prev_snapshot_ids = set()
def ingest_turn(self, turn: NormalizedTurn) -> None:
self._session_id = turn.session_id
text = f"{turn.role}: {turn.text}"
for att in turn.attachments:
text += f"\n[{att.type}] {att.caption}"
# Truncate to avoid excessively long inputs that break graph entity extraction
text = text[:2000]
try:
self._memory.add(
messages=[{"role": turn.role, "content": text}],
user_id=self._user_id,
)
except Exception:
# Graph mode can fail on entity embedding; fall back silently
pass
def end_session(self, session_id: str) -> None:
self._session_id = session_id
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
all_mems = self._memory.get_all(user_id=self._user_id)
rows: list[MemorySnapshotRecord] = []
# Vector results (standard mode)
results = all_mems.get("results", []) if isinstance(all_mems, dict) else all_mems
for mem in results:
mid = str(mem.get("id", ""))
text = str(mem.get("memory", ""))
rows.append(MemorySnapshotRecord(
memory_id=mid, text=text,
session_id=self._session_id, status="active",
source="Mem0", raw_backend_id=mid,
raw_backend_type="mem0_vector", metadata={},
))
# Graph relations (graph mode)
relations = all_mems.get("relations", []) if isinstance(all_mems, dict) else []
for i, rel in enumerate(relations):
if isinstance(rel, dict):
src = rel.get("source", "")
rtype = rel.get("relationship", "")
tgt = rel.get("target") or rel.get("destination", "")
text = f"{src}{rtype}{tgt}"
mid = f"rel_{i}"
rows.append(MemorySnapshotRecord(
memory_id=mid, text=text,
session_id=self._session_id, status="active",
source="Mem0-Graph", raw_backend_id=mid,
raw_backend_type="mem0_graph_relation", metadata=rel,
))
return rows
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
current = self.snapshot_memories()
current_ids = {s.memory_id for s in current}
deltas = [
MemoryDeltaRecord(
session_id=session_id,
op="add",
text=s.text,
linked_previous=(),
raw_backend_id=s.raw_backend_id,
metadata={"baseline": "Mem0"},
)
for s in current if s.memory_id not in self._prev_snapshot_ids
]
self._prev_snapshot_ids = current_ids
return deltas
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
results = self._memory.search(query=query, user_id=self._user_id, limit=top_k)
items: list[RetrievalItem] = []
# Vector results
search_results = results.get("results", []) if isinstance(results, dict) else results
for i, r in enumerate(search_results[:top_k]):
items.append(RetrievalItem(
rank=len(items),
memory_id=str(r.get("id", i)),
text=str(r.get("memory", "")),
score=float(r.get("score", 1.0 / (i + 1))),
raw_backend_id=str(r.get("id", "")),
))
# Graph relations
relations = results.get("relations", []) if isinstance(results, dict) else []
for rel in relations:
if isinstance(rel, dict) and len(items) < top_k:
src = rel.get("source", "")
rtype = rel.get("relationship", "")
tgt = rel.get("target") or rel.get("destination", "")
items.append(RetrievalItem(
rank=len(items),
memory_id=f"rel_{len(items)}",
text=f"{src}{rtype}{tgt}",
score=0.9,
raw_backend_id=None,
))
return RetrievalRecord(
query=query, top_k=top_k, items=items[:top_k],
raw_trace={"baseline": "Mem0-Graph" if self._use_graph else "Mem0"},
)
def get_capabilities(self) -> dict[str, Any]:
return {
"backend": "Mem0-Graph" if self._use_graph else "Mem0",
"baseline": "Mem0-Graph" if self._use_graph else "Mem0",
"available": True,
"delta_granularity": "snapshot_diff",
"snapshot_mode": "full",
}