"""Adapters for Mem0 and Mem0-Graph baselines.""" from __future__ import annotations import os import uuid as _uuid from pathlib import Path from typing import Any from dotenv import load_dotenv load_dotenv(Path(__file__).resolve().parents[2] / ".env") from eval_framework.datasets.schemas import ( MemoryDeltaRecord, MemorySnapshotRecord, NormalizedTurn, RetrievalItem, RetrievalRecord, ) from eval_framework.memory_adapters.base import MemoryAdapter class Mem0Adapter(MemoryAdapter): """Adapter for Mem0 (vector mode).""" def __init__(self, *, use_graph: bool = False, **kwargs: Any) -> None: from mem0 import Memory self._user_id = f"eval_{_uuid.uuid4().hex[:8]}" self._session_id = "" self._prev_snapshot_ids: set[str] = set() config: dict[str, Any] = { "llm": { "provider": "openai", "config": { "model": os.getenv("OPENAI_MODEL") or "gpt-4o", "api_key": os.getenv("OPENAI_API_KEY") or "", }, }, "embedder": { "provider": "openai", "config": { "model": "text-embedding-3-small", "api_key": os.getenv("OPENAI_API_KEY") or "", "embedding_dims": 1536, }, }, } base_url = os.getenv("OPENAI_BASE_URL") if base_url: config["llm"]["config"]["openai_base_url"] = base_url config["embedder"]["config"]["openai_base_url"] = base_url if use_graph: config["graph_store"] = { "provider": "kuzu", "config": { "url": "/tmp/mem0_kuzu_eval", }, } self._memory = Memory.from_config(config) self._use_graph = use_graph def reset(self) -> None: self._memory.delete_all(user_id=self._user_id) self._user_id = f"eval_{_uuid.uuid4().hex[:8]}" self._prev_snapshot_ids = set() def ingest_turn(self, turn: NormalizedTurn) -> None: self._session_id = turn.session_id text = f"{turn.role}: {turn.text}" for att in turn.attachments: text += f"\n[{att.type}] {att.caption}" # Truncate to avoid excessively long inputs that break graph entity extraction text = text[:2000] try: self._memory.add( messages=[{"role": turn.role, "content": text}], user_id=self._user_id, ) except Exception: # Graph mode can fail on entity embedding; fall back silently pass def end_session(self, session_id: str) -> None: self._session_id = session_id def snapshot_memories(self) -> list[MemorySnapshotRecord]: all_mems = self._memory.get_all(user_id=self._user_id) rows: list[MemorySnapshotRecord] = [] # Vector results (standard mode) results = all_mems.get("results", []) if isinstance(all_mems, dict) else all_mems for mem in results: mid = str(mem.get("id", "")) text = str(mem.get("memory", "")) rows.append(MemorySnapshotRecord( memory_id=mid, text=text, session_id=self._session_id, status="active", source="Mem0", raw_backend_id=mid, raw_backend_type="mem0_vector", metadata={}, )) # Graph relations (graph mode) relations = all_mems.get("relations", []) if isinstance(all_mems, dict) else [] for i, rel in enumerate(relations): if isinstance(rel, dict): src = rel.get("source", "") rtype = rel.get("relationship", "") tgt = rel.get("target") or rel.get("destination", "") text = f"{src} → {rtype} → {tgt}" mid = f"rel_{i}" rows.append(MemorySnapshotRecord( memory_id=mid, text=text, session_id=self._session_id, status="active", source="Mem0-Graph", raw_backend_id=mid, raw_backend_type="mem0_graph_relation", metadata=rel, )) return rows def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: current = self.snapshot_memories() current_ids = {s.memory_id for s in current} deltas = [ MemoryDeltaRecord( session_id=session_id, op="add", text=s.text, linked_previous=(), raw_backend_id=s.raw_backend_id, metadata={"baseline": "Mem0"}, ) for s in current if s.memory_id not in self._prev_snapshot_ids ] self._prev_snapshot_ids = current_ids return deltas def retrieve(self, query: str, top_k: int) -> RetrievalRecord: results = self._memory.search(query=query, user_id=self._user_id, limit=top_k) items: list[RetrievalItem] = [] # Vector results search_results = results.get("results", []) if isinstance(results, dict) else results for i, r in enumerate(search_results[:top_k]): items.append(RetrievalItem( rank=len(items), memory_id=str(r.get("id", i)), text=str(r.get("memory", "")), score=float(r.get("score", 1.0 / (i + 1))), raw_backend_id=str(r.get("id", "")), )) # Graph relations relations = results.get("relations", []) if isinstance(results, dict) else [] for rel in relations: if isinstance(rel, dict) and len(items) < top_k: src = rel.get("source", "") rtype = rel.get("relationship", "") tgt = rel.get("target") or rel.get("destination", "") items.append(RetrievalItem( rank=len(items), memory_id=f"rel_{len(items)}", text=f"{src} → {rtype} → {tgt}", score=0.9, raw_backend_id=None, )) return RetrievalRecord( query=query, top_k=top_k, items=items[:top_k], raw_trace={"baseline": "Mem0-Graph" if self._use_graph else "Mem0"}, ) def get_capabilities(self) -> dict[str, Any]: return { "backend": "Mem0-Graph" if self._use_graph else "Mem0", "baseline": "Mem0-Graph" if self._use_graph else "Mem0", "available": True, "delta_granularity": "snapshot_diff", "snapshot_mode": "full", }