File size: 4,102 Bytes
0b2d478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Simple short + long-term memory store."""
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np

from .rag.embeddings import AzureEmbeddingClient

@dataclass
class MemoryEntry:
    role: str
    text: str


class MemoryStore:
    def __init__(self, store_path: Path, top_k: int, summary_every: int) -> None:
        self._store_path = store_path
        self._top_k = top_k
        self._summary_every = summary_every
        self._vectors_path = self._store_path.with_suffix(".npy")
        self._short: list[MemoryEntry] = []
        self._long: list[str] = []
        self._vectors: np.ndarray | None = None
        self._embedder = AzureEmbeddingClient()
        self._store_path.parent.mkdir(parents=True, exist_ok=True)
        self._load()

    def add_turn(self, user: str, assistant: str) -> None:
        self._short.append(MemoryEntry(role="user", text=user))
        self._short.append(MemoryEntry(role="assistant", text=assistant))
        if len(self._short) // 2 >= self._summary_every:
            summary = self._summarize_short()
            if summary:
                self._long.append(summary)
            self._short = []
            self._save()

    def short_context(self, max_turns: int = 6) -> str:
        if not self._short:
            return ""
        lines: list[str] = []
        for entry in self._short[-max_turns * 2 :]:
            prefix = "User" if entry.role == "user" else "Assistant"
            lines.append(f"{prefix}: {entry.text}")
        return "\n".join(lines)

    def load_short(self, entries: list[dict[str, str]]) -> None:
        self._short = [MemoryEntry(role=e["role"], text=e["text"]) for e in entries]

    def short_entries(self) -> list[dict[str, str]]:
        return [{"role": e.role, "text": e.text} for e in self._short]

    def long_context(self, query: str) -> str:
        if not self._long:
            return ""
        if self._vectors is None or len(self._vectors) != len(self._long):
            self._vectors = self._embedder.embed(self._long)
            if self._vectors.size:
                self._vectors = self._normalize(self._vectors)
        if self._vectors is None or self._vectors.size == 0:
            return ""
        q = self._embedder.embed([query])
        if q.size == 0:
            return ""
        q = self._normalize(q)
        scores = np.dot(self._vectors, q.T).flatten()
        top_indices = scores.argsort()[::-1][: self._top_k]
        selected = [self._long[i] for i in top_indices if scores[i] > 0]
        return "\n".join(selected)

    def reset(self) -> None:
        self._short = []
        self._long = []
        self._vectors = None
        if self._store_path.exists():
            self._store_path.unlink()
        if self._vectors_path.exists():
            self._vectors_path.unlink()

    def _summarize_short(self) -> str:
        if not self._short:
            return ""
        lines: list[str] = []
        for entry in self._short:
            prefix = "User" if entry.role == "user" else "Assistant"
            lines.append(f"{prefix}: {entry.text}")
        return " | ".join(lines)

    def _save(self) -> None:
        payload = {"long": self._long}
        self._store_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))
        if self._vectors is not None and self._vectors.size:
            np.save(self._vectors_path, self._vectors)

    def _load(self) -> None:
        if not self._store_path.exists():
            return
        try:
            payload: dict[str, Any] = json.loads(self._store_path.read_text())
            self._long = list(payload.get("long", []))
            if self._vectors_path.exists():
                self._vectors = np.load(self._vectors_path)
        except Exception:
            self._long = []
            self._vectors = None

    def _normalize(self, vectors: np.ndarray) -> np.ndarray:
        norms = np.linalg.norm(vectors, axis=1, keepdims=True)
        norms[norms == 0] = 1.0
        return vectors / norms