SIMPLE_AI / memory_manager.py
tiahchia's picture
Update memory_manager.py
09fc4f6 verified
import logging
import os
from typing import List, Sequence, Tuple
import numpy as np
from openai import OpenAI
import db
EMBEDDING_MODEL = "text-embedding-3-small"
LOGGER = logging.getLogger(__name__)
_CLIENT: OpenAI | None = None
def _get_client() -> OpenAI:
global _CLIENT
if _CLIENT is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY is not set. Please configure the environment.")
_CLIENT = OpenAI(api_key=api_key)
return _CLIENT
def get_embedding(text: str) -> List[float]:
"""Generate an embedding vector for the provided text."""
client = _get_client()
response = client.embeddings.create(model=EMBEDDING_MODEL, input=[text])
return response.data[0].embedding
def _cosine_similarity(vector_a: Sequence[float], vector_b: Sequence[float]) -> float:
a = np.asarray(vector_a)
b = np.asarray(vector_b)
if np.linalg.norm(a) == 0 or np.linalg.norm(b) == 0:
return 0.0
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def retrieve_relevant_memories(user_id: str, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""Return the top K memories most similar to the query for the given user."""
try:
stored_embeddings = db.get_user_embeddings(user_id)
if not stored_embeddings:
return []
except Exception as exc: # pragma: no cover - defensive logging
LOGGER.exception("Failed to load embeddings for user %s", user_id)
return []
query_embedding = get_embedding(query)
scored_memories: List[Tuple[str, float]] = []
for record in stored_embeddings:
similarity = _cosine_similarity(record["embedding"], query_embedding)
scored_memories.append((record["text"], similarity))
scored_memories.sort(key=lambda item: item[1], reverse=True)
return scored_memories[:top_k]