| from __future__ import annotations |
|
|
| import random |
| import time |
|
|
| from langchain_core.embeddings import Embeddings |
| from mistralai import Mistral |
|
|
| from memory_agent.errors import is_rate_limit_error |
|
|
|
|
| class MistralEmbedEmbeddings(Embeddings): |
| """Embeddings client backed by Mistral embeddings API.""" |
|
|
| def __init__( |
| self, |
| model_name: str, |
| api_token: str, |
| max_retries: int = 6, |
| base_delay_seconds: float = 1.0, |
| max_delay_seconds: float = 16.0, |
| ) -> None: |
| self._model_name = model_name |
| self._client = Mistral(api_key=api_token) |
| self._max_retries = max_retries |
| self._base_delay_seconds = base_delay_seconds |
| self._max_delay_seconds = max_delay_seconds |
|
|
| def embed_documents(self, texts: list[str]) -> list[list[float]]: |
| if not texts: |
| return [] |
| response = self._create_embeddings_with_retry(texts=texts) |
| vectors = self._extract_vectors(response=response) |
| if len(vectors) != len(texts): |
| raise RuntimeError( |
| f"Unexpected embeddings count from Mistral API: got {len(vectors)}, expected {len(texts)}." |
| ) |
| return vectors |
|
|
| def embed_query(self, text: str) -> list[float]: |
| embeddings = self.embed_documents([text]) |
| return embeddings[0] |
|
|
| def _create_embeddings_with_retry(self, texts: list[str]) -> object: |
| for attempt in range(self._max_retries + 1): |
| try: |
| return self._client.embeddings.create( |
| model=self._model_name, |
| inputs=texts, |
| ) |
| except Exception as error: |
| if not is_rate_limit_error(error) or attempt >= self._max_retries: |
| raise |
| delay = min( |
| self._max_delay_seconds, |
| self._base_delay_seconds * (2**attempt), |
| ) + random.uniform(0.0, 0.25) |
| time.sleep(delay) |
| raise RuntimeError("Failed to request embeddings after retries.") |
|
|
| def _extract_vectors(self, response: object) -> list[list[float]]: |
| payload_data = getattr(response, "data", None) |
| if payload_data is None and hasattr(response, "model_dump"): |
| payload_data = response.model_dump().get("data", []) |
| if payload_data is None and isinstance(response, dict): |
| payload_data = response.get("data", []) |
| if payload_data is None: |
| raise RuntimeError("Unexpected embeddings response format from Mistral API.") |
|
|
| indexed_vectors: list[tuple[int | None, list[float]]] = [] |
| for item in payload_data: |
| if isinstance(item, dict): |
| index = item.get("index") |
| raw_vector = item.get("embedding") |
| else: |
| index = getattr(item, "index", None) |
| raw_vector = getattr(item, "embedding", None) |
| if raw_vector is None: |
| raise RuntimeError("Missing embedding vector in Mistral API response.") |
| indexed_vectors.append((index, [float(value) for value in raw_vector])) |
|
|
| if indexed_vectors and all(index is not None for index, _ in indexed_vectors): |
| indexed_vectors.sort(key=lambda pair: int(pair[0])) |
| return [vector for _, vector in indexed_vectors] |
|
|