"""Registry and factory for native Mem-Gallery and external placeholder adapters.""" from __future__ import annotations import os import sys import types from contextlib import nullcontext from functools import partial import importlib from pathlib import Path from typing import Any, Callable from eval_framework.memory_adapters.amem import AMemAdapter from eval_framework.memory_adapters.base import MemoryAdapter from eval_framework.memory_adapters.dummy import DummyAdapter from eval_framework.memory_adapters.memgallery_native import MemGalleryNativeAdapter from eval_framework.memory_adapters.memoryos import MemoryOSAdapter MEMGALLERY_NATIVE_BASELINES: frozenset[str] = frozenset( { "FUMemory", "STMemory", "LTMemory", "GAMemory", "MGMemory", "RFMemory", "MMMemory", "MMFUMemory", "NGMemory", "AUGUSTUSMemory", "UniversalRAGMemory", } ) def _word_mode_truncation(number: int = 50_000) -> dict[str, Any]: return { "method": "LMTruncation", "mode": "word", "number": number, "path": "", } def _text_encoder_override() -> dict[str, Any]: return { "method": "STEncoder", "path": "all-MiniLM-L6-v2", } def _openai_llm_override() -> dict[str, Any]: return { "method": "APILLM", "name": os.getenv("OPENAI_MODEL") or "gpt-5.1", "api_key": os.getenv("OPENAI_API_KEY") or "", "base_url": os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1", "temperature": float(os.getenv("OPENAI_TEMPERATURE", "0.0")), } def _default_memgallery_runtime_overrides(baseline_name: str) -> dict[str, Any]: overrides: dict[str, Any] = {} # --- text-only baselines --- if baseline_name in {"FUMemory", "STMemory", "LTMemory", "RFMemory"}: overrides["recall"] = {"truncation": _word_mode_truncation()} if baseline_name == "LTMemory": overrides.setdefault("recall", {}) overrides["recall"]["text_retrieval"] = {"encoder": _text_encoder_override()} if baseline_name == "GAMemory": overrides = { "recall": { "truncation": _word_mode_truncation(), "text_retrieval": {"encoder": _text_encoder_override()}, "importance_judge": {"LLM_config": _openai_llm_override()}, }, "reflect": { "reflector": {"LLM_config": _openai_llm_override()}, }, } if baseline_name == "MGMemory": overrides = { "recall": { "truncation": _word_mode_truncation(), "recall_retrieval": {"encoder": _text_encoder_override()}, "archival_retrieval": {"encoder": _text_encoder_override()}, "trigger": {"LLM_config": _openai_llm_override()}, }, "store": { "flush_checker": _word_mode_truncation(), "summarizer": {"LLM_config": _openai_llm_override()}, }, } if baseline_name == "RFMemory": overrides.setdefault("optimize", {}) overrides["optimize"] = { "reflector": {"LLM_config": _openai_llm_override()}, } # --- multimodal baselines --- if baseline_name == "MMMemory": overrides = { "recall": { "truncation": _word_mode_truncation(), }, } if baseline_name == "MMFUMemory": overrides = { "recall": { "truncation": _word_mode_truncation(), }, } if baseline_name == "NGMemory": overrides = { "recall": { "truncation": _word_mode_truncation(), }, } if baseline_name == "AUGUSTUSMemory": overrides = { "recall": { "truncation": _word_mode_truncation(), }, "concept_extractor": { "llm": _openai_llm_override(), }, } if baseline_name == "UniversalRAGMemory": overrides = { "recall": { "truncation": _word_mode_truncation(), "text_retrieval": {"encoder": _text_encoder_override()}, }, "routing": { "llm": _openai_llm_override(), }, } return overrides def _resolve_baselines_root() -> Path: """Return the ``baselines/`` directory (sibling of eval_framework/). Layout:: nips26/ ├── eval_framework/ └── baselines/ ├── memengine/ └── default_config/ """ # registry.py -> memory_adapters/ -> eval_framework/ -> nips26/ return Path(__file__).resolve().parents[2] / "baselines" def _ensure_memgallery_benchmark_on_path() -> Path: """Add ``baselines/`` to sys.path so that ``memengine`` and ``default_config`` packages are importable.""" baselines_root = _resolve_baselines_root() if not (baselines_root / "memengine").is_dir(): raise FileNotFoundError( f"memengine/ not found under {baselines_root}. " f"Clone MemEngine into baselines/memengine." ) s = str(baselines_root) if s not in sys.path: sys.path.insert(0, s) _bootstrap_memengine_namespace(baselines_root) return baselines_root def _bootstrap_memengine_namespace(root: Path) -> None: """ Pre-seed lightweight namespace packages for the co-located memengine package. memengine's package-level ``__init__.py`` eagerly imports all memories and function modules, which pulls in heavyweight optional dependencies like ``torch`` even for simple baselines such as ``FUMemory``. By registering package shells in ``sys.modules`` first, we can import only the specific submodules we need. *root* is the ``our/`` directory that contains ``memengine/``. """ package_paths = { "memengine": root / "memengine", "memengine.config": root / "memengine" / "config", "memengine.memory": root / "memengine" / "memory", "memengine.function": root / "memengine" / "function", "memengine.operation": root / "memengine" / "operation", "memengine.utils": root / "memengine" / "utils", } for package_name, package_path in package_paths.items(): existing = sys.modules.get(package_name) if existing is not None: continue module = types.ModuleType(package_name) module.__path__ = [str(package_path)] # type: ignore[attr-defined] module.__package__ = package_name sys.modules[package_name] = module for package_name in package_paths: if "." not in package_name: continue parent_name, child_name = package_name.rsplit(".", 1) parent = sys.modules.get(parent_name) child = sys.modules.get(package_name) if parent is not None and child is not None and not hasattr(parent, child_name): setattr(parent, child_name, child) _bootstrap_optional_dependency_stubs() _populate_safe_memengine_function_exports() def _bootstrap_optional_dependency_stubs() -> None: """Provide narrow stubs for optional imports needed only on unused code paths.""" if "torch" not in sys.modules: try: sys.modules["torch"] = importlib.import_module("torch") except Exception: pass if "torch" not in sys.modules: torch_module = types.ModuleType("torch") def _torch_unavailable(*args: Any, **kwargs: Any) -> Any: del args, kwargs raise RuntimeError( "PyTorch is required for encoder-backed or tensor-based Mem-Gallery " "baselines, but `torch` is not installed in this environment." ) torch_module.cuda = types.SimpleNamespace(is_available=lambda: False) # type: ignore[attr-defined] torch_module.device = lambda spec: spec # type: ignore[attr-defined] torch_module.no_grad = lambda: nullcontext() # type: ignore[attr-defined] torch_module.from_numpy = _torch_unavailable # type: ignore[attr-defined] torch_module.stack = _torch_unavailable # type: ignore[attr-defined] torch_module.sort = _torch_unavailable # type: ignore[attr-defined] torch_module.matmul = _torch_unavailable # type: ignore[attr-defined] torch_module.ones = _torch_unavailable # type: ignore[attr-defined] torch_module.nn = types.SimpleNamespace( # type: ignore[attr-defined] functional=types.SimpleNamespace(normalize=_torch_unavailable) ) sys.modules["torch"] = torch_module if "transformers" not in sys.modules: try: sys.modules["transformers"] = importlib.import_module("transformers") except Exception: pass if "transformers" not in sys.modules: transformers_module = types.ModuleType("transformers") class _UnavailableAutoTokenizer: @classmethod def from_pretrained(cls, *args: Any, **kwargs: Any) -> Any: del args, kwargs raise RuntimeError( "transformers.AutoTokenizer is required for token-mode truncation " "or encoder-backed baselines, but `transformers` is not installed." ) transformers_module.AutoTokenizer = _UnavailableAutoTokenizer # type: ignore[attr-defined] sys.modules["transformers"] = transformers_module def _populate_safe_memengine_function_exports() -> None: """Expose all function symbols for complete baseline deployment without running package __init__.""" function_pkg = sys.modules.get("memengine.function") if function_pkg is None: return # Complete list — covers every module referenced by any of the 11 baselines: # FU/ST/LT/GA/MG/RF (text-only) + MM/MMFU/NG/AUGUSTUS/UniversalRAG (multimodal) for module_name in ( # --- text-only baselines --- "memengine.function.Encoder", "memengine.function.Retrieval", "memengine.function.LLM", "memengine.function.Judge", "memengine.function.Reflector", "memengine.function.Summarizer", "memengine.function.Truncation", "memengine.function.Trigger", "memengine.function.Utilization", "memengine.function.Forget", # --- multimodal / graph / concept baselines --- "memengine.function.MultiModalEncoder", "memengine.function.MultiModalRetrieval", "memengine.function.ConceptExtractor", "memengine.function.ConceptBasedRetrieval", "memengine.function.EntityExtractor", "memengine.function.FactExtractor", "memengine.function.UniversalRAGRouting", "memengine.function.UniversalRAGRetrieval", ): try: module = importlib.import_module(module_name) except Exception: # Some modules may depend on optional heavy deps (torch, transformers). # Skip gracefully — they will fail loudly if the baseline actually needs them. continue for attr_name, value in vars(module).items(): if attr_name.startswith("_"): continue if not hasattr(function_pkg, attr_name): setattr(function_pkg, attr_name, value) def create_memgallery_adapter( baseline_name: str, *, config_overrides: dict[str, Any] | None = None, ) -> MemGalleryNativeAdapter: """ Instantiate a native Mem-Gallery adapter for a known baseline name. Loads default_config + memengine from the Mem-Gallery benchmark tree. """ if baseline_name not in MEMGALLERY_NATIVE_BASELINES: raise KeyError(f"unknown Mem-Gallery baseline: {baseline_name!r}") _ensure_memgallery_benchmark_on_path() runtime_overrides = _default_memgallery_runtime_overrides(baseline_name) if config_overrides: runtime_overrides = { **runtime_overrides, **config_overrides, } return MemGalleryNativeAdapter.from_baseline( baseline_name, config=runtime_overrides or None ) MEMGALLERY_NATIVE_REGISTRY: dict[str, Callable[..., MemGalleryNativeAdapter]] = { name: partial(create_memgallery_adapter, name) for name in MEMGALLERY_NATIVE_BASELINES } EXTERNAL_ADAPTER_KEYS: frozenset[str] = frozenset({ "A-Mem", "MemoryOS", "Dummy", "Mem0", "Mem0-Graph", "SimpleMem", "Omni-SimpleMem", "MemVerse", "Zep", }) def create_amem_adapter(**kwargs: Any) -> MemoryAdapter: from eval_framework.memory_adapters.amem_v2 import AMemV2Adapter return AMemV2Adapter(**kwargs) def create_memoryos_adapter(**kwargs: Any) -> MemoryOSAdapter: return MemoryOSAdapter(**kwargs) def create_dummy_adapter(**kwargs: Any) -> DummyAdapter: return DummyAdapter() def create_mem0_adapter(**kwargs: Any) -> MemoryAdapter: from eval_framework.memory_adapters.mem0_adapter import Mem0Adapter return Mem0Adapter(use_graph=False, **kwargs) def create_mem0_graph_adapter(**kwargs: Any) -> MemoryAdapter: from eval_framework.memory_adapters.mem0_adapter import Mem0Adapter return Mem0Adapter(use_graph=True, **kwargs) def create_simplemem_adapter(**kwargs: Any) -> MemoryAdapter: from eval_framework.memory_adapters.simplemem_adapter import SimpleMemAdapter return SimpleMemAdapter(mode="text", **kwargs) def create_omni_simplemem_adapter(**kwargs: Any) -> MemoryAdapter: from eval_framework.memory_adapters.simplemem_adapter import SimpleMemAdapter return SimpleMemAdapter(mode="omni", **kwargs) def create_memverse_adapter(**kwargs: Any) -> MemoryAdapter: from eval_framework.memory_adapters.memverse_adapter import MemVerseAdapter return MemVerseAdapter(**kwargs) def create_zep_adapter(**kwargs: Any) -> MemoryAdapter: from eval_framework.memory_adapters.zep_adapter import ZepAdapter return ZepAdapter(**kwargs) EXTERNAL_ADAPTER_REGISTRY: dict[str, Callable[..., MemoryAdapter]] = { "A-Mem": create_amem_adapter, "MemoryOS": create_memoryos_adapter, "Dummy": create_dummy_adapter, "Mem0": create_mem0_adapter, "Mem0-Graph": create_mem0_graph_adapter, "SimpleMem": create_simplemem_adapter, "Omni-SimpleMem": create_omni_simplemem_adapter, "MemVerse": create_memverse_adapter, "Zep": create_zep_adapter, } def create_external_adapter( name: str, *, config_overrides: dict[str, Any] | None = None, ) -> MemoryAdapter: """Instantiate an external adapter for a known baseline name.""" if name not in EXTERNAL_ADAPTER_KEYS: raise KeyError(f"unknown external adapter: {name!r}") return EXTERNAL_ADAPTER_REGISTRY[name](**(config_overrides or {}))