File size: 4,252 Bytes
85b19cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Adapter for Zep memory system (community/self-hosted edition)."""

from __future__ import annotations

import os
import uuid as _uuid
from typing import Any

from eval_framework.datasets.schemas import (
    MemoryDeltaRecord,
    MemorySnapshotRecord,
    NormalizedTurn,
    RetrievalItem,
    RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter


class ZepAdapter(MemoryAdapter):
    """Adapter for Zep community edition (self-hosted)."""

    def __init__(self, *, base_url: str | None = None, **kwargs: Any) -> None:
        from zep_python import ZepClient

        self._base_url = base_url or os.getenv("ZEP_BASE_URL", "http://localhost:8000")
        self._client = ZepClient(base_url=self._base_url)
        self._session_id = ""
        self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}"
        self._prev_snapshot_ids: set[str] = set()

    def reset(self) -> None:
        try:
            self._client.memory.delete_memory(self._thread_id)
        except Exception:
            pass
        self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}"
        self._prev_snapshot_ids = set()

    def ingest_turn(self, turn: NormalizedTurn) -> None:
        from zep_python.memory import Memory
        from zep_python.message import Message

        self._session_id = turn.session_id
        text = f"{turn.role}: {turn.text}"
        for att in turn.attachments:
            text += f"\n[{att.type}] {att.caption}"

        role_type = "user" if turn.role == "user" else "ai"
        msg = Message(role=turn.role, role_type=role_type, content=text)
        memory = Memory(messages=[msg])
        self._client.memory.add_memory(self._thread_id, memory)

    def end_session(self, session_id: str) -> None:
        self._session_id = session_id

    def snapshot_memories(self) -> list[MemorySnapshotRecord]:
        try:
            memory = self._client.memory.get_memory(self._thread_id)
        except Exception:
            return []

        rows: list[MemorySnapshotRecord] = []
        if memory and memory.messages:
            for i, msg in enumerate(memory.messages):
                mid = str(getattr(msg, "uuid", i))
                rows.append(MemorySnapshotRecord(
                    memory_id=mid,
                    text=msg.content or "",
                    session_id=self._session_id,
                    status="active",
                    source="Zep",
                    raw_backend_id=mid,
                    raw_backend_type="zep_message",
                    metadata={},
                ))
        return rows

    def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
        current = self.snapshot_memories()
        current_ids = {s.memory_id for s in current}
        deltas = [
            MemoryDeltaRecord(
                session_id=session_id, op="add", text=s.text,
                linked_previous=(), raw_backend_id=s.raw_backend_id,
                metadata={"baseline": "Zep"},
            )
            for s in current if s.memory_id not in self._prev_snapshot_ids
        ]
        self._prev_snapshot_ids = current_ids
        return deltas

    def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
        try:
            results = self._client.memory.search_memory(
                self._thread_id, query, limit=top_k,
            )
        except Exception:
            results = []

        items = [
            RetrievalItem(
                rank=i,
                memory_id=str(getattr(r.message, "uuid", i)) if r.message else str(i),
                text=r.message.content if r.message else str(r),
                score=float(getattr(r, "score", 1.0 / (i + 1))),
                raw_backend_id=str(getattr(r.message, "uuid", "")) if r.message else None,
            )
            for i, r in enumerate(results[:top_k])
        ]
        return RetrievalRecord(
            query=query, top_k=top_k, items=items,
            raw_trace={"baseline": "Zep"},
        )

    def get_capabilities(self) -> dict[str, Any]:
        return {
            "backend": "Zep",
            "baseline": "Zep",
            "available": True,
            "delta_granularity": "snapshot_diff",
            "snapshot_mode": "full",
        }