File size: 3,691 Bytes
85b19cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Helpers to map turns, backend memory dicts, and recall outputs into shared schemas."""

from __future__ import annotations

from typing import Any, Mapping

from eval_framework.datasets.schemas import (
    MemorySnapshotRecord,
    NormalizedTurn,
    RetrievalItem,
    RetrievalRecord,
)


def turn_to_observation_dict(turn: NormalizedTurn) -> dict[str, Any]:
    """Build a Mem-Gallery store observation from a normalized turn."""
    parts: list[str] = [turn.text]
    for att in turn.attachments:
        parts.append(f"[{att.type}] {att.caption}")
    text = "\n".join(parts)
    obs: dict[str, Any] = {"text": text}
    if turn.timestamp:
        obs["timestamp"] = turn.timestamp
    obs["dialogue_id"] = f"{turn.session_id}:{turn.turn_index}"
    return obs


def memory_element_text(element: Mapping[str, Any]) -> str:
    """Best-effort text extraction from a Mem-Gallery memory dict."""
    raw = element.get("text", "")
    if isinstance(raw, list):
        return " ".join(str(x) for x in raw)
    if raw is None:
        base = ""
    else:
        base = str(raw)
    image = element.get("image")
    if isinstance(image, dict):
        cap = image.get("caption")
        if cap:
            base = f"{base}\n[image] {cap}".strip()
    return base


def linear_element_to_snapshot(
    element: Mapping[str, Any],
    *,
    memory_id: str,
    session_id: str,
    source: str,
    status: str = "active",
) -> MemorySnapshotRecord:
    """Map a linear-storage memory dict into MemorySnapshotRecord."""
    cid = element.get("counter_id")
    raw_id = str(cid) if cid is not None else memory_id
    return MemorySnapshotRecord(
        memory_id=memory_id,
        text=memory_element_text(element),
        session_id=session_id,
        status=status,
        source=source,
        raw_backend_id=raw_id,
        raw_backend_type="linear",
        metadata={},
    )


def normalize_recall_to_retrieval(
    query: str,
    top_k: int,
    raw: Any,
    *,
    raw_trace: dict[str, Any] | None = None,
) -> RetrievalRecord:
    """Normalize Mem-Gallery recall outputs into RetrievalRecord."""
    trace = dict(raw_trace or {})
    items: list[RetrievalItem] = []

    if isinstance(raw, str):
        items.append(
            RetrievalItem(
                rank=0,
                memory_id="memgallery:string_bundle",
                text=raw,
                score=1.0,
                raw_backend_id=None,
            )
        )
    elif isinstance(raw, list):
        for i, row in enumerate(raw[: max(0, top_k)]):
            if isinstance(row, dict):
                mid = row.get("counter_id")
                items.append(
                    RetrievalItem(
                        rank=i,
                        memory_id=str(mid if mid is not None else i),
                        text=memory_element_text(row),
                        score=float(row.get("score", 1.0)),
                        raw_backend_id=str(mid) if mid is not None else None,
                    )
                )
            else:
                items.append(
                    RetrievalItem(
                        rank=i,
                        memory_id=str(i),
                        text=str(row),
                        score=1.0,
                        raw_backend_id=None,
                    )
                )
    else:
        items.append(
            RetrievalItem(
                rank=0,
                memory_id="memgallery:object_bundle",
                text=str(raw),
                score=1.0,
                raw_backend_id=None,
            )
        )

    return RetrievalRecord(query=query, top_k=top_k, items=items[:top_k], raw_trace=trace)