| """Adapters for Mem0 and Mem0-Graph baselines.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import uuid as _uuid |
| from pathlib import Path |
| from typing import Any |
|
|
| from dotenv import load_dotenv |
|
|
| load_dotenv(Path(__file__).resolve().parents[2] / ".env") |
|
|
| from eval_framework.datasets.schemas import ( |
| MemoryDeltaRecord, |
| MemorySnapshotRecord, |
| NormalizedTurn, |
| RetrievalItem, |
| RetrievalRecord, |
| ) |
| from eval_framework.memory_adapters.base import MemoryAdapter |
|
|
|
|
| class Mem0Adapter(MemoryAdapter): |
| """Adapter for Mem0 (vector mode).""" |
|
|
| def __init__(self, *, use_graph: bool = False, **kwargs: Any) -> None: |
| from mem0 import Memory |
|
|
| self._user_id = f"eval_{_uuid.uuid4().hex[:8]}" |
| self._session_id = "" |
| self._prev_snapshot_ids: set[str] = set() |
|
|
| config: dict[str, Any] = { |
| "llm": { |
| "provider": "openai", |
| "config": { |
| "model": os.getenv("OPENAI_MODEL") or "gpt-4o", |
| "api_key": os.getenv("OPENAI_API_KEY") or "", |
| }, |
| }, |
| "embedder": { |
| "provider": "openai", |
| "config": { |
| "model": "text-embedding-3-small", |
| "api_key": os.getenv("OPENAI_API_KEY") or "", |
| "embedding_dims": 1536, |
| }, |
| }, |
| } |
|
|
| base_url = os.getenv("OPENAI_BASE_URL") |
| if base_url: |
| config["llm"]["config"]["openai_base_url"] = base_url |
| config["embedder"]["config"]["openai_base_url"] = base_url |
|
|
| if use_graph: |
| config["graph_store"] = { |
| "provider": "kuzu", |
| "config": { |
| "url": "/tmp/mem0_kuzu_eval", |
| }, |
| } |
|
|
| self._memory = Memory.from_config(config) |
| self._use_graph = use_graph |
|
|
| def reset(self) -> None: |
| self._memory.delete_all(user_id=self._user_id) |
| self._user_id = f"eval_{_uuid.uuid4().hex[:8]}" |
| self._prev_snapshot_ids = set() |
|
|
| def ingest_turn(self, turn: NormalizedTurn) -> None: |
| self._session_id = turn.session_id |
| text = f"{turn.role}: {turn.text}" |
| for att in turn.attachments: |
| text += f"\n[{att.type}] {att.caption}" |
| |
| text = text[:2000] |
| try: |
| self._memory.add( |
| messages=[{"role": turn.role, "content": text}], |
| user_id=self._user_id, |
| ) |
| except Exception: |
| |
| pass |
|
|
| def end_session(self, session_id: str) -> None: |
| self._session_id = session_id |
|
|
| def snapshot_memories(self) -> list[MemorySnapshotRecord]: |
| all_mems = self._memory.get_all(user_id=self._user_id) |
| rows: list[MemorySnapshotRecord] = [] |
|
|
| |
| results = all_mems.get("results", []) if isinstance(all_mems, dict) else all_mems |
| for mem in results: |
| mid = str(mem.get("id", "")) |
| text = str(mem.get("memory", "")) |
| rows.append(MemorySnapshotRecord( |
| memory_id=mid, text=text, |
| session_id=self._session_id, status="active", |
| source="Mem0", raw_backend_id=mid, |
| raw_backend_type="mem0_vector", metadata={}, |
| )) |
|
|
| |
| relations = all_mems.get("relations", []) if isinstance(all_mems, dict) else [] |
| for i, rel in enumerate(relations): |
| if isinstance(rel, dict): |
| src = rel.get("source", "") |
| rtype = rel.get("relationship", "") |
| tgt = rel.get("target") or rel.get("destination", "") |
| text = f"{src} → {rtype} → {tgt}" |
| mid = f"rel_{i}" |
| rows.append(MemorySnapshotRecord( |
| memory_id=mid, text=text, |
| session_id=self._session_id, status="active", |
| source="Mem0-Graph", raw_backend_id=mid, |
| raw_backend_type="mem0_graph_relation", metadata=rel, |
| )) |
|
|
| return rows |
|
|
| def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: |
| current = self.snapshot_memories() |
| current_ids = {s.memory_id for s in current} |
| deltas = [ |
| MemoryDeltaRecord( |
| session_id=session_id, |
| op="add", |
| text=s.text, |
| linked_previous=(), |
| raw_backend_id=s.raw_backend_id, |
| metadata={"baseline": "Mem0"}, |
| ) |
| for s in current if s.memory_id not in self._prev_snapshot_ids |
| ] |
| self._prev_snapshot_ids = current_ids |
| return deltas |
|
|
| def retrieve(self, query: str, top_k: int) -> RetrievalRecord: |
| results = self._memory.search(query=query, user_id=self._user_id, limit=top_k) |
| items: list[RetrievalItem] = [] |
|
|
| |
| search_results = results.get("results", []) if isinstance(results, dict) else results |
| for i, r in enumerate(search_results[:top_k]): |
| items.append(RetrievalItem( |
| rank=len(items), |
| memory_id=str(r.get("id", i)), |
| text=str(r.get("memory", "")), |
| score=float(r.get("score", 1.0 / (i + 1))), |
| raw_backend_id=str(r.get("id", "")), |
| )) |
|
|
| |
| relations = results.get("relations", []) if isinstance(results, dict) else [] |
| for rel in relations: |
| if isinstance(rel, dict) and len(items) < top_k: |
| src = rel.get("source", "") |
| rtype = rel.get("relationship", "") |
| tgt = rel.get("target") or rel.get("destination", "") |
| items.append(RetrievalItem( |
| rank=len(items), |
| memory_id=f"rel_{len(items)}", |
| text=f"{src} → {rtype} → {tgt}", |
| score=0.9, |
| raw_backend_id=None, |
| )) |
|
|
| return RetrievalRecord( |
| query=query, top_k=top_k, items=items[:top_k], |
| raw_trace={"baseline": "Mem0-Graph" if self._use_graph else "Mem0"}, |
| ) |
|
|
| def get_capabilities(self) -> dict[str, Any]: |
| return { |
| "backend": "Mem0-Graph" if self._use_graph else "Mem0", |
| "baseline": "Mem0-Graph" if self._use_graph else "Mem0", |
| "available": True, |
| "delta_granularity": "snapshot_diff", |
| "snapshot_mode": "full", |
| } |
|
|