from __future__ import annotations import random import time from langchain_core.embeddings import Embeddings from mistralai import Mistral from memory_agent.errors import is_rate_limit_error class MistralEmbedEmbeddings(Embeddings): """Embeddings client backed by Mistral embeddings API.""" def __init__( self, model_name: str, api_token: str, max_retries: int = 6, base_delay_seconds: float = 1.0, max_delay_seconds: float = 16.0, ) -> None: self._model_name = model_name self._client = Mistral(api_key=api_token) self._max_retries = max_retries self._base_delay_seconds = base_delay_seconds self._max_delay_seconds = max_delay_seconds def embed_documents(self, texts: list[str]) -> list[list[float]]: if not texts: return [] response = self._create_embeddings_with_retry(texts=texts) vectors = self._extract_vectors(response=response) if len(vectors) != len(texts): raise RuntimeError( f"Unexpected embeddings count from Mistral API: got {len(vectors)}, expected {len(texts)}." ) return vectors def embed_query(self, text: str) -> list[float]: embeddings = self.embed_documents([text]) return embeddings[0] def _create_embeddings_with_retry(self, texts: list[str]) -> object: for attempt in range(self._max_retries + 1): try: return self._client.embeddings.create( model=self._model_name, inputs=texts, ) except Exception as error: if not is_rate_limit_error(error) or attempt >= self._max_retries: raise delay = min( self._max_delay_seconds, self._base_delay_seconds * (2**attempt), ) + random.uniform(0.0, 0.25) time.sleep(delay) raise RuntimeError("Failed to request embeddings after retries.") def _extract_vectors(self, response: object) -> list[list[float]]: payload_data = getattr(response, "data", None) if payload_data is None and hasattr(response, "model_dump"): payload_data = response.model_dump().get("data", []) if payload_data is None and isinstance(response, dict): payload_data = response.get("data", []) if payload_data is None: raise RuntimeError("Unexpected embeddings response format from Mistral API.") indexed_vectors: list[tuple[int | None, list[float]]] = [] for item in payload_data: if isinstance(item, dict): index = item.get("index") raw_vector = item.get("embedding") else: index = getattr(item, "index", None) raw_vector = getattr(item, "embedding", None) if raw_vector is None: raise RuntimeError("Missing embedding vector in Mistral API response.") indexed_vectors.append((index, [float(value) for value in raw_vector])) if indexed_vectors and all(index is not None for index, _ in indexed_vectors): indexed_vectors.sort(key=lambda pair: int(pair[0])) return [vector for _, vector in indexed_vectors]