SIMPLE_AI / memory_manager.py
tiahchia's picture
Upload 6 files
015dbc8 verified
raw
history blame
1.94 kB
import logging
import os
from typing import List, Sequence, Tuple
import numpy as np
from openai import OpenAI
from backend import db
EMBEDDING_MODEL = "text-embedding-3-small"
LOGGER = logging.getLogger(__name__)
_CLIENT: OpenAI | None = None
def _get_client() -> OpenAI:
global _CLIENT
if _CLIENT is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY is not set. Please configure the environment.")
_CLIENT = OpenAI(api_key=api_key)
return _CLIENT
def get_embedding(text: str) -> List[float]:
"""Generate an embedding vector for the provided text."""
client = _get_client()
response = client.embeddings.create(model=EMBEDDING_MODEL, input=[text])
return response.data[0].embedding
def _cosine_similarity(vector_a: Sequence[float], vector_b: Sequence[float]) -> float:
a = np.asarray(vector_a)
b = np.asarray(vector_b)
if np.linalg.norm(a) == 0 or np.linalg.norm(b) == 0:
return 0.0
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def retrieve_relevant_memories(user_id: str, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""Return the top K memories most similar to the query for the given user."""
try:
stored_embeddings = db.get_user_embeddings(user_id)
if not stored_embeddings:
return []
except Exception as exc: # pragma: no cover - defensive logging
LOGGER.exception("Failed to load embeddings for user %s", user_id)
return []
query_embedding = get_embedding(query)
scored_memories: List[Tuple[str, float]] = []
for record in stored_embeddings:
similarity = _cosine_similarity(record["embedding"], query_embedding)
scored_memories.append((record["text"], similarity))
scored_memories.sort(key=lambda item: item[1], reverse=True)
return scored_memories[:top_k]