File size: 4,612 Bytes
6d882b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from httpx import AsyncClient, HTTPError, TimeoutException
from typing import List, Dict
import asyncio


class EmbeddingAPIClient:
    """
    A client for interacting with an embedding API.

    Attributes:
        base_url (str): The base URL of the embedding API.
        timeout (int): The timeout duration for requests in seconds.
        max_retries (int): The maximum number of retry attempts for failed requests.
        client (AsyncClient): An instance of AsyncClient for making HTTP requests.
    """

    def __init__(self, base_url: str, timeout: int = 60, max_retries: int = 3) -> None:
        """
        Initializes the EmbeddingAPIClient with the specified parameters.
        """

        self.base_url = base_url
        self.timeout = timeout
        self.max_retries = max_retries
        self.client = AsyncClient(base_url=base_url, timeout=timeout)

    async def _make_request_with_retry(
        self, endpoint: str, payload: Dict, retry_count: int = 0
    ):
        """
        Helper method to make a POST request with retry logic.

        Args:
            endpoint (str): The endpoint URL to which the request is sent.
            payload (Dict): The JSON data to be sent in the request.
            retry_count (int, optional): The current retry attempt count. Defaults to 0.

        Returns:
            Dict: The JSON response from the API.

        Raises:
            Exception: If the request fails after the maximum number of retries.
        """
        try:
            response = await self.client.post(endpoint, json=payload)
            response.raise_for_status()
            return response.json()

        except (HTTPError, TimeoutException) as e:
            if retry_count < self.max_retries:
                wait_time = 2**retry_count
                print(
                    f"⚠️  Request failed, retrying in {wait_time}s... (attempt {retry_count + 1}/{self.max_retries})"
                )
                await asyncio.sleep(wait_time)
                return await self._make_request_with_retry(
                    endpoint, payload, retry_count + 1
                )
            else:
                raise Exception(f"❌ Failed after {self.max_retries} retries: {str(e)}")

    async def get_dense_embeddings(
        self, texts: List[str], model: str = "qwen3-0.6b"
    ) -> List[List[float]]:
        """
        Retrieve dense embeddings from the API.

        Args:
            texts (List[str]): A list of texts for which to retrieve embeddings.
            model (str): The model to use for generating embeddings. Defaults to "qwen3-0.6b".

        Returns:
            List[List[float]]: A list of dense embeddings corresponding to the input texts.
        """
        data = await self._make_request_with_retry(
            "/embeddings", {"input": texts, "model": model}
        )
        return [item["embedding"] for item in data["data"]]

    async def get_sparse_embeddings(
        self, texts: List[str], model: str = "splade-large-query"
    ) -> List[Dict[str, List]]:
        """
        Retrieve sparse embeddings from the API.

        Args:
            texts (List[str]): A list of texts for which to retrieve embeddings.
            model (str): The model to use for generating embeddings. Defaults to "splade-large-query".

        Returns:
            List[Dict[str, List]]: A list of sparse embeddings corresponding to the input texts.
        """
        data = await self._make_request_with_retry(
            "/embed_sparse", {"input": texts, "model": model}
        )
        return data["embeddings"]

    async def rerank_documents(
        self, query: str, documents: List[str], top_k: int = 5, model: str = "bge-v2-m3"
    ) -> List[Dict]:
        """
        Rerank a list of documents based on a query.

        Args:
            query (str): The query string used for reranking.
            documents (List[str]): A list of documents to be reranked.
            top_k (int): The number of top documents to return. Defaults to 5.
            model (str): The model to use for reranking. Defaults to "bge-v2-m3".

        Returns:
            List[Dict]: A list of reranked documents with their scores.
        """
        data = await self._make_request_with_retry(
            "/rerank",
            {"query": query, "documents": documents, "top_k": top_k, "model": model},
        )
        return data["results"]

    async def close(self):
        """
        Close the HTTP client.

        This method closes the AsyncClient instance to free up resources.
        """
        await self.client.aclose()