"""Adapter for SimpleMem and Omni-SimpleMem baselines.""" from __future__ import annotations import os import sys from pathlib import Path from typing import Any from eval_framework.datasets.schemas import ( MemoryDeltaRecord, MemorySnapshotRecord, NormalizedTurn, RetrievalItem, RetrievalRecord, ) from eval_framework.memory_adapters.base import MemoryAdapter _DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/SimpleMem") class SimpleMemAdapter(MemoryAdapter): """Adapter for SimpleMem (text mode) or Omni-SimpleMem (omni mode).""" def __init__( self, *, mode: str = "text", source_root: str | os.PathLike[str] | None = None, **kwargs: Any, ) -> None: self._mode = mode # "text" or "omni" root = Path(source_root or _DEFAULT_SOURCE).resolve() if str(root) not in sys.path: sys.path.insert(0, str(root)) import simplemem_router as simplemem self._simplemem = simplemem self._mem: Any = None self._session_id = "" self._prev_snapshot_ids: set[str] = set() self._stored_texts: list[dict[str, str]] = [] self._init_mem() def _init_mem(self) -> None: self._mem = self._simplemem.create(mode=self._mode, clear_db=True) self._stored_texts = [] def reset(self) -> None: if self._mem is not None: try: self._mem.close() except Exception: pass self._init_mem() self._prev_snapshot_ids = set() def ingest_turn(self, turn: NormalizedTurn) -> None: self._session_id = turn.session_id text = f"{turn.role}: {turn.text}" for att in turn.attachments: text += f"\n[{att.type}] {att.caption}" mid = str(len(self._stored_texts)) if self._mode == "omni": self._mem.add_text(text, tags=[f"session:{turn.session_id}"]) else: speaker = "User" if turn.role == "user" else "Assistant" ts = turn.timestamp or "" self._mem.add_dialogue(speaker, text, ts) self._stored_texts.append({"id": mid, "text": text, "session_id": turn.session_id}) def end_session(self, session_id: str) -> None: self._session_id = session_id if self._mode == "text": try: self._mem.finalize() except Exception: pass def snapshot_memories(self) -> list[MemorySnapshotRecord]: return [ MemorySnapshotRecord( memory_id=t["id"], text=t["text"], session_id=t["session_id"], status="active", source=f"SimpleMem-{self._mode}", raw_backend_id=t["id"], raw_backend_type="simplemem", metadata={}, ) for t in self._stored_texts ] def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: current = self.snapshot_memories() current_ids = {s.memory_id for s in current} deltas = [ MemoryDeltaRecord( session_id=session_id, op="add", text=s.text, linked_previous=(), raw_backend_id=s.raw_backend_id, metadata={"baseline": f"SimpleMem-{self._mode}"}, ) for s in current if s.memory_id not in self._prev_snapshot_ids ] self._prev_snapshot_ids = current_ids return deltas def retrieve(self, query: str, top_k: int) -> RetrievalRecord: items: list[RetrievalItem] = [] try: if self._mode == "omni": result = self._mem.query(query, top_k=top_k) if isinstance(result, list): for i, r in enumerate(result[:top_k]): text = r.get("text", str(r)) if isinstance(r, dict) else str(r) items.append(RetrievalItem( rank=i, memory_id=str(i), text=text, score=1.0 / (i + 1), raw_backend_id=None, )) else: answer = self._mem.ask(query) if answer: items.append(RetrievalItem( rank=0, memory_id="answer", text=str(answer), score=1.0, raw_backend_id=None, )) except Exception: pass if not items: # Fallback: simple text search over stored memories query_lower = query.lower() scored = [] for t in self._stored_texts: overlap = len(set(query_lower.split()) & set(t["text"].lower().split())) scored.append((overlap, t)) scored.sort(key=lambda x: x[0], reverse=True) for i, (sc, t) in enumerate(scored[:top_k]): items.append(RetrievalItem( rank=i, memory_id=t["id"], text=t["text"], score=float(sc) / max(len(query.split()), 1), raw_backend_id=t["id"], )) return RetrievalRecord( query=query, top_k=top_k, items=items[:top_k], raw_trace={"baseline": f"SimpleMem-{self._mode}"}, ) def get_capabilities(self) -> dict[str, Any]: name = "Omni-SimpleMem" if self._mode == "omni" else "SimpleMem" return { "backend": name, "baseline": name, "available": self._mem is not None, "delta_granularity": "per_turn", "snapshot_mode": "full", }