File size: 4,252 Bytes
85b19cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | """Adapter for Zep memory system (community/self-hosted edition)."""
from __future__ import annotations
import os
import uuid as _uuid
from typing import Any
from eval_framework.datasets.schemas import (
MemoryDeltaRecord,
MemorySnapshotRecord,
NormalizedTurn,
RetrievalItem,
RetrievalRecord,
)
from eval_framework.memory_adapters.base import MemoryAdapter
class ZepAdapter(MemoryAdapter):
"""Adapter for Zep community edition (self-hosted)."""
def __init__(self, *, base_url: str | None = None, **kwargs: Any) -> None:
from zep_python import ZepClient
self._base_url = base_url or os.getenv("ZEP_BASE_URL", "http://localhost:8000")
self._client = ZepClient(base_url=self._base_url)
self._session_id = ""
self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}"
self._prev_snapshot_ids: set[str] = set()
def reset(self) -> None:
try:
self._client.memory.delete_memory(self._thread_id)
except Exception:
pass
self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}"
self._prev_snapshot_ids = set()
def ingest_turn(self, turn: NormalizedTurn) -> None:
from zep_python.memory import Memory
from zep_python.message import Message
self._session_id = turn.session_id
text = f"{turn.role}: {turn.text}"
for att in turn.attachments:
text += f"\n[{att.type}] {att.caption}"
role_type = "user" if turn.role == "user" else "ai"
msg = Message(role=turn.role, role_type=role_type, content=text)
memory = Memory(messages=[msg])
self._client.memory.add_memory(self._thread_id, memory)
def end_session(self, session_id: str) -> None:
self._session_id = session_id
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
try:
memory = self._client.memory.get_memory(self._thread_id)
except Exception:
return []
rows: list[MemorySnapshotRecord] = []
if memory and memory.messages:
for i, msg in enumerate(memory.messages):
mid = str(getattr(msg, "uuid", i))
rows.append(MemorySnapshotRecord(
memory_id=mid,
text=msg.content or "",
session_id=self._session_id,
status="active",
source="Zep",
raw_backend_id=mid,
raw_backend_type="zep_message",
metadata={},
))
return rows
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
current = self.snapshot_memories()
current_ids = {s.memory_id for s in current}
deltas = [
MemoryDeltaRecord(
session_id=session_id, op="add", text=s.text,
linked_previous=(), raw_backend_id=s.raw_backend_id,
metadata={"baseline": "Zep"},
)
for s in current if s.memory_id not in self._prev_snapshot_ids
]
self._prev_snapshot_ids = current_ids
return deltas
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
try:
results = self._client.memory.search_memory(
self._thread_id, query, limit=top_k,
)
except Exception:
results = []
items = [
RetrievalItem(
rank=i,
memory_id=str(getattr(r.message, "uuid", i)) if r.message else str(i),
text=r.message.content if r.message else str(r),
score=float(getattr(r, "score", 1.0 / (i + 1))),
raw_backend_id=str(getattr(r.message, "uuid", "")) if r.message else None,
)
for i, r in enumerate(results[:top_k])
]
return RetrievalRecord(
query=query, top_k=top_k, items=items,
raw_trace={"baseline": "Zep"},
)
def get_capabilities(self) -> dict[str, Any]:
return {
"backend": "Zep",
"baseline": "Zep",
"available": True,
"delta_granularity": "snapshot_diff",
"snapshot_mode": "full",
}
|