diabetesLLM / core /embeddings.py
KS00Max's picture
first commit
f27bb68
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Dict, List
import numpy as np
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_exponential
from .config import CACHE_DIR, get_settings
logger = logging.getLogger(__name__)
class EmbeddingClient:
def __init__(self, settings=None, cache_path: Path | None = None):
self.settings = settings or get_settings()
self.client = OpenAI(
api_key=self.settings.openai_api_key,
base_url=self.settings.openai_base_url,
)
self.model = self.settings.embedding_model
self.cache_path = cache_path or (CACHE_DIR / "embedding_cache.json")
self._cache: Dict[str, List[float]] = {}
self._load_cache()
def _load_cache(self) -> None:
if self.cache_path.exists():
try:
self._cache = json.loads(self.cache_path.read_text(encoding="utf-8"))
except Exception as exc:
logger.warning("Failed to load embedding cache: %s", exc)
self._cache = {}
def _save_cache(self) -> None:
try:
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
self.cache_path.write_text(json.dumps(self._cache), encoding="utf-8")
except Exception as exc:
logger.warning("Failed to save embedding cache: %s", exc)
@retry(wait=wait_exponential(multiplier=1, min=1, max=10), stop=stop_after_attempt(5))
def _embed(self, texts: List[str]) -> List[List[float]]:
response = self.client.embeddings.create(model=self.model, input=texts)
return [item.embedding for item in response.data]
def embed(self, texts: List[str]) -> List[List[float]]:
vectors: List[List[float]] = []
to_embed: List[str] = []
for text in texts:
cached = self._cache.get(text)
if cached is not None:
vectors.append(cached)
else:
vectors.append(None) # placeholder
to_embed.append(text)
if to_embed:
new_vecs = self._embed(to_embed)
idx = 0
for i, vec in enumerate(vectors):
if vec is None:
new_v = new_vecs[idx]
vectors[i] = new_v
self._cache[to_embed[idx]] = new_v
idx += 1
self._save_cache()
return vectors
def embed_one(self, text: str) -> List[float]:
return self.embed([text])[0]
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
denom = np.linalg.norm(a) * np.linalg.norm(b)
return float(np.dot(a, b) / denom) if denom else 0.0