File size: 6,738 Bytes
85b19cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | """Adapters for Mem0 and Mem0-Graph baselines."""
from __future__ import annotations
import os
import uuid as _uuid
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parents[2] / ".env")
from eval_framework.datasets.schemas import (
MemoryDeltaRecord,
MemorySnapshotRecord,
NormalizedTurn,
RetrievalItem,
RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter
class Mem0Adapter(MemoryAdapter):
"""Adapter for Mem0 (vector mode)."""
def __init__(self, *, use_graph: bool = False, **kwargs: Any) -> None:
from mem0 import Memory
self._user_id = f"eval_{_uuid.uuid4().hex[:8]}"
self._session_id = ""
self._prev_snapshot_ids: set[str] = set()
config: dict[str, Any] = {
"llm": {
"provider": "openai",
"config": {
"model": os.getenv("OPENAI_MODEL") or "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY") or "",
},
},
"embedder": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": os.getenv("OPENAI_API_KEY") or "",
"embedding_dims": 1536,
},
},
}
base_url = os.getenv("OPENAI_BASE_URL")
if base_url:
config["llm"]["config"]["openai_base_url"] = base_url
config["embedder"]["config"]["openai_base_url"] = base_url
if use_graph:
config["graph_store"] = {
"provider": "kuzu",
"config": {
"url": "/tmp/mem0_kuzu_eval",
},
}
self._memory = Memory.from_config(config)
self._use_graph = use_graph
def reset(self) -> None:
self._memory.delete_all(user_id=self._user_id)
self._user_id = f"eval_{_uuid.uuid4().hex[:8]}"
self._prev_snapshot_ids = set()
def ingest_turn(self, turn: NormalizedTurn) -> None:
self._session_id = turn.session_id
text = f"{turn.role}: {turn.text}"
for att in turn.attachments:
text += f"\n[{att.type}] {att.caption}"
# Truncate to avoid excessively long inputs that break graph entity extraction
text = text[:2000]
try:
self._memory.add(
messages=[{"role": turn.role, "content": text}],
user_id=self._user_id,
)
except Exception:
# Graph mode can fail on entity embedding; fall back silently
pass
def end_session(self, session_id: str) -> None:
self._session_id = session_id
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
all_mems = self._memory.get_all(user_id=self._user_id)
rows: list[MemorySnapshotRecord] = []
# Vector results (standard mode)
results = all_mems.get("results", []) if isinstance(all_mems, dict) else all_mems
for mem in results:
mid = str(mem.get("id", ""))
text = str(mem.get("memory", ""))
rows.append(MemorySnapshotRecord(
memory_id=mid, text=text,
session_id=self._session_id, status="active",
source="Mem0", raw_backend_id=mid,
raw_backend_type="mem0_vector", metadata={},
))
# Graph relations (graph mode)
relations = all_mems.get("relations", []) if isinstance(all_mems, dict) else []
for i, rel in enumerate(relations):
if isinstance(rel, dict):
src = rel.get("source", "")
rtype = rel.get("relationship", "")
tgt = rel.get("target") or rel.get("destination", "")
text = f"{src} → {rtype} → {tgt}"
mid = f"rel_{i}"
rows.append(MemorySnapshotRecord(
memory_id=mid, text=text,
session_id=self._session_id, status="active",
source="Mem0-Graph", raw_backend_id=mid,
raw_backend_type="mem0_graph_relation", metadata=rel,
))
return rows
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
current = self.snapshot_memories()
current_ids = {s.memory_id for s in current}
deltas = [
MemoryDeltaRecord(
session_id=session_id,
op="add",
text=s.text,
linked_previous=(),
raw_backend_id=s.raw_backend_id,
metadata={"baseline": "Mem0"},
)
for s in current if s.memory_id not in self._prev_snapshot_ids
]
self._prev_snapshot_ids = current_ids
return deltas
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
results = self._memory.search(query=query, user_id=self._user_id, limit=top_k)
items: list[RetrievalItem] = []
# Vector results
search_results = results.get("results", []) if isinstance(results, dict) else results
for i, r in enumerate(search_results[:top_k]):
items.append(RetrievalItem(
rank=len(items),
memory_id=str(r.get("id", i)),
text=str(r.get("memory", "")),
score=float(r.get("score", 1.0 / (i + 1))),
raw_backend_id=str(r.get("id", "")),
))
# Graph relations
relations = results.get("relations", []) if isinstance(results, dict) else []
for rel in relations:
if isinstance(rel, dict) and len(items) < top_k:
src = rel.get("source", "")
rtype = rel.get("relationship", "")
tgt = rel.get("target") or rel.get("destination", "")
items.append(RetrievalItem(
rank=len(items),
memory_id=f"rel_{len(items)}",
text=f"{src} → {rtype} → {tgt}",
score=0.9,
raw_backend_id=None,
))
return RetrievalRecord(
query=query, top_k=top_k, items=items[:top_k],
raw_trace={"baseline": "Mem0-Graph" if self._use_graph else "Mem0"},
)
def get_capabilities(self) -> dict[str, Any]:
return {
"backend": "Mem0-Graph" if self._use_graph else "Mem0",
"baseline": "Mem0-Graph" if self._use_graph else "Mem0",
"available": True,
"delta_granularity": "snapshot_diff",
"snapshot_mode": "full",
}
|