File size: 5,601 Bytes
85b19cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Adapter for SimpleMem and Omni-SimpleMem baselines."""

from __future__ import annotations

import os
import sys
from pathlib import Path
from typing import Any

from eval_framework.datasets.schemas import (
    MemoryDeltaRecord,
    MemorySnapshotRecord,
    NormalizedTurn,
    RetrievalItem,
    RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter

_DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/SimpleMem")


class SimpleMemAdapter(MemoryAdapter):
    """Adapter for SimpleMem (text mode) or Omni-SimpleMem (omni mode)."""

    def __init__(
        self,
        *,
        mode: str = "text",
        source_root: str | os.PathLike[str] | None = None,
        **kwargs: Any,
    ) -> None:
        self._mode = mode  # "text" or "omni"
        root = Path(source_root or _DEFAULT_SOURCE).resolve()
        if str(root) not in sys.path:
            sys.path.insert(0, str(root))

        import simplemem_router as simplemem
        self._simplemem = simplemem
        self._mem: Any = None
        self._session_id = ""
        self._prev_snapshot_ids: set[str] = set()
        self._stored_texts: list[dict[str, str]] = []
        self._init_mem()

    def _init_mem(self) -> None:
        self._mem = self._simplemem.create(mode=self._mode, clear_db=True)
        self._stored_texts = []

    def reset(self) -> None:
        if self._mem is not None:
            try:
                self._mem.close()
            except Exception:
                pass
        self._init_mem()
        self._prev_snapshot_ids = set()

    def ingest_turn(self, turn: NormalizedTurn) -> None:
        self._session_id = turn.session_id
        text = f"{turn.role}: {turn.text}"
        for att in turn.attachments:
            text += f"\n[{att.type}] {att.caption}"

        mid = str(len(self._stored_texts))
        if self._mode == "omni":
            self._mem.add_text(text, tags=[f"session:{turn.session_id}"])
        else:
            speaker = "User" if turn.role == "user" else "Assistant"
            ts = turn.timestamp or ""
            self._mem.add_dialogue(speaker, text, ts)
        self._stored_texts.append({"id": mid, "text": text, "session_id": turn.session_id})

    def end_session(self, session_id: str) -> None:
        self._session_id = session_id
        if self._mode == "text":
            try:
                self._mem.finalize()
            except Exception:
                pass

    def snapshot_memories(self) -> list[MemorySnapshotRecord]:
        return [
            MemorySnapshotRecord(
                memory_id=t["id"], text=t["text"],
                session_id=t["session_id"], status="active",
                source=f"SimpleMem-{self._mode}",
                raw_backend_id=t["id"], raw_backend_type="simplemem",
                metadata={},
            )
            for t in self._stored_texts
        ]

    def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
        current = self.snapshot_memories()
        current_ids = {s.memory_id for s in current}
        deltas = [
            MemoryDeltaRecord(
                session_id=session_id, op="add", text=s.text,
                linked_previous=(), raw_backend_id=s.raw_backend_id,
                metadata={"baseline": f"SimpleMem-{self._mode}"},
            )
            for s in current if s.memory_id not in self._prev_snapshot_ids
        ]
        self._prev_snapshot_ids = current_ids
        return deltas

    def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
        items: list[RetrievalItem] = []
        try:
            if self._mode == "omni":
                result = self._mem.query(query, top_k=top_k)
                if isinstance(result, list):
                    for i, r in enumerate(result[:top_k]):
                        text = r.get("text", str(r)) if isinstance(r, dict) else str(r)
                        items.append(RetrievalItem(
                            rank=i, memory_id=str(i), text=text,
                            score=1.0 / (i + 1), raw_backend_id=None,
                        ))
            else:
                answer = self._mem.ask(query)
                if answer:
                    items.append(RetrievalItem(
                        rank=0, memory_id="answer", text=str(answer),
                        score=1.0, raw_backend_id=None,
                    ))
        except Exception:
            pass

        if not items:
            # Fallback: simple text search over stored memories
            query_lower = query.lower()
            scored = []
            for t in self._stored_texts:
                overlap = len(set(query_lower.split()) & set(t["text"].lower().split()))
                scored.append((overlap, t))
            scored.sort(key=lambda x: x[0], reverse=True)
            for i, (sc, t) in enumerate(scored[:top_k]):
                items.append(RetrievalItem(
                    rank=i, memory_id=t["id"], text=t["text"],
                    score=float(sc) / max(len(query.split()), 1),
                    raw_backend_id=t["id"],
                ))

        return RetrievalRecord(
            query=query, top_k=top_k, items=items[:top_k],
            raw_trace={"baseline": f"SimpleMem-{self._mode}"},
        )

    def get_capabilities(self) -> dict[str, Any]:
        name = "Omni-SimpleMem" if self._mode == "omni" else "SimpleMem"
        return {
            "backend": name, "baseline": name,
            "available": self._mem is not None,
            "delta_granularity": "per_turn",
            "snapshot_mode": "full",
        }