LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Registry and factory for native Mem-Gallery and external placeholder adapters."""
from __future__ import annotations
import os
import sys
import types
from contextlib import nullcontext
from functools import partial
import importlib
from pathlib import Path
from typing import Any, Callable
from eval_framework.memory_adapters.amem import AMemAdapter
from eval_framework.memory_adapters.base import MemoryAdapter
from eval_framework.memory_adapters.dummy import DummyAdapter
from eval_framework.memory_adapters.memgallery_native import MemGalleryNativeAdapter
from eval_framework.memory_adapters.memoryos import MemoryOSAdapter
MEMGALLERY_NATIVE_BASELINES: frozenset[str] = frozenset(
{
"FUMemory",
"STMemory",
"LTMemory",
"GAMemory",
"MGMemory",
"RFMemory",
"MMMemory",
"MMFUMemory",
"NGMemory",
"AUGUSTUSMemory",
"UniversalRAGMemory",
}
)
def _word_mode_truncation(number: int = 50_000) -> dict[str, Any]:
return {
"method": "LMTruncation",
"mode": "word",
"number": number,
"path": "",
}
def _text_encoder_override() -> dict[str, Any]:
return {
"method": "STEncoder",
"path": "all-MiniLM-L6-v2",
}
def _openai_llm_override() -> dict[str, Any]:
return {
"method": "APILLM",
"name": os.getenv("OPENAI_MODEL") or "gpt-5.1",
"api_key": os.getenv("OPENAI_API_KEY") or "",
"base_url": os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1",
"temperature": float(os.getenv("OPENAI_TEMPERATURE", "0.0")),
}
def _default_memgallery_runtime_overrides(baseline_name: str) -> dict[str, Any]:
overrides: dict[str, Any] = {}
# --- text-only baselines ---
if baseline_name in {"FUMemory", "STMemory", "LTMemory", "RFMemory"}:
overrides["recall"] = {"truncation": _word_mode_truncation()}
if baseline_name == "LTMemory":
overrides.setdefault("recall", {})
overrides["recall"]["text_retrieval"] = {"encoder": _text_encoder_override()}
if baseline_name == "GAMemory":
overrides = {
"recall": {
"truncation": _word_mode_truncation(),
"text_retrieval": {"encoder": _text_encoder_override()},
"importance_judge": {"LLM_config": _openai_llm_override()},
},
"reflect": {
"reflector": {"LLM_config": _openai_llm_override()},
},
}
if baseline_name == "MGMemory":
overrides = {
"recall": {
"truncation": _word_mode_truncation(),
"recall_retrieval": {"encoder": _text_encoder_override()},
"archival_retrieval": {"encoder": _text_encoder_override()},
"trigger": {"LLM_config": _openai_llm_override()},
},
"store": {
"flush_checker": _word_mode_truncation(),
"summarizer": {"LLM_config": _openai_llm_override()},
},
}
if baseline_name == "RFMemory":
overrides.setdefault("optimize", {})
overrides["optimize"] = {
"reflector": {"LLM_config": _openai_llm_override()},
}
# --- multimodal baselines ---
if baseline_name == "MMMemory":
overrides = {
"recall": {
"truncation": _word_mode_truncation(),
},
}
if baseline_name == "MMFUMemory":
overrides = {
"recall": {
"truncation": _word_mode_truncation(),
},
}
if baseline_name == "NGMemory":
overrides = {
"recall": {
"truncation": _word_mode_truncation(),
},
}
if baseline_name == "AUGUSTUSMemory":
overrides = {
"recall": {
"truncation": _word_mode_truncation(),
},
"concept_extractor": {
"llm": _openai_llm_override(),
},
}
if baseline_name == "UniversalRAGMemory":
overrides = {
"recall": {
"truncation": _word_mode_truncation(),
"text_retrieval": {"encoder": _text_encoder_override()},
},
"routing": {
"llm": _openai_llm_override(),
},
}
return overrides
def _resolve_baselines_root() -> Path:
"""Return the ``baselines/`` directory (sibling of eval_framework/).
Layout::
nips26/
β”œβ”€β”€ eval_framework/
└── baselines/
β”œβ”€β”€ memengine/
└── default_config/
"""
# registry.py -> memory_adapters/ -> eval_framework/ -> nips26/
return Path(__file__).resolve().parents[2] / "baselines"
def _ensure_memgallery_benchmark_on_path() -> Path:
"""Add ``baselines/`` to sys.path so that ``memengine`` and
``default_config`` packages are importable."""
baselines_root = _resolve_baselines_root()
if not (baselines_root / "memengine").is_dir():
raise FileNotFoundError(
f"memengine/ not found under {baselines_root}. "
f"Clone MemEngine into baselines/memengine."
)
s = str(baselines_root)
if s not in sys.path:
sys.path.insert(0, s)
_bootstrap_memengine_namespace(baselines_root)
return baselines_root
def _bootstrap_memengine_namespace(root: Path) -> None:
"""
Pre-seed lightweight namespace packages for the co-located memengine package.
memengine's package-level ``__init__.py`` eagerly imports all memories and function
modules, which pulls in heavyweight optional dependencies like ``torch`` even for
simple baselines such as ``FUMemory``. By registering package shells in ``sys.modules``
first, we can import only the specific submodules we need.
*root* is the ``our/`` directory that contains ``memengine/``.
"""
package_paths = {
"memengine": root / "memengine",
"memengine.config": root / "memengine" / "config",
"memengine.memory": root / "memengine" / "memory",
"memengine.function": root / "memengine" / "function",
"memengine.operation": root / "memengine" / "operation",
"memengine.utils": root / "memengine" / "utils",
}
for package_name, package_path in package_paths.items():
existing = sys.modules.get(package_name)
if existing is not None:
continue
module = types.ModuleType(package_name)
module.__path__ = [str(package_path)] # type: ignore[attr-defined]
module.__package__ = package_name
sys.modules[package_name] = module
for package_name in package_paths:
if "." not in package_name:
continue
parent_name, child_name = package_name.rsplit(".", 1)
parent = sys.modules.get(parent_name)
child = sys.modules.get(package_name)
if parent is not None and child is not None and not hasattr(parent, child_name):
setattr(parent, child_name, child)
_bootstrap_optional_dependency_stubs()
_populate_safe_memengine_function_exports()
def _bootstrap_optional_dependency_stubs() -> None:
"""Provide narrow stubs for optional imports needed only on unused code paths."""
if "torch" not in sys.modules:
try:
sys.modules["torch"] = importlib.import_module("torch")
except Exception:
pass
if "torch" not in sys.modules:
torch_module = types.ModuleType("torch")
def _torch_unavailable(*args: Any, **kwargs: Any) -> Any:
del args, kwargs
raise RuntimeError(
"PyTorch is required for encoder-backed or tensor-based Mem-Gallery "
"baselines, but `torch` is not installed in this environment."
)
torch_module.cuda = types.SimpleNamespace(is_available=lambda: False) # type: ignore[attr-defined]
torch_module.device = lambda spec: spec # type: ignore[attr-defined]
torch_module.no_grad = lambda: nullcontext() # type: ignore[attr-defined]
torch_module.from_numpy = _torch_unavailable # type: ignore[attr-defined]
torch_module.stack = _torch_unavailable # type: ignore[attr-defined]
torch_module.sort = _torch_unavailable # type: ignore[attr-defined]
torch_module.matmul = _torch_unavailable # type: ignore[attr-defined]
torch_module.ones = _torch_unavailable # type: ignore[attr-defined]
torch_module.nn = types.SimpleNamespace( # type: ignore[attr-defined]
functional=types.SimpleNamespace(normalize=_torch_unavailable)
)
sys.modules["torch"] = torch_module
if "transformers" not in sys.modules:
try:
sys.modules["transformers"] = importlib.import_module("transformers")
except Exception:
pass
if "transformers" not in sys.modules:
transformers_module = types.ModuleType("transformers")
class _UnavailableAutoTokenizer:
@classmethod
def from_pretrained(cls, *args: Any, **kwargs: Any) -> Any:
del args, kwargs
raise RuntimeError(
"transformers.AutoTokenizer is required for token-mode truncation "
"or encoder-backed baselines, but `transformers` is not installed."
)
transformers_module.AutoTokenizer = _UnavailableAutoTokenizer # type: ignore[attr-defined]
sys.modules["transformers"] = transformers_module
def _populate_safe_memengine_function_exports() -> None:
"""Expose all function symbols for complete baseline deployment without running package __init__."""
function_pkg = sys.modules.get("memengine.function")
if function_pkg is None:
return
# Complete list β€” covers every module referenced by any of the 11 baselines:
# FU/ST/LT/GA/MG/RF (text-only) + MM/MMFU/NG/AUGUSTUS/UniversalRAG (multimodal)
for module_name in (
# --- text-only baselines ---
"memengine.function.Encoder",
"memengine.function.Retrieval",
"memengine.function.LLM",
"memengine.function.Judge",
"memengine.function.Reflector",
"memengine.function.Summarizer",
"memengine.function.Truncation",
"memengine.function.Trigger",
"memengine.function.Utilization",
"memengine.function.Forget",
# --- multimodal / graph / concept baselines ---
"memengine.function.MultiModalEncoder",
"memengine.function.MultiModalRetrieval",
"memengine.function.ConceptExtractor",
"memengine.function.ConceptBasedRetrieval",
"memengine.function.EntityExtractor",
"memengine.function.FactExtractor",
"memengine.function.UniversalRAGRouting",
"memengine.function.UniversalRAGRetrieval",
):
try:
module = importlib.import_module(module_name)
except Exception:
# Some modules may depend on optional heavy deps (torch, transformers).
# Skip gracefully β€” they will fail loudly if the baseline actually needs them.
continue
for attr_name, value in vars(module).items():
if attr_name.startswith("_"):
continue
if not hasattr(function_pkg, attr_name):
setattr(function_pkg, attr_name, value)
def create_memgallery_adapter(
baseline_name: str,
*,
config_overrides: dict[str, Any] | None = None,
) -> MemGalleryNativeAdapter:
"""
Instantiate a native Mem-Gallery adapter for a known baseline name.
Loads default_config + memengine from the Mem-Gallery benchmark tree.
"""
if baseline_name not in MEMGALLERY_NATIVE_BASELINES:
raise KeyError(f"unknown Mem-Gallery baseline: {baseline_name!r}")
_ensure_memgallery_benchmark_on_path()
runtime_overrides = _default_memgallery_runtime_overrides(baseline_name)
if config_overrides:
runtime_overrides = {
**runtime_overrides,
**config_overrides,
}
return MemGalleryNativeAdapter.from_baseline(
baseline_name, config=runtime_overrides or None
)
MEMGALLERY_NATIVE_REGISTRY: dict[str, Callable[..., MemGalleryNativeAdapter]] = {
name: partial(create_memgallery_adapter, name) for name in MEMGALLERY_NATIVE_BASELINES
}
EXTERNAL_ADAPTER_KEYS: frozenset[str] = frozenset({
"A-Mem", "MemoryOS", "Dummy",
"Mem0", "Mem0-Graph",
"SimpleMem", "Omni-SimpleMem",
"MemVerse",
"Zep",
})
def create_amem_adapter(**kwargs: Any) -> MemoryAdapter:
from eval_framework.memory_adapters.amem_v2 import AMemV2Adapter
return AMemV2Adapter(**kwargs)
def create_memoryos_adapter(**kwargs: Any) -> MemoryOSAdapter:
return MemoryOSAdapter(**kwargs)
def create_dummy_adapter(**kwargs: Any) -> DummyAdapter:
return DummyAdapter()
def create_mem0_adapter(**kwargs: Any) -> MemoryAdapter:
from eval_framework.memory_adapters.mem0_adapter import Mem0Adapter
return Mem0Adapter(use_graph=False, **kwargs)
def create_mem0_graph_adapter(**kwargs: Any) -> MemoryAdapter:
from eval_framework.memory_adapters.mem0_adapter import Mem0Adapter
return Mem0Adapter(use_graph=True, **kwargs)
def create_simplemem_adapter(**kwargs: Any) -> MemoryAdapter:
from eval_framework.memory_adapters.simplemem_adapter import SimpleMemAdapter
return SimpleMemAdapter(mode="text", **kwargs)
def create_omni_simplemem_adapter(**kwargs: Any) -> MemoryAdapter:
from eval_framework.memory_adapters.simplemem_adapter import SimpleMemAdapter
return SimpleMemAdapter(mode="omni", **kwargs)
def create_memverse_adapter(**kwargs: Any) -> MemoryAdapter:
from eval_framework.memory_adapters.memverse_adapter import MemVerseAdapter
return MemVerseAdapter(**kwargs)
def create_zep_adapter(**kwargs: Any) -> MemoryAdapter:
from eval_framework.memory_adapters.zep_adapter import ZepAdapter
return ZepAdapter(**kwargs)
EXTERNAL_ADAPTER_REGISTRY: dict[str, Callable[..., MemoryAdapter]] = {
"A-Mem": create_amem_adapter,
"MemoryOS": create_memoryos_adapter,
"Dummy": create_dummy_adapter,
"Mem0": create_mem0_adapter,
"Mem0-Graph": create_mem0_graph_adapter,
"SimpleMem": create_simplemem_adapter,
"Omni-SimpleMem": create_omni_simplemem_adapter,
"MemVerse": create_memverse_adapter,
"Zep": create_zep_adapter,
}
def create_external_adapter(
name: str,
*,
config_overrides: dict[str, Any] | None = None,
) -> MemoryAdapter:
"""Instantiate an external adapter for a known baseline name."""
if name not in EXTERNAL_ADAPTER_KEYS:
raise KeyError(f"unknown external adapter: {name!r}")
return EXTERNAL_ADAPTER_REGISTRY[name](**(config_overrides or {}))