File size: 4,953 Bytes
85b19cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Adapter for A-Mem (new API: agentic_memory.AgenticMemorySystem)."""

from __future__ import annotations

import os
import sys
from pathlib import Path
from typing import Any

from dotenv import load_dotenv

load_dotenv(Path(__file__).resolve().parents[2] / ".env")

from eval_framework.datasets.schemas import (
    MemoryDeltaRecord,
    MemorySnapshotRecord,
    NormalizedTurn,
    RetrievalItem,
    RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter

_DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/A-Mem")


class AMemV2Adapter(MemoryAdapter):
    """Adapter for A-Mem (new agentic_memory API)."""

    def __init__(
        self,
        *,
        source_root: str | os.PathLike[str] | None = None,
        **kwargs: Any,
    ) -> None:
        root = Path(source_root or _DEFAULT_SOURCE).resolve()
        if str(root) not in sys.path:
            sys.path.insert(0, str(root))

        from agentic_memory.memory_system import AgenticMemorySystem

        self._cls = AgenticMemorySystem
        self._backend: Any = None
        self._session_id = ""
        self._prev_snapshot_ids: set[str] = set()
        self._init_backend()

    def _init_backend(self) -> None:
        self._backend = self._cls(
            model_name="all-MiniLM-L6-v2",
            llm_backend="openai",
            llm_model=os.getenv("OPENAI_MODEL") or "gpt-4o",
            api_key=os.getenv("OPENAI_API_KEY"),
        )

    def reset(self) -> None:
        self._init_backend()
        self._prev_snapshot_ids = set()

    def ingest_turn(self, turn: NormalizedTurn) -> None:
        self._session_id = turn.session_id
        text = f"{turn.role}: {turn.text}"
        for att in turn.attachments:
            text += f"\n[{att.type}] {att.caption}"
        self._backend.add_note(text, time=turn.timestamp)

    def end_session(self, session_id: str) -> None:
        self._session_id = session_id

    def snapshot_memories(self) -> list[MemorySnapshotRecord]:
        rows: list[MemorySnapshotRecord] = []
        for mid, note in self._backend.memories.items():
            content = str(getattr(note, "content", ""))
            context = getattr(note, "context", "")
            keywords = list(getattr(note, "keywords", []) or [])
            parts = [content]
            if context:
                parts.append(f"[context] {context}")
            if keywords:
                parts.append(f"[keywords] {', '.join(keywords)}")
            rows.append(MemorySnapshotRecord(
                memory_id=str(mid),
                text="\n".join(parts),
                session_id=self._session_id,
                status="active",
                source="A-Mem",
                raw_backend_id=str(mid),
                raw_backend_type="a_mem_note",
                metadata={},
            ))
        return rows

    def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
        current = self.snapshot_memories()
        current_ids = {s.memory_id for s in current}
        deltas = [
            MemoryDeltaRecord(
                session_id=session_id, op="add", text=s.text,
                linked_previous=(), raw_backend_id=s.raw_backend_id,
                metadata={"baseline": "A-Mem"},
            )
            for s in current if s.memory_id not in self._prev_snapshot_ids
        ]
        self._prev_snapshot_ids = current_ids
        return deltas

    def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
        items: list[RetrievalItem] = []
        try:
            results = self._backend.search(query, k=top_k)
            for i, r in enumerate(results[:top_k]):
                text = r.get("content", str(r)) if isinstance(r, dict) else str(r)
                mid = r.get("id", str(i)) if isinstance(r, dict) else str(i)
                score = float(r.get("score", 1.0 / (i + 1))) if isinstance(r, dict) else 1.0 / (i + 1)
                items.append(RetrievalItem(
                    rank=i, memory_id=str(mid), text=text,
                    score=score, raw_backend_id=str(mid),
                ))
        except Exception:
            # Fallback to raw search
            try:
                raw = self._backend.find_related_memories_raw(query, k=top_k)
                if raw:
                    items.append(RetrievalItem(
                        rank=0, memory_id="bundle", text=str(raw),
                        score=1.0, raw_backend_id=None,
                    ))
            except Exception:
                pass

        return RetrievalRecord(
            query=query, top_k=top_k, items=items[:top_k],
            raw_trace={"baseline": "A-Mem"},
        )

    def get_capabilities(self) -> dict[str, Any]:
        return {
            "backend": "A-Mem",
            "baseline": "A-Mem",
            "available": self._backend is not None,
            "delta_granularity": "snapshot_diff",
            "snapshot_mode": "full",
        }