"""Mem-Gallery native baseline wrappers with conservative schema normalization.""" from __future__ import annotations import copy from typing import Any, Callable from eval_framework.datasets.schemas import ( MemoryDeltaRecord, MemorySnapshotRecord, NormalizedTurn, RetrievalRecord, ) from eval_framework.memory_adapters.base import MemoryAdapter from eval_framework.memory_adapters.export_utils import ( linear_element_to_snapshot, memory_element_text, normalize_recall_to_retrieval, turn_to_observation_dict, ) def _deep_merge_dict(base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]: out = copy.deepcopy(base) for key, val in overrides.items(): if ( key in out and isinstance(out[key], dict) and isinstance(val, dict) ): out[key] = _deep_merge_dict(out[key], val) else: out[key] = copy.deepcopy(val) return out def _default_config_for_baseline(name: str) -> dict[str, Any]: import default_config.DefaultMemoryConfig as dmc # type: ignore[import-not-found] key = { "FUMemory": "DEFAULT_FUMEMORY", "STMemory": "DEFAULT_STMEMORY", "LTMemory": "DEFAULT_LTMEMORY", "GAMemory": "DEFAULT_GAMEMORY", "MGMemory": "DEFAULT_MGMEMORY", "RFMemory": "DEFAULT_RFMEMORY", "MMMemory": "DEFAULT_MMMEMORY", "MMFUMemory": "DEFAULT_MMFUMEMORY", "NGMemory": "DEFAULT_NGMEMORY", "AUGUSTUSMemory": "DEFAULT_AUGUSTUSMEMORY", "UniversalRAGMemory": "DEFAULT_UNIVERSALRAGMEMORY", }[name] cfg = getattr(dmc, key) return copy.deepcopy(cfg) def _import_memory_class(name: str) -> Callable[..., Any]: modmap = { "FUMemory": ("memengine.memory.FUMemory", "FUMemory"), "STMemory": ("memengine.memory.STMemory", "STMemory"), "LTMemory": ("memengine.memory.LTMemory", "LTMemory"), "GAMemory": ("memengine.memory.GAMemory", "GAMemory"), "MGMemory": ("memengine.memory.MGMemory", "MGMemory"), "RFMemory": ("memengine.memory.RFMemory", "RFMemory"), "MMMemory": ("memengine.memory.MMMemory", "MMMemory"), "MMFUMemory": ("memengine.memory.MMFUMemory", "MMFUMemory"), "NGMemory": ("memengine.memory.NGMemory", "NGMemory"), "AUGUSTUSMemory": ("memengine.memory.AUGUSTUSMemory", "AUGUSTUSMemory"), "UniversalRAGMemory": ("memengine.memory.UniversalRAGMemory", "UniversalRAGMemory"), } module_path, cls_name = modmap[name] import importlib mod = importlib.import_module(module_path) return getattr(mod, cls_name) def instantiate_memgallery_memory( baseline_name: str, config: dict[str, Any] | None = None, ) -> Any: """Construct a Mem-Gallery memory object with optional config overrides.""" base_cfg = _default_config_for_baseline(baseline_name) merged = _deep_merge_dict(base_cfg, config or {}) from memengine.config.Config import MemoryConfig # type: ignore[import-not-found] cls = _import_memory_class(baseline_name) return cls(MemoryConfig(merged)) def _graph_nodes_to_snapshots( storage: Any, *, session_id: str, source: str, include_concepts: bool = False, ) -> list[MemorySnapshotRecord]: out: list[MemorySnapshotRecord] = [] order = getattr(storage, "memory_order_map", []) or [] node_concepts = getattr(storage, "node_concepts", {}) for mid_idx, node_id in enumerate(order): node = storage.node[node_id] cid = node.get("counter_id", mid_idx) memory_id = f"n{node_id}" text = memory_element_text(node) # For AUGUSTUS: append concept tags extracted by the system if include_concepts: concepts = node_concepts.get(node_id, set()) if concepts: text = f"{text}\n[concepts] {', '.join(sorted(concepts))}" out.append( MemorySnapshotRecord( memory_id=memory_id, text=text, session_id=session_id, status="active", source=source, raw_backend_id=str(cid), raw_backend_type="graph_node", metadata={"node_id": node_id}, ) ) return out def _linear_storage_snapshots( storage: Any, *, session_id: str, source: str, ) -> list[MemorySnapshotRecord]: rows: list[MemorySnapshotRecord] = [] for i, m in enumerate(storage.memory_list): cid = m.get("counter_id", i) rows.append( linear_element_to_snapshot( m, memory_id=str(cid), session_id=session_id, source=source, ) ) return rows def collect_memgallery_snapshots( memory: Any, baseline_name: str, session_id: str, ) -> list[MemorySnapshotRecord]: """Best-effort snapshot of backend-visible memories.""" source = baseline_name if baseline_name == "MGMemory": out: list[MemorySnapshotRecord] = [] # store_op/recall_op have their own main_context references; # prefer store_op's view as it holds the actual stored data. mc = getattr(memory.store_op, "main_context", None) or memory.main_context recall_storage = getattr(memory.recall_op, "recall_storage", getattr(memory, "recall_storage", None)) archival_storage = getattr(memory.recall_op, "archival_storage", getattr(memory, "archival_storage", None)) storages = [("wm", mc["working_context"]), ("fifo", mc["FIFO_queue"])] if recall_storage is not None: storages.append(("recall", recall_storage)) if archival_storage is not None: storages.append(("archival", archival_storage)) for prefix, st in storages: for i, m in enumerate(st.memory_list): cid = m.get("counter_id", i) mid = f"{prefix}-{cid}" rows = linear_element_to_snapshot( m, memory_id=mid, session_id=session_id, source=source, ) out.append(rows) gsum = mc.get("recursive_summary", {}).get("global") if gsum and str(gsum) != "None": out.append( MemorySnapshotRecord( memory_id="recursive_summary", text=str(gsum), session_id=session_id, status="active", source=source, raw_backend_id=None, raw_backend_type="mg_summary", metadata={}, ) ) return out if baseline_name == "RFMemory": rows = _linear_storage_snapshots( memory.storage, session_id=session_id, source=source ) insight = getattr(memory, "insight", {}).get("global_insight", "") if insight: rows.append( MemorySnapshotRecord( memory_id="rf_insight", text=str(insight), session_id=session_id, status="active", source=source, raw_backend_id=None, raw_backend_type="rf_insight", metadata={}, ) ) return rows if baseline_name == "NGMemory": return _graph_nodes_to_snapshots( memory.storage, session_id=session_id, source=source ) if baseline_name == "AUGUSTUSMemory": return _graph_nodes_to_snapshots( memory.contextual_memory, session_id=session_id, source=source, include_concepts=True, ) if baseline_name == "UniversalRAGMemory": return _linear_storage_snapshots( memory.storage, session_id=session_id, source=source ) if hasattr(memory, "storage") and hasattr(memory.storage, "memory_list"): return _linear_storage_snapshots( memory.storage, session_id=session_id, source=source ) return [] class MemGalleryNativeAdapter(MemoryAdapter): """Thin wrapper that forwards to Mem-Gallery memories and normalizes I/O.""" def __init__(self, memory: Any, *, baseline_name: str) -> None: self._memory = memory self._baseline_name = baseline_name self._session_id = "" self._prev_snapshot_ids: set[str] = set() self._pending_user_turn: NormalizedTurn | None = None self._session_turns: list[str] = [] # collect turn texts for RF optimize @classmethod def from_baseline( cls, baseline_name: str, *, config: dict[str, Any] | None = None, ) -> MemGalleryNativeAdapter: mem = instantiate_memgallery_memory(baseline_name, config) return cls(mem, baseline_name=baseline_name) def ingest_turn(self, turn: NormalizedTurn) -> None: """Buffer user turns; store merged user+assistant pair on assistant turn. This matches the original Mem-Gallery benchmark behavior where each dialogue round (user + assistant) is merged into a single observation before calling store(). """ self._session_id = turn.session_id if turn.role == "user": # Flush any prior unpaired user turn, then buffer this one if self._pending_user_turn is not None: self._store_observation(self._pending_user_turn, assistant_turn=None) self._pending_user_turn = turn else: # Assistant turn: merge with buffered user turn and store self._store_observation(self._pending_user_turn, assistant_turn=turn) self._pending_user_turn = None def _store_observation( self, user_turn: NormalizedTurn | None, assistant_turn: NormalizedTurn | None, ) -> None: """Build a merged observation dict (matching original benchmark format) and store.""" parts: list[str] = [] timestamp = None dialogue_id = "" if user_turn is not None: parts.append(f"user: {user_turn.text}") for att in user_turn.attachments: parts.append(f"[{att.type}] {att.caption}") timestamp = user_turn.timestamp dialogue_id = f"{user_turn.session_id}:{user_turn.turn_index}" if assistant_turn is not None: parts.append(f"assistant: {assistant_turn.text}") for att in assistant_turn.attachments: parts.append(f"[{att.type}] {att.caption}") if timestamp is None: timestamp = assistant_turn.timestamp if not dialogue_id: dialogue_id = f"{assistant_turn.session_id}:{assistant_turn.turn_index}" obs: dict[str, Any] = {"text": "\n".join(parts)} if timestamp: obs["timestamp"] = timestamp obs["dialogue_id"] = dialogue_id self._memory.store(obs) self._session_turns.append(obs["text"]) def end_session(self, session_id: str) -> None: # Flush any remaining unpaired user turn if self._pending_user_turn is not None: self._store_observation(self._pending_user_turn, assistant_turn=None) self._pending_user_turn = None # --- Trigger backend-specific post-session processing --- # GAMemory: self-reflection generates insights and stores them if self._baseline_name == "GAMemory": try: self._memory.manage("reflect") except Exception: pass # reflection may fail if accumulated importance < threshold # RFMemory: optimize generates a global insight from the session trial if self._baseline_name == "RFMemory" and self._session_turns: try: trial = "\n".join(self._session_turns) self._memory.optimize(new_trial=trial) except Exception: pass self._session_turns = [] def snapshot_memories(self) -> list[MemorySnapshotRecord]: sid = self._session_id or "" return collect_memgallery_snapshots( self._memory, self._baseline_name, sid ) def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: """Export delta by diffing current backend snapshot against previous snapshot. This reflects what the backend ACTUALLY stores, not what was fed in. For FU/ST/LT/GA/RF (LinearStorage), this is the raw observations added. For MGMemory, this includes FIFO items, summaries, and archival entries. """ current_snapshot = self.snapshot_memories() prev_ids = self._prev_snapshot_ids deltas: list[MemoryDeltaRecord] = [] current_ids: set[str] = set() for snap in current_snapshot: current_ids.add(snap.memory_id) if snap.memory_id not in prev_ids: deltas.append( MemoryDeltaRecord( session_id=session_id, op="add", text=snap.text, linked_previous=(), raw_backend_id=snap.raw_backend_id, metadata={ "baseline": self._baseline_name, "source": snap.source, "backend_type": snap.raw_backend_type, }, ) ) self._prev_snapshot_ids = current_ids return deltas def reset(self) -> None: self._memory.reset() self._prev_snapshot_ids = set() self._pending_user_turn = None self._session_turns = [] def retrieve(self, query: str, top_k: int) -> RetrievalRecord: raw = self._memory.recall(query) trace: dict[str, Any] = {"baseline": self._baseline_name} ro = getattr(self._memory, "recall_op", None) if ro is not None and hasattr(ro, "last_retrieved_ids"): trace["last_retrieved_ids"] = list(ro.last_retrieved_ids) return normalize_recall_to_retrieval(query, top_k, raw, raw_trace=trace) def get_capabilities(self) -> dict[str, Any]: return { "backend": "MemGallery", "baseline": self._baseline_name, "delta_granularity": "ingest_turn_only", "snapshot_mode": "conservative", "notes": ( "Deltas record adapter ingest only; backend-internal rewrite, reflection, " "or graph reshaping is not diffed. Snapshots read observable storage where supported." ), }