File size: 3,349 Bytes
6059138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

import random
import time

from langchain_core.embeddings import Embeddings
from mistralai import Mistral

from memory_agent.errors import is_rate_limit_error


class MistralEmbedEmbeddings(Embeddings):
    """Embeddings client backed by Mistral embeddings API."""

    def __init__(
        self,
        model_name: str,
        api_token: str,
        max_retries: int = 6,
        base_delay_seconds: float = 1.0,
        max_delay_seconds: float = 16.0,
    ) -> None:
        self._model_name = model_name
        self._client = Mistral(api_key=api_token)
        self._max_retries = max_retries
        self._base_delay_seconds = base_delay_seconds
        self._max_delay_seconds = max_delay_seconds

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        if not texts:
            return []
        response = self._create_embeddings_with_retry(texts=texts)
        vectors = self._extract_vectors(response=response)
        if len(vectors) != len(texts):
            raise RuntimeError(
                f"Unexpected embeddings count from Mistral API: got {len(vectors)}, expected {len(texts)}."
            )
        return vectors

    def embed_query(self, text: str) -> list[float]:
        embeddings = self.embed_documents([text])
        return embeddings[0]

    def _create_embeddings_with_retry(self, texts: list[str]) -> object:
        for attempt in range(self._max_retries + 1):
            try:
                return self._client.embeddings.create(
                    model=self._model_name,
                    inputs=texts,
                )
            except Exception as error:
                if not is_rate_limit_error(error) or attempt >= self._max_retries:
                    raise
                delay = min(
                    self._max_delay_seconds,
                    self._base_delay_seconds * (2**attempt),
                ) + random.uniform(0.0, 0.25)
                time.sleep(delay)
        raise RuntimeError("Failed to request embeddings after retries.")

    def _extract_vectors(self, response: object) -> list[list[float]]:
        payload_data = getattr(response, "data", None)
        if payload_data is None and hasattr(response, "model_dump"):
            payload_data = response.model_dump().get("data", [])
        if payload_data is None and isinstance(response, dict):
            payload_data = response.get("data", [])
        if payload_data is None:
            raise RuntimeError("Unexpected embeddings response format from Mistral API.")

        indexed_vectors: list[tuple[int | None, list[float]]] = []
        for item in payload_data:
            if isinstance(item, dict):
                index = item.get("index")
                raw_vector = item.get("embedding")
            else:
                index = getattr(item, "index", None)
                raw_vector = getattr(item, "embedding", None)
            if raw_vector is None:
                raise RuntimeError("Missing embedding vector in Mistral API response.")
            indexed_vectors.append((index, [float(value) for value in raw_vector]))

        if indexed_vectors and all(index is not None for index, _ in indexed_vectors):
            indexed_vectors.sort(key=lambda pair: int(pair[0]))
        return [vector for _, vector in indexed_vectors]