eval_framework / memory_adapters /simplemem_adapter.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Adapter for SimpleMem and Omni-SimpleMem baselines."""
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import Any
from eval_framework.datasets.schemas import (
MemoryDeltaRecord,
MemorySnapshotRecord,
NormalizedTurn,
RetrievalItem,
RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter
_DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/SimpleMem")
class SimpleMemAdapter(MemoryAdapter):
"""Adapter for SimpleMem (text mode) or Omni-SimpleMem (omni mode)."""
def __init__(
self,
*,
mode: str = "text",
source_root: str | os.PathLike[str] | None = None,
**kwargs: Any,
) -> None:
self._mode = mode # "text" or "omni"
root = Path(source_root or _DEFAULT_SOURCE).resolve()
if str(root) not in sys.path:
sys.path.insert(0, str(root))
import simplemem_router as simplemem
self._simplemem = simplemem
self._mem: Any = None
self._session_id = ""
self._prev_snapshot_ids: set[str] = set()
self._stored_texts: list[dict[str, str]] = []
self._init_mem()
def _init_mem(self) -> None:
self._mem = self._simplemem.create(mode=self._mode, clear_db=True)
self._stored_texts = []
def reset(self) -> None:
if self._mem is not None:
try:
self._mem.close()
except Exception:
pass
self._init_mem()
self._prev_snapshot_ids = set()
def ingest_turn(self, turn: NormalizedTurn) -> None:
self._session_id = turn.session_id
text = f"{turn.role}: {turn.text}"
for att in turn.attachments:
text += f"\n[{att.type}] {att.caption}"
mid = str(len(self._stored_texts))
if self._mode == "omni":
self._mem.add_text(text, tags=[f"session:{turn.session_id}"])
else:
speaker = "User" if turn.role == "user" else "Assistant"
ts = turn.timestamp or ""
self._mem.add_dialogue(speaker, text, ts)
self._stored_texts.append({"id": mid, "text": text, "session_id": turn.session_id})
def end_session(self, session_id: str) -> None:
self._session_id = session_id
if self._mode == "text":
try:
self._mem.finalize()
except Exception:
pass
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
return [
MemorySnapshotRecord(
memory_id=t["id"], text=t["text"],
session_id=t["session_id"], status="active",
source=f"SimpleMem-{self._mode}",
raw_backend_id=t["id"], raw_backend_type="simplemem",
metadata={},
)
for t in self._stored_texts
]
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
current = self.snapshot_memories()
current_ids = {s.memory_id for s in current}
deltas = [
MemoryDeltaRecord(
session_id=session_id, op="add", text=s.text,
linked_previous=(), raw_backend_id=s.raw_backend_id,
metadata={"baseline": f"SimpleMem-{self._mode}"},
)
for s in current if s.memory_id not in self._prev_snapshot_ids
]
self._prev_snapshot_ids = current_ids
return deltas
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
items: list[RetrievalItem] = []
try:
if self._mode == "omni":
result = self._mem.query(query, top_k=top_k)
if isinstance(result, list):
for i, r in enumerate(result[:top_k]):
text = r.get("text", str(r)) if isinstance(r, dict) else str(r)
items.append(RetrievalItem(
rank=i, memory_id=str(i), text=text,
score=1.0 / (i + 1), raw_backend_id=None,
))
else:
answer = self._mem.ask(query)
if answer:
items.append(RetrievalItem(
rank=0, memory_id="answer", text=str(answer),
score=1.0, raw_backend_id=None,
))
except Exception:
pass
if not items:
# Fallback: simple text search over stored memories
query_lower = query.lower()
scored = []
for t in self._stored_texts:
overlap = len(set(query_lower.split()) & set(t["text"].lower().split()))
scored.append((overlap, t))
scored.sort(key=lambda x: x[0], reverse=True)
for i, (sc, t) in enumerate(scored[:top_k]):
items.append(RetrievalItem(
rank=i, memory_id=t["id"], text=t["text"],
score=float(sc) / max(len(query.split()), 1),
raw_backend_id=t["id"],
))
return RetrievalRecord(
query=query, top_k=top_k, items=items[:top_k],
raw_trace={"baseline": f"SimpleMem-{self._mode}"},
)
def get_capabilities(self) -> dict[str, Any]:
name = "Omni-SimpleMem" if self._mode == "omni" else "SimpleMem"
return {
"backend": name, "baseline": name,
"available": self._mem is not None,
"delta_granularity": "per_turn",
"snapshot_mode": "full",
}