eval_framework / memory_adapters /memgallery_native.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Mem-Gallery native baseline wrappers with conservative schema normalization."""
from __future__ import annotations
import copy
from typing import Any, Callable
from eval_framework.datasets.schemas import (
MemoryDeltaRecord,
MemorySnapshotRecord,
NormalizedTurn,
RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter
from eval_framework.memory_adapters.export_utils import (
linear_element_to_snapshot,
memory_element_text,
normalize_recall_to_retrieval,
turn_to_observation_dict,
)
def _deep_merge_dict(base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
out = copy.deepcopy(base)
for key, val in overrides.items():
if (
key in out
and isinstance(out[key], dict)
and isinstance(val, dict)
):
out[key] = _deep_merge_dict(out[key], val)
else:
out[key] = copy.deepcopy(val)
return out
def _default_config_for_baseline(name: str) -> dict[str, Any]:
import default_config.DefaultMemoryConfig as dmc # type: ignore[import-not-found]
key = {
"FUMemory": "DEFAULT_FUMEMORY",
"STMemory": "DEFAULT_STMEMORY",
"LTMemory": "DEFAULT_LTMEMORY",
"GAMemory": "DEFAULT_GAMEMORY",
"MGMemory": "DEFAULT_MGMEMORY",
"RFMemory": "DEFAULT_RFMEMORY",
"MMMemory": "DEFAULT_MMMEMORY",
"MMFUMemory": "DEFAULT_MMFUMEMORY",
"NGMemory": "DEFAULT_NGMEMORY",
"AUGUSTUSMemory": "DEFAULT_AUGUSTUSMEMORY",
"UniversalRAGMemory": "DEFAULT_UNIVERSALRAGMEMORY",
}[name]
cfg = getattr(dmc, key)
return copy.deepcopy(cfg)
def _import_memory_class(name: str) -> Callable[..., Any]:
modmap = {
"FUMemory": ("memengine.memory.FUMemory", "FUMemory"),
"STMemory": ("memengine.memory.STMemory", "STMemory"),
"LTMemory": ("memengine.memory.LTMemory", "LTMemory"),
"GAMemory": ("memengine.memory.GAMemory", "GAMemory"),
"MGMemory": ("memengine.memory.MGMemory", "MGMemory"),
"RFMemory": ("memengine.memory.RFMemory", "RFMemory"),
"MMMemory": ("memengine.memory.MMMemory", "MMMemory"),
"MMFUMemory": ("memengine.memory.MMFUMemory", "MMFUMemory"),
"NGMemory": ("memengine.memory.NGMemory", "NGMemory"),
"AUGUSTUSMemory": ("memengine.memory.AUGUSTUSMemory", "AUGUSTUSMemory"),
"UniversalRAGMemory": ("memengine.memory.UniversalRAGMemory", "UniversalRAGMemory"),
}
module_path, cls_name = modmap[name]
import importlib
mod = importlib.import_module(module_path)
return getattr(mod, cls_name)
def instantiate_memgallery_memory(
baseline_name: str,
config: dict[str, Any] | None = None,
) -> Any:
"""Construct a Mem-Gallery memory object with optional config overrides."""
base_cfg = _default_config_for_baseline(baseline_name)
merged = _deep_merge_dict(base_cfg, config or {})
from memengine.config.Config import MemoryConfig # type: ignore[import-not-found]
cls = _import_memory_class(baseline_name)
return cls(MemoryConfig(merged))
def _graph_nodes_to_snapshots(
storage: Any,
*,
session_id: str,
source: str,
include_concepts: bool = False,
) -> list[MemorySnapshotRecord]:
out: list[MemorySnapshotRecord] = []
order = getattr(storage, "memory_order_map", []) or []
node_concepts = getattr(storage, "node_concepts", {})
for mid_idx, node_id in enumerate(order):
node = storage.node[node_id]
cid = node.get("counter_id", mid_idx)
memory_id = f"n{node_id}"
text = memory_element_text(node)
# For AUGUSTUS: append concept tags extracted by the system
if include_concepts:
concepts = node_concepts.get(node_id, set())
if concepts:
text = f"{text}\n[concepts] {', '.join(sorted(concepts))}"
out.append(
MemorySnapshotRecord(
memory_id=memory_id,
text=text,
session_id=session_id,
status="active",
source=source,
raw_backend_id=str(cid),
raw_backend_type="graph_node",
metadata={"node_id": node_id},
)
)
return out
def _linear_storage_snapshots(
storage: Any,
*,
session_id: str,
source: str,
) -> list[MemorySnapshotRecord]:
rows: list[MemorySnapshotRecord] = []
for i, m in enumerate(storage.memory_list):
cid = m.get("counter_id", i)
rows.append(
linear_element_to_snapshot(
m,
memory_id=str(cid),
session_id=session_id,
source=source,
)
)
return rows
def collect_memgallery_snapshots(
memory: Any,
baseline_name: str,
session_id: str,
) -> list[MemorySnapshotRecord]:
"""Best-effort snapshot of backend-visible memories."""
source = baseline_name
if baseline_name == "MGMemory":
out: list[MemorySnapshotRecord] = []
# store_op/recall_op have their own main_context references;
# prefer store_op's view as it holds the actual stored data.
mc = getattr(memory.store_op, "main_context", None) or memory.main_context
recall_storage = getattr(memory.recall_op, "recall_storage",
getattr(memory, "recall_storage", None))
archival_storage = getattr(memory.recall_op, "archival_storage",
getattr(memory, "archival_storage", None))
storages = [("wm", mc["working_context"]), ("fifo", mc["FIFO_queue"])]
if recall_storage is not None:
storages.append(("recall", recall_storage))
if archival_storage is not None:
storages.append(("archival", archival_storage))
for prefix, st in storages:
for i, m in enumerate(st.memory_list):
cid = m.get("counter_id", i)
mid = f"{prefix}-{cid}"
rows = linear_element_to_snapshot(
m,
memory_id=mid,
session_id=session_id,
source=source,
)
out.append(rows)
gsum = mc.get("recursive_summary", {}).get("global")
if gsum and str(gsum) != "None":
out.append(
MemorySnapshotRecord(
memory_id="recursive_summary",
text=str(gsum),
session_id=session_id,
status="active",
source=source,
raw_backend_id=None,
raw_backend_type="mg_summary",
metadata={},
)
)
return out
if baseline_name == "RFMemory":
rows = _linear_storage_snapshots(
memory.storage, session_id=session_id, source=source
)
insight = getattr(memory, "insight", {}).get("global_insight", "")
if insight:
rows.append(
MemorySnapshotRecord(
memory_id="rf_insight",
text=str(insight),
session_id=session_id,
status="active",
source=source,
raw_backend_id=None,
raw_backend_type="rf_insight",
metadata={},
)
)
return rows
if baseline_name == "NGMemory":
return _graph_nodes_to_snapshots(
memory.storage, session_id=session_id, source=source
)
if baseline_name == "AUGUSTUSMemory":
return _graph_nodes_to_snapshots(
memory.contextual_memory, session_id=session_id, source=source,
include_concepts=True,
)
if baseline_name == "UniversalRAGMemory":
return _linear_storage_snapshots(
memory.storage, session_id=session_id, source=source
)
if hasattr(memory, "storage") and hasattr(memory.storage, "memory_list"):
return _linear_storage_snapshots(
memory.storage, session_id=session_id, source=source
)
return []
class MemGalleryNativeAdapter(MemoryAdapter):
"""Thin wrapper that forwards to Mem-Gallery memories and normalizes I/O."""
def __init__(self, memory: Any, *, baseline_name: str) -> None:
self._memory = memory
self._baseline_name = baseline_name
self._session_id = ""
self._prev_snapshot_ids: set[str] = set()
self._pending_user_turn: NormalizedTurn | None = None
self._session_turns: list[str] = [] # collect turn texts for RF optimize
@classmethod
def from_baseline(
cls,
baseline_name: str,
*,
config: dict[str, Any] | None = None,
) -> MemGalleryNativeAdapter:
mem = instantiate_memgallery_memory(baseline_name, config)
return cls(mem, baseline_name=baseline_name)
def ingest_turn(self, turn: NormalizedTurn) -> None:
"""Buffer user turns; store merged user+assistant pair on assistant turn.
This matches the original Mem-Gallery benchmark behavior where each
dialogue round (user + assistant) is merged into a single observation
before calling store().
"""
self._session_id = turn.session_id
if turn.role == "user":
# Flush any prior unpaired user turn, then buffer this one
if self._pending_user_turn is not None:
self._store_observation(self._pending_user_turn, assistant_turn=None)
self._pending_user_turn = turn
else:
# Assistant turn: merge with buffered user turn and store
self._store_observation(self._pending_user_turn, assistant_turn=turn)
self._pending_user_turn = None
def _store_observation(
self,
user_turn: NormalizedTurn | None,
assistant_turn: NormalizedTurn | None,
) -> None:
"""Build a merged observation dict (matching original benchmark format) and store."""
parts: list[str] = []
timestamp = None
dialogue_id = ""
if user_turn is not None:
parts.append(f"user: {user_turn.text}")
for att in user_turn.attachments:
parts.append(f"[{att.type}] {att.caption}")
timestamp = user_turn.timestamp
dialogue_id = f"{user_turn.session_id}:{user_turn.turn_index}"
if assistant_turn is not None:
parts.append(f"assistant: {assistant_turn.text}")
for att in assistant_turn.attachments:
parts.append(f"[{att.type}] {att.caption}")
if timestamp is None:
timestamp = assistant_turn.timestamp
if not dialogue_id:
dialogue_id = f"{assistant_turn.session_id}:{assistant_turn.turn_index}"
obs: dict[str, Any] = {"text": "\n".join(parts)}
if timestamp:
obs["timestamp"] = timestamp
obs["dialogue_id"] = dialogue_id
self._memory.store(obs)
self._session_turns.append(obs["text"])
def end_session(self, session_id: str) -> None:
# Flush any remaining unpaired user turn
if self._pending_user_turn is not None:
self._store_observation(self._pending_user_turn, assistant_turn=None)
self._pending_user_turn = None
# --- Trigger backend-specific post-session processing ---
# GAMemory: self-reflection generates insights and stores them
if self._baseline_name == "GAMemory":
try:
self._memory.manage("reflect")
except Exception:
pass # reflection may fail if accumulated importance < threshold
# RFMemory: optimize generates a global insight from the session trial
if self._baseline_name == "RFMemory" and self._session_turns:
try:
trial = "\n".join(self._session_turns)
self._memory.optimize(new_trial=trial)
except Exception:
pass
self._session_turns = []
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
sid = self._session_id or ""
return collect_memgallery_snapshots(
self._memory, self._baseline_name, sid
)
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
"""Export delta by diffing current backend snapshot against previous snapshot.
This reflects what the backend ACTUALLY stores, not what was fed in.
For FU/ST/LT/GA/RF (LinearStorage), this is the raw observations added.
For MGMemory, this includes FIFO items, summaries, and archival entries.
"""
current_snapshot = self.snapshot_memories()
prev_ids = self._prev_snapshot_ids
deltas: list[MemoryDeltaRecord] = []
current_ids: set[str] = set()
for snap in current_snapshot:
current_ids.add(snap.memory_id)
if snap.memory_id not in prev_ids:
deltas.append(
MemoryDeltaRecord(
session_id=session_id,
op="add",
text=snap.text,
linked_previous=(),
raw_backend_id=snap.raw_backend_id,
metadata={
"baseline": self._baseline_name,
"source": snap.source,
"backend_type": snap.raw_backend_type,
},
)
)
self._prev_snapshot_ids = current_ids
return deltas
def reset(self) -> None:
self._memory.reset()
self._prev_snapshot_ids = set()
self._pending_user_turn = None
self._session_turns = []
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
raw = self._memory.recall(query)
trace: dict[str, Any] = {"baseline": self._baseline_name}
ro = getattr(self._memory, "recall_op", None)
if ro is not None and hasattr(ro, "last_retrieved_ids"):
trace["last_retrieved_ids"] = list(ro.last_retrieved_ids)
return normalize_recall_to_retrieval(query, top_k, raw, raw_trace=trace)
def get_capabilities(self) -> dict[str, Any]:
return {
"backend": "MemGallery",
"baseline": self._baseline_name,
"delta_granularity": "ingest_turn_only",
"snapshot_mode": "conservative",
"notes": (
"Deltas record adapter ingest only; backend-internal rewrite, reflection, "
"or graph reshaping is not diffed. Snapshots read observable storage where supported."
),
}