Spaces:
Sleeping
Sleeping
| """Simple short + long-term memory store.""" | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| from .rag.embeddings import AzureEmbeddingClient | |
| class MemoryEntry: | |
| role: str | |
| text: str | |
| class MemoryStore: | |
| def __init__(self, store_path: Path, top_k: int, summary_every: int) -> None: | |
| self._store_path = store_path | |
| self._top_k = top_k | |
| self._summary_every = summary_every | |
| self._vectors_path = self._store_path.with_suffix(".npy") | |
| self._short: list[MemoryEntry] = [] | |
| self._long: list[str] = [] | |
| self._vectors: np.ndarray | None = None | |
| self._embedder = AzureEmbeddingClient() | |
| self._store_path.parent.mkdir(parents=True, exist_ok=True) | |
| self._load() | |
| def add_turn(self, user: str, assistant: str) -> None: | |
| self._short.append(MemoryEntry(role="user", text=user)) | |
| self._short.append(MemoryEntry(role="assistant", text=assistant)) | |
| if len(self._short) // 2 >= self._summary_every: | |
| summary = self._summarize_short() | |
| if summary: | |
| self._long.append(summary) | |
| self._short = [] | |
| self._save() | |
| def short_context(self, max_turns: int = 6) -> str: | |
| if not self._short: | |
| return "" | |
| lines: list[str] = [] | |
| for entry in self._short[-max_turns * 2 :]: | |
| prefix = "User" if entry.role == "user" else "Assistant" | |
| lines.append(f"{prefix}: {entry.text}") | |
| return "\n".join(lines) | |
| def load_short(self, entries: list[dict[str, str]]) -> None: | |
| self._short = [MemoryEntry(role=e["role"], text=e["text"]) for e in entries] | |
| def short_entries(self) -> list[dict[str, str]]: | |
| return [{"role": e.role, "text": e.text} for e in self._short] | |
| def long_context(self, query: str) -> str: | |
| if not self._long: | |
| return "" | |
| if self._vectors is None or len(self._vectors) != len(self._long): | |
| self._vectors = self._embedder.embed(self._long) | |
| if self._vectors.size: | |
| self._vectors = self._normalize(self._vectors) | |
| if self._vectors is None or self._vectors.size == 0: | |
| return "" | |
| q = self._embedder.embed([query]) | |
| if q.size == 0: | |
| return "" | |
| q = self._normalize(q) | |
| scores = np.dot(self._vectors, q.T).flatten() | |
| top_indices = scores.argsort()[::-1][: self._top_k] | |
| selected = [self._long[i] for i in top_indices if scores[i] > 0] | |
| return "\n".join(selected) | |
| def reset(self) -> None: | |
| self._short = [] | |
| self._long = [] | |
| self._vectors = None | |
| if self._store_path.exists(): | |
| self._store_path.unlink() | |
| if self._vectors_path.exists(): | |
| self._vectors_path.unlink() | |
| def _summarize_short(self) -> str: | |
| if not self._short: | |
| return "" | |
| lines: list[str] = [] | |
| for entry in self._short: | |
| prefix = "User" if entry.role == "user" else "Assistant" | |
| lines.append(f"{prefix}: {entry.text}") | |
| return " | ".join(lines) | |
| def _save(self) -> None: | |
| payload = {"long": self._long} | |
| self._store_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2)) | |
| if self._vectors is not None and self._vectors.size: | |
| np.save(self._vectors_path, self._vectors) | |
| def _load(self) -> None: | |
| if not self._store_path.exists(): | |
| return | |
| try: | |
| payload: dict[str, Any] = json.loads(self._store_path.read_text()) | |
| self._long = list(payload.get("long", [])) | |
| if self._vectors_path.exists(): | |
| self._vectors = np.load(self._vectors_path) | |
| except Exception: | |
| self._long = [] | |
| self._vectors = None | |
| def _normalize(self, vectors: np.ndarray) -> np.ndarray: | |
| norms = np.linalg.norm(vectors, axis=1, keepdims=True) | |
| norms[norms == 0] = 1.0 | |
| return vectors / norms | |