Marik1337's picture
Add application file
6059138
from __future__ import annotations
import random
import time
from langchain_core.embeddings import Embeddings
from mistralai import Mistral
from memory_agent.errors import is_rate_limit_error
class MistralEmbedEmbeddings(Embeddings):
"""Embeddings client backed by Mistral embeddings API."""
def __init__(
self,
model_name: str,
api_token: str,
max_retries: int = 6,
base_delay_seconds: float = 1.0,
max_delay_seconds: float = 16.0,
) -> None:
self._model_name = model_name
self._client = Mistral(api_key=api_token)
self._max_retries = max_retries
self._base_delay_seconds = base_delay_seconds
self._max_delay_seconds = max_delay_seconds
def embed_documents(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
response = self._create_embeddings_with_retry(texts=texts)
vectors = self._extract_vectors(response=response)
if len(vectors) != len(texts):
raise RuntimeError(
f"Unexpected embeddings count from Mistral API: got {len(vectors)}, expected {len(texts)}."
)
return vectors
def embed_query(self, text: str) -> list[float]:
embeddings = self.embed_documents([text])
return embeddings[0]
def _create_embeddings_with_retry(self, texts: list[str]) -> object:
for attempt in range(self._max_retries + 1):
try:
return self._client.embeddings.create(
model=self._model_name,
inputs=texts,
)
except Exception as error:
if not is_rate_limit_error(error) or attempt >= self._max_retries:
raise
delay = min(
self._max_delay_seconds,
self._base_delay_seconds * (2**attempt),
) + random.uniform(0.0, 0.25)
time.sleep(delay)
raise RuntimeError("Failed to request embeddings after retries.")
def _extract_vectors(self, response: object) -> list[list[float]]:
payload_data = getattr(response, "data", None)
if payload_data is None and hasattr(response, "model_dump"):
payload_data = response.model_dump().get("data", [])
if payload_data is None and isinstance(response, dict):
payload_data = response.get("data", [])
if payload_data is None:
raise RuntimeError("Unexpected embeddings response format from Mistral API.")
indexed_vectors: list[tuple[int | None, list[float]]] = []
for item in payload_data:
if isinstance(item, dict):
index = item.get("index")
raw_vector = item.get("embedding")
else:
index = getattr(item, "index", None)
raw_vector = getattr(item, "embedding", None)
if raw_vector is None:
raise RuntimeError("Missing embedding vector in Mistral API response.")
indexed_vectors.append((index, [float(value) for value in raw_vector]))
if indexed_vectors and all(index is not None for index, _ in indexed_vectors):
indexed_vectors.sort(key=lambda pair: int(pair[0]))
return [vector for _, vector in indexed_vectors]