"""Simple short + long-term memory store.""" from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Any import numpy as np from .rag.embeddings import AzureEmbeddingClient @dataclass class MemoryEntry: role: str text: str class MemoryStore: def __init__(self, store_path: Path, top_k: int, summary_every: int) -> None: self._store_path = store_path self._top_k = top_k self._summary_every = summary_every self._vectors_path = self._store_path.with_suffix(".npy") self._short: list[MemoryEntry] = [] self._long: list[str] = [] self._vectors: np.ndarray | None = None self._embedder = AzureEmbeddingClient() self._store_path.parent.mkdir(parents=True, exist_ok=True) self._load() def add_turn(self, user: str, assistant: str) -> None: self._short.append(MemoryEntry(role="user", text=user)) self._short.append(MemoryEntry(role="assistant", text=assistant)) if len(self._short) // 2 >= self._summary_every: summary = self._summarize_short() if summary: self._long.append(summary) self._short = [] self._save() def short_context(self, max_turns: int = 6) -> str: if not self._short: return "" lines: list[str] = [] for entry in self._short[-max_turns * 2 :]: prefix = "User" if entry.role == "user" else "Assistant" lines.append(f"{prefix}: {entry.text}") return "\n".join(lines) def load_short(self, entries: list[dict[str, str]]) -> None: self._short = [MemoryEntry(role=e["role"], text=e["text"]) for e in entries] def short_entries(self) -> list[dict[str, str]]: return [{"role": e.role, "text": e.text} for e in self._short] def long_context(self, query: str) -> str: if not self._long: return "" if self._vectors is None or len(self._vectors) != len(self._long): self._vectors = self._embedder.embed(self._long) if self._vectors.size: self._vectors = self._normalize(self._vectors) if self._vectors is None or self._vectors.size == 0: return "" q = self._embedder.embed([query]) if q.size == 0: return "" q = self._normalize(q) scores = np.dot(self._vectors, q.T).flatten() top_indices = scores.argsort()[::-1][: self._top_k] selected = [self._long[i] for i in top_indices if scores[i] > 0] return "\n".join(selected) def reset(self) -> None: self._short = [] self._long = [] self._vectors = None if self._store_path.exists(): self._store_path.unlink() if self._vectors_path.exists(): self._vectors_path.unlink() def _summarize_short(self) -> str: if not self._short: return "" lines: list[str] = [] for entry in self._short: prefix = "User" if entry.role == "user" else "Assistant" lines.append(f"{prefix}: {entry.text}") return " | ".join(lines) def _save(self) -> None: payload = {"long": self._long} self._store_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2)) if self._vectors is not None and self._vectors.size: np.save(self._vectors_path, self._vectors) def _load(self) -> None: if not self._store_path.exists(): return try: payload: dict[str, Any] = json.loads(self._store_path.read_text()) self._long = list(payload.get("long", [])) if self._vectors_path.exists(): self._vectors = np.load(self._vectors_path) except Exception: self._long = [] self._vectors = None def _normalize(self, vectors: np.ndarray) -> np.ndarray: norms = np.linalg.norm(vectors, axis=1, keepdims=True) norms[norms == 0] = 1.0 return vectors / norms