Sbboss's picture
RAG, language updates
0b2d478
"""Simple short + long-term memory store."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from .rag.embeddings import AzureEmbeddingClient
@dataclass
class MemoryEntry:
role: str
text: str
class MemoryStore:
def __init__(self, store_path: Path, top_k: int, summary_every: int) -> None:
self._store_path = store_path
self._top_k = top_k
self._summary_every = summary_every
self._vectors_path = self._store_path.with_suffix(".npy")
self._short: list[MemoryEntry] = []
self._long: list[str] = []
self._vectors: np.ndarray | None = None
self._embedder = AzureEmbeddingClient()
self._store_path.parent.mkdir(parents=True, exist_ok=True)
self._load()
def add_turn(self, user: str, assistant: str) -> None:
self._short.append(MemoryEntry(role="user", text=user))
self._short.append(MemoryEntry(role="assistant", text=assistant))
if len(self._short) // 2 >= self._summary_every:
summary = self._summarize_short()
if summary:
self._long.append(summary)
self._short = []
self._save()
def short_context(self, max_turns: int = 6) -> str:
if not self._short:
return ""
lines: list[str] = []
for entry in self._short[-max_turns * 2 :]:
prefix = "User" if entry.role == "user" else "Assistant"
lines.append(f"{prefix}: {entry.text}")
return "\n".join(lines)
def load_short(self, entries: list[dict[str, str]]) -> None:
self._short = [MemoryEntry(role=e["role"], text=e["text"]) for e in entries]
def short_entries(self) -> list[dict[str, str]]:
return [{"role": e.role, "text": e.text} for e in self._short]
def long_context(self, query: str) -> str:
if not self._long:
return ""
if self._vectors is None or len(self._vectors) != len(self._long):
self._vectors = self._embedder.embed(self._long)
if self._vectors.size:
self._vectors = self._normalize(self._vectors)
if self._vectors is None or self._vectors.size == 0:
return ""
q = self._embedder.embed([query])
if q.size == 0:
return ""
q = self._normalize(q)
scores = np.dot(self._vectors, q.T).flatten()
top_indices = scores.argsort()[::-1][: self._top_k]
selected = [self._long[i] for i in top_indices if scores[i] > 0]
return "\n".join(selected)
def reset(self) -> None:
self._short = []
self._long = []
self._vectors = None
if self._store_path.exists():
self._store_path.unlink()
if self._vectors_path.exists():
self._vectors_path.unlink()
def _summarize_short(self) -> str:
if not self._short:
return ""
lines: list[str] = []
for entry in self._short:
prefix = "User" if entry.role == "user" else "Assistant"
lines.append(f"{prefix}: {entry.text}")
return " | ".join(lines)
def _save(self) -> None:
payload = {"long": self._long}
self._store_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2))
if self._vectors is not None and self._vectors.size:
np.save(self._vectors_path, self._vectors)
def _load(self) -> None:
if not self._store_path.exists():
return
try:
payload: dict[str, Any] = json.loads(self._store_path.read_text())
self._long = list(payload.get("long", []))
if self._vectors_path.exists():
self._vectors = np.load(self._vectors_path)
except Exception:
self._long = []
self._vectors = None
def _normalize(self, vectors: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms[norms == 0] = 1.0
return vectors / norms