Spaces:
Running
Running
File size: 4,612 Bytes
6d882b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from httpx import AsyncClient, HTTPError, TimeoutException
from typing import List, Dict
import asyncio
class EmbeddingAPIClient:
"""
A client for interacting with an embedding API.
Attributes:
base_url (str): The base URL of the embedding API.
timeout (int): The timeout duration for requests in seconds.
max_retries (int): The maximum number of retry attempts for failed requests.
client (AsyncClient): An instance of AsyncClient for making HTTP requests.
"""
def __init__(self, base_url: str, timeout: int = 60, max_retries: int = 3) -> None:
"""
Initializes the EmbeddingAPIClient with the specified parameters.
"""
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
self.client = AsyncClient(base_url=base_url, timeout=timeout)
async def _make_request_with_retry(
self, endpoint: str, payload: Dict, retry_count: int = 0
):
"""
Helper method to make a POST request with retry logic.
Args:
endpoint (str): The endpoint URL to which the request is sent.
payload (Dict): The JSON data to be sent in the request.
retry_count (int, optional): The current retry attempt count. Defaults to 0.
Returns:
Dict: The JSON response from the API.
Raises:
Exception: If the request fails after the maximum number of retries.
"""
try:
response = await self.client.post(endpoint, json=payload)
response.raise_for_status()
return response.json()
except (HTTPError, TimeoutException) as e:
if retry_count < self.max_retries:
wait_time = 2**retry_count
print(
f"⚠️ Request failed, retrying in {wait_time}s... (attempt {retry_count + 1}/{self.max_retries})"
)
await asyncio.sleep(wait_time)
return await self._make_request_with_retry(
endpoint, payload, retry_count + 1
)
else:
raise Exception(f"❌ Failed after {self.max_retries} retries: {str(e)}")
async def get_dense_embeddings(
self, texts: List[str], model: str = "qwen3-0.6b"
) -> List[List[float]]:
"""
Retrieve dense embeddings from the API.
Args:
texts (List[str]): A list of texts for which to retrieve embeddings.
model (str): The model to use for generating embeddings. Defaults to "qwen3-0.6b".
Returns:
List[List[float]]: A list of dense embeddings corresponding to the input texts.
"""
data = await self._make_request_with_retry(
"/embeddings", {"input": texts, "model": model}
)
return [item["embedding"] for item in data["data"]]
async def get_sparse_embeddings(
self, texts: List[str], model: str = "splade-large-query"
) -> List[Dict[str, List]]:
"""
Retrieve sparse embeddings from the API.
Args:
texts (List[str]): A list of texts for which to retrieve embeddings.
model (str): The model to use for generating embeddings. Defaults to "splade-large-query".
Returns:
List[Dict[str, List]]: A list of sparse embeddings corresponding to the input texts.
"""
data = await self._make_request_with_retry(
"/embed_sparse", {"input": texts, "model": model}
)
return data["embeddings"]
async def rerank_documents(
self, query: str, documents: List[str], top_k: int = 5, model: str = "bge-v2-m3"
) -> List[Dict]:
"""
Rerank a list of documents based on a query.
Args:
query (str): The query string used for reranking.
documents (List[str]): A list of documents to be reranked.
top_k (int): The number of top documents to return. Defaults to 5.
model (str): The model to use for reranking. Defaults to "bge-v2-m3".
Returns:
List[Dict]: A list of reranked documents with their scores.
"""
data = await self._make_request_with_retry(
"/rerank",
{"query": query, "documents": documents, "top_k": top_k, "model": model},
)
return data["results"]
async def close(self):
"""
Close the HTTP client.
This method closes the AsyncClient instance to free up resources.
"""
await self.client.aclose()
|